import scanpy as sc
import sys
import numpy as np
import anndata
import scvelo as scv
import matplotlib.pyplot as plt
import pandas as pd
import cellrank as cr
import matplotlib.colors
import random
import statistics

#import palantir



#import plotly.express as px
"""
print(np.__version__)

adata = anndata.read_h5ad("Integration_velocity_object/Integration_focus_with_velocity_metadata.h5ad")
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["lightgray", "mediumpurple","mediumslateblue", "blue"])

matplotlib.use('TkAgg')
scv.settings.set_figure_params('scvelo')

scv.pp.moments(adata, use_rep= 'X_pca', n_neighbors= 30, method="umap",mode="connectivities",n_pcs=None)
#Computes moments for velocity estimation (means and uncentered variances). Moments = shape and movements of an object

scv.tl.recover_dynamics(adata, n_top_genes=None, n_jobs= 10, fit_steady_states=True, fit_scaling=True, fit_time= True)
#Recovers the full splicing kinetics of specified genes.

scv.tl.velocity(adata, mode = 'dynamical',filter_genes=False)
scv.tl.rank_velocity_genes(adata, n_genes=100, groupby="labelling_digest")
#Rank genes for velocity characterizing groups.
df = scv.DataFrame(adata.uns['rank_velocity_genes']['names']).head()

scv.tl.velocity_graph(adata)
scv.pl.velocity_embedding_stream(adata, basis='umap', color="labelling_digest", legend_loc= 'right', size=10)
scv.pl.proportions(adata,groupby="labelling_semi_sup")
top_genes = adata.var['fit_likelihood'].sort_values(ascending=False).index
scv.pl.scatter(adata, basis=top_genes[:15], ncols=5, frameon=False)
scv.tl.latent_time(adata)
scv.pl.scatter(adata, color='latent_time', color_map='gnuplot', size=80)

df = adata.var
df = df[(df['fit_likelihood'] > .1) & df['velocity_genes'] == True]

kwargs = dict(xscale='log', fontsize=16)
with scv.GridSpec(ncols=3) as pl:
    pl.hist(df['fit_alpha'], xlabel='transcription rate', **kwargs)
    pl.hist(df['fit_beta'] * df['fit_scaling'], xlabel='splicing rate', xticks=[.1, .4, 1], **kwargs)
    pl.hist(df['fit_gamma'], xlabel='degradation rate', xticks=[.1, .4, 1], **kwargs)

scv.get_df(adata, 'fit*', dropna=True).head()

print("saving adata file")
adata.write("Integration_velocity_object/Integration_focus_with_velocity_metadata_with_hep_adult_with_mat_batch_corrected_velocity_done_with_all_genes.h5ad")

print(anndata.__version__)
random.seed(4321)
cr.logging.print_versions()
print("loading adata file")
adata= anndata.read_h5ad("Integration_velocity_object/Integration_focus_with_velocity_metadata_with_hep_adult_with_mat_batch_corrected_velocity_done_with_all_genes.h5ad")
print(adata)


vk = cr.kernels.VelocityKernel(adata)

vk.compute_transition_matrix( model= 'monte_carlo', similarity='cosine', backward_mode= 'transpose', n_jobs = 10)

### Compute matrix transition
from cellrank.tl.kernels import ConnectivityKernel
ck = cr.kernels.ConnectivityKernel(adata)
ck.compute_transition_matrix()

### Combining two kernels

combined_kernel = 0.87 * vk + 0.13 * ck #spliced unspliced proportions


# CellRank decomposes cellular dynamics into macrostates : Based on Generalized Perron Cluster Cluster Analysis (GPCCA), CellRank identifies macrostates of cellular dynamics which includes initial intermediate and terminal states. 
g = cr.estimators.GPCCA(vk)
print(g)

g.compute_macrostates(n_states = 7, cluster_key="labelling_semi_sup", n_cells=30)

print(scv.__version__)
print(cr.__version__)
random.seed(4321)

g.adata.obsm["macrostates_fwd_memberships"].colors = ["#ff7f0e","#ff0e0e","#0e22ff","#36ff0e","#ffef0e","#ff0ef3","#854c03"]

color_key = {
    'Hep_Adult_1':"#ff7f0e",
    'Hep_Adult_2' : "#854c03",
    'Hep_Adult_3' : "#0e22ff",
    'Hepatoblast_E11_1' : "#36ff0e",
    'BEC_Adult_DDC' : "#ffef0e",
    'Hepatoblast_E11_2' : "#ff0ef3",
    'BEC_Adult' : "#ff0e0e"
}



g.predict_terminal_states()

g.plot_macrostates(which="terminal", legend_loc="right", s=100 )


g.predict_initial_states(allow_overlap=True, n_states= 2)
g.plot_macrostates(which="initial", legend_loc="right", s=100)
sc.tl.score_genes(
    adata, gene_list=["Afp", "Cd44", "Itga6"], score_name="initial_score"  
)

# write macrostates to AnnData
adata.obs["macrostates"] = g.macrostates
#print(g.macrostates_memberships)# 
sc.pl.umap(adata, color = "macrostates")

"""

"""
import seaborn as sn

g.set_terminal_states(states=["BEC_Adult", "Hep_Adult_3","Hepatoblast_E11_1"], allow_overlap=True)

g.compute_fate_probabilities(n_jobs=10)

labels = ["BEC_Adult", "Hep_Adult_3","Hepatoblast_E11_1"]



fate_probs = pd.DataFrame(g.adata.obsm["lineages_fwd"])
fate_probs.columns = labels
fate_probs.index = g.adata.obs["labelling_semi_sup"].index
labelling_digest = pd.DataFrame(adata.obs["labelling_digest"])
labelling_semi_sup = pd.DataFrame(adata.obs["labelling_semi_sup"])
fate_probs_new = pd.concat([fate_probs,labelling_digest,labelling_semi_sup],axis=1)
print(fate_probs_new)

for i in ["BEC_Adult", "Hep_Adult_3","Hepatoblast_E11_1"] :
    order_digest = fate_probs_new.groupby(by = ["labelling_digest"])[i].median().sort_values().index
    cr.pl.aggregate_fate_probabilities(
        adata,
        mode="violin",
        lineages=[i],
        cluster_key="labelling_digest",
        order = order_digest
    )
    order_semi_sup = fate_probs_new.groupby(by = ["labelling_semi_sup"])[i].median().sort_values().index
    cr.pl.aggregate_fate_probabilities(
        adata,
        mode="violin",
        lineages=[i],
        cluster_key="labelling_semi_sup",
        order = order_semi_sup
    )
"""


g= cr.estimators.GPCCA.read(fname="Integration_velocity_object/Integration_focus_with_velocity_metadata_with_hep_adult_with_mat_batch_corrected_velocity_done_with_all_genes_with_transition_matrix.h5ad")

cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["lightgray", "mediumpurple","mediumslateblue", "blue"])

def read_file(file_path):
    f=open(file_path)
    lines = f.readlines()
    new_lines =[]
    for element in lines:
        element = element.strip()
        new_lines.append(element)
    return(new_lines)

def concat_expression(gene_list):
    nb = 0
    mean_genes = 0
    for i in gene_list:
        if g.adata.var_names.str.fullmatch(i).any():
            nb +=1
            mean_genes += g.adata[:,i].X
    print("found ",nb,"/",len(gene_list)," genes")
    mean_genes = mean_genes/nb
    return(mean_genes)



def Extract_gene_expression(Integration_object,gene):
    gene_expression = []
    tmp_expression = Integration_object[:,gene].X
    final_expression_with_NA = []
    results_mean_exp = []
    list_labels = []
    results_exp_labelling = []
    for i in range(0,len(Integration_object.obs["dataset"])):
        if tmp_expression[i,0] == 0.0:
            gene_expression.append(0)
        else :
            gene_expression.append(tmp_expression[i,0])
    results_exp_labelling = pd.DataFrame({"Dataset" : Integration_object.obs["dataset"],"Gene" : gene_expression})
    results_mean_exp = pd.DataFrame(results_exp_labelling.groupby(by = ["Dataset"])["Gene"].mean().sort_values())
    print(results_mean_exp)
    for i in range(0,len(results_mean_exp)):
        if results_mean_exp.values[i] == 0.0:
            list_labels.append(results_mean_exp.index[i])
    print(list_labels)
    for i in range(0,len(Integration_object.obs["labelling_digest"])):
        if tmp_expression[i,0] == 0.0:
            if Integration_object.obs["dataset"][i] in list_labels :
                final_expression_with_NA.append(np.nan)
            else :
                final_expression_with_NA.append(0)
        else :
            final_expression_with_NA.append(tmp_expression[i,0])
    return final_expression_with_NA

 

######################################
############ Pseudotime ##############
######################################


Integration_object = anndata.read("Integration_velocity_object/integration_with_clustering.h5ad")

g.adata.obs["slingshot_lineage1"] = Integration_object.obs["slingshot_lineage1"]
g.adata.obs["slingshot_lineage2"] = Integration_object.obs["slingshot_lineage2"]
g.adata.obs["slingshot_lineage3"] = Integration_object.obs["slingshot_lineage3"]
g.adata.obs["slingshot_lineage4"] = Integration_object.obs["slingshot_lineage4"]
g.adata.obs["Slingshot_pseudotime"] = Integration_object.obs["Pseudotime_slingshot"]


color_key = {
    'Hep_Adult_1':"#ff7f0e",
    'Hep_Adult_2' : "#854c03",
    'Hep_Adult_3' : "#0e22ff",
    'Hepatoblast_E11_1' : "#36ff0e",
    'BEC_Adult_DDC' : "#ffef0e",
    'Hepatoblast_E11_2' : "#ff0ef3",
    'BEC_Adult' : "#ff0e0e"
}


#import os
#os.environ['R_HOME'] = 'lib/R'
import magic
#MAGIC imputation of genes
#magic_operator = magic.MAGIC(n_pca = 30, knn = 5, t = "auto")
#values_imputed_magic_all_dataset = magic_operator.fit_transform(g.adata.X , genes = "all_genes")
#values_imputed_magic_all_dataset = pd.DataFrame(values_imputed_magic_all_dataset, columns = g.adata.var_names, index = g.adata.obs_names)
#values_imputed_magic_all_dataset.to_csv("Integration_velocity_object/magic_imputed_data_entire_dataset.csv")

#read magic dataset and fit with matrix already present
values_imputed_magic_all_dataset = pd.read_csv("Integration_velocity_object/magic_imputed_data_entire_dataset.csv", index_col=0)
mask = np.in1d(g.adata.var_names, values_imputed_magic_all_dataset.columns)
g.adata = g.adata[:, mask].copy()
g.adata.layers["magic_imputed_data"] = values_imputed_magic_all_dataset[g.adata.var_names].loc[g.adata.obs_names]

#Extract cells from lineage 2
lin2 = g.adata.obs["slingshot_lineage2"].isna()
cells_to_keep = lin2[lin2 == False].index
lineage2 = g.adata[cells_to_keep]

#Plot histogram gene expression for each cells histogram is ok 
matrix = np.array(lineage2.obs[values_imputed_magic_all_dataset])
counts_test, bins_test = np.histogram(values_imputed_magic_all_dataset,bins= 1000)
plt.hist(bins_test[:-1], bins_test, weights = counts_test)
plt.show()


"""

import os
os.environ['R_HOME'] = '/usr/lib/R'

import magic
print(magic.__version__)
#import scprep
print(cr.__version__)
magic_operator = magic.MAGIC(n_pca = 30, knn = 5, t = "auto")
#magic_operator.set_params(n_pca = 30, knn = 5, t = 6)
#print(magic_operator)
ddc_slingshot2 = anndata.read("Integration_velocity_object/DDC_cells_slingshot_lineage2.h5ad")# read integration velocity slingshot lineage 2 ddc cells dims(2337 cells, 35589 genes)
cells_ddc_SL2 = ddc_slingshot2.obs["dataset"].index # extract cells name from this object
SL2_ddc = g.adata[cells_ddc_SL2] #extract only cells that are contained in lineage 2 and that are labellized as DDC cells (GSE157698)

gene_velo = np.array(SL2_ddc.var_names) # extract genes from SL2_ddc (11972 genes)

ddc_slingshot2_only_velo_genes = ddc_slingshot2[:, gene_velo] #remove genes that are not contained in this 11972 genes


new_labelling =[]
for i in SL2_ddc.obs["labelling_semi_sup"]:
    if i.startswith("Hep_Adult_DDC2wks"):
        new_labelling.append("Hep_DDC2wks")
    elif i.startswith("Hep_Adult_DDC3wks"):
        new_labelling.append("Hep_DDC3wks")
    elif i.startswith("Hep_Adult_DDC4wks"):
        new_labelling.append("Hep_DDC4wks")
    elif i.startswith("Hep_Adult_DDC6wks"):
        new_labelling.append("Hep_DDC6wks")
    elif i.startswith("Hep_Adult_DDC2wks"):
        new_labelling.append("Hep_DDC2wks")
    elif i.startswith("Hep_YFP_Adult_DDC3wks"):
        new_labelling.append("Hep_DDC3wks")
    elif i.startswith("Hep_YFP_Adult_DDC4wks"):
        new_labelling.append("Hep_DDC4wks")
    elif i.startswith("Hep_YFP_Adult_DDC6wks"):
        new_labelling.append("Hep_DDC6wks")
    else :
        new_labelling.append("BEC")

new_labelling = pd.DataFrame(new_labelling, index = SL2_ddc.obs_names)

SL2_ddc.obs["new_labelling"] = np.array(new_labelling)

sc.pl.umap(SL2_ddc, color = "new_labelling")

####To put at the end #####

cells_with_pseudotime_sup_14 = []
for i in range(len(ddc_slingshot2.obs["slingshot_lineage2"])):
    if ddc_slingshot2.obs["slingshot_lineage2"].values[i] >= 14 and ddc_slingshot2.obs["slingshot_lineage2"].values[i] <= 17:
        cells_with_pseudotime_sup_14.append(ddc_slingshot2.obs["slingshot_lineage2"].index[i])

object_from_lineage_2_only_14_17_pseudotime_values = g.adata[cells_with_pseudotime_sup_14]
values_imputed_magic = pd.read_csv("Integration_velocity_object/magic_imputed_data_pseudotime_14_17.csv", index_col=0)
mask = np.in1d(object_from_lineage_2_only_14_17_pseudotime_values.var_names, values_imputed_magic.columns)
object_from_lineage_2_only_14_17_pseudotime_values = object_from_lineage_2_only_14_17_pseudotime_values[:, mask].copy()
object_from_lineage_2_only_14_17_pseudotime_values.layers["magic_imputed_data"] = values_imputed_magic[object_from_lineage_2_only_14_17_pseudotime_values.var_names].loc[object_from_lineage_2_only_14_17_pseudotime_values.obs_names]
new_labelling =[]
for i in object_from_lineage_2_only_14_17_pseudotime_values.obs["labelling_semi_sup"]:
    if i.startswith("Hep_Adult_DDC2wks"):
        new_labelling.append("Hep_DDC2wks")
    elif i.startswith("Hep_Adult_DDC3wks"):
        new_labelling.append("Hep_DDC3wks")
    elif i.startswith("Hep_Adult_DDC4wks"):
        new_labelling.append("Hep_DDC4wks")
    elif i.startswith("Hep_Adult_DDC6wks"):
        new_labelling.append("Hep_DDC6wks")
    elif i.startswith("Hep_Adult_DDC2wks"):
        new_labelling.append("Hep_DDC2wks")
    elif i.startswith("Hep_YFP_Adult_DDC3wks"):
        new_labelling.append("Hep_DDC3wks")
    elif i.startswith("Hep_YFP_Adult_DDC4wks"):
        new_labelling.append("Hep_DDC4wks")
    elif i.startswith("Hep_YFP_Adult_DDC6wks"):
        new_labelling.append("Hep_DDC6wks")
    else :
        new_labelling.append("BEC")

new_labelling = pd.DataFrame(new_labelling, index = object_from_lineage_2_only_14_17_pseudotime_values.obs_names)

object_from_lineage_2_only_14_17_pseudotime_values.obs["new_labelling"] = np.array(new_labelling)

sc.pl.umap(object_from_lineage_2_only_14_17_pseudotime_values, color = "new_labelling")
modelSL2_only_some_pseudotime_values = cr.models.GAM(object_from_lineage_2_only_14_17_pseudotime_values, distribution = "gamma", link = "inverse",n_knots = 12)
cr.pl.gene_trends(
    object_from_lineage_2_only_14_17_pseudotime_values,
    model=modelSL2_only_some_pseudotime_values,
    lineages="BEC_Adult",
    data_key="magic_imputed_data",
    genes=["Sox9","Epcam","Hnf4a","Vim","Apof","Apob","Cd24a"],
    same_plot=True,
    ncols=2,
    time_key="slingshot_lineage2",
    cell_color="new_labelling",
    hide_cells=False,
    weight_threshold=(1e-3, 1e-3),
    size =5
)
plt.show()
##### Until there######
