#!/bin/python3

#----------------------------------------------------------
# Setup
#----------------------------------------------------------

import os
import re

import pandas as pd
import scipy as sp
import numpy as np


# Output Directory
OUTPUT_DIR = ''

# S_het Path
S_HET_PATH = ''

# Monotonicity Estimates
MOM_ESTIMATES = ''

# Selected phenotypes
SEL_PHENOS = ''

# LoF Summary Statistics
LOF_DIR = ''

# CNV Summary Statistics
CNV_DIR = ''

# Duplication Burden LD
DUP_LD_DIR = ''

# Simulation Directory
SIM_DIR = ''


os.chdir(OUTPUT_DIR)

#----------------------------------------------------------
# Extract Traits for Analysis
#----------------------------------------------------------

mom = pd.read_csv(MOM_ESTIMATES)
mom = mom[mom.selection == 'All']
z_crit = sp.stats.norm.ppf(1 - (0.05 / 2))
mom = mom[np.abs(mom.gamma_bar_1_hat / mom.gamma_bar_1_hat_se) > z_crit]
mom = mom[np.abs(mom.gamma_bar_2_hat / mom.gamma_bar_2_hat_se) > z_crit]

sel_phenos = pd.read_csv(SEL_PHENOS)

field_ids = np.intersect1d(sel_phenos.field_id, mom.field_id).tolist()

experiments = ['xi_both', 'm_only', 'nm_only']
experiment_numbers = list(range(100))

epochs = list(range(5))

#----------------------------------------------------------
# Rules
#----------------------------------------------------------

rule all:
    input:
        'xi_mash_simulations.csv',
        'xi_mash.csv'

#----------------------------------------------------------
# RSS MASH under Simulation
#----------------------------------------------------------

rule run_rss_mash_mcem_simulation:
    output:
        '{experiment}.rss_mash_mcem.number_{experiment_number}.csv'
    threads: 1
    params:
        source=workflow.source_path('rss_mash_mcem_simulation.py'),
        source_import=workflow.source_path('load_summary_statistics.py'),
        s_het_path=S_HET_PATH,
        sim_dir=SIM_DIR,
        dup_ld_dir=DUP_LD_DIR
    script:
        'rss_mash_mcem_simulation.sh'


rule run_rss_mash_simulation:
    input:
        '{experiment}.rss_mash_mcem.number_{experiment_number}.csv'
    output:
        '{experiment}.rss_mash_samples.number_{experiment_number}.csv'
    threads: 1
    params:
        source=workflow.source_path('rss_mash_simulation.py'),
        source_import=workflow.source_path('load_summary_statistics.py'),
        s_het_path=S_HET_PATH,
        sim_dir=SIM_DIR,
        dup_ld_dir=DUP_LD_DIR
    script:
        'rss_mash_simulation.sh'


rule calculate_xi_mash_simulation:
    input:
        expand('{experiment}.rss_mash_samples.number_{experiment_number}.csv', experiment=experiments, experiment_number=experiment_numbers)
    output:
        'xi_mash_simulations.csv'
    threads: 1
    params:
        source=workflow.source_path('calculate_xi_mash_simulation.py')
    script:
        'calculate_xi_mash_simulation.sh'

#----------------------------------------------------------
# RSS MASH on Traits
#----------------------------------------------------------

rule run_rss_mash_mcem_start:
    output:
        pi=temp('{field_id}.rss_mash_mcem.start.csv'),
        gamma_list=temp('{field_id}.rss_mash_mcem.start.pkl')
    shell:
        '''
        touch {output.pi}
        touch {output.gamma_list}
        '''


def run_rss_mash_mcem_input_pi(wildcards):

    output_epoch = int(wildcards.epoch)

    if output_epoch == 0:
        return f'{wildcards.field_id}.rss_mash_mcem.start.csv'
    else:
        return f'{wildcards.field_id}.rss_mash_mcem.epoch_{output_epoch - 1}.csv'

def run_rss_mash_mcem_input_gamma_list(wildcards):

    output_epoch = int(wildcards.epoch)

    if output_epoch == 0:
        return f'{wildcards.field_id}.rss_mash_mcem.start.pkl'
    else:
        return f'{wildcards.field_id}.rss_mash_mcem.epoch_{output_epoch - 1}.pkl'


rule run_rss_mash_mcem:
    input:
        pi=run_rss_mash_mcem_input_pi,
        gamma_list=run_rss_mash_mcem_input_gamma_list
    output:
        pi='{field_id}.rss_mash_mcem.epoch_{epoch}.csv',
        gamma_list='{field_id}.rss_mash_mcem.epoch_{epoch}.pkl'
    threads: 1
    params:
        source=workflow.source_path('rss_mash_mcem.py'),
        source_import=workflow.source_path('load_summary_statistics.py'),
        s_het_path=S_HET_PATH,
        lof_dir=LOF_DIR,
        cnv_dir=CNV_DIR,
        dup_ld_dir=DUP_LD_DIR
    script:
        'rss_mash_mcem.sh'


rule run_rss_mash_mcem_finish:
    input:
        '{field_id}.rss_mash_mcem.epoch_4.csv'
    output:
        '{field_id}.rss_mash_mcem.csv'
    shell:
        '''
        cp {input} {output}
        '''


rule run_rss_mash:
    input:
        '{field_id}.rss_mash_mcem.csv'
    output:
        '{field_id}.rss_mash_samples.csv'
    threads: 1
    params:
        source=workflow.source_path('rss_mash.py'),
        source_import=workflow.source_path('load_summary_statistics.py'),
        s_het_path=S_HET_PATH,
        lof_dir=LOF_DIR,
        cnv_dir=CNV_DIR,
        dup_ld_dir=DUP_LD_DIR
    script:
        'rss_mash.sh'


rule calculate_xi_mash:
    input:
        expand('{field_id}.rss_mash_samples.csv', field_id=field_ids)
    output:
        'xi_mash.csv'
    threads: 1
    params:
        source=workflow.source_path('calculate_xi_mash.py')
    script:
        'calculate_xi_mash.sh'
