#Package and data imports
import numpy as np
import scipy as sci
import pandas as pd

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

import autogenes as ag

from sklearn.svm import NuSVR
import pickle

from custom_functions import find_high_var_genes

sc_ref_path = 'lcm_tbl_input_file.csv'

bulk_path = 'sc_tbl_input_file.csv'

export_path = 'deconv_liver_tumor_AutoGeneS_output_file.xlsx'

data_bulk_raw = pd.read_csv(bulk_path, index_col=0, sep=',')

sc_ref_df = pd.read_csv(sc_ref_path, index_col=0)

# keep only genes which appear in both bulk and sc data to avoid problems further down
keep_genes = [gene for gene in data_bulk_raw.index if gene in sc_ref_df.index]
# subset
data_bulk_raw = data_bulk_raw.loc[keep_genes]
sc_ref_df = sc_ref_df.loc[keep_genes]

# get a subset of sc_ref with only highly variable gene, aim for approx 4000
hv_genes = find_high_var_genes(sc_ref_df, disp_min=0.3)
len(hv_genes)

#Identification of informative genes for deconvolution with AutoGeneS
# chose subset of sc ref data
sc_ref_sub = sc_ref_df.loc[hv_genes]

#Run AutoGeneS on highly variable genes to select 400 marker genes
ag.init(sc_ref_sub.T)
ag.optimize(ngen=5000, seed=0, nfeatures=700, mode='fixed', offspring_size=100, verbose=False)
ag.plot(size='large',weights=(1,-1))

# select the solution with the lowest correlation
index = ag.select(index=0)

# subset signature data to selected genes for visualisation
centroids_sc_pareto = sc_ref_sub[index]

# Correlation matrix
corr = pd.DataFrame(data = np.corrcoef(centroids_sc_pareto.T), columns = centroids_sc_pareto.columns, index = centroids_sc_pareto.columns)
mask = np.zeros_like(corr)
mask[np.triu_indices_from(mask)] = True
with sns.axes_style("white"):
    sns_plot =sns.clustermap(np.abs(corr),cmap=sns.color_palette("GnBu", 1000), robust=True)

#Non-negative least square regression using selected genes, export of results
coef_nnls = ag.deconvolve(data_bulk_raw.T, model='nnls')
def normalize_proportions(data,copy):
    if copy==True:
        data_copy = data.copy()
    else:
        data_copy = data
    data_copy[data_copy < 0] = 0
    for raw in data_copy.index:
        sum = data_copy.loc[raw].sum()
        data_copy.loc[raw] = np.divide(data_copy.loc[raw],sum)
    return data_copy

proportions_nnls = normalize_proportions(pd.DataFrame(data=coef_nnls, columns=sc_ref_df.columns, index=data_bulk_raw.columns), copy = False)

# write
proportions_nnls.to_excel(export_path)