from pathlib import Path
from itertools import chain

rule all:
  input:
    "outputs/minimap/summary.csv",
    "outputs/minimap/overlap/summary.csv",
    "outputs/minimap/depth/summary.csv"

rule download_sra_dataset:
  output:
     "outputs/01_RAW_FASTQ/{sra_id}.fastq.gz",
  conda: "env/sra.yml"
  shell: '''
    fastq-dump --skip-technical  \
               --readids \
               --read-filter pass \
               --dumpbase \
               --split-spot \
               --clip \
               -Z \
               {wildcards.sra_id} | \
               perl -ne 's/\.([12]) /\/$1 /; print $_' | \
               gzip > {output}
  '''

rule download_TOBG:
  output: "outputs/01_FASTA/TOBG.fa.gz"
  shell: """
    wget https://ftp.ncbi.nlm.nih.gov/genomes/all/GCA/002/731/465/GCA_002731465.1_ASM273146v1/GCA_002731465.1_ASM273146v1_genomic.fna.gz \
           -O outputs/01_FASTA/TOBG.fa.gz
  """

#####################
# Minimap
#####################

rule minimap:
  output:
    bam="outputs/minimap/{sra_id}.bam",
  input:
    query = "outputs/01_FASTA/TOBG.fa.gz",
    metagenome = "outputs/01_RAW_FASTQ/{sra_id}.fastq.gz",
  conda: "env/minimap2.yml"
  threads: 4
  shell: """
    minimap2 -ax sr -t {threads} {input.query} {input.metagenome} | \
        samtools view -b -F 4 - | samtools sort - > {output.bam}
  """

rule samtools_stats:
  output:
    stats="outputs/minimap/{sra_id}.stats",
  input:
    bam="outputs/minimap/{sra_id}.bam",
  conda: "env/minimap2.yml"
  threads: 4
  shell: """
    samtools stats -F 4 {input.bam} | grep ^SN | cut -f 2- > {output.stats}
  """

rule samtools_fastq:
  output:
    mapped="outputs/minimap/{sra_id}.mapped.fastq",
  input:
    bam="outputs/minimap/{sra_id}.bam",
  conda: "env/minimap2.yml"
  threads: 4
  shell: """
    samtools bam2fq {input.bam} > {output.mapped}
  """

rule samtools_depth:
  output:
    depth="outputs/minimap/depth/{sra_id}.txt",
  input:
    bam="outputs/minimap/{sra_id}.bam",
  conda: "env/minimap2.yml"
  threads: 4
  shell: """
    samtools depth -aa {input.bam} > {output.depth}
  """

rule summarize_mapping:
  output: "outputs/minimap/summary.csv"
  input:
    expand("outputs/minimap/{sra_id}.stats",
           sra_id=("SRR1509798", "SRR1509792", "SRR1509799",
                   "SRR1509793", "SRR1509794", "ERR3256923", 
                   "SRR070081", "SRR070083", "SRR070084",
                   "SRR5868539", "SRR5868540", "SRR304680"))
  run:
    import pandas as pd

    runs = {}
    for sra_stat in input:
      data = Path(sra_stat).read_text().split("\n")
      sra_id = sra_stat.split("/")[-1].split(".")[0]
      d = {}
      for line in data:
        line = line.strip()
        if not line:
          continue
        k, v = line.split("\t")[:2]
        k = k[:-1]
        d[k] = v
      runs[sra_id] = d

    pd.DataFrame(runs).T.to_csv(output[0])

rule compute_sigs:
  output: "outputs/minimap/sigs/{sra_id}.sig"
  input: "outputs/minimap/{sra_id}.mapped.fastq"
  conda: "env/sourmash.yml"
  shell: "sourmash compute -k 21,31,51 --scaled 1000 --track-abundance -o {output} {input}"

rule compute_sig_mag:
  output: "outputs/sigs/TOBG.sig"
  input: "outputs/01_FASTA/TOBG.fa.gz"
  conda: "env/sourmash.yml"
  shell: "sourmash compute -k 21,31,51 --scaled 1000 --track-abundance -o {output} {input}"

rule sig_overlaps:
  output: "outputs/minimap/overlap/{sra_id}.txt"
  input:
    mag="outputs/sigs/TOBG.sig",
    metagenome="outputs/minimap/sigs/{sra_id}.sig"
  conda: "env/sourmash.yml"
  shell: "sourmash sig overlap -k 31 {input.mag} {input.metagenome} > {output}"

rule summarize_sigs:
  output: "outputs/minimap/overlap/summary.csv"
  input:
    expand("outputs/minimap/overlap/{sra_id}.txt",
           sra_id=("SRR1509798", "SRR1509792", "SRR1509799",
                   "SRR1509793", "SRR1509794", "ERR3256923", 
                   "SRR070081", "SRR070083", "SRR070084",
                   "SRR5868539", "SRR5868540", "SRR304680"))
  run:
    import pandas as pd

    runs = {}
    for sra_stat in input:
      data = Path(sra_stat).read_text().split("\n")
      sra_id = sra_stat.split("/")[-1].split(".")[0]

      d = {}
      for line in data:
        line = line.strip()
        if not line:
          continue
        if line.startswith("first contained"):
          d["containment"] = line.split(":")[-1].strip()
      runs[sra_id] = d

    pd.DataFrame(runs).T.to_csv(output[0])

rule summarize_depth:
  output: "outputs/minimap/depth/summary.csv"
  input:
    expand("outputs/minimap/depth/{sra_id}.txt",
           sra_id=("SRR1509798", "SRR1509792", "SRR1509799",
                   "SRR1509793", "SRR1509794", "ERR3256923", 
                   "SRR070081", "SRR070083", "SRR070084",
                   "SRR5868539", "SRR5868540", "SRR304680"))
  run:
    import pandas as pd

    runs = {}
    for sra_stat in input:
      data = pd.read_table(sra_stat, names=["contig", "pos", "coverage"])
      sra_id = sra_stat.split("/")[-1].split(".")[0]

      d = {}
      value_counts = data['coverage'].value_counts()
      d['genome bp'] = int(len(data))
      d['missed'] = int(value_counts.get(0, 0))
      d['percent missed'] = 100 * d['missed'] / d['genome bp']
      d['coverage'] = data['coverage'].sum() / len(data)
      runs[sra_id] = d

    pd.DataFrame(runs).T.to_csv(output[0])
