# -*- coding: utf-8 -*-
"""
Created on Tue Jul  5 17:24:42 2022


@author: Administrator
"""
import os
import bz2
import pickle

import scipy.stats
import pandas as pd
pd.set_option('use_inf_as_na', True)

import numpy as np
import numpy.ma as ma
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.use('Agg')
from mpl_toolkits.mplot3d import Axes3D

import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.stats.api import anova_lm

from sklearn.ensemble import RandomForestRegressor
import joblib


from sklearn.inspection import plot_partial_dependence
from sklearn.inspection import partial_dependence




from netCDF4 import Dataset

import matplotlib as mplt
import cartopy.crs as ccrs
import cartopy.feature as cfea
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.basemap import Basemap
from matplotlib import pyplot

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

# reduced major axis regression, when both x and y has variability (std)
from pylr2 import regress2


plt.rcParams['axes.labelsize'] = 14
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['xtick.labelsize'] = 14
plt.rcParams['ytick.labelsize'] = 14
plt.rcParams['legend.fontsize'] = 14
plt.rcParams['font.family'] = 'Times New Roman'

os.chdir("/n/holylfs04/LABS/moorcroft_lab/Users/sqliu/GEDI_regiondata_Compactness_Reprocess")

Hite_grouped = pd.read_csv('/n/holylfs04/LABS/moorcroft_lab/Users/sqliu/GEDI_regiondata_Compactness/Simard_heigt.csv')
Hite_grouped.dropna(inplace=True)
#Hite_grouped['Height_diff'] = Hite_grouped['Simard_Height'] - Hite_grouped['Height']
#Hite_grouped = Hite_grouped[['Lat','Lon','Height_diff','Region']]
#Hite_grouped = Hite_grouped.rename(columns = {'Height_diff':'Height'})



def get_from_action_data(fname, chunk_size=100000):
    reader = pd.read_csv(fname, header=0, iterator=True)
    chunks = []
    loop = True
    while loop:
        try:
            label = ['Lat','Lon','TotalPAI', 'MaxPAI', 'Height', 'Beta_a', 'Beta_b', 'Region', 'Elev', \
                     'Prec', 'Temp', 'Prec_sea', 'Temp_sea', 'Tmax', 'Vapr', 'Wind', \
                     'Cloud', 'Fapar', 'Light', 'DSL', 'Soil_Bulk', 'Soil_Sand', 'Soil_SOC', 'Ssrd']
            chunk = reader.get_chunk(chunk_size)[label]
            chunks.append(chunk)
        except StopIteration:
            loop = False
            print("Iteration is stopped")
    df_ac = pd.concat(chunks, ignore_index=True)
    return df_ac

def get_from_action_data2(fname, chunk_size=100000): # vertical pai_###_cm
    reader = pd.read_csv(fname, header=0, iterator=True)
    chunks = []
    loop = True
    while loop:
        try:
            label = ["pai z_%03d_cm" % i for i in range(0,150,5)]
            label.append('Canopy Height (rh100)')
            label.append('Region')
            chunk = reader.get_chunk(chunk_size)[label]
            chunks.append(chunk)
        except StopIteration:
            loop = False
            print("Iteration is stopped")
    df_ac = pd.concat(chunks, ignore_index=True)
    return df_ac

def get_from_action_height(fname, chunk_size=100000):
    reader = pd.read_csv(fname, header=0, iterator=True)
    chunks = []
    loop = True
    while loop:
        try:
            label = ['Lat','Lon','Elev','Region','Beta_a', 'Beta_b','R2','RMSE','TotalPAI','MaxPAI','Height', \
                     'Prec', 'Temp', 'Prec_sea', 'Temp_sea', 'Tmax', 'Vapr', 'Wind', \
                     'Cloud', 'Fapar', 'Light','DSL','Soil_Bulk','Soil_Sand','Soil_SOC','Ssrd']
            chunk = reader.get_chunk(chunk_size)[label]
            chunks.append(chunk)
        except StopIteration:
            loop = False
            print("Iteration is stopped")
    df_ac = pd.concat(chunks, ignore_index=True)
    return df_ac

def vertical_PAI_plot(df):
    label = []
    for i in range(0, 140, 5):
        columns1 = 'pai z_' + format(i, '03d') + '_cm'
        columns2 = 'pai z_' + format(i+5, '03d') + '_cm'
        columns = 'pai_' + format(i, '03d') + '_' + format(i+5, '03d')
        label.append(columns)
        df[columns] = df[columns1] - df[columns2]
    label.append('Region')
    df = df[label]

    average = df.mean(axis = 0).values
    std = df.std(axis = 0).values
    fig, axs = plt.subplots(1, 4, figsize=(16,6))
    axs[0].plot(average [0:28],range(0, 140, 5))
    axs[0].fill_betweenx(range(0, 140, 5), average[0:28] - std[0:28], average[0:28] + std[0:28], alpha=0.2)
    axs[0].set_ylim([0, 60])

    average = df.groupby("Region").mean().reset_index()
    std = df.groupby("Region").std().reset_index()

    temp1 = average.loc[average.Region==0,label].values
    axs[1].plot(temp1[0,0:28], range(0, 140, 5))
    temp2 = std.loc[std.Region==0,label].values
    axs[1].fill_betweenx(range(0, 140, 5), temp1[0,0:28] - temp2[0,0:28], temp1[0,0:28] + temp2[0,0:28], alpha=0.2)
    axs[1].set_ylim([0, 60])

    temp1 = average.loc[average.Region==1,label].values
    axs[2].plot(temp1[0,0:28], range(0, 140, 5))
    temp2 = std.loc[std.Region==1,label].values
    axs[2].fill_betweenx(range(0, 140, 5), temp1[0,0:28] - temp2[0,0:28], temp1[0,0:28] + temp2[0,0:28], alpha=0.2)
    axs[2].set_ylim([0, 60])

    temp1 = average.loc[average.Region==2,label].values
    axs[3].plot(temp1[0,0:28], range(0, 140, 5))
    temp2 = std.loc[std.Region==2,label].values
    axs[3].fill_betweenx(range(0, 140, 5), temp1[0,0:28] - temp2[0,0:28], temp1[0,0:28] + temp2[0,0:28], alpha=0.2)
    axs[3].set_ylim([0, 60])
    plt.savefig('Statistical_Analysis/Vertical_PAI.png', dpi = 300)
    
    return 0
# plot the comparison between RF prediction and observation
def RF_plot(df):
    colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)', (1, 1, 0.2), (0.98,0.98,0.98)]

    #hite_pre = np.loadtxt('Statistical_Analysis/Hite_predict.txt')
    hite_pre = np.loadtxt('Statistical_Analysis/Hite_predict_Ssrd.txt')
    
    
    x=hite_pre
    y=df['Height'].values
    bins = [1000, 1000] # number of bins
    
    # histogram the data
    hh, locx, locy = np.histogram2d(x, y, bins=bins)
    
    #Sort the points by density, so that the densest points are plotted last
    z = np.array([hh[np.argmax(a<=locx[1:]),np.argmax(b<=locy[1:])] for a,b in zip(x,y)])
    idx = z.argsort()
    x2, y2, z2 = x[idx], y[idx], z[idx]

    plt.figure(1,figsize=(8,8)).clf()
    s = plt.scatter(x2, y2, c=z2, cmap='YlOrBr', marker='.') 
    A = np.vstack([hite_pre,np.ones(len(hite_pre))]).T
    m,c = np.linalg.lstsq(A,np.array(df['Height'].values))[0]
    plt.plot(hite_pre,hite_pre*m+c,'r-')
    plt.xlabel('Prediction')
    plt.ylabel('Observation')
    plt.xlim([15, 60])
    plt.ylim([15, 60])
    plt.savefig('figure/RF_HITE_Ssrd_plot.png', dpi = 300)
    print(np.linalg.norm(hite_pre - np.array(df['Height'].values)) / np.sqrt(len(hite_pre)))
    print('\n')

    """
    #pai_pre = np.loadtxt('Statistical_Analysis/PAI_predict.txt')
    pai_pre = np.loadtxt('Statistical_Analysis/PAI_predict_NoTreeTypes.txt')
    x=pai_pre
    y=df['TotalPAI'].values
    bins = [10000, 10000] # number of bins
    
    # histogram the data
    hh, locx, locy = np.histogram2d(x, y, bins=bins)
    
    #Sort the points by density, so that the densest points are plotted last
    z = np.array([hh[np.argmax(a<=locx[1:]),np.argmax(b<=locy[1:])] for a,b in zip(x,y)])
    idx = z.argsort()
    x2, y2, z2 = x[idx], y[idx], z[idx]

    plt.figure(1,figsize=(8,8)).clf()
    s = plt.scatter(x2, y2, c=z2, cmap='YlOrBr', marker='.') 
    A = np.vstack([pai_pre,np.ones(len(pai_pre))]).T
    m,c = np.linalg.lstsq(A,np.array(df['TotalPAI'].values))[0]
    plt.plot(pai_pre,pai_pre*m+c,'r-')
    plt.xlabel('Prediction')
    plt.ylabel('Observation')
    plt.xlim([0, 8])
    plt.ylim([0, 8])
    plt.savefig('Statistical_Analysis/RF_PAI_plot.png', dpi = 300)
    print(np.linalg.norm(pai_pre - np.array(df['TotalPAI'].values)) / np.sqrt(len(hite_pre)))
    """
    return 0




# plot the histgram of grouped Hite for each region: SEAsia, Africa and Amazon
def hite_hist_plot(df):
    fig = plt.figure(figsize=(12, 3))
    ax1 = plt.subplot(141)
    ax1.hist(Hite_grouped.Height,bins=500,density=True)
    ax1.set_xlim([-20,20])
    ax1.set_ylim([0,0.1])
    ax1.set_ylabel('Frequency')
    ax1.set_xlabel('Height_diff(m)')
    aver = np.mean(Hite_grouped.Height)
    std = np.std(Hite_grouped.Height)
    ax1.text(6, 0.09, 'Pan Tropic',fontsize=12)
    ax1.text(7, 0.07, 'mean: ' + "{:.1f}".format(aver))
    ax1.text(10, 0.06, 'std: ' + "{:.1f}".format(std))
    
    ax2 = plt.subplot(142)
    temp = Hite_grouped.loc[Hite_grouped.Region==0]
    ax2.hist(temp.Height,bins=500,density=True)
    ax2.set_xlim([-20,20])
    ax2.set_yticklabels([])
    ax2.set_ylim([0,0.1])
    aver = np.mean(temp.Height)
    std = np.std(temp.Height)
    ax2.text(10, 0.09, 'Africa',fontsize=12)
    ax2.text(7, 0.07, 'mean: ' + "{:.1f}".format(aver))
    ax2.text(10, 0.06, 'std: ' + "{:.1f}".format(std))
    ax2.set_xlabel('Height_diff(m)')
    
    ax3 = plt.subplot(143)
    temp = Hite_grouped.loc[Hite_grouped.Region==1]
    ax3.hist(temp.Height,bins=500,density=True)
    ax3.set_xlim([-20,20])
    ax3.set_yticklabels([])
    ax3.set_ylim([0,0.1])
    aver = np.mean(temp.Height)
    std = np.std(temp.Height)
    ax3.text(8.5, 0.09, 'Amazon',fontsize=12)
    ax3.text(7, 0.07, 'mean: ' + "{:.1f}".format(aver))
    ax3.text(10, 0.06, 'std: ' + "{:.1f}".format(std))
    ax3.set_xlabel('Height_diff(m)')
    
    ax4 = plt.subplot(144)
    temp = Hite_grouped.loc[Hite_grouped.Region==2]
    ax4.hist(temp.Height,bins=500,density=True)
    ax4.set_xlim([-20,20])
    ax4.set_yticklabels([])
    ax4.set_ylim([0,0.1])
    aver = np.mean(temp.Height)
    std = np.std(temp.Height)
    ax4.text(9.5, 0.09, 'SEAsia',fontsize=12)
    ax4.text(7, 0.07, 'mean: ' + "{:.1f}".format(aver))
    ax4.text(10, 0.06, 'std: ' + "{:.1f}".format(std))
    ax4.set_xlabel('Height_diff(m)')
    
    plt.tight_layout()
    plt.savefig('D:\\Postdoc\\Harvard\\GEDI_regional\\Figures\\Tropical_Simard_minus_GEDI_Hite_Hist.png', dpi = 600)
    return 0


def hite_hist_plot2(df):
    fig = plt.figure(figsize=(12, 3))
    ax1 = plt.subplot(141)
    ax1.hist(Hite_grouped.Height,bins=100,density=True)
    ax1.set_xlim([15,60])
    ax1.set_ylim([0,0.085])
    ax1.set_ylabel('Frequency')
    ax1.set_xlabel('Height(m)')
    aver = np.mean(Hite_grouped.Height)
    std = np.std(Hite_grouped.Height)
    ax1.text(44, 0.075, 'Pan Tropic',fontsize=12)
    ax1.text(45, 0.010, 'mean: ' + "{:.1f}".format(aver))
    ax1.text(45, 0.005, 'std: ' + "{:.1f}".format(std))
    
    ax2 = plt.subplot(142)
    temp = Hite_grouped.loc[Hite_grouped.Region==0]
    ax2.hist(temp.Height,bins=100,density=True)
    ax2.set_xlim([15,60])
    ax2.set_yticklabels([])
    ax2.set_ylim([0,0.085])
    aver = np.mean(temp.Height)
    std = np.std(temp.Height)
    ax2.text(45, 0.075, 'Africa',fontsize=12)
    ax2.text(45, 0.010, 'mean: ' + "{:.1f}".format(aver))
    ax2.text(45, 0.005, 'std: ' + "{:.1f}".format(std))
    ax2.set_xlabel('Height(m)')
    
    ax3 = plt.subplot(143)
    temp = Hite_grouped.loc[Hite_grouped.Region==1]
    ax3.hist(temp.Height,bins=100,density=True)
    ax3.set_xlim([15,60])
    ax3.set_yticklabels([])
    ax3.set_ylim([0,0.085])
    aver = np.mean(temp.Height)
    std = np.std(temp.Height)
    ax3.text(45, 0.075, 'Amazon',fontsize=12)
    ax3.text(45, 0.010, 'mean: ' + "{:.1f}".format(aver))
    ax3.text(45, 0.005, 'std: ' + "{:.1f}".format(std))
    ax3.set_xlabel('Height(m)')
    
    ax4 = plt.subplot(144)
    temp = Hite_grouped.loc[Hite_grouped.Region==2]
    ax4.hist(temp.Height,bins=100,density=True)
    ax4.set_xlim([15,60])
    ax4.set_yticklabels([])
    ax4.set_ylim([0,0.085])
    aver = np.mean(temp.Height)
    std = np.std(temp.Height)
    ax4.text(45, 0.075, 'SEAsia',fontsize=12)
    ax4.text(45, 0.010, 'mean: ' + "{:.1f}".format(aver))
    ax4.text(45, 0.005, 'std: ' + "{:.1f}".format(std))
    ax4.set_xlabel('Height(m)')
    
    plt.tight_layout()
    plt.savefig('D:\\Postdoc\\Harvard\\GEDI_regional\\Figures\\Tropical_Hite_Hist.png', dpi = 600)
    return 0


def hite_cov_hist_plot2(df):
    fig = plt.figure(figsize=(12, 3))
    ax1 = plt.subplot(141)
    ax1.hist(df.Height_cov,bins=100,density=True)
    ax1.set_xlim([15,60])
    ax1.set_ylim([0,0.085])
    ax1.set_ylabel('Frequency')
    ax1.set_xlabel('Height(m)')
    aver = np.mean(Hite_grouped.Height)
    std = np.std(Hite_grouped.Height)
    ax1.text(44, 0.075, 'Pan Tropic',fontsize=12)
    ax1.text(45, 0.010, 'mean: ' + "{:.1f}".format(aver))
    ax1.text(45, 0.005, 'std: ' + "{:.1f}".format(std))
    
    ax2 = plt.subplot(142)
    temp = Hite_grouped.loc[Hite_grouped.Region==0]
    ax2.hist(temp.Height,bins=100,density=True)
    ax2.set_xlim([15,60])
    ax2.set_yticklabels([])
    ax2.set_ylim([0,0.085])
    aver = np.mean(temp.Height)
    std = np.std(temp.Height)
    ax2.text(45, 0.075, 'Africa',fontsize=12)
    ax2.text(45, 0.010, 'mean: ' + "{:.1f}".format(aver))
    ax2.text(45, 0.005, 'std: ' + "{:.1f}".format(std))
    ax2.set_xlabel('Height(m)')
    
    ax3 = plt.subplot(143)
    temp = Hite_grouped.loc[Hite_grouped.Region==1]
    ax3.hist(temp.Height,bins=100,density=True)
    ax3.set_xlim([15,60])
    ax3.set_yticklabels([])
    ax3.set_ylim([0,0.085])
    aver = np.mean(temp.Height)
    std = np.std(temp.Height)
    ax3.text(45, 0.075, 'Amazon',fontsize=12)
    ax3.text(45, 0.010, 'mean: ' + "{:.1f}".format(aver))
    ax3.text(45, 0.005, 'std: ' + "{:.1f}".format(std))
    ax3.set_xlabel('Height(m)')
    
    ax4 = plt.subplot(144)
    temp = Hite_grouped.loc[Hite_grouped.Region==2]
    ax4.hist(temp.Height,bins=100,density=True)
    ax4.set_xlim([15,60])
    ax4.set_yticklabels([])
    ax4.set_ylim([0,0.085])
    aver = np.mean(temp.Height)
    std = np.std(temp.Height)
    ax4.text(45, 0.075, 'SEAsia',fontsize=12)
    ax4.text(45, 0.010, 'mean: ' + "{:.1f}".format(aver))
    ax4.text(45, 0.005, 'std: ' + "{:.1f}".format(std))
    ax4.set_xlabel('Height(m)')
    
    plt.tight_layout()
    plt.savefig('D:\\Postdoc\\Harvard\\GEDI_regional\\Figures\\Tropical_Hite_Hist.png', dpi = 600)
    return 0

# scatter plot between Hite and climatic variables
def hite_variable_plot(df):
    fig, axs = plt.subplots(2, 5, figsize=(12, 6))
    
    bins = [10000, 10000] # number of bins
    
    features = ['Elev','Prec','Temp_sea','Wind','Cloud','Ssrd', 'Light', 'DSL', 'Soil_Sand', 'Soil_SOC']
    X = df[['Elev','Prec','Temp_sea','Wind','Cloud','Ssrd', 'Light', 'DSL', 'Soil_Sand', 'Soil_SOC']]
    Y = df['Height_std']/df['Height']
    fid = open('/n/holylfs04/LABS/moorcroft_lab/Users/sqliu/GEDI_regiondata_Compactness/Statistical_Analysis/Hite_std_correlation.txt','w')
    for i in range(10):
        fid.write(str(np.corrcoef(Y,X[features[i]])[0,1]))
        fid.write(',' + features[i] + '\n')
    fid.close()
    #Y = df['Height_std']/df['Height']
    """
    for i in range(2):
        for j in range(5):
            ind = i*5 + j
            if (ind >0 and ind <5) or (ind >5):
                axs[i,j].set_yticklabels([])
            #axs[i,j].plot(X[features[ind]], Y,'bo',markersize=1,alpha=0.05)
            # histogram the data
            x = X[features[ind]].values
            y = Y.values
            hh, locx, locy = np.histogram2d(x, y, bins=bins)
            #Sort the points by density, so that the densest points are plotted last
            z = np.array([hh[np.argmax(a<=locx[1:]),np.argmax(b<=locy[1:])] for a,b in zip(x,y)])
            idx = z.argsort()
            x2, y2, z2 = x[idx], y[idx], z[idx]
            axs[i,j].scatter(x2, y2, c=z2, cmap='Wistia', marker='.', s=1) 
            
            axs[i,j].set_xlabel(features[ind])
            xAsInts = X[features[ind]].values
            A = np.vstack([xAsInts,np.ones(len(xAsInts))]).T
            m,c = np.linalg.lstsq(A,Y.values)[0]
            axs[i,j].plot(xAsInts,xAsInts*m+c,'r-')
    
    plt.tight_layout()
    plt.savefig('/n/holylfs04/LABS/moorcroft_lab/Users/sqliu/GEDI_regiondata_Compactness/figure/Tropical_Hite_std_scatterplot.png', dpi = 600)
    """
    return 0

# plot the sactter between predictions and three important climate variables
def variable_plot(df, var):
# var is a variable dict, var = {'Hite': [v1, v2, v3], 'PAI': [v4,v5,v6]}
    for i in range(3):
        variable = var['Hite'][i]
        
        fig = plt.figure()
        plt.plot(df[variable], df['Height'],'b.',markersize=1)
        plt.xlabel(variable)
        plt.ylabel('Hite')
        xAsInts = df[variable].values
        A = np.vstack([xAsInts,np.ones(len(xAsInts))]).T
        m,c = np.linalg.lstsq(A,df['Height'].values)[0]
        plt.plot(xAsInts,xAsInts*m+c,'r-')
        plt.savefig('Statistical_Analysis/' + variable + '_vs_Hite.png', dpi = 300)
        plt.close(fig)

    for i in range(3):
        variable = var['PAI'][i]
        
        fig = plt.figure()
        plt.plot(df[variable], df['TotalPAI'],'b.',markersize=1)
        plt.xlabel(variable)
        plt.ylabel('PAI')
        xAsInts = df[variable].values
        A = np.vstack([xAsInts,np.ones(len(xAsInts))]).T
        m,c = np.linalg.lstsq(A,df['TotalPAI'].values)[0]
        plt.plot(xAsInts,xAsInts*m+c,'r-')
        plt.savefig('Statistical_Analysis/' + variable + '_vs_PAI.png', dpi = 300)
        plt.close(fig)

    return 0


def marginal_plot(df):

    label = ['Elev','Prec','Temp_sea','Wind','Cloud','Fapar', 'Light', 'DSL', 'Soil_Sand', 'Soil_SOC']
    model0 = joblib.load('Statistical_Analysis/Hite_RF.joblib')
    model1 = joblib.load('Statistical_Analysis/PAI_RF.joblib')
    X = df[['Elev','Prec','Temp_sea','Wind','Cloud','Fapar', 'Light', 'DSL', 'Soil_Sand', 'Soil_SOC']]
    X0 = X.copy()
    for i in range(len(label)):
        X0[label[i]] = X[label[i]].median()
    X1 = X0.copy()

    for i in range(len(label)):
        X0 = X1.copy()
        X0[label[i]] = X[label[i]]
        y_pre=model0.predict(X0.values)
        plt.figure(1,figsize=(8,8)).clf()
        plt.plot(X0[label[i]].values, y_pre, 'r.', linestyle="None")
        plt.xlabel(label[i])
        plt.ylabel('Hite')
        plt.savefig('Statistical_Analysis/MarginalPlot_Hite_vs_' + label[i] + '.png', dpi = 300)
        
        y_pre=model1.predict(X0.values)
        plt.figure(1,figsize=(8,8)).clf()
        plt.plot(X0[label[i]].values, y_pre, 'r.',linestyle="None")
        plt.xlabel(label[i])
        plt.ylabel('Hite')
        plt.savefig('Statistical_Analysis/MarginalPlot_PAI_vs_' + label[i] + '.png', dpi = 300)
    return 0

def Partial_Dependence_plot(df):
    model = joblib.load('Statistical_Analysis/Hite_GB_Ssrd.joblib')
    #X = df[['Elev','Prec','Temp_sea','Wind','Cloud','Fapar', 'Light', 'DSL', 'Soil_Sand', 'Soil_SOC']]
    X = df[['Elev','Prec','Temp_sea','Wind','Cloud','Ssrd', 'Light', 'DSL', 'Soil_Sand', 'Soil_SOC']]
    #features = ['Elev','Prec','Temp_sea','Wind','Cloud','Fapar', 'Light', 'DSL', 'Soil_Sand', 'Soil_SOC']
    features = [('Ssrd','DSL'), ('Ssrd','Elev'), ('DSL','Elev')]
    fig, ax = plt.subplots(figsize=(14, 6))
    plot_partial_dependence(model, X, features,n_jobs=-1, grid_resolution=100, ax=ax, kind='average', subsample=0.95,random_state=0)
    plt.savefig('figure/Hite_Partial_Dependence_plot_Ssrd_GB_2way.png', format='png', dpi=300)
    #results = partial_dependence(model, features, X=X, percentiles=(0, 1), grid_resolution=100,kind='both')
    """
    fig, ax = plt.subplots(figsize=(14, 14))
    plot_partial_dependence(model, X, features,n_jobs=-1, grid_resolution=100, ax=ax, kind='average', subsample=0.95,random_state=0)
    plt.savefig('figure/Hite_Partial_Dependence_plot_Ssrd_GB_1way.png', format='png', dpi=300)
    plt.close('all')
    """
    """
    features = [('Ssrd','DSL'), ('Ssrd','Elev'), ('DSL','Elev'),('Temp_sea','DSL'), ('Temp_sea','Elev'), ('Temp_sea','Ssrd')]
    fig, ax = plt.subplots(figsize=(14, 14))
    plot_partial_dependence(model, X, features,n_jobs=-1, grid_resolution=100, ax=ax, kind='average', subsample=0.95,random_state=0)
    plt.savefig('figure/Hite_Partial_Dependence_plot_Ssrd_GB_2way.png', format='png', dpi=300)
    """
    return 0

def Skewness_Climate_plot(df): 
    # plot the relationship between skewness and climate when height and totalPAI are fixed
    # PAI resolution is 0.1, height resolution is 1.0
    # header is [TotalPAI, Height, var, p_value, slope]
    feature_names = ['Elev','Prec','Temp_sea','Wind','Cloud','Fapar', 'Light', 'DSL', 'Soil_Sand', 'Soil_SOC']
    hite_max = max(df.Height)
    pai_max = max(df.TotalPAI)
    hite = np.arange(15, hite_max, 1)
    pai = np.arange(0, pai_max, 0.1)
    
    for k in range(len(feature_names)):
        covariation = []
        PAI = []
        Hite = []
        for i in range(len(hite)):
            for j in range(len(pai)):
                data = df.loc[(df.TotalPAI == pai[j]) & (df.Height == hite[i]) & (df['var'] == feature_names[k])]
                if data.p_value.values <= 0.01:
                    temp = 1 if data.slope.values > 0 else 0
                    covariation.append(temp)
                    PAI.append(pai[j])
                    Hite.append(hite[i])
        plt.scatter(PAI, Hite,c=covariation,cmap='seismic', marker='s', s=2)
        plt.xlabel('PAI')
        plt.ylabel('Hite')
        plt.grid()
        plt.savefig('D:/Postdoc/Harvard/GEDI_regional/Figures/Skewness_' + feature_names[k] + '.png', format='png', dpi=300)
    return 0

def Lime_Spatial_plot(df):
    df=df.replace(to_replace="Ssrd",value=1)
    df=df.replace(to_replace="Elev",value=2)
    df=df.replace(to_replace="DSL",value=3)
    df['var1'] = pd.to_numeric(df.var1, errors='coerce').fillna(999).astype(np.int64)
    df['var2'] = pd.to_numeric(df.var2, errors='coerce').fillna(999).astype(np.int64)
    df['var3'] = pd.to_numeric(df.var3, errors='coerce').fillna(999).astype(np.int64)
    df=df.replace(to_replace=999,value=4)
    return df

def Lime_2_Netcdf(df):
    minlon = min(df.Lon)
    maxlon = max(df.Lon)
    loncount = int((maxlon-minlon)/0.01) + 1
    LON = np.linspace(minlon, maxlon, loncount)
    minlat = min(df.Lat)
    maxlat = max(df.Lat)
    latcount = int((maxlat-minlat)/0.01) + 1
    LAT = np.linspace(minlat, maxlat, latcount)
    
    ncfile = Dataset('D:/Postdoc/Harvard/GEDI_regional/Results/New_Lime/Spatial_TreeType_ED2.nc',mode='w',format='NETCDF4_CLASSIC')
    lat_dim = ncfile.createDimension('lat', latcount)
    lon_dim = ncfile.createDimension('lon', loncount)
    ncfile.title='Hite Lime first three features'
    ncfile.subtitle="Output contains: Fapar, Elev, DSL, and others"
    
    lat = ncfile.createVariable('lat', np.float32, ('lat',))
    lat.units = 'degrees_north'
    lat.long_name = 'latitude'
    lat[:] = LAT
    lon = ncfile.createVariable('lon', np.float32, ('lon',))
    lon.units = 'degrees_east'
    lon.long_name = 'longitude'
    lon[:] = LON
    
    var1 = np.zeros((latcount, loncount))
    var2 = np.zeros((latcount, loncount))
    var3 = np.zeros((latcount, loncount))
    
    df['lat_index'] = ((df.Lat-minlat)/0.01).astype(int)
    df['lon_index'] = ((df.Lon-minlon)/0.01).astype(int)

    #var1[df['lat_index'],df['lon_index']] = df.var1_var2
    var1[df['lat_index'],df['lon_index']] = df.var1
    var2[df['lat_index'],df['lon_index']] = df.var2
    var3[df['lat_index'],df['lon_index']] = df.var3
    
    VAR1 = ncfile.createVariable('VAR1',np.float64,('lat','lon'))
    VAR1.standard_name = '1st feature'
    VAR1[:] = var1
    
    VAR2 = ncfile.createVariable('VAR2',np.float64,('lat','lon'))
    VAR2.standard_name = '2nd feature'
    VAR2[:] = var2

    VAR3 = ncfile.createVariable('VAR3',np.float64,('lat','lon'))
    VAR3.standard_name = '3rd feature'
    VAR3[:] = var3
    ncfile.close()
    return 0
    
# plot the histgram of lime variable over regions
def Hite_Lime_hist(df):
    labels=['Radiation', 'Elevation', 'DSL', 'Other']
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    df['Region'] = -9999
    df.loc[(df.Lat<9) & (df.Lat>-19) & (df.Lon>-80) & (df.Lon<-44),'Region'] = 1
    df.loc[(df.Lat<10) & (df.Lat>-10) & (df.Lon>-15) & (df.Lon<33),'Region'] = 0
    df.loc[(df.Lat<29) & (df.Lat>-11) & (df.Lon>92) & (df.Lon<142),'Region'] = 2
    
    df.loc[df.var1>3,'var1'] = 4
    for i in range(3):
        temp = df.loc[df.Region==i,'var1'].values
        Fapar_freq = np.count_nonzero(temp == 1) / np.size(temp)
        Elev_freq = np.count_nonzero(temp == 2) / np.size(temp)
        DSL_freq = np.count_nonzero(temp == 3) / np.size(temp)
        Other_freq = np.count_nonzero(temp == 4) / np.size(temp)
        freq = [Fapar_freq,Elev_freq,DSL_freq,Other_freq]
        x_pos = np.arange(len(labels))
        axs[i].set_xticks(x_pos)
        axs[i].set_xticklabels(labels)
        axs[i].set_ylim([0,0.9])
        if i>0:
            axs[i].set_yticklabels([])
        if i==0:
            axs[i].set_ylabel('Percentage')
        axs[i].bar(x_pos, freq, color=['tab:blue', 'red', 'm', 'c'])
    plt.tight_layout()
    plt.savefig('D:\\Postdoc\\Harvard\\GEDI_regional\\Figures\\New_Figures_Ssrd_RF\\Tropical_Hite_Lime_hist.png', dpi = 600)
    return 0

# plot the grouped spatial variable over Pan-Tropic
def Spatial_plot_grouped(var):
    color_scheme = {'Ssrd':'jet','Elev':'BrBG','DSL':'RdYlGn'}
    projection=ccrs.PlateCarree()
    bbox=[-90,150,-19,29];creg='glob'
    mplt.rc('xtick', labelsize=9) 
    mplt.rc('ytick', labelsize=9)
   
    nc = Dataset('/n/holylfs04/LABS/moorcroft_lab/Users/sqliu/GEDI_regiondata_Compactness_Reprocess/GEDI_grouped_' + var + '.nc','r')
    lat = nc.variables['lat'][:]
    lon = nc.variables['lon'][:]
    kopi= nc.variables[var][:]
    kopi=ma.masked_equal(kopi, -9999)
    #kopi=ma.masked_equal(kopi, 0.0)
    kopi = kopi.astype("float")
    kopi[kopi==0.0] = np.nan
    nc.close()
    lon_0 = 0
    lat_0 = 0

    plt.figure(figsize=(12,8))
    m = Basemap(width=5000000,height=3500000,
            resolution='c',projection='cyl',\
            lat_0=lat_0,lon_0=lon_0,llcrnrlon=-90.,llcrnrlat=-19.,urcrnrlon=150.,urcrnrlat=29)

    lons, lats = np.meshgrid(lon, lat)
    xi, yi = m(lons, lats)
    if var == 'Hite' or 'Height_cov':
        orig_map=plt.cm.get_cmap('autumn')
        reversed_map = orig_map.reversed()
        cs = m.pcolor(xi,yi,np.squeeze(kopi),vmin=0,vmax=0.4,cmap=reversed_map)
    if var == 'Ssrd':
        kopi = kopi
        cs = m.pcolor(xi,yi,np.squeeze(kopi),vmin=10000,vmax=20000,cmap=plt.cm.jet)
    elif var == 'Elev':
        orig_map=plt.cm.get_cmap('BrBG')
        reversed_map = orig_map.reversed()
        cs = m.pcolor(xi,yi,np.squeeze(kopi),vmin=-50,vmax=800,cmap=reversed_map)
    elif var == 'DSL':
        orig_map=plt.cm.get_cmap('RdYlGn')
        reversed_map = orig_map.reversed()
        cs = m.pcolor(xi,yi,np.squeeze(kopi),vmin=0,vmax=10,cmap=reversed_map)
    # Add Grid Lines
    m.drawparallels(np.arange(-80., 81., 10.), linewidth=0.1,labels=[1,0,0,0], fontsize=10)
    m.drawmeridians(np.arange(-180., 181., 30.), linewidth=0.1,labels=[0,0,0,1], fontsize=10)
    m.readshapefile('/n/holylfs04/LABS/moorcroft_lab/Users/sqliu/GEDI_regiondata_Compactness/shapefile/Amazon_Basin_WGS84','Amazon_Basin_WGS84.shp')
    m.readshapefile('/n/holylfs04/LABS/moorcroft_lab/Users/sqliu/GEDI_regiondata_Compactness/shapefile/Africa_Tropical_Forest','Africa_Tropical_Forest.shp')
    m.readshapefile('/n/holylfs04/LABS/moorcroft_lab/Users/sqliu/GEDI_regiondata_Compactness/shapefile/SEAsia_Dissolved','SEAsia_Dissolved.shp')
    # Add Coastlines, States, and Country Boundaries
    m.drawcoastlines(color='grey', linewidth=0.3)
    m.drawmapboundary(fill_color='#FAFAFA')
    #m.drawstates()
    #m.drawcountries()

    # Add Colorbar
    cbar = m.colorbar(cs, location='right', pad="5%")
    plt.savefig('/n/holylfs04/LABS/moorcroft_lab/Users/sqliu/GEDI_regiondata_Compactness_Reprocess/figure/GEDI_grouped_'+ var + '.png',dpi=600)
    pyplot.clf()
    return 0

    
def Simard_vs_GEDI_plot(df):
    colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)', (1, 1, 0.2), (0.98,0.98,0.98)]
    
    x=df['Simard_Height'].values
    y=df['Height'].values
    bins = [100, 100] # number of bins
    
    # histogram the data
    hh, locx, locy = np.histogram2d(x, y, bins=bins)
    
    #Sort the points by density, so that the densest points are plotted last
    z = np.array([hh[np.argmax(a<=locx[1:]),np.argmax(b<=locy[1:])] for a,b in zip(x,y)])
    idx = z.argsort()
    x2, y2, z2 = x[idx], y[idx], z[idx]

    plt.figure(1,figsize=(8,8)).clf()
    s = plt.scatter(x2, y2, c=z2, cmap='YlOrBr', marker='.') 

    # reduced major axis regression
    results = regress2(x, y, _method_type_2="reduced major axis")
    plt.plot(x,x*results['slope']+results['intercept'],'r-')
    print(results['r'])
    plt.xlabel('Simard_Height')
    plt.ylabel('GEDI_Height')
    plt.xlim([0, 60])
    plt.ylim([0, 60])
    plt.plot([0,60],[0,60],'r--') # 1:1 line
    plt.savefig('figure/Simard_vs_GEDI_plot.png', dpi = 300)
    return 0

def PCA_plot(df):
    feature_names = ['Elev','Prec','Temp_sea','Wind','Cloud','Fapar', 'Light', 'DSL', 'Soil_Sand', 'Soil_SOC']
    X = df[['Elev','Prec','Temp_sea','Wind','Cloud','Ssrd', 'Light', 'DSL', 'Soil_Sand', 'Soil_SOC']].values
    Y = df['Height'].values
    Y[Y>60]=60
    scaler = StandardScaler()
    scaler.fit(X)
    X=scaler.transform(X)
    pca = PCA() # three components explains 60% variations
    pca.fit(X)
    x_new = pca.transform(X)  # dimension (data x 3)
    
    fig, ax = plt.subplots(1,3,figsize=(12,4))
    
    # 1st and 2nd comp
    score = x_new[:,1:3]
    xs = score[:,0]
    ys = score[:,1]
    coeff = pca.components_.T
    ax[0].scatter(xs ,ys,s=0.5,c=Y,cmap='plasma') #without scaling
    n = coeff.shape[0]
    for i in range(n):
        ax[0].arrow(0, 0, 10*coeff[i,0], 10*coeff[i,1],color = 'k',alpha = 0.5)
        ax[0].text(coeff[i,0]* 10, coeff[i,1] * 10, s=feature_names[i], color = 'g', ha = 'center', va = 'center')
    plt.savefig('figure/PCA_3rd_2nd.png', format='png', dpi=300)
    print(np.cumsum(pca.explained_variance_ratio_))
    loadings = pca.components_.T * np.sqrt(pca.explained_variance_)
    print(loadings)
    return 0



Hite_grouped = pd.read_csv('/n/holylfs04/LABS/moorcroft_lab/Users/sqliu/GEDI_regiondata_Compactness/Beta_fitting_Climate_Soil_Tropic_Grouped.csv')
Hite_grouped = Hite_grouped[['Height','Elev','Prec','Temp_sea','Wind','Cloud','Ssrd', 'Light', 'DSL', 'Soil_Sand', 'Soil_SOC']]
Hite_grouped.dropna(inplace=True)
Hite_grouped = Hite_grouped[['Height','Region']]
hite_hist_plot(Hite_grouped)
hite_variable_plot(Hite_grouped)
PCA_plot(Hite_grouped)


Partial_Dependence_plot(Hite_grouped)

label = ['Hite']
for var in label:
    Spatial_plot_grouped(var)