import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os
import csv
import pandas as pd

from tqdm import tqdm

def last_4chars(x):
    return(x[-9:])

fps = 1e3 # frame per sec
dt = 1/fps # sec, time interval between the two frames
path = Path(r"C:\Users\lucas\Documents\UB\Stage_M1\newdata\piv\piv")


folderlist=os.listdir(path)
for fichier in folderlist[:]: # filelist[:] makes a copy of filelist.
    if not(fichier.startswith("SC1")):
        folderlist.remove(fichier)
num_folder = len(folderlist)

print(folderlist)

umpix = 0.222 # 0.222 um/pixel

data_exp_info = pd.read_csv(r'C:\Users\lucas\Documents\UB\Stage_M1\newdata\photoshock_indexes.csv', sep=';')

for fl in tqdm(range(num_folder)):
    txtlist=os.listdir(Path(path,folderlist[fl]))

    exp_name = folderlist[fl]
    
    temp_data = data_exp_info[data_exp_info['exp_name']==exp_name]

    for fichier in txtlist[:]: # filelist[:] makes a copy of filelist.
        if not(fichier.endswith(".txt")):
            txtlist.remove(fichier)
    txtlist = sorted(txtlist, key = last_4chars)  
    num_txt = len(txtlist)
    
    filename = Path(path,folderlist[fl],txtlist[0])
    a = np.loadtxt(filename)
    # parse
    x, y, u, v, flags = a[:, 0], a[:, 1], a[:, 2], a[:, 3], a[:, 4]
    offset = np.array([-1,1])
        
    indices = np.empty(2, dtype=object)
    
    for yoff in offset:

        if yoff==1:
            idx_y=0
        if yoff==-1:
            idx_y=1 
        
        x0 = temp_data['x0'].values[0]
        y0 = temp_data['y0'].values[0]

        date = exp_name.split('_')[1].split('-')[0]
        if date=='09': # check if camera is 400x400px or 496x496px
            video_height_px = 400
            print(date)
        else:
            video_height_px = 496

        y0 = video_height_px - y0 # inverse y0

        x0 = x0 + 5/umpix
        y0 = y0 + yoff*20/umpix
        
        xmax = x0*umpix+10
        xmin = x0*umpix-10
        ymin = y0*umpix-5
        ymax = y0*umpix+5
        
        indices[idx_y] = np.where((x>xmin) & (x<xmax) & (y>ymin) & (y<ymax))[0]
        
            
    # Preallocate a 2D list of lists to hold velocity time series for each region
    velocity = [[] for _ in offset]
    velocity_x = [[] for _ in offset]
    velocity_y = [[] for _ in offset]
    
    for ii in tqdm(range(num_txt)):
        filename = Path(path,folderlist[fl],txtlist[ii])
        a = np.loadtxt(filename)
        # parse
        x, y, u, v, flags = a[:, 0], a[:, 1], a[:, 2], a[:, 3], a[:, 4]
        
        for yoff in offset:
            if yoff==1:
                idx_y=0
            if yoff==-1:
                idx_y=1
            inds = indices[idx_y]
            if len(inds) > 0:
                vel_mag = np.sqrt(u[inds]**2 + v[inds]**2)
                velocity[idx_y].append(np.mean(vel_mag))
                velocity_x[idx_y].append(np.mean(u[inds]))
                velocity_y[idx_y].append(np.mean(v[inds]))
            else:
                velocity[idx_y].append(np.nan)
                velocity_x[idx_y].append(np.nan)
                velocity_y[idx_y].append(np.nan)
    
    t = np.arange(0,num_txt)*dt
    # Plotting and saving CSVs per region
    for yoff in offset:

        if yoff==1:
            idx_y=0
        if yoff==-1:
            idx_y=1

        v = velocity[idx_y]
        vx = velocity_x[idx_y]
        vy = velocity_y[idx_y]

        # Plot
        # plt.figure()
        # plt.plot(t, v, label='|u|')
        # plt.plot(t, vx, label='u_x')
        # plt.plot(t, vy, label='u_y')
        # plt.title(f"Velocity at pos x={xoff}, y={yoff}")
        # plt.legend()
        # plt.xlabel('Time [s]')
        # plt.ylabel('Velocity [um/s]')
        # plt.xlim(.8,2)
        # plt.grid()
        # plt.show()

        # Save to CSV
        newpath = Path(r"C:\Users\lucas\Documents\UB\Stage_M1\newdata\piv\velocities")
        header = ['time [s]', 'norm u [um/s]', 'u_x [um/s]', 'u_y [um/s]']
        data = np.transpose([t, v, vx, vy])
        filename = Path(newpath, f'velocities_{exp_name}_posy{yoff}.csv')
        with open(filename, 'w', encoding='UTF8', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(header)
            writer.writerows(data)
            