import os
import numpy as np
import scipy as sp
from scipy import stats
from scipy.optimize import curve_fit
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
import json
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.transforms as transforms
import seaborn as sns
from joblib import Parallel, delayed
from tqdm import tqdm

from IPython import embed as shell

# from tools_mcginley import utils
# from mcginley_pipeline import behavior

sns.plotting_context()
sns.set_palette("tab10")

def add_previous_trial_info(df, columns):

    repeats = (df['trial'].diff()==1)
    df['previous_trial_exists'] = repeats
    df.loc[repeats, ['{}_p'.format(c) for c in columns]] = df.shift(1).loc[repeats, columns].values
    for s in range(1,16):
        ind = (df['trial'].shift(s)==(df['trial']-s))
        df.loc[ind, ['{}_p{}'.format(c,s) for c in columns]] = df.shift(s).loc[ind, columns].values
    df['reward_obtained_p1_diff'] = 0
    df.loc[df['reward_obtained_p1']!=0, 'reward_obtained_p1_diff'] = df.loc[df['reward_obtained_p1']!=0,'reward_obtained_p1'].diff()

    return df

def add_next_trial_info(df, columns):
    repeats = (df['trial'].diff(-1)==-1)
    df['next_trial_exists'] = repeats
    df.loc[repeats, ['{}_n'.format(c) for c in columns]] = df.shift(-1).loc[repeats, columns].values
    return df

def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), sp.stats.sem(a)
    h = se * sp.stats.t.ppf((1 + confidence) / 2., n-1)
    return h

def fit_kaplanmeier(df):

    from lifelines import KaplanMeierFitter
    T, E = (df['trial_dur'], df['choice'])
    kmf = KaplanMeierFitter().fit(T, E)
    return kmf

def fit_weibull(df):

    from lifelines import WeibullFitter
    T, E = (df['trial_dur'], df['choice'])
    wbf = WeibullFitter().fit(T,E)
    return wbf

def fit_log_normal(df):

    from lifelines import LogNormalFitter
    T, E = (df['trial_dur'], df['choice'])
    lnf = LogNormalFitter().fit(T,E)
    return lnf

def fit_log_logistic(df):

    from lifelines import LogLogisticFitter
    T, E = (df['trial_dur'], df['choice'])
    llf = LogLogisticFitter().fit(T,E)
    return llf

def fit_gen_gamma(df):

    from lifelines import GeneralizedGammaFitter
    T, E = (df['trial_dur'], df['choice'])
    gg = GeneralizedGammaFitter().fit(T,E)
    return gg

def assign_random_rts(df):

    df['catch_session'] = df.groupby(['subject_id', 'session_id'])['catch'].transform('max')

    for c in [0,1]:
    
        nr_trials = int(1e6)
        target_mean = 5
        target_max = 11
        signal_dur = 3
        target_times = np.random.exponential(target_mean, nr_trials)
        
        if c == 0:
            target_times[target_times > target_max] = 11

            # for FAs:
            rts = np.ceil(df.loc[(df['catch_session']==c)&(df['fa']==1), 'rt']*100)/100
            noise_durs = np.zeros(len(rts))
            for rt in np.unique(rts):
                if rt < 11:
                    noise_durs[rts==rt] = np.random.choice(a=target_times[target_times>rt], size=np.sum(rts==rt), replace=True)
                else:
                    noise_durs[rts==rt] = 11
            df.loc[(df['catch_session']==c)&(df['fa']==1), 'noise_dur'] = noise_durs

        elif c == 1:
            
            # for FAs:
            target_times = target_times[target_times < target_max]
            rts = np.ceil(df.loc[(df['catch_session']==c)&(df['fa']==1)&(df['catch']==0), 'rt']*100)/100
            noise_durs = np.zeros(len(rts))
            for rt in np.unique(rts):
                try:
                    noise_durs[rts==rt] = np.random.choice(a=target_times[target_times>rt], size=np.sum(rts==rt), replace=True)
                except:
                    shell()
            df.loc[(df['catch_session']==c)&(df['fa']==1)&(df['catch']==0), 'noise_dur'] = noise_durs

            # for CRs:
            df.loc[(df['catch_session']==c)&(df['catch']==1), 'noise_dur'] = 14
    
    return df

def add_sdt_columns(df):

    df['outcome'] = 3
    df.loc[(df['rt']>df['noise_dur'])&(df['rt']<=(df['noise_dur']+3))&(~df['rt'].isna()), 'outcome'] = 0
    df.loc[(df['noise_dur']<14)&(df['rt'].isna()), 'outcome'] = 1
    df.loc[(df['rt']<=df['noise_dur'])&(~df['rt'].isna()), 'outcome'] = 2

    df['hit'] = (df['outcome']==0).astype(int)
    df['miss'] = (df['outcome']==1).astype(int)
    df['fa'] = (df['outcome']==2).astype(int)
    df['cr'] = (df['outcome']==3).astype(int)

    df['stimulus'] = ((df['outcome'] == 0) | (df['outcome'] == 1)).astype(int)
    df['response'] = ((df['outcome'] == 0) | (df['outcome'] == 2)).astype(int)
    df['correct'] = ((df['outcome'] == 0) | (df['outcome'] == 3)).astype(int)

    df['rt2'] = np.NaN
    df.loc[(df['outcome']==0), 'rt2'] = df.loc[(df['outcome']==0), 'rt'] - df.loc[(df['outcome']==0), 'noise_dur']

    df['trial_dur'] = np.NaN
    df.loc[(df['outcome']==0), 'trial_dur'] = df.loc[(df['outcome']==0), 'rt']
    df.loc[(df['outcome']==1), 'trial_dur'] = df.loc[(df['outcome']==1), 'noise_dur'] + 3
    df.loc[(df['outcome']==2), 'trial_dur'] = df.loc[(df['outcome']==2), 'rt']
    df.loc[(df['outcome']==3), 'trial_dur'] = df.loc[(df['outcome']==3), 'noise_dur']

    return df

def fast_hits_to_false_alarms(df, rt_cutoff=0.05):

    ind = (df['outcome']==0)&(df['rt2']<rt_cutoff)
    print('turning {}% of trials that are hits AND RT < {} into FAs'.format(round(ind.mean()*100,3), rt_cutoff))
    df.loc[ind, 'outcome'] = 2
    df.loc[ind, 'rt2'] = np.NaN

    # fix columns:
    df.loc[:,'hit'] = (df['outcome'] == 0).astype(int)
    df.loc[:,'fa'] = (df['outcome'] == 2).astype(int)
    df.loc[:,'stimulus'] = ((df['hit'] == 1) | (df['miss'] == 1)).astype(int)
    df.loc[:,'response'] = ((df['hit'] == 1) | (df['fa'] == 1)).astype(int)
    df.loc[:,'correct'] = ((df['hit'] == 1) | (df['cr'] == 1)).astype(int)
    
    # NaN:
    # df.loc[df['hit']!=1, 'rt'] = np.NaN
    # df.loc[df['stimulus']==0, 'time_signal_s'] = np.NaN
    # df.loc[df['choice']==0, 'time_report'] = np.NaN
    
    return df

def prepare_data_logistic_regression(df, min_dur=0.15, regularize=False):

    # columns = ['subject_id', 'session_id', 'reward', 'hit', 'miss', 'fa', 'cr', 'correct', 'choice', 'stimulus', 'noise_dur']
    # columns = [c for c in columns if c in df.columns]
    # df = df[columns]

    # count extra correct rejects:
    crs = df.loc[((df['hit']==1)|(df['miss']==1))&(df['noise_dur']>=min_dur),:].copy()
    crs.loc[:,'stimulus'] = 0
    crs.loc[:,'choice'] = 0
    crs.loc[:,'correct'] = 1
    crs.loc[:,'hit'] = 0
    crs.loc[:,'miss'] = 0
    crs.loc[:,'cr'] = 1
    crs.loc[:,'rt2'] = np.NaN
    df = pd.concat((df, crs), axis=0).reset_index(drop=True)
    df = df.sort_values(['subject_id', 'session_id', 'block_id', 'trial', 'stimulus']).reset_index(drop=True)

    # # regularize:
    # if regularize:
    #     df = regularize_data(df)

    return df

def trial_to_discrete_time(df, dt=0.1):

    dfs = []
    for i in range(df.shape[0]):
        
        d_i = df.iloc[i]

        # trial duration:
        trial_end = np.floor(d_i['trial_dur']*(1/dt))/(1/dt)

        # make dataframe:
        df_trial = pd.DataFrame({'time': np.arange(0,trial_end+dt,dt).round(3)})
        df_trial['signal'] = np.zeros(df_trial.shape[0])
        df_trial['event'] = np.zeros(df_trial.shape[0])
        
        # set response:
        if (d_i['hit']==1) | (d_i['fa']==1):
            df_trial.loc[df_trial['time']==trial_end, 'event'] = 1

        # set stimulus:
        if (d_i['hit']==1) | (d_i['miss']==1):
            target_start = np.floor(d_i['noise_dur']*(1/dt))/(1/dt)
            df_trial.loc[df_trial['time']>=target_start, 'signal'] = 1
        
        # add trial-wise columns:
        columns = ['subject_id', 'session_id', 'trial', 'block_id', 'reward', 'walk', 'pupil_trial_start', 'pupil_trial_start_c1']
        df_trial[columns] = d_i[columns]

        # append:
        dfs.append(df_trial)
    
    return pd.concat(dfs)

def make_discrete_time_dataframe(df, groupby=['subject_id', 'session_id'], dt=0.1, n_jobs=48):

    res = Parallel(n_jobs=n_jobs)(delayed(trial_to_discrete_time)(df=data, dt=dt)
                                    for ids, data in tqdm(df.groupby(groupby)))
    df = pd.concat(res).reset_index()

    return df

def simulate_data(rate=1, noise_dur_mean=5, noise_dur_max=11, rt_cutoff=0.15, nr_trials=100000):

    # generate trials:
    noise_durs = np.random.exponential(noise_dur_mean, nr_trials)
    noise_durs[noise_durs>noise_dur_max] = noise_dur_max
    rts = np.random.exponential(1/rate, nr_trials)

    # dataframe:
    df = pd.DataFrame({'noise_dur': noise_durs, 
                        'rt': rts,})

    # fix columns:
    df.loc[df['rt']>=14, 'rt'] = np.NaN
    df.loc[df['rt']>=(df['noise_dur']+3), 'rt'] = np.NaN

    # exclude fast trials:
    exclude_fast_trials = 1
    if exclude_fast_trials:
        df = df.loc[(df['rt']>=rt_cutoff)|df['rt'].isna(),:]

    # add sdt columns
    df = add_sdt_columns(df)

    # fast hits to false alarms:
    fast_hits_to_fas = 1
    if fast_hits_to_fas:
        df = fast_hits_to_false_alarms(df, rt_cutoff=rt_cutoff)

    # # add meta:
    # df[groupby] = df_emp[groupby].iloc[0]

    return df


# def simulate_data(df_emp=None, rate=None, method='licks_ps', censor_signals=0, noise_dur_mean=5, noise_dur_max=11, 
#                     sig_dur=3, rt_cutoff=0.15, exclude_fast_trials=True, fast_hits_to_fas=True):

#     # censor signals:
#     if df_emp is not None:
#         if censor_signals:
#             print('censoring signals!')
#             df_emp.loc[df_emp['hit']==1, 'choice'] = 0
#             df_emp.loc[df_emp['hit']==1, 'trial_dur'] = df_emp.loc[df_emp['hit']==1, 'noise_dur']
#             df_emp.loc[df_emp['miss']==1, 'choice'] = 0
#             df_emp.loc[df_emp['miss']==1, 'trial_dur'] = df_emp.loc[df_emp['miss']==1, 'noise_dur']
#             df_emp.loc[df_emp['trial_dur']<0.01, 'trial_dur'] = 0.01

#     # number of trials:
#     if df_emp is not None:
#         nr_trials = df_emp.shape[0] * 5
#     else:
#         nr_trials = 1000000

#     # generate trials:
#     noise_durs = np.random.exponential(noise_dur_mean, nr_trials)
#     noise_durs[noise_durs>noise_dur_max] = noise_dur_max

#     # generate rts:
#     if method == 'survival':
#         if df_emp is not None:
#             lnf = fit_log_normal(df_emp)
#             rts = np.random.lognormal(mean=lnf.mu_, sigma=lnf.sigma_, size=nr_trials)
#     elif method == 'mean_rt':
#         if df_emp is not None:
#             mean_rt = df_emp['rt'].mean()
#             rts = np.random.exponential(mean_rt, nr_trials)
#     elif method == 'licks_ps':
#         if df_emp is not None:
#             rate = (df_emp['hit'].sum() + df_emp['fa'].sum()) / df_emp['trial_dur'].sum()
#             rts = np.random.exponential(1/rate, nr_trials)
#         else:
#             rts = np.random.exponential(1/rate, nr_trials)

#     # dataframe:
#     df = pd.DataFrame({'noise_dur': noise_durs, 
#                         'rt': rts,})

#     # fix columns:
#     df.loc[df['rt']>=14, 'rt'] = np.NaN
#     df.loc[df['rt']>=(df['noise_dur']+3), 'rt'] = np.NaN

#     # exclude fast trials:
#     if exclude_fast_trials:
#         df = df.loc[(df['rt']>=rt_cutoff)|df['rt'].isna(),:]

#     # add sdt columns
#     df = add_sdt_columns(df)

#     # fast hits to false alarms:
#     if fast_hits_to_fas:
#         df = fast_hits_to_false_alarms(df, rt_cutoff=rt_cutoff)

#     # # add meta:
#     # df[groupby] = df_emp[groupby].iloc[0]

#     return df

def sdt(df, regularize=1):

    # counts:
    n_hit = ((df['signal']==1)&(df['event']==1)).sum()
    n_miss = ((df['signal']==1)&(df['event']==0)).sum()
    n_fa = ((df['signal']==0)&(df['event']==1)).sum()
    n_cr = ((df['signal']==0)&(df['event']==0)).sum()
        
    if regularize:
        n_hit += 0.5
        n_miss += 0.5
        n_fa += 0.5
        n_cr += 0.5
    
    # rates:
    hit_rate = n_hit / (n_hit + n_miss) / 0.01
    fa_rate = n_fa / (n_fa + n_cr) / 0.01
    
    # z-score:
    hit_rate_z = sp.stats.norm.isf(1-hit_rate)
    fa_rate_z = sp.stats.norm.isf(1-fa_rate)
    
    # measures:
    d = hit_rate_z - fa_rate_z
    c = -(hit_rate_z + fa_rate_z) / 2

    return pd.DataFrame({'hr':[hit_rate], 'far':[fa_rate], 'd':[d], 'c':[c], })

def tr_sdt(df, signal_dur=3, dt=0.01):

    from lifelines import KaplanMeierFitter
    kmf = KaplanMeierFitter()

    # fit survival of noise trials:
    T = np.concatenate((df.loc[(df['outcome']==2), 'rt'],
                        df.loc[(df['outcome']==0), 'noise_dur'],
                        df.loc[(df['outcome']==1), 'noise_dur'],
                        df.loc[(df['outcome']==3), 'noise_dur'],))

    E = np.concatenate((np.ones(sum(df['outcome']==2)),
                        np.zeros(sum(df['outcome']==0)),
                        np.zeros(sum(df['outcome']==1)),
                        np.zeros(sum(df['outcome']==3)),))
    kmf.fit_right_censoring(T, E)
    
    # interpolate at resolution determined by dt:
    km_far = kmf.predict(np.linspace(0,14,int((14/dt)+1))).reset_index()
    km_far.columns = ['time', 'KM']

    # to trials:
    km_far['KM_t'] = km_far['KM'] * T.shape[0]

    # convert to cumulative probability of licking during noise:
    km_far['CD'] = 1-km_far['KM']
    km_far['CD_t'] = T.shape[0]-km_far['KM_t']

    # compute probability of making a false alarm in the next 3 seconds, conditioned on not having false alarmed yet:
    km_far['CD_shift'] = km_far['CD'].shift(-int(signal_dur/dt))
    km_far['CD_conditional'] = (km_far['CD_shift'] - km_far['CD']) / (1-km_far['CD'])
    km_far['CD_shift_t'] = km_far['CD_t'].shift(-int(signal_dur/dt))
    km_far['CD_conditional_t'] = (km_far['CD_shift_t'] - km_far['CD_t'] + 0.5) / (T.shape[0]-km_far['CD_t']+1)

    # compute conditional hit-rates and fa-rates:
    d = df.loc[(df['outcome']==0)|(df['outcome']==1),:].copy()
    hr = (sum(d['outcome']==0)+0.5) / (sum(d['outcome']==0)+sum(d['outcome']==1)+1)
    far = np.mean(km_far['CD_conditional_t'].iloc[km_far['time'].searchsorted(d['noise_dur'])])

    # compute sdt metrics:
    tr_d = sp.stats.norm.ppf(hr) - sp.stats.norm.ppf(far)
    tr_c = -0.5 * (sp.stats.norm.ppf(hr) + sp.stats.norm.ppf(far))

    # make dataframe to return
    res = pd.DataFrame({'tr_hr':[hr], 'tr_far':[far], 'tr_d':[tr_d], 'tr_c':[tr_c],})
    return res

def sdt_ps(df, regularize=1):

    # counts:
    n_hit = df['hit'].sum()
    n_fa = df['fa'].sum()

    # durs:
    signal_dur = df.loc[df['hit']==1, 'rt2'].sum() + (df['miss'].sum()*3)
    noise_dur = ((df.loc[df['hit']==1, 'noise_dur']).sum() + 
                 (df.loc[df['miss']==1, 'noise_dur']).sum() + 
                 (df.loc[df['fa']==1, 'rt']).sum() + 
                 ((df['cr']==1).sum()*12))

    if regularize:
        n_hit += 1
        n_fa += 1
        signal_dur += 3
        noise_dur += 3

    # rates:
    h_ps = n_hit / signal_dur
    fa_ps = n_fa / noise_dur
    r_ps = (n_hit+n_fa) / (signal_dur + noise_dur)

    # measures:
    d_ps = h_ps - fa_ps
    d2_ps = (h_ps - fa_ps) / fa_ps
    d3_ps = (h_ps - fa_ps) / (h_ps + fa_ps) 
    d4_ps = np.log10(h_ps/fa_ps)
    c_ps = -(h_ps + fa_ps) / 2

    rr = n_hit / (signal_dur + noise_dur + (n_fa * 14))

    return pd.DataFrame({'h_ps':[h_ps], 'fa_ps':[fa_ps], 'r_ps':[r_ps], 'c_ps':[c_ps], 'd_ps':[d_ps], 'd2_ps':[d2_ps], 'd3_ps':[d3_ps], 'rr':[rr]})

def compute_metrics(df, groupby):

    df['reward_ml'] = 0
    df.loc[(df['reward']==0)&(df['outcome']==0), 'reward_ml'] = 2
    df.loc[(df['reward']==1)&(df['outcome']==0), 'reward_ml'] = 12
    
    df['trial_dur'] = 0
    df.loc[df['outcome']==0, 'trial_dur'] = df.loc[df['outcome']==0, 'noise_dur'] + df.loc[df['outcome']==0, 'rt2']
    df.loc[df['outcome']==1, 'trial_dur'] = df.loc[df['outcome']==1, 'noise_dur'] + 3
    df.loc[df['outcome']==2, 'trial_dur'] = df.loc[df['outcome']==2, 'rt'] + 14
    df.loc[df['outcome']==3, 'trial_dur'] = df.loc[df['outcome']==3, 'noise_dur']
    

    from functools import reduce

    if 'subject_id' not in groupby:
        df = df.select_dtypes(exclude=['object'])

    df_res1 = df.groupby(groupby).mean()
    df_res1_median = df.groupby(groupby).median()
    df_res1_fa_median = df.loc[df['outcome']==2,:].groupby(groupby)['rt'].median()
    df_res1_std = df.groupby(groupby).std()
    df_res1_sum = df.groupby(groupby).sum()
    df_res1['hit'] = df.loc[df['catch']!=1,:].groupby(groupby).mean()['hit']
    df_res1['fa'] = df.loc[df['catch']!=1,:].groupby(groupby).mean()['fa']
    df_res1['miss'] = df.loc[df['catch']!=1,:].groupby(groupby).mean()['miss']
    df_res1['rt2'] = df_res1_median['rt2']
    df_res1['fa_rt'] = df_res1_fa_median
    df_res1['rr'] = df_res1_sum['reward_ml'] / df_res1_sum['trial_dur']
    df_res2 = df.groupby(groupby).apply(tr_sdt)

    try:
        # df_res1['pupil_trial_start'] = df_res1_median['pupil_trial_start']
        # df_res1['pupil_trial_start_c1'] = df_res1_median['pupil_trial_start_c1']
        df_res1['pupil_trial_start_std'] = df_res1_std['pupil_trial_start']
        df_res1['pupil_trial_start_c1_std'] = df_res1_std['pupil_trial_start_c1']
        df_res1['pupil_trial_start_stability'] = df_res1['pupil_trial_start'] / (1 / df_res1['pupil_trial_start_std'])
        df_res1['pupil_trial_start_c1_stability'] = df_res1['pupil_trial_start_c1'] / (1 / df_res1['pupil_trial_start_c1_std'])
    except:
        pass

    df_res = reduce(lambda left, right: pd.merge(left,right, on=groupby), [df_res1, df_res2]).reset_index()

    return df_res

def bootstrap_sdt(df, groupby, iteration):
    df_sdt = df.groupby(groupby).sample(frac=1, replace=1).groupby(groupby).apply(sdt).reset_index()
    df_sdt['subject_id'] = iteration
    return df_sdt

def plot_baselines_across_trials(dfs, xs, measure, error='se', scale_sem=1):

    import matplotlib.transforms as transforms
    

    # nr_subjects = df.groupby(['subject_id']).count().shape[0]
    # nr_sessions = df.groupby(['subject_id', 'session_id']).count().shape[0]
    nr_blocks = len(dfs[0]['block_id'].unique())
    print(nr_blocks)
    colors = [sns.color_palette()[0], sns.color_palette()[1]]

    fig = plt.figure(figsize=(max((1.5,nr_blocks/2.5)),1.75))

    ax = fig.add_subplot(111)
    
    for b, df in dfs[0].groupby(['block_id']):

        b = int(b[0])
        mean = df.groupby(['subject_id', xs[0]]).mean().groupby([xs[0]])[measure].mean().reset_index()
        if error == 'se':
            sem = df.groupby(['subject_id', xs[0]]).mean().groupby([xs[0]])[measure].sem().reset_index()
        elif error == 'ci':
            ci = 66 #%
            ci = ci / 100
            sem = (df.groupby(xs[0]).quantile((1-(1-ci)/2))-df.groupby(xs[0]).quantile((0+(1-ci)/2))).reset_index()    
        x = mean[xs[0]]
        ax.fill_between(x, mean[measure]-(sem[measure]*scale_sem), mean[measure]+(sem[measure]*scale_sem), color=colors[b%2], alpha=0.1)
        ax.plot(x, mean[measure], lw=1, ls=':', color=colors[b%2])
    
    for b, df in dfs[1].groupby(['block_id']):
        b = int(b[0])
        mean = df.groupby(['subject_id', xs[1]]).mean().groupby([xs[1]])[measure].mean().reset_index()
        if error == 'se':
            sem = df.groupby(['subject_id', xs[1]]).mean().groupby([xs[1]])[measure].sem().reset_index()
        elif error == 'ci':
            ci = 66 #%
            ci = ci / 100
            sem = (df.groupby(xs[1]).quantile((1-(1-ci)/2))-df.groupby(xs[1]).quantile((0+(1-ci)/2))).reset_index()    
        x = mean[xs[1]]
        ax.fill_between(x, mean[measure]-(sem[measure]*scale_sem), mean[measure]+(sem[measure]*scale_sem), color=colors[b%2], alpha=0.1)
        ax.plot(x, mean[measure], lw=1, ls='-', color=colors[b%2])
    
    if nr_blocks == 1:
        trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
        plt.text(x=10, y=0.95, s='block 0', transform=trans, size=7)
        plt.xticks([0,30], [0,30])
    elif nr_blocks == 6:
        trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
        for i, t in enumerate([60,120,180,240,300,360]):
            if i > 0:
                plt.axvline(t, color='k', ls='--', lw=0.5)
            # plt.text(x=t+10, y=0.95, s='block {}'.format(i+1), transform=trans, size=7)
        plt.xticks([60,90,120,150,180,210,240,270,300,330,360,390], [0,30,0,30,0,30,0,30,0,30,0,30])
    elif nr_blocks == 7:
        trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
        for i, t in enumerate([0,60,120,180,240,300,360]):
            # if i > 0:
            plt.axvline(t, color='k', ls='--', lw=0.5)
            # plt.text(x=t+10, y=0.95, s='block {}'.format(i+1), transform=trans, size=7)
        plt.xticks([0,30,60,90,120,150,180,210,240,270,300,330,360,390], [0,30,0,30,0,30,0,30,0,30,0,30,0,30])
    

    # ax.set_title('{} mice; {} sessions'.format(nr_subjects, nr_sessions))
    ax.set_xlabel('Trial #')
    ax.set_ylabel(measure, color='black')

    sns.despine(trim=False)
    plt.tight_layout()

    return fig

def func_lin(x, a, b):
    return a + b * x

def func_quad(x, a, b, c):
    return a + (b * x) + (c * (x**2))

# def func(x, a, b, c, d):
#     return a * np.exp(b * x) + c * np.exp(d * x)

# def func(x, a, b,):
#     return a * np.exp(b * x)

# def func(x, a, b, c):
#     return np.exp(a * x) + b * x + c

# def func(x, a, b, c, d):
#     return a * np.exp(b * x) + c * x + d

def func(x, a, b, c):
    return a * np.log(x) + b * x + c


def plot_behavior_since_hit_mean(dfs, x_measure, y_measure, line, error='se', scale_sem=1, plot=True, fit=True):

    print()
    print()
    print()
    print()
    print('#################################################################')
    print(y_measure)
    print('#################################################################')

    if plot:
        fig = plt.figure(figsize=(1.5,1.75))
        ax = fig.add_subplot(111)
    else:
        fig = None

    for r, c, z in zip([0,1], sns.color_palette(), [0,1]):

        mean0 = dfs[0].loc[dfs[0]['reward']==r,:].groupby([x_measure]).mean().reset_index()
        mean0 = mean0.set_index(['trial_since_hit'])
        mean1 = dfs[1].loc[dfs[1]['reward']==r,:].groupby([x_measure]).mean().reset_index()
        mean1 = mean1.set_index(['trial_since_hit'])

        ci = 66 #%
        ci = ci / 100
        sem0 = (dfs[0].loc[dfs[0]['reward']==r,:].groupby([x_measure]).quantile((1-(1-ci)/2)) - 
                dfs[0].loc[dfs[0]['reward']==r,:].groupby([x_measure]).quantile((0+(1-ci)/2))).reset_index()
        sem0 = sem0.set_index(['trial_since_hit'])
        sem1 = (dfs[1].loc[dfs[1]['reward']==r,:].groupby([x_measure]).quantile((1-(1-ci)/2)) - 
                dfs[1].loc[dfs[1]['reward']==r,:].groupby([x_measure]).quantile((0+(1-ci)/2))).reset_index()
        sem1 = sem1.set_index(['trial_since_hit'])

        mean = (mean0 + mean1)/2
        sem = (sem0 + sem1)/2
        
        x = mean.index
        y = mean.loc[:, y_measure]
        s = sem.loc[:, y_measure]*scale_sem
        
        if plot:
            plt.fill_between(x, y-(s*scale_sem), y+(s*scale_sem), color=c, alpha=0.2, zorder=z)
            plt.plot(x, y, lw=1, color=c, zorder=z)
            ax.axvspan(line, ax.get_xlim()[1], color='green', alpha=0.1, lw=0)
            plt.xticks([0,30,60], [0,30,60])
            plt.xlabel('Trial from 1st hit')
            plt.ylabel(y_measure)
            sns.despine(trim=False)
            plt.tight_layout()
    return fig


def plot_behavior_since_hit(dfs, x_measure, y_measure, line, error='se', scale_sem=1, plot=True, fit=True):

    print()
    print()
    print()
    print()
    print('#################################################################')
    print(y_measure)
    print('#################################################################')


    if plot:
        fig = plt.figure(figsize=(1.5,1.75))
        ax = fig.add_subplot(111)
    else:
        fig = None

    plt_nr = 0
    for df in dfs:

        # try:
        mean = df.groupby(['reward', x_measure]).mean().reset_index()
        if error == 'se':
            sem = df.groupby(['reward', x_measure]).sem().reset_index()
        if error == 'ci':
            ci = 66 #%
            ci = ci / 100
            sem = (df.groupby(['reward', x_measure]).quantile((1-(1-ci)/2))-df.groupby(['reward', x_measure]).quantile((0+(1-ci)/2))).reset_index()
        
        # shell()

        # except Exception as e:
        #     print(e)
        #     mean = df.groupby(['reward', x_measure]).mean().reset_index()
        #     sem = df.groupby(['reward', x_measure]).sem().reset_index()


        lines = []
        extremes = []
        for r, c, z in zip([0,1], sns.color_palette(), [0,1]):
            ind = (mean['reward']==r)

            x = mean.loc[ind, x_measure]
            y = mean.loc[ind, y_measure]
            s = sem.loc[ind, y_measure]*scale_sem
            
            if plot:
                plt.fill_between(x, y-(s*scale_sem), y+(s*scale_sem), color=c, alpha=0.2, zorder=z)
                plt.plot(x, y, lw=1, color=c, zorder=z, ls=[':','-'][plt_nr])

            if fit:
                try:
                    # popt, pcov = curve_fit(func, x, y, p0=(y.mean()/2, 0, y.mean()/2, 0),)
                    popt, pcov = curve_fit(func, x, y)
                    print(popt) # This contains your three best fit parameters
                    curve_linexp = func(x, *popt)
                    lines.append(np.array(curve_linexp))
                    if plot:
                        plt.plot(x, curve_linexp, lw=0.75, ls='--', color='black')
                    # plt.plot(x, curve_lin, lw=0.75, ls='--', color='black')
                    # plt.plot(x, curve_log, lw=0.75, ls='--', color='black')
                    if popt[0] < 0:
                        func_min = func(popt[0] / -popt[1], *popt)
                        func_max = func(2, *popt)
                        cutoff = func_min + ((func_max-func_min)/20)
                        xx = np.linspace(0,50,501)
                        yy = func(xx, *popt)
                        extreme = np.floor(xx[np.where(yy<=cutoff)[0][0]])
                    elif popt[0] > 0:
                        func_min = func(2, *popt)
                        func_max = func(popt[0] / -popt[1], *popt)
                        cutoff = func_max - ((func_max-func_min)/20)
                        xx = np.linspace(0,50,501)
                        yy = func(xx, *popt)
                        extreme = np.floor(xx[np.where(yy>=cutoff)[0][0]])
                    extremes.append(extreme)
                    if plot:
                        plt.axvline(extreme, lw=1, ls='--', color=c,)
                except Exception as e:
                    print(e)
                    extremes.append(0)
                    pass
            # plt.errorbar(x=mean.loc[ind, x_measure], y=mean.loc[ind, y_measure], yerr=(sem.loc[ind, y_measure]*scale_sem), fmt='-o', markersize=3, color=c)
        # print(extremes)
        # print(lines)
        
        # print(lines)
        low_start = lines[0][0]
        low_end = lines[0][-1]
        low_max = max(lines[0])
        low_min = min(lines[0])
        high_start = lines[1][0]
        high_end = lines[1][-1]
        high_max = max(lines[1])
        high_min = min(lines[1])

        # print('low to high:')
        # # print(low_end)
        # # print(high_start)
        # print((high_start-low_end) / low_end * 100)
        # print((high_max-high_start) / high_start * 100)
        # print((high_min-high_start) / high_start * 100)

        # print('high to low:')
        # # print(high_end)
        # # print(low_start)
        # print((low_start-high_end) / high_end * 100)
        # print((low_max-low_start) / low_start * 100)
        # print((low_min-low_start) / low_start * 100)

        plt_nr += 1

    if plot:
        trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
        plt.text(x=15, y=0.95, s='{} {}'.format(int(extremes[0]), int(extremes[1])), transform=trans, size=7)
        ax.axvspan(line, ax.get_xlim()[1], color='green', alpha=0.1, lw=0)
        plt.xticks([0,30,60], [0,30,60])
        plt.xlabel('Trial from 1st hit')
        plt.ylabel(y_measure)
        sns.despine(trim=False)
        plt.tight_layout()

    return (
            fig, 
            (extremes[0], extremes[1]), 
            ((high_start-low_end) / low_end * 100, (high_max-high_start) / high_start * 100, (high_min-high_start) / high_start * 100,
            (low_start-high_end) / high_end * 100, (low_max-low_start) / low_start * 100, (low_min-low_start) / low_start * 100)
            )

def plot_scalars_across_blocks(df, x, y, error='sem', collapse=False):

    import statsmodels.api as sm
    from statsmodels.stats.anova import AnovaRM

    fig = plt.figure(figsize=(1.5,1.75))
    ax = fig.add_subplot(1,1,1)
    c0 = sns.color_palette()[0]
    c1 = sns.color_palette()[1]
    
    # # boxplot:
    # flierprops = dict(marker='o', markerfacecolor='black', markersize=2, linestyle='none')
    # sns.boxplot(x=x, y=m, data=df_res, ax=ax, palette=[c1,c0,c1,c0,c1,c0], linewidth=0.5, boxprops=dict(alpha=.5), flierprops=flierprops)
    # # for patch in ax.artists:
    # #     r, g, b, a = patch.get_facecolor()
    # #     patch.set_facecolor((r, g, b, .3))
    # for l in ax.lines:
    #     l.set_alpha(.5)

    if collapse:
        # mean = df.groupby(['reward', x])[y].mean().groupby([x]).mean().reset_index()
        # for subj, d in df.groupby(['subject_id']):
        #     df.loc[d.index, y] = df.loc[d.index, y] - df.loc[d.index,y].mean()
        # sem = df.groupby(['reward', x])[y].sem().groupby([x]).mean().reset_index()

        mean = df.groupby([x])[y].mean().reset_index()
        for subj, d in df.groupby(['subject_id']):
            df.loc[d.index, y] = df.loc[d.index, y] - df.loc[d.index,y].mean()
        sem = df.groupby([x])[y].sem().reset_index()


        plt.errorbar(x=mean[x], y=mean[y], yerr=sem[y], fmt='-o', color='black', markerfacecolor='lightgrey', linewidth=1.5, elinewidth=1)

        try:
            df_a = df.copy()
            aovrm = AnovaRM(df_a, y, 'subject_id', within=[x], aggregate_func='mean')
            res = aovrm.fit().anova_table.reset_index()
            plt.text(x=0.1, y=0.1, s='C: F({},{})={}, p={}'.format(
                                                                int(res.loc[res['index']==x, 'Num DF']),
                                                                int(res.loc[res['index']==x, 'Den DF']),
                                                                round(float(res.loc[res['index']==x, 'F Value']),1),
                                                                round(float(res.loc[res['index']==x, 'Pr > F']),3),
                                                                ), size=6, transform=ax.transAxes)
        except Exception as e:
            print(e)
            pass





    else:
        df_sem = df.copy()
        for subj, d in df_sem.groupby(['subject_id']):
            df_sem.loc[d.index, y] = df_sem.loc[d.index, y] - df_sem.loc[d.index,y].mean()
        

        mean = df.loc[df['reward']==0,:].groupby([x])[y].mean().reset_index()
        if error == 'sem':
            sem = df_sem.loc[df_sem['reward']==0,:].groupby([x])[y].sem().reset_index()
        elif error == 'conf':
            sem = df_sem.loc[df_sem['reward']==0,:].groupby([x])[y].apply(mean_confidence_interval).reset_index()
        plt.errorbar(x=mean[x], y=mean[y], yerr=sem[y], fmt='-o', color=c0, markerfacecolor='lightgrey', linewidth=1.5, elinewidth=1)

        mean = df.loc[df['reward']==1,:].groupby([x])[y].mean().reset_index()
        if error == 'sem':
            sem = df_sem.loc[df_sem['reward']==1,:].groupby([x])[y].sem().reset_index()
        elif error == 'conf':
            sem = df_sem.loc[df_sem['reward']==1,:].groupby([x])[y].apply(mean_confidence_interval).reset_index()
        plt.errorbar(x=mean[x], y=mean[y], yerr=sem[y], fmt='-o', color=c1, markerfacecolor='lightgrey', linewidth=1.5, elinewidth=1)

        try:
            df_a = df.copy()
            if x == 'block_id':
                if sum(df_a['block_id']==6) > 0:
                    df_a.loc[df_a['block_id']==2, 'block_id'] = 1
                    df_a.loc[df_a['block_id']==3, 'block_id'] = 2
                    df_a.loc[df_a['block_id']==4, 'block_id'] = 2
                    df_a.loc[df_a['block_id']==5, 'block_id'] = 3
                    df_a.loc[df_a['block_id']==6, 'block_id'] = 3
                else:
                    df_a.loc[df_a['block_id']==2, 'block_id'] = 1
                    df_a.loc[df_a['block_id']==3, 'block_id'] = 1
                    df_a.loc[df_a['block_id']==4, 'block_id'] = 2
                    df_a.loc[df_a['block_id']==5, 'block_id'] = 2
            aovrm = AnovaRM(df_a, y, 'subject_id', within=[x, 'reward'], aggregate_func='mean')
            res = aovrm.fit().anova_table.reset_index()
            print(y)
            print(aovrm.fit().summary())

            from pingouin import rm_anova
            aovrm2 = rm_anova(df_a, y, within=[x, 'reward'], subject='subject_id')
            print(aovrm2)


            plt.text(x=0.1, y=0.1, s='B: F({},{})={}, p={}\nR: F({},{})={}, p={}\nInt: F({},{})={}, p={}'.format(
                                                                                int(res.loc[res['index']==x, 'Num DF']),
                                                                                int(res.loc[res['index']==x, 'Den DF']),
                                                                                round(float(res.loc[res['index']==x, 'F Value']),1),
                                                                                round(float(res.loc[res['index']==x, 'Pr > F']),3),
                                                                                int(res.loc[res['index']=='reward', 'Num DF']),
                                                                                int(res.loc[res['index']=='reward', 'Den DF']),
                                                                                round(float(res.loc[res['index']=='reward', 'F Value']),1),
                                                                                round(float(res.loc[res['index']=='reward', 'Pr > F']),3),
                                                                                int(res.loc[res['index']=='{}:reward'.format(x), 'Num DF']),
                                                                                int(res.loc[res['index']=='{}:reward'.format(x), 'Den DF']),
                                                                                round(float(res.loc[res['index']=='{}:reward'.format(x), 'F Value']),1),
                                                                                round(float(res.loc[res['index']=='{}:reward'.format(x), 'Pr > F']),3),
                                                                                ), size=6, transform=ax.transAxes)
        except Exception as e:
            print(e)
            pass
    if x == 'block_id':
        plt.xticks([1,2,3,4,5,6], [1,2,3,4,5,6])
        plt.xlabel('Block #')
    elif x == 'block':
        plt.xticks([1,2,3,4], [1,2,3,4])
    ax.legend([],[], frameon=False)
    plt.ylabel(y)
    sns.despine(trim=False)
    plt.tight_layout()

    return fig

def plot_scalars_across_blocks_ptrial(df, x, y, error='sem', collapse=False):

    import statsmodels.api as sm
    from statsmodels.stats.anova import AnovaRM

    fig = plt.figure(figsize=(1.5,1.75))
    ax = fig.add_subplot(1,1,1)
    c0 = sns.color_palette()[0]
    c1 = sns.color_palette()[1]

    df_sem = df.copy()
    for subj, d in df_sem.groupby(['subject_id']):
        df_sem.loc[d.index, y] = df_sem.loc[d.index, y] - df_sem.loc[d.index,y].mean()
    
    mean = df.loc[(df['reward']==0)&(df['hit_p']==0),:].groupby([x])[y].mean().reset_index()
    sem = df_sem.loc[(df_sem['reward']==0)&(df_sem['hit_p']==0),:].groupby([x])[y].sem().reset_index()
    plt.errorbar(x=mean[x], y=mean[y], yerr=sem[y], fmt='--o', color=c0, markerfacecolor='lightgrey', linewidth=1, elinewidth=1)

    mean = df.loc[(df['reward']==0)&(df['hit_p']==1),:].groupby([x])[y].mean().reset_index()
    sem = df_sem.loc[(df_sem['reward']==0)&(df_sem['hit_p']==1),:].groupby([x])[y].sem().reset_index()
    plt.errorbar(x=mean[x], y=mean[y], yerr=sem[y], fmt='-o', color=c0, markerfacecolor='lightgrey', linewidth=1.5, elinewidth=1)

    mean = df.loc[(df['reward']==1)&(df['hit_p']==0),:].groupby([x])[y].mean().reset_index()
    sem = df_sem.loc[(df_sem['reward']==1)&(df_sem['hit_p']==0),:].groupby([x])[y].sem().reset_index()
    plt.errorbar(x=mean[x], y=mean[y], yerr=sem[y], fmt='--o', color=c1, markerfacecolor='lightgrey', linewidth=1, elinewidth=1)

    mean = df.loc[(df['reward']==1)&(df['hit_p']==1),:].groupby([x])[y].mean().reset_index()
    sem = df_sem.loc[(df_sem['reward']==1)&(df_sem['hit_p']==1),:].groupby([x])[y].sem().reset_index()
    plt.errorbar(x=mean[x], y=mean[y], yerr=sem[y], fmt='-o', color=c1, markerfacecolor='lightgrey', linewidth=1.5, elinewidth=1)


    try:
        df_a = df.copy()
        if x == 'block_id':
            if sum(df_a['block_id']==6) > 0:
                df_a.loc[df_a['block_id']==2, 'block_id'] = 1
                df_a.loc[df_a['block_id']==3, 'block_id'] = 2
                df_a.loc[df_a['block_id']==4, 'block_id'] = 2
                df_a.loc[df_a['block_id']==5, 'block_id'] = 3
                df_a.loc[df_a['block_id']==6, 'block_id'] = 3
            else:
                df_a.loc[df_a['block_id']==2, 'block_id'] = 1
                df_a.loc[df_a['block_id']==3, 'block_id'] = 1
                df_a.loc[df_a['block_id']==4, 'block_id'] = 2
                df_a.loc[df_a['block_id']==5, 'block_id'] = 2
        aovrm = AnovaRM(df_a, y, 'subject_id', within=[x, 'reward', 'hit_p'], aggregate_func='mean')
        res = aovrm.fit().anova_table.reset_index()
        print(y)
        print(aovrm.fit().summary())

        # from pingouin import rm_anova
        # aovrm2 = rm_anova(df_a, y, within=[x, 'reward', 'hit_p'], subject='subject_id')
        # print(aovrm2)


        # plt.text(x=0.1, y=0.1, s='B: F({},{})={}, p={}\nR: F({},{})={}, p={}\nInt: F({},{})={}, p={}'.format(
        #                                                                     int(res.loc[res['index']==x, 'Num DF']),
        #                                                                     int(res.loc[res['index']==x, 'Den DF']),
        #                                                                     round(float(res.loc[res['index']==x, 'F Value']),1),
        #                                                                     round(float(res.loc[res['index']==x, 'Pr > F']),3),
        #                                                                     int(res.loc[res['index']=='reward', 'Num DF']),
        #                                                                     int(res.loc[res['index']=='reward', 'Den DF']),
        #                                                                     round(float(res.loc[res['index']=='reward', 'F Value']),1),
        #                                                                     round(float(res.loc[res['index']=='reward', 'Pr > F']),3),
        #                                                                     int(res.loc[res['index']=='{}:reward'.format(x), 'Num DF']),
        #                                                                     int(res.loc[res['index']=='{}:reward'.format(x), 'Den DF']),
        #                                                                     round(float(res.loc[res['index']=='{}:reward'.format(x), 'F Value']),1),
        #                                                                     round(float(res.loc[res['index']=='{}:reward'.format(x), 'Pr > F']),3),
        #                                                                     ), size=6, transform=ax.transAxes)
    except Exception as e:
        print(e)
        pass
    if x == 'block_id':
        plt.xticks([1,2,3,4,5,6], [1,2,3,4,5,6])
        plt.xlabel('Block #')
    elif x == 'block':
        plt.xticks([1,2,3,4], [1,2,3,4])
    ax.legend([],[], frameon=False)
    plt.ylabel(y)
    sns.despine(trim=False)
    plt.tight_layout()

    return fig

def sequential_mixed_regression(df, x_measure, y_measure, order=5):
    import statsmodels.formula.api as smf
    
    df['y'] = df[y_measure]
    bics = []
    for o in range(0,order+1):
        print(o)
        if o == 0:
            md = smf.mixedlm('{} ~ 1'.format(y_measure), df, groups=df["subject_id"], 
                             re_formula='~1')
        elif o == 1:
            md = smf.mixedlm('{} ~ 1 + {}'.format(y_measure, x_measure), df, groups=df["subject_id"], 
                             re_formula='~1') # +{}'.format(x_measure)
        elif o == 2:
            md = smf.mixedlm('{} ~ 1 + {} + np.power({}, 2)'.format(y_measure, x_measure, x_measure), df, groups=df["subject_id"],
                             re_formula='~1') #+{}+np.power({},2)'.format(x_measure, x_measure)
        elif o == 3:
            md = smf.mixedlm('{} ~ 1 + {} + np.power({}, 2) + np.power({}, 3)'.format(y_measure, x_measure, x_measure, x_measure), df, groups=df["subject_id"],
                             re_formula='~1') #+{}+np.power({},2)+np.power({},3)'.format(x_measure, x_measure, x_measure)
        mdf = md.fit(reml=False)
        bics.append(mdf.bic)

    print()
    print(bics)
    print()

    model = 0
    for o in range(1,order+1):
        if (bics[o]-bics[o-1]) < -10:
            model = o

    return model

def permutationTest_correlation(a, b, tail=0, nrand=10000):
    """
    test whether 2 correlations are significantly different. For permuting single corr see randtest_corr2
    function out = randtest_corr(a,b,tail,nrand, type)
    tail = 0 (test A~=B), 1 (test A>B), -1 (test A<B)
    type = 'Spearman' or 'Pearson'
    """

    import numpy as np
    import numpy.random as random

    ntra = a.shape[0]
    ntrb = b.shape[0]
    truecorrdiff = sp.stats.pearsonr(a[:,0],a[:,1])[0] - sp.stats.pearsonr(b[:,0],b[:,1])[0]
    # truecorrdiff = sp.stats.spearmanr(a[:,0],a[:,1])[0] - sp.stats.spearmanr(b[:,0],b[:,1])[0]
    alldat = np.vstack((a,b))
    corrdiffrand = np.zeros(nrand)
    indices = np.arange(alldat.shape[0])

    for irand in range(nrand):
        random.shuffle(indices)
        randa = sp.stats.pearsonr(alldat[indices[:ntra],0],alldat[indices[:ntra],1])[0]
        randb = sp.stats.pearsonr(alldat[indices[ntra:],0],alldat[indices[ntra:],1])[0]
        # randa = sp.stats.spearmanr(alldat[indices[:ntra],0],alldat[indices[:ntra],1])[0]
        # randb = sp.stats.spearmanr(alldat[indices[ntra:],0],alldat[indices[ntra:],1])[0]
        corrdiffrand[irand] = randa - randb
    
    if tail == 0:
        p_value = sum(abs(corrdiffrand) >= abs(truecorrdiff)) / float(nrand)
    else:
        p_value = sum(tail*(corrdiffrand) >= tail*(truecorrdiff)) / float(nrand)

    return(truecorrdiff, p_value)

def sequential_regression(df, x_measure, y_measure, order=5):
    import statsmodels.api as sm
    
    df['y'] = df[y_measure]
    F_values = []
    p_values = []
    model_dfs = []
    resid_dfs = []
    for o in range(0,order+1):
        df['x'] = (df[x_measure]-df[x_measure].mean())**o
        results = sm.OLS(df['y'], df['x']).fit() 
        # print(results.summary())
        F_values.append(results.fvalue)
        p_values.append(results.f_pvalue)
        model_dfs.append(results.df_model)
        resid_dfs.append(results.df_resid)
        df['y'] = results.resid

    model = 0
    for o in range(0,order+1):
        if p_values[o] < 0.05:
            model = o

    return F_values, p_values, model_dfs, resid_dfs, model

def arbitrary_poly(x, *params):
    return sum([p*(x**i) for i, p in enumerate(params)])

def baseline_pupil_behavior(df_res, x_measure='pupil_trial_start', 
                            bin_by='pupil_trial_start_bin',
                            measures = ['tr_c', 'tr_d', 'rt2', 'hit', 'reward_ml'],
                            color='grey', axes=None):

    if axes is None:
        fig, axes = plt.subplots(nrows=1, ncols=len(measures), figsize=(len(measures)*1.5,1.75))
        return_fig = True
    else:
        return_fig = False
    plt_nr = 1
    for i, m in enumerate(measures):
        
        x = df_res.loc[(df_res[bin_by]!=max(df_res[bin_by])), :].groupby([bin_by])[x_measure].mean()
        y = df_res.loc[(df_res[bin_by]!=max(df_res[bin_by])), :].groupby([bin_by])[m].mean()
        x_sem = df_res.loc[(df_res[bin_by]!=max(df_res[bin_by])), :].groupby([bin_by])[x_measure].sem()
        y_sem = df_res.loc[(df_res[bin_by]!=max(df_res[bin_by])), :].groupby([bin_by])[m].sem()
        axes[i].errorbar(x, y, xerr=x_sem, yerr=y_sem, fmt='o', color=color, markerfacecolor='lightgrey', markeredgewidth=0.75)
        axes[i].set_ylabel(m)
        axes[i].set_xlabel('Pupil size (% max)')
        fit = True
        if fit:
            
            print()
            # with all subject values included:
            df_seq_reg = df_res.loc[(df_res[bin_by]!=max(df_res[bin_by])), ['subject_id', bin_by, x_measure, m]]
            df_seq_reg = df_seq_reg.loc[~(df_seq_reg[x_measure].isna()|df_seq_reg[m].isna())]
            print(df_seq_reg)
            # model = sequential_mixed_regression(df=df_seq_reg, x_measure=x_measure, y_measure=m, order=2)
            # print(model)
            # across the group:
            df_seq_reg = df_res.loc[(df_res[bin_by]!=max(df_res[bin_by])),:].groupby([bin_by])[[x_measure, m]].mean()
            F_values, p_values, model_dfs, resid_dfs, model = sequential_regression(df=df_seq_reg, x_measure=x_measure, y_measure=m, order=5)
            print()
            print(df_seq_reg)
            print(F_values)
            print(model_dfs)
            print(resid_dfs)
            print(p_values)
            try:
                popt, pcov = curve_fit(arbitrary_poly, x, y, p0=[1]*(model+1))
                # print(popt) # This contains your three best fit parameters
                xx = np.linspace(min(x), max(x), 100)
                curve_linexp = arbitrary_poly(xx, *popt)
                axes[i].plot(xx, curve_linexp, lw=1.5, ls='-', color=color, zorder=0)
            except Exception as e:
                print(e)
                pass
            
        # ax = ax.twinx()
        x = df_res.loc[(df_res[bin_by]==max(df_res[bin_by])), :].groupby([bin_by])[x_measure].mean()
        y = df_res.loc[(df_res[bin_by]==max(df_res[bin_by])), :].groupby([bin_by])[m].mean()
        x_sem = df_res.loc[(df_res[bin_by]==max(df_res[bin_by])), :].groupby([bin_by])[x_measure].sem()
        y_sem = df_res.loc[(df_res[bin_by]==max(df_res[bin_by])), :].groupby([bin_by])[m].sem()
        axes[i].errorbar(x, y, xerr=x_sem, yerr=y_sem, fmt='x', color=color)
    plt_nr += 1
    
    sns.despine(trim=False)
    plt.tight_layout()

    if return_fig:
        return fig
    else:
        return axes

def sdt_bars(df, m):

    fig = plt.figure(figsize=(1.5,1.75))
    ax = fig.add_subplot(111)
    sns.boxplot(x='reward', y=m, data=df, ax=ax, linewidth=0.5)
    # sns.stripplot(x='reward', y=m, jitter=False, dodge=True, color='k', data=df, ax=ax)
    for s, d in df.groupby(['subject_id']):
        ax.plot([0,1], [d.loc[d['reward']==0,m], d.loc[d['reward']==1,m]],
                color='grey', linewidth=0.5, alpha=0.5, linestyle='-', zorder=-1)
    fraction = round( (( np.array(df.loc[df['reward']==1,m]) - np.array(df.loc[df['reward']==0,m]) ) > 0).mean()*100,2)
    fraction = max((fraction, 100-fraction))
    p = round(sp.stats.ttest_rel(df.loc[df['reward']==0,m], df.loc[df['reward']==1,m])[1], 3)
    if p < 0.001:
        plt.title('{}% of mice\np < {}'.format(fraction, 0.001))
    else:
        plt.title('{}% of mice\np = {}'.format(fraction, p))
    sns.despine(trim=False)
    plt.tight_layout()
    return fig

def kaplan_meier_plots(df, nbins=10, dt=0.1):

    from lifelines import KaplanMeierFitter

    # assign random signal onset times to FA trials
    for c in [0,1]:
        nr_trials = int(1e6)
        target_mean = 5
        target_max = 11
        signal_dur = 3
        target_times = np.random.exponential(target_mean, nr_trials)
        if c == 0:
            target_times[target_times > target_max] = 11
            rts = np.ceil(df.loc[(df['catch_session']==c)&(df['fa']==1), 'rt']*100)/100
            noise_durs = np.zeros(len(rts))
            for rt in np.unique(rts):
                if rt < 11:
                    noise_durs[rts==rt] = np.random.choice(a=target_times[target_times>rt], size=np.sum(rts==rt), replace=True)
                else:
                    noise_durs[rts==rt] = 11
            df.loc[(df['catch_session']==c)&(df['fa']==1), 'noise_dur'] = noise_durs
        elif c == 1:
            target_times = target_times[target_times < target_max]
            rts = np.ceil(df.loc[(df['catch_session']==c)&(df['fa']==1)&(df['catch']==0), 'rt']*100)/100
            noise_durs = np.zeros(len(rts))
            for rt in np.unique(rts):
                noise_durs[rts==rt] = np.random.choice(a=target_times[target_times>rt], size=np.sum(rts==rt), replace=True)
            df.loc[(df['catch_session']==c)&(df['fa']==1)&(df['catch']==0), 'noise_dur'] = noise_durs
            df.loc[(df['catch_session']==c)&(df['catch']==1), 'noise_dur'] = 14

    # bin:
    df.loc[(df['noise_dur']<11), 'bins'] = pd.qcut(df.loc[(df['noise_dur']<11), 'noise_dur'], nbins, labels=False)
    df.loc[df['noise_dur']==11, 'bins'] = nbins+1
    df.loc[df['noise_dur']==14, 'bins'] = nbins+2

    # shift times:
    df['start'] = df['noise_dur'].copy()
    df['start_s'] = df.groupby(['bins'])['start'].transform(lambda x: x.min())
    df.loc[df['outcome']==0, 'rt'] = df.loc[df['outcome']==0, 'rt'] - (df.loc[df['outcome']==0, 'start']-df.loc[df['outcome']==0, 'start_s'])

    shell()

    # first for noise dur of 11s:
    km11 = []
    for r, d in df.loc[df['start_s']==11,:].groupby(['reward']):
        T = np.concatenate((d.loc[(d['outcome']==0), 'rt'],
                            d.loc[(d['outcome']==1), 'start_s']+signal_dur,
                            d.loc[(d['outcome']==2), 'rt'],
                            d.loc[(d['outcome']==3), 'start_s'],
                            ))

        E = np.concatenate((np.ones(sum(d['outcome']==0)),
                            np.zeros(sum(d['outcome']==1)),
                            np.ones(sum(d['outcome']==2)),
                            np.zeros(sum(d['outcome']==3)),
                            ))
        kmf = KaplanMeierFitter()
        kmf.fit_right_censoring(T, E)
        km_hr = kmf.predict(np.linspace(0,11+signal_dur,int(((11+signal_dur)/dt)+1))).reset_index()
        km_hr.columns = ['time', 'KM']
        km11.append(km_hr.copy())

    # fit and plot:
    fig = plt.figure(figsize=(4,2))
    for (r,b), d in df.groupby(['reward', 'start_s']):

        plt.figure()

        # fit survival of signal trials:
        T = np.concatenate((d.loc[(d['outcome']==0), 'rt'],
                            d.loc[(d['outcome']==1), 'start_s']+signal_dur,
                            d.loc[(d['outcome']==2), 'rt'],
                            d.loc[(d['outcome']==3), 'start_s'],
                            ))

        E = np.concatenate((np.ones(sum(d['outcome']==0)),
                            np.zeros(sum(d['outcome']==1)),
                            np.ones(sum(d['outcome']==2)),
                            np.zeros(sum(d['outcome']==3)),
                            ))
        kmf = KaplanMeierFitter()
        kmf.fit_right_censoring(T, E)
        km_hr = kmf.predict(np.linspace(0, b+signal_dur,int(((b+signal_dur)/dt)+1))).reset_index()
        km_hr.columns = ['time', 'KM']

        if not b == 14:
            if b == 11:
                plt.plot(km_hr.loc[km_hr['time']>=b, 'time'], 
                         km_hr.loc[km_hr['time']>=b, 'KM'], color=sns.color_palette()[r], ls='-')
                plt.plot(km_hr.loc[km_hr['time']<b, 'time'], 
                         km_hr.loc[km_hr['time']<b, 'KM'], color=sns.color_palette()[r], ls='--')
            else:
                offset = (km_hr.loc[km_hr['time']>=b, 'KM'].iloc[0] - 
                          km11[r]['KM'].iloc[km11[r]['time'].searchsorted(b)])
                print(offset)
                plt.plot(km_hr.loc[km_hr['time']>=b, 'time'], 
                         km_hr.loc[km_hr['time']>=b, 'KM']-offset, color=sns.color_palette()[r])
                if r == 1:
                    print(km_hr.loc[km_hr['time']>=b, 'time'])
            if r == 0:
                plt.axvline(b, color='k', lw=0.5, ls='--')

    plt.ylabel('P(survival)')
    plt.xlabel('Time (s)')
    sns.despine(trim=False, offset=3)
    plt.tight_layout()
    return fig

def cox_prepare_data(df, conditions=[]):

    from lifelines.utils import to_long_format

    # part 1:
    df_noise = df.copy()
    df_noise['event'] = 0
    df_noise.loc[df_noise['fa']==1, 'event'] = 1
    df_noise.loc[(df_noise['miss']==1)|(df_noise['hit']==1), 'duration'] = df_noise.loc[(df_noise['miss']==1)|(df_noise['hit']==1), 'noise_dur']
    df_noise.loc[(df_noise['fa']==1)|(df_noise['cr']==1), 'duration'] = df_noise.loc[(df_noise['fa']==1)|(df_noise['cr']==1), 'trial_dur']
    # df_noise.loc[df_noise['duration'].isna(), 'duration'] = df_noise.loc[df_noise['duration'].isna(), 'trial_dur']
    df_noise['id'] = np.array(df_noise.index)
    df_noise = to_long_format(df_noise, duration_col="duration")
    df_noise['signal'] = 0

    # part 2:
    df_signal = df.loc[(df['miss']==1)|(df['hit']==1)].copy()
    df_signal['event'] = 0
    df_signal.loc[df_signal['hit']==1, 'event'] = 1
    df_signal['duration'] = df_signal['trial_dur']-df_signal['noise_dur']
    # df_signal.loc[df_signal['duration'].isna(), 'duration'] = df_signal.loc[df_signal['duration'].isna(), 'trial_dur']
    df_signal['id'] = np.array(df_signal.index)
    df_signal = to_long_format(df_signal, duration_col="duration")
    df_signal.loc[:,'start'] = df_signal.loc[:,'start'] + df_noise.loc[(df_noise['hit']==1)|(df_noise['miss']==1), 'stop']
    df_signal.loc[:,'stop'] = df_signal.loc[:,'stop'] + df_noise.loc[(df_noise['hit']==1)|(df_noise['miss']==1), 'stop']
    df_signal['signal'] = 1
    
    # return:
    # columns = ['id', 'subject_id', 'session_id', 'trial', 'start', 'stop', 'event', 'signal']
    # columns.extend(conditions)
    # df_noise = df_noise[columns]
    # df_signal = df_signal[columns]
    df_cox = pd.concat((df_noise, df_signal), axis=0).sort_values(by=['id', 'start'])
    df_cox = df_cox.loc[df_cox['stop']!=0,:]

    # print(df_cox.head())

    return df_cox

def hazard_stratified(df, groupby=['subject_id', 'reward'], split='reward', measure='coef', n_jobs=64):

    baseline, params, test_ph = behavior.cox_group(df, groupby=groupby, formula='Surv(start,stop,event)~signal', n_jobs=n_jobs)

    fig = plt.figure(figsize=(6,3))
    ax = fig.add_subplot(121)
    for r, d in baseline.groupby(split):
        plt.fill_between(np.array(d.columns, dtype=float), d.mean(axis=0)-d.sem(axis=0), d.mean(axis=0)+d.sem(axis=0), alpha=0.2)
        plt.plot(np.array(d.columns, dtype=float), d.mean(axis=0))
    ax.set_title('Baseline cum. hazard')
    ax.set_xlabel('Time (s)')

    ax = fig.add_subplot(122)
    sns.boxplot(x=split, y=measure, data=params)
    # sns.stripplot(x=split, y=measure, color='grey', data=params)
    for s, d in params.groupby(['subject_id']):
        plt.plot([0,1], [d[measure].iloc[0], d[measure].iloc[1]], color='grey', alpha=0.2, lw=0.75)
    t_value, p_value = sp.stats.ttest_rel(params.loc[params['reward']==1,measure], params.loc[params['reward']==0,measure])
    trans = transforms.blended_transform_factory(
        ax.transData, ax.transAxes)
    plt.text(x=0.5, y=0.9, s='t = {}, p = {}'.format(round(t_value,3), round(p_value,3)), size=6, ha='center', transform=trans)
    plt.axhline(1, lw=0.5, color='r')
    ax.set_title('Hazard ratio (signal vs. noise)')

    # ax = fig.add_subplot(133)
    # sns.barplot(x=split, y='correct', data=corrects)
    # sns.stripplot(x=split, y='correct', color='grey', data=corrects)
    # ax.set_title('Accuracy (% correct)')

    plt.tight_layout()
    sns.despine(trim=False, offset=3)
    return fig

def hazards_one_model(df, groupby=['subject_id'], measure='coef', n_jobs=64):

    baseline, params, test_ph = behavior.cox_group(df, groupby=groupby, 
                                        formula='Surv(start,stop,event)~signal*reward', n_jobs=n_jobs)
    # params.loc[params['covariate']=='signal:reward', measure] = 1 / params.loc[params['covariate']=='signal:reward', measure] 

    fig = plt.figure(figsize=(6,3))
    ax = fig.add_subplot(121)
    plt.fill_between(np.array(baseline.columns, dtype=float), 
                        baseline.mean(axis=0)-baseline.sem(axis=0), 
                        baseline.mean(axis=0)+baseline.sem(axis=0), alpha=0.2)
    plt.plot(np.array(baseline.columns, dtype=float), baseline.mean(axis=0))
    ax.set_title('Baseline cum. hazard')
    ax.set_xlabel('Time (s)')

    ax = fig.add_subplot(122)
    # sns.stripplot(x='covariate', y='measure', data=params.reset_index())
    sns.boxplot(x='covariate', y='coef', data=params)
    trans = transforms.blended_transform_factory(
        ax.transData, ax.transAxes)
    for i, (c, p) in enumerate(params.groupby(['covariate'])):
        t_value, p_value = sp.stats.ttest_rel(p[measure], np.ones(len(p)))
        plt.text(x=i, y=0.9, s='t = {}\np = {}'.format(round(t_value,3), round(p_value,3)), size=6, ha='center', transform=trans)
    plt.axhline(1, color='r', lw=0.75)
    plt.title('hazard ~ 1 + signal * reward')
    plt.ylabel('Hazard ratio')
    plt.tight_layout()
    sns.despine(trim=False, offset=3)
    return fig

def hazards_across_trials(df, n_jobs=64):

    groupby = ['trial']
    baseline, params, test_ph = behavior.fit_cox_group(df, groupby=groupby, 
                                        formula='Surv(start,stop,event)~signal', n_jobs=n_jobs)

    ind = params['covariate']=='signal'

    fig = plt.figure(figsize=(4,2))
    plt.fill_between(params.loc[ind, 'trial'], 
                        params.loc[ind, 'exp(coef)']-params.loc[ind, 'robust se'],
                        params.loc[ind, 'exp(coef)']+params.loc[ind, 'robust se'], alpha=0.2)
    plt.plot(params.loc[ind, 'trial'], params.loc[ind, 'exp(coef)'], label='signal HR')

    plt.plot(baseline.loc[:, baseline.columns==4], label='baseline hazard')
    plt.legend()
    plt.xlabel('Trial (#)')
    plt.tight_layout()
    sns.despine(trim=False, offset=3)
    return fig

def sdt_across_x(df_sdt, df_sdt_c, x_measure='session_id'):

    plt_nr = 1
    fig = plt.figure(figsize=(4,4))
    for measure in ['d_c', 'c']:
        ax = fig.add_subplot(2,2,plt_nr)
        mean = df_sdt.groupby([x_measure]).mean()[measure]
        sem = df_sdt.groupby([x_measure]).sem()[measure]
        x = np.array(mean.index)
        plt.fill_between(x, mean-sem, mean+sem, alpha=0.2)
        plt.plot(x, mean)
        plt.xlabel(x_measure)
        plt.ylabel(measure)
        plt_nr += 1

    for measure in ['d_c', 'c']:
        ax = fig.add_subplot(2,2,plt_nr)
        mean = df_sdt_c.groupby([x_measure]).mean()[measure]
        sem = df_sdt_c.groupby([x_measure]).sem()[measure]
        x = np.array(mean.index)
        plt.fill_between(x, mean-sem, mean+sem, alpha=0.2)
        plt.plot(x, mean)
        plt.xlabel(x_measure)
        plt.ylabel('Δ ' + measure)
        plt_nr += 1
    sns.despine(trim=False)
    plt.tight_layout()
    return fig



def plot_physio_since_hit(df, hit_nr):

    trial_measure = 'trial_{}'.format(hit_nr)

    groupby = ['c_bin', 'reward', trial_measure]
    df_sdt = behavior.sdt_group(df=df.loc[(df[trial_measure]>0)], groupby=groupby,
                                    min_dur=0.1, nr_sim_trials=5000, n_jobs=n_jobs)
    mean = df_sdt.groupby(['reward', trial_measure]).mean().reset_index()
    sem = df_sdt.groupby(['reward', trial_measure]).sem().reset_index()

    fig = plt.figure(figsize=(4,2))
    plt_nr = 1
    for measure in ['d_c', 'c']:
        ax = fig.add_subplot(1,2,plt_nr)
        for r, c in zip([0,1], ['grey', 'green']):
            ind = (mean['reward']==r)
            plt.fill_between(mean.loc[ind, trial_measure], mean.loc[ind, measure]-sem.loc[ind, measure], mean.loc[ind, measure]+sem.loc[ind, measure], color=c, alpha=0.2)
            plt.plot(mean.loc[ind, trial_measure], mean.loc[ind, measure], color=c)
        plt.xlabel('Trial # since {} hit(s) in block'.format(hit_nr+1))
        plt.ylabel(measure)
        plt_nr += 1
    sns.despine(trim=False)
    plt.tight_layout()

    return fig

def cox_regression(df, formula='1 + signal', effect_coding=True):

    from lifelines import CoxTimeVaryingFitter
    ctv = CoxTimeVaryingFitter(penalizer=0.1)

    covars = formula.split('+')
    covars = [c for c in covars if c != '1 ']
    covars = [c for c in covars if 'signal' not in c]
    covars = [c for c in covars if ':' not in c]
    covars = [c for c in covars if '**' not in c]
    covars = [c.strip() for c in covars]

    # check if not NaN:
    for c in covars:
        df = df.loc[~df[c].isna(),:]

    if effect_coding:
        for c in covars:
            print('z-scoring: {}'.format(c))
            df[c] = (df[c]-df[c].mean()) / df[c].std()
    else:
        print('no effect coding!')

    ctv.fit(df, id_col="id", event_col="event", start_col="start", stop_col="stop", formula=formula)
    params = pd.DataFrame(ctv.hazard_ratios_).T
    params['aic'] = ctv.AIC_partial_
    params['ll'] = ctv.log_likelihood_

    return params

def cox_regression_group(df, formula, groupby, effect_coding=True, n_jobs=48):

    res = Parallel(n_jobs=n_jobs)(delayed(cox_regression)(df=data, formula=formula, effect_coding=effect_coding) 
                                        for ids, data in tqdm(df.groupby(groupby)))
    df_res = pd.concat(res).reset_index()
    df_res[groupby] = pd.DataFrame([name for name, unused_df in df.groupby(groupby)], columns=groupby)
    return df_res

def logistic_regression(df, formula='choice ~ 1 + stimulus * reward', link_function='logit', start_params=None, effect_coding=True):
    
    import statsmodels.formula.api as smf

    if effect_coding:
        print('effect coding!')
        df['stimulus'] = (df['stimulus']-df['stimulus'].mean()) / df['stimulus'].std()
        df['trial'] = (df['trial']-df['trial'].mean()) / df['trial'].std()
        df['hit_p'] = (df['hit_p']-df['hit_p'].mean()) / df['hit_p'].std()
        df['reward'] = (df['reward']-df['reward'].mean()) / df['reward'].std()
        df['walk'] = (df['walk']-df['walk'].mean()) / df['walk'].std()
        df['walk_p'] = (df['walk_p']-df['walk_p'].mean()) / df['walk_p'].std()
        df['pupil_trial_start'] = (df['pupil_trial_start']-df['pupil_trial_start'].mean()) / df['pupil_trial_start'].std()
        df['dist_from_optimal'] = (df['dist_from_optimal']-df['dist_from_optimal'].mean()) / df['dist_from_optimal'].std()

        # df['noise_dur'] = (df['noise_dur']-df['noise_dur'].mean()) / df['noise_dur'].std()
    else:
        print('no effect coding!')

    try:
        if link_function == 'linear':
            print('linear!!')
            model = smf.ols(formula=formula, data=df)
            fit = model.fit(maxiter=500)
        elif link_function == 'logit':
            print('logit!!')
            model = smf.logit(formula=formula, data=df)
            # fit = model.fit(start_params=start_params, maxiter=500)
            fit = model.fit(maxiter=500)
        elif link_function == 'probit':
            print('probit!!')
            model = smf.probit(formula=formula, data=df)
            fit = model.fit(maxiter=500)

        if not fit.mle_retvals["converged"]:
            params = pd.DataFrame({n:[np.NaN] for n in model.exog_names})
            params['aic'] = np.NaN
            params['bic'] = np.NaN
            params['r2'] = np.NaN
        else:
            params = pd.DataFrame(fit.params).T
            # params = pd.DataFrame(fit.tvalues).T
            params['aic'] = fit.aic
            params['bic'] = fit.bic
            if link_function == 'linear':
                params['r2'] = fit.rsquared
            else:
                params['r2'] = fit.prsquared

    except np.linalg.LinAlgError as err:
        print(err)
        params = pd.DataFrame({n:[np.NaN] for n in model.exog_names})
        params['aic'] = np.NaN
        params['bic'] = np.NaN
        params['r2'] = np.NaN

    # params_se = pd.DataFrame(logitfit.bse).T
    # t_scores = params/params_se
    # params_se.columns = [c + '_se' for c in params_se.columns]
    # t_scores.columns = [c + '_t' for c in t_scores.columns]
    # res = pd.concat((params, params_se, t_scores), axis=1)

    # res['far_0'] = 1 / (1 + np.exp(-(logitfit.params['Intercept'])))
    # res['far_1'] = 1 / (1 + np.exp(-(logitfit.params['Intercept']+logitfit.params['reward'])))
    # res['hr_0'] = 1 / (1 + np.exp(-(logitfit.params['Intercept']+logitfit.params['stimulus'])))
    # res['hr_1'] = 1 / (1 + np.exp(-(logitfit.params['Intercept']+logitfit.params['reward']+logitfit.params['stimulus']+logitfit.params['stimulus:reward'])))

    return params

def logistic_regression_group(df, formula, groupby, link_function='logit', effect_coding=True, n_jobs=48):

    start_params = logistic_regression(df=df, formula=formula, link_function=link_function, effect_coding=effect_coding)
    print(start_params)
    start_params = start_params.values[0,:-3]
    print(start_params)

    res = Parallel(n_jobs=n_jobs)(delayed(logistic_regression)(df=data, formula=formula, link_function=link_function, 
                                                               start_params=start_params, effect_coding=effect_coding) 
                                        for ids, data in tqdm(df.groupby(groupby)))
    df_res = pd.concat(res).reset_index()
    df_res[groupby] = pd.DataFrame([name for name, unused_df in df.groupby(groupby)], columns=groupby)
    return df_res

def glm_behavior(df, y, groupby, link_function='logit', effect_coding=False, n_jobs=48):

    if effect_coding:
        if 'signal' in df.columns:
            df.loc[df['signal']==0, 'signal'] = -1
        # effect coding:
        df.loc[df['reward']==0, 'reward'] = -1
        df.loc[df['correct_p']==0, 'correct_p'] = -1
        df.loc[df['choice_p']==0, 'choice_p'] = -1
        df.loc[df['false_start_p']==0, 'false_start_p'] = -1
        df.loc[df['walk']==0, 'walk'] = -1
    
    df['block_id'] = df['block_id'] - 3.5
    for subj, d in df.groupby(groupby):
        df.loc[d.index, 'pupil'] = (df.loc[d.index, 'pupil']-df.loc[d.index, 'pupil'].mean()) / df.loc[d.index, 'pupil'].std()
        # df.loc[d.index, 'noise_dur'] = (df.loc[d.index, 'noise_dur']-df.loc[d.index, 'noise_dur'].mean()) / df.loc[d.index, 'noise_dur'].std()
    df['pupil_sqr'] = df['pupil']**2

    # # weighted effect coding:
    # for ids, d in df.groupby(groupby):
    #     print(ids)
    #     df.loc[d.index, 'signal'] = np.array(df.loc[d.index, 'signal'] - df.loc[d.index, 'signal'].mean())
    #     df.loc[d.index, 'reward'] = np.array(df.loc[d.index, 'reward'] - df.loc[d.index, 'reward'].mean())
    #     df.loc[d.index, 'choice_p'] = np.array(df.loc[d.index, 'choice_p'] - df.loc[d.index, 'choice_p'].mean())
    #     df.loc[d.index, 'correct_p'] = np.array(df.loc[d.index, 'correct_p'] - df.loc[d.index, 'correct_p'].mean())
    #     df.loc[d.index, 'false_start_p'] = np.array(df.loc[d.index, 'false_start_p'] - df.loc[d.index, 'false_start_p'].mean())
    
    # print(df.loc[:,['signal', 'reward', 'false_start_p', 'walk']].head())

    # fit models:  + signal * noise_dur
    if link_function == 'linear':
        df_res0 = logistic_regression_group(df=df, formula="{} ~ 1 + block_id".format(y), groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        df_res1 = logistic_regression_group(df=df, formula="{} ~ 1 + block_id + reward".format(y), groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        df_res2 = logistic_regression_group(df=df, formula="{} ~ 1 + block_id + reward + walk".format(y), groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        df_res3 = logistic_regression_group(df=df, formula="{} ~ 1 + block_id + reward + walk + pupil".format(y), groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        df_res4 = logistic_regression_group(df=df, formula="{} ~ 1 + block_id + reward + walk + pupil + pupil_sqr".format(y), groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        dfs_res = [df_res0, df_res1, df_res2, df_res3, df_res4]
    elif link_function == 'logit':
        df_res0 = logistic_regression_group(df=df, formula="{} ~ 1 + noise_dur + choice_p + correct_p".format(y), groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        df_res1 = logistic_regression_group(df=df, formula="{} ~ 1 + noise_dur + choice_p + correct_p + block_id".format(y), groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        df_res2 = logistic_regression_group(df=df, formula="{} ~ 1 + noise_dur + choice_p + correct_p + block_id + reward".format(y), groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        df_res3 = logistic_regression_group(df=df, formula="{} ~ 1 + noise_dur + choice_p + correct_p + block_id + reward + pupil".format(y), groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        df_res4 = logistic_regression_group(df=df, formula="{} ~ 1 + noise_dur + choice_p + correct_p + block_id + reward + pupil + pupil_sqr".format(y), groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        df_res5 = logistic_regression_group(df=df, formula="{} ~ 1 + noise_dur + choice_p + correct_p + block_id + reward + pupil + pupil_sqr + walk".format(y), groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        dfs_res = [df_res0, df_res1, df_res2, df_res3, df_res4, df_res5]
    elif link_function == 'probit':
        df_res0 = logistic_regression_group(df=df, formula="event ~ 1 + signal + choice_p + correct_p", groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        df_res1 = logistic_regression_group(df=df, formula="event ~ 1 + signal + choice_p + correct_p", groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        df_res2 = logistic_regression_group(df=df, formula="event ~ 1 + signal * block_id + choice_p + correct_p", groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        df_res3 = logistic_regression_group(df=df, formula="event ~ 1 + signal * block_id + signal * reward + choice_p + correct_p", groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        df_res4 = logistic_regression_group(df=df, formula="event ~ 1 + signal * block_id + signal * reward + signal * pupil + signal * pupil_sqr + choice_p + correct_p", groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        df_res5 = logistic_regression_group(df=df, formula="event ~ 1 + signal * block_id + signal * reward + signal * pupil + signal * pupil_sqr + signal * walk + choice_p + correct_p", groupby=groupby, link_function=link_function, n_jobs=n_jobs)
        dfs_res = [df_res0, df_res1, df_res2, df_res3, df_res4, df_res5]
    
    # add:
    for i in range(len(dfs_res)):
        dfs_res[i]['model'] = i

    # correct bics:
    for i in range(1,len(dfs_res)):
        dfs_res[i]['bic_d'] = dfs_res[i]['bic']-df_res0['bic']

    # correct r2s:
    for i in range(1,len(dfs_res)):
        dfs_res[i]['r2_d'] = (dfs_res[i]['r2']-df_res0['r2']) / df_res0['r2'] * 100

    # concatenate:
    df_res = pd.concat(dfs_res, axis=0)

    print(df_res.groupby(['model']).mean())

    # # correct:
    # choice_columns = ['Intercept', 'reward', 'choice_p', 'correct_p', 'pupil', 'reward:pupil']
    # signal_columns = ['signal', 'signal:reward', 'signal:pupil', 'signal:reward:pupil']
    # df_res[choice_columns] = df_res[choice_columns] * -1
    # df_res[signal_columns] = df_res[signal_columns] * 2 

    return df_res

# def logistic_regression(df, formula, order, groupby=['subject_id', 'reward'], dt=0.1, n_jobs=32):
    
#     # columns:
#     columns = list(set(formula.split(' ')))
#     columns.remove('~')
#     columns.remove('1')
#     columns.remove('+')
#     if 'x' in columns:
#         columns.remove('*')
#     columns = [c for c in columns if not '_2' in c]
#     if groupby is not None:
#         columns.extend(groupby)
#     columns.extend(['noise_dur', 'trial_dur', 'hit', 'fa', 'miss', 'cr'])

#     print(columns)

#     # make stimulus 1 and -1:
#     columns_to_adjust = ['stimulus', 'stimulus_p', 'choice_p', 'walk']
#     for c in columns_to_adjust:
#         df.loc[df[c]==0, c] = -1
#         print(c)
#         print(df[c].unique())

#     # run:
#     df_lr = behavior.logistic_regression_group(df=df[columns], groupby=groupby, 
#                                                 formula=formula, dt=dt, n_jobs=n_jobs)
    
#     # remove subject with NaNs:
#     a = df_lr.copy()
#     failed_subjects = df_lr.loc[df_lr['Intercept'].isna(), 'subject_id'].unique()
#     df_lr = df_lr.loc[~df_lr['subject_id'].isin(failed_subjects),:]

#     # adjust intercept:
#     df_lr['Intercept'] = df_lr['Intercept'] * dt

#     # rearrange:
#     df_lr = df_lr.loc[:, order+groupby].melt(id_vars=groupby)

#     # plot:
        
#     fig = plt.figure(figsize=(6,6))
#     # gs = fig3.add_gridspec(1, 6)
#     ax = fig.add_subplot(111)
#     sns.boxplot(x='variable', y='value', hue='reward', order=order, 
#                 linewidth=0.5, palette=[sns.color_palette("muted")[0], sns.color_palette("muted")[1]], 
#                 data=df_lr, ax=ax)
#     # sns.stripplot(x='variable', y='value', hue='reward', order=order, 
#     #                 jitter=True, dodge=True, size=4, color=".3", linewidth=0, 
#     #                 data=df_lr, ax=ax)
#     plt.axhline(0, color='k', lw=1, ls='--')
#     plt.ylabel('coefficient')
#     import matplotlib.transforms as transforms
#     trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
#     for i, v in enumerate(order):
#         # t,p = sp.stats.wilcoxon(df_lr.loc[(df_lr['variable']==v)&(df_lr['reward']==1), 'value'], df_lr.loc[(df_lr['variable']==v)&(df_lr['reward']==0), 'value'])
#         t,p = sp.stats.ttest_rel(df_lr.loc[(df_lr['variable']==v)&(df_lr['reward']==1), 'value'], df_lr.loc[(df_lr['variable']==v)&(df_lr['reward']==0), 'value'])
#         plt.text(x=i, y=0.9, s='p={}'.format(round(p,3)), size=6, rotation=45, transform=trans)
#     plt.xticks(rotation=90)
#     plt.xlabel('')
#     sns.despine(trim=False)
#     plt.tight_layout()

#     # fig.savefig('/home/jwdegee/att_eff/figs/logistic_regression_coefficients.pdf')
    
#     return fig

# def logistic_regression(df):

#     df = df.reset_index(drop=True)
#     df_1 = df.loc[np.where((df['trial'].diff()==1))[0],:].reset_index()
#     df_0 = df.loc[np.where((df['trial'].diff()==1))[0]-1,:].reset_index()
#     df_1[['outcome_0', 'stimulus_p', 'choice_p', 'correct_0', 'hit_0', 'fa_0', 'miss_0', 'cr_0']] = df_0[['outcome', 'stimulus', 'choice', 'correct', 'hit', 'fa', 'miss', 'cr']]

#     groupby = ['subject_id']

#     # formula:
#     formula = ("choice ~ 1 + " 
#                     # "p_signal + "
#                     "stimulus * reward + "
#                     "stimulus * stimulus-1 + " 
#                     "stimulus * choice-1 + " 
#                     "stimulus * trial + "
#                     "stimulus * walk + "
#                     "stimulus * pupil_trial_start + "
#                     "stimulus * pupil_trial_start_2"
#             )

#     # columns:
#     columns = list(set(formula.split(' ')))
#     columns.remove('~')
#     columns.remove('1')
#     columns.remove('+')
#     columns.remove('*')
#     columns.remove('pupil_trial_start_2')
#     columns.extend(groupby)
#     columns.extend(['time_target', 'time_trial_end', 'hit', 'fa', 'miss', 'cr'])

#     # run:
#     df_lr = behavior.logistic_regression_group(df=df_1[columns], groupby=groupby, 
#                                                 formula=formula, min_dur=0.1, regularize=False, n_jobs=n_jobs)


#     # groupby = ['subject_id', 'reward']
#     # n_jobs = 48
#     # df_sdt = behavior.sdt_group(df=df.loc[(df['trial']%60)>=15], groupby=groupby,
#     #                                 min_dur=0.1, min_trials=5, nr_sim_trials=5000, n_jobs=n_jobs)

#     # imp.reload(behavior)

#     # df_lr = behavior.logistic_regression_group(df=df_1.loc[(df_1['trial']%60)>=15], groupby=groupby, 
#     #                                             formula=(
#     #                                                     "choice ~ 1 + " 
#     #                                                             "p_signal + "
#     #                                                             "stimulus * reward"
#     #                                                     ), 
#     #                                             min_dur=0.1, regularize=False, n_jobs=n_jobs)

#     # print(sp.stats.pearsonr(df_lr['stimulus'], df_sdt.groupby('subject_id').mean()['d']))
#     # print(sp.stats.pearsonr(df_lr['stimulus'], df_sdt.groupby('subject_id').mean()['d_c']))
#     # print(sp.stats.pearsonr(df_lr['stimulus:reward'], np.array(df_sdt.loc[df_sdt['reward']==1,'d'])-np.array(df_sdt.loc[df_sdt['reward']==0,'d'])))
#     # print(sp.stats.pearsonr(df_lr['stimulus:reward'], np.array(df_sdt.loc[df_sdt['reward']==1,'d_c'])-np.array(df_sdt.loc[df_sdt['reward']==0,'d_c'])))

#     # for i, v in enumerate(order):
#     #     t,p = sp.stats.wilcoxon(df_lr.loc[df_lr['variable']==v, 'value'], np.zeros(df_lr.loc[df_lr['variable']==v, 'value'].shape[0]))


#     order = ['Intercept', 'time', 'p_signal', 'stimulus', 'reward', 'stimulus_p', 'choice_p', 'trial', 'walk', 'pupil_trial_start', 'pupil_trial_start_2',
#             'stimulus:reward', 'stimulus:stimulus_p', 'stimulus:choice_p', 'stimulus:trial', 'stimulus:walk', 'stimulus:pupil_trial_start', 'stimulus:pupil_trial_start_2',]

#     # order = [c + '_t' for c in order]

#     df_lr = df_lr.loc[:, order+groupby].melt(id_vars=groupby)

#     fig1 = plt.figure(figsize=(6,6))
#     ax = fig1.add_subplot(111)
#     sns.boxplot(x='variable', y='value', order=order, linewidth=0.75, data=df_lr, ax=ax)
#     plt.axhline(0, color='k', lw=1, ls='--')
#     plt.ylabel('coefficient')
#     import matplotlib.transforms as transforms
#     trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
#     for i, v in enumerate(order):
#         t,p = sp.stats.wilcoxon(df_lr.loc[df_lr['variable']==v, 'value'], np.zeros(df_lr.loc[df_lr['variable']==v, 'value'].shape[0]))
#         plt.text(x=i, y=0.9, s='p={}'.format(round(p,3)), size=6, rotation=45, transform=trans)
#     plt.xticks(rotation=90)
#     plt.xlabel('')
#     sns.despine(trim=False)
#     plt.tight_layout()
    

#     groupby = ['subject_id']
#     res = Parallel(n_jobs=n_jobs)(delayed(behavior.sdt_prepare_data)(df=data, min_dur=0.1)
#                                         for ids, data in tqdm(df_1.loc[(df_1['trial']%60)>=15].groupby(groupby)) if data['choice'].sum() >= 5)
#     df_sdt = pd.concat(res).reset_index(drop=True)

#     fig2 = plt.figure(figsize=(4,1.75))
#     for i, var in enumerate(['reward', 'stimulus_p', 'choice_p', 'walk']):
#         ax = fig2.add_subplot(1,4,i+1)
#         d_ = df_sdt.groupby(['subject_id', 'stimulus', var]).mean()['choice'].reset_index()
#         sns.pointplot(x=var, y='choice', hue='stimulus', color="xkcd:plum", data=d_, ax=ax)
#         plt.ylabel('')
#         ax.legend([],[], frameon=False)
#     sns.despine(trim=False)
#     plt.tight_layout()
    
#     return fig1, fig2



def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'same') / w

def make_epochs_kernels(df, cutoff=2):

    subject_id, session_id = (df['subject_id'].iloc[0], df['session_id'].iloc[0])
    try:
        # df_input = pd.read_csv('/media/external1/projects/att_eff/preprocess/{}/{}/{}_{}_df_input.csv'.format(subject_id, session_id, subject_id, session_id), usecols=['time', 'input'])
        df_input = pd.read_csv('/media/external1/projects/att_eff/preprocess/{}/{}/{}_{}_df_input.csv'.format(subject_id, session_id, subject_id, session_id))
    except:
        print('{} session {} failed!'.format(subject_id, session_id))
        return [pd.DataFrame([]), pd.DataFrame([]), pd.DataFrame([]), pd.DataFrame([])]
    

    columns = [c for c in df_input.columns if 'tone' in c]
    M = np.array(df_input[columns]).T

    # # compute input signal:
    # m1 = np.zeros(M.shape[1])
    # m2 = np.zeros(M.shape[1])
    # m3 = np.zeros(M.shape[1])
    # for i in range(M.shape[1]-5):
    #     # corr[i] = sp.stats.pearsonr(M[:,i], np.roll(M[:,i+1], shift=-1))[0]
    #     m1[i] = (M[:,i]*np.roll(M[:,i+1], shift=-1)).sum() / M[:,i].sum()
    #     m2[i] = (M[:,i]*np.roll(M[:,i+1], shift=-1)*np.roll(M[:,i+2], shift=-2)).sum() / M[:,i].sum()
    #     m3[i] = (M[:,i]*np.roll(M[:,i+1], shift=-1)*np.roll(M[:,i+2], shift=-2)*np.roll(M[:,i+3], shift=-3)).sum() / M[:,i].sum()

    bandwiths = [1]
    for b in bandwiths:
        m = np.zeros(M.shape[1])
        for i in range(max(bandwiths),M.shape[1]): 
            tones = M[:,i]
            for shift in range(1,b+1):
                tones = tones * np.roll(M[:,i-shift], shift=+shift)
            m[i] = tones.sum() / M[:,i].sum()
        df_input['m{}'.format(b)] = m

    # plt.plot(df_input['input'], alpha=0.5)
    # plt.plot(corr, alpha=0.5)

    # df_input['m1'] = m1
    # df_input['m2'] = m2
    # df_input['m3'] = m3

    # # resample:
    # df_input = utils.resample(df_input[['time', 'input']], fs=50)

    epochs_h = []
    epochs_m = []
    epochs_fa = []
    epochs_h2 = []
    # for measure in ['input', 'input_low', 'input_med', 'input_high']:
    for measure in ['input', 'm1']:

        # # make epochs of input signal
        # ind = ((df['time_target']>=4)|(df['time_report']>=4))
        # e = utils.make_epochs(df=df_input, df_meta=df.loc[ind,:], locking='trial_start_time', start=-1, dur=5, measure=measure, fs=50, baseline=False, b_start=-1, b_dur=1)
        # columns = ['subject_id', 'session_id', 'reward']
        # e[columns] = df.loc[ind,columns].reset_index(drop=True)
        # e = e.set_index(columns)
        # epochs.append(e)

        columns = ['subject_id', 'session_id', 'trial', 'reward', 'difficulty']

        # make epochs of input signal
        ind = (df['outcome']=='hit')&(df['time_target']>=cutoff)
        e = utils.make_epochs(df=df_input, df_meta=df.loc[ind,:], locking='hit_locking', start=-cutoff-1.5, dur=cutoff+2, measure=measure, fs=50, baseline=False, b_start=-1, b_dur=1)
        e[columns] = df.loc[ind,columns].reset_index(drop=True)
        e = e.set_index(columns)
        epochs_h.append(e)

        # make epochs of input signal
        ind = (df['outcome']=='miss')&(df['time_target']>=cutoff)
        e = utils.make_epochs(df=df_input, df_meta=df.loc[ind,:], locking='miss_locking', start=-cutoff-1.5, dur=cutoff+2, measure=measure, fs=50, baseline=False, b_start=-1, b_dur=1)
        e[columns] = df.loc[ind,columns].reset_index(drop=True)
        e = e.set_index(columns)
        epochs_m.append(e)

        # make epochs of input signal
        ind = (df['outcome']=='fa')&(df['time_report']>=cutoff)
        e = utils.make_epochs(df=df_input, df_meta=df.loc[ind,:], locking='fa_locking', start=-cutoff-1.5, dur=cutoff+2, measure=measure, fs=50, baseline=False, b_start=-1, b_dur=1)
        e[columns] = df.loc[ind,columns].reset_index(drop=True)
        e = e.set_index(columns)
        epochs_fa.append(e)

        # make epochs of input signal
        ind = (df['outcome']=='hit')&(df['rt']>=cutoff)
        e = utils.make_epochs(df=df_input, df_meta=df.loc[ind,:], locking='hit_locking2', start=-cutoff-1.5, dur=cutoff+2, measure=measure, fs=50, baseline=False, b_start=-1, b_dur=1)
        e[columns] = df.loc[ind,columns].reset_index(drop=True)
        e = e.set_index(columns)
        epochs_h2.append(e)

        # # make epochs of input signal:
        # ind = (df['outcome']=='fa')&(df['time_report']>=1)
        # e = []
        # for t_start, t_end in zip(df.loc[ind, 'trial_start_time'], df.loc[ind, 'fa_locking']):
        #     loc_start = df_input['time'].searchsorted(t_start+0.05)
        #     loc_end = df_input['time'].searchsorted(t_end)
            
        #     input_trial = pd.DataFrame({'time': np.linspace(0,1,loc_end-loc_start),
        #                                 'input': np.array(df_input[measure].iloc[loc_start:loc_end])
        #                                 })
        #     input_trial = utils.resample(input_trial, fs=100)
        #     input_trial = input_trial.T
        #     input_trial.columns = np.array(input_trial.iloc[0])
        #     input_trial = input_trial.iloc[1:]
        #     e.append(input_trial)
        # e = pd.concat(e).reset_index()
        # e[columns] = df.loc[ind,columns].reset_index(drop=True)
        # e = e.set_index(columns)
        # epochs.append(e)

    return epochs_h, epochs_m, epochs_fa, epochs_h2

def plot_physio_responses(df, epochs, epochs_baseline=None, nan=False, random=False, split='sdt_reward', slope=False, ax=None):

    # match:
    globals().update(locals())
    index1 = pd.MultiIndex.from_arrays([df[col] for col in ['subject_id', 'session_id', 'reward', 'trial', 'outcome']])
    index2 = epochs.index
    df = df.loc[index1.isin(index2),:]
    epochs = epochs.loc[index2.isin(index1),:]
    if epochs_baseline is not None:
        epochs_baseline = epochs_baseline.loc[index2.isin(index1),:]
    print(epochs.shape)
    print(df.shape)

    # to NaN:
    if not nan is None:
        x = np.array(epochs.columns, dtype=float)
        for i in range(epochs.shape[0]):
            if nan == 'forward':
                epochs.iloc[i].loc[x>df['trial_dur'].iloc[i]] = np.NaN
            elif nan == 'backward':
                epochs.iloc[i].loc[x<-df['trial_dur'].iloc[i]] = np.NaN

    if epochs_baseline is not None:
        baselines = np.array(epochs_baseline.loc[:,(epochs_baseline.columns>=-1)&(epochs_baseline.columns<0)].mean(axis=1))
        epochs = epochs - np.atleast_2d(baselines).T

    if slope:
        epochs = epochs.diff(axis=1) * round(1/np.diff(epochs.columns).mean())
        epochs_baseline = None
    
    # plot:
    if split == 'none':
        if random:
            means = epochs.groupby(['subject_id']).mean().mean(axis=0)
            sems = epochs.groupby(['subject_id']).mean().sem()
        else:
            means = epochs.mean(axis=0)
            sems = epochs.sem(axis=0)
        
        x = np.array(epochs.columns, dtype=float)
        plt.fill_between(x, means-sems, means+sems, color='black', alpha=0.2)
        plt.plot(x, means, color='black', ls='-', lw=1)
        plt.axvline(0, color='k', lw=0.5)
        # plt.xlim(-2.5,5)
        plt.xlabel('Time (s)')
        # plt.ylabel(measure)


    if split == 'reward':
        if random:
            means = epochs.groupby(['subject_id', 'reward']).mean().groupby(['reward']).mean()
            sems = epochs.groupby(['subject_id', 'reward']).mean().groupby(['reward']).sem()
        else:
            means = epochs.groupby(['reward']).mean()
            sems = epochs.groupby(['reward']).sem()
        
        x = np.array(epochs.columns, dtype=float)
        for i, color, ls in zip([0,1], sns.color_palette(), ['-', '-']):
            plt.fill_between(x, means.iloc[i]-sems.iloc[i], means.iloc[i]+sems.iloc[i], color=color, alpha=0.2)
            plt.plot(x, means.iloc[i], color=color, ls=ls, lw=1)
        plt.axvline(0, color='k', lw=0.5)
        # plt.xlim(-2.5,5)
        plt.xlabel('Time (s)')
        # plt.ylabel(measure)

    if split == 'sdt_reward':
        if random:
            means = epochs.groupby(['subject_id', 'reward', 'outcome']).mean().groupby(['reward', 'outcome']).mean()
            sems = epochs.groupby(['subject_id', 'reward', 'outcome']).mean().groupby(['reward', 'outcome']).sem()
        else:
            means = epochs.groupby(['reward', 'outcome']).mean()
            sems = epochs.groupby(['reward', 'outcome']).sem()

        x = np.array(epochs.columns, dtype=float)
        for i in range(len(means)):
            # ls = ['--', '-'][means.index.get_level_values('reward')[i]]
            # color = ['green', 'grey', 'red', 'blue'][means.index.get_level_values('outcome')[i]]

            ls = ['-', '--', '--'][means.index.get_level_values('outcome')[i]]
            color = [sns.color_palette("tab10")[0], sns.color_palette("tab10")[1]][means.index.get_level_values('reward')[i]]


            try:
                plt.fill_between(x, means.iloc[i]-sems.iloc[i], means.iloc[i]+sems.iloc[i], color=color, alpha=0.2)
                plt.plot(x, means.iloc[i], color=color, ls=ls, lw=1)
            except:
                pass
        plt.axvline(0, color='k', lw=0.5)
        # plt.xlim(-2.5,5)
        plt.xlabel('Time (s)')
        # plt.ylabel(measure)

        # add ANOVA:
        try:
            import statsmodels.api as sm
            from statsmodels.stats.anova import AnovaRM
            import mne
            means = epochs.groupby(['subject_id', 'reward', 'outcome']).mean()
            # print(means.head())
            # print(means.columns)

            p_values_reward = []
            p_values_outcome = []
            for j in range(means.shape[1]):
                data = means.iloc[:,j].reset_index()
                data = data.loc[(data['outcome']==0)|(data['outcome']==2),:]
                aovrm = AnovaRM(data, data.columns[-1], 'subject_id', within=['reward', 'outcome'], aggregate_func='mean')
                res = aovrm.fit().anova_table.reset_index()
                p = float(res.loc[res['index']=='reward', 'Pr > F'])
                if np.isnan(p):
                    p_values_reward.append(1)
                else:
                    p_values_reward.append(p)
                p = float(res.loc[res['index']=='outcome', 'Pr > F'])
                if np.isnan(p):
                    p_values_outcome.append(1)
                else:
                    p_values_outcome.append(p)

            reject, p_values_reward_fdr = mne.stats.fdr_correction(np.array(p_values_reward), alpha=0.05, method='indep')
            # print(p_values_reward)
            # print(p_values_reward_fdr)
            reject = reject.astype(int)
            reject[0] = 0
            reject[-1] = 0
            # print(reject)
            starts = np.where(np.diff(reject) == 1)[0]
            ends = np.where(np.diff(reject) == -1)[0]
            print(starts)
            print(ends)
            height = (ax.get_ylim()[0] + (0.10 * (ax.get_ylim()[1]-ax.get_ylim()[0])))
            timepoints = np.array(means.columns, dtype=float)
            if len(starts) > 0:
                for s, e in zip(starts, ends):
                    x = np.linspace(timepoints[int(s)+1], timepoints[int(e)+1], 3)
                    print(x)
                    plt.plot(x, np.ones(3) * height, lw=4, color='orange')

            reject, p_values_outcome_fdr = mne.stats.fdr_correction(np.array(p_values_outcome), alpha=0.05, method='indep')
            # print(p_values_outcome)
            # print(p_values_outcome_fdr)
            reject = reject.astype(int)
            reject[0] = 0
            reject[-1] = 0
            starts = np.where(np.diff(reject) == 1)[0]
            ends = np.where(np.diff(reject) == -1)[0]
            # print(starts)
            height = (ax.get_ylim()[0] + (0.05 * (ax.get_ylim()[1]-ax.get_ylim()[0])))
            timepoints = np.array(means.columns, dtype=float)
            if len(starts) > 0:
                for s, e in zip(starts, ends):
                    x = np.linspace(timepoints[int(s)+1], timepoints[int(e)+1], 3)
                    print(x)
                    plt.plot(x, np.ones(3) * height, lw=4, color='green')
        except:
            pass

    elif split == 'sdt-1_reward':
        epochs['outcome_p'] = np.array(df['outcome_p'])
        epochs = epochs.set_index(['outcome_p'], append=True)
        if random:
            means = epochs.groupby(['subject_id', 'reward', 'outcome_p']).mean().groupby(['reward', 'outcome_p']).mean()
            sems = epochs.groupby(['subject_id', 'reward', 'outcome_p']).mean().groupby(['reward', 'outcome_p']).sem()
        else:
            means = epochs.groupby(['reward', 'outcome_p']).mean()
            sems = epochs.groupby(['reward', 'outcome_p']).sem()
        
        # print(means)
        means = means.loc[means.index.get_level_values('outcome_p')!=3,:]
        sems = sems.loc[sems.index.get_level_values('outcome_p')!=3,:]
        
        x = np.array(epochs.columns, dtype=float)
        for i, color, ls in zip([0,1,2,3,4,5], ['mediumseagreen', 'lightslategrey', 'lightsalmon', 'seagreen', 'slategrey', 'salmon'], [':', ':', ':', '-', '-', '-']):
            plt.fill_between(x, means.iloc[i]-sems.iloc[i], means.iloc[i]+sems.iloc[i], color=color, alpha=0.2)
            plt.plot(x, means.iloc[i], color=color, ls=ls, lw=1)
        plt.axvline(0, color='k', lw=0.5)
        # plt.xlim(-2.5,5)
        plt.xlabel('Time (s)')
        # plt.ylabel(measure)

    elif split == 'sdt-1':
        epochs['outcome_p'] = np.array(df['outcome_p'])
        epochs = epochs.set_index(['outcome_p'], append=True)
        if random:
            means = epochs.groupby(['subject_id', 'outcome_p']).mean().groupby(['outcome_p']).mean()
            sems = epochs.groupby(['subject_id', 'outcome_p']).mean().groupby(['outcome_p']).sem()
        else:
            means = epochs.groupby(['outcome_p']).mean()
            sems = epochs.groupby(['outcome_p']).sem()
        
        # print(means)
        means = means.loc[means.index.get_level_values('outcome_p')!=3,:]
        sems = sems.loc[sems.index.get_level_values('outcome_p')!=3,:]
        
        x = np.array(epochs.columns, dtype=float)
        for i, color, ls in zip([0,1,2], ['seagreen', 'slategrey', 'salmon'], ['-', '-', '-']):
            plt.fill_between(x, means.iloc[i]-sems.iloc[i], means.iloc[i]+sems.iloc[i], color=color, alpha=0.2)
            plt.plot(x, means.iloc[i], color=color, ls=ls, lw=1)
        plt.axvline(0, color='k', lw=0.5)
        # plt.xlim(-2.5,5)
        plt.xlabel('Time (s)')
        # plt.ylabel(measure)

    elif split == 'sdt':
        epochs['outcome'] = np.array(df['outcome'])
        epochs = epochs.set_index(['outcome'], append=True)
        if random:
            means = epochs.groupby(['subject_id', 'outcome']).mean().groupby(['outcome']).mean()
            sems = epochs.groupby(['subject_id', 'outcome']).mean().groupby(['outcome']).sem()
        else:
            means = epochs.groupby(['outcome']).mean()
            sems = epochs.groupby(['outcome']).sem()
        
        print(means)
        means = means.loc[means.index.get_level_values('outcome')!=3,:]
        sems = sems.loc[sems.index.get_level_values('outcome')!=3,:]
        
        x = np.array(epochs.columns, dtype=float)
        for i, color, ls in zip([0,1,2], ['seagreen', 'slategrey', 'salmon'], ['-', '-', '-']):
            plt.fill_between(x, means.iloc[i]-sems.iloc[i], means.iloc[i]+sems.iloc[i], color=color, alpha=0.2)
            plt.plot(x, means.iloc[i], color=color, ls=ls, lw=1)
        plt.axvline(0, color='k', lw=0.5)
        # plt.xlim(-2.5,5)
        plt.xlabel('Time (s)')
        # plt.ylabel(measure)


    sns.despine(trim=False)
    plt.tight_layout()
    return ax