#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jan 15 12:41:24 2024

@author: ludovicocoletta
"""

import os 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import scipy.stats as stats

def scatter_plot_with_regression_line(x,y):

    # https://stackoverflow.com/questions/27164114/show-confidence-limits-and-prediction-limits-in-scatter-plot
    
    slope, intercept = np.polyfit(x, y, 1)  # linear model adjustment
    
    y_model = np.polyval([slope, intercept], x)   # modeling...
    
    x_mean = np.mean(x)
    y_mean = np.mean(y)
    n = x.size                        # number of samples
    m = 2                             # number of parameters
    dof = n - m                       # degrees of freedom
    t = stats.t.ppf(0.975, dof)       # Students statistic of interval confidence
    
    residual = y - y_model
    
    std_error = (np.sum(residual**2) / dof)**.5   # Standard deviation of the error
    
    # calculating the r2
    # https://www.statisticshowto.com/probability-and-statistics/coefficient-of-determination-r-squared/
    # Pearson's correlation coefficient
    numerator = np.sum((x - x_mean)*(y - y_mean))
    denominator = ( np.sum((x - x_mean)**2) * np.sum((y - y_mean)**2) )**.5
    correlation_coef = numerator / denominator
    r2 = correlation_coef**2
    
    # mean squared error
    MSE = 1/n * np.sum( (y - y_model)**2 )
    
    # to plot the adjusted model
    x_line = np.linspace(np.min(x), np.max(x), 100)
    y_line = np.polyval([slope, intercept], x_line)
    
    # confidence interval
    ci = t * std_error * (1/n + (x_line - x_mean)**2 / np.sum((x - x_mean)**2))**.5
    # predicting interval
    pi = t * std_error * (1 + 1/n + (x_line - x_mean)**2 / np.sum((x - x_mean)**2))**.5  

    return x_line,y_line, ci, pi

def main():
    
    os.makedirs('fig_clinical_validation',exist_ok=True)
    
    df=pd.read_csv('intersection_with_hubs.csv')
    colums=df.columns.to_list()
    data=df.to_numpy(dtype=float)
    
    index_of_int=[0,1,4,5] # we know this a priori
    
    for ii in range(0,3):
        print(colums[ii])
        print(stats.pearsonr(data[:,index_of_int[ii]],data[:,index_of_int[-1]]))
        
        if ('SEM' in colums[index_of_int[ii]]) or ('PHONO' in colums[index_of_int[ii]]) :
            colors=['lime','limegreen']
        else:
            colors=['red','darkred']
            
        if 'SEM' in colums[index_of_int[ii]]:
            func_of_int='SEMANTIC'
        elif 'PHONO' in colums[index_of_int[ii]]:
            func_of_int='PHONO'
        else:
            func_of_int='TOT_VOL'
            
        x_line,y_line, ci, pi=scatter_plot_with_regression_line(data[:,index_of_int[ii]],data[:,index_of_int[-1]])

        fig, ax = plt.subplots()
        ax.plot(x_line, y_line, color = 'black')
        ax.fill_between(x_line, y_line + ci, y_line - ci, color = colors[1], label = '95% confidence interval')
        ax.scatter(data[:,index_of_int[ii]],data[:,index_of_int[-1]],100,c=colors[0],edgecolors='black')
        #ax.fill_between(x_line, y_line + pi, y_line - pi, color = 'black', label = '95% prediction interval')            
        ax.set_ylim([0,100])
        
        if func_of_int=='TOT_VOL':
            ax.set_xlim([0,440])
            ax.set_xticks([0,200,400])
            ax.set_xlabel('Lesion volume (cm3)',fontsize=15,labelpad=0)
        else:
            ax.set_xlim([0,100])
            ax.set_xlabel('Overlap between lesion\nand DES derived network (%)',fontsize=15,labelpad=0)       
            ax.set_xticks([0,50,100])
            
        ax.set_yticks([0,50,100])
        ax.tick_params(labelsize=15)
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')
        #ax.set_xlabel('True values',fontsize=15,labelpad=0)
        #ax.set_ylabel('Predicted values',fontsize=15,labelpad=0)
        #ax.set_xlabel('Overlap between lesion\nand DES derived network (%)',fontsize=15,labelpad=0)
        ax.set_ylabel('Symptom severity\n(WAB-R AQ)',fontsize=15,labelpad=0)
        #ax.set_title('Correlated network',fontsize=20)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_position(('outward',5))
        ax.spines['bottom'].set_position(('outward',5))
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')     
        plt.savefig('fig_clinical_validation'+'/'+func_of_int+'_original_corr.png', format='png', dpi=900, transparent=True,bbox_inches='tight')

if __name__ == "__main__":
    main()     
