
# coding: utf-8

# In[143]:


import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import sys
import copy

import matplotlib.transforms as transforms

plt.rcParams.update({"font.size":15})


# In[41]:


################### input parameters ########################################################
WFIPixelLength = 130 * 10**-6 #meters
#WFIIntegrationTime = 15 * 10**-6 # seconds
WFIIntegrationTime = 1.3 * 10**-3 # seconds

#SpectrumNormalisationFactor = 5.11 #particles / cm2 / s
#SpectrumNormalisationFactor = 278.72 #particles / cm2 / s
SpectrumNormalisationFactor = 4.6858 #particles / cm2 / s

generationSphereRadius = 9 * 10**-2 # meters
generationSphereArea = 4.0 * np.pi * (generationSphereRadius**2)
generationSphereAreaIncm = generationSphereArea * 10**4

particlesPerFrame = (SpectrumNormalisationFactor / (4.0 * np.pi)) * generationSphereAreaIncm * WFIIntegrationTime
#particlesPerFrame = (SpectrumNormalisationFactor / (4.0)) * generationSphereAreaIncm * WFIIntegrationTime
print("particlesPerFrame =",particlesPerFrame)

xmax = 0.02
xmin = -0.02
ymax = 0.02
ymin = -0.02

###############################################################################################

# In[4]:


#inputDirName = "GCRonly/_VaryingDetectorThickness_29-07-19-Monday-16_27_20_MaterialG4_Al_PShieldThickness4_detectorThickness15_batch_outputs"
#inputDirName = sys.argv[1]


# In[70]:

class spectrum:
    
    def __init__(self, 
                 spectraDF, 
                 totTimeOfIntegration = WFIIntegrationTime ,
                 eMin = 0.1, 
                 eMax=15, 
                 nbins = 50, 
                 cutString = " edep", 
                 minVal = None, 
                 maxVal = None, 
                 fluxType = "GCR", 
                 saveDF = True, 
                 saveFullDF = False):
        #self.extNormConst = 1
        self.eMin = eMin
        self.eMax = eMax

        self.minVal = minVal
        self.maxVal = maxVal

        self.nbins = nbins

        self.cutString = cutString

        self.saveDF = saveDF
        self.saveFullDF = saveFullDF

        if saveFullDF == True:
            #self.RawSpectraDFFull = spectraDF
            self.RawSpectraDFFull = spectraDF[(spectraDF[" edep"] < 15) & (spectraDF[" edep"] > 0.1)]

        if saveDF == True:
            #self.RawSpectraDF = spectraDF[cutString]
            self.RawSpectraDF = spectraDF[(spectraDF[" edep"] < 15) & (spectraDF[" edep"] > 0.1)][cutString]

        self.lastEid = spectraDF[" eid"].tail(2).iloc[0]

        
        if cutString == " edep":
            logBins = np.logspace(np.log10(eMin),np.log10(eMax),nbins)
            (histCounts, histEdges) = np.histogram(spectraDF[cutString],bins=logBins)
            hitNumber = len(spectraDF[cutString])
            if saveDF == True:
                self.spectraDF = spectraDF[cutString]
        elif cutString == " eprimary" or cutString == " ecur":
            logBins = np.logspace(np.log10(minVal),np.log10(maxVal),nbins)
            (histCounts, histEdges) = np.histogram(spectraDF[(spectraDF[" edep"] > eMin) & (spectraDF[" edep"] < eMax)][cutString],bins=logBins)
            hitNumber = len(spectraDF[(spectraDF[" edep"] > eMin) & (spectraDF[" edep"] < eMax)][cutString])
            if saveDF == True:
                self.spectraDF = spectraDF[(spectraDF[" edep"] > eMin) & (spectraDF[" edep"] < eMax)][cutString]
        else:
            linBins = np.linspace(minVal,maxVal,nbins)
            (histCounts, histEdges) = np.histogram(spectraDF[(spectraDF[" edep"] > eMin) & (spectraDF[" edep"] < eMax)][cutString],bins=linBins)
            hitNumber = len(spectraDF[(spectraDF[" edep"] < eMax) & (spectraDF[" edep"] < eMax)][cutString])
            if saveDF == True:
                self.spectraDF = spectraDF[(spectraDF[" edep"] > eMin) & (spectraDF[" edep"] < eMax)][cutString]
            print("creating the following spectrum :",spectraDF[(spectraDF[" edep"] > eMin) & (spectraDF[" edep"] < eMax)][cutString].head())

        self.histEdgeLefts = histEdges[0:-1]
        self.histEdgeDiff = histEdges[1:] - histEdges[0:-1]

        
        #self.integratedTime = 1/((SpectrumNormalisationFactor / (4.0*np.pi)) * generationSphereArea)
        
        #self.normConst = (1 / ((xmax - xmin) * (ymax - ymin) * 10**4 * totTimeOfIntegration))
        
        if fluxType == "GCR":
            #fluxIntConst = 5.1138
            fluxIntConst_particles_per_cm2_per_s = 4.6858 #particles / cm2 / s
        elif fluxType == "CXB":
            fluxIntConst_particles_per_cm2_per_s = 41.117

        generationSphereRadius_cm = 9
        generationSphereArea_cm2 = 4 * np.pi * (generationSphereRadius_cm**2)

        fluxIntConst_particles_per_s = fluxIntConst_particles_per_cm2_per_s * generationSphereArea_cm2

        detectorWidth_cm = 4
        detectorArea_cm2 = detectorWidth_cm**2

        fluxIntConst_per_cm2_per_s = fluxIntConst_particles_per_s / (self.lastEid * detectorArea_cm2)

        self.normConst = fluxIntConst_per_cm2_per_s

        self.histValues = histCounts * self.normConst / self.histEdgeDiff

        totEDiff = np.sum(self.histEdgeDiff)
        
        sdCounts = np.sqrt(histCounts)
        self.sdValues = sdCounts * self.normConst / self.histEdgeDiff
    
        #print(hitNumber)
        #print(self.normConst)

        self.hitNumberNorm = hitNumber * self.normConst
        #self.hitNumberNorm = sum(self.histValues * self.histEdgeDiff)
        #print("sum(histValues * self.histEdgeDiff / self.normConst) is",sum(self.histValues * self.histEdgeDiff / self.normConst))
        #print("hitNumber is", hitNumber)
        #print("hitNumberNorm is ",hitNumber * self.normConst)

    def regenerateSpectrum(self, eMin = 0.1, eMax = 15, nbins = 50): #currently only works if no other spectrum functions have been used
        self.eMin = eMin
        self.eMax = eMax
        self.nbins = nbins
        spectraDF = self.spectraDF

        (histCounts, histEdges) = np.histogram(spectraDF[" edep"],bins=np.logspace(np.log10(eMin),np.log10(eMax),nbins))
        self.histEdgeLefts = histEdges[0:-1]
        self.histEdgeDiff = histEdges[1:] - histEdges[0:-1]


        #self.integratedTime = 1/((SpectrumNormalisationFactor / (4.0*np.pi)) * generationSphereArea)

        self.normConst = (1 / ((xmax - xmin) * (ymax - ymin) * 10**4 * totTimeOfIntegration))
        self.histValues = histCounts * self.normConst / self.histEdgeDiff

        sdCounts = np.sqrt(histCounts)
        self.sdValues = sdCounts * self.normConst / self.histEdgeDiff

    def plot(self,includeStd=True,quantiles=None,quantileColorList=["cornflowerblue", "blue", "navy", "darkorange"],zorder=49,**xargs):

        #print("plotting these arrays :",np.append(self.histEdgeLefts,1000), np.append(self.histValues, self.histValues[-1]))
        if self.cutString == " edep":
            maxPlottingValue = self.eMax
        else:
            maxPlottingValue = self.maxVal

        stepLine, = plt.step(np.append(self.histEdgeLefts,maxPlottingValue), 
                np.append(self.histValues, self.histValues[-1]),where='post',zorder=zorder,**xargs)

        stepColor = plt.gca().lines[-1].get_color()
        #stepColor = stepLine[-1].get_color()

        if includeStd == True:
            #plt.step(self.histEdgeLefts, self.histValues + self.sdValues,**xargs)
            #plt.step(self.histEdgeLefts, self.histValues - self.sdValues,**xargs)

            plt.fill_between(np.append(self.histEdgeLefts, maxPlottingValue),
                    np.append((self.histValues - self.sdValues),(self.histValues - self.sdValues)[-1]),
                    np.append((self.histValues + self.sdValues),(self.histValues + self.sdValues)[-1]),
                         interpolate=False,
                         step='post',
                         #color='cyan',
                         color=stepColor,
                         alpha=0.5,
                         zorder = zorder,
                         label="_hiddenLabel"),
                         #**xargs)

        plt.grid(True)
        plt.yscale("log")
        plt.xscale("log")

        plt.xlabel("Energy ($keV$)")
        plt.ylabel("cts / $cm^2$ / $keV$ / $s$")
        
        if self.cutString == " edep":
            plt.xlim([self.eMin,self.eMax])
        else:
            plt.xlim([self.minVal,self.maxVal])

        if quantiles != None:
            quantVals = np.quantile(self.spectraDF,quantiles)
            #print(self.spectraDF.head())
            print("quantVals are", quantVals)

            index = 0
            for quant in quantVals:
                #pass
                plt.axvline(quant, ls='--',color=stepColor, zorder=50, label='_nolegend_')
                #plt.axvline(quant, ls='--',color=quantileColorList[index], zorder=50, label='_nolegend_')
                trans = transforms.blended_transform_factory(plt.gca().transData, plt.gca().transAxes)
                plt.text(quant,1.02,str(int(quantiles[index] * 100)) + "%", 
                    transform=trans,
                    horizontalalignment="center", 
                    fontsize=10,
                    #color=quantileColorList[index])
                    color=stepColor)

                index += 1

    def getQuantiles(self, quantiles=[0.25,0.5,0.75,0.9]):

        quantVals = np.quantile(self.spectraDF,quantiles)
        #print(self.spectraDF.head())
        print("quantVals are", quantVals)
        return quantVals

    def plotCumQuantile(self, xPos,yPos, ls="--", color="skyblue"):

        pointTrans = lambda x,y:plt.gca().transAxes.inverted().transform(plt.gca().transData.transform((x,y)))

        plt.axvline(xPos, ymax = pointTrans(xPos,yPos)[1], ls=ls, color=color, label='_nolegend_')
        plt.axhline(yPos, xmax = pointTrans(xPos,yPos)[0], ls=ls, color=color, label='_nolegend_')
        #plt.text(10.1,0,'blah',rotation=90)

    def cumPlot(self,includeStd=True,quantiles=[0.25,0.5,0.75,0.9],colorList=["cornflowerblue", "blue", "navy", "darkorange"],
            xmin=10**1,xmax=10**4,xscale="log",**xargs):

        #print("plotting these arrays :",np.append(self.histEdgeLefts,1000), np.append(self.histValues, self.histValues[-1]))


        xvalsToPlot = self.histEdgeLefts

        yvalsRaw = np.cumsum(self.histValues * self.histEdgeDiff)
        #yvalsToPlot = yvalsRaw / yvalsRaw[-1]
        yvalsToPlot = yvalsRaw / self.hitNumberNorm
        #print("hitNumberNorm is ",self.hitNumberNorm)
        #print("hitNumberRaw is ",self.hitNumberNorm / (self.normConst / (self.eMax - self.eMin)))
        #print("edgeDiffs is ",self.histEdgeDiff)
        #print(yvalsRaw[-1])

        yvalsRawSD = np.sqrt(np.cumsum((self.sdValues * self.histEdgeDiff)**2))
        #yvalsToPlotSD = yvalsRawSD / yvalsRaw[-1] 
        yvalsToPlotSD = yvalsRawSD / self.hitNumberNorm

        if self.cutString == " edep":
            maxPlottingValue = self.eMax
        else:
            maxPlottingValue = self.maxVal

        stepLine, = plt.step(np.append(xvalsToPlot,maxPlottingValue), 
                np.append(yvalsToPlot, yvalsToPlot[-1]),where='post',**xargs)

        stepColor = plt.gca().lines[-1].get_color()
        #stepColor = stepLine[-1].get_color()

        if includeStd == True:
            #plt.step(self.histEdgeLefts, self.histValues + self.sdValues,**xargs)
            #plt.step(self.histEdgeLefts, self.histValues - self.sdValues,**xargs)

            plt.fill_between(np.append(xvalsToPlot, maxPlottingValue),
                    np.append((yvalsToPlot - yvalsToPlotSD),(yvalsToPlot - yvalsToPlotSD)[-1]),
                    np.append((yvalsToPlot + yvalsToPlotSD),(yvalsToPlot + yvalsToPlotSD)[-1]),
                         interpolate=False,
                         step='post',
                         #color='cyan',
                         color=stepColor,
                         alpha=0.5)
                         #**xargs)

        plt.grid(True)
        #plt.yscale("log")
        plt.xscale(xscale)

        plt.xlabel("Energy ($keV$)")
        plt.ylabel("cts / $cm^2$ / $s$")
        
        #if self.cutString == " edep":
        #    plt.xlim([self.eMin,self.eMax])
        #else:
        #    plt.xlim([self.minVal,self.maxVal])

        plt.xlim(xmin,xmax)

        plt.ylim([0,1])

        if quantiles != None:
            quantVals = np.transpose((quantiles,np.quantile(self.spectraDF,quantiles)))
            #print(self.spectraDF.head())
            #print("quantVals are", quantVals)

            index=0
            
            for quant in quantVals:
                #pass
                #plt.axvline(quant, ls='--',color=stepColor, zorder=50)

                #self.plotCumQuantile(quant[1],quant[0], color=colorList[index])
                self.plotCumQuantile(quant[1],quant[0], color=stepColor)

                #if quant[0] > 0.85:
                #    self.plotCumQuantile(quant[1],quant[0], color="darkorange")
                #else:
                #    self.plotCumQuantile(quant[1],quant[0], color="navy")

                index += 1
        
    def __add__(self, right):
        
        newSpectrum = copy.copy(self)
        
        newSpectrum.histValues = self.histValues + right.histValues
        newSpectrum.sdValues = np.sqrt((self.sdValues**2) + (right.sdValues**2))
        #newSpectrum.sdValues = np.sqrt(newSpectrum.histValues * histEdgeDiff / newSpectrum.normConst) * self.normConst / self.histEdgeDiff
        if self.saveDF == True:
            newSpectrum.spectraDF = pd.concat([self.spectraDF,right.spectraDF])
        if self.saveFullDF == True:
            newSpectrum.RawSpectraDFFull = pd.concat([self.RawSpectraDFFull,right.RawSpectraDFFull])

        newSpectrum.hitNumberNorm += right.hitNumberNorm
        
        return newSpectrum

    def __mul__(self,right):
        newSpectrum = copy.copy(self)

        newSpectrum.histValues = self.histValues * right
        newSpectrum.sdValues = self.sdValues * right
        newSpectrum.hitNumberNorm = self.hitNumberNorm * right

        #self.extNormConst = self.extNormConst / right

        return newSpectrum
    
    def __truediv__(self,right):
        
        newSpectrum = copy.copy(self)
        
        newSpectrum.histValues = self.histValues / right
        newSpectrum.sdValues = self.sdValues / right
        newSpectrum.hitNumberNorm = self.hitNumberNorm / right

        #self.extNormConst = self.extNormConst / right
        
        return newSpectrum

    def getCountsBetweenRange(self, eMin, eMax):

        DFWithinRange = self.spectraDF[(self.spectraDF[" edep"] > eMin) &
                                        (self.spectraDF[" edep"] < eMax)]

        countsWithinRange = len(DFWithinRange)
        sdCounts = np.sqrt(countsWithinRange)

        #valueWithinRange = countsWithinRange * self.normConst / (eMax - eMin)
        #sdValueWithinRange = sdCounts * self.normConst / (eMax - eMin)

        return (countsWithinRange, sdCounts)

    def getIntegralBetweenRange(self, eMin, eMax):
        (histCounts, histEdges) = np.histogram(self.spectraDF[" edep"],bins=np.logspace(np.log10(0.1),np.log10(15),50))

        if len(histCounts) == 0:
            return (0.0, 0.0)

        FNNormArray = self.histValues / (histCounts * self.normConst / self.histEdgeDiff)

        if len(FNNormArray[np.isfinite(FNNormArray)]) == 0:
            return (0.0, 0.0)

        FNNormConst = FNNormArray[np.isfinite(FNNormArray)][0]

        DFWithinRange = self.spectraDF[(self.spectraDF[" edep"] > eMin) &
                                        (self.spectraDF[" edep"] < eMax)]

        countsWithinRange = len(DFWithinRange)
        sdCounts = np.sqrt(countsWithinRange)

        valueWithinRange = FNNormConst * countsWithinRange * self.normConst #/ (eMax - eMin)
        sdValueWithinRange = FNNormConst * sdCounts * self.normConst #/ (eMax - eMin)

        return (valueWithinRange, sdValueWithinRange)

    def getIntegralBetweenRangeV2(self, eMin, eMax):
        
        relevantPositions = (self.histEdgeLefts > eMin) & (self.histEdgeLefts < eMax) 
        outputIntegral = sum(self.histValues[relevantPositions] * self.histEdgeDiff[relevantPositions])
        outputSD = np.linalg.norm(self.sdValues[relevantPositions] * self.histEdgeDiff[relevantPositions])

        return (outputIntegral, outputSD)


#def getFullSpectrumFromDF(edepInRangeDF):
#    minFrameNumber = np.unique(edepInRangeDF[" frameNumber"])[0]
#    
#    for frameNumber in np.unique(edepInRangeDF[" frameNumber"]):
#        if frameNumber == minFrameNumber:
#            fullSpectrum = spectrum(edepInRangeDF[edepInRangeDF[" frameNumber"] == minFrameNumber])
#        else:
#            fullSpectrum = fullSpectrum + spectrum(edepInRangeDF[edepInRangeDF[" frameNumber"] == frameNumber])
#
#    fullSpectrum = fullSpectrum / (max(np.unique(edepInRangeDF[" frameNumber"]))+1)
#
#   return fullSpectrum

def getFullSpectrumFromDF(edepInRangeDF, nbins=10, cutString = " edep", minVal = None, maxVal = None, saveDF=True, saveFullDF = False, includeXIFUFlag=False):
    #print(edepInRangeDF.head())
    totIntTime = max(edepInRangeDF[" frameNumber"]) * WFIIntegrationTime
    #print("totIntTime is", totIntTime, "seconds")

    fluxType = edepInRangeDF[" inputSpectrum"].iloc[0]

    if includeXIFUFlag == False:
        edepInRangeDF = edepInRangeDF[edepInRangeDF[" vertexvname"] != "XIFU sphere"]

    fullSpectrum = spectrum(edepInRangeDF, totTimeOfIntegration = totIntTime, nbins=nbins, cutString = cutString, minVal = minVal, maxVal = maxVal, fluxType = fluxType, saveDF = saveDF, saveFullDF = saveFullDF)

    return fullSpectrum
