'''
Here we compare the ML regression on the first part of the samples to the regression on the second part of the samples,
where the even is uut of the tableau (i.e. its info cannot be conveyed by self pulsing)
'''

import numpy as np
from matplotlib import pyplot as plt
# import scipy.io as sio

# load data
i_structure = 2
structures = ['1 MRR', '2 MRRs', '3 MRRs']
if i_structure == 0:
    savename = 'SavedMaps_noPulseRead_newChip_RING-lowQ_2023-12-15_16-26_2RingsScissor.pkl'
elif i_structure == 1:
    savename = 'SavedMaps_noPulseRead_newChip_SCISSOR-2-1_2023-12-18_16-09_2RingsScissor.pkl'
elif i_structure == 2:
    savename = 'SavedMaps_noPulseRead_newChip_SCISSOR-3_2023-12-19_13-19_2RingsScissor.pkl'

loaded_data = np.load(savename, allow_pickle=True)
OV_scoresTest_drop, OV_scoresTrain_drop, OV_scoresTest_through, OV_scoresTrain_through, OV_storedPulse_drop, OV_storedPulse_through, OV_stored_drop, OV_stored_through, OV_X_drop, OV_X_through, OV_y = loaded_data

i_freq = 15
i_pow = 10

# X = OV_X_through[i_freq][i_pow][0]
# X = OV_X_drop[i_freq][i_pow][0]
#
# plt.figure()
# plt.plot( np.concatenate( np.concatenate([ X, -0.2*np.ones(len(X))[:, None ] ], axis=1) ) )

# now extract all the points from OV_stored_drop and through (i_freq, i_pow, i_rep, i_samp, i_feat)
X_extended_drop = np.zeros( ( len(OV_stored_drop), len(OV_stored_drop[0]), len(OV_stored_drop[0][0][0]),
                              len(OV_stored_drop[0][0][0][0]), len(OV_stored_drop[0][0][0][0][0][0]), ) )
X_extended_through = np.zeros(np.shape(X_extended_drop))

for i_freq in range(len(X_extended_drop)):
    for i_pow in range(len(X_extended_drop[0])):
        for i_rep in range(len(X_extended_drop[0, 0])):
            for i_samp in range(len(X_extended_drop[0, 0, 0])):
                # for i_feat in range(len(X_extended_drop[0, 0, 0, 0])):
                try:
                    X_extended_drop[i_freq, i_pow, i_rep, i_samp, :] = OV_stored_drop[i_freq][i_pow][0][i_rep][i_samp][0][:,0]
                except:
                    pass
                try:
                    X_extended_through[i_freq, i_pow, i_rep, i_samp, :] = OV_stored_through[i_freq][i_pow][0][i_rep][i_samp][0][:,0]
                except:
                    pass


# X_extDrop = np.concatenate([ X_extended_drop[:,:,0,:int(len(X_extended_drop[0,0,0])/2) ,:],
#                             X_extended_drop[:, :, 1,:int(len(X_extended_drop[0, 0, 0]) / 2), :] ], axis = 2)   # here we assume that there are 2 repetitions of the data (i.e. 2 dimensions on the third axis of the array)
# X_extDrop_alt = np.concatenate([ X_extended_drop[:,:,0,:int(len(X_extended_drop[0,0,0])/2) ,:],
#                             X_extended_drop[:, :, 1,:int(len(X_extended_drop[0, 0, 0]) / 2), :] ], axis = 2)   # here we assume that there are 2 repetitions of the data (i.e. 2 dimensions on the third axis of the array)
# X_extThrough = np.concatenate([ X_extended_through[:,:,0,:int(len(X_extended_through[0,0,0])/2) ,:],
#                             X_extended_through[:, :, 1,:int(len(X_extended_through[0, 0, 0]) / 2), :] ], axis = 2)   # here we assume that there are 2 repetitions of the data (i.e. 2 dimensions on the third axis of the array)
# X_extThrough_alt = np.concatenate([ X_extended_through[:,:,0,:int(len(X_extended_through[0,0,0])/2) ,:],
#                             X_extended_through[:, :, 1,:int(len(X_extended_through[0, 0, 0]) / 2), :] ], axis = 2)   # here we assume that there are 2 repetitions of the data (i.e. 2 dimensions on the third axis of the array)

# for now use only one (the first) data repetition (so there is only one sample per y value)
X_extDrop = X_extended_drop[:,:,0,:int(len(X_extended_drop[0,0,0])/2) ,:]
X_extDrop_alt = X_extended_drop[:,:,0,int(len(X_extended_drop[0,0,0])/2)+1: ,:]
X_extThrough = X_extended_through[:,:,0,:int(len(X_extended_through[0,0,0])/2) ,:]
X_extThrough_alt = X_extended_through[:,:,0,int(len(X_extended_through[0,0,0])/2)+1: ,:]

y = np.linspace(0,1,len(X_extDrop[0,0]))
y_alt = np.linspace(0,1,len(X_extDrop_alt[0,0]))

i_freq = 5
i_pow = 15

# loops
port_list = ['drop', 'through']


n_splits_list = [-1]
# initialize storage of results
# hypermap_corrSum = np.zeros(( len(OV_stored_drop), len(OV_stored_drop[0]), len(port_list), len(n_splits_list), len(X_extDrop[0, 0]), len(X_extDrop[0, 0]) ))
# hypermap_corrSum_alt = np.zeros(( len(OV_stored_drop), len(OV_stored_drop[0]), len(port_list), len(n_splits_list), len(X_extDrop_alt[0, 0]), len(X_extDrop_alt[0, 0]) ))
hypermap_corrMean = np.ones((len(OV_stored_drop), len(OV_stored_drop[0]), len(port_list), len(n_splits_list)))
hypermap_corrMean_alt = np.ones((len(OV_stored_drop), len(OV_stored_drop[0]), len(port_list), len(n_splits_list)))
hypermap_corrMem = np.ones((len(OV_stored_drop), len(OV_stored_drop[0]), len(port_list), len(n_splits_list), len(X_extDrop[0, 0])))
hypermap_corrMem_alt = np.ones((len(OV_stored_drop), len(OV_stored_drop[0]), len(port_list), len(n_splits_list), len(X_extDrop[0, 0])))

for i_freq in range(len(OV_stored_drop)):
    for i_pow in range(len(OV_stored_drop[0])):
        print(i_freq, i_pow)
        try:
            # y = OV_y[i_freq][i_pow][0]
            for i_port in range(len(port_list)):
                if port_list[i_port] == 'drop':
                    X_ = X_extDrop[i_freq, i_pow]
                    X_alt_ = X_extDrop_alt[i_freq, i_pow]
                elif port_list[i_port] == 'through':
                    X_ = X_extThrough[i_freq, i_pow]
                    X_alt_ = X_extThrough_alt[i_freq, i_pow]
                elif port_list[i_port] == 'both':
                    X_ = np.concatenate([ X_extThrough[i_freq, i_pow], X_extDrop[i_freq, i_pow] ], axis = 1)
                    X_alt_ = np.concatenate([X_extThrough_alt[i_freq, i_pow], X_extDrop_alt[i_freq, i_pow]], axis=1)

                for i_split in range(len(n_splits_list)):
                    if n_splits_list[i_split] == -1:
                        X = X_
                        X_alt = X_alt_
                    elif n_splits_list[i_split] == 1:
                        X = np.sum(X_, axis=1)[:, None]
                        X_alt = np.sum(X_alt_, axis=1)[:, None]
                    else:
                        X = np.sum(np.array(np.split(X_[:, :-(len(X_[0])%n_splits_list[i_split])], indices_or_sections=n_splits_list[i_split], axis=1)), axis=2).T
                        X_alt = np.sum(np.array(np.split(X_alt_[:, :-(len(X_alt_[0]) % n_splits_list[i_split])],
                                                     indices_or_sections=n_splits_list[i_split], axis=1)), axis=2).T
                    corr_X = np.corrcoef(X)
                    corr_X_alt = np.corrcoef(X_alt)
                    hypermap_corrMean[i_freq, i_pow, i_port, i_split] = np.mean(corr_X[:int(len(corr_X)/2), :int(len(corr_X))])
                    hypermap_corrMean_alt[i_freq, i_pow, i_port, i_split] = np.mean(corr_X_alt[:int(len(corr_X_alt)/2), :int(len(corr_X_alt))])

                    hypermap_corrMem[i_freq, i_pow, i_port, i_split] = corr_X[5]
                    hypermap_corrMem_alt[i_freq, i_pow, i_port, i_split] = corr_X_alt[5]

        except:
            pass

min_corr = np.nanmin(hypermap_corrMean)
ind_min = np.unravel_index(np.nanargmin(hypermap_corrMean, axis=None), hypermap_corrMean.shape)
# print(f'Min corr.: {min_corr} at indices: {ind_min}')

min_corr_alt = np.nanmin(hypermap_corrMean_alt)
ind_min_alt = np.unravel_index(np.nanargmin(hypermap_corrMean_alt, axis=None), hypermap_corrMean_alt.shape)
# print(f'Min alt. corr.: {min_corr_alt} at indices: {ind_min_alt}')

chosen_split = 0
best_indices = []
maps = []
corrPlots = []
for i_port in range(len(port_list)):
    ### save and plot results
    # min_corr = np.nanmax(hypermap_corrSum)

    plt.figure()
    plt.rcParams.update({'font.size': 14})
    toPlot = np.nan_to_num(hypermap_corrMean_alt[:, :, i_port, chosen_split], 1) - np.nan_to_num(hypermap_corrMean[:, :, i_port, chosen_split], 1)
    maps.append(toPlot)
    plot_ = plt.imshow(toPlot.T, interpolation="none", origin='lower')
    plt.title(f'Comparison of perturbed and unperturbed delayed responses \n{structures[i_structure]} - {port_list[i_port]} port', fontsize=12)
    plt.xlabel('Input frequency index')
    plt.ylabel('Input power index')
    cbar = plt.colorbar(plot_)
    cbar.set_label("Difference of mean correlation")

    ind_maxCorrDiff = np.unravel_index(np.nanargmax(toPlot, axis=None), toPlot.shape)
    # ind_maxCorrDiff = (14,5)
    best_indices.append( ind_maxCorrDiff )
    plt.figure(figsize=(8,5))
    toPlot_ = hypermap_corrMem[ind_maxCorrDiff][i_port][0]
    plt.plot(toPlot_, label = 'Perturbed')
    toPlot_alt = hypermap_corrMem_alt[ind_maxCorrDiff][i_port][0]
    corrPlots.append([toPlot_, toPlot_alt])
    plt.plot(hypermap_corrMem_alt[ind_maxCorrDiff][i_port][0], label = 'Unperturbed')
    plt.title(f'{port_list[i_port]} port')
    plt.xlabel('Perturbation time (a.u.)')
    plt.ylabel('Correlation with initial state')
    plt.ylim([-1,1])
    plt.legend(loc=3)

# visualize chosen traces
import scipy.io as sio
if i_structure == 0:
    filename = 'newChip_RING-lowQ_2023-12-15_16-26'
elif i_structure == 1:
    filename = 'newChip_SCISSOR-2-1_2023-12-18_16-09'
elif i_structure == 2:
    filename = 'newChip_SCISSOR-3_2023-12-19_13-19'

i_freq, i_pow = best_indices[0]
trace_drop = sio.loadmat(filename + '/SPmap_2outs_iFreq' + str(i_freq) + '_iPow' + str(i_pow))['out_drop']
i_freq, i_pow = best_indices[1]
trace_through = sio.loadmat(filename + '/SPmap_2outs_iFreq' + str(i_freq) + '_iPow' + str(i_pow))['out_through']

plt.figure()
plt.plot(trace_drop)
plt.figure()
plt.plot(trace_through)

print(f'Chosen indices for drop and thorugh: {best_indices}')
# plot samples ne after another
n_points = len(X[0])
plt.figure()
for i_samp in range(len(X_extDrop[0][0])):
    plt.plot(X_extDrop[best_indices[0][0], best_indices[0][1], i_samp])
plt.title('Perturbed drop')
plt.figure()
for i_samp in range(len(X_extDrop[0][0])):
    plt.plot(np.arange(n_points * i_samp, n_points * (i_samp + 1)),
             X_extDrop[best_indices[0][0], best_indices[0][1], i_samp])
plt.title('Perturbed drop')
plt.figure()
for i_samp in range(len(X_extThrough[0][0])):
    plt.plot(X_extThrough[best_indices[1][0], best_indices[1][1], i_samp])
plt.title('Perturbed through')
plt.figure()
for i_samp in range(len(X_extThrough[0][0])):
    plt.plot(np.arange(n_points * i_samp, n_points * (i_samp + 1)),
             X_extThrough[best_indices[1][0], best_indices[1][1], i_samp])
plt.title('Perturbed through')

plt.figure()
for i_samp in range(len(X_extDrop_alt[0][0])):
    plt.plot(X_extDrop_alt[best_indices[0][0], best_indices[0][1], i_samp])
plt.title('UNperturbed drop')
plt.figure()
for i_samp in range(len(X_extDrop_alt[0][0])):
    plt.plot(np.arange(n_points * i_samp, n_points * (i_samp + 1)),
             X_extDrop_alt[best_indices[0][0], best_indices[0][1], i_samp])
plt.title('UNperturbed drop')
plt.figure()
for i_samp in range(len(X_extThrough_alt[0][0])):
    plt.plot(X_extThrough_alt[best_indices[1][0], best_indices[1][1], i_samp])
plt.title('UNperturbed through')
plt.figure()
for i_samp in range(len(X_extThrough_alt[0][0])):
    plt.plot(np.arange(n_points * i_samp, n_points * (i_samp + 1)),
             X_extThrough_alt[best_indices[1][0], best_indices[1][1], i_samp])
plt.title('UNperturbed through')

# save data
saved_data = [['structure', 'output_ports', 'X_drop_perturbed', 'X_drop_unperturbed', 'X_through_perturbed', 'X_through_unperturbed',
               'maps_of_correlation_trends_perturbed', 'maps_of_correlation_trends_unperturbed'],
              structures[i_structure],
              port_list,
              X_extDrop,
              X_extDrop_alt,
              X_extThrough,
              X_extThrough_alt,
              hypermap_corrMem,
              hypermap_corrMem_alt]

import pickle
file_ = open('./paper_data_and_example_figures/RawishData_andCorrelationData_'+savename[:-18]+'.pkl', 'wb')
pickle.dump(saved_data, file_)

# np.save('./paper_data_and_example_figures/RawishData_andCorrelationData_'+savename, saved_data)


import scipy.io as sio
sio.savemat( './paper_data_and_example_figures/RawishData_andCorrelationData_'+savename[:-18]+'.mat',
    {
        'structure' : structures[i_structure],
        'output_ports' : port_list,
        'X_drop_perturbed' : X_extDrop,
        'X_drop_unperturbed': X_extDrop_alt,
        'X_through_perturbed' : X_extThrough,
        'X_through_unperturbed' : X_extThrough_alt,
        'X_through_perturbed' : X_extThrough,
        'maps_of_correlation_trends_perturbed' : hypermap_corrMem,
        'maps_of_correlation_trends_unperturbed' : hypermap_corrMem_alt
    }
)
    # # data_loaded = sio.loadmat('./paper_data_and_example_figures/ComparedCorrelation_'+savename[:-4]+'.mat')

plt.show()


