from glob import glob

ALL_K = [21, 31, 51]
MASH_K = [21, 31]
NUM = [1000, 100000]

rule all:
  input:
    expand("outputs/scaled_1000/containments_SRR606249-k{k}.csv", k=ALL_K),
    expand("outputs/scaled_1000/similarity-k{k}.matrix.png", k=ALL_K),
    expand("outputs/scaled_1000/SRR606249-k{k}.csv", k=ALL_K),
    expand("outputs/smol_1000/SRR606249-k{k}.csv", k=ALL_K),
    expand("outputs/smol_1000/search-SRR606249-k{k}.csv", k=ALL_K),
    expand("outputs/mash_screen/SRR606249-k{k}-s{num}-m{m}.tsv", k=MASH_K, num=NUM, m=(1,3)),
    expand("outputs/mash_screen/SRR606249-k{k}-refseq.tsv", k=MASH_K),
    expand("outputs/exact/SRR606249-k{k}.csv", k=ALL_K),
    expand("outputs/cmash/SRR606249.csv", k=ALL_K),
    expand("outputs/cmash/SRR606249-k{k}-n{num}.csv", k=ALL_K, num=NUM),
    expand("outputs/cmash_paper/SRR606249-k{k}-n{num}.csv", k=MASH_K, num=NUM),

### Download the podar metagenome

rule download_reads:
  output: "data/reads/SRR606249_{i}.fastq.gz"
  shell: """
    wget -qO {output[0]} \
      ftp://ftp.sra.ebi.ac.uk/vol1/fastq/SRR606/SRR606249/SRR606249_{wildcards.i}.fastq.gz
  """

### Download mash simulated data (from screen paper)
#   https://mash.readthedocs.io/en/latest/data.html#data-files
rule download_simulated_reads:
  output: "data/mash/art.fastq.gz"
  shell: "wget -qO {output[0]} https://obj.umiacs.umd.edu/mash/screen/art.fastq.gz"

rule download_tables:
  output:
    "data/mash/sra_meta_{mol}_95idy.tsv.gz",
    "data/mash/sra_meta_{mol}_80idy_3x.tsv.gz"
  shell: """
    wget -qO {output[0]} https://obj.umiacs.umd.edu/mash/screen/tables/sra_meta_{wildcards.mol}_95idy.tsv.gz
    wget -qO {output[1]} https://obj.umiacs.umd.edu/mash/screen/tables/sra_meta_{wildcards.mol}_80idy_3x.tsv.gz
  """

rule download_fig5_tsv:
  output: "data/mash/fig5.tsv"
  shell: "wget -qO {output} https://obj.umiacs.umd.edu/mash/screen/fig5/fig5.tsv"

### Download precomputed databases (might not need them)

rule download_mash_screen_db:
  output: "data/mash/RefSeq88n.msh"
  shell: """
    wget -qO {output}.gz https://obj.umiacs.umd.edu/mash/screen/RefSeq88n.msh.gz
    gunzip data/mash/RefSeq88n.msh.gz
  """

rule download_cmash_db:
  output: "data/cmash/cmash_db_n1000_k60.h5"
  shell: "curl -L https://ucla.box.com/shared/static/27xulklfvo60g7heogvi2y43jti5tqxo.gz | tar xzf - -C data/cmash"

### Download reference genomes present in podar

rule download_podar:
  output: expand("data/refs/{i}.fa", i=range(0, 64))
  shell: "curl -L https://osf.io/8uxj9/download | tar xzf - -C data/refs/"

# the mash screen paper found another 4 contaminant genomes
extra_refs = {
  # Propionibacterium acnes HL072PA1
  '64': 'https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/144/245/GCF_000144245.1_ASM14424v1/GCF_000144245.1_ASM14424v1_genomic.fna.gz',

  # Escherichia coli strain 2014C-3250
  '65': 'https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/001/645/245/GCF_001645245.1_ASM164524v1/GCF_001645245.1_ASM164524v1_genomic.fna.gz',

  # Proteiniclasticum ruminis DSM 24773
  '66': 'https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/701/905/GCF_000701905.1_ASM70190v1/GCF_000701905.1_ASM70190v1_genomic.fna.gz',

  # Streptococcus parasanguinis strain C1A
  '67': 'https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/724/645/GCF_000724645.1_ASM72464v1/GCF_000724645.1_ASM72464v1_genomic.fna.gz',
}

rule download_mash_screen_extra_refs:
  output: "data/refs/{i,[6][4-7]}.fa"
  params:
    url = lambda w: extra_refs[w.i]
  shell: "curl -L {params.url} | gunzip -c > {output}"

### Compute exact k-mer sets

rule build_exact:
  output: "scripts/exact/target/release/exact"
  conda: "envs/exact.yml"
  shell: """
    cargo build --release --manifest-path scripts/exact/Cargo.toml
  """

rule kmer_sets_refs:
  output: "outputs/exact/refs/{sample}-k{ksize}.set"
  input:
    ref = "data/refs/{sample}.fa",
    exact = "scripts/exact/target/release/exact"
  params:
    ksize = "{ksize}",
  shell: """
    {input.exact} generate -k {params.ksize} {input.ref} {output}
  """

rule kmer_sets_reads:
  output: "outputs/exact/reads/{sample}-k{ksize}.set"
  input:
    reads = "data/reads/{sample}_2.fastq.gz",
    exact = "scripts/exact/target/release/exact"
  params:
    ksize = "{ksize}",
  shell: """
    {input.exact} generate -k {params.ksize} {input.reads} {output}
  """

### Compute smol sketches

rule build_smol:
  output: "scripts/smol/target/release/smol"
#  input: "scripts/smol/src/main.rs"
  conda: "envs/exact.yml"
  shell: """
    cargo build --release --manifest-path scripts/smol/Cargo.toml
  """

rule smol_refs:
  output: "outputs/smol_{scaled}/refs/{sample}-k{ksize}.smol"
  input:
    ref = "data/refs/{sample}.fa",
    smol = "scripts/smol/target/release/smol"
  params:
    ksize = "{ksize}",
    scaled = "{scaled}",
  shell: """
    {input.smol} compute -k {params.ksize} -o {output} --scaled {params.scaled} {input.ref}
  """

rule smol_reads:
  output: "outputs/smol_{scaled}/reads/{sample}-k{ksize}.smol"
  input:
    reads = "data/reads/{sample}_2.fastq.gz",
    smol = "scripts/smol/target/release/smol"
  params:
    ksize = "{ksize}",
    scaled = "{scaled}",
  shell: """
    {input.smol} compute -k {params.ksize} --scaled {params.scaled} -o {output} {input.reads}
  """

### Compute sourmash signatures for each reference genome and the reads

rule compute_refs:
  output: "outputs/scaled_{scaled}/refs/{sample}.sig"
  input: "data/refs/{sample}.fa"
  params:
    scaled = "{scaled}"
  conda: "envs/sourmash.yml"
  shell: """
    sourmash compute -k 21,31,51 \
             --name-from-first \
             --scaled {params.scaled} \
             --track-abundance \
             -o {output} \
             {input}
  """

rule compute_reads:
  output: "outputs/scaled_{scaled}/reads/{sample}.sig"
#  input: expand("data/reads/{{sample}}_{i}.fastq.gz", i=(1,2))
  input: "data/reads/{sample}_2.fastq.gz"
  params:
    scaled = "{scaled}"
  conda: "envs/sourmash.yml"
  shell: """
    sourmash compute -k 21,31,51 \
             --name-from-first \
             --scaled {params.scaled} \
             --track-abundance \
             -o {output} \
             {input}
  """

### exact containment

rule kmer_sets_containment:
  output: "outputs/exact/{sample}-k{ksize}.csv"
  input:
    mg = "outputs/exact/reads/{sample}-k{ksize}.set",
    refs = expand("outputs/exact/refs/{i}-k{{ksize}}.set", i=range(0, 68)),
    exact = "scripts/exact/target/release/exact"
  shell: """
    {input.exact} containment {input.mg} {input.refs} > {output}
  """


### smol containment

rule smol_gather:
  output: "outputs/smol_{scaled}/{sample}-k{ksize}.csv"
  input:
    mg = "outputs/smol_{scaled}/reads/{sample}-k{ksize}.smol",
    refs = expand("outputs/smol_{{scaled}}/refs/{i}-k{{ksize}}.smol", i=range(0, 68)),
    smol = "scripts/smol.py"
  conda: "envs/smol.yml"
  shell: """
    {input.smol} gather --threshold 0 -o {output} {input.mg} {input.refs}
  """

rule smol_search:
  output: "outputs/smol_{scaled}/search-{sample}-k{ksize}.csv"
  input:
    mg = "outputs/smol_{scaled}/reads/{sample}-k{ksize}.smol",
    refs = expand("outputs/smol_{{scaled}}/refs/{i}-k{{ksize}}.smol", i=range(0, 68)),
    smol = "scripts/smol/target/release/smol"
  shell: """
    {input.smol} search --threshold 0 -o {output} {input.mg} {input.refs}
  """

### Run gather

rule gather_sigs:
  output: "outputs/scaled_{scaled,\d+}/{sample}-k{ksize}.csv"
  input:
    query = "outputs/scaled_{scaled}/reads/{sample}.sig",
    sigs = expand("outputs/scaled_{{scaled}}/refs/{i}.sig", i=range(0, 68))
  params:
    ksize = "{ksize}"
  conda: "envs/sourmash.yml"
  shell: """
    sourmash gather --threshold-bp 0 -o {output} {input.query} {input.sigs}
  """

rule containment_sigs:
  output: "outputs/scaled_{scaled}/containments/{i}_{sample}-k{ksize}.csv"
  input:
    sig = "outputs/scaled_{scaled}/refs/{i}.sig",
    metagenome = "outputs/scaled_{scaled}/reads/{sample}.sig",
  params:
    ksize = "{ksize}"
  conda: "envs/sourmash.yml"
  shell: """
    sourmash search --threshold 0 --containment -k {params.ksize} -o {output} {input.sig} {input.metagenome}
  """

rule containment_sample:
  output: "outputs/scaled_{scaled,\d+}/containments_{sample}-k{ksize}.csv"
  input: expand("outputs/scaled_{{scaled}}/containments/{i}_{{sample}}-k{{ksize}}.csv", i=range(0, 68)),
  run:
    out = ['containment,filename']
    for ref in input:
      with open(ref, 'r') as f:
        data = f.readlines()[-1]
        if "similarity" in data:  # Couldn't find it
          continue

        containment = data.split(',')[0]
        filename = os.path.basename(ref).split("_")[0]
        out.append(f"{containment},{filename}.fa")

    with open(output[0], 'w') as o:
        o.write("\n".join(out))

### Build DB and run mash screen

rule mash_sketch:
  output: "outputs/mash_screen/refs/{sample}-k{ksize}-s{num}-m{cutoff}.msh"
  input: "data/refs/{sample}.fa"
  conda: "envs/mash.yml"
  params:
    ksize = "{ksize}",
    num = "{num}",
    cutoff = "{cutoff}",
  shell: """
    mash sketch -k {params.ksize} \
                -s {params.num} \
                -m {params.cutoff} \
                -o {output} \
                {input}
  """

rule build_mash_db:
  output: "outputs/mash_screen/db-k{ksize}-s{num}-m{cutoff}.msh"
  input: expand("outputs/mash_screen/refs/{i}-k{{ksize}}-s{{num}}-m{{cutoff}}.msh", i=range(0, 68))
  conda: "envs/mash.yml"
  shell: """
    mash paste {output} {input}
  """

rule mash_screen:
  output: "outputs/mash_screen/{sample}-k{ksize}-s{num}-m{cutoff}.tsv"
  input:
    query = "data/reads/{sample}_2.fastq.gz",
    db = "outputs/mash_screen/db-k{ksize}-s{num}-m{cutoff}.msh",
  conda: "envs/mash.yml"
  threads: 24
  shell: """
    mash screen -p {threads} {input.db} {input.query} > {output}
  """

rule mash_screen_refseq:
  output: "outputs/mash_screen/{sample}-k{ksize}-refseq.tsv"
  input:
    query = "data/reads/{sample}_2.fastq.gz",
    db = "data/mash/RefSeq88n.msh",
  conda: "envs/mash.yml"
  threads: 24
  shell: """
    mash screen -p {threads} {input.db} {input.query} | sort -gr > {output}
  """

### Build DB and run cmash

rule build_cmash_db:
  output:
    db = "outputs/cmash/db.h5",
    tst = "outputs/cmash/db.tst",
  input: expand("data/refs/{i}.fa", i=range(0, 68))
  conda: "envs/cmash.yml"
  threads: 24
  shell: """
    MakeStreamingDNADatabase.py -t {threads} -k 51 <(ls -1 data/refs/*.fa) {output.db}
  """

rule cmash:
  output: "outputs/cmash/{sample}.csv"
  input:
    query = "data/reads/{sample}_2.fastq.gz",
    db = "outputs/cmash/db.h5"
  conda: "envs/cmash.yml"
  threads: 24
  shell: """
     StreamingQueryDNADatabase.py \
         -c 0 \
         -t {threads} \
         {input.query} \
         {input.db} \
         {output} \
         21-51-10
  """

#### new CMash (unpublished, in development)

rule build_cmash_db_k:
  output:
    db = "outputs/cmash/db-k{ksize,\d+}-n{num,\d+}.h5",
    tst = "outputs/cmash/db-k{ksize,\d+}-n{num,\d+}.tst",
  input: expand("data/refs/{i}.fa", i=range(0, 68))
  params:
    ksize = "{ksize}",
    num = "{num}",
  conda: "envs/cmash.yml"
  threads: 24
  shell: """
    MakeStreamingDNADatabase.py \
        -k {params.ksize} \
        -n {params.num} \
        -t {threads} \
        <(ls -1 data/refs/*.fa) \
        {output.db}
  """

rule cmash_k:
  output: "outputs/cmash/{sample}-k{ksize,\d+}-n{num,\d+}.csv"
  input:
    query = "data/reads/{sample}_2.fastq.gz",
    db = "outputs/cmash/db-k{ksize}-n{num}.h5"
  params:
    k = "{ksize}",
  conda: "envs/cmash.yml"
  threads: 24
  shell: """
     StreamingQueryDNADatabase.py \
         -c 0 \
         -t {threads} \
         {input.query} \
         {input.db} \
         {output} \
         {params.k}-{params.k}-1
  """

#### manuscript CMash

rule build_cmash_paper_db_k:
  output:
    db = "outputs/cmash_paper/db-k{ksize,\d+}-n{num,\d+}.h5",
  input: expand("data/refs/{i}.fa", i=range(0, 68))
  params:
    ksize = "{ksize}",
    num = "{num}",
  conda: "envs/cmash_paper.yml"
  threads: 24
  shell: """
    MakeDNADatabase.py \
        -k {params.ksize} \
        -n {params.num} \
        -t {threads} \
        <(ls -1 data/refs/*.fa) \
        {output.db}
  """

rule cmash_paper_query_bf_k:
  output: "outputs/cmash_paper/{sample}-k{ksize,\d+}.ng"
  input:
    query = "data/reads/{sample}_2.fastq.gz",
  params:
    k = "{ksize}",
    outdir = lambda w, output: os.path.dirname(output[0]),
    query_base = lambda w, input: os.path.basename(input.query)
  conda: "envs/cmash_paper.yml"
  threads: 1
  shell: """
     MakeNodeGraph.py \
         -t {threads} \
         -fp 0.01 \
         -k {params.k} \
         {input.query} \
         {params.outdir}
     mv {params.outdir}/{params.query_base}.NodeGraph.K{params.k} {output[0]}
  """

rule cmash_paper_k:
  output: "outputs/cmash_paper/{sample}-k{ksize,\d+}-n{num,\d+}.csv"
  input:
    query = "data/reads/{sample}_2.fastq.gz",
    query_ng = "outputs/cmash_paper/{sample}-k{ksize}.ng",
    db = "outputs/cmash_paper/db-k{ksize}-n{num}.h5"
  params:
    k = "{ksize}",
  conda: "envs/cmash_paper.yml"
  threads: 24
  shell: """
     QueryDNADatabase.py \
         -ct 0 \
         -t {threads} \
         -ng {input.query_ng} \
         {input.query} \
         {input.db} \
         {output}
  """

### Evaluating similarity and containment of the samples

rule compare_sigs_similarity:
  output: "outputs/scaled_{scaled,\d+}/similarity-k{ksize,\d+}"
  input: expand("outputs/scaled_{{scaled}}/refs/{i}.sig", i=range(0, 68))
  params:
    ksize = "{ksize}"
  conda: "envs/sourmash.yml"
  shell: """
    sourmash compare -k {params.ksize} -o {output} {input}
  """

rule plot_sigs_similarity:
  output: "outputs/scaled_{scaled,\d+}/{metric}.matrix.png"
  input: "outputs/scaled_{scaled}/{metric}"
  params:
    outdir = lambda w, output: os.path.dirname(output[0])
  conda: "envs/sourmash.yml"
  shell: """
    sourmash plot --output-dir {params.outdir} {input}
  """
