"""
This script returns data for the number of pathways present in
individual and ensemble ADAGE models after crosstalk removal.
Output:
    1 .txt file containing tab delimited rows formatted as such:
    model type, # models (N), counts for N models, sep. by semicolons
    Example:
      eADAGE same   9   63;65;61;62;61;60;63;65;65

Usage:
    pathway_counts_no_crosstalk.py <sigpathway-ens-dir> <sigpathway-indiv-dir>
    pathway_counts_no_crosstalk.py -h | --help

Options:
    -h --help                       Show this screen.
    <sigpathway-ens-dir>            Path to directory of significant pathway
                                    files, generated by
                                    `pathway_coverage_no_crosstalk.R`,
                                    for ensemble ADAGE models.
    <sigpathway-indiv-dir>          Path to directory of significant pathway
                                    files for individual ADAGE models.

"""
from docopt import docopt
import os
import csv
import pandas as pd

NUM_NODES = 300
OUTPUT_FILE = os.path.join("models_sigpathways_no_crosstalk",
                           "MODEL_PATHWAY_COVERAGE_no_crosstalk.txt")
MODEL_PATHWAY_COVERAGE = {
    "corADAGE same": [],
    "corADAGE different": [],
    "eADAGE same": [],
    "eADAGE different": [],
    "individual ADAGE": []
}


def _get_pathway_coverage(model_sigpathway_file):
    """Helper function to load the model's significant pathways file and
       return the number of unique pathways covered in the model after
       crosstalk has been removed.
   """
    node_pathway_df = pd.read_table(model_sigpathway_file,
                                    sep="\t",
                                    header=0,
                                    usecols=["pathway"])
    return len(set(node_pathway_df["pathway"]))


def update_pathway_coverage_data(sigpathway_ensemble_dir,
                                 sigpathway_individ_dir):
    """Updates the dictionary MODEL_PATHWAY_COVERAGE, which will be written
       to the OUTPUT_FILE specified in this script.
   """
    # collect pathway counts for the eADAGE models:
    for filename in os.listdir(sigpathway_ensemble_dir):
        path_to_file = os.path.join(sigpathway_ensemble_dir, filename)
        if os.path.isdir(path_to_file):
            continue
        num_pws = _get_pathway_coverage(path_to_file)
        key_to_update = None

        # determine model type based on naming conventions
        if "_ClusterByweight_" in filename:
            key_to_update = "corADAGE "
        else:
            key_to_update = "eADAGE "
        if "_1_100_" in filename and "_seed=1_" in filename:
            # there is one model considered both in "same" and "different"
            update_same = key_to_update + "same"
            update_different = key_to_update + "different"
            MODEL_PATHWAY_COVERAGE[update_same].append(str(num_pws))
            MODEL_PATHWAY_COVERAGE[update_different].append(str(num_pws))
        else:
            if "_1_100_" in filename:
                key_to_update += "same"
            else:
                key_to_update += "different"
            MODEL_PATHWAY_COVERAGE[key_to_update].append(str(num_pws))

    # collect pathway counts for the individual ADAGE models:
    for fn in os.listdir(sigpathway_individ_dir):
        path_to_file = os.path.join(sigpathway_individ_dir, fn)
        if os.path.isdir(path_to_file):
            continue
        num_pws = _get_pathway_coverage(path_to_file)
        MODEL_PATHWAY_COVERAGE["individual ADAGE"].append(str(num_pws))

if __name__ == "__main__":
    arguments = docopt(__doc__, version=None)
    sigpathway_ensemble_dir = arguments["<sigpathway-ens-dir>"]
    sigpathway_individ_dir = arguments["<sigpathway-indiv-dir>"]

    update_pathway_coverage_data(
        sigpathway_ensemble_dir, sigpathway_individ_dir)

    with open(OUTPUT_FILE, "w") as fp:
        writer = csv.writer(fp, delimiter="\t")
        for label, counts_by_model in MODEL_PATHWAY_COVERAGE.iteritems():
            N = len(counts_by_model)
            counts = ";".join(counts_by_model)
            line = (label, N, counts)
            writer.writerow(line)
