# see README or genome_grist/__main__.py for top-level targets.
import snakemake
srcdir = workflow.source_path

import urllib
import os, csv, tempfile, shutil, sys, gzip
from functools import lru_cache
import polars

if config.get('sample'):
    print(f"ERROR: this is an old config file. Please use 'samples' instead of 'sample'.",
          file=sys.stderr)
    sys.exit(-1)

SAMPLES=config['samples']
print(f'samples: {SAMPLES}', file=sys.stderr)
assert isinstance(SAMPLES, list), "config 'samples' must be a list."

fail = False
for sample in SAMPLES:
    if '.' in sample:
        print(f"sample name '{sample}' contains a period; please remove",
              file=sys.stderr)
        fail = True
    if '/' in sample:
        print(f"sample name '{sample}' contains a forward slash ('/'); please remove",
              file=sys.stderr)
        fail = True

if fail:
    sys.exit(-1)

outdir = config.get('outdir', 'outputs/')
outdir = outdir.rstrip('/')
print('outdir:', outdir, file=sys.stderr)

ABUNDTRIM_MEMORY = float(config.get('metagenome_trim_memory', '1e9'))

GENBANK_CACHE = config.get('genbank_cache', './genbank_cache/')
GENBANK_CACHE = os.path.normpath(GENBANK_CACHE)

### collect databases

if config.get('sourmash_database_glob_pattern'):
    print(f"ERROR: this is an old config file. Please provide a list of 'sourmash_databases' instead of a 'sourmash_database_glob_pattern'",
          file=sys.stderr)
    sys.exit(-1)

sourmash_databases = config.get('sourmash_databases', [])
local_databases_info = config.get('local_databases_info', [])

SOURMASH_DB_LIST = sourmash_databases
if not SOURMASH_DB_LIST:
    print(f"ERROR: no sourmash_databases specified in config file!",
          file=sys.stderr)
    sys.exit(-1)

SOURMASH_DB_KSIZE = config.get('sourmash_database_ksize', ['31'])
SOURMASH_DATABASE_THRESHOLD_BP = config.get('sourmash_database_threshold_bp',
                                            1e5)
SOURMASH_COMPUTE_KSIZES = config.get('sourmash_compute_ksizes',
                                     ['21', '31',' 51'])
SOURMASH_COMPUTE_SCALED = config.get('sourmash_scaled', '1000')
SOURMASH_COMPUTE_TYPE = config.get('sourmash_sigtype', 'DNA')
assert SOURMASH_COMPUTE_TYPE in ('DNA', 'protein')

PREFETCH_MEMORY = float(config.get('prefetch_memory', '20e9'))

if config.get('database_taxonomy'):
    print(f"ERROR: this is an old config file. Please use 'taxonomies' instead of 'database_taxonomy'.",
          file=sys.stderr)
    sys.exit(-1)

TAXONOMY_DB = config.get('taxonomies', ["NO_TAXONOMIES_SPECIFIED_IN_CONFIG"])
if not isinstance(TAXONOMY_DB, list):
    print(f"ERROR: 'taxonomies' should be a list.", file=sys.stderr)
    sys.exit(-1)

PICKLIST = config.get('picklist', '')

# make a `sourmash sketch` -p param string.
def make_param_str(ksizes, scaled):
    ks = [ f'k={k}' for k in ksizes ]
    ks = ",".join(ks)
    return f"{ks},scaled={scaled},abund"

IGNORE_IDENTS = config.get('skip_genomes', [])

###
#
# print out unknown config variables...
#

all_config_keys = set(config.keys())
known_config_keys = { 'samples', 'outdir', 'tempdir',
                      'metagenome_trim_memory',
                      'genbank_cache',
                      'sourmash_databases',
                      'local_databases_info',
                      'sourmash_database_threshold_bp',
                      'sourmash_database_ksize',
                      'sourmash_compute_ksizes',
                      'sourmash_scaled',
                      'sourmash_sigtype',
                      'prefetch_memory',
                      'taxonomies',
                      'prevent_sra_download',
                      'skip_genomes',
                      'picklist' }

unknown_config_keys = all_config_keys - known_config_keys

if unknown_config_keys:
    print(f'**', file=sys.stderr)
    print(f'** WARNING: {len(unknown_config_keys)} unknown parameters found in config.', file=sys.stderr)
    print(f'** The following config parameters are being ignored:', file=sys.stderr)
    print(f"**    {', '.join(unknown_config_keys)}", file=sys.stderr)
    print(f'**', file=sys.stderr)
    print(f'** Please see docs at https://dib-lab.github.io/genome-grist/configuring/',
          file=sys.stderr)
    print(f'**', file=sys.stderr)

###

# utility function
def load_csv(filename):
    xopen = open
    if filename.endswith('.gz'):
        xopen = gzip.open
    with xopen(filename, "rt") as fp:
        r = csv.DictReader(fp)
        for row in r:
            yield row

###

# mark top-level rules:
_toplevel_rules = []
def toplevel(fn):
    assert fn.__name__.startswith('__')
    _toplevel_rules.append(fn.__name__[2:])
    return fn

wildcard_constraints:
    size="\\d+",
    sample='[^./]+'                   # should be everything but /.

@toplevel
rule print_rules:
    run:
        print("\nTop level rules are: \n", file=sys.stderr)
        print("* " + "\n* ".join(_toplevel_rules), file=sys.stderr)
        print("\nPlease see documentation for details.\n\n",
              file=sys.stderr)

@toplevel
rule clean_gather:
    # the '|| true' here makes this command succeed even when dirs do not exist
    # (adding -fr to rm seemed too dangerous!)
    shell: """
        (rm -r {outdir}/{{gather,genomes,mapping,leftover,reports}}/ \
            {outdir}/.kernel.set || true)
    """

@toplevel
rule download_reads:
    input:
        expand(f"{outdir}/raw/{{sample}}_1.fastq.gz", sample=SAMPLES),
        expand(f"{outdir}/raw/{{sample}}_2.fastq.gz", sample=SAMPLES),
        expand(f"{outdir}/raw/{{sample}}.raw.sig.zip", sample=SAMPLES),

@toplevel
rule trim_reads:
    input:
        url_file = expand(f"{outdir}/trim/{{sample}}.trim.fq.gz",
                          sample=SAMPLES)

@toplevel
rule estimate_distinct_kmers:
    input:
        url_file = expand(f"{outdir}/trim/{{sample}}.trim.fq.gz.kmer-report.txt",
                          sample=SAMPLES)

@toplevel
rule count_trimmed_reads:
    input:
        url_file = expand(f"{outdir}/trim/{{sample}}.trim.fq.gz.reads-report.txt",
                          sample=SAMPLES)

@toplevel
rule smash_reads:
    input:
        url_file = expand(f"{outdir}/sigs/{{sample}}.trim.sig.zip",
                          sample=SAMPLES)

@toplevel
rule prefetch_reads:
    input:
        expand(outdir + "/gather/{sample}.prefetch.csv.gz",
               sample=SAMPLES)

@toplevel
rule summarize_sample_info:
    input:
        expand(outdir + '/{sample}.info.yaml', sample=SAMPLES)

@toplevel
rule abundtrim_reads:
    input:
        url_file = expand(f"{outdir}/abundtrim/{{sample}}.abundtrim.fq.gz",
                          sample=SAMPLES)

@toplevel
checkpoint gather_reads:
    input:
        expand(f'{outdir}/gather/{{sample}}.gather.csv.gz',
               sample=SAMPLES),
        expand(f'{outdir}/gather/{{sample}}.gather.parquet',
               sample=SAMPLES),

checkpoint gather_reads_wc:
    input:
        gather_csv = f'{outdir}/gather/{{sample}}.gather.csv.gz',
        gather_pq = f'{outdir}/gather/{{sample}}.gather.parquet',
    output:
        touch(f"{outdir}/gather/.gather.{{sample}}")   # checkpoints need an output ;)


_gather_csv_cache = {}
class Checkpoint_GatherResults:
    """Given a pattern containing {ident} and {sample}, this class
    will generate the list of {ident} for that {sample}'s gather results.

    Alternatively, you can omit {sample} from the pattern and include the
    list of sample names in the second argument to the constructor.
    """
    def __init__(self, pattern, samples=None):
        self.pattern = pattern
        self.samples = samples

    def get_genome_idents(self, sample):
        gather_parquet = f'{outdir}/gather/{sample}.gather.parquet'
        if not os.path.exists(gather_parquet):
            print(f"gather parquet output file does not exist!? '{gather_parquet}'")

        df = polars.read_parquet(gather_parquet, columns=['match_name'])

        genome_idents = []
        for name in df.select('match_name')['match_name']:
            ident = name.split()[0]
            genome_idents.append(ident)

        return genome_idents

    def __call__(self, w):
        # get 'sample' from wildcards?
        if self.samples is None:
            return self.do_sample(w)
        else:
            assert not hasattr(w, 'sample'), "if 'samples' provided to constructor, cannot also be in rule inputs"

            ret = []
            for sample in self.samples:
                d = dict(sample=sample)
                w = snakemake.io.Wildcards(fromdict=d)

                x = self.do_sample(w)
                ret.extend(x)

            return ret

    def do_sample(self, w):
        # wait for the results of 'gather_reads_wc'; this will trigger
        # exception until that rule has been run.
        checkpoints.gather_reads_wc.get(**w)

        # parse hitlist_genomes,
        genome_idents = self.get_genome_idents(w.sample)

        p = expand(self.pattern, ident=genome_idents, **w)

        return p


class ListGatherGenomes(Checkpoint_GatherResults):
    """Provide list of the source genome files for either local or
    genbank databases, depending on the source of the genome info.

    This is used to get the filenames for the various genome files; the
    assumption is that snakemake already knows how to generate those :).

    Note: a key thing here is that the filenames themselves are correct,
    so we are not renaming any files, merely copying them with their
    existing names into the output subdirectory 'genomes/'.
    """
    def __init__(self, samples=None):
        self.samples = samples

    def _load_local_database_info(self):
        # get list of local genomes first...
        local_info = {}
        for filename in local_databases_info:
            for row in load_csv(filename):
                ident = row['ident']

                genome_dir = os.path.dirname(row['genome_filename'])
                row['genome_filename'] = os.path.normpath(row['genome_filename'])

                genome_dir = os.path.dirname(row['genome_filename'])
                info_filename = f'{ident}.info.csv'
                info_filename = os.path.join(genome_dir, info_filename)
                
                row['info_filename'] = os.path.normpath(info_filename)

                local_info[ident] = row

        if local_info:
            print(f"Loaded info on {len(local_info)} local genomes.")
        return local_info

    def do_sample(self, w):
        # wait for the results of 'gather_reads_wc'; this will trigger
        # exception until that rule has been run.
        checkpoints.gather_reads_wc.get(**w)

        sample = w.sample

        local_info = self._load_local_database_info()
        genome_filenames = []

        gather_parquet = f'{outdir}/gather/{sample}.gather.parquet'
        if not os.path.exists(gather_parquet):
            print(f"gather parquet output file does not exist!? '{gather_parquet}'")

        df = polars.read_parquet(gather_parquet)
        for name in df.select('match_name')['match_name']:
            ident = name.split()[0]

            # if in local information, use that as genome source.
            if ident in local_info:
                info = local_info[ident]
                genome_filenames.append(info['genome_filename'])
                genome_filenames.append(info['info_filename'])

            # genbank: point at genbank_genomes
            else:
                genome_filenames.append(f'{GENBANK_CACHE}/{ident}_genomic.fna.gz')
                genome_filenames.append(f'{GENBANK_CACHE}/{ident}.info.csv')

        return genome_filenames

class Checkpoint_GenomeFiles(Checkpoint_GatherResults):
    def do_sample(self, w):
        checkpoints.copy_sample_genomes_to_output_wc.get(**w)

        # parse hitlist_genomes,
        genome_idents = self.get_genome_idents(w.sample)

        if 'ident' in dict(w):
            p = expand(self.pattern, **w)
        else:
            p = expand(self.pattern, ident=genome_idents, **w)

        return p


@toplevel
rule gather_to_tax:
    input:
        expand(f'{outdir}/gather/{{sample}}.gather.with-lineages.csv.gz',
               sample=SAMPLES)
@toplevel
rule summarize_gather:
    input:
        expand(f'{outdir}/reports/report-gather-{{sample}}.html',
               sample=SAMPLES)

@toplevel
rule summarize_tax:
    input:
        expand(f'{outdir}/reports/report-taxonomy-{{sample}}.html',
               sample=SAMPLES),
        expand(f'{outdir}/gather/{{sample}}.gather.with-lineages.csv.gz',
               sample=SAMPLES)

@toplevel
rule combine_genome_info:
    input:
        expand(f"{outdir}/gather/{{sample}}.genomes.info.csv",
               sample=SAMPLES)

@toplevel
rule download_genbank_genomes:
    input:
        Checkpoint_GatherResults(f"{GENBANK_CACHE}/{{ident}}_genomic.fna.gz",
                                 samples=SAMPLES)

@toplevel
rule retrieve_genomes:
    input:
        expand(f"{outdir}/genomes/.genomes.{{sample}}", sample=SAMPLES)

@toplevel
rule map_reads:
    input:
        expand(f"{outdir}/mapping/{{sample}}.summary.csv", sample=SAMPLES),
        expand(f"{outdir}/leftover/{{sample}}.summary.csv", sample=SAMPLES)

@toplevel
rule build_consensus:
    input:
        Checkpoint_GatherResults(outdir + f"/mapping/{{sample}}.x.{{ident}}.consensus.fa.gz"),
        Checkpoint_GatherResults(outdir + f"/leftover/{{sample}}.x.{{ident}}.consensus.fa.gz"),

@toplevel
rule summarize_mapping:
    input:
        expand(f'{outdir}/reports/report-mapping-{{sample}}.html',
               sample=SAMPLES),
        expand(f'{outdir}/reports/report-gather-{{sample}}.html',
               sample=SAMPLES)

@toplevel
rule summarize:
    input:
        expand(f'{outdir}/reports/report-mapping-{{sample}}.html',
               sample=SAMPLES),
        expand(f'{outdir}/reports/report-gather-{{sample}}.html',
               sample=SAMPLES),
        expand(f'{outdir}/reports/report-taxonomy-{{sample}}.html',
               sample=SAMPLES),
        expand(f'{outdir}/gather/{{sample}}.gather.with-lineages.csv.gz',
               sample=SAMPLES),
        expand(outdir + '/{sample}.info.yaml', sample=SAMPLES)

@toplevel
rule make_sgc_conf:
    input:
        expand(f"{outdir}/sgc/{{sample}}.conf", sample=SAMPLES)

# print out the configuration
@toplevel
rule showconf:
    run:
        import yaml
        print('# full aggregated configuration:')
        print(yaml.dump(config).strip())
        print('# END')

# check config files only
@toplevel
rule check:
    run:
        pass

@toplevel
rule zip:
    shell: """
        ZIPFILE=$(basename "{outdir}").zip
        rm -f $ZIPFILE
        zip -r $ZIPFILE {outdir}/leftover/*.summary.csv \
                {outdir}/mapping/*.summary.csv {outdir}/*.yaml \
                {outdir}/gather/*.csv.gz {outdir}/gather/*.out \
                {outdir}/gather/*.genomes.info.csv {outdir}/reports/
        echo "Created $ZIPFILE"
    """


# download SRA IDs with prefetch
# @CTB: make it possible to feed in sralite files directly in the config?
rule download_sra_prefetch_wc:
    output:
        sralite = temp(outdir + f"/sra/{{sample}}.sralite"),
    params:
        do_not_run_me = "true" if config.get("prevent_sra_download", False) else "false",
    benchmark:
        outdir + "/benchmarks/download_sra_{sample}.txt"
    conda: "env/sra.yml"
    resources:
        mem_mb=40000,
    shell: '''
        if {params.do_not_run_me}; then
            echo "** genome-grist is trying to download from SRA for sample {wildcards.sample},"
            echo "** but 'prevent_sra_download' is set to true in config."
            echo "** Does '{outdir}/trim/{wildcards.sample}.trim.fq.gz' exist?"
            exit -1
        fi

        echo 'configuring SRA toolkit to use sra-lite format'
        vdb-config -Q yes

        echo 'running sra prefetch for {wildcards.sample}'
        prefetch {wildcards.sample} -o {output.sralite} --eliminate-quals
        '''

# extract reads from SRA prefetch file
rule download_sra_extract_wc:
    input:
        sralite = outdir + f"/sra/{{sample}}.sralite",
    output:
        temp_r1 =  temp(outdir + f"/sra/{{sample}}_1.fastq"),
        temp_r2 =  temp(outdir + f"/sra/{{sample}}_2.fastq"),
        temp_unp = temp(outdir + f"/sra/{{sample}}.fastq"),
    params:
        sra_dir = outdir + f"/sra"
    benchmark:
        outdir + "/benchmarks/download_sra_{sample}.txt"
    threads: 6
    conda: "env/sra.yml"
    resources:
        mem_mb=40000,
    shell: '''
        echo running fasterq-dump for {wildcards.sample}

        fasterq-dump {input.sralite} -e {threads} -p --split-files \
           -O {params.sra_dir}

        # make unpaired file if needed
        if [ -f {output.temp_r1} -a -f {output.temp_r2} -a \\! -f {output.temp_unp} ];
          then
            echo "no unpaired; creating empty unpaired file {output.temp_unp} for simplicity"
            touch {output.temp_unp}
          # make r1, r2 files if needed
        elif [ -f {output.temp_unp} -a \\! -f {output.temp_r1} -a \\! -f {output.temp_r2} ];
          then
            echo "unpaired file found; creating empty r1 ({output.temp_r1}) and r2 ({output.temp_r2}) files for simplicity"
            touch {output.temp_r1}
            touch {output.temp_r2}
        fi

        echo "finished extracting raw reads from SRAlite file."
        '''

# process output of SRA prefetch/dump
rule download_sra_process_wc:
    input:
        temp_r1 =  outdir + f"/sra/{{sample}}_1.fastq",
        temp_r2 =  outdir + f"/sra/{{sample}}_2.fastq",
        temp_unp = outdir + f"/sra/{{sample}}.fastq",
    output:
        r1  = protected(outdir + "/raw/{sample}_1.fastq.gz"),
        r2  = protected(outdir + "/raw/{sample}_2.fastq.gz"),
        unp = protected(outdir + "/raw/{sample}_unpaired.fastq.gz"),
    benchmark:
        outdir + "/benchmarks/download_sra_{sample}.txt"
    threads: 6
    conda: "env/sra.yml"
    shell: '''
        # now process the files and move to a permanent location
        echo processing R1...
        seqtk seq -C {input.temp_r1} | \
            perl -ne 's/\\.([12])$/\\/$1/; print $_' | \
            gzip -c > {output.r1} &

        echo processing R2...
        seqtk seq -C {input.temp_r2} | \
            perl -ne 's/\\.([12])$/\\/$1/; print $_' | \
            gzip -c > {output.r2} &

        echo processing unpaired...
        seqtk seq -C {input.temp_unp} | \
            perl -ne 's/\\.([12])$/\\/$1/; print $_' | \
            gzip -c > {output.unp} &
        wait
        echo "finished extracting raw reads from SRAlite file."
        '''

# compute sourmash signature from raw reads
rule smash_raw_reads_wc:
    input:
        r1 = ancient(outdir + "/raw/{sample}_1.fastq.gz"),
        r2 = ancient(outdir + "/raw/{sample}_2.fastq.gz"),
        #unp = ancient(outdir + "/raw/{sample}_unpaired.fastq.gz"),
    output:
        r1 = temp(outdir + "/raw/{sample}.raw.r1.sig.zip"),
        r2 = temp(outdir + "/raw/{sample}.raw.r2.sig.zip"),
        #unp = temp(outdir + "/raw/{sample}.raw.unp.sig.gz"),
        merged = outdir + "/raw/{sample}.raw.sig.zip",
    conda: "env/sourmash.yml"
    params:
        param_str = make_param_str(ksizes=[SOURMASH_DB_KSIZE],
                                   scaled=SOURMASH_COMPUTE_SCALED),
        #action = "translate" if SOURMASH_COMPUTE_TYPE == "protein" else "dna"
    benchmark:
        outdir + "/benchmarks/sketch_raw_{sample}.txt"
    shell: """
        sourmash scripts singlesketch --input-moltype dna \
            -p {params.param_str} {input.r1} -o {output.r1}
        sourmash scripts singlesketch --input-moltype dna \
            -p {params.param_str} {input.r2} -o {output.r2}
        sourmash sig merge {output.r1} {output.r2} \
           -o {output.merged} --name 'rawreads:{wildcards.sample}'
    """

# adapter trimming
rule trim_adapters_wc:
    input:
        r1 = ancient(outdir + "/raw/{sample}_1.fastq.gz"),
        r2 = ancient(outdir + "/raw/{sample}_2.fastq.gz"),
    output:
        interleaved = protected(outdir + '/trim/{sample}.trim.fq.gz'),
        json=outdir + "/trim/{sample}.trim.json",
        html=outdir + "/trim/{sample}.trim.html",
    conda: 'env/trim.yml'
    benchmark:
        outdir + "/benchmarks/trim_{sample}.txt"
    threads: 4
    resources:
        mem_mb=5000,
        runtime_min=600,
    shadow: "shallow"
    shell: """
        fastp --in1 {input.r1} --in2 {input.r2} \
             --detect_adapter_for_pe  --qualified_quality_phred 4 \
             --length_required 25 --correction --thread {threads} \
             --json {output.json} --html {output.html} \
             --low_complexity_filter --stdout | gzip -9 > {output.interleaved}
    """

# adapter trimming for the singleton reads
rule trim_unpaired_adapters_wc:
    input:
        unp = ancient(outdir + "/raw/{sample}_unpaired.fastq.gz"),
    output:
        unp = protected(outdir + '/trim/{sample}_unpaired.trim.fq.gz'),
        json = protected(outdir + '/trim/{sample}_unpaired.trim.json'),
        html = protected(outdir + '/trim/{sample}_unpaired.trim.html'),
    threads: 4
    resources:
        mem_mb=5000,
        runtime_min=600,
    shadow: "shallow"
    conda: 'env/trim.yml'
    shell: """
        fastp --in1 {input.unp} --out1 {output.unp} \
            --detect_adapter_for_se  --qualified_quality_phred 4 \
            --low_complexity_filter --thread {threads} \
            --length_required 25 --correction \
            --json {output.json} --html {output.html}
    """

# k-mer abundance trimming - optional
rule kmer_trim_reads_wc:
    input: 
        interleaved = ancient(outdir + '/trim/{sample}.trim.fq.gz'),
    output:
        protected(outdir + "/abundtrim/{sample}.abundtrim.fq.gz")
    conda: 'env/abundtrim.yml'
    benchmark:
        outdir + "/benchmarks/kmertrim_{sample}.txt"
    resources:
        mem_mb = int(ABUNDTRIM_MEMORY / 1e6),
    params:
        mem = ABUNDTRIM_MEMORY,
        ksize = SOURMASH_DB_KSIZE,
    shell: """
        trim-low-abund.py -C 3 -Z 18 -M {params.mem} -k {params.ksize} -V \
            {input.interleaved} -o {output} --gzip
    """

# count k-mers
rule estimate_distinct_kmers_wc:
    message: """
        Count distinct k-mers for {wildcards.sample} using 'unique-kmers.py' from the khmer package.
    """
    conda: 'env/abundtrim.yml'  # needs khmer package
    input:
        outdir + "/trim/{sample}.trim.fq.gz"
    output:
        report = outdir + "/trim/{sample}.trim.fq.gz.kmer-report.txt",
    params:
        ksize = SOURMASH_DB_KSIZE,
    benchmark:
        outdir + "/benchmarks/countkmers_{sample}.txt"
    shell: """
        unique-kmers.py {input} -k {params.ksize} -R {output.report}
    """

# count reads and bases
rule count_trimmed_reads_wc:
    message: """
        Count reads & bp in trimmed file for {wildcards.sample}.
    """
    input:
        outdir + "/trim/{sample}.trim.fq.gz"
    output:
        report = outdir + "/trim/{sample}.trim.fq.gz.reads-report.txt",
    # from Matt Bashton, in https://bioinformatics.stackexchange.com/questions/935/fast-way-to-count-number-of-reads-and-number-of-bases-in-a-fastq-file
    shell: """
        gzip -dc {input} |
             awk 'NR%4==2{{c++; l+=length($0)}}
                  END{{
                        print "n_reads,n_bases"
                        print c","l
                      }}' > {output.report}
    """

# map trimmed reads and produce a bam
rule minimap_wc:
    input:
        dep = f"{outdir}/genomes/.genomes.{{sample}}",
        query = Checkpoint_GenomeFiles(f"{outdir}/genomes/{{ident}}_genomic.fna.gz"),
        metagenome = outdir + "/trim/{sample}.trim.fq.gz",
    output:
        bam = outdir + "/mapping/{sample}.x.{ident}.bam",
    conda: "env/minimap2.yml"
    threads: 4
    benchmark:
        outdir + "/benchmarks/minimap_{sample}_{ident}.txt"
    shell: """
        minimap2 -ax sr -t {threads} {input.query} {input.metagenome} | \
            samtools view -b -F 4 - | samtools sort - > {output.bam}
    """

# extract FASTQ from BAM
rule bam_to_fastq_wc:
    input:
        bam = outdir + "/mapping/{bam}.bam",
    output:
        mapped = outdir + "/mapping/{bam}.mapped.fq.gz",
    conda: "env/minimap2.yml"
    shell: """
        samtools bam2fq {input.bam} | gzip > {output.mapped}
    """

# get per-base depth information from BAM
rule bam_to_depth_wc:
    input:
        bam = outdir + "/{dir}/{bam}.bam",
    output:
        depth = outdir + "/{dir}/{bam}.depth.txt",
    conda: "env/minimap2.yml"
    shell: """
        samtools depth -aa {input.bam} > {output.depth}
    """

# wild card rule for getting _covered_ regions from BAM
rule bam_covered_regions_wc:
    input:
        bam = outdir + "/{dir}/{bam}.bam",
    output:
        regions = outdir + "/{dir}/{bam}.regions.bed",
    conda: "env/covtobed.yml"
    shell: """
        covtobed {input.bam} -l 100 -m 1 | \
            bedtools merge -d 5 -c 4 -o mean > {output.regions}
    """

# calculating SNPs/etc.
rule samtools_mpileup_wc:
    input:
        query = Checkpoint_GenomeFiles(f"{outdir}/genomes/{{ident}}_genomic.fna.gz"),
        bam = outdir + "/{dir}/{sample}.x.{ident}.bam",
    output:
        bcf = outdir + "/{dir}/{sample}.x.{ident}.bcf",
        vcf = outdir + "/{dir}/{sample}.x.{ident}.vcf.gz",
        vcfi = outdir + "/{dir}/{sample}.x.{ident}.vcf.gz.csi",
    conda: "env/bcftools.yml"
    shell: """
        genomefile=$(mktemp -t grist.genome.XXXXXXX)
        gunzip -c {input.query} > $genomefile
        bcftools mpileup -Ou -f $genomefile {input.bam} | bcftools call -mv -Ob -o {output.bcf}
        rm $genomefile
        bcftools view {output.bcf} | bgzip > {output.vcf}
        bcftools index {output.vcf}
    """

# calculating mapped reads
rule samtools_count_wc:
    input:
        query = Checkpoint_GenomeFiles(f"{outdir}/genomes/{{ident}}_genomic.fna.gz"),
        bam = outdir + "/{dir}/{sample}.x.{ident}.bam",
    output:
        mapcount = outdir + "/{dir}/{sample}.x.{ident}.count_mapped_reads.txt",
    conda: "env/bcftools.yml"
    shell: """
        samtools view -c -F 260 {input.bam} > {output.mapcount}
    """

# build new consensus
rule build_new_consensus_wc:
    input:
        vcf = outdir + "/{dir}/{sample}.x.{ident}.vcf.gz",
        query = Checkpoint_GenomeFiles(f"{outdir}/genomes/{{ident}}_genomic.fna.gz"),
        regions = outdir + "/{dir}/{sample}.x.{ident}.regions.bed",
    output:
        mask = outdir + "/{dir}/{sample}.x.{ident}.mask.bed",
        genomefile = outdir + "/{dir}/{sample}.x.{ident}.fna.gz.sizes",
        consensus = outdir + "/{dir}/{sample}.x.{ident}.consensus.fa.gz",
    conda: "env/bcftools.yml"
    shell: """
        genomefile=$(mktemp -t grist.genome.XXXXXXX)
        gunzip -c {input.query} > $genomefile
        samtools faidx $genomefile
        cut -f1,2 ${{genomefile}}.fai > {output.genomefile}
        bedtools complement -i {input.regions} -g {output.genomefile} > {output.mask}
        bcftools consensus -f $genomefile {input.vcf} -m {output.mask} | \
            gzip > {output.consensus}
        rm $genomefile
    """

# summarize depth into a CSV
rule summarize_samtools_depth_wc:
    input:
        depth = Checkpoint_GatherResults(outdir + f"/{{dir}}/{{sample}}.x.{{ident}}.depth.txt"),
        vcf = Checkpoint_GatherResults(outdir + f"/{{dir}}/{{sample}}.x.{{ident}}.vcf.gz"),
        mapcount = Checkpoint_GatherResults(outdir + f"/{{dir}}/{{sample}}.x.{{ident}}.count_mapped_reads.txt")
    output:
        csv = f"{outdir}/{{dir}}/{{sample}}.summary.csv"
    shell: """
        python -m genome_grist.summarize_mapping {wildcards.sample} \
             {input.depth} -o {output.csv}
    """

# compute sourmash signature from trim reads
rule smash_trim_wc:
    input:
        metagenome = ancient(outdir + "/trim/{sample}.trim.fq.gz"),
    output:
        sig = outdir + "/sigs/{sample}.trim.sig.zip",
    conda: "env/sourmash.yml"
    params:
        param_str = make_param_str(ksizes=SOURMASH_COMPUTE_KSIZES,
                                   scaled=SOURMASH_COMPUTE_SCALED),
    threads: 1
    benchmark:
        outdir + "/benchmarks/sketch_{sample}.txt"
    shell: """
        sourmash scripts singlesketch -p {params.param_str} \
           -o {output.sig} --name {wildcards.sample} {input.metagenome} \
           --input-moltype={SOURMASH_COMPUTE_TYPE}
    """

# configure ipython kernel for papermill
rule set_kernel:
    input:
        srcdir('env/papermill.yml')
    output:
        touch(f"{outdir}/.kernel.set")
    conda: 'env/papermill.yml'
    shell: """
        python -m ipykernel install --user --name genome_grist
    """


# papermill -> gather reporting notebook + html
# CTB: it's not clear why 'ancient' is needed here, but it prevents this rule
# from being run over and over!
rule make_gather_notebook_wc:
    input:
        nb = srcdir('notebooks/report-gather.ipynb'),
        gather_csv = f'{outdir}/gather/{{sample}}.gather.csv.gz',
        genomes_info_csv = ancient(f"{outdir}/gather/{{sample}}.genomes.info.csv"),
        kernel_set = rules.set_kernel.output,
    output:
        nb = outdir + f'/reports/report-gather-{{sample}}.ipynb',
        html = outdir + f'/reports/report-gather-{{sample}}.html',
    params:
        cwd = outdir + '/reports/',
        outdir = outdir,
    conda: 'env/papermill.yml'
    shell: """
        papermill {input.nb} {output.nb} -k genome_grist \
              -p sample_id {wildcards.sample:q} -p render '' -p outdir {outdir:q}\
              --cwd {params.cwd}
        python -m nbconvert {output.nb} --to html --stdout --no-input \
             --ExecutePreprocessor.kernel_name=genome_grist > {output.html}
    """

# papermill -> taxonomy reporting notebook + html
# CTB: it's not clear why 'ancient' is needed here, but it prevents this rule
# from being run over and over!
rule make_taxonomy_notebook_wc:
    input:
        nb = srcdir('notebooks/report-taxonomy.ipynb'),
        taxcsv = ancient(f'{outdir}/gather/{{sample}}.gather.with-lineages.csv.gz'),
        kernel_set = f'{outdir}/.kernel.set'
    output:
        nb = outdir + f'/reports/report-taxonomy-{{sample}}.ipynb',
        html = outdir + f'/reports/report-taxonomy-{{sample}}.html',
    params:
        cwd = outdir + '/reports/',
        outdir = outdir,
    conda: 'env/papermill.yml'
    shell: """
        ls -la {input.kernel_set}
        papermill {input.nb} {output.nb} -k genome_grist \
              -p sample_id {wildcards.sample:q} -p render '' -p outdir {outdir:q}\
              --cwd {params.cwd}
        python -m nbconvert {output.nb} --to html --stdout --no-input \
             --ExecutePreprocessor.kernel_name=genome_grist > {output.html}
    """

# papermill -> reporting notebook + html
# CTB: it's not clear why 'ancient' is needed here, but it prevents this rule
# from being run over and over!
rule make_mapping_notebook_wc:
    input:
        nb = srcdir('notebooks/report-mapping.ipynb'),
        all_csv = ancient(f"{outdir}/mapping/{{sample}}.summary.csv"),
        depth_csv = ancient(f"{outdir}/leftover/{{sample}}.summary.csv"),
        gather_csv = f'{outdir}/gather/{{sample}}.gather.csv.gz',
        genomes_info_csv = ancient(f"{outdir}/gather/{{sample}}.genomes.info.csv"),
        kernel_set = rules.set_kernel.output,
    output:
        nb = outdir + f'/reports/report-mapping-{{sample}}.ipynb',
        html = outdir + f'/reports/report-mapping-{{sample}}.html',
    params:
        cwd = outdir + '/reports/',
        outdir = outdir,
    conda: 'env/papermill.yml'
    shell: """
        papermill {input.nb} {output.nb} -k genome_grist \
              -p sample_id {wildcards.sample:q} -p render '' \
              -p outdir {outdir:q} --cwd {params.cwd}
        python -m nbconvert {output.nb} --to html --stdout --no-input \
             --ExecutePreprocessor.kernel_name=genome_grist > {output.html}
    """

# convert mapped reads to leftover reads
rule extract_leftover_reads_wc:
    input:
        csv = f'{outdir}/gather/{{sample}}.gather.csv.gz',
        pq = f'{outdir}/gather/{{sample}}.gather.parquet',
        mapped = Checkpoint_GatherResults(f"{outdir}/mapping/{{sample}}.x.{{ident}}.mapped.fq.gz"),
    output:
        touch(f"{outdir}/leftover/.leftover.{{sample}}")
    conda: "env/sourmash.yml"
    params:
        outdir = outdir,
    benchmark:
        outdir + "/benchmarks/subtract_gather_{sample}.txt",
    shell: """
        python -Werror -Wignore::DeprecationWarning -m genome_grist.subtract_gather \
            {wildcards.sample:q} {input.csv} --outdir={params.outdir:q}
    """

# rule for mapping leftover reads to genomes -> BAM
rule map_leftover_reads_wc:
    input:
        all_csv = f"{outdir}/mapping/{{sample}}.summary.csv",
        query = Checkpoint_GenomeFiles(f"{outdir}/genomes/{{ident}}_genomic.fna.gz"),
        leftover_reads_flag = f"{outdir}/leftover/.leftover.{{sample}}",
    output:
        bam=outdir + "/leftover/{sample}.x.{ident}.bam",
    conda: "env/minimap2.yml"
    benchmark:
        outdir + "/benchmarks/leftovermap_{sample}_{ident}.txt"
    threads: 4
    shell: """
        minimap2 -ax sr -t {threads} {input.query} \
     {outdir}/mapping/{wildcards.sample}.x.{wildcards.ident}.leftover.fq.gz | \
            samtools view -b -F 4 - | samtools sort - > {output.bam}
    """

# build a combined manifest CSV of sourmash databases
rule sourmash_collect_wc:
    input:
        db = SOURMASH_DB_LIST,
    output:
        mf_csv = outdir + "/gather/db.collect.mf.csv",
    conda: "env/sourmash.yml"
    params:
        ksize = SOURMASH_DB_KSIZE,
        moltype = f"--{SOURMASH_COMPUTE_TYPE.lower()}",
        threshold_bp = SOURMASH_DATABASE_THRESHOLD_BP,
        picklist = f"--picklist {PICKLIST}" if PICKLIST else "",
        run_sig_check = "true" if PICKLIST else "false"
    # here, 'collect' does not take a picklist, but 'check' does. so
    # pick which one we run based on whehter we have a picklist.
    shell: """
        if {params.run_sig_check}; then
           sourmash sig check -F csv -k {params.ksize} {params.moltype} \
              {params.picklist} {input.db} -m {output.mf_csv} --abspath
        else
           sourmash sig collect -F csv -k {params.ksize} {params.moltype} \
              {input.db} -o {output.mf_csv} --abspath
        fi
    """

# eliminate IGNORE_IDENTs in the input manifest
rule sourmash_ignore_idents_wc:
    input:
        mf_csv = outdir + "/gather/db.collect.mf.csv",
    output:
        mf_csv = outdir + "/gather/db.mf.csv",
    params:
        idents = ",".join(IGNORE_IDENTS) if IGNORE_IDENTS else ' '
    shell:
        """
        python -Werror -Wignore::DeprecationWarning -m genome_grist.remove_idents_from_manifest \
            --idents {params.idents:q} {input} -o {output}
    """

# run sourmash gather with known hashes on prefetched matches.
rule sourmash_gather_wc:
    message: """
       Run gather for {wildcards.sample}
    """
    input:
        sig = outdir + "/sigs/{sample}.trim.sig.zip",
        db = outdir + "/gather/db.mf.csv",
    output:
        prefetch_csv = outdir + "/gather/{sample}.prefetch.csv.gz",
        gather_csv = outdir + "/gather/{sample}.gather.csv.gz",
        matches = outdir + "/gather/{sample}.matches.sig.zip",
    conda: "env/sourmash.yml"
    params:
        ksize = SOURMASH_DB_KSIZE,
        threshold_bp = SOURMASH_DATABASE_THRESHOLD_BP,
        prefetch_csv_ungz = outdir + "/gather/{sample}.prefetch.csv",
        gather_csv_ungz = outdir + "/gather/{sample}.gather.csv",
    benchmark:
        outdir + "/benchmarks/gather_{sample}.txt"
    resources:
        mem_mb=int(PREFETCH_MEMORY / 1e6)
    threads: 64
    shell: """
        sourmash scripts fastgather {input.sig} {input.db} -c {threads} \
          -k {params.ksize} -m {SOURMASH_COMPUTE_TYPE} \
          -o {params.gather_csv_ungz} \
          --output-prefetch {params.prefetch_csv_ungz} \
          --threshold-bp {params.threshold_bp}

        if [ ! -f {params.gather_csv_ungz} ]; then
             echo "** ERROR: gather didn't find anything for sample '{wildcards.sample}'. Failing."
             exit 1
        fi

        # gzip CSV files
        gzip {params.gather_csv_ungz} {params.prefetch_csv_ungz}

        # extract matches
        sourmash sig cat --picklist {output.gather_csv}::prefetch {input.db} \
          -o {output.matches}
    """

# convert csv.gz to parquet
rule convert_csv_to_parquet:
    input:
        csv = "{filename}.csv.gz",
    output:
        parquet = "{filename}.parquet",
    run:
        df = polars.scan_csv(input[0])
        df.sink_parquet(output[0], compression="zstd")

# report on known and unknown hashes
rule report_query_known_unknown_wc:
    message: """
       Report on "known" and "unknown" hashes for {wildcards.sample}
    """
    input:
        query = outdir + "/sigs/{sample}.trim.sig.zip",
        gather_csv = outdir + "/gather/{sample}.gather.csv.gz",
        gather_pq = outdir + "/gather/{sample}.gather.parquet",
        db = outdir + "/gather/db.mf.csv",
    output:
        report = outdir + "/gather/{sample}.prefetch.report.txt",
    conda: "env/sourmash.yml"
    params:
        ksize = SOURMASH_DB_KSIZE,
        moltype = SOURMASH_COMPUTE_TYPE,
    shell: """
         python -Werror -Wignore::DeprecationWarning -m genome_grist.summarize_gather_hashes \
          --query {input.query} --picklist {input.gather_csv}::prefetch --db {input.db} \
          -k {params.ksize} --moltype {params.moltype} --report {output.report}
    """

# run sourmash tax annotate on the gather results
rule summarize_tax_wc:
    input:
        gather_csv = f"{outdir}/gather/{{sample}}.gather.csv.gz",
        tax_csv = TAXONOMY_DB,
    output:
        f"{outdir}/gather/{{sample}}.gather.with-lineages.csv.gz",
        gather_csv_ungz = f"{outdir}/gather/{{sample}}.gather.csv",
    conda: "env/sourmash.yml"
    params:
        o_param = f"{outdir}/gather/"
    shell: """
        gunzip -k {input.gather_csv}
        sourmash tax annotate -g {output.gather_csv_ungz} -t {input.tax_csv} \
            -o {params.o_param} --fail-on-missing-taxonomy
        gzip {outdir}/gather/{wildcards.sample}.gather.with-lineages.csv
    """


# download genbank genome & make an info.csv file for entry.
rule download_genbank_genomes_wc:
    output:
        csvfile = f'{GENBANK_CACHE}/{{ident}}.info.csv',
        genome = f'{GENBANK_CACHE}/{{ident}}_genomic.fna.gz'
    retries: 3
    shell: """
        sourmash scripts get-genomes {wildcards.ident} --output-dir {GENBANK_CACHE}
    """

# copy genbank + local genomes for this sample into output dir
checkpoint copy_sample_genomes_to_output_wc:
    input:
        # note: a key thing here is that the filenames themselves are correct,
        # so we are simply copying from (multiple) directories into one.
        # this is why the genome filenames need to be {acc}_genomic.fna.gz.
        genomes = ListGatherGenomes(),
    output:
        touch(f"{outdir}/genomes/.genomes.{{sample}}")
    shell: """
        mkdir -p {outdir}/genomes/
        cp {input} {outdir}/genomes/
    """

# combined info.csv per sample
rule make_combined_info_csv_wc:
    input:
        csvs = Checkpoint_GenomeFiles(f'{outdir}/genomes/{{ident}}.info.csv'),
    output:
        genomes_info_csv = f"{outdir}/gather/{{sample}}.genomes.info.csv",
    shell: """
        python -Werror -Wignore::DeprecationWarning -m genome_grist.combine_csvs \
             --fields ident,display_name \
             {input.csvs} > {output}
    """

# summarize_reads_info
rule summarize_reads_info_wc:
    input:
        kmers = outdir + "/trim/{sample}.trim.fq.gz.kmer-report.txt",
        reads = outdir + "/trim/{sample}.trim.fq.gz.reads-report.txt",
        gather_report = outdir + "/gather/{sample}.prefetch.report.txt",
    output:
        outdir + '/{sample}.info.yaml',
    run:
        d = {}
        with open(str(input.kmers), 'rt') as fp:
            kmers = fp.readlines()[3].strip()
            kmers = kmers.split()[-1]
            d['kmers'] = int(kmers)

        with open(str(input.reads), 'rt') as fp:
            r = csv.DictReader(fp)
            row = next(iter(r))
            d['n_reads'] = int(row['n_reads'])
            d['n_bases'] = int(row['n_bases'])

        with open(str(input.gather_report), 'rt') as fp:
            r = csv.DictReader(fp)
            row = next(iter(r))
            d['total_hashes'] = int(row['total_hashes'])
            d['known_hashes'] = int(row['known_hashes'])
            d['unknown_hashes'] = int(row['unknown_hashes'])

        d['sample'] = wildcards.sample

        with open(str(output), 'wt') as fp:
            import yaml
            yaml.dump(d, fp)


# create a spacegraphcats config file
rule create_sgc_conf_wc:
    input:
        csv = outdir + "/gather/{sample}.gather.csv.gz",
        pq = outdir + "/gather/{sample}.gather.parquet",
        queries = Checkpoint_GenomeFiles(f"{outdir}/genomes/{{ident}}_genomic.fna.gz")
    output:
        conf = outdir + "/sgc/{sample}.conf"
    run:
        query_list = "\n- ".join(input.queries)
        with open(output.conf, 'wt') as fp:
           print(f"""\
catlas_base: {wildcards.sample}
input_sequences:
- {outdir}/trim/{wildcards.sample}.trim.fq.gz
ksize: 31
radius: 1
search:
- {query_list}
""", file=fp)
