"""
Code Author: Meryem Beyza Avci
Corresponding Author of the Article: Arif E. Cetin
Affiliated with NBS Lab
Article Title: "A Pre-Diagnostic Tool For a Rare Disease: Familial Mediterranean Fever"
Date: 06.09.2024

This code is used for the analyses and calculations described in the article.
"""

from OceanView import Spectrometer
from Model import *
from ExtendedUtils import *
from time import sleep
from PyQt5 import QtGui, QtWidgets, QtCore
from PyQt5.QtCore import QCoreApplication
from threading import Thread
import numpy as np
from PyQt5.QtWidgets import QLineEdit, QPushButton, QMessageBox, QFileDialog
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import serial.tools.list_ports
import serial
import os
import time
import datetime

global window_start_index, window_end_index, count, control_initial_area, control_patient_area, IgG_area, sensor_patient_area, patient_dir_path

database = Database()
BAUD_RATE = 9600
TIMEOUT = 0
Device = serial.Serial(baudrate = BAUD_RATE, timeout = TIMEOUT) #Initialize the serial communication and GUI
count = 0

#Safely stop the device connection.
def SafeClose():
    if(database.GetDeviceState()):
        DisconnectTheDevice()
        print("Successfully disconnected from the device.")

#Connect/Disconncet to the device
def ConnectDisconnect(ui):
    if(database.GetDeviceState()):
        if(database.GetRunningState()): 
            StopTheDevice(ui)

        DisconnectTheDevice()
        database.SetDeviceState(False)
        SetUIWidgets(ui, False)
        
    else:
        status = ConnectTheDevice()
        if(status == False): return

        database.SetDeviceState(True)
        SetUIWidgets(ui, True)

#Update the enability of the UI controls
def SetUIWidgets(ui, status):
    device = database.GetDeviceInstance()
    
    if(status):
        ui.deviceName_lbl.setText("Device: " + device.models[0])
        ui.connection_btn.setText("Disconnect")
        ui.label_5.setPixmap(QtGui.QPixmap("resources/connected_16x16.png"))

        for module in device.instances:
            database.AddModuleSerial(module.serial_number)

        ui.modules_cb.clear()
        list = sorted(database.GetModuleSerials())
        ui.modules_cb.addItems(list)
    else:
        ui.deviceName_lbl.setText("Device: Unknown") 
        ui.connection_btn.setText("Connect")
        ui.modules_cb.clear()
        database.moduleSerials = []
        ui.controlInitial_btn.setEnabled(False)

    ui.start_btn.setEnabled(status)
    ui.startStop_btn.setEnabled(status)
    ui.dataUpdate_sb.setEnabled(status)
    ui.integrationTime_sp.setEnabled(status)
    ui.modules_cb.setEnabled(status)
    ui.boxcarWidth_sp.setEnabled(status)
    ui.biosensorModeControls_gb.setEnabled(status)

#Connect to the device
def ConnectTheDevice():
    try:
        device = OceanDevice() #Generate a device object
        status = device.Connect()
        if(status == False): return False

        device.AssignDefaultValues()
        database.SetDeviceInstance(device) #Store the device to the database
    except:
        print("Error - No device was found!")
        return False

#Disconnect from the device
def DisconnectTheDevice():
    device = database.GetDeviceInstance()
    device.Disconnect()
    database.SetDeviceInstance(None)

#Run the device and get the spectrum
def StartTheDevice(ui, graph):
    SetDeviceRunningUI(ui, True)
    database.SetRunningState(True)
    ui.controlInitial_btn.setEnabled(True)
    device = database.GetDeviceInstance()
    graph.axes.set_xlim([GetMinWavelengthValue(device), GetMaxWavelengthValue(device)])

    while(database.GetRunningState()):
        sR = device.GetSampleRate() #msec

        sleep(sR /1000)

        if(ui.smoothen_cb.isChecked()): boxcar = ui.boxcarWidth_sp.value()
        else: boxcar = 0

        SetGraphAxis(graph, device)
        database.SetCurrentData(GetSpectrum(), boxcar)
        status = UpdateTheGraph(graph)

        if(status == False): ConnectDisconnect(ui)

#Update the graph axis based on the received spectrum
def SetGraphAxis(graph, device):
    graph.axes.clear()
    graph.axes.set_xlabel("Wavelength (nm)", fontsize=14)
    graph.axes.set_ylabel("Transmission", fontsize=14)
    graph.axes.set_ylim([0,1])

#Stop receiving data from the device.
def StopTheDevice(ui):
    SetDeviceRunningUI(ui, False)
    #This assignment automatically stops the running process 
    #of getting spectrum and updating the graph.
    database.SetRunningState(False)

#One time data request to the device. Automatically stops the device after receving one spectrum of data.
def StartStopDevice(ui, graph):
    Thread(target = StartTheDevice, args = (ui, graph),  name = "StartStopSub").start()
    sleep(database.GetDeviceInstance().GetSampleRate() / 1000 + 0.7)
    StopTheDevice(ui)
    CatchTheThread("StartStopSub") #Join to the main thread.

#Update the UI while the device is running or after stopping.
def SetDeviceRunningUI(ui, status):
    ui.start_btn.setEnabled(not status)
    ui.startStop_btn.setEnabled(not status)
    ui.stop_btn.setEnabled(status)

#Returns the spectrum
def GetSpectrum():
    device = database.GetDeviceInstance()
    return np.array(device.ReturnSpectrum())

#Update the graph data
def UpdateTheGraph(graph, color = None):

    data = database.GetCurrentData()
    data = Normalize(data)

    if(not len(data)): return False
    
    color = "#f72809"

    for w, i in data:
        graph.axes.plot(w, i, color=color)
    graph.draw()

    return True

def Normalize(data):
    maxVal = GetMaxDataValue(data[:, 1])
    data[:, 1] /= maxVal
    return data

def SampleRateChanged(value):
    device = database.GetDeviceInstance()
    device.SetSampleRate(value)
    
def IntegrationTimeChanged(selectedModuleIndex, value):
    database.SetIntegrationTime(selectedModuleIndex, value)
    device = database.GetDeviceInstance()
    device.SetIntegrationTime(selectedModuleIndex, value)

def SaveFile(ui, file_name):
    global patient_dir_path
    
    currentData = database.GetCurrentData()    
    currentData = Normalize(currentData) 
    column_width = 20
    file_path = "patients/" + file_name + ".txt"
    
    # check if the directory exists before creating it
    if file_name == "initial_control":
        patient_file_name = ui.patientFileNumber_le.text()
        patient_dir_path = "patients/" + patient_file_name
        if not os.path.exists(patient_dir_path):
            os.mkdir(patient_dir_path)
    
    with open(os.path.join(patient_dir_path, file_name + ".txt"), "w") as f:
        for w, i in currentData:
            for count in range(0, w.shape[0]):
                formatted_text = '{:<{width}}{}'.format(str(w[count]), str(i[count]), width=column_width)
                f.write(formatted_text)
                f.write("\n")

def UpdateStatus(ui, mode, boolean):
    if(mode == "initial_control"):
        if(boolean):
            ui.controlStatus_lbl.setPixmap(QtGui.QPixmap("resources\green_circle_48px.png"))
            ui.sensorInitial_btn.setEnabled(True)
        else:
            ui.controlStatus_lbl.setPixmap(QtGui.QPixmap("resources\yellow_circle_48px.png"))
    elif(mode == "initial_sensor"):
        ui.sensorStatus_lbl.setPixmap(QtGui.QPixmap("resources\green_circle_48px.png"))
        ui.sensorAG_btn.setEnabled(True)
    elif(mode == "AG_sensor"):
        if(boolean):
            ui.sensorAG_lbl.setPixmap(QtGui.QPixmap("resources\green_circle_48px.png"))
            ui.sensorIgG_btn.setEnabled(True)
        else:
            ui.sensorAG_lbl.setPixmap(QtGui.QPixmap("resources\yellow_circle_48px.png"))
    elif(mode == "IgG_sensor"):
        if(boolean):
            ui.sensorIgG_lbl.setPixmap(QtGui.QPixmap("resources\green_circle_48px.png"))
            ui.controlPatient_btn.setEnabled(True)
        else:
            ui.sensorIgG_lbl.setPixmap(QtGui.QPixmap("resources\yellow_circle_48px.png"))
    elif(mode == "patient_control"):
        if(boolean):
            ui.patientControl_lbl.setPixmap(QtGui.QPixmap("resources\green_circle_48px.png"))
            ui.sensorPatient_btn.setEnabled(True)
        else:
            ui.patientControl_lbl.setPixmap(QtGui.QPixmap("resources\yellow_circle_48px.png"))
    elif(mode == "patient_sensor"):
        ui.sensorPatient_lbl.setPixmap(QtGui.QPixmap("resources\green_circle_48px.png"))
        ui.runTest_btn.setEnabled(True)
        
def GetDataFromGraphs(ui, graph, mode):    
    valve = 0
    duration = 0.1
    UpdateStatus(ui, mode, False)
    QCoreApplication.processEvents()  # Force label update
    if((mode == "initial_control" or mode == "patient_control")):
        valve = 1
    if(not(mode == "initial_sensor" or mode == "patient_sensor")):
        try:
            StartThePump(ui, duration, valve)
            sleep(duration*60)
        except:
            pass
       
    UpdateStatus(ui, mode, True)
    SaveFile(ui, mode)

#After the sensor data is retrieved, the analyze function is called.
#The function basically saves the data and fills the area under curve in 60nm for illustrative purposes.
def Analyze(graph, mode):
    currentData = database.GetCurrentData()
    currentData = Normalize(currentData) 

    if(mode == "Biosensor_Control"):
        database.AddControlData(currentData)
    elif(mode == "Biosensor_Sensor"):
        database.AddSensorData(currentData)
        #Update the integration values based on the index that we are changing.
        #If the index equals to 0, than the peak value should be taken from the new value.
        #The parameters (integration, peak value) should be updated by the UpdateControlData function,
        #and the new graphs should be visualized.

        data = database.GetCurrentData()
        maxIndex = database.GetMaxDataIndex()

        x = data[maxIndex][0]
        y = data[maxIndex][1]
        ym = np.full(shape=len(y), fill_value=False, dtype=np.bool8)
        ym[database.windowStartIndex : database.windowEndIndex] = True

        graph.axes.fill_between(x, y, where = ym, color='gray')
        graph.draw()

def RunTest(ui, graphMain, graphIntegralControl, graphIntegralSensor):
    folder_path = "C:/Users/Beyza/Desktop/Spectroscopy_Basic/patients"  # Replace with the path of the folder you want to search
    folders = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, f))]
    latest_folder = max(folders, key=os.path.getctime)
    print("Latest created folder:", latest_folder)
    Import(ui, graphMain, graphIntegralControl, graphIntegralSensor, latest_folder)
    
def UpdateIntegrationTimeView(ui):
    selectedIndex = ui.modules_cb.currentIndex()
    ui.integrationTime_sp.setValue(database.GetIntegrationTime(selectedIndex))

def SaveData():
    peakW = database.GetPeakWavelength()
    controlData = database.Get

#Import Data Offline
def Import(ui, graphMain, graphIntegralControl, graphIntegralSensor, parent):
    
    if(parent == "archive"): folder_path = QFileDialog.getExistingDirectory(None, 'Select Patient Folder')
    else: folder_path = parent
    
    files = ['initial_control.txt', 'initial_sensor.txt', 'AG_sensor.txt', 'IgG_sensor.txt', 'patient_control.txt', 'patient_sensor.txt']
    count = 0
    control_initial_area = 0 
    control_patient_area = 0
    IgG_area = 0
    sensor_patient_area = 0
    
    for file_name in files:
        file_path = os.path.join(folder_path, file_name)
        with open(file_path, 'r') as file:
            data = file.read().splitlines()
                        
        x = []
        y = []
        
        for row in data:
            row = row.split()
            x.append(float(row[0]))
            y.append(float(row[1])) 
        
        if(file_name == files[0]):
            peak_value = max(y)
            peak_index = y.index(peak_value)
            window_start_index, window_end_index = SetIntegralWindow(x, y, graphMain, peak_index)
        
        area = CalculateIntegral(x, y, window_start_index, window_end_index)
        
        if(count == 0): 
            x_axis = 0.3
            color = "gray"
            control_initial_area = area
        if(count == 1): 
            x_axis = 0.1
            color = "gray"
        if(count == 2): 
            x_axis = 0.36
            color = "gray"
        if(count == 3): 
            x_axis = 0.62
            color = "gray"
            IgG_area = area
        if(count == 4): 
            x_axis = 0.7
            color = "gray"
            control_patient_area = area
        if(count == 5): 
            x_axis = 0.9
            color = "gray"
            sensor_patient_area = area
        count += 1
        graphMain.axes.set_xlim([600, 800])
        graphMain.axes.set_ylim([0, 1.05])
        graphMain.axes.plot(x, y, color = 'gray')
        graphMain.draw()
        
        if(file_name == files[0] or file_name == files[4]): 
            graphIntegralControl.axes2.bar(x_axis, area, color = color, width = 0.13)
            graphIntegralControl.draw()
        else: 
            graphIntegralSensor.axes3.bar(x_axis, area, color = color, width = 0.13)    
            graphIntegralSensor.draw()
            
    Detect(ui, control_initial_area, control_patient_area, IgG_area, sensor_patient_area, parent, folder_path)

def CalculateIntegral(x, y, window_start_index, window_end_index):
    y_values = np.array(y[window_start_index:window_end_index])
    area = np.sum(y_values)
    return area
    
def SetIntegralWindow(x, y, graph, peak_index):
    shift = 20 #starting point of the integral window will start at 6 nm shift from the control peak point
    peak_wavelength = x[peak_index]
    shifted_value = peak_wavelength + shift
    window_start_index = min(range(len(x)), key=lambda i: abs(x[i] - shifted_value))
    
    integral_window_width = 50 #nm
    window_end_wavelength = shifted_value + integral_window_width
    window_end_index = min(range(len(x)), key=lambda i: abs(x[i] - window_end_wavelength))
    
    ym = np.full(shape=len(y), fill_value=False, dtype=np.bool8)
    ym[window_start_index:window_end_index] = True
    graph.axes.fill_between(x, y, where = ym, color='gray')
    graph.draw()
    
    return window_start_index, window_end_index

def Detect(ui, control_initial_area, control_patient_area, IgG_area, sensor_patient_area, parent, folder_path):
    
    A = sensor_patient_area / IgG_area
    B = control_patient_area / control_initial_area

    print("A: " + str(A))
    print("B: " + str(B))

    result = A/B

    print("Result (A/B): " + str(result))
    result_fmf = ""
    
    if ((result < 24.9) and (result > 14.1)): 
        ui.testStatus_lbl.setStyleSheet("color: red; font-size: 24px;")
        ui.testStatus_lbl.setAlignment(QtCore.Qt.AlignHCenter)
        ui.testStatus_lbl.setText("POSITIVE")
        result_fmf = "FMF POSITIVE"
    if((2 < result) and (result < 4.9)): 
        ui.testStatus_lbl.setStyleSheet("color: green; font-size: 24px;")
        ui.testStatus_lbl.setAlignment(QtCore.Qt.AlignHCenter)
        ui.testStatus_lbl.setText("FMF NEGATIVE")
        result_fmf = "FMF NEGATIVE"
        
    ui.testStatus_lbl.setVisible(True)
    
    if(not parent == "archive"):
        file_name = "SI"
        with open(os.path.join(folder_path, file_name + ".txt"), "w") as f:
                patient_name = "File Name: " + str(ui.patientFileNumber_le.text())
                date = "Test Date: " + str(datetime.date.today())
                SI_value_sensor_patient_area = "SI Value of Patient Sensor: " + str(sensor_patient_area) 
                SI_value_IgG = "SI Value of IgG: " + str(IgG_area)
                A_txt =  "A: " + str(A)
                SI_value_control_patient_area = "SI Value of Patient Control: " + str(control_patient_area)
                SI_value_control_initial_area = "SI Value of Initial Control: " + str(control_initial_area)
                B_txt = "B: " + str(B)
                A_B_txt = "Result (A/B): " + str(result)
                result_fmf = "Result: " + result_fmf
                
                txt_file = [patient_name, date, SI_value_sensor_patient_area, SI_value_IgG, A_txt, SI_value_control_patient_area, SI_value_control_initial_area, B_txt, A_B_txt, result_fmf]

                for data in txt_file:
                    f.write(data)
                    f.write("\n")
#Pump Controls
def ConnectToThePort(ui):
    ports = list(serial.tools.list_ports.comports()) #List available COM ports
    uniqueID = 'VID:PID=0403:6001' #Unique ID for Arduino Nano (both Clon and Original)
    connectionStatus = False
    if(len(ports) != 0):
        for p in ports:
            id = p.hwid
            idFind = id.find(uniqueID) #Search the ID
            if(idFind != -1): #If function cannot find the ID, returns -1
                Device.port = p.device
                Device.close()
                Device.open() #Connect with Arduino
                ui.label_11.setPixmap(QtGui.QPixmap("resources/connected_16x16.png"))
                ui.connectionPump_btn.setText("Disconnect")
                ui.startPump_btn.setEnabled(True)
                ui.slow_btn.setEnabled(True)
                ui.medium_btn.setEnabled(True)
                ui.high_btn.setEnabled(True)
                ui.duration_sb.setEnabled(True)
                ui.stopPump_btn.setEnabled(True)
                break

def StartThePump(ui, duration, valve):
    try:
        voltage = 230 #Convert voltage value string to float.
        duration = float(duration) #Convert duration value string to float.
        duration *= 60000 #Duration should multiplied by 1000 because in Arduino IDE 1 second is equal to 1000 unit.
        command = str(voltage) + "~" + str(int(duration)) + "~" + str(valve)
        Device.write(bytes(command, 'UTF-8')) #Send the value to the Arduino Nano.
    except:
        pass      
    
def StopThePump(ui):
    '''StopThePump enables to stop the pump. User can stop the pump with using 'Stop Pump' PushButton in the main window.
    '''
    try:
        voltage = 0 #Convert voltage value string to float.
        duration = 0 #Convert duration value string to float.
        command = str(voltage) + "~" + str(int(duration))
        Device.write(bytes(command, 'UTF-8')) #Send the command to the Arduino Nano.
        QMessageBox(QMessageBox.Information, "Success!", "Pump has been stopped!").exec_() #Alert
    except:
        QMessageBox(QMessageBox.Warning, "Error!", "Could not send the data!").exec_() #Alert   