#!/bin/python3

#----------------------------------------------------------
# Setup
#----------------------------------------------------------

import os
import re


# Output Directory
OUTPUT_DIR = ''

# LoF Summary Statistics
LOF_DIR = ''

# CNV Summary Statistics
CNV_DIR = ''

# Duplication Burden LD
DUP_LD_DIR = ''

# Duplication Burden LD (Uncorrected)
DUP_LD_DIR_UNCORRECTED = ''

# S_het Estimates
S_HET_PATH = ''

# Simulation Directory
SIM_DIR = ''


os.chdir(OUTPUT_DIR)

#----------------------------------------------------------
# Extract Traits for Analysis
#----------------------------------------------------------

wildcard_constraints:
    field_id=r'p\d+',
    chrom=r'\d+'

field_ids = list({re.sub(r'\..*$', '', x) for x in os.listdir(LOF_DIR)})

chroms = list(range(23))

experiments_mom = ['gamma_bar', 'var_gamma_bar', 'sigma2', 'rho']
experiments_mle = ['gamma_bar', 'sigma2', 'rho']
experiment_batches = list(range(10))

#----------------------------------------------------------
# Rules
#----------------------------------------------------------

rule all:
    input:
        'monotonicity.csv',
        'genetic_correlation.csv',
        'mom_estimates.csv',
        'sq_eff_estimates.csv',
        expand('{field_id}.gamma_2_hat_mle.csv', field_id=field_ids),
        'ngd_simulations.csv',
        'mom_simulations.csv',
        'mse_simulations.csv'

#----------------------------------------------------------
# Simulations
#----------------------------------------------------------

rule ngd_simulations:
    output:
        temp('{experiment}.ngd_simulations.batch_{experiment_batch}.csv')
    threads: 1
    params:
        source=workflow.source_path('ngd_simulations.py'),
        source_import_1=workflow.source_path('load_summary_statistics.py'),
        source_import_2=workflow.source_path('monotonicity_ngd.py'),
        s_het_path=S_HET_PATH,
        sim_dir=SIM_DIR,
        dup_ld_dir=DUP_LD_DIR_UNCORRECTED
    script:
        'ngd_simulations.sh'

rule combine_ngd_simulations:
    input:
        expand('{experiment}.ngd_simulations.batch_{experiment_batch}.csv', experiment=experiments_mle, experiment_batch=experiment_batches)
    output:
        'ngd_simulations.csv'
    shell:
        '''
        echo "experiment,iteration,gamma_bar_1,gamma_bar_2,sigma2_11,sigma_12,sigma2_22,rho,phi,phi_hat,phi_prime_hat,phi_prime_hat_se,gamma_bar_1_hat,gamma_bar_1_hat_se,gamma_bar_2_hat,gamma_bar_2_hat_se,sigma2_11_hat,sigma_12_hat,sigma2_22_hat,ll,llr" > {output}

        cat {input} >> {output}
        '''

#----------------------------------------------------------
# MoM Simulations
#----------------------------------------------------------

rule mom_simulations:
    output:
        temp('{experiment}.mom_simulations.batch_{experiment_batch}.csv')
    threads: 1
    params:
        source=workflow.source_path('mom_simulations.py'),
        source_import_1=workflow.source_path('load_summary_statistics.py'),
        source_import_2=workflow.source_path('monotonicity_ngd.py'),
        s_het_path=S_HET_PATH,
        sim_dir=SIM_DIR,
        dup_ld_dir=DUP_LD_DIR_UNCORRECTED
    script:
        'mom_simulations.sh'

rule combine_mom_simulations:
    input:
        expand('{experiment}.mom_simulations.batch_{experiment_batch}.csv', experiment=experiments_mom, experiment_batch=experiment_batches)
    output:
        'mom_simulations.csv'
    shell:
        '''
        echo "experiment,iteration,gamma_bar_1,gamma_bar_2,sigma2_11,sigma_12,sigma2_22,rho,phi,phi_hat,phi_hat_se,gamma_bar_1_hat,gamma_bar_1_hat_se,gamma_bar_2_hat,gamma_bar_2_hat_se,sigma2_11_hat,sigma2_11_hat_se,sigma_12_hat,sigma2_22_hat,sigma2_22_hat_se" > {output}

        cat {input} >> {output}
        '''

#----------------------------------------------------------
# Mean Squared Effect Simulations
#----------------------------------------------------------

rule mse_simulations:
    output:
        temp('var_gamma_bar_sq.mse_simulations.batch_{experiment_batch}.csv')
    threads: 1
    params:
        source=workflow.source_path('mom_simulations.py'),
        source_import_1=workflow.source_path('load_summary_statistics.py'),
        source_import_2=workflow.source_path('monotonicity_ngd.py'),
        s_het_path=S_HET_PATH,
        sim_dir=SIM_DIR,
        dup_ld_dir=DUP_LD_DIR_UNCORRECTED
    script:
        'mse_simulations.sh'

rule combine_mse_simulations:
    input:
        expand('var_gamma_bar_sq.mse_simulations.batch_{experiment_batch}.csv', experiment_batch=experiment_batches)
    output:
        'mse_simulations.csv'
    shell:
        '''
        echo "experiment,iteration,gamma_bar_sq_1,gamma_bar_sq_2,gamma2_1_hat,gamma2_1_hat_se,gamma2_2_hat,gamma2_2_hat_se" > {output}

        cat {input} >> {output}
        '''

#----------------------------------------------------------
# Estimate Means
#----------------------------------------------------------

# Estimate mean for a field

rule estimate_mean:
    output:
        estimates=temp('{field_id}.mean.csv'),
        sq_eff=temp('{field_id}.sq_eff.csv'),
        mle='{field_id}.gamma_2_hat_mle.csv'
    threads: 1
    retries: 3
    params:
        source=workflow.source_path('estimate_mean.py'),
        source_import=workflow.source_path('load_summary_statistics.py'),
        lof_dir=LOF_DIR,
        cnv_dir=CNV_DIR,
        dup_ld_dir=DUP_LD_DIR,
        s_het_path=S_HET_PATH
    script:
        'estimate_mean.sh'

# Combine means across fields

rule combine_mean:
    input:
        mom_fields=expand('{field_id}.mean.csv', field_id=field_ids),
        sq_eff_fields=expand('{field_id}.sq_eff.csv', field_id=field_ids)
    output:
        mom='mom_estimates.csv',
        sq_eff='sq_eff_estimates.csv'
    shell:
        '''
        echo "field_id,class_1,class_2,selection,phi_hat,phi_hat_se,gamma_bar_1_hat,gamma_bar_1_hat_se,gamma_bar_2_hat,gamma_bar_2_hat_se,sigma2_11_hat,sigma2_11_hat_se,sigma_12_hat,sigma2_22_hat,sigma2_22_hat_se" > {output.mom}

        cat {input.mom_fields} >> {output.mom}

        echo "field_id,class_1,class_2,selection,gamma2_1_hat,gamma2_1_hat_se,gamma2_2_hat,gamma2_2_hat_se" > {output.sq_eff}

        cat {input.sq_eff_fields} >> {output.sq_eff}
        '''

#----------------------------------------------------------
# Estimate Monotonicity
#----------------------------------------------------------

def get_walltime_monotonicity(wildcards, attempt):
    return attempt * 120

# Estimate monotonicity for a field

rule estimate_monotonicity:
    output:
        temp('{field_id}.monotonicity.csv')
    threads: 1
    retries: 3
    resources:
        runtime=get_walltime_monotonicity
    params:
        source=workflow.source_path('estimate_monotonicity.py'),
        source_import_1=workflow.source_path('load_summary_statistics.py'),
        source_import_2=workflow.source_path('monotonicity_ngd.py'),
        lof_dir=LOF_DIR,
        cnv_dir=CNV_DIR,
        dup_ld_dir=DUP_LD_DIR,
        s_het_path=S_HET_PATH
    script:
        'estimate_monotonicity.sh'

# Combine monotonicities across fields

rule combine_monotonicity:
    input:
        fields=expand('{field_id}.monotonicity.csv', field_id=field_ids)
    output:
        traits='monotonicity.csv'
    shell:
        '''
        echo "field_id,class_1,class_2,phi_hat,phi_prime_hat,phi_prime_hat_se,gamma_bar_1_hat,gamma_bar_1_hat_se,gamma_bar_2_hat,gamma_bar_2_hat_se,sigma2_11_hat,sigma_12_hat,sigma2_22_hat,ll,llr" > {output.traits}

        cat {input.fields} >> {output.traits}
        '''

#----------------------------------------------------------
# Estimate Genetic Correlation
#----------------------------------------------------------

def get_walltime_genetic_correlation(wildcards, attempt):
    return attempt * 120

# Estimate genetic correlation for a field

rule estimate_genetic_correlation:
    output:
        corr=temp('{field_id}.gen_corr.csv'),
        mom=temp('{field_id}.del_mom.csv')
    threads: 1
    retries: 3
    resources:
        runtime=get_walltime_genetic_correlation
    params:
        source=workflow.source_path('estimate_genetic_correlation.py'),
        source_import_2=workflow.source_path('genetic_correlation_ngd.py'),
        lof_dir=LOF_DIR,
        cnv_dir=CNV_DIR,
        del_ld_dir=DUP_LD_DIR
    script:
        'estimate_genetic_correlation.sh'

# Combine genetic correlations across fields

rule combine_genetic_correlation:
    input:
        fields=expand('{field_id}.gen_corr.csv', field_id=field_ids),
        mom_estimates=expand('{field_id}.del_mom.csv', field_id=field_ids)
    output:
        traits='genetic_correlation.csv',
        mom_estimates='del_mom_estimates.csv'
    shell:
        '''
        echo "field_id,class_1,class_2,corr_hat,corr_prime_hat,corr_prime_hat_se,sigma2_11_hat,sigma_12_hat,sigma2_22_hat,ll,llr" > {output.traits}

        cat {input.fields} >> {output.traits}

        echo "iteration,class_1,class_2,rho_hat,rho_hat_se,gamma_bar_1_hat,gamma_bar_1_hat_se,gamma_bar_2_hat,gamma_bar_2_hat_se,sigma2_11_hat,sigma2_11_hat_se,sigma_12_hat,sigma2_22_hat,sigma2_22_hat_se" > {output.mom_estimates}

        cat {input.mom_estimates} >> {output.mom_estimates}
        '''
