'''
(Alessio Lugnan, end of July 2025)
Here I try classification using only a single pixel at each time.
In this version, each pixel defines a different dataset
'''

import numpy as np
# from matplotlib import pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import mat73
import scipy.io as sio
from skimage.measure import block_reduce
# import time

n_train=55000
# pix_cutoff = 80   # pixel index from which to start classification (main results were calculated using 80)
pix_cutoff = 0   # pixel index from which to start classification (main results were calculated using 80)
# ind_selected_pow = 5
n_ports = 6
n_classes = 10
freq_ind = np.arange(2, 12)
# freq_ind = np.arange(5, 7)
pow_ind = np.arange(1, 6)
# pow_ind = np.arange(4, 5)
n_freq = 11
n_pow = len(pow_ind)
scaler = StandardScaler()

# folder_path = '/media/alessio_lugnan/Elements/foradori_parsed/X4Y1_processed_110725/awg_mnist/'
# AWGdata_path = '/media/alessio_lugnan/Elements/awg_mnist_010324.mat'

# folder_path = '/media/alessio_lugnan/Elements/foradori_parsed/X4Y1_processed_100725/awg_mnist_double/'
# AWGdata_path = '/media/alessio_lugnan/Elements/awg_mnist_double_140524.mat'

# folder_path = '/media/alessio_lugnan/Elements/foradori_parsed/X4Y1_processed_100725/awg_mnist_box/'
# AWGdata_path = '/media/alessio_lugnan/Elements/awg_mnist_box_100524.mat'

# folder_path = '/media/alessio_lugnan/Elements/foradori_parsed/X4Y1_processed_100725/awg_mnist_double_box/'
# AWGdata_path = '/media/alessio_lugnan/Elements/awg_mnist_double_box_140524.mat'

folder_path = '/media/alessio_lugnan/Elements/foradori_parsed/X4Y1_processed_100725/awg_mnist_inverse/'
AWGdata_path = '/media/alessio_lugnan/Elements/awg_mnist_inverse_280524.mat'

# folder_path = '/media/alessio_lugnan/Elements/foradori_parsed/X4Y1_processed_100725/awg_mnist_double_inverse/'
# AWGdata_path = '/media/alessio_lugnan/Elements/awg_mnist_double_inverse_280524.mat'

# folder_path = '/media/alessio_lugnan/Elements/foradori_parsed/X4Y1_processed_110725/awg_fashion_mnist/'
# AWGdata_path = '/media/alessio_lugnan/Elements/awg_fashion_mnist_040324.mat'

# folder_path = '/media/alessio_lugnan/Elements/foradori_parsed/X4Y1_processed_110725/awg_fashion_mnist_double/'
# AWGdata_path = '/media/alessio_lugnan/Elements/awg_fashion_mnist_double_280524.mat'

# load and visualize original AWG pattern
AWGdata = sio.loadmat(AWGdata_path)
y = AWGdata['label'][0]   # NB: here the 100 is a warm start index, and was already considered for the X array in the parsing code
np.random.seed(seed=15)
ind_start = np.arange(len(y))
ind_shuf = ind_start.copy()
y = y[ind_shuf]

i_port = 0
i_freq = 0
ind_port=1
ind_freq=3
# ind_pow=5
downsample_factor = 3
ind_selected_pow=5

def get_samplesets_1pix( n_train, folder_path, ind_shuf, y, n_ports, ind_selected_pow, downsample_factor=1, pix_cutoff=pix_cutoff ):
    '''Load and preprocess datasets from a given measurement.
    This version is for data parsed by Alessandro Foradori'''
    port_numbers = [1, 2, 3, 5, 7, 9]
    count = 0
    for i_port in range(0,1):
        for i_freq in range(0,1):
            print(f'Load i_port {i_port} i_freq {i_freq}, for i_pow {ind_selected_pow-1}')
            filename = f'processed_iPort{port_numbers[i_port]}_iFreq{i_freq+1}_iPow{ind_selected_pow}_iSet0.npz'
            try:
                X = np.load(folder_path + filename)['features']
                if downsample_factor != 1:
                    X = block_reduce(X, block_size=(1,downsample_factor), func=np.mean, cval=np.mean(X)).astype('int16')
                X = X[ind_shuf]
                X = X[:,pix_cutoff:]
                n_pix = np.shape(X)[1]
                if count == 0:
                    X_sp = X[:,:,None]
                else:
                    X_sp = np.concatenate( [X_sp, X[:,:,None]], axis=2)
                count += 1
            except:
                pass
    X_train = X_sp[:-10000, :].copy()
    X_test = X_sp[-10000:, :].copy()
    y_train = y[:-10000].copy()
    y_test = y[-10000:].copy()
    # split training and validation sets
    X_traintrain = X_train.copy()[:n_train]
    X_val = X_train.copy()[n_train:]
    y_traintrain = y_train.copy()[:n_train]
    y_val = y_train.copy()[n_train:]

    # return X_traintrain[:,:10], X_val[:,:10], X_test[:,:10], y_traintrain, y_val, y_test   # line for debug
    return X_traintrain, X_val, X_test, y_traintrain, y_val, y_test, n_pix

############# check classification ###################
# scores_sklearn = [ [] for i_pow in range(n_pow) ]
scores_sklearn = []
for i_pow in range(4,5):
    X_traintrain, X_val, X_test, y_traintrain, y_val, y_test, n_pix = get_samplesets_1pix( n_train=n_train,
                                                                                           folder_path=folder_path, ind_shuf=ind_shuf,
                                                                                           y=y, n_ports=n_ports, ind_selected_pow=i_pow+1,
                                                                                          downsample_factor=downsample_factor,
                                                                                          pix_cutoff=pix_cutoff)
    # classifier = models.Sequential(
    #     [
    #         layers.Dense(n_classes, activation='softmax'),
    #     ]
    # )
    # classifier.compile(optimizer='Adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    # scores = []
    # t0 = time.time()
    # predictions_val = []
    # predictions_test = []
    # predictions_train = []
    # for i_pix in range(n_pix):
    #     epochs_info = classifier.fit(X_traintrain[:,i_pix,:],
    #                                  y_traintrain,
    #                                  batch_size=batch_size,
    #                                  epochs=100,
    #                                  callbacks=[early_stopping_callback],
    #                                  validation_split=0.15,
    #                                  verbose=0)
    #     # classifier.summary()
    #     _, score_train = classifier.evaluate(X_traintrain[:,i_pix,:], y_traintrain, verbose=0)
    #     _, score_test = classifier.evaluate(X_test[:,i_pix,:], y_test, verbose=0)
    #     _, score = classifier.evaluate(X_val[:,i_pix,:], y_val, verbose=0)
    #     print(f'Scores for i_pix {i_pix}: val {score}, train {score_train}, test {score_test}')
    #     scores.append( [score, score_train, score_test] )
    #     predictions_val.append( classifier.predict(X_val[:,i_pix,:], verbose=0 ) )
    #     predictions_test.append(classifier.predict(X_test[:, i_pix, :], verbose=0))
    #     predictions_train.append(classifier.predict(X_traintrain[:, i_pix, :], verbose=0))
    #
    # scores = np.array(scores)
    # predictions_val = np.array(predictions_val)
    # predictions_test = np.array(predictions_test)
    # predictions_train = np.array(predictions_train)
    # t1 = time.time()
    # delta_time = t1-t0
    # print(f'Best val accuracy is for pixel {np.argmax(scores[:,0])}, yelding {scores[np.argmax(scores[:,0])]}, over a time {delta_time}')
    ##### try with sklearn logistic reg

    classifier = LogisticRegression(penalty='l2', dual=False, tol=0.0001, C=1e6, fit_intercept=True, intercept_scaling=1, class_weight=None, random_state=None,
                                    solver='lbfgs', max_iter=10000, verbose=0, warm_start=False, n_jobs=None, l1_ratio=None)
    # t0 = time.time()
    predictions_skl_val = []
    predictions_skl_test = []
    predictions_skl_train = []
    for i_pix in range(n_pix):
        epochs_info = classifier.fit(scaler.fit_transform(X_traintrain[:,i_pix,:]),
                                     y_traintrain)
        # classifier.summary()
        score_train = classifier.score(scaler.transform(X_traintrain[:,i_pix,:]), y_traintrain)
        score_test = classifier.score(scaler.transform(X_test[:,i_pix,:]), y_test)
        score = classifier.score(scaler.transform(X_val[:,i_pix,:]), y_val)
        print(f'Sklearn Scores for i_pix {i_pix}: val {score}, train {score_train}, test {score_test}')
        scores_sklearn.append( [score, score_train, score_test] )
        # predictions_skl_val.append( classifier.predict_proba(scaler.transform(X_val[:,i_pix,:]) ) )
        # predictions_skl_test.append(classifier.predict_proba(scaler.transform(X_test[:, i_pix, :])))
        # predictions_skl_train.append(classifier.predict_proba(scaler.transform(X_traintrain[:, i_pix, :])))

    scores_sklearn = np.array(scores_sklearn)
    # predictions_skl_val = np.array(predictions_skl_val)
    # predictions_skl_test = np.array(predictions_skl_test)
    # predictions_skl_train = np.array(predictions_skl_train)
    # t1 = time.time()
    # delta_time_sklearn = t1-t0
    print(f'Best val accuracy for i_pow {i_pow} is for pixel {np.argmax(scores_sklearn[:,0])}, '
          f'yelding {scores_sklearn[np.argmax(scores_sklearn[:,0])]}')

scores_sklearn = np.array(scores_sklearn)

# np.savez(f'Results_ML_singlePix2_ExpBASELINE_Mnist10ns_nTrain{n_train}_accu{scores_sklearn[np.argmax(scores_sklearn[:,2])][2]}',
#          scores_sklearn=scores_sklearn,
#          folder_path=folder_path,
#          n_train=n_train, pix_cutoff=pix_cutoff)
np.savez(f'Results_ML_singlePix2_ExpBASELINE_NOpIXcUTOFF_Mnist10ns_nTrain{n_train}',
         scores_sklearn=scores_sklearn,
         folder_path=folder_path,
         n_train=n_train, pix_cutoff=pix_cutoff)

# ind_best = np.unravel_index(np.argmax(scores_sklearn[:,:,2]), np.shape(scores_sklearn[:,:,2]))
# score_tuple_best = scores_sklearn[ind_best]
# print(f'Best score for i_pow {ind_best[0]}, pixel {ind_best[1]}: {score_tuple_best}')

############################ train classifier over 1st classifier's output ###################################
# Xf_train = np.reshape(np.swapaxes(predictions_skl_train, 0, 1),
#                       (len(predictions_skl_train[0]), len(predictions_skl_train)*len(predictions_skl_train[0,0])))
# Xf_val = np.reshape(np.swapaxes(predictions_skl_val, 0, 1),
#                       (len(predictions_skl_val[0]), len(predictions_skl_val)*len(predictions_skl_val[0,0])))
# Xf_test = np.reshape(np.swapaxes(predictions_skl_test, 0, 1),
#                       (len(predictions_skl_test[0]), len(predictions_skl_test)*len(predictions_skl_test[0,0])))
#
# reg_strengths = np.logspace(-2,6,13)
# scores_reg = []
# for i_reg in range(len(reg_strengths)):
#     classifier_f = LogisticRegression(penalty='l2', dual=False, tol=0.0001, C=1/reg_strengths[i_reg], fit_intercept=True, intercept_scaling=1, class_weight=None, random_state=None,
#                                     solver='lbfgs', max_iter=10000, verbose=0, warm_start=False, n_jobs=None, l1_ratio=None)
#     # t1_f = time.time()
#     epochs_info = classifier_f.fit(scaler.fit_transform(Xf_train), y_traintrain)
#     scoreF_train = classifier_f.score(scaler.transform(Xf_train), y_traintrain)
#     scoreF_test = classifier_f.score(scaler.transform(Xf_test), y_test)
#     scoreF = classifier_f.score(scaler.transform(Xf_val), y_val)
#     scores_reg.append(scoreF)
#     # t2_f = time.time()
#     print(f'Final reg. optimization with reg. strength {reg_strengths[i_reg]}: val {scoreF}, train {scoreF_train}, test {scoreF_test}')