# -*- coding: utf-8 -*-
"""
Created on Thu Jun 15 10:55:17 2023
Script to create heatmaps of correlations from micobiome studies.
Automated in Sept to reduce user errors. 
@author: dunca
"""
#Dependencies------------------------------------------------------------------
import pandas as pd
import re
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import os
#------------------------------------------------------------------------------

#inputs------------------------------------------------------------------------
bugKeyF=r"/Volumes/TOSHI_EXT/Projects/Mycobiome/Output/CorrelationStudies/bugKey.xlsx" #A simple file that links the 
outDir=r"/Volumes/TOSHI_EXT/Projects/Mycobiome/Output/tOut_Rev" #Where to save
master_KeyF=r"/Volumes/TOSHI_EXT/Projects/Mycobiome/Output/CorrelationStudies/created_BugKey_Master.csv" #Filew hich says which species are signficant in each genotype / gender division
#Input Data Note - This must be a master direcotry with three subdirectories (male,female, and bothGenders)
masterDirectory=r"/Volumes/TOSHI_EXT/Projects/Mycobiome/Output/CorrelationStudies"
sigThresh=0.05 #The p value threshold for signficance. 
lfc_File_Euk=r"/Volumes/TOSHI_EXT/Projects/Mycobiome/Output/CorrelationStudies/Eukaryote_LFC_Key.csv" #The file with the log fold changes by species for eukaryota
lfc_File_Arch=r"/Volumes/TOSHI_EXT/Projects/Mycobiome/Output/CorrelationStudies/Archaea_LFC_Key.csv" #The file with the log fold changes by species for eukaryota
#------------------------------------------------------------------------------

#Functions---------------------------------------------------------------------

def getSubDirectoryList(masterDir):
    """
    A simple function to get a list of subdirectories containing the different files by sex divide.

    Parameters
    ----------
    masterDir : Str
        Path to the master directory

    Returns
    -------
    List[Str] - A list of subdirectories of interest within the master directory

    """
    toRet=[]
    for ind in ["bothGenders","female","male"]:
        toAdd=os.path.join(masterDir,ind)
        toRet.append(toAdd)
    return(toRet)


def remove_entries(df1, df2,theThresh):
    """
    A function to "remove" those correlations that are not under a significance threshold.
    This removal is accomplished by setting to na
    
    df1 = pandas dataframe = The correlation dataframe
    df2 = pandas dataframe = The signficance (p value) dataframe
    theThresh = Float = the p value threshold you wish to apply
    
    Returns 
        df1 = an updated version of the correlation dataframe with na's at non-sig spots.
    
    """
    # Iterate over each cell in df2
    for index, value in np.ndenumerate(df2.values):
        if value > theThresh:
            # Set corresponding entry in df1 to NaN
            df1.iloc[index] = np.nan
    
    # Return the updated dataframe
    return df1

def make_symmetric(idf):
    """
    This simple function makes a traingular correlation dataframe into a symmetric dataframe by reflection over the diagnal.

    Parameters
    ----------
    idf : Pandas dataframe
        A pandas dataframe representation of a traingular correlation matrix.

    Returns
    -------
    A pandas dataframe of this matrix after it has been made symmetrical.

    """
    ndf=idf.copy()
    r=0
    for index, row in idf.iterrows():
        c=0
        for item in row:
            if c>r:
                theVal=idf.iat[c,r]
                ndf.iat[r,c] = theVal
            c=c+1
        r=r+1
    return(ndf)

def getKeyDict(keyDF):
    """
    This simple function just maps bugs to the their domain of life (Archaea, bacteria, etc)

    Parameters
    ----------
    keyDF : Pandas dataframe
        DESCRIPTION.

    Returns
    -------
    Dictionary that maps bugs sto domains.
        Dict[Str] - [Str]

    """
    toRet={}
    gOpts=keyDF['Group'].unique()
    for g in gOpts:
        subDF= keyDF[(keyDF['Group'] == g)]
        specs=list(subDF["Bug"].unique())
        toRet[g]=replace_non_alnum(specs)
    return(toRet)


def replace_non_alnum(strings):
    """
    REmove substrings that don't belong in species names and adjust those names throwing errors.

    Parameters
    ----------
    strings : List of strings
        List of bugs

    Returns
    -------
    A modified list where the formating of the names is fixed.

    """
    toRet=[]
    for s1 in strings:
        s = re.sub('[^0-9a-zA-Z]+', '.', s1)
        modified_string = s.replace('.', '__', 1)
        if "magnoliae" in modified_string:
            modified_string="s__.Candida..magnoliae"
        if "Haemophilus" in modified_string:
            modified_string="s__.Haemophilus..ducreyi"
        if "Microbacterium" in modified_string:
            modified_string="s__Microbacterium.sp..LKL04"
        if "Amanita" in modified_string:
            modified_string="s__Amanita.sp...trygonion."  
        if "huakuii" in modified_string:
            modified_string="s__Mesorhizobium.huakuii"
        toRet.append(modified_string)
    return(toRet)   

def apply_replacements_to_dictionary_strings(dictionary):
    """
    This function adjusts the names in a dictionary so they are easier to look at when graphing

    Parameters
    ----------
    dictionary : Dict[Str] - Str
          Bug dictionary

    Returns
    -------
    modified_dict : Dict[Str] - Str
       Names fixed

    """
    modified_dict = {}
    
    for key, value in dictionary.items():
        modified_value = [string.replace("s__.", "").replace("s_", "").replace("s_.", "").replace("s_.", "").replace(".", " ").replace("_", "").replace("  ", " ") for string in value]
        modified_dict[key] = modified_value
    
    
    return modified_dict


def create_clustermap(dataframe, bug_dict,outD,theName,botBug,theThresh):
    """
    Actually create, and save, the correlation map.

    Parameters
    ----------
    dataframe : Pandas dataframe
        The correlation dataframe
    bug_dict : Dict[Str] - List[Str]
        The dictionary which specifies which domain each species belongs to.
    outD : Str
        Folder where the output will be saved
    theName : Str
        The file name, from this the the genotype is extracted.
    botBug : Str
        Archaea or Fungi?

    theThresh : Float
        The signficance threshold used
    Returns
    -------
    None.

    """
    #Fix some annoying formatting erros
    bug_dict1=apply_replacements_to_dictionary_strings(bug_dict) 
    
    # Create a custom color map
    cmap = sns.diverging_palette(250, 10, s=90, l=50, as_cmap=True)
    mymask = dataframe == 0
    
    #Create a clustermap
    clustergrid = sns.clustermap(dataframe, annot=True, mask=mymask, cmap=cmap, center=0, row_cluster=False,col_cluster=False)
    clustergrid.ax_col_dendrogram.set_visible(False)
    clustergrid.ax_heatmap.set_xticklabels(clustergrid.ax_heatmap.get_xticklabels(), rotation=45, ha='right')

    # Set color for x-axis tick labels based on the key dictionary
    for label in clustergrid.ax_heatmap.get_xticklabels():
        column = label.get_text()
        for group, columns in bug_dict1.items():
            if column in columns:
                color = 'red' if group == 'Archaea' else 'green' if group == 'Eukaryota' else 'Purple' if group == 'Virus' else 'blue'
                label.set_color(color)

    # Set color for y-axis tick labels based on the key dictionary
    for label in clustergrid.ax_heatmap.get_yticklabels():
        column = label.get_text()
        for group, columns in bug_dict1.items():
            if column in columns:
                color = 'red' if group == 'Archaea' else 'green' if group == 'Eukaryota' else 'Purple' if group == 'Virus' else 'blue'
                label.set_color(color)

    # Set plot title
    fTitle=""
    fDict={"Lyz":"Lyz","DEFA":"PC","IEC":"IEC"}
    for subS in ["Lyz","DEFA","IEC"]:
        if subS in theName:
            fTitle=fDict[subS]
    theTitle="Significant (p<="+str(theThresh)+") Correlation of Differential Taxa: "
    title = clustergrid.ax_heatmap.set_title(theTitle+fTitle, fontsize=14)
    title.set_position([0.5, 1.05])  # Adjust the position of the title (x, y) relative to the plot area
    plt.gcf().set_size_inches(10, 8)  # Adjust figure size as needed
    
    #Save the plot
    tOut="Clustermap_SelectBugs_"+theName+"_"+botBug+".png"
    
    plt.tight_layout()
    plt.savefig(os.path.join(outD,tOut),dpi=300)
    # Show the plot
    plt.show()

def filter_and_sort_dataframe(dataframe, curL):
    """
    This function filters the rows of the dataframe so that they only contain bugs we are interested in.

    Parameters
    ----------
    dataframe : pandas dataframe
        Correlation dataframe
    curL : List of str
        List of bugs that we want to visualize (differential bugs in this genotype)

    Returns
    -------
    sorted_dataframe : TYPE
        New dataframe without those bugs. SORted is a bit of a misnomer.

    """
    curL2=[]
    for val in curL:
        if val in dataframe.index:
            curL2.append(val)
    filtered_dataframe = dataframe.loc[curL2]
    sorted_dataframe = filtered_dataframe.reindex(curL2)
    return sorted_dataframe


def modify_dataframe(df):
    """
    This function adjusts the names in a dictionary so they are easier to look at when graphing

    Parameters
    ----------
    dictionary : Dict[Str] - Str
          Bug dictionary

    Returns
    -------
    modified_dict : Dict[Str] - Str
       Names fixed
   """
    # Modify column names
    modified_columns = [col.replace("s__.", "").replace("s_", "").replace("s_.", "").replace("s_.", "").replace(".", " ").replace("_", "").replace("_", "").replace("  ", " ") for col in df.columns]
    df.columns = modified_columns
    
    # Modify row names
    modified_index = [index.replace("s_", "").replace("s_.", "").replace("s_.", "").replace("s__.", "").replace(".", " ").replace("_", "").replace("  ", " ") for index in df.index]
    df.index = modified_index
    
    df.columns = df.columns.str.lstrip()
    df.index = df.index.str.lstrip()

    
    return df


def getKeepDict2(mstrDF,theSex):
    """
    This function takes an gender name (both, male, female), and builds the dictionary of bugs to keep for that sample by genotype.

    Parameters
    ----------
    mstrDF : PD Dataframe
        The master annotation df
    theSex : Str
        Gender group
        
    Returns
    -------
    Dictionary{Str} = List
    (Maps list of bugs to each domain)

    """
    toCheck=["Lyz","IEC","DEFA"]
    toRet={}
    
    # Fetch the list inidices associated with the sex of interest.
    if theSex=="both":
        columnsToUse = mstrDF.columns[2:5]
    if theSex=="female":
        columnsToUse = mstrDF.columns[5:8]
    if theSex=="male":
        columnsToUse = mstrDF.columns[8:11]

    #Loop over the rows
    counter=0
    for indcolName in columnsToUse:
        theCat=toCheck[counter]
        counter=counter+1

        for index, row in mstrDF.iterrows():
            name = row['Bug']
            present=row[indcolName]
            
            #If we see a "T" we should add this to the appropriate domain, otherwise it is not a microbe of interest
            if present=="T":
                if theCat in toRet:
                    theVal=toRet[theCat]
                    theVal.append(name)
                    toRet[theCat]=theVal
                else:
                    toRet[theCat]=[name]
    return(toRet)

def getColumnOrder(colList,lfc_DF):
    # Create a dictionary mapping names to LFC values
    name_to_lfc_dict = dict(zip(lfc_DF['Name'], lfc_DF['LFC']))
    
    # Sort the list of names based on the associated LFC values
    sorted_names = sorted(colList, key=lambda name: name_to_lfc_dict.get(name, 0))
    return(sorted_names)    
        
#Run-----------------------------------------------------------------------------
#First we load in the key file that says which bugs should be considered for which subgroups 
master_key_df=pd.read_csv(master_KeyF)
bugDict=getKeyDict(master_key_df) #Use it to simply map each species to its domain of life
subDirList=getSubDirectoryList(masterDirectory) #Get the list of subdirectories with correlation input
#Load files with LFC information to sort the x axis later
lfc_df_Arch=pd.read_csv(lfc_File_Arch) 
lfc_df_Euk=pd.read_csv(lfc_File_Euk)


#Loop over the sex divides
sex_List=["both","female","male"]

#Lets get dictionaries for the bugs that are differential based on archaea.
btk_Dict_male=getKeepDict2(master_key_df,"male")
btk_Dict_single=getKeepDict2(master_key_df,"both")

i=0
for indSex in sex_List:
    #First we need to get a dictionary which says which bugs to keep by genotype for a given sex divide
    #This is based on which of the differential bugs we didn't filter have signficant correlations 
    btk_Dict=getKeepDict2(master_key_df,indSex)
    
    #Subset hte bugs to the sex division we are intereststd in
    subset_LFC_Euk= lfc_df_Euk.loc[lfc_df_Euk['Gender'] == indSex ]
    subset_LFC_Arch= lfc_df_Arch.loc[lfc_df_Arch['Gender'] == indSex ]
     
    
    #Next we need to loop through correlation data to get the data for our map
    directory=subDirList[i]
    print("Running on directory: ",directory)
    i=i+1

    for filename in os.listdir(directory):
        if filename.endswith('.xlsx') and '_vs_' not in filename and not (filename.startswith('.') or filename.startswith('~$')):
            file_path = os.path.join(directory, filename)
            file_name_without_extension = os.path.splitext(filename)[0]
    
            #Load the correlation data frame which stores correlation values between pairs
            df = pd.read_excel(file_path, 'Cor',index_col=0)
            #Next, we simply reflect over the diagnol to make this symmetric for visualization purposes
            df=make_symmetric(df)
            
            #Load the signficance dataframe which stores the signficance (pvalues) of the correlation above
            #It is critical to note that the indices here are the same as those in the correlation matrix
            df_p = pd.read_excel(file_path, 'p',index_col=0)
            df_p=make_symmetric(df_p) #Make symmetric as above
            
            #...because of these identical indices it is easy to remove those correlations that aren't under a significance threshold.
            df=remove_entries(df,df_p,sigThresh)
            
            
            #Now that we have correlation matrices agnostic of the species of interest, lets select the species of interest
            #Get keep dictionary for this particular subdivision
            if "DEFA" in file_name_without_extension:
                curL=btk_Dict["DEFA"]
                subset_LFC_Euk2= lfc_df_Euk.loc[lfc_df_Euk['KO_Group'] == "DEFA" ]
                subset_LFC_Arch2= lfc_df_Arch.loc[lfc_df_Arch['KO_Group'] == "DEFA" ]
                
            if "Lyz" in file_name_without_extension:
                curL=btk_Dict["Lyz"]
                subset_LFC_Euk2= lfc_df_Euk.loc[lfc_df_Euk['KO_Group'] == "Lyz" ]
                subset_LFC_Arch2= lfc_df_Arch.loc[lfc_df_Arch['KO_Group'] == "Lyz"]
            
            if "IEC" in file_name_without_extension:
                curL=btk_Dict["IEC"]
                subset_LFC_Euk2= lfc_df_Euk.loc[lfc_df_Euk['KO_Group'] == "IEC"]
                subset_LFC_Arch2= lfc_df_Arch.loc[lfc_df_Arch['KO_Group'] == "IEC"]
            
            curL = [s.replace(" ", ".") for s in curL] #Formatting
            curL = [s.replace("[", ".") for s in curL] #Formatting
            curL = [s.replace("]", ".") for s in curL] #Formatting
            
            #Build the plots, which first requires us to split by what will be on the x axis
            for theCat in ["Archaea","Eukaryota"]:
                subOutDir=os.path.join(outDir,theCat)
                # Check if the subdirectory already exists
                if not os.path.exists(subOutDir):
                    # If it doesn't exist, create the subdirectory
                    os.mkdir(subOutDir)
                
                
                #Get the relevent log fold change lists
                if theCat=="Archaea":
                    subset_LFC=subset_LFC_Arch2
                else:
                    subset_LFC=subset_LFC_Euk2
                
                
                #Building x axis - Our x axis will be composed of either archaea or eukaryotes
                #So we just need to select the ones from this list that are in the genotype of interest
                tC=bugDict[theCat] #From all options we want to select....
                toGet=[]
                for val in df.columns:#... from those in our correlation data frame...
                    if val in tC: 
                        if val in curL: #... only those that are in our current genotype.
                            toGet.append(val)
                
                subset_df = df[toGet] #Performing the subsetting highlighted above
                
                #Dropping NA rows and columns
                #... remember that above we used placement of an na value to select those pairs that do not meet our signficance threshold.
                #Any row or column that has ONLY na's is thus one we do not wish to visualize, and can be removed.
                
                #Drop rows
                if "DEFA" in file_name_without_extension:
                    print(subset_df)
                subset_df=filter_and_sort_dataframe(subset_df, curL)
                if "DEFA" in file_name_without_extension:
                    print(subset_df)
                
                #Drop columns
                all_nan_columns = subset_df.columns[subset_df.isna().all()] #Identify the columns that only have nas...
                subset_df = subset_df.drop(all_nan_columns, axis=1) #... then remove those columns
                subset_df = subset_df.dropna(how='all') #Use innate pandas function to drop those rows that are na only
                subset_df = subset_df.fillna(0) #Finally, for asthetic purposes when graphing,populate the remaining nas with 0's
                
                #Logic break if we have no signficant correlations (common for IEC group)
                if subset_df.empty:
                    continue
                
                #For asthetic reasons, I wanted to match the order of the differential abundance plots
                #This means the spcies need to be sorted by their LFC values
                colList = [s.replace(".", " ") for s in subset_df.columns ] #Reformat to match up
                column_order=getColumnOrder(colList,subset_LFC) #Now sort by the LFC value...
                column_order= [s.replace(" ", ".") for s in column_order ] #Reformat to match up
                subset_df = subset_df[column_order] #... and cast the dataframe to that sorted value
                
                subset_df=modify_dataframe(subset_df) #Fix some formatting issues
                
                #Finally, create and save the clustermap for this comparison
                create_clustermap(subset_df,bugDict,subOutDir,file_name_without_extension,theCat,sigThresh)



