#!/usr/bin/python

import anndata
import scvelo as scv
import pandas as pd
import re
import os
import subprocess
import sys
import getopt
import loompy

def main(argv):
    loom_file = ''
    umap_file = ''
    clusters_file = ''
    name = ''
    outdir = ''
    mode = ''
    try:
        opts, args = getopt.getopt(argv,"hl:n:u:c:o:m:",["loom=", "name=", "umap=", "clusters=", "outdir=", "mode="])
    except getopt.GetoptError:
        print('plotVelocity -l <loom> -n <name> -u <umap> -c <clusters> -o <outdir> -m <mode>')
        sys.exit(2)
    for opt, arg in opts:
        if opt == '-h':
            print('plotVelocity -l <loom> -n <name> -u <umap> -c <clusters> -o <outdir> -m <mode>')
            sys.exit()
        elif opt in ("-l", "--loom"):
            loom_file = arg
        elif opt in ("-n", "--name"):
            name = arg
        elif opt in ("-u", "--umap"):
            umap_file = arg
        elif opt in ("-c", "--clusters"):
            clusters_file = arg
        elif opt in ("-o", "--outdir"):
            outdir = arg
        elif opt in ("-m", "--mode"):
            mode = arg

    # >>> sample = anndata.read_loom("FetalCortex/recovered_Seq030.Seq041.Seq044_FetalCortex/12.10.20_premRNA/1_counts/DA094/velocyto/DA094.loom")
    # Variable names are not unique. To make them unique, call `.var_names_make_unique`.
    # >>> cell_clusters = pd.read_csv("FetalCortex/recovered_Seq030.Seq041.Seq044_FetalCortex/12.10.20_premRNA/5_rnavelocity/DA094_clusters.csv")
    # >>> sample_umap = pd.read_csv("FetalCortex/recovered_Seq030.Seq041.Seq044_FetalCortex/12.10.20_premRNA/5_rnavelocity/DA094_cell_embeddings.csv")
    if not os.path.exists(outdir):
        os.makedirs(outdir, exist_ok=True)

    loom_files = loom_file.split(",")
    loom_files = [i.strip() for i in loom_files]
    umap_files = umap_file.split(",")
    umap_files = [i.strip() for i in umap_files]
    clusters_files = clusters_file.split(",")
    clusters_files = [i.strip() for i in clusters_files]
    names = name.split(",")
    names = [i.strip() for i in names]
    if mode == "merged":
        name = "merged"
        loom_file = os.path.join(outdir, (name + ".loom"))
        loompy.combine(files = loom_files, output_file = loom_file, key = "Accession")

    sample = anndata.read_loom(loom_file)
    
    for i in range(0, len(names)):
        names_i = names[i]
        umap_i = pd.read_csv(umap_files[i])
        clusters_i = pd.read_csv(clusters_files[i])
        
        umap_i = umap_i.rename(columns = {'Unnamed: 0':'Cell ID'})
        clusters_i = clusters_i.rename(columns={'Unnamed: 0': 'Cell ID'})
        clusters_i['Cell ID'] = [
            '%s:%sx' % (cluster, re.sub('-[0-9]+$', '', cell))
            for cluster, cell in zip([names_i] * clusters_i.shape[0], clusters_i['Cell ID'].tolist())
        ]
        umap_i['Cell ID'] = [
            '%s:%sx' % (cluster, re.sub('-[0-9]+$', '', cell))
            for cluster, cell in zip([names_i] * umap_i.shape[0], umap_i['Cell ID'].tolist())
        ]
        if i == 0:
            cell_clusters = clusters_i
            sample_umap = umap_i
        else:
            frames = [cell_clusters, clusters_i]
            cell_clusters = pd.concat(frames)
            frames = [sample_umap, umap_i]
            sample_umap = pd.concat(frames)
    
    sample_index = pd.DataFrame(sample.obs.index)
    sample_index = sample_index.rename(columns = {0:'Cell ID'})
    keep_cell = [cell for cell in sample_index['Cell ID'].tolist() if cell in sample_umap['Cell ID'].tolist()]
    sample = sample[keep_cell]

    umap_ordered = sample_index.merge(sample_umap,on="Cell ID")
    umap_ordered = umap_ordered.iloc[:,1:]
    sample.obsm['X_umap'] = umap_ordered.values

    sample.obs['Cell ID'] = sample.obs.index
    sample.obs = pd.merge(left = sample.obs, right = cell_clusters, on="Cell ID")
    sample.obs.index = sample.obs['Cell ID']

    scv.pp.filter_and_normalize(sample)
    scv.pp.moments(sample)
    scv.tl.velocity(sample, mode = "stochastic")
    scv.tl.velocity_graph(sample)

    colours = cell_clusters[["seurat_clusters", "cluster_colours"]].drop_duplicates()
    colours_dict = {cluster:colour for cluster,colour in zip(colours.seurat_clusters, colours.cluster_colours)}
    colours_dict.update({str(cluster):colour for cluster,colour in zip(colours.seurat_clusters, colours.cluster_colours)})
    colours_dict.update({key: colours_dict[str(key)] for key in list(set([i for i in sample.obs["seurat_clusters"]]))})

    scv.pl.velocity_embedding(sample, arrow_length=5, arrow_size=3, dpi=120, color='seurat_clusters', palette=colours_dict, save = (name + "_embedding.pdf"), show=False)
    subprocess.call(["mv", ("./figures/scvelo_" + name + "_embedding.pdf"), os.path.join(outdir, (name + "_embedding.pdf"))])

    scv.pl.velocity_embedding_stream(sample, basis='umap', color='seurat_clusters', dpi=320, palette=colours_dict, save = (name + "_stream.png"), show=False)
    subprocess.call(["mv", ("./figures/scvelo_" + name + "_stream.png"), os.path.join(outdir, (name + "_stream.png"))])

    scv.pl.umap(sample, color='seurat_clusters', palette=colours_dict, save = (name + "_UMAP.pdf"), show=False)
    subprocess.call(["mv", ("./figures/scvelo_" + name + "_UMAP.pdf"), os.path.join(outdir, (name + "_UMAP.pdf"))])

if __name__ == "__main__":
    main(sys.argv[1:])
