# -*- coding: utf-8 -*-
"""
Created on Thu Jun 22 15:40:23 2023

@author: dunca
"""

#Dependencies
import pandas as pd 
from scipy.stats import entropy
import seaborn as sns
import matplotlib.pyplot as plt
import os

#Inputs
rawOTU=r"/Volumes/TOSHI_EXT/Projects/Mycobiome/Data/OTU Tables_Lowest/Archaea_Raw/Archaea_L8.csv"
metaFile=r"/Volumes/TOSHI_EXT/Projects/Mycobiome/Data/metadata.csv"
outDir=r"/Volumes/TOSHI_EXT/Projects/Mycobiome/Output/tOut_Rev"
abrev="Archaea"
theColors=colors = [
    "#00008B",  # darkblue
    "#ADD8E6",  # lightblue
    "#B8860B",  # darkgoldenrod
    "#DAA520",  # goldenrod
    "#006400",  # darkgreen
    "#90EE90",   # lightgreen
    "#8B0000",  # darkred
    "#F08080"  # lightcoral
]

def intersect_data(df1, df2):
    # Create an empty list to store the rows
    rows = []

    # Iterate over the values in the "SampleID" column of df1
    for sample_id in df1["SampleID"]:
        # Find the matching row index in df2
        row_index = df2.index[df2.index == sample_id]
        
        # Check if a matching row is found
        if len(row_index) > 0:
            # Get the values from df1 and df2
            sample_id_value = sample_id
            group_value = df1.loc[df1["SampleID"] == sample_id, "Group2"].values[0]
            column_0_value = df2.loc[row_index, "Shannon"].values[0]
            
            # Append the values as a new row
            rows.append([sample_id_value, group_value, column_0_value])
    
    # Create a new dataframe from the list of rows
    intersect_df = pd.DataFrame(rows, columns=["SampleID", "Group", "Shannon"])
    
    # Return the intersected dataframe
    return intersect_df

def create_violin_plot(dataframe):
    # Set the style of the plot
    sns.set_style("whitegrid")
    
    # Set the colors for the violin plot
    colors = theColors
    
    # Create the violin plot
    plt.figure(figsize=(10, 6))
    plt.xticks(rotation = 45)
    sns.violinplot(data=dataframe, x="Group", y="Shannon", palette=colors)
    
    # Set the labels and title
    plt.xlabel("Genotype")
    plt.ylabel("Shannon Diversity")
    plt.title(abrev+" Shannon Diversity by Group")
    
    # Show the plot
    tS=abrev+"Species_AlphaDiversity_ViolinPlot_ByGender.png"
    tSa=os.path.join(outDir,tS)
    plt.savefig(tSa, dpi=300)
    plt.show()

#Run
#Load files
df_otu=pd.read_csv(rawOTU,index_col=0)
df_otu_T=df_otu.T
df_meta=pd.read_csv(metaFile)

#Get shannon
shannon_indexes = df_otu_T.apply(lambda row: entropy(row), axis=1)
sDF=pd.DataFrame(shannon_indexes,columns=["Shannon"])

#Intersected
df2=intersect_data(df_meta, sDF)

create_violin_plot(df2)