# This script is available under GPL v3 license

import seaborn as sns
from osgeo import gdal
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def getFrequencies(inputList: list) -> tuple:
    """
    Returns a tuple of the pixel count, percentage and model names of each raster / class.
    """
    pixelCountsList = []
    names = []
    for element in inputList:
        data = gdal.Open(element[0]) # raster path
        names.append(element[1]) # label name
        array = data.ReadAsArray()
        _, pixelCount = np.unique(array, return_counts = True)
        pixelCountsList.append(pixelCount)
    pixelCountsArray = np.vstack(pixelCountsList) # pixel counts
    pixelCountsPercentArray = pixelCountsArray/pixelCountsArray[0].sum()*100
    return (pixelCountsArray, pixelCountsPercentArray, names)

def drawBarplotAndTable(inputList: list, usePercent = False, colors = [], zoneLabels = []) -> None:
    """
    Generates a barplot with a table based on raster value distribution.
    usePercent (bool): If true show percentage distribution, else show pixel count
    colors (list): List of strings with matplotlib color names, leave empty for random colors
    zoneLabels (list): List of strings with labels (left -> right; highest -> lowest), leave empty
        for defaults.
    """
    pixelCountsArray, pixelCountsPercentArray, names = getFrequencies(inputList)
    plt.rcParams.update({'font.size': 12})
    plotdata = pd.DataFrame(data=pixelCountsArray.T/1000, columns=names,)
    sns.set_style("white")
    if not colors:
        for i in range(len(inputList)):
            color = (np.random.random(), np.random.random(), np.random.random())
            colors.append(color)
    plotdata.plot(kind="bar", color = colors)
    plt.ylabel("Pixel count (x1000)")
    # disable ticks on x axis
    plt.tick_params(
        axis='x',
        which='both',
        bottom=False,
        top=False,
        labelbottom=False)

    # Add a table at the bellow the x axis
    if not zoneLabels:
        zoneLabels = ["Very high", "High", "Moderate", "Low", "Very low"]
    if usePercent:
        yAxisValues = pixelCountsPercentArray
    else:
        yAxisValues = pixelCountsArray
    table = plt.table(cellText = np.round(yAxisValues, 2),
                      cellLoc = "center",
                      rowLabels = names,
                      rowLoc = "center",
                      colLabels = zoneLabels,
                      colLoc = "center",
                      fontsize = 10,
                      loc = "bottom",
                      edges = "closed")
    plt.tight_layout()
    plt.show()

# define your input list. A list of tupels with raster path and a label.
inputList = [
    (r"D:\LSAT_Project\results\susceptibility_maps\woe_test_map.tif", "WoE"),
    (r"D:\LSAT_Project\results\susceptibility_maps\lr_c_test_map.tif", "LR_c"),
    (r"D:\LSAT_Project\results\susceptibility_maps\lr_d_test_map.tif", "LR_d"),
    (r"D:\LSAT_Project\results\susceptibility_maps\ann_c_test_map.tif", "ANN_c"),
    (r"D:\LSAT_Project\results\susceptibility_maps\ann_d_test_map.tif", "ANN_d")
]
drawBarplotAndTable(inputList, usePercent=False, colors = ["r", "b", "orange", "g", "y"])