#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pandas as pd
import glob

##### Load config and sample sheets #####
onsuccess:
	print("buildHIC completed successfully!")

configfile: "config/config.yaml"

## Read in samplesheet
samples = pd.read_table(config["samplesheet"])

## Convert all columns to strings
samples = samples.astype(str)

## Concatenate columns to identify which groups to run (i.e. Seq_Rep will be run together)
samples['id'] = samples[config['groupBy']].agg('_'.join, axis=1)

## Use ID's as groups
groups = list(set(samples['id']))

## Group by id and extract merged_nodups, inter & inter30
int1 = samples.groupby('id')['inter'].apply(list).to_dict()
int30 = samples.groupby('id')['inter30'].apply(list).to_dict()
mergNoDups = samples.groupby('id')['merged_nodups'].apply(list).to_dict()

## Get unique files for each grouping
for key in groups:
    int1[key] = list(set(int1[key]))
    int30[key] = list(set(int30[key]))
    mergNoDups[key] = list(set(mergNoDups[key]))

##### Define rules #####
rule all:
    input:
        expand('output/{group}/{group}_norm{ext}', group=groups, ext=['_done.txt', '30_done.txt'])

## Merge all merged_nodups.txt files found in mergeList
rule mergedNoDups:
    input:
        lambda wildcards: mergNoDups[wildcards.group]
    output:
        merged_nodups = 'output/{group}/{group}_merged_nodups.txt.gz',
        done = 'output/{group}/{group}_mergedNoDups_done.txt'
    log:
        'output/{group}/logs/{group}_merged_nodups.err'
    benchmark:
        'output/{group}/benchmarks/{group}_merged_nodups.tsv'
    threads: 8
    shell:
        """
        module load pigz
        tmp_dir=$(mktemp -d -p $PWD)
        cmd="sort --parallel={threads} -T $tmp_dir -m -k2,2d -k6,6d "
        for i in {input}; do cmd="$cmd <(gunzip -c '$i')"; done
        eval "$cmd" | pigz -p {threads} > {output.merged_nodups} 2> {log}
        rm -rf $tmp_dir
        touch {output.done}
        """

## Create statistics files for MQ > 0
rule inter:
    input:
        inter = lambda wildcards: int1[wildcards.group],
        merged_nodups = rules.mergedNoDups.output.merged_nodups,
        merged_nodups_done = rules.mergedNoDups.output.done
    output:
        inter = 'output/{group}/{group}_inter.txt',
        hists = 'output/{group}/{group}_inter_hists.m',
        done = 'output/{group}/{group}_inter_done.txt'
    log:
        err = 'output/{group}/logs/{group}_inter.err'
    benchmark:
        'output/{group}/benchmarks/{group}_inter.tsv'
    params:
        ligation = config['ligation'],
        site_file = config['site_file']
    threads: 1
    shell:
        """
        awk -f ./scripts/makemega_addstats.awk {input.inter} > {output.inter} 2> {log.err}
        gunzip -c {input.merged_nodups} | ./scripts/statistics.pl -q 1 -o {output.inter} -s {params.site_file} 2>> {log.err}
        touch {output.done}
        """

## Create statistics files for MP > 30
rule inter30:
    input:
        inter30 = lambda wildcards: int30[wildcards.group],
        merged_nodups = rules.mergedNoDups.output.merged_nodups,
        merged_nodups_done = rules.mergedNoDups.output.done
    output:
        inter30 = 'output/{group}/{group}_inter_30.txt',
        hists30 = 'output/{group}/{group}_inter_30_hists.m',
        done = 'output/{group}/{group}_inter_done.txt'
    log:
        err = 'output/{group}/logs/{group}_inter30.err'
    benchmark:
        'output/{group}/benchmarks/{group}_inter30.tsv'
    params:
        ligation = config['ligation'],
        site_file = config['site_file']
    threads: 1
    shell:
        """
        awk -f ./scripts/makemega_addstats.awk {input.inter30} > {output.inter30} 2> {log.err}
        gunzip -c {input.merged_nodups} | ./scripts/statistics.pl -q 30 -o {output.inter30} -s {params.site_file} 2>> {log.err}
        touch {output.done}
        """

## Create Hi-C maps file for MQ > 0
rule hic:
    input:
        inter = rules.inter.output.inter,
        inter_done = rules.inter.output.done,
        hists = rules.inter.output.hists,
        merged_nodups = rules.mergedNoDups.output.merged_nodups,
        merged_nodups_done = rules.mergedNoDups.output.done
    output:
        hic = 'output/{group}/{group}_inter.hic'
    log:
        err = 'output/{group}/logs/{group}_hic.err',
        out = 'output/{group}/logs/{group}_hic.out'
    benchmark:
        'output/{group}/benchmarks/{group}_hic.tsv'
    params:
        chromSizes = config['chromSizes'],
        javaMem = config['javaMem']
    threads: 1
    shell:
        """
        export IBM_JAVA_OPTIONS="-Xmx{params.javaMem}m -Xgcthreads1"
        java -Djava.awt.headless=true -Djava.library.path=/lib64 -Xms{params.javaMem}m -Xmx{params.javaMem}m -jar ./scripts/juicer_tools.jar pre -n -s {input.inter} -g {input.hists} -q 1 {input.merged_nodups} {output.hic} {params.chromSizes} 1> {log.out} 2> {log.err}
        """

## Create Hi-C maps file for MQ > 30
rule hic30:
    input:
        inter30 = rules.inter30.output.inter30,
        inter30_done = rules.inter30.output.done,
        hists30 = rules.inter30.output.hists30,
        merged_nodups = rules.mergedNoDups.output.merged_nodups,
        merged_nodups_done = rules.mergedNoDups.output.done
    output:
        hic30 = 'output/{group}/{group}_inter_30.hic'
    log:
        err = 'output/{group}/logs/{group}_hic30.err',
        out = 'output/{group}/logs/{group}_hic30.out'
    benchmark:
        'output/{group}/benchmarks/{group}_hic30.tsv'
    params:
        chromSizes = config['chromSizes'],
        javaMem = config['javaMem']
    threads: 1
    shell:
        """
        export IBM_JAVA_OPTIONS="-Xmx{params.javaMem}m -Xgcthreads1"
        java -Djava.awt.headless=true -Djava.library.path=/lib64 -Xms{params.javaMem}m -Xmx{params.javaMem}m -jar ./scripts/juicer_tools.jar pre -n -s {input.inter30} -g {input.hists30} -q 30 {input.merged_nodups} {output.hic30} {params.chromSizes} 1> {log.out} 2> {log.err}
        """
	
## Create normalizations for MQ > 0
rule norm:
    input:
        hic = rules.hic.output.hic
    output:
        done = 'output/{group}/{group}_norm_done.txt'
    log:
        err = 'output/{group}/logs/{group}_norm.err',
        out = 'output/{group}/logs/{group}_norm.out'
    benchmark:
        'output/{group}/benchmarks/{group}_norm.tsv'
    params:
        javaMem = config['javaMem']
    threads: 1
    shell:
        """
        export IBM_JAVA_OPTIONS="-Xmx{params.javaMem}m -Xgcthreads1"
        java -Djava.awt.headless=true -Djava.library.path=/lib64 -Xms{params.javaMem}m -Xmx{params.javaMem}m -jar ./scripts/juicer_tools.jar addNorm {input.hic} 1> {log.out} 2> {log.err}
        touch {output.done}
        """
	
## Create normalizations for MQ > 30
rule norm30:
    input:
        hic30 = rules.hic30.output.hic30
    output:
        done = 'output/{group}/{group}_norm30_done.txt'
    log:
        err = 'output/{group}/logs/{group}_norm30.err',
        out = 'output/{group}/logs/{group}_norm30.out'
    benchmark:
        'output/{group}/benchmarks/{group}_norm.tsv'
    params:
        javaMem = config['javaMem']
    threads: 1
    shell:
        """
        export IBM_JAVA_OPTIONS="-Xmx{params.javaMem}m -Xgcthreads1"
        java -Djava.awt.headless=true -Djava.library.path=/lib64 -Xms{params.javaMem}m -Xmx{params.javaMem}m -jar ./scripts/juicer_tools.jar addNorm {input.hic30} 1> {log.out} 2> {log.err}
        touch {output.done}
        """
