import glob
import re

import numpy as np
import pandas as pd

plot_styles = "styles/paperdraft.mplstyle"
plot_filetype = "pdf"

critical_mass_ensembles = pd.read_csv("metadata/critical_mass_tuning.csv")
critical_mass_targets = critical_mass_ensembles.drop(columns=["measure_spectrum", "m", "nsteps"]).drop_duplicates()

# Kludge to pick the right filename for both Nf=1 and Nf=2 cases.
mass_extrapolation_filename_template = f"assets/plots/mass_extrapolation_{{Npv}}pv_mpv{{mpv}}_beta{{beta}}.{plot_filetype}"
if {'Npv': 5, 'beta': 2.35, 'mpv': 0.5} in critical_mass_targets.to_dict(orient="records"):
    mass_extrapolation_filename = mass_extrapolation_filename_template.format(Npv=5, beta=2.35, mpv=0.5)
else:
    mass_extrapolation_filename = mass_extrapolation_filename_template.format(Npv=15, beta=2.7, mpv=0.5)

Npvs = [5, 10, 15]
mpvs = [0.5]

try:
    production_ensembles = pd.read_csv("metadata/production.csv")
    Npvs = sorted(set(production_ensembles.Npv))
    mpvs = sorted(set(production_ensembles.mpv))
    production_targets = [
        f"assets/plots/volume_extrapolation_sym.{plot_filetype}",
        f"assets/plots/g2_plaquette_comparison_sym.{plot_filetype}",
        f"assets/plots/beta_interpolation_finite_a_combined.{plot_filetype}",
    ] + [
        f"assets/plots/beta_interpolation_finite_a_{Npv}pv_mpv{mpv}_sym.{plot_filetype}"
        for Npv in Npvs
        for mpv in mpvs
    ]
except FileNotFoundError:
    production_ensembles = pd.DataFrame()
    production_targets = []

thermalisation_mdtu = 2000

operators = ["plaq", "sym"]
interpolate_fit_order = 3

def single_ensemble_metadata(wildcards):
    subset = production_ensembles[
       (production_ensembles.Npv == int(wildcards.Npv))
       & (production_ensembles.beta == float(wildcards.beta))
       & (production_ensembles.mpv == float(wildcards.mpv))
       & (production_ensembles.L == int(wildcards.L))
    ]
    if len(subset) != 1:
        raise ValueError(f"Expected 1 ensemble for {wildcards=}; found {len(subset)}")

    return subset.iloc[0]


rule all:
    input:
        f"assets/plots/phasediagram.{plot_filetype}",
        "intermediary_data/critical_mass/target_mass.csv",
        mass_extrapolation_filename,
        production_targets,


rule phasediagram:
    input:
        datafiles=glob.glob("raw_data/phasediagram/*/out_hmc_*"),
        script="src/phasediagram.py",
        plot_styles=plot_styles,
    output:
        "assets/plots/phasediagram.{plot_filetype}"
    conda:
        "envs/environment.yml"
    priority:
        10
    shell:
        "python {input.script} --input_dirname raw_data/phasediagram --threepanel_plot_filename {output} --combined_plot_filename /dev/null --plot_styles {input.plot_styles}"


mpcac_datafile = "intermediary_data/critical_mass/{Npv}pv/beta{beta}/m{m}/mpv{mpv}/mpcac_{Npv}pv_beta{beta}_m{m}_mpv{mpv}_{nsteps}steps.json.gz"
rule fit_mpcac:
    input:
        datafile="raw_data/critical_mass/{Npv}pv/beta{beta}/m{m}/mpv{mpv}/out_corr_{Npv}pv_beta{beta}_m{m}_mpv{mpv}_{nsteps}steps_0",
        script="src/mpcac.py",
    output:
        datafile=mpcac_datafile,
        plotfile=f"intermediary_data/critical_mass/{{Npv}}pv/beta{{beta}}/m{{m}}/mpv{{mpv}}/effmass_{{Npv}}pv_beta{{beta}}_m{{m}}_mpv{{mpv}}_{{nsteps}}steps.{plot_filetype}",
    conda:
        "envs/environment.yml"
    shell:
        "python {input.script} {input.datafile} --output_filename {output.datafile} --plot_filename {output.plotfile} --Npv {wildcards.Npv} --mpv {wildcards.mpv}"


def mass_inputs(wildcards):
    ensembles = critical_mass_ensembles[
        (critical_mass_ensembles.Npv == int(wildcards["Npv"]))
        & (critical_mass_ensembles.beta == float(wildcards["beta"]))
        & (critical_mass_ensembles.mpv == float(wildcards["mpv"]))
        & (critical_mass_ensembles.measure_spectrum)
    ]
    return [mpcac_datafile.format(**ensemble) for ensemble in ensembles.to_dict("records")]


critical_mass_datafile = "intermediary_data/critical_mass/{Npv}pv/beta{beta}/mpv{mpv}/critical_mf.json.gz"
rule critical_mass:
    input:
        datafiles=mass_inputs,
        script="src/critical_mf.py",
    output:
        datafile=critical_mass_datafile,
        plotfile=f"intermediary_data/critical_mass/{{Npv}}pv/beta{{beta}}/mpv{{mpv}}/mf_extrapolation.{plot_filetype}",
    conda:
        "envs/environment.yml"
    shell:
        "python {input.script} {input.datafiles} --output_filename {output.datafile} --plot_filename {output.plotfile}"


rule get_critical_mass_plot:
    input:
        "intermediary_data/critical_mass/{Npv}pv/beta{beta}/mpv{mpv}/mf_extrapolation.{plot_filetype}",
    output:
        "assets/plots/mass_extrapolation_{Npv}pv_mpv{mpv}_beta{beta}.{plot_filetype}",
    shell:
        "cp {input} {output}"


critical_mass_files = [
    critical_mass_datafile.format(**target) for target in critical_mass_targets.to_dict("records")
]


rule collate_critical_masses:
    input:
        datafiles=critical_mass_files,
        script="src/collate_critical_mf.py",
    output:
        csv="intermediary_data/critical_mass/target_mass.csv",
    conda:
        "envs/environment.yml"
    shell:
        "python {input.script} {input.datafiles} --output_file {output.csv}"


def parse_flow_filename(filename):
    index, jobid = re.match(".*/out_wflow_n([0-9]+)_([0-9]+)_0", filename).groups()
    return int(index), int(jobid)


def single_flows(wildcards):
    metadata = single_ensemble_metadata(wildcards)
    return sorted(
        [
            filename
            for filename in glob.glob(f"raw_data/wilson_flow/{int(metadata.Npv)}pv/beta{metadata.beta}/m{metadata.m}/mpv{metadata.mpv}/L{int(wildcards.L)}/out_wflow_*")
            if parse_flow_filename(filename)[0] * metadata.trajectory_length > thermalisation_mdtu
        ],
        key=lambda filename: parse_flow_filename(filename)[0],
    )


rule collate_flows:
    input:
        datafiles=single_flows,
    output:
        datafile="intermediary_data/wilson_flow/{Npv}pv/beta{beta}/mpv{mpv}/out_wflow_{Npv}pv_beta{beta}_mpv{mpv}_L{L}",
    shell:
        "touch {output.datafile} && if [[ '{input.datafiles}' != '' ]]; then cat {input.datafiles} > {output.datafile}; fi"


def volume_extrapolation_ensembles(wildcards):
    subset = production_ensembles[
       (production_ensembles.Npv == int(wildcards.Npv))
       & (production_ensembles.beta == float(wildcards.beta))
       & (production_ensembles.mpv == float(wildcards.mpv))
    ]
    return [
        f"intermediary_data/wilson_flow/{wildcards.Npv}pv/beta{wildcards.beta}/mpv{wildcards.mpv}/out_wflow_{wildcards.Npv}pv_beta{wildcards.beta}_mpv{wildcards.mpv}_L{ensemble['L']}"
        for ensemble in subset.to_dict("records")
        if ensemble["use"]
    ]


rule extrapolate_infinite_volume:
    input:
        data=volume_extrapolation_ensembles,
        script="src/extrapolate_infinite_volume.py",
    output:
        "intermediary_data/beta_function/infinite_volume/{Npv}pv/mpv{mpv}/beta{beta}/t{time}_{operator}.json.gz",
    conda:
        "envs/environment.yml"
    shell:
        "python {input.script} {input.data} --output_filename {output} --operator {wildcards.operator} --time {wildcards.time} --Npv {wildcards.Npv} --mpv {wildcards.mpv} --beta {wildcards.beta}"


def volume_extrapolation_plot_inputs(wildcards):
    params = (5, 0.5, 2.35), (5, 0.5, 2.5), (10, 0.5, 2.4), (15, 0.5, 2.7)
    times = [2.5, 3.5, 4.5, 6.0]
    return [
        f"intermediary_data/beta_function/infinite_volume/{Npv}pv/mpv{mpv}/beta{beta}/t{time}_{{operator}}.json.gz"
        for Npv, mpv, beta in params
        for time in times
    ]


rule plot_volume_extrapolation:
    input:
        data=expand(
            "intermediary_data/beta_function/infinite_volume/{{Npv}}pv/mpv{{mpv}}/beta{{beta}}/t{time}_{{operator}}.json.gz",
            time=[2.5, 3.5, 4.5, 6.0],
        ),
        script="src/plot_infinite_volume_extrapolation.py",
        plot_styles=plot_styles,
    output:
        "intermediary_data/beta_function/infinite_volume/{Npv}pv/mpv{mpv}/beta{beta}_{operator}.{plot_filetype}",
    conda:
        "envs/environment.yml"
    shell:
        "python {input.script} {input.data} --plot_filename {output} --plot_styles {input.plot_styles}"


rule plot_volume_extrapolation_combined:
    input:
        data=volume_extrapolation_plot_inputs,
        script="src/plot_infinite_volume_extrapolation.py",
        plot_styles=plot_styles,
    output:
        "assets/plots/volume_extrapolation_{operator}.{plot_filetype}",
    conda:
        "envs/environment.yml"
    shell:
        "python {input.script} {input.data} --plot_filename {output} --plot_styles {input.plot_styles}"


def g2_comparison_inputs(wildcards):
    return [
        f"intermediary_data/beta_function/infinite_volume/{Npv}pv/mpv{mpv}/beta{beta}/t6.0_{{operator}}.json.gz"
        for Npv, mpv, beta in set(production_ensembles[["Npv", "mpv", "beta"]].itertuples(index=False))
    ]


rule g2_comparison:
    input:
        data=g2_comparison_inputs,
        script="src/plot_g2_against_beta0.py",
        plot_styles=plot_styles,
    output:
        "assets/plots/g2_plaquette_comparison_{operator}.{plot_filetype}",
    conda:
        "envs/environment.yml"
    shell:
        "python {input.script} {input.data} --plot_filename {output} --plot_styles {input.plot_styles}"


def finite_a_betas(wildcards):
    subset = production_ensembles[
       (production_ensembles.Npv == int(wildcards.Npv))
       & (production_ensembles.mpv == float(wildcards.mpv))
    ]
    return [
        f"intermediary_data/beta_function/infinite_volume/{{Npv}}pv/mpv{{mpv}}/beta{beta}/t{{time}}_{{operator}}.json.gz"
        for beta in set(subset.beta)
    ]


rule interpolate_finite_a:
    input:
        data=finite_a_betas,
        script="src/fit_beta_against_g2.py",
    output:
        "intermediary_data/beta_interpolation/{Npv}pv/mpv{mpv}/t{time}_{operator}.json.gz",
    params:
        fit_order=interpolate_fit_order,
    conda:
        "envs/environment.yml"
    shell:
        "python {input.script} {input.data} --order {params.fit_order} --output_filename {output}"


finite_a_plot_times = [2.5, 3.5, 4.5, 6.0]

rule plot_finite_a_interpolation:
    input:
        data=expand(
            "intermediary_data/beta_interpolation/{{Npv}}pv/mpv{{mpv}}/t{time}_{{operator}}.json.gz",
            time=finite_a_plot_times,
        ),
        script="src/plot_beta_against_g2.py",
        plot_styles=plot_styles,
    output:
        "assets/plots/beta_interpolation_finite_a_{Npv}pv_mpv{mpv}_{operator}.{plot_filetype}",
    conda:
        "envs/environment.yml"
    shell:
        "python {input.script} {input.data} --plot_styles {input.plot_styles} --plot_filename {output}"


rule plot_finite_a_interpolation_combined:
    input:
        data=expand(
            "intermediary_data/beta_interpolation/{Npv}pv/mpv{mpv}/t{time}_{operator}.json.gz",
            time=finite_a_plot_times,
            Npv=Npvs,
            operator=operators,
            mpv=mpvs,
        ),
        script="src/plot_beta_against_g2.py",
        plot_styles=plot_styles,
    output:
        "assets/plots/beta_interpolation_finite_a_combined.{plot_filetype}",
    conda:
        "envs/environment.yml"
    shell:
        "python {input.script} {input.data} --plot_styles {input.plot_styles} --plot_filename {output}"
