import numpy as np
import pandas as pd


def find_high_var_genes(
            sc_ref_df,
            disp_min=0.5,
            ):
    """ Finds highly variable genes across cell types in sc_ref_df. The implementation
    follows scanpy's highly_variable_genes procedure for flavor 'Seurat'."""
    # normalize counts within each celltype to sum 1
    sc_ref_norm = sc_ref_df.div(sc_ref_df.sum(axis=0), axis=1)

    # calculate mean and variance of each gene across types
    mean = np.mean(sc_ref_norm, axis=1)
    mean_of_sq = np.multiply(sc_ref_norm, sc_ref_norm).mean(axis=1)
    var = mean_of_sq - mean **2
    # enforce R convention (unbiased estimator) for variance
    var *= np.shape(sc_ref_norm)[1] / (np.shape(sc_ref_norm)[1] - 1)
    # set entries equal to zero to small value to avoid div by 0 value
    mean[mean == 0] = 1e-12
    # caculate dispersion from var and mean
    dispersion = var / mean

    # for 'seurat' flavor, log versions of mean and dispersion are needed
    dispersion[dispersion == 0] = np.nan
    dispersion = np.log(dispersion)
    mean = np.log1p(mean)

    # collect in a dataframe
    df = pd.DataFrame(index=sc_ref_norm.index)
    df['means'] = mean
    df['dispersions'] = dispersion

    # group into 20 bins
    df['mean_bin'] = pd.cut(df['means'], bins=20)
    disp_grouped = df.groupby('mean_bin')['dispersions']
    # mean and std of dispersion in each group
    disp_mean_bin = disp_grouped.mean()
    disp_std_bin = disp_grouped.std(ddof=1)

    # retrieve those genes that have nan std, these are the ones where
    # only a single gene fell in the bin and implicitly set them to have
    # a normalized disperion of 1
    one_gene_per_bin = disp_std_bin.isnull()
    disp_std_bin[one_gene_per_bin.values] = disp_mean_bin[one_gene_per_bin.values].values
    disp_mean_bin[one_gene_per_bin.values] = 0

    # normalize dispersions with respect to mean and std within each gene's expression bin
    df['dispersions_norm'] = ((df['dispersions'].values -
                               disp_mean_bin[df['mean_bin'].values].values) /
                              disp_std_bin[df['mean_bin'].values].values)

    # check which genes pass dispersion and expression thresholds
    dispersion_norm = df['dispersions_norm'].values.astype('float32')
    dispersion_norm[np.isnan(dispersion_norm)] = 0  # similar to Seuratß
    dispersion_norm[np.isnan(dispersion_norm)] = 0  # similar to Seurat
    gene_subset = (dispersion_norm > disp_min)

    # write to df
    df['highly_variable'] = gene_subset

    # get list of gene names
    high_var_genes = df.loc[df['highly_variable']].index.tolist()

    return high_var_genes
