###################################
############ Libraries ############
###################################

import pathlib

################################
############ Config ############
################################

segment_length_threshold = config["segmentLengthThreshold"]
tangle_segment_length_threshold = config.get('tangleSegmentLengthThreshold', 250000)

scripts_dir = config["scripts_dir"]
per_thread_memory = config.get("per_thread_memory", False)

## Experimental Features
calc_breakpoints = config.get('calc_breakpoints', False)


################################
########## Variables ###########
################################

bwa_index_suffices = ["amb", "ann", "bwt", "pac", "sa"]

################################
########## Functions ###########
################################

def get_script_path(*args):
    out = pathlib.Path(scripts_dir).joinpath(*args).resolve(strict=True)
    return out.as_posix()

def make_mem_calculator(scale=1024, per_thread_memory=False):
    def calc_mem_mb(m=1, b=None):
        if b is None:
            b = m
        def _calc_mem_mb(wildcards, threads, attempt):
            x = attempt-1
            mem_mb = (m * x + b) * scale
            if per_thread_memory:
                mem_mb /= threads
            return mem_mb

        return _calc_mem_mb

    return calc_mem_mb

def calc_walltime(m=1, b=None):
    if b is None:
        b = m
    def _calc_walltime(wildcards, attempt):
        x = attempt - 1
        walltime = f'{(m * x * x + b):02}:59:00'
        return walltime

    return _calc_walltime

calc_mem = make_mem_calculator(scale=1024, per_thread_memory=per_thread_memory)


# Sample to Input #

def sample2ss(wildcards):
    out = MAP_SAMPLE_TO_INPUT[wildcards.sample]['strandseq_pairs'][wildcards.lib][wildcards.pair]
    return out

def sample2unmergedss(wildcards, file):
    fpath = 'ss/unmerged/{sample}/{lib}_'

    if file==1:
        fpath += 'file1'
    else:
        fpath += 'file2'

    if MAP_SAMPLE_TO_INPUT[wildcards.sample]['hpc']:
        fpath += '.homopolymer-compressed.fasta'
    else:
        fpath += '.renamed.fasta'

    return fpath

def sample2mergedss(wildcards):
    fpath = 'ss/merged/{sample}/{lib}.combined'

    if MAP_SAMPLE_TO_INPUT[wildcards.sample]['hpc']:
        fpath += '.homopolymer-compressed.fasta'
    else:
        fpath += '.renamed.fasta'

    return fpath

def sample2mem(wildcards, ext='bam'):
    out = expand("bwa_alignments/mem/{{sample}}/{lib}.mdup.filt."+ext,
        lib=MAP_SAMPLE_TO_INPUT[wildcards.sample]['strandseq_libs'])
    return out

def sample2fastmap(wildcards):
    out = expand("bwa_alignments/fastmap/{{sample}}/{lib}_maximal_unique_exact_match.tsv",
        lib=MAP_SAMPLE_TO_INPUT[wildcards.sample]['strandseq_libs'])
    return out

###################################
############ Rules ################
###################################

wildcard_constraints:
    sample='[^/_]+',
    lib='[0-9]+'

include: 'rules/sample_table.smk'
include: 'rules/collect_output.smk'

if calc_breakpoints:
    include: 'rules/strandseq_breakpoints.smk'
    ALL_OUTPUT.append(expand("breakpoints/{sample}_breakpoint_windows.csv", sample=SAMPLES))

# Inferred File Pairs
ALL_OUTPUT.append(expand('library_ids/{sample}_library_ids.csv', sample=SAMPLES))

###############################
############ All ##############
###############################

rule all:
    input: ALL_OUTPUT

################################################
################# Library IDs ##################
################################################

# Export Library code-> file CSV
rule export_library_ids:
    output: 'library_ids/{sample}_library_ids.csv'
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    run:
        print(f'Exporting library ids for {wildcards.sample}')
        sseq_pairs = MAP_SAMPLE_TO_INPUT[wildcards.sample]['strandseq_pairs']
        df = pandas.DataFrame.from_dict(sseq_pairs, orient='index')
        df.to_csv(output[0], index=True, index_label='library')

################################################
############ Unmerged SS Processing ############
################################################

# Unmerged reads are aligned in bwa mem paired end mode, and alignments are
# used for the initial clustering step which assigns
# unitig chromosome and orientation, and calls library strand state


rule unzip_ss:
    input: sample2ss
    output: temp("ss/unmerged/{sample}/{lib}_{pair}.fasta")
    conda:'envs/env_cl.yaml'
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    log: "log/unzip_ss_{sample}_{lib}_{pair}.log"
    shell:
        '''
        (time bioawk -c fastx '{{print \">\"$name; print $seq}}' <(cat {input}) > {output}) > {log} 2>&1
        '''

# TODO I don't think this step is necessary anymore
rule add_ss_libname_unmerged:
    input: "ss/unmerged/{sample}/{lib}_{pair}.fasta"
    output: temp("ss/unmerged/{sample}/{lib}_{pair}.renamed.fasta")
    conda:'envs/env_cl.yaml'
    threads: 1
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    log: "log/add_ss_libname_{sample}_{lib}_{pair}.log"
    shell:
        '''
        (bioawk -c fastx -v libname={wildcards.lib} '{{print \">\"$name"_"libname; print $seq}}' <(cat {input}) > {output}) > {log} 2>&1
        '''

rule homopolymer_compress_unmerged_ss:
    input: "ss/unmerged/{sample}/{lib}_{pair}.renamed.fasta"
    output: "ss/unmerged/{sample}/{lib}_{pair}.homopolymer-compressed.fasta"
    conda:'envs/env_cl.yaml'
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    shell:
        '''
        seqtk hpc {input} > {output}
        '''


##############################################
############ Merged SS Processing ############
##############################################
# Merged SS reads are used with bwa fastmap, for longer exact matches.
# How useful merging actually is for that step is untested, but it feels like
# the right thing to do as fastmap does not have a paired alignment mode

# TODO check the name of the merged reads, does it simply take the name of the first read in the pair?


rule pear_merge_mates:
    input:
        fq1=lambda wildcards: MAP_SAMPLE_TO_INPUT[wildcards.sample]['strandseq_pairs'][wildcards.lib]['file1'],
        fq2=lambda wildcards: MAP_SAMPLE_TO_INPUT[wildcards.sample]['strandseq_pairs'][wildcards.lib]['file2']
    output:
        "ss/merged/{sample}/{lib}.assembled.fastq",
        "ss/merged/{sample}/{lib}.discarded.fastq",
        "ss/merged/{sample}/{lib}.unassembled.forward.fastq",
        "ss/merged/{sample}/{lib}.unassembled.reverse.fastq"
    conda:'envs/env_cl.yaml'
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    log: "log/pear_merge_mates_{sample}_{lib}.log"
    benchmark: "benchmark/pear_merge_mates_{sample}_{lib}.benchmark"
    shell: "(pear -f {input.fq1} -r {input.fq2} -o ss/merged/{wildcards.sample}/{wildcards.lib}) > {log} 2>&1"

rule concat_assembled_with_first_pair_of_unassembled:
    input:
        "ss/merged/{sample}/{lib}.assembled.fastq",
        "ss/merged/{sample}/{lib}.unassembled.forward.fastq"
    output: temp("ss/merged/{sample}/{lib}.combined.fasta")
    conda:'envs/env_cl.yaml'
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    log: "log/concat_merged_with_first_unmerged_{sample}_{lib}.log"
    shell: "(time bioawk -c fastx '{{print \">\"$name; print $seq}}' <(cat {input}) > {output}) > {log} 2>&1"

rule add_ss_libname_merged:
    input: "ss/merged/{sample}/{lib}.combined.fasta"
    output: temp("ss/merged/{sample}/{lib}.combined.renamed.fasta")
    conda:'envs/env_cl.yaml'
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    shell:
        '''
        bioawk -c fastx -v libname={wildcards.lib} '{{print \">\"$name"_"libname; print $seq}}' <(cat {input}) > {output}
        '''

rule homopolymer_compress_merged_ss:
    input: "ss/merged/{sample}/{lib}.combined.renamed.fasta" # concat_assembled_with_first_pair_of_unassembled
    output: "ss/merged/{sample}/{lib}.combined.homopolymer-compressed.fasta"
    conda:'envs/env_cl.yaml'
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    shell:
        '''
        seqtk hpc {input} > {output}
        '''

#######################################
############ Index Unitigs ############
#######################################
rule gfa_to_fasta:
    input: lambda wildcards: MAP_SAMPLE_TO_INPUT[wildcards.sample]['gfa']
    output: "fasta/{sample}/{sample}_unitigs-hpc.fasta"
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    log: "log/gfa_to_fasta_{sample}.log"
    benchmark: "benchmark/gfa_to_fasta_{sample}.benchmark"
    shell:
        '''
        (time grep S {input} | awk '{{print \">\"$2\"\\n\"$3}}' > {output}) > {log} 2<&1
        '''

rule bwa_index_unitigs:
    input: "fasta/{sample}/{sample}_unitigs-hpc.fasta" # gfa_to_fasta
    output: expand("fasta/{{sample}}/{{sample}}_unitigs-hpc.fasta.{bwa_index_suffix}", bwa_index_suffix=bwa_index_suffices)
    conda:'envs/env_cl.yaml'
    resources:
        mem_mb=calc_mem(8),
        walltime=calc_walltime(1, 8)
    log: "log/bwa_index_unitigs_{sample}.log"
    benchmark: "benchmark/bwa_index_unitigs_{sample}.benchmark"
    shell: "(time bwa index {input}) > {log} 2>&1"

#####################################################################
############ bwa mem Unmerged SS Reads to Assembly Unitigs ##########
#####################################################################


# This block step also filters out supplementary and secondary alignments,
# then filters out duplicates and pairs where either read is unmapped, then
# filters only for first mate reads in proper pairs.
# https://broadinstitute.github.io/picard/explain-flags.html
rule bwa_align_unmerged_ss_to_unitigs:
    input:
        unitigs="fasta/{sample}/{sample}_unitigs-hpc.fasta", # gfa_to_fasta
        unitigs_index=expand("fasta/{{sample}}/{{sample}}_unitigs-hpc.fasta.{bwa_index_suffix}", bwa_index_suffix=bwa_index_suffices), # bwa_index_unitigs
        mate1=lambda wildcards: sample2unmergedss(wildcards, 1),
        mate2=lambda wildcards: sample2unmergedss(wildcards, 2)
    output: temp("temp_bwa_alignments/mem/{sample}/{lib}.bam")
    conda:'envs/env_cl.yaml'
    threads: 6
    resources:
        mem_mb=calc_mem(12),
        walltime=calc_walltime()
    log:
        bwa="log/bwa_align_unmerged_ss_to_unitigs_bwa_{sample}_{lib}.log",
        samtools="log/bwa_align_unmerged_ss_to_unitigs_samtools_{sample}_{lib}.log"
    benchmark: "benchmark/bwa_align_unmerged_ss_to_unitigs_{sample}_{lib}.benchmark"
    shell:
        '''
        bwa mem -t {threads} -R "@RG\\tID:{wildcards.lib}" -v 2 {input.unitigs} {input.mate1} {input.mate2} 2> {log.bwa} | samtools view -b -F 2304 /dev/stdin > {output} 2> {log.samtools}
        '''

rule bwa_sort_unmerged_ss_alignments:
    input:  "temp_bwa_alignments/mem/{sample}/{lib}.bam" # bwa_align
    output: temp("bwa_alignments/mem/{sample}/{lib}.bam")
    conda:'envs/env_cl.yaml'
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    log: "log/bwa_sort_unmerged_{sample}_{lib}.log"
    benchmark: "benchmark/bwa_sort_unmerged_{sample}_{lib}.benchmark"
    shell:
        '''
        (time samtools sort -o {output} {input}) > {log} 2>&1
        '''

rule unmerged_mark_duplicates:
    input:  "bwa_alignments/mem/{sample}/{lib}.bam" # bwa_sort
    output: temp("bwa_alignments/mem/{sample}/{lib}.mdup.bam")
    conda:'envs/env_cl.yaml'
    resources:
        mem_mb=calc_mem(2),
        walltime=calc_walltime()
    log: "log/unmerged_mark_duplicates_{sample}_{lib}.log"
    benchmark: "benchmark/unmerged_mark_duplicates_{sample}_{lib}.benchmark"
    shell:
        '''
        (time sambamba markdup {input} {output}) > {log} 2>&1
        '''

# Some filtering happens at the alignment step to save a little time in duplicate marking, but this step can also
# # stand alone for the filtering.
rule unmerged_filter_bam:
    input: rules.unmerged_mark_duplicates.output
    output: "bwa_alignments/mem/{sample}/{lib}.mdup.filt.bam"
    conda:'envs/env_cl.yaml'
    resources:
        mem_mb=calc_mem(2),
        walltime=calc_walltime()
    log: "log/unmerged_mark_duplicates_{sample}_{lib}.log"
    benchmark: "benchmark/unmerged_mark_duplicates_{sample}_{lib}.benchmark"
    shell:
        '''
        samtools view -b -F 3340 -f 66 {input} > {output} 2> {log} 2>&1
        '''

# DO I still need this rule from Maryam? Does it make filtering and import of the Bams in the R program easier?
rule unmerged_bwa_index:
    input:  rules.unmerged_filter_bam.output[0]
    output: rules.unmerged_filter_bam.output[0] + '.bai'
    conda:'envs/env_cl.yaml'
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    log: "log/unmerged_bwa_index_{sample}_{lib}.log"
    benchmark: "benchmark/unmerged_bwa_index_{sample}_{lib}.benchmark"
    shell:
        '''
        (time samtools index {input}) > {log} 2>&1
        '''

#############################################################
############ fastmap merged Strand-seq reads ################
#############################################################

rule fastmap_ss_reads_to_unitigs:
    input:
        unitigs="fasta/{sample}/{sample}_unitigs-hpc.fasta",# gfa_to_fasta
        unitigs_index=expand("fasta/{{sample}}/{{sample}}_unitigs-hpc.fasta.{bwa_index_suffix}",bwa_index_suffix=bwa_index_suffices),# bwa_index_unitigs
        SSreads=sample2mergedss
    output: "bwa_alignments/fastmap/{sample}/{lib}_maximal_unique_exact_match.tsv" # fastmap doesn't output bam
    conda: 'envs/env_cl.yaml'
    resources:
        mem_mb=calc_mem(6),
        walltime=calc_walltime()
    log: "log/map_SS_reads_to_unitigs_{sample}_{lib}.log"
    benchmark: "benchmark/map_SS_reads_to_unitigs_{sample}_{lib}.benchmark"
    shell: "(bwa fastmap -w 1 -l 75 {input.unitigs} {input.SSreads} > {output}) > {log} 2>&1"

###################################################
############ Acrocentric Tangle Removal ###########
###################################################

rule remove_acrocentric_tangle:
    input: lambda wildcards: MAP_SAMPLE_TO_INPUT[wildcards.sample]['gfa']
    output: gfa = 'gfa/gfa/{sample}_exploded.gfa',
            ccs = 'gfa/ccs/{sample}_exploded_ccs.tsv'
    params:
        segment_length_threshold = tangle_segment_length_threshold,
        script = get_script_path('R','explode_largest_component.R')
    conda: 'envs/env_Renv.yaml'
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    log: "log/remove_acrocentric_tangle_{sample}.log"
    benchmark: "benchmark/remove_acrocentric_tangle_{sample}.benchmark"
    shell:
        '''
        (Rscript --vanilla {params.script} \\
        --gfa {input} \\
        --output-gfa {output.gfa} \\
        --output-ccs {output.ccs} \\
        --segment-length-threshold {params.segment_length_threshold}) > {log} 2>&1
        '''


#######################################
############ Phase Unitigs ############
#######################################

# IN an effort to avoid using R bioconductor packages, which can be hard to install, output to SAM and read with
# non-bioconductor R packages.
rule counting_bam_to_sam:
    input: rules.unmerged_bwa_index.input
    output: temp('temp/sams/{sample}/{lib}.mdup.filt.sam')
    conda:'envs/env_cl.yaml'
    resources:
        mem_mb=calc_mem(2),
        walltime=calc_walltime()
    log: "log/unmerged_mark_duplicates_{sample}_{lib}.log"
    benchmark: "benchmark/unmerged_mark_duplicates_{sample}_{lib}.benchmark"
    shell:
        '''
        samtools view --with-header {input} > {output} 2> {log} 2>&1
        '''

rule count_sseq_alignments:
    input:
        sam= lambda wildcards: expand('temp/sams/{{sample}}/{lib}.mdup.filt.sam', lib=MAP_SAMPLE_TO_INPUT[wildcards.sample]['strandseq_libs']),
        # bai=lambda wildcards: sample2mem(wildcards, 'bam.bai'),
        fastmap=sample2fastmap,
    output:
        mem = "sseq_alignment_counts/{sample}_sseq_mem_counts.csv",
        fastmap = "sseq_alignment_counts/{sample}_sseq_fastmap_counts.csv",
    params:
        script = get_script_path('R', 'count_alignments.snakemake.R')
    conda: 'envs/env_Renv.yaml'
    threads: 1
    resources:
        mem_mb=calc_mem(32),
        walltime=calc_walltime()
    log: "log/count_sseq_alignments_{sample}.log"
    benchmark: "benchmark/count_sseq_alignments_{sample}.benchmark"
    shell:
        '''
        (Rscript --vanilla {params.script} \\
        --mem-alignment-bams {input.sam} \\
        --fastmap-alignments {input.fastmap} \\
        --threads {threads} \\
        --aggregate-alignments TRUE \\
        --output-mem {output.mem} \\
        --output-fastmap {output.fastmap}) > {log} 2>&1
        '''

def if_file_exists(file_name):
    file_path = pathlib.Path(file_name)
    if file_path.is_file():
        return str(file_path)
    else:
        return ""

def get_optional_phasing_input(wildcards, what):
    return if_file_exists('phasing_input/' + wildcards.sample + '/' + what)

rule count_haplotype_markers:
    input:
        mem = "sseq_alignment_counts/{sample}_sseq_mem_counts.csv",
        fastmap="sseq_alignment_counts/{sample}_sseq_fastmap_counts.csv",
        connected_components='gfa/ccs/{sample}_exploded_ccs.tsv'
    output:
        mc="haplotype_marker_counts/{sample}_haplotype_marker_counts.csv",
        lib="library_weights/{sample}_library_weights.csv",
        imo=directory("intermediate_output/{sample}/"),
    params:
        segment_length_threshold=segment_length_threshold,
        script = get_script_path('R','clustering_orient_strandstate.snakemake.R'),
        il = lambda wildcards: get_optional_phasing_input(wildcards,'included_libraries.tsv'),
        pc = lambda wildcards: get_optional_phasing_input(wildcards,'prior_clusters.tsv'),
        rc= lambda wildcards: get_optional_phasing_input(wildcards,'refined_clusters.tsv'),
        fc = lambda wildcards: get_optional_phasing_input(wildcards,'final_clusters.tsv'),
        uo = lambda wildcards: get_optional_phasing_input(wildcards,'unitig_orientation.tsv'),
        cm = lambda wildcards: get_optional_phasing_input(wildcards,'counting_methods.tsv')
    # singularity: singularity_r_env
    conda: 'envs/env_R-COS.yaml'
    resources:
        mem_mb=calc_mem(32),
        walltime=calc_walltime()
    log: "log/count_haplotype_markers_{sample}.log"
    benchmark: "benchmark/count_haplotype_markers_{sample}.benchmark"
    shell:
        '''
        mkdir -p {output.imo} &&
        (Rscript --vanilla {params.script} \\
        --mem-counts {input.mem} \\
        --fastmap-counts {input.fastmap} \\
        --connected-components {input.connected_components} \\
        --segment-length-threshold {params.segment_length_threshold} \\
        --intermediate-output-dir {output.imo} \\
        --included-libraries {params.il} \\
        --prior-clusters {params.pc} \\
        --refined-clusters {params.rc} \\
        --final-clusters {params.fc} \\
        --unitig-orientation {params.uo} \\
        --counting-methods {params.cm} \\
        --output-marker-counts {output.mc} \\
        --output-lib {output.lib}) > {log} 2>&1
        '''

rule fudge_haplotype_markers:
    input:
        hmc = "haplotype_marker_counts/{sample}_haplotype_marker_counts.csv"
    output:
        "haplotype_marker_counts/{sample}_fudged_haplotype_marker_counts.csv"
    params:
        script = get_script_path('R','fudge_marker_counts.snakemake.R')
    conda: 'envs/env_Renv.yaml'
    resources:
        mem_mb=calc_mem(4),
        walltime=calc_walltime()
    log: "log/fudge_haplotype_markers_{sample}.log"
    benchmark: "benchmark/fudge_haplotype_markers_{sample}.benchmark"
    shell:
        '''
        (Rscript --vanilla {params.script} \\
        --hmc {input.hmc} \\
        --output {output}) > {log} 2>&1
        '''
################################################################
############ HiFiasm  Graph Hifi Coverage for Rukki ############
################################################################

rule extract_coverage_from_hifiasm_gfa:
    input:
        gfa=lambda wildcards: MAP_SAMPLE_TO_INPUT[wildcards.sample]['gfa'],
    output:
        'hifiasm_hifi_coverage/{sample}_hifiasm_hifi_coverage.tsv'
    shell:
        '''
        grep "^S" {input} | awk -F'\t' '{{gsub(/^LN:i:/, "", $4); gsub(/^rd:i:/, "", $5); print $2 "\t" $5 "\t" $4}}' > {output}
        '''


#######################################################
############ Run Rukki Graph Hifi Coverage ############
#######################################################

rule add_coverage_to_gfa:
    input:
        gfa=lambda wildcards: MAP_SAMPLE_TO_INPUT[wildcards.sample]['gfa'],
        coverage=lambda wildcards: MAP_SAMPLE_TO_INPUT[wildcards.sample]['coverage']
    output:
        temp('temp/{sample}_gfa_with_coverage.gfa')
    conda: 'envs/env_verkko.yaml'
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    log: "log/add_coverage_to_gfa_{sample}.log"
    benchmark: "benchmark/add_coverage_to_gfa_{sample}.benchmark"
    shell:
        '''
        ($CONDA_PREFIX/lib/verkko/scripts/inject_coverage.py --allow-absent {input.coverage} {input.gfa} > {output}) > {log} 2>&1
        '''

rule convert_markers_to_tsv:
    input: "haplotype_marker_counts/{sample}_fudged_haplotype_marker_counts.csv"
    output: temp("haplotype_marker_counts/{sample}_fudged_haplotype_marker_counts.tsv")
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    shell:
        '''
           tail -n +2 {input} | awk -F',' '{{OFS="\t"; $1=$1; print}}' > {output}
        '''

#  hifiasm paramters are very experimental. Just trying to manage the presence of long, homozygous nodes, which aren't found in verkko graphs.
def rukki_paramaterizer(wildcards, parameter) :
    assembler = MAP_SAMPLE_TO_INPUT[wildcards.sample]['assembler']
    if parameter == 'solid-homozygous-cov-coeff':
        if assembler == 'verkko':
            return 1.1
        elif assembler == 'hifiasm':
            return 0

    if parameter == 'max-homozygous-len':
        if assembler == 'verkko':
            return 10000000
        elif assembler == 'hifiasm':
            return 10000000

    else:
        raise ValueError(f'No function for parameter: {parameter}')

    raise ValueError('Got nothing')

rule run_rukki:
    input:
        graph='temp/{sample}_gfa_with_coverage.gfa',
        markers="haplotype_marker_counts/{sample}_fudged_haplotype_marker_counts.tsv"
    output:
        tsv="rukki/{sample}/{sample}_rukki_paths.tsv",
        gaf="rukki/{sample}/{sample}_rukki_paths.gaf",
        initial_annotation="rukki/{sample}/{sample}_out_init_ann.csv",
        refined_annotation='rukki/{sample}/{sample}_out_refined_ann.csv',
        final_annotation='rukki/{sample}/{sample}_out_final_ann.csv'
    conda: 'envs/env_verkko.yaml'
    resources:
        mem_mb=calc_mem(6),
        walltime=calc_walltime()
    params:
        shcc =lambda wildcards: rukki_paramaterizer(wildcards, 'solid-homozygous-cov-coeff'),
        mhl= lambda wildcards: rukki_paramaterizer(wildcards, 'max-homozygous-len')

    log: "log/run_rukki_{sample}.log"
    benchmark: "benchmark/run_rukki_{sample}.benchmark"
    priority: 1
    shell:
        '''
        params=""
        params="$params --init-assign {output.initial_annotation}"
        params="$params --refined-assign {output.refined_annotation}"
        params="$params --final-assign {output.final_annotation}"
        params="$params --hap-names haplotype1,haplotype2"

        # Minimal number of parent-specific markers required for assigning parental group to a node [default: 10]
        params="$params --marker-cnt 10" 

        # Require at least (node_length / <value>) markers within the node for parental group assignment [default: 10000]
        params="$params --marker-sparsity 5000000" 

        # Sets minimal marker excess for assigning a parental group to <value>:1 [default: 5]
        params="$params --marker-ratio 4" 

        # Longer nodes are unlikely to be spurious and likely to be reliably assigned based on markers (used in HOMOZYGOUS node labeling) [default: 200000]
        params="$params --trusted-len 250000"
        
        # Require at least (node_length / <value>) markers for assigning ISSUE label (by default == marker_sparsity, will typically be set to a value >= marker_sparsity)
        # params="$params --issue-sparsity 10000" 

        # Require primary marker excess BELOW <value>:1 for assigning ISSUE label. Must be <= marker_ratio (by default == marker_ratio)
        # Apparently, Rukki is very bad at distinguishing between large HOM and ISSUE nodes, seemingly preferring to 
        # assign them to ISSUE. My guess is that setting this to 1 (essentially disabling ISSUE nodes) is an attempt to force it to consider more nodes as HOM 
        params="$params --issue-ratio 1."

        # Try to fill in small ambiguous bubbles
        params="$params --try-fill-bubbles"

        # Bubbles including a longer alternative sequence will not be filled [default: 50000]
        params="$params --fillable-bubble-len 500000"

        # Bubbles with bigger difference between alternatives' lengths will not be filled [default: 200]
        params="$params --fillable-bubble-diff 1000"

        # Longer nodes are unlikely to represent repeats, polymorphic variants, etc (used to seed and guide the path search) [default: 500000]
        params="$params --solid-len 500000" 
        
        # Solid nodes with coverage below <coeff> * <weighted mean coverage of 'solid' nodes> can not be classified as homozygous. 0. disables check [default: 1.5]
        params="$params --solid-homozygous-cov-coeff {params.shcc}"

        # Sets minimal marker excess for assigning a parental group of solid nodes to <value>:1. Must be <= marker_ratio (by default == marker_ratio)
        params="$params --solid-ratio 3"
        
        # Longer nodes can not be classified as homozygous [default: 2000000]
        params="$params --max-homozygous-len {params.mhl}"
        
        # Assign tangles flanked by solid nodes from the same class, # Allow dead-end nodes in the tangles
        params="$params --assign-tangles --tangle-allow-deadend"

        ($CONDA_PREFIX/lib/verkko/bin/rukki trio -g {input.graph} -m {input.markers}              -p {output.tsv} $params) > {log} 2>&1
        ($CONDA_PREFIX/lib/verkko/bin/rukki trio -g {input.graph} -m {input.markers} --gaf-format -p {output.gaf} $params) >> {log} 2>&1
        '''


##############################################
############ Hifiasm Yak Counting ############
##############################################

# Peter mentioned that this should be very quick to count, and I am therefore keeping it
# simple and running both haplotypes in a single rule

# FIXME This doesn't work, as the attempts to force rukki to assign HOM labels completely turn off the assignment of HOM
#  labels (--issue-ratio), and therefore the paths need to be processed instead of the annotations.
rule split_haplotype_annotations:
    input: 'rukki/{sample}/{sample}_out_final_ann.csv'
    output:
        hap1=temp('rukki/{sample}/{sample}_hap1_nodes.txt'),
        hap2=temp('rukki/{sample}/{sample}_hap2_nodes.txt')
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    shell:
        '''
        awk -F'\t' '$2 == "HAPLOTYPE1" || $2 == "HOM" {{print $1}}' "{input}" > {output.hap1}
        awk -F'\t' '$2 == "HAPLOTYPE2" || $2 == "HOM" {{print $1}}' "{input}" > {output.hap2}
        '''

rule split_assembly_for_yak:
    input:
        hap1_nodes = 'rukki/{sample}/{sample}_hap1_nodes.txt',
        hap2_nodes = 'rukki/{sample}/{sample}_hap2_nodes.txt',
        fasta= 'fasta/{sample}/{sample}_unitigs-hpc.fasta'
    output:
        hap1=temp('fasta/{sample}/{sample}_unitigs_hap1.fasta'),
        hap2=temp('fasta/{sample}/{sample}_unitigs_hap2.fasta')
    conda: 'envs/env_yak.yaml'
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    benchmark: "benchmark/split_assembly_for_yak_{sample}.benchmark"
    shell:
        '''
        seqtk subseq {input.fasta} {input.hap1_nodes} > {output.hap1}
        seqtk subseq {input.fasta} {input.hap2_nodes} > {output.hap2}
        '''

rule duplicate_sequences_for_yak:
    input:
        hap1 = 'fasta/{sample}/{sample}_unitigs_hap1.fasta',
        hap2 = 'fasta/{sample}/{sample}_unitigs_hap2.fasta'
    output:
        hap1=temp('fasta/{sample}/{sample}_unitigs_hap1_dup.fasta'),
        hap2=temp('fasta/{sample}/{sample}_unitigs_hap2_dup.fasta')
    resources:
        mem_mb=calc_mem(),
        walltime=calc_walltime()
    shell:
        '''
        for ((i=1; i < 10; i++)); do
            cat {input.hap1} >> {output.hap1}
            cat {input.hap2} >> {output.hap2}
        done
        '''

rule yak_count_split_fasta:
    input:
        hap1='fasta/{sample}/{sample}_unitigs_hap1_dup.fasta',
        hap2='fasta/{sample}/{sample}_unitigs_hap2_dup.fasta'
    output:
        hap1='yak/{sample}_hap1.yak',
        hap2='yak/{sample}_hap2.yak'
    threads: 24
    conda:'envs/env_yak.yaml'
    resources:
        mem_mb=calc_mem(96),
        walltime=calc_walltime()
    log:
        hap1="log/yak_count_fasta_hap1_{sample}.log",
        hap2="log/yak_count_fasta_hap2_{sample}.log"
    benchmark: "benchmark/yak_count_fasta_{sample}.benchmark"
    shell:
        '''
        yak count -b37 -t{threads} -o {output.hap1} {input.hap1} > {log.hap1} 2>&1
        yak count -b37 -t{threads} -o {output.hap2} {input.hap2} > {log.hap2} 2>&1
        '''
