import pickle
import dill
import numpy as np
import matplotlib.pyplot as plt


test_problem = 'stove'
model = 'leverage'

# directory for loading results:
dirname = 'results_problem='+test_problem +'_model='+model

# Load necessary data:
FE_results = pickle.load(open('results/FE_data_problem='+test_problem+'.pickle','rb'))
grid_t = FE_results['grid_t']
rhs_t = dill.load(open('results/rhs_t_problem='+test_problem+'.dill','rb'))

# tested parameters:
n_rand = 10
tol = 10**(-8)
iterations = 10**5
nt_rand = 15

# determine ordering:
which_error = 'rel_C0_L2_errors'
k = 15
results = pickle.load(open(dirname+f'/iterations={iterations}_nrand={n_rand}_ninit=ntrand={nt_rand}_k={k}_tol={tol}.pickle','rb'))
ordering = np.argsort(results[which_error])

# load results:
L2_errors = []
for k in [13,15]:
    results = pickle.load(open(dirname+f'/iterations={iterations}_nrand={n_rand}_ninit=ntrand={nt_rand}_k={k}_tol={tol}.pickle','rb'))
    chosen_points = results['chosen_time_points'][ordering, :][97729]
    L2_errors.append(results['rel_L2_errors_over_time'][ordering,:][97729,:])


# Figure 7 left:
plt.figure(figsize=plt.figaspect(0.5))
for i in range(len(rhs_t)):
    plt.plot(grid_t,rhs_t[i](grid_t))
plt.plot(chosen_points,np.zeros(10),marker=6,linestyle = 'None',label=f'chosen (end) time points',color='black')
for i in range(len(rhs_t)):
    for j in range(10):
        if rhs_t[i](chosen_points[j]) != 0:
            plt.plot(chosen_points[j],rhs_t[i](chosen_points[j]),marker='.',linestyle='None',color='black')
plt.legend()
plt.show()

# Figure 7 right:
plt.figure()
plt.semilogy(L2_errors[0],label='k=13')
plt.semilogy(L2_errors[1],label='k=15')
plt.legend()
plt.show()