#!/usr/bin/env python
# coding: utf-8

# In[2]:


#%run SpectrumObject.py
#get_ipython().magic(u'run SpectrumObjectNew.py')


import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import os

import matplotlib.transforms as transforms

plt.rcParams.update({"font.size":15})

from tqdm import tqdm_notebook as tqdm

import scipy


# In[3]:


import pickle as pkl


def plotAlLine(plotLabels=True):
    
    plt.axvline(1.48670, ls="dotted",color="grey")
    
    if plotLabels == True:
        trans = transforms.blended_transform_factory(plt.gca().transData, plt.gca().transAxes)
        plt.text(1.48670,1.02,'Al', 
             transform=trans,
            horizontalalignment="center", 
            color="grey")

def plotMoLines(plotLabels=True):
    
    plt.axvline(1.48670, ls="dotted",color="grey")
    
    if plotLabels == True:
        trans = transforms.blended_transform_factory(plt.gca().transData, plt.gca().transAxes)
        plt.text(1.48670,1.02,'Al', 
             transform=trans,
            horizontalalignment="center", 
            color="grey")
        
def plotWFILine(plotLabels=True):

    plt.axhline(5e-3, ls="dotted",color="red")
    
    if plotLabels == True:
        trans = transforms.blended_transform_factory(plt.gca().transAxes,plt.gca().transData)
        plt.text(1.1,5e-3,r'$5 \times 10^{-3}$', 
             transform=trans,
            horizontalalignment="center", 
                verticalalignment="center",
            color="red")


def importSpectraFromDir(dirName, conditions=None, saveFullDF = False, nbins=100):
    
    fileList = os.listdir(dirName)
    
    fullspectrum = 0
    firstSpec = True
    
    fileFails = 0

    for filename in tqdm(fileList,dirName):
        
        conditionMultFactor = 1
    
        try:
            inputDF = pd.read_csv(dirName + filename + "/spectrumDF.csv")
        
            if conditions != None:
                #boolDF="test"
                #boolDFString = "boolDF = "
                #eval(boolDFString + conditions)
                
                lastEid = inputDF[" eid"].tail(2).iloc[0]
            
                inputDF = inputDF[eval(conditions)]
                
                conditionMultFactor = (inputDF[" eid"].tail(2).iloc[0])/lastEid
    
            if firstSpec == True:
                fullspectrum = getFullSpectrumFromDF(inputDF, saveDF = False, saveFullDF = saveFullDF, nbins=nbins) * conditionMultFactor
            else:
                fullspectrum += getFullSpectrumFromDF(inputDF, saveDF=False, saveFullDF = saveFullDF, nbins=nbins) * conditionMultFactor
            
            fullspectrum = fullspectrum
            firstSpec = False
        except:
            print("WARNING file inclusion failed, skipping...")
            fileFails += 1
    
    #fullspectrum = fullspectrum / (len(fileList) - fileFails)
    fullspectrum = fullspectrum / (len(fileList))

    return fullspectrum

def getMeanFlux(inputDF, eMin=0.1,eMax=15):
    
    #plt.figure()
    
    #inputDF.plot()
    
    #plotAlLine()
    #plt.axvline(1.7)
    #plt.yscale("linear")

    eVals = inputDF.histEdgeLefts
    seVals = inputDF.histValues
    seStd = inputDF.sdValues

    seRange = seVals[(eVals > eMin) & (eVals < eMax)]
    seStd = seStd[(eVals > eMin) & (eVals < eMax)]

    #print(sum(seRange)/len(seRange))
    #print(np.sqrt(sum(seStd**2))/len(seRange))
    
    print(sum(seRange)/len(seRange),"+/-",np.sqrt(sum(seStd**2))/len(seRange))
    
    #plt.axhline(sum(seRange)/len(seRange))
    
    return (sum(seRange)/len(seRange),np.sqrt(sum(seStd**2))/len(seRange))

def getMaxBinCountRate(inputDF):
    
    maxVal = max(inputDF.histValues) #* np.append((inputDF.histEdgeLefts[1:] - inputDF.histEdgeLefts[:-1]),5000))
    maxPos = np.argmax(inputDF.histValues)
    maxDiff = (inputDF.histEdgeLefts[1:] - inputDF.histEdgeLefts[:-1])[maxPos]
    maxStd = inputDF.sdValues[maxPos]
    return (inputDF.histEdgeLefts[maxPos],inputDF.histEdgeLefts[maxPos + 1],maxVal * maxDiff,maxStd * maxDiff )



def importEKinSpectraFromDir(dirName, minVal=0.1,maxVal=10000, nbins=50, conditions=None):
    
    fileList = os.listdir(dirName)
    
    fullspectrum = 0
    firstSpec = True

    for filename in fileList:
    
        inputDF = pd.read_csv(dirName + filename + "/spectrumDF.csv")
        
        if conditions != None:
            #boolDF="test"
            #boolDFString = "boolDF = "
            #eval(boolDFString + conditions)
            
            inputDF = inputDF[eval(conditions)]
    
        if firstSpec == True:
            fullspectrum = getFullSpectrumFromDF(inputDF, 
                                                 cutString=" ecur",
                                                 minVal=minVal,
                                                 maxVal = maxVal,
                                                 nbins=nbins,
                                                 saveDF = True)
        else:
            fullspectrum += getFullSpectrumFromDF(inputDF, 
                                                  cutString=" ecur",
                                                  minVal=minVal,
                                                 maxVal = maxVal,
                                                  nbins=nbins,
                                                  saveDF=True)
    
        firstSpec = False
    
    fullspectrum = fullspectrum / len(fileList)

    return fullspectrum

def importEPrimarySpectraFromDir(dirName, minVal=0.1,maxVal=10000, nbins=50, conditions=None):
    
    fileList = os.listdir(dirName)
    
    fullspectrum = 0
    firstSpec = True

    for filename in fileList:
    
        inputDF = pd.read_csv(dirName + filename + "/spectrumDF.csv")
        
        if conditions != None:
            #boolDF="test"
            #boolDFString = "boolDF = "
            #eval(boolDFString + conditions)
            
            inputDF = inputDF[eval(conditions)]
    
        if firstSpec == True:
            fullspectrum = getFullSpectrumFromDF(inputDF, 
                                                 cutString=" eprimary",
                                                 minVal=minVal,
                                                 maxVal = maxVal,
                                                 nbins=nbins,
                                                 saveDF = True,
                                                saveFullDF=False)
        else:
            fullspectrum += getFullSpectrumFromDF(inputDF, 
                                                  cutString=" eprimary",
                                                  minVal=minVal,
                                                 maxVal = maxVal,
                                                  nbins=nbins,
                                                  saveDF=True,
                                                 saveFullDF=False)
    
        firstSpec = False
    
    fullspectrum = fullspectrum / len(fileList)

    return fullspectrum



def plotPos(inputDF,xString=" vx", yString = " vy", plotCircles=True, 
            xlim=[-0.1,0.1], ylim=[-0.1,0.1],
           conditionList = [],
           ms=0.2,
           recSide = "side"):
    
    detWidth = 0.04
    detDepth = 450e-6
    
    if plotCircles==True:
        genCircle = plt.Circle((0,0),genRadius,fc='White',ec="red")
        plt.gca().add_patch(genCircle)

        outerCircle = plt.Circle((0,0),outerRadius,fc='darkgrey',ec="red")
        plt.gca().add_patch(outerCircle)

        innerCircle = plt.Circle((0,0),innerRadius,fc='White',ec="red")
        plt.gca().add_patch(innerCircle)
    
    if recSide == "side":
        rectangle = plt.Rectangle((-detWidth/2,-detDepth/2), detWidth, detDepth, fc='Pink',ec="red")
        plt.gca().add_patch(rectangle)
    elif recSide == "face":
        rectangle = plt.Rectangle((-detWidth/2,-detWidth/2), detWidth, detWidth, fc='Pink',ec="red")
        plt.gca().add_patch(rectangle)
    
    inputDFInERange = inputDF[(inputDF[" edep"] > 0.1) & (inputDF[" edep"] < 15)]
    plt.plot(inputDFInERange[xString],
             inputDFInERange[yString],
             ms=ms, marker="o", ls="None")
    
    for condition in conditionList:
        
        inputDFSpec = inputDFInERange[eval(condition)]
        plt.plot(inputDFSpec[xString],
             inputDFSpec[yString],
             ms=ms, marker="o", ls="None")
        
    
    plt.grid(True, "both")
    plt.gca().set_aspect(1)
    plt.xlim(xlim)
    plt.ylim(ylim)
    
def plotPosSwivel(inputDF,viewString=" vz",plotCircles=True, 
            xlim=[-0.1,0.1], ylim=[-0.1,0.1],
           conditionList = [],
           ms=0.2,
           recSide = "side"):
    
    detWidth = 0.04
    detDepth = 450e-6
    
    if viewString == " vz":
        firstString = " vx"
        secondString = " vy"
    if viewString == " vx":
        firstString = " vy"
        secondString = " vz"
    if viewString == " vy":
        firstString = " vx"
        secondString = " vz"

    if plotCircles == True:
    
        genCircle = plt.Circle((0,0),genRadius,fc='White',ec="red")
        plt.gca().add_patch(genCircle)

        outerCircle = plt.Circle((0,0),outerRadius,fc='darkgrey',ec="red")
        plt.gca().add_patch(outerCircle)

        innerCircle = plt.Circle((0,0),innerRadius,fc='White',ec="red")
        plt.gca().add_patch(innerCircle)
    
    #if recSide == "side":
    if viewString == " vz":
        rectangle = plt.Rectangle((-detWidth/2,-detWidth/2), detWidth, detWidth, fc='Pink',ec="red")
        plt.gca().add_patch(rectangle)
    else:
        rectangle = plt.Rectangle((-detWidth/2,-detDepth/2), detWidth, detDepth, fc='Pink',ec="red")
        plt.gca().add_patch(rectangle)
        
    
    inputDFInERange = inputDF[(inputDF[" edep"] > 0.1) & (inputDF[" edep"] < 15)]
    plt.plot(np.sign(inputDFInERange[firstString]) * 
             np.sqrt(inputDFInERange[viewString]**2 + inputDFInERange[firstString]**2),
             inputDFInERange[secondString],
             ms=ms, marker="o", ls="None")
    
    for condition in conditionList:
        
        inputDFSpec = inputDFInERange[eval(condition)]
        plt.plot(np.sign(inputDFSpec[firstString]) *
             np.sqrt(inputDFSpec[viewString]**2 + inputDFSpec[firstString]**2),
             inputDFSpec[secondString],
             ms=ms, marker="o", ls="None")
    
    plt.grid(True, "both")
    plt.gca().set_aspect(1)
    plt.xlim(xlim)
    plt.ylim(ylim)


# In[77]:



def plotRPosHist(inputDF, conditionList=[], nbins = 50, fluxType="GCR"):
    
    inputDFInERange = inputDF[(inputDF[" edep"] > 0.1) & (inputDF[" edep"] < 15)]
    
    rPosList = np.sqrt((inputDFInERange[" vx"]**2) + 
               (inputDFInERange[" vy"]**2) +
               (inputDFInERange[" vz"]**2))
    
    if fluxType == "GCR":
        fluxIntConst = 5.1138
    elif fluxType == "CXB":
        fluxIntConst = 41.117
        
    lastEid = inputDF[" eid"].tail(2).iloc[0]

    normConst = (fluxIntConst * np.pi * (9**2)) / ((4**2) * lastEid)
    
    (counts, edges) = np.histogram(rPosList, np.linspace(0,0.09,nbins))
    countsNormalised = counts  * normConst /(edges[1:] - edges[:-1])
    stdNormalised = np.sqrt(counts) * normConst /(edges[1:] - edges[:-1])
    plt.step(edges[:-1], countsNormalised)
    
    stepColor = plt.gca().lines[-1].get_color()
    
    plt.fill_between(edges[:-1], 
                     countsNormalised-stdNormalised, countsNormalised + stdNormalised,
                    alpha=0.5,
                    color=stepColor,
                    interpolate=False,
                    step='pre')
    
    for condition in conditionList:
        
        inputDFSpec = inputDFInERange[eval(condition)]
        rPosList = np.sqrt((inputDFSpec[" vx"]**2) + 
               (inputDFSpec[" vy"]**2) +
               (inputDFSpec[" vz"]**2))
        (counts, edges) = np.histogram(rPosList, np.linspace(0,0.09,nbins))
        countsNormalised = counts * normConst/(edges[1:] - edges[:-1])
        stdNormalised = np.sqrt(counts) * normConst/(edges[1:] - edges[:-1])
        
        plt.step(edges[:-1], countsNormalised)
        
        stepColor = plt.gca().lines[-1].get_color()
    
        plt.fill_between(edges[:-1], 
                     countsNormalised-stdNormalised, countsNormalised + stdNormalised,
                    alpha=0.5,
                    color=stepColor,
                    interpolate=False,
                    step='pre')
        
    
    plt.grid(True, "both")
    #plt.gca().set_aspect(1)
    #plt.xlim(xlim)
    #plt.ylim(ylim)
    
def getRPosVals(inputDF, conditionList=[], nbins = 50, fluxType="GCR"):
    
    inputDFInERange = inputDF[(inputDF[" edep"] > 0.1) & (inputDF[" edep"] < 15)]
    
    rPosList = np.sqrt((inputDFInERange[" vx"]**2) + 
               (inputDFInERange[" vy"]**2) +
               (inputDFInERange[" vz"]**2))
    
    if fluxType == "GCR":
        fluxIntConst = 5.1138
    elif fluxType == "CXB":
        fluxIntConst = 41.117
        
    lastEid = inputDF[" eid"].tail(2).iloc[0]

    normConst = (fluxIntConst * np.pi * (9**2)) / ((4**2) * lastEid)
    
    (counts, edges) = np.histogram(rPosList, np.linspace(0,0.09,nbins))
    countsNormalised = counts  * normConst /(edges[1:] - edges[:-1])
    stdNormalised = np.sqrt(counts) * normConst /(edges[1:] - edges[:-1])
    
    return ((edges[:-1] + edges[1:])/2,countsNormalised, stdNormalised, edges[:-1],edges[1:])


def plotThetaPosHist(inputDF, conditionList=[], nbins = 50, fluxType="GCR"):
    
    inputDFInERange = inputDF[(inputDF[" edep"] > 0.1) & (inputDF[" edep"] < 15)]
    
    thetaPosList = np.arccos((inputDFInERange[" vz"]-inputDFInERange[" z"])/  
                              np.sqrt(((inputDFInERange[" vx"]-inputDFInERange[" x"])**2) + 
               ((inputDFInERange[" vy"]-inputDFInERange[" y"])**2) +
               ((inputDFInERange[" vz"]-inputDFInERange[" z"])**2)))
    
    if fluxType == "GCR":
        fluxIntConst = 5.1138
    elif fluxType == "CXB":
        fluxIntConst = 41.117
        
    lastEid = inputDF[" eid"].tail(2).iloc[0]

    normConst = (fluxIntConst * np.pi * (9**2)) / ((4**2) * lastEid)
    
    (counts, edges) = np.histogram(thetaPosList, np.linspace(0,np.pi,nbins))
    stds = np.sqrt(counts)
    countsNormalised = counts * normConst/(edges[1:] - edges[:-1])
    stdsNormalised = stds * normConst/(edges[1:] - edges[:-1])
    
    plt.step(edges[:-1], countsNormalised, where="post")
    stepColor = plt.gca().lines[-1].get_color()
        
    plt.fill_between(edges[:-1], 
                         countsNormalised - stdsNormalised, 
                         countsNormalised + stdsNormalised,
                         interpolate=False,
                         step='post',
                         #color='cyan',
                         color=stepColor,
                         alpha=0.5)
    
    for condition in conditionList:
        
        inputDFSpec = inputDFInERange[eval(condition)]
        thetaPosList = np.arccos((inputDFSpec[" vz"]-inputDFSpec[" z"])/  
                              np.sqrt(((inputDFSpec[" vx"]-inputDFSpec[" x"])**2) + 
               ((inputDFSpec[" vy"]-inputDFSpec[" y"])**2) +
               ((inputDFSpec[" vz"]-inputDFSpec[" z"])**2)))
        (counts, edges) = np.histogram(thetaPosList, np.linspace(0,np.pi,nbins))
        stds = np.sqrt(counts)
        countsNormalised = counts * normConst/(edges[1:] - edges[:-1])
        stdsNormalised = stds * normConst/(edges[1:] - edges[:-1])
        
        plt.step(edges[:-1], countsNormalised, where="post")
        stepColor = plt.gca().lines[-1].get_color()
        
        plt.fill_between(edges[:-1], 
                         countsNormalised - stdsNormalised, 
                         countsNormalised + stdsNormalised,
                         interpolate=False,
                         step='post',
                         #color='cyan',
                         color=stepColor,
                         alpha=0.5)
        
    
    plt.grid(True, "both")
    #plt.gca().set_aspect(1)
    #plt.xlim(xlim)
    #plt.ylim(ylim)
    plt.ylabel("cts / $cm^2$ / $s$ / radian")
    plt.xlabel(r"$ \theta $")
    
    plt.xlim([0,np.pi])
    
def plotSolidAngleDiffThetaPosHist(inputDF, conditionList=[], nbins = 50, fluxType="GCR"):
    
    inputDFInERange = inputDF[(inputDF[" edep"] > 0.1) & 
                              (inputDF[" edep"] < 15)]
    
    thetaPosList = np.arccos((inputDFInERange[" vz"])/  
                              np.sqrt(((inputDFInERange[" vx"])**2) + 
               ((inputDFInERange[" vy"])**2) +
               ((inputDFInERange[" vz"])**2)))
    
    if fluxType == "GCR":
        fluxIntConst = 5.1138
    elif fluxType == "CXB":
        fluxIntConst = 41.117
        
    lastEid = inputDF[" eid"].tail(2).iloc[0]

    normConst = (fluxIntConst * np.pi * (9**2)) / ((4**2) * lastEid)
    
    (counts, edges) = np.histogram(thetaPosList, np.linspace(0,np.pi,nbins))
    edgeMids = (edges[1:] + edges[:-1])/2
    
    stds = np.sqrt(counts)
    countsNormalised = counts * normConst/((edges[1:] - edges[:-1]) * np.sin(edgeMids)*2*np.pi)
    stdsNormalised = stds * normConst/((edges[1:] - edges[:-1]) * np.sin(edgeMids)*2*np.pi)
    
    plt.step(edges[:-1], countsNormalised, where="post")
    stepColor = plt.gca().lines[-1].get_color()
        
    plt.fill_between(edges[:-1], 
                         countsNormalised - stdsNormalised, 
                         countsNormalised + stdsNormalised,
                         interpolate=False,
                         step='post',
                         #color='cyan',
                         color=stepColor,
                         alpha=0.5)
    
    for condition in conditionList:
        
        inputDFSpec = inputDFInERange[eval(condition)]
        thetaPosList = np.arccos((inputDFSpec[" vz"])/  
                              np.sqrt(((inputDFSpec[" vx"])**2) + 
               ((inputDFSpec[" vy"])**2) +
               ((inputDFSpec[" vz"])**2)))
        (counts, edges) = np.histogram(thetaPosList, np.linspace(0,np.pi,nbins))
        edgeMids = (edges[1:] + edges[:-1])/2
        
        stds = np.sqrt(counts)
        countsNormalised = counts * normConst/((edges[1:] - edges[:-1])*np.sin(edgeMids)*2*np.pi)
        stdsNormalised = stds * normConst/((edges[1:] - edges[:-1])*np.sin(edgeMids)*2*np.pi)
        
        plt.step(edges[:-1], countsNormalised, where="post")
        stepColor = plt.gca().lines[-1].get_color()
        
        plt.fill_between(edges[:-1], 
                         countsNormalised - stdsNormalised, 
                         countsNormalised + stdsNormalised,
                         interpolate=False,
                         step='post',
                         #color='cyan',
                         color=stepColor,
                         alpha=0.5)
        
    
    plt.grid(True, "both")
    #plt.gca().set_aspect(1)
    #plt.xlim(xlim)
    #plt.ylim(ylim)
    plt.ylabel("cts / $cm^2$ / $s$ / sr")
    plt.xlabel(r"$ \theta $")
    
    plt.xlim([0,np.pi])
    
def plotSolidAngleDiffThetaPosVals(inputDF, conditionList=[], nbins = 50, fluxType="GCR"):
    
    inputDFInERange = inputDF[(inputDF[" edep"] > 0.1) & 
                              (inputDF[" edep"] < 15)]
    
    thetaPosList = np.arccos((inputDFInERange[" vz"])/  
                              np.sqrt(((inputDFInERange[" vx"])**2) + 
               ((inputDFInERange[" vy"])**2) +
               ((inputDFInERange[" vz"])**2)))
    
    if fluxType == "GCR":
        fluxIntConst = 5.1138
    elif fluxType == "CXB":
        fluxIntConst = 41.117
        
    lastEid = inputDF[" eid"].tail(2).iloc[0]

    normConst = (fluxIntConst * np.pi * (9**2)) / ((4**2) * lastEid)
    
    (counts, edges) = np.histogram(thetaPosList, np.linspace(0,np.pi,nbins))
    edgeMids = (edges[1:] + edges[:-1])/2
    
    stds = np.sqrt(counts)
    countsNormalised = counts * normConst/((edges[1:] - edges[:-1]) * np.sin(edgeMids)*2*np.pi)
    stdsNormalised = stds * normConst/((edges[1:] - edges[:-1]) * np.sin(edgeMids)*2*np.pi)
    
    return (edgeMids,countsNormalised,stdsNormalised)
    
def plotMomSolidAngleDiffThetaPosHist(inputDF, conditionList=[], nbins = 50, fluxType="GCR"):
    
    inputDFInERange = inputDF[(inputDF[" edep"] > 0.1) & 
                              (inputDF[" edep"] < 15)]
    
    thetaPosList = np.arccos((inputDFInERange[" z"]-inputDFInERange[" vz"])/  
                              np.sqrt(((inputDFInERange[" x"]-inputDFInERange[" vx"])**2) + 
               ((inputDFInERange[" y"]-inputDFInERange[" vy"])**2) +
               ((inputDFInERange[" z"]-inputDFInERange[" vz"])**2)))
    
    if fluxType == "GCR":
        fluxIntConst = 5.1138
    elif fluxType == "CXB":
        fluxIntConst = 41.117
        
    lastEid = inputDF[" eid"].tail(2).iloc[0]

    normConst = (fluxIntConst * np.pi * (9**2)) / ((4**2) * lastEid)
    
    (counts, edges) = np.histogram(thetaPosList, np.linspace(0,np.pi,nbins))
    edgeMids = (edges[1:] + edges[:-1])/2
    
    stds = np.sqrt(counts)
    countsNormalised = counts * normConst/((edges[1:] - edges[:-1]) * np.sin(edgeMids) * 2 * np.pi)
    stdsNormalised = stds * normConst/((edges[1:] - edges[:-1]) * np.sin(edgeMids) * 2 * np.pi)
    
    plt.step(edges[:-1], countsNormalised, where="post",zorder=50)
    stepColor = plt.gca().lines[-1].get_color()
        
    plt.fill_between(edges[:-1], 
                         countsNormalised - stdsNormalised, 
                         countsNormalised + stdsNormalised,
                         interpolate=False,
                         step='post',
                         #color='cyan',
                         color=stepColor,
                         alpha=0.5,
                    zorder=50)
    
    for condition in conditionList:
        
        inputDFSpec = inputDFInERange[eval(condition)]
        thetaPosList = np.arccos((inputDFSpec[" vz"])/  
                              np.sqrt(((inputDFSpec[" vx"])**2) + 
               ((inputDFSpec[" vy"])**2) +
               ((inputDFSpec[" vz"])**2)))
        (counts, edges) = np.histogram(thetaPosList, np.linspace(0,np.pi,nbins))
        edgeMids = (edges[1:] + edges[:-1])/2
        
        stds = np.sqrt(counts)
        countsNormalised = counts * normConst/((edges[1:] - edges[:-1])*np.sin(edgeMids) * 2 * np.pi)
        stdsNormalised = stds * normConst/((edges[1:] - edges[:-1])*np.sin(edgeMids) * 2 * np.pi)
        
        plt.step(edges[:-1], countsNormalised, where="post",zorder=30)
        stepColor = plt.gca().lines[-1].get_color()
        
        plt.fill_between(edges[:-1], 
                         countsNormalised - stdsNormalised, 
                         countsNormalised + stdsNormalised,
                         interpolate=False,
                         step='post',
                         #color='cyan',
                         color=stepColor,
                         alpha=0.5,
                         zorder=30)
        
    
    plt.grid(True, "both")
    #plt.gca().set_aspect(1)
    #plt.xlim(xlim)
    #plt.ylim(ylim)
    plt.ylabel("cts / $cm^2$ / $s$ / sr")
    plt.xlabel(r"$ \theta $")
    
    plt.xlim([0,np.pi])
    
def plotMomSolidAngleDiffThetaPosVals(inputDF, conditionList=[], nbins = 50, fluxType="GCR"):
    
    inputDFInERange = inputDF[(inputDF[" edep"] > 0.1) & 
                              (inputDF[" edep"] < 15)]
    
    thetaPosList = np.arccos((inputDFInERange[" z"]-inputDFInERange[" vz"])/  
                              np.sqrt(((inputDFInERange[" x"]-inputDFInERange[" vx"])**2) + 
               ((inputDFInERange[" y"]-inputDFInERange[" vy"])**2) +
               ((inputDFInERange[" z"]-inputDFInERange[" vz"])**2)))
    
    if fluxType == "GCR":
        fluxIntConst = 5.1138
    elif fluxType == "CXB":
        fluxIntConst = 41.117
        
    lastEid = inputDF[" eid"].tail(2).iloc[0]

    normConst = (fluxIntConst * np.pi * (9**2)) / ((4**2) * lastEid)
    
    (counts, edges) = np.histogram(thetaPosList, np.linspace(0,np.pi,nbins))
    edgeMids = (edges[1:] + edges[:-1])/2
    
    stds = np.sqrt(counts)
    countsNormalised = counts * normConst/((edges[1:] - edges[:-1]) * np.sin(edgeMids) * 2 * np.pi)
    stdsNormalised = stds * normConst/((edges[1:] - edges[:-1]) * np.sin(edgeMids) * 2 * np.pi)
    
    return (edgeMids,countsNormalised,stdsNormalised)
    
def plotSinNormedThetaPosHist(inputDF, conditionList=[], nbins = 50, fluxType="GCR"):
    
    inputDFInERange = inputDF[(inputDF[" edep"] > 0.1) & (inputDF[" edep"] < 15)]
    
    thetaPosList = np.arccos(inputDFInERange[" vz"]/  
                              np.sqrt((inputDFInERange[" vx"]**2) + 
               (inputDFInERange[" vy"]**2) +
               (inputDFInERange[" vz"]**2)))
    
    if fluxType == "GCR":
        fluxIntConst = 5.1138
    elif fluxType == "CXB":
        fluxIntConst = 41.117
        
    lastEid = inputDF[" eid"].tail(2).iloc[0]

    normConst = (fluxIntConst * np.pi * (9**2)) / ((4**2) * lastEid)
    
    (counts, edges) = np.histogram(thetaPosList, 
                                   np.linspace(0,np.pi,nbins))
    
    edgeMids = (edges[1:] + edges[:-1])/2
    
    stds = np.sqrt(counts)
    countsNormalised = (counts * normConst/(edges[1:] - edges[:-1]))/np.sin(2 * edgeMids)
    stdsNormalised = (stds * normConst/(edges[1:] - edges[:-1]))/np.sin(2 * edgeMids)
    
    plt.step(edges[:-1], countsNormalised, where="post",zorder=50)
    stepColor = plt.gca().lines[-1].get_color()
        
    plt.fill_between(edges[:-1], 
                         countsNormalised - stdsNormalised, 
                         countsNormalised + stdsNormalised,
                         interpolate=False,
                         step='post',
                         #color='cyan',
                         color=stepColor,
                         alpha=0.5,
                        zorder=50)
    
    for condition in conditionList:
        
        inputDFSpec = inputDFInERange[eval(condition)]
        thetaPosList = np.arccos((inputDFSpec[" vz"]-inputDFSpec[" z"])/  
                              np.sqrt(((inputDFSpec[" vx"]-inputDFSpec[" x"])**2) + 
               ((inputDFSpec[" vy"]-inputDFSpec[" y"])**2) +
               ((inputDFSpec[" vz"]-inputDFSpec[" z"])**2)))
        (counts, edges) = np.histogram(thetaPosList, 
                                       np.linspace(0,np.pi,nbins))
        stds = np.sqrt(counts)
        countsNormalised = (counts * normConst/(edges[1:] - edges[:-1]))/np.sin(2 * edgeMids)
        stdsNormalised = (stds * normConst/(edges[1:] - edges[:-1]))/np.sin(2 * edgeMids)
        
        plt.step(edges[:-1], countsNormalised, where="post",zorder=30)
        stepColor = plt.gca().lines[-1].get_color()
        
        plt.fill_between(edges[:-1], 
                         countsNormalised - stdsNormalised, 
                         countsNormalised + stdsNormalised,
                         interpolate=False,
                         step='post',
                         #color='cyan',
                         color=stepColor,
                         alpha=0.5,
                        zorder = 30)
        
    
    plt.grid(True, "both")
    #plt.gca().set_aspect(1)
    #plt.xlim(xlim)
    #plt.ylim(ylim)
    plt.ylabel("cts / $cm^2$ / $s$ / radian")
    plt.xlabel(r"$ \theta $")
    
    plt.xlim([0,np.pi])
    
def getThetaPosHist(inputDF, nbins = 50, fluxType="GCR"):
    
    inputDFInERange = inputDF[(inputDF[" edep"] > 0.1) & (inputDF[" edep"] < 15)]
    
    thetaPosList = np.arccos((inputDFInERange[" vz"]-inputDFInERange[" z"])/  
                              np.sqrt(((inputDFInERange[" vx"]-inputDFInERange[" x"])**2) + 
               ((inputDFInERange[" vy"]-inputDFInERange[" y"])**2) +
               ((inputDFInERange[" vz"]-inputDFInERange[" z"])**2)))
    
    if fluxType == "GCR":
        fluxIntConst = 5.1138
    elif fluxType == "CXB":
        fluxIntConst = 41.117
        
    lastEid = inputDF[" eid"].tail(2).iloc[0]

    normConst = (fluxIntConst * np.pi * (9**2)) / ((4**2) * lastEid)
    
    (counts, edges) = np.histogram(thetaPosList, np.linspace(0,np.pi,nbins))
    stds = np.sqrt(counts)
    countsNormalised = counts * normConst/(edges[1:] - edges[:-1])
    stdsNormalised = stds * normConst/(edges[1:] - edges[:-1])
    edgeMids = (edges[1:] + edges[:-1])/2
    
    return (edgeMids,countsNormalised,stdsNormalised)


from scipy.optimize import curve_fit


cprocDict = {
    "hIoni":"hadron ionisation",
    "eBrem":"electron bremsstrahlung",
    "protonInelastic":"proton inelastic scattering",
    "eIoni":"electron Ionisation",
    "neutronInelastic":"neutron inelastic scattering",
    "annihil":"annihilation",
    "compt":"Compton scattering"
}

def expandCProcString(cprocString):
    return list(map(cprocDict.get, cprocString))

