"""Upload `kipoi test` predictions to S3 for all the test models

1. create a few major dependency groups
  - python 2 vs 3
  - keras 1 vs 2
2. Install these 4 environments (snakemake)
3. Run predictions in each
4. Upload to S3
5. Do the same in the new model group
6. Put the comparison to py.test
"""
import numpy as np
from tqdm import tqdm
import kipoi
import os
from kipoi.utils import read_txt
from kipoi.cli.source_test import all_models_to_test

source = kipoi.get_source("kipoi")
# get all models to test
all_test_models = all_models_to_test(source)

exclude = ['BassetGM12878_Demo', 'Basenji', 'extended_coda']
# all_model_groups = [m.split("/")[0] for m in all_test_models]
# env_names = [get_env_name(m) for m in all_test_models]


model_environments = {
    "kipoi-py27-keras2": [],
    "kipoi-py27-keras1.2": [],
    "kipoi-py3-keras1.2": ["CpGenie",
                           "DeepCpG_DNA",
                           "Divergent421", ],  # keras == 1.2.2
    "kipoi-py3-keras2": ["Basenji",  # tf>=1.4
                         "Basset",  # pytorch>=0.2
                         "DeepBind",  # tf>=1.4
                         "DeepSEA",  # pytorch>=0.2
                         "FactorNet",  # tensorflow > 1.4.1
                         "HAL",
                         "KipoiSplice",
                         "MaxEntScan",
                         "SiSp",
                         "labranchor",
                         "lsgkm-SVM",
                         "pwm_HOCOMOCO",
                         "rbp_eclip",
                         "CleTimer",
                         ],
    "kipoi-MMSplice": ["MMSplice"]
}

# --------------------------------------------
# Functions


def get_env_name(model):
    """Get the right environment name for each tested model
    """
    model_group = model.split("/")[0]

    for k in model_environments:
        if model_group in model_environments[k]:
            return k
    return None

def git_commit(source_folder):
    from kipoi.conda import _call_command
    from kipoi.utils import cd
    print("Commit number:")
    with cd(source_folder):
        code, lines = _call_command("git", ['rev-parse', 'HEAD'], use_stdout=True,
                                    return_logs_with_stdout=True)
    assert len(lines) == 1
    return lines[0]


def export_environemtn(model_environments):
    """Create two major environments

    # remaining models:
    CleTimer
    deepTarget (theano 0.8.2, keras=0.3.3)
    """
    from kipoi.cli.env import export_env
    for env_name in model_environments:
        env_name = 'kipoi-py3-keras1.2'
        models = model_environments[env_name]
        if not models:
            continue
        output_dir = "/home/avsec/.kipoi/env"
        os.makedirs(output_dir, exist_ok=True)

        env, env_file = export_env(models,
                                   None,
                                   'kipoi',
                                   env_file=f"{output_dir}/{env_name}.yaml",
                                   env=env_name,
                                   vep=True,
                                   gpu=False)


# --------------------------------------------
# Snakemake part

from snakemake.remote.S3 import RemoteProvider as S3RemoteProvider
S3 = S3RemoteProvider()

model_source_commit = git_commit(source.local_path)

rule all:
    input:
        [S3.remote(f"kipoi-models/predictions/{model_source_commit}/{model_name}/predictions.h5")
         for model_name in all_test_models if get_env_name(model_name) is not None and model_name not in exclude]


rule capture_predictions:
    """Run kipoi test <model> -o <file>
    for every kipoi model and export it to the cloud
    """
    output:
        output_file = S3.remote(f"kipoi-models/predictions/{model_source_commit}/{{model_name}}/predictions.h5")
    params:
        env_name = lambda wildcards: get_env_name(wildcards.model_name),
        batch_size = lambda wildcards: 2 if wildcards.model_name == 'Basenji' else 32
    shell:
        """
        source activate {params.env_name}
        mkdir -p `dirname {output.output_file}`
        kipoi test {wildcards.model_name} \
          --batch_size={params.batch_size} \
          --source=kipoi \
          -o {output.output_file}
        """
