#!/usr/bin/env python3

import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'
import numpy as np
import pandas as pd
from scipy import stats
from scipy import sparse
from scipy.special import expit, logit
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False

import keras 
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.layers import Layer
from keras.layers import Input, Dense, Dropout, LeakyReLU, Activation, Lambda, BatchNormalization
from keras.models import Model, load_model
import keras.backend as K
from keras.callbacks import TensorBoard
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

import warnings
import argparse as argp
import sys
import subprocess as subp

# Build the network
warnings.filterwarnings('ignore', message='Discrepancy between')
LATENT_DIM = 128
DEFAULT_NOISE_RATE = 0.035
quiet = False
keepTmp = False #TO DO: set to false

__version__ = "1.0"


class confusionMat:
    def __init__(self,TPR,TNR,PPV,NPV,FPR,FNR,FDR,ACC):
        self.TPR = TPR
        self.TNR = TNR
        self.PPV = PPV
        self.NPV = NPV 
        self.FPR = FPR
        self.FNR = FNR
        self.FDR = FDR
        self.ACC = ACC

class MinibatchDiscrimination(Layer):
    def __init__(self, units=5, units_out=10, **kwargs):
        self._w = None
        self._units = units
        self._units_out = units_out
        super(MinibatchDiscrimination, self).__init__(**kwargs)

    def build(self, input_shape):
        self._w = self.add_weight(name='w',
                                  shape=(input_shape[1], self._units * self._units_out),
                                  initializer='uniform',
                                  trainable=True
                                  )
        super(MinibatchDiscrimination, self).build(input_shape)

    def call(self, x, **kwargs):
        h = K.dot(x, self._w)  # Shape=(batch_size, units * units_out)
        h = K.reshape(h, (-1, self._units, self._units_out))  # Shape=(batch_size, units, units_out)
        h_t = K.permute_dimensions(h, [1, 2, 0])  # Shape=(units, units_out, batch_size)
        diffs = h[..., None] - h_t[None, ...]  # Shape=(batch_size, units, units_out, batch_size)
        abs_diffs = K.sum(K.abs(diffs), axis=1)  # Shape=(batch_size, units_out, batch_size)
        features = K.sum(K.exp(-abs_diffs), axis=-1)  # Shape=(batch_size, units_out)
        return features

    def compute_output_shape(self, input_shape):
        return input_shape[0], self._units_out

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'units': self._units,
            'units_out': self._units_out
        })
        return config


class gGAN:
    def __init__(self, data, network_num, latent_dim=LATENT_DIM, noise_rate=DEFAULT_NOISE_RATE,
                 discriminate_batch=False,
                 max_replay_len=None):
        """
        Initialize GAN
        :param data: expression matrix. Shape=(nb_samples, nb_genes)
        :param gene_symbols: list of gene symbols. Shape=(nb_genes,)
        :param latent_dim: number input noise units for the generator
        :param noise_rate: rate of noise being added in the gene-wise layer of the generator (psi/nb_genes).
        :param discriminate_batch: whether to discriminate a whole batch of samples rather than discriminating samples
               individually. NOTE: Not implemented
        :param max_replay_len: size of the replay buffer
        """
        self._latent_dim = latent_dim
        self._noise_rate = noise_rate
        self._data = data
        self._gene_symbols = data.columns
        self._nb_samples, self._nb_genes = data.shape
        self._discriminate_batch = discriminate_batch
        self._max_replay_len = max_replay_len
        self._network_num = network_num

        # Following parameter and optimizer set as recommended in paper
        self.n_critic = 5
        self.clip_value = 0.01
        optimizer = RMSprop(lr=0.00005)

        # Build and compile the discriminator
        self.discriminator = self._build_discriminator()
        if self._discriminate_batch:
            # loss = self._minibatch_binary_crossentropy
            raise NotImplementedError

        self.discriminator.compile(loss=self.wasserstein_loss,
                                   optimizer=optimizer)
        self._gradients_discr = self._gradients_norm(self.discriminator)

        # Build the generator
        self.generator = self._build_generator()
        z = Input(shape=(self._latent_dim,))
        gen_out = self.generator(z)

        # Build the combined model
        self.discriminator.trainable = False
        valid = self.discriminator(gen_out)
        self.combined = Model(z, valid)
        self.combined.compile(loss=self.wasserstein_loss,
                              optimizer= optimizer)
        self._gradients_gen = self._gradients_norm(self.combined)

        # Initialize replay buffer
        noise = np.random.normal(0, 1, (32, self._latent_dim))
        x_fake = self.generator.predict(noise)
        self._replay_buffer = np.array(x_fake)

    def wasserstein_loss(self, y_true, y_pred):
        return K.mean(y_true * y_pred)

    def _build_generator(self):
        """
        Build the generator
        """
        i = self._network_num
        noise = Input(shape=(self._latent_dim,))
        h = noise
        # h = Dropout(0.5)(h)
        h = Dense(10 + i*100 )(h)
        h = LeakyReLU(0.3)(h)
        h = Dropout(0.1)(h)
        h = Dense((10 + i*100)*2)(h)
        h = LeakyReLU(0.3)(h)
        h = Dropout(0.1)(h)
        h = Dense((10 + i*100)*3)(h)
        h = LeakyReLU(0.3)(h)
        h = Dropout(0.1)(h)
        h = Dense(self._nb_genes)(h)
        model = Model(inputs=noise, outputs=h)
        if not quiet:
            model.summary()
        return model

    def _build_discriminator(self):
        """
        Build the discriminator
        """
        i = self._network_num
        expressions_input = Input(shape=(self._nb_genes,))
        h = expressions_input
        h = Dropout(0.3)(h)
        h = MinibatchDiscrimination()(h)
        h = Dense((10 + i*100)*3)(h)
        h = LeakyReLU(0.3)(h)
        h = Dense((10 + i*100)*2)(h)
        h = LeakyReLU(0.3)(h)
        h = Dense(10 + i*100)(h)
        h = LeakyReLU(0.3)(h)
        h = Dropout(0.3)(h)
        h = Dense(1)(h)
        model = Model(inputs=expressions_input, outputs=h)
        if not quiet:
            model.summary()
        return model

    @staticmethod
    def _write_log(callback, names, logs, epoch):
        """
        Write log to TensorBoard callback
        :param callback: TensorBoard callback
        :param names: list of names for each log. Shape=(nb_logs,)
        :param logs: list of scalars. Shape=(nb_logs,)
        :param epoch: epoch number
        """
        for name, value in zip(names, logs):
            summary = tf.Summary()
            summary_value = summary.value.add()
            summary_value.simple_value = value
            summary_value.tag = name
            callback.writer.add_summary(summary, epoch)
            callback.writer.flush()

    @staticmethod
    def _gradients_norm(model):
        """
        Create Keras function to compute Euclidean norm of gradient vector
        :param model: Keras model for which the norm will be computed
        :return: Keras function to compute Euclidean norm of gradient vector
        """
        grads = K.gradients(model.total_loss, model.trainable_weights)
        summed_squares = tf.stack([K.sum(K.square(g)) for g in grads])
        norm = K.sqrt(K.sum(summed_squares))

        input_tensors = [model.inputs[0],  # input data
                         model.sample_weights[0],  # how much to weight each sample by
                         model._targets[0],  # labels
                         K.learning_phase(),  # train or test mode
                         ]

        return K.function(inputs=input_tensors, outputs=[norm])

    def train(self, epochs, file_name=None, batch_size=32):
        """
        Trains the GAN
        :param epochs: Number of epochs
        :param file_name: Name of the .h5 file in checkpoints. If None, the model won't be saved
        :param batch_size: Batch size
        """
        # Best score for correlation of distance gene matrices
        best_gdxdz = 0

        # Adversarial ground truths
        valid = -np.ones((batch_size, 1))
        fake = np.ones((batch_size, 1))

        for epoch in range(epochs):

            for _ in range(self.n_critic):
                # ----------------------
                #  Train Discriminator
                # ----------------------

                # Select random sample
                idxs = np.random.randint(0, self._nb_samples, batch_size)
                x_real = self._data.iloc[idxs, :]

                # Sample noise and generate a batch of new images
                noise = np.random.normal(0, 1, (batch_size, self._latent_dim))
                x_fake = self.generate_batch(batch_size)  # self.generator.predict(noise)

                # Train the discriminator (real classified as ones and generated as zeros)
                d_loss_real = self.discriminator.train_on_batch(x_real, valid)
                d_loss_fake = self.discriminator.train_on_batch(x_fake, fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                # Clip critic weights
                for l in self.discriminator.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
                    l.set_weights(weights)

            # ----------------------
            #  Train Generator
            # ----------------------

            # Train the generator
            g_loss = self.combined.train_on_batch(noise, valid)

            # ----------------------
            #  Plot the progress
            # ----------------------
            if not quiet:
                print('{} [D loss: {:.4f}] [G loss: {:.4f}]'.format(epoch, d_loss, g_loss))
        
        # if file_name:
        #     self.save_model(file_name)


    def generate_batch(self, batch_size=32):
        """
        Generate a batch of samples using the generator
        :param batch_size: Batch size
        :return: Artificial samples generated by the generator. Shape=(batch_size, nb_genes)
        """
        noise = np.random.normal(0, 1, (batch_size, self._latent_dim))
        pred = self.generator.predict(noise)

        return pred

    def discriminate(self, expr):
        """
        Discriminates a batch of samples
        :param expr: expressions matrix. Shape=(nb_samples, nb_genes)
        :return: for each sample, probability that it comes from the real distribution
        """
        return self.discriminator.predict(expr)

    def save_model(self, outdir, name):
        """
        Saves model to CHECKPOINTS_DIR
        :param name: model id
        """
        try:
            os.mkdir('{}/discr'.format(outdir))
            os.mkdir('{}/gen'.format(outdir))
            os.mkdir('{}/gan'.format(outdir))
        except:
            pass
        self.discriminator.trainable = True
        self.discriminator.save('{}/discr/{}.h5'.format(outdir, name))
        self.discriminator.trainable = False
        for layer in self.discriminator.layers:  # https://github.com/keras-team/keras/issues/9589
            layer.trainable = False
        self.generator.save('{}/gen/{}.h5'.format(outdir, name), include_optimizer=False)
        self.combined.save('{}/gan/{}.h5'.format(outdir, name))

    def load_model(self, inDir, name):
        """
        Loads model from CHECKPOINTS_DIR
        :param name: model id
        """
        self.discriminator = load_model('{}/discr/{}.h5'.format(inDir, name),
                                        custom_objects={'MinibatchDiscrimination': MinibatchDiscrimination,
                                                        'wasserstein_loss': self.wasserstein_loss})
        self.generator = load_model('{}/gen/{}.h5'.format(inDir, name),
                                    custom_objects={'GeneWiseNoise': GeneWiseNoise,
                                                    'wasserstein_loss': self.wasserstein_loss},
                                    compile=False)
        self.combined = load_model('{}/gan/{}.h5'.format(inDir, name),
                                   custom_objects={'GeneWiseNoise': GeneWiseNoise,
                                                   'MinibatchDiscrimination': MinibatchDiscrimination,
                                                   'wasserstein_loss': self.wasserstein_loss})
        self._latent_dim = self.generator.input_shape[-1]


def normalize(expr, kappa=1):
    """
    Normalizes expressions to make each gene have mean 0 and std kappa^-1
    :param expr: matrix of gene expressions. Shape=(nb_samples, nb_genes)
    :param kappa: kappa^-1 is the gene std
    :return: normalized expressions
    """
    mean = np.mean(expr, axis=0)
    std = np.std(expr, axis=0)
    std[std==0] = np.nan
    #import pdb; pdb.set_trace()
    return (expr - mean) / (kappa * std)


def restore_scale(expr, mean, std):
    """
    Makes each gene j have mean_j and std_j
    :param expr: matrix of gene expressions. Shape=(nb_samples, nb_genes)
    :param mean: vector of gene means. Shape=(nb_genes,)
    :param std: vector of gene stds. Shape=(nb_genes,)
    :return: Rescaled gene expressions
    """
    return expr * std + mean


def clip_outliers(expr, gene_means=None, gene_stds=None, std_clip=2):
    """
    Clips gene expressions of samples in which the gene deviates more than std_clip standard deviations from the gene mean.
    :param expr: np.array of expression data with Shape=(nb_samples, nb_genes)
    :param gene_means: np.array with the mean of each gene. Shape=(nb_genes,). If None, it is computed from expr
    :param gene_stds: np.array with the std of each gene. Shape=(nb_genes,). If None, it is computed from expr
    :param std_clip: Number of standard deviations for which the expression of a gene will be clipped
    :return: clipped expression matrix
    """
    nb_samples, nb_genes = expr.shape

    # Find gene means and stds
    if gene_means is None:
        gene_means = np.mean(expr, axis=0)
    if gene_stds is None:
        gene_stds = np.std(expr, axis=0)

    # Clip samples for which a gene is not within [gene_mean - std_clip * gene_std, gene_mean + std_clip * gene_std]
    clipped_expr = np.array(expr)
    for sample_idx in range(nb_samples):
        for gene_idx in range(nb_genes):
            if expr[sample_idx, gene_idx] > (gene_means + std_clip * gene_stds)[gene_idx]:
                clipped_expr[sample_idx, gene_idx] = (gene_means + std_clip * gene_stds)[gene_idx]
            elif expr[sample_idx, gene_idx] < (gene_means - std_clip * gene_stds)[gene_idx]:
                clipped_expr[sample_idx, gene_idx] = (gene_means - std_clip * gene_stds)[gene_idx]

    # Clip samples below zero to zero
    clipped_expr2 = np.array(clipped_expr)
    for sample_idx in range(nb_samples):
        for gene_idx in range(nb_genes):
            if expr[sample_idx, gene_idx] < 0:
                clipped_expr2[sample_idx, gene_idx] = 0

    return clipped_expr2


def nb_sig_test():
    """
    Makes each gene j have mean_j and std_j
    :param expr: matrix of gene expressions. Shape=(nb_samples, nb_genes)
    :param mean: vector of gene means. Shape=(nb_genes,)
    :param std: vector of gene stds. Shape=(nb_genes,)
    :return: Rescaled gene expressions
    """    
    return expr * std + mean


def formatdataforgen(df, condition, num_samples_to_build_on, isLogNorm, seed=None):
    # df = df.set_index(df.columns[0]) Now done in SepMultirunGAN when merging duplicate gene names
    # Further formating after subsetting class
    df = df.loc[:, df.iloc[0] == condition]   # Select only one class to condition GAN 
    df = df.drop(df.index[0])
    gene_symbols = df.index # Save Gene Symbols for later
    expr_train = df # re-naming
    expr_train = expr_train.sample(n = num_samples_to_build_on, axis=1, random_state=seed) # Random Sample of patients to build GAN on
    # Normalize, Drop NAs and Drop Zero columns after normalization
    expr_train.loc[:,(expr_train != 0).any()] # Drop any samples with all zeros
    expr_train = expr_train.T
    #import pdb; pdb.set_trace()
    if not isLogNorm:
        data = normalize(expr_train, kappa=2)
    else:
        data = expr_train
    todrop = list(data.columns[data.isnull().any()]) #Drop NAs
    # expr_train = expr_train.drop(list(todrop), axis=1)
    if not quiet:
        print('Samples for condition {}:\n'.format(condition)+
                '\n'.join(expr_train.index.to_list()))
    return (expr_train, todrop)


def gensynthdata (ggan, sizeofset, popMean, popStd, geneNames, isLogNorm):
    # Generate synthetic data
    size = sizeofset # Choose the size of the synthetic dataset
    #mean = np.mean(populationData, axis=0)
    #std = np.std(populationData, axis=0)
    #r_min = populationData.min()
    #r_max = populationData.max()
    expr = ggan.generate_batch(size)
    if not isLogNorm:
        expr = normalize(expr) #TO DO: ask Mike about this
        mean_array = popMean.values
        std_array = popStd.values
        expr = expr * std_array + mean_array 
        synth_expression = clip_outliers(expr, gene_means=None, gene_stds=None, std_clip=2)
    else:
        synth_expression = expr
    condition_N = synth_expression.T
    genes_condition_N = list(geneNames)
    return (condition_N, genes_condition_N)


def formatforsigtestingbycondition(condition_n, genes_condition_n) :
    # Format Condition N Generative Data for Significance Testing
    expression_n = pd.DataFrame(condition_n, index=genes_condition_n)
    expression_n = expression_n.sort_index()
    return(expression_n)


# Perform significance testing
def performsigtesting(sorted_condition_0, sorted_condition_1, outname, useDESeq, geneMins, geneScale, isLogNorm):
    if isLogNorm:
        #Restore scale by reversing logistic sigmoid and log2 steps:
        #sorted_condition_0 = sorted_condition_0.applymap(lambda x: logit(x))
        sorted_condition_0 = sorted_condition_0.mul(geneScale, axis='rows')
        sorted_condition_0 = sorted_condition_0.add(geneMins, axis='rows')
        sorted_condition_0 = sorted_condition_0.rpow(2)
        #sorted_condition_1 = sorted_condition_1.applymap(lambda x: logit(x))
        sorted_condition_1 = sorted_condition_1.mul(geneScale, axis='rows')
        sorted_condition_1 = sorted_condition_1.add(geneMins, axis='rows')
        sorted_condition_1 = sorted_condition_1.rpow(2)
    sorted_condition_0.to_csv(outname+"_tmp_cond0.csv", index = True, header=True)
    sorted_condition_1.to_csv(outname+"_tmp_cond1.csv", index = True, header=True)

    rscript = os.path.dirname(os.path.abspath(__file__))+'/sigtest.R'
    if useDESeq:
        DESeq = '1'
    else:
        DESeq = '0'
    args = ['Rscript', rscript, outname+'_tmp_cond0.csv', outname+'_tmp_cond1.csv', '0', outname, DESeq]
    subp.check_call(args, stdout=sys.stdout)
    
    pvalues = pd.read_csv(outname+"_tmp_pval.csv")
    pvalues = pvalues.set_index(pvalues.columns[0])
    if useDESeq:
        pvalues = pvalues.loc[:,['pvalue','padj','log2FoldChange','absDiff','deRank']]
    else:
        pvalues = pvalues.loc[:,['PValue','FDR','logFC','absDiff','deRank']]
    pvalues.columns = ['pvalue','padj','lfc','delta','rank']
    pvalues[['pvalue','padj']] = pvalues[['pvalue','padj']].fillna(value=1)

    tmpfiles = [outname+"_tmp_cond0.csv", outname+"_tmp_cond1.csv", outname+"_tmp_pval.csv", outname+"_tmp_true_pval.csv", outname+'_real_expr.csv']
    for fname in tmpfiles:
        if os.path.exists(fname) and not keepTmp:
            os.remove(fname)
    return(pvalues)



def Cleanupsiggenelist(pvalues, alpha, minGE, minfold):
    final_df_sig = pvalues[(abs(pvalues['delta'])> minGE) & #filter by delta
                           (abs(pvalues['lfc']) > minfold) & #filter by LFC
                           (pvalues['padj'] < alpha)]
    return(final_df_sig)


def Getrealsiggenelist(pReal, alpha, minGE, minfold):
    final_real_df_sig = pReal[(abs(pReal['delta'])> minGE) & #filter by delta
                         (abs(pReal['lfc']) > minfold) & #filter by LFC
                         (pReal['padj'] < alpha)]
    return(final_real_df_sig)


def getfinaloutput(outName, final_df_sig, final_test_df_sig_real, final_test_df):
    #Compare sig gene lists from real and generated data
    final_test_df = final_test_df
    final_test_df_sig = final_df_sig
    generative_sig_list = final_test_df_sig['gene']
    real_sig_list = final_test_df_sig_real.index
    def intersection(lst1, lst2): 
        lst3 = [value for value in lst1 if value in lst2] 
        return lst3 
    intersection_of_lists = intersection(generative_sig_list, real_sig_list)

    # Get Raw Output
    overlap = np.count_nonzero(np.asarray(intersection_of_lists))
    synth = np.count_nonzero(np.asarray(generative_sig_list))
    real = np.count_nonzero(np.asarray(real_sig_list))
    synth_alone = abs(synth - overlap)
    real_alone = abs(real - overlap)

    TP = overlap 
    FP = synth_alone 
    FN = real_alone 
    total_genes = final_test_df.shape[0]
    TN = total_genes - synth 

    # Sensitivity, hit rate, recall, or true positive rate
    TPR = np.float64(TP)/(TP+FN)
    # Specificity or true negative rate
    TNR = np.float64(TN)/(TN+FP) 
    # Precision or positive predictive value
    PPV = np.float64(TP)/(TP+FP)
    # Negative predictive value
    NPV = np.float64(TN)/(TN+FN)
    # Fall out or false positive rate
    FPR = np.float64(FP)/(FP+TN)
    # False negative rate
    FNR = np.float64(FN)/(TP+FN)
    # False discovery rate
    FDR = np.float64(FP)/(TP+FP)
    # Overall accuracy
    ACC = np.float64(TP+TN)/(TP+FP+FN+TN)

    print('TP={}\nTN={}\nFP={}\nFN={}'.format(TP,TN,FP,FN))
    print('Sensitivity=',TPR)
    print('Specificity=',TNR)
    print('PPV=',PPV)
    print('NPV=',NPV)
    print('False Positive Rate=',FPR)
    print('False Negative Rate=',FNR)
    print('False Discovery Rate=',FDR)
    print('Overall Accuracy=',ACC)
    # Generate Accuracy Measurements and graphical output
    # venn2(subsets = (synth_alone, real_alone, overlap), set_labels = ('Synthetic', 'Real'), set_colors=('red', 'skyblue'))
    # plt.title('Venn')
    # plt.savefig(outName+'_venn.pdf')
    return(TPR,TNR,PPV,NPV,FPR,FNR,FDR,ACC)
    

def getfinaloutputSepMultiGAN(outName, generative_sig_list, real_sig_list, totNumGenes):
    #Compare sig gene lists from real and generated data
    genSet = set(generative_sig_list)
    realSet = set(real_sig_list)

    # Get Raw Output
    overlap = len(genSet & realSet)
    synth = len(genSet)
    real = len(realSet)
    synth_alone = len(genSet - realSet)
    real_alone = len(realSet - genSet)

    TP = overlap 
    FP = synth_alone 
    FN = real_alone 
    total_genes = totNumGenes
    TN = total_genes - synth 

    # Sensitivity, hit rate, recall, or true positive rate
    TPR = np.float64(TP)/(TP+FN)
    # Specificity or true negative rate
    TNR = np.float64(TN)/(TN+FP) 
    # Precision or positive predictive value
    PPV = np.float64(TP)/(TP+FP)
    # Negative predictive value
    NPV = np.float64(TN)/(TN+FN)
    # Fall out or false positive rate
    FPR = np.float64(FP)/(FP+TN)
    # False negative rate
    FNR = np.float64(FN)/(TP+FN)
    # False discovery rate
    FDR = np.float64(FP)/(TP+FP)
    # Overall accuracy
    ACC = np.float64(TP+TN)/(TP+FP+FN+TN)

    print('TP={}\nTN={}\nFP={}\nFN={}'.format(TP,TN,FP,FN))
    print('Sensitivity=',TPR)
    print('Specificity=',TNR)
    print('PPV=',PPV)
    print('NPV=',NPV)
    print('False Positive Rate=',FPR)
    print('False Negative Rate=',FNR)
    print('False Discovery Rate=',FDR)
    print('Overall Accuracy=',ACC)
    # Generate Accuracy Measurements and graphical output
    # fig = plt.figure()
    # venn2(subsets = (synth_alone, real_alone, overlap), set_labels = ('Synthetic', 'Real'), set_colors=('red', 'skyblue'))
    # plt.title('Venn')
    # fig.savefig(outName+'_venn.pdf')
    # #fig.clear()
    # plt.close()
    return(confusionMat(TPR,TNR,PPV,NPV,FPR,FNR,FDR,ACC))


def SepMultirunGAN(inputFileName, popFileName, outName,
                   epochs, batchsize0, batchsize1, minExprMean, minExprDev,
                   numbOfNetworks, isLogNorm, useDESeq, saveTrained, saveSynthExpr, 
                   seed=None):
    """
    runGAN : Performs diff exp testing between GAN gene lists
    :param epochs: Number of epochs to train GAN
    :param batchsize0: Number of synthetic samples to generate for condition 0
    :param batchsize1: Number of synthetic samples to generate for condition 1
    :param minExprMean: Remove genes from input data with mean expression less than value
    :param minExprDev: Remove genes from input data with average absolute deviation less than value
    """

    if seed is not None:
        np.random.seed(seed)
        keras.utils.set_random_seed(seed)

    df_combined = pd.DataFrame(columns=['pvalue', 'padj', 'lfc', 'delta'])
    # Load Data
    inputFile = pd.read_csv(inputFileName) 
    df = inputFile
    if df.isnull().values.any():
        print('WARNING: NaN expression values detected in input, removing automatically')
        df = df.dropna()

    # Merge duplicate gene names
    df = df.set_index(df.columns[0])
    cancerID = df.iloc[0]
    df = df.drop(index = df.index[0])
    df = df.groupby([df.index]).sum()
    df = pd.concat([pd.DataFrame(cancerID).T, df])

    #numberC0 = sum(df.iloc[0]==0)
    numberC0 = -1
    #numberC1 = sum(df.iloc[0]==1)
    numberC1 = -1

    #Set first index name to "Group"
    df = df.rename(index={df.index[0]: 'Group'})

    #Filter input by mean and avg absolute deviation of expression:
    df = df[((df.mean(axis=1) > minExprMean) & 
             (abs(df.sub(df.mean(axis=1), axis=0)).mean(axis=1) > minExprDev)) | 
            (df.index =='Group')]

    if isLogNorm: #log2 normalize and apply logistic sigmoid scaling
        #df = df.applymap(lambda x: np.log2(x)) #Removed, only accept log(CPM) input
        #df.iloc[1:,:] = df.iloc[1:,:].applymap(lambda x: expit(x)) #logistic function
        # set min of all genes to 0, saving offsets
        # normalize all gene to 0-1, saving max values
        geneMins = df.iloc[1:,:].min(axis=1)
        df.iloc[1:,:] = df.iloc[1:,:].sub(geneMins, axis='rows')
        geneScale = df.iloc[1:,:].max(axis='columns')
        df.iloc[1:,:] = df.iloc[1:,:].div(geneScale, axis='rows')
        #df = df.dropna()

    # #Split df into training and validation sets
    # dfTrain, dfValid = train_test_split(df.T, test_size=0.5, random_state=seed, stratify=df.loc['Group'])

    # if numberC0 == -1 and numberC1 == -1:
    #     #Using all samples, don't employ train/validation split
    #     df = df
    #     dfValid = df
    # else:
    #     #Only use training samples for training GAN, validation samples for gold standard DE results
    #     df = dfTrain.T
    #     dfValid = dfValid.T

    if numberC0<=0:
        numberC0 = sum(df.iloc[0]==0)
        
    if numberC1<=0:
        numberC1 = sum(df.iloc[0]==1)

    # Format input data for training GAN
    expr_train0, todrop0 = formatdataforgen(df, 0, numberC0, seed) #condition, number of samples to train on
    expr_train1, todrop1 = formatdataforgen(df, 1, numberC1, seed) 
    expr_train0 = expr_train0.drop(todrop0 + todrop1, axis=1, errors='ignore')
    expr_train1 = expr_train1.drop(todrop0 + todrop1, axis=1, errors='ignore')

    #Load population data
    if not popFileName: # Use training data as population for RF classifier
        popDf = df.drop('Group')
    else:
        popDf = pd.read_csv(popFileName)
        popDf = popDf.set_index(popDf.columns[0])
    if popDf.isnull().values.any():
        print('WARNING: NaN expression values detected in population data, removing automatically')
        popDf = popDf.dropna()

    # Merge duplicate gene names
    popDf = popDf.groupby([popDf.index]).sum()
    commonGenes = popDf.index.intersection(expr_train0.columns)
    popDf = popDf.loc[commonGenes]
    expr_train0 = expr_train0[commonGenes]
    expr_train1 = expr_train1[commonGenes]
    
    if not isLogNorm:
        # Train RF classifier
        labels = np.array([0]*numberC0 + [1]*numberC1)
        features = np.array(pd.concat([expr_train0, expr_train1], ignore_index=True))

        # Instantiate model with 1000 decision trees
        rf = RandomForestClassifier(n_estimators = 1000, random_state = seed)
        # Train the model on training data
        rf = rf.fit(features, labels)
        # Using Random Forest classifier, predict population sample class for all population samples
        # First, split population data into training and validation sets
        rfResults = pd.DataFrame(rf.predict(np.array(popDf.T)))
        # Ensure some pop samples were assigned to both conditions:
        if sum(rfResults[0]==0) == 0:
           print('No population samples assigned to condition 0, aborting.')
           sys.exit(1)
        if sum(rfResults[0]==1) == 0:
           print('No population samples assigned to condition 1, aborting.')
           sys.exit(1)

        rfResults = rfResults.set_index(popDf.columns)
        rfResults.columns = ['class']
        # Make a dataframe of all samples with their associated prediction (forpredict)
        popClass = pd.concat([rfResults.T, popDf], axis=0)
        genesMean0 = popClass.loc[:, popClass.iloc[0] == 0].iloc[1:,:].mean(axis=1)
        genesStd0 = popClass.loc[:, popClass.iloc[0] == 0].iloc[1:,:].std(axis=1)
        genesMean1 = popClass.loc[:, popClass.iloc[0] == 1].iloc[1:,:].mean(axis=1)
        genesStd1 = popClass.loc[:, popClass.iloc[0] == 1].iloc[1:,:].std(axis=1)
        geneMins = None
        geneScale = None
    else:
        genesMean0 = None
        genesStd0 = None
        genesMean1 = None
        genesStd1 = None
    
    # Train ensemble of GANs
    for i in range(numbOfNetworks):
        # Train GAN on condition 0
        if saveTrained:
            saveName = '{}.trained.cond0.{}'.format(outName,i)
        else:
            saveName = None
        
        ggan = gGAN(expr_train0, network_num=i, latent_dim=LATENT_DIM, noise_rate=DEFAULT_NOISE_RATE)
        ggan1 = ggan.train(epochs=epochs, file_name=saveName, batch_size=32)
        if saveName is not None:
            outDir, fName = saveName.rsplit('/',1)
            ggan.save_model(outDir, fName)        
        condition_0, genes_condition_0 = gensynthdata(ggan, batchsize0, genesMean0, genesStd0, popDf.index, isLogNorm) # Output batch size 

        # Train GAN on condition 1
        if saveTrained:
            saveName = '{}.trained.cond1.{}'.format(outName,i)
        else:
            saveName = None

        ggan = gGAN(expr_train1, network_num=i, latent_dim=LATENT_DIM, noise_rate=DEFAULT_NOISE_RATE)
        ggan1 = ggan.train(epochs=epochs, file_name=saveName, batch_size=32)
        if saveName is not None:
            outDir, fName = saveName.rsplit('/',1)
            ggan.save_model(outDir, fName)
        condition_1, genes_condition_1 = gensynthdata(ggan, batchsize1, genesMean1, genesStd1, popDf.index, isLogNorm) # Output batch size 

        #Perform significance testing
        sorted_condition_0 = formatforsigtestingbycondition(condition_0, genes_condition_0)
        sorted_condition_1 = formatforsigtestingbycondition(condition_1, genes_condition_1)
        #Pass validation set to DE testing function
        pvalues = performsigtesting(sorted_condition_0, sorted_condition_1, '{}_{}'.format(outName,i), useDESeq, geneMins, geneScale, isLogNorm) 
        final_df_temp = pvalues #No longer filtering by pval, etc., will use avg. rank at end

        df_combined = pd.concat([df_combined, final_df_temp])
        print('Network Number', i )

    if saveSynthExpr:
        exprDf = pd.concat([sorted_condition_0.add_prefix('cond0_'), sorted_condition_1.add_prefix('cond1_')],axis=1)
        exprDf.to_csv(outName+'_synthExpr.csv',index=True,header=True)

    
    dfCounts = pd.DataFrame(df_combined.index.value_counts(), columns=['count'])
    dfRank = pd.DataFrame(df_combined.groupby(level=0)['rank'].mean(), columns=['rank'])
    gainResult = pd.merge(dfCounts, dfRank, left_index=True, right_index=True)
    # genelist = gainResult.sort_values(by=['rank']).index[:compareN]
    
    return (gainResult)


def main():
    mapParse = argp.ArgumentParser(description=('GAiN generates large synthetic cohorts using sparse training sets of gene expression data from samples of two different phenotypes.  It then performs differential expression (DE) analysis on the synthetic cohorts, returning a list of candidate DE genes.'))
    mapParse.add_argument('-b','--batchsizes', type=int, nargs=2, default=[500, 500], metavar=('C0', 'C1'), help='Size of synthetic cohort to generate for each condition [500 500]')
    mapParse.add_argument('-e','--epochs', type=int, default=500, help='Number of epochs for training [500]')
#    mapParse.add_argument('-a','--alphaGAN', type=float, default=-1.0, metavar='ALPHAGAN', help='Adjusted significance threshold for DE genes between synthetic cohorts [.05]')
#    mapParse.add_argument('--alphaReal', type=float, default=-1.0, metavar='ALPHAREAL', help='Adjusted significance threshold for DE genes between real cohorts [.05]')
#    mapParse.add_argument('--minGE', type=float, metavar='GE', default=10, help='Minimum absolute difference in average expression between the conditions for a gene to be reported [10]')
#    mapParse.add_argument('--minLFC', type=float, metavar='LFC', default=1, help='Minimum log2 fold change in expression between the conditions for a gene to be reported [1]')
    mapParse.add_argument('--minExprMean', type=int, default=10, help='Minimum mean expression for gene to be modeled [10]')
    mapParse.add_argument('--minExprMAD', type=int, default=10, help='Minimum mean absolute deviation of expression for gene to be modeled [10]')
    mapParse.add_argument('--numbOfNetworks', type=int, metavar='NON', default=5, help='Number of networks [5]')
#    mapParse.add_argument('--numNetworkCutoff', type=int, metavar='NNC', default=20, help='Number of networks cutoff [20]')
    mapParse.add_argument('--deseq', action='store_true', help="Use DESeq2 method for DE significance calculations (edgeR is default)")
    mapParse.add_argument('--save', action='store_true', help="Saved trained models for later use")
    mapParse.add_argument('--seed', type=int, help='Optional random seed for sampling training cases')
    mapParse.add_argument('-q', '--quiet', action='store_true', help="Limit program output")
    mapParse.add_argument('--synth', action='store_true', help="Save synthetic expression tables")
    mapParse.add_argument('-o','--outname', type=str, default='./GAiN', help='Prefix for output filenames [./GAiN]')
    mapParse.add_argument('-p', '--popCSV', type=str, default='', help="Population expression table in CSV format")
    mapParse.add_argument('inputCSV', metavar='input.csv', help="Training expression table in CSV format")
    args = mapParse.parse_args()
    
    print(args) #TO DO remove
    
    global quiet 
    quiet = args.quiet
    if quiet:
        tf.get_logger().setLevel('ERROR')
    else:
        deprecation._PRINT_DEPRECATION_WARNINGS = True
        

    outDir = os.path.dirname(args.outname)
    if not outDir:
        outDir='.'

    if not os.path.isdir(outDir):
        print('Output folder {} does not exist, please create it before running {}'.format(outDir,os.path.basename(__file__)))
        sys.exit(1)
    
    #See above parameters for arguments
    final_test_df_sig_real = SepMultirunGAN(args.inputCSV,
                                                      args.popCSV,
                                                      args.outname,
                                                      args.epochs,
                                                      args.batchsizes[0],
                                                      args.batchsizes[1],
                                                      args.minExprMean,
                                                      args.minExprMAD,
                                                      args.numbOfNetworks, 
                                                      False,
                                                      args.deseq,
                                                      args.save,
                                                      args.synth,
                                                      seed=args.seed)
    #Write CSV with genelists
    # counts_df.to_csv(args.outname+"_counts_df.csv", index = True, header=False)
    
    final_test_df_sig_real.sort_values('rank').to_csv(args.outname+"_DE_gene_list.csv", index = True, header=True)

    
if __name__=="__main__":
    main()
