import os

DATA_ROOT = "data"

PARTS = ['main', 'supp']
MODEL_TO_RUN = list(range(1, 11))
MODEL_AIC = list(range(2, 11))

REG_TABLE = os.path.join(DATA_ROOT, 'reg_table_{part}.fst')

############################################
# Generate more variables
REG_TABLE_NEW_VAR = os.path.join(DATA_ROOT, 'output', 'reg_new_var_{part}.fst')
rule add_new_variable_all:
    input: expand(REG_TABLE_NEW_VAR, part=PARTS)

rule add_new_variable:
    input: REG_TABLE
    output: REG_TABLE_NEW_VAR
    shell: "Rscript add_new_var.R {input} {output}"

############################################
# Create formula
FORMULA = os.path.join(DATA_ROOT, 'output', 'model_formula_{part}.RDS')
rule construct_the_model_formula_all:
    input: expand(FORMULA, part=PARTS)

rule construct_the_model_formula:
    input: REG_TABLE_NEW_VAR
    params: "{part}"
    output: FORMULA
    shell: "Rscript construct_the_model_formula.R {input} {params} {output}"

############################################
# Run model
MODEL_RESULT = os.path.join(DATA_ROOT, 'output', 'reg_results', 'reg_results_{part}_{index}.RDS')
MODEL_TO_RUN = list(range(1, 11))
MODEL_AIC = list(range(2, 11))

rule run_regression_model_all:
    input: expand(MODEL_RESULT, part=PARTS, index=MODEL_TO_RUN)

rule run_regression_model:
    input: REG_TABLE_NEW_VAR, FORMULA
    params: "{index}"
    output: MODEL_RESULT
    shell: "Rscript run_regression_model.R {input} {params} {output}"


############################################
# Gen figures
FIG_ROOT = os.path.join(DATA_ROOT, 'output', 'figures')

FIG_CORRELATION = os.path.join(FIG_ROOT, 'correlation_{part}.png')
rule gen_fig_corr:
    input: REG_TABLE_NEW_VAR
    output: FIG_CORRELATION
    shell: "Rscript gen_fig_corr.R {input} {output}"

FIG_LINEAR = os.path.join(FIG_ROOT, 'linear_association_{part}.png')
rule gen_fig_linear:
    input: REG_TABLE_NEW_VAR
    output: FIG_LINEAR
    shell: "Rscript gen_fig_linear.R {input} {output}"

FIG_AIC = os.path.join(FIG_ROOT, 'aic_{part}.png')
rule gen_fig_aic:
    input: expand(MODEL_RESULT, index=MODEL_AIC, allow_missing=True)
    output: FIG_AIC
    shell: "Rscript gen_fig_aic.R {input} {output}"

rule all:
    input:
        expand(MODEL_RESULT, part=PARTS, index=MODEL_TO_RUN),
        expand(FIG_CORRELATION, part=PARTS),
        expand(FIG_LINEAR, part=PARTS),
        expand(FIG_AIC, part=PARTS)