#!/usr/bin/env python

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import inspect, os
from scipy.io import netcdf
import sys

print("Add any argument to save a PDF")
makePDF = False
if len(sys.argv) > 1:
    makePDF = True

files = ['20200716-01-032_nthetanzetaPlasma128_nthetanzetaCoil128_mpolntor32_ns3/regcoil_out.test.nc', \
         '20200716-01-033_nthetanzetaPlasma128_nthetanzetaCoil128_mpolntor32_ns6/regcoil_out.test.nc', \
         '20200716-01-034_nthetanzetaPlasma256_nthetanzetaCoil128_mpolntor32_ns3/regcoil_out.test.nc', \
         '20200716-01-035_nthetanzetaPlasma128_nthetanzetaCoil256_mpolntor32_ns3/regcoil_out.test.nc', \
         '20200716-01-041_nthetanzetaPlasma128_nthetanzetaCoil129_mpolntor64_ns3/regcoil_out.test.nc']

specialLambdasFile='20200716-01-039_nthetanzetaPlasma128_nthetanzetaCoil128_mpolntor32_ns3_multipleOf10lambdas/regcoil_out.test.nc'

specialLambdas = [1e-9, 1e-15, 1e-18]

lambda_many = []
chi2_B_many = []
chi2_M_many = []
max_B_many = []
max_M_many = []
max_lambda = 0

for whichFile in range(len(files)):
    filename = files[whichFile]
    f = netcdf.netcdf_file(filename,'r',mmap=False)
    # We use 'lambdas' instead of 'lambda' to avoid conflict with python's keyword lambda.
    lambdas = f.variables['lambda'][()]
    permutation = np.argsort(lambdas)
    lambdas = lambdas[permutation]
    if lambdas[-1]>1.0e199:
        lambdas[-1] = np.inf

    lambda_many.append(lambdas)
    max_lambda = np.max((max_lambda,np.max(lambdas)))

    chi2_B_many.append(f.variables['chi2_B'][()][permutation])
    max_B_many.append(f.variables['max_Bnormal'][()][permutation])

    chi2_M_many.append(f.variables['chi2_M'][()][permutation])
    max_M_many.append(f.variables['max_M'][()][permutation])

    f.close()

    print("Read data from file "+filename)

# Now read in the special points
f = netcdf.netcdf_file(specialLambdasFile,'r',mmap=False)
# We use 'lambdas' instead of 'lambda' to avoid conflict with python's keyword lambda.
lambdas = f.variables['lambda'][()]
chi2_B_allspecial = f.variables['chi2_B'][()]
chi2_M_allspecial = f.variables['chi2_M'][()]
f.close()

print("Read data from file "+filename)

chi2_B_special = []
chi2_M_special = []
for specialLambda in specialLambdas:
    for j in range(len(lambdas)):
        if np.abs(specialLambda - lambdas[j]) < 1.0e-20 and lambdas[j] > 1e-20:
            chi2_B_special.append(chi2_B_allspecial[j])
            chi2_M_special.append(chi2_M_allspecial[j])


##########################################################
# Make plots
##########################################################

data_M = chi2_M_many
label_M = r'$f_M$ [Amperes$^2$ meters$^2$]'
data_B = chi2_B_many
label_B = r'$f_B$ [Tesla$^2$ meters$^2$]'


#matplotlib.rcParams.update({'font.size': 9})
fig = plt.figure(figsize=(5,4.5))
fig.patch.set_facecolor('white')

for whichFile in range(len(files)):
    filename = files[whichFile]
    plt.loglog(data_M[whichFile],data_B[whichFile],'-b',label=filename)

plt.plot(chi2_M_special, chi2_B_special, '.r', ms=7)
plt.xlabel(label_M)
plt.ylabel(label_B)
plt.grid(True)

plt.xlim([100, 3.0e13])
#plt.ylim([1.0e-8, 2.0])
plt.ylim([3.0e-9, 2.0])

plt.text(chi2_M_special[0], chi2_B_special[0]*1.2, r'$\lambda=10^{-9}$')
plt.text(chi2_M_special[1]*1.2, chi2_B_special[1], r'$\lambda=10^{-15}$')
plt.text(chi2_M_special[2]*1.2, chi2_B_special[2], r'$\lambda=10^{-18}$')

#plt.subplots_adjust(left=0.05,bottom=0.08,right=0.99,top=0.93,wspace=0.15,hspace=0.22)

#titleString = "Plot generated by "+ os.path.abspath(inspect.getfile(inspect.currentframe())) + "\nRun in "+os.getcwd()
#plt.figtext(0.5,0.99,titleString,horizontalalignment='center',verticalalignment='top')

myfontsize=10
plt.text(1.5e2, 0.18, r'High regularization $\lambda$', fontsize=myfontsize)
plt.text(1.5e2, 1.25e-8, 'Ideal solutions\nwould be here', fontsize=myfontsize, va='bottom')
plt.text(1.5e11, 1.25e-8, r'Low regularization $\lambda$', fontsize=myfontsize, ha='right', va='bottom')

plt.tight_layout()

if makePDF:
    print("Saving PDF")
    plt.savefig(__file__+'.pdf')
else:
    plt.show()

