#!/bin/python
#
# histmcmc.py
# created by Y.Sawawa 20190610
#
#20190610 created
#20190730 updated for paper-grade figures
#20230106 updated for hoope-enkf

from pylab import *
from scipy import stats
from sklearn.neighbors import KernelDensity
import numpy as np
import matplotlib.pyplot as plt

def histplot(data,truth,mu,sigma):
 plt.hist(data,bins=50,range=(1,20))
 xx = np.linspace(1,20,50)
 yy = 50000*np.exp(-(xx-mu)**2/(2*sigma**2))
 #xx = np.ones((20000))*truth
 plt.plot(xx,yy,'r--')
 ylim(0,50000)
 xlim(1,20)
 plt.tick_params(labelsize=12)

#
# parameters
# 
truth = np.zeros((5))
truth[0] = 0.40
truth[1] = (8/3)/15
truth[2] = 0.40
truth[3] = 0.60
#
# load MCMC samples
#
#data = loadtxt('./exp0_fixedpara/MCMCsamplesfrom500_yz_fixedpara')
#data = loadtxt('./exp1_abruptshift/MCMCsamplesfrom500_yz_abruptshift')
data = loadtxt('./MCMCsamples_3')
#data = loadtxt('./data_20210614/MCMCsamplesfrom1500_yz')

data = 1 + (20-1) * data
mu = np.mean(data[100000:])
sigma = np.std(data[100000:])
print(mu,sigma)
#
# kernel density estimation (kde)
#
#x = data[:,0]
#y = data[:,1]
#xy = np.vstack([x,y])
#print(xy)
#weights = np.ones((100000))/100000.0
#kde = KernelDensity(kernel='gaussian',bandwidth=0.02).fit(xy.T,sample_weight=weights)
#XX, YY = np.mgrid[0:1:0.02,0:1:0.02]
#positions = np.vstack([XX.ravel(),YY.ravel()])
#score = kde.score_samples(positions.T)
#score = score.reshape(50,50)
#print(shape(score))
#normalize = sum(exp(score))
#plt.imshow(np.rot90(exp(score)/normalize),cmap=plt.cm.viridis)
#cb = plt.colorbar(shrink=0.75)
#plt.show()

#
# draw histogram
#
fig = plt.figure(figsize=(15,15))

for i in range(0,1):
 plt.subplot(1,1,i+1)
 #histplot(data[:100000,i],truth[i])
 histplot(data[100000:],truth[i],mu,sigma)
#plt.savefig('./exp0_fixedpara/hist.png')
plt.savefig('./hist_test1000_smlerr.png')
clf()
sys.exit()
# 
# Joint distribution
#
for i in range (0,5):
 for j in range (i,5):
  heatmap, xedges, yedges = np.histogram2d(data[:,i],data[:,j],bins=50)
  plt.imshow(heatmap, extent=[0,1,1,0], interpolation='nearest')
  xlim(0,1)
  ylim(0,1) 
  cb = plt.colorbar(shrink=0.75)
  cb.ax.tick_params(labelsize=24)
  plt.tick_params(labelsize=24)
#
# calculating correlation
#
  slope, intercept, r_value, p_value, std_err = stats.linregress(data[:,i],data[:,j])
  if (p_value < 0.001):
   title = 'Slope = '+str(round(slope,3))+', R = '+str(round(r_value,3))+'(p<0.001)'
  else:
   title = 'no significant correlation'
  plt.title(title,fontsize=20)
  figname = './exp3_hymod/jointdistfrom500_test25_'+str(i)+str(j)+'.png'
  plt.savefig(figname)
  clf()
