#!/usr/bin/env python

####################################################################################
# Script written to produce convex mass hulls, as part of the LH2019 project       #
# 'Determination of Independent Signal Regions in LHC Searches for New Physics'    #
# by A. Buckley, B. Fuks, H. Reyes-González, W. Waltenberger and S. Williamson.    #
####################################################################################

# from __future__ import print_function

#~~~~~~~
#script can be separated up into 4 stand alone bits (denoted by --%-- sections) depending on which information has been produced/files are installed for the user.
#~~~~~~

import re
from numpy import *
import os
import pandas as pd
import urllib2

#~~~~~~~
# if smodels is not installed, try e.g. pip3 install --user smodels
# which fetches the official database from a CERN server automatically. If this does not work, install SModelS database manually.
# then do
# from smodels.experiment.databaseObj import Database
# from smodels.tools.physicsUnits import GeV
# If database is already installed, continue with:
#~~~~~~~

#-------%-------# (Section 1): finds common analyses between SmodelS and MA5.

db = Database("./official122.pcl", force_load="pcl" )

filedata =  urllib2.urlopen('http://madanalysis.irmp.ucl.ac.be/attachment/wiki/MA5SandBox/bib_pad.dat')
datatowrite = filedata.read()

print('Beginning MA5 bib-file download with urllib2')
with open(os.getcwd()+'/AnalysisListMA5.txt', 'w+') as f:
     f.write(datatowrite)
     f.close()

AnalysesInMA5 = []
MA5Readin = open("AnalysisListMA5.txt","r+")
print "Analyses in MA5: \n"
for lines in MA5Readin:
    lcomp = lines.split("@")
    for lcompel in lcomp:
        MAnameATLAS = re.search(r".*atlas_(\S+)_(\S+)_(\S+),.*",lcompel)
        MAnameCMS = re.search(r".*cms_(\S+)_(\S+)_(\S+),.*",lcompel)
        if MAnameATLAS:
            MAnameGroupedATLAS = "ATLAS"+"-"+(MAnameATLAS.group(1)).upper()+"-"+MAnameATLAS.group(2)+"-"+MAnameATLAS.group(3)
            AnalysesInMA5.append(MAnameGroupedATLAS)
        else:
            pass
        if MAnameCMS:
            MAnameGroupedCMS = "CMS"+"-"+(MAnameCMS.group(1)).upper()+"-"+MAnameCMS.group(2)+"-"+MAnameCMS.group(3)
            AnalysesInMA5.append(MAnameGroupedCMS)
        else:
            pass
print "Analyses in MA5\n", AnalysesInMA5, "\n"

resultsSModelS = db.getExpResults ( analysisIDs = ["all"])
ListofAnalysesSModelS = []
for result in resultsSModelS:
         ListofAnalysesSModelS.append((result.globalInfo.id))
print "Analyses in SModelS\n", ListofAnalysesSModelS, "\n"

SharedAnalyses= list(set(ListofAnalysesSModelS) & set(AnalysesInMA5))
print "Shared Analyses\n", SharedAnalyses

results = db.getExpResults ( analysisIDs = SharedAnalyses )

#-------%-------#

#-------%-------# (Section 2)

#~~~~~~~
# the below are the common analyses in between SModelS and MA5 as of November 10 2019. If no new updates to MA or SModelS, can comment out section (1) above.
#~~~~~~~

results = db.getExpResults (analysisIDs=['ATLAS-SUSY-2015-06', 'CMS-SUS-16-039', 'CMS-SUS-16-033', 'CMS-SUS-17-001'])

m1range = range(0,2500,5)
m2range = range(0,2500,5)
m3range = range(0,2500,5)

Text_File = open("Lists.txt","w+")

#~~~~~~~
# Disclaimer: section written assuming no knowledge of mass ranges probed by analyses
# and works as is, but could be optimised  : (a) a looser preliminary outline of the convex mass hull could be implemented before finding the precise hull (i.e. here looking at mass intervals of 5 GeV), to eliminate the finding of data surplus to requirement; (b) some simplified model implementations in SModelS only analyse three body decays assuming a third mass, m3, that is dependent on those of m1 and m2, and this relation could be used to limit the m3 masses probed.
#~~~~~~~

print "Finding convex mass hulls\n"
print "(This can take a very long time with large mass ranges (esp looking at 3-body decays))\n"
for result in results:
    signalregions = result.datasets
    ToposForSR = []
    for SR in signalregions:
        NumTopos = len(SR.txnameList)
        for i in range(0,NumTopos):
            SR_Topo = SR.txnameList[i].txName
            Text_File.write('{0}\n'.format(str(SR_Topo)))
            print SR_Topo
            for m1 in m1range:
                for m2 in m2range:
                    try:
                        massvec = [[ m1, m2 ], [ m1, m2 ]]
                        eff = SR.txnameList[i].getEfficiencyFor(massvec)
                        # eff = SR.txnameList[i].getValueFor(massvec)
                        if eff > 0.:
                            m1m2interpol = (m1,m2)
                            Text_File.write('{0} {1}\n'.format(str(m1),str(m2)))
                        else:
                            m1m2extrapol = (m1,m2)
                    except:
                        ## if an error is encountered in the try block, then the execution
                        ## is stopped and the code jumps to handling the error and passing
                        ## over it (i.e. giving two masses will not work for branches with three)
                        pass

            for m1 in m1range:
                for m2 in m2range:
                    for m3 in m3range:
                        try:
                            massvec2 = [[m1, m2, m3], [m1, m2, m3]]
                            eff = SR.txnameList[i].getEfficiencyFor(massvec2)
                            # eff = SR.txnameList[i].getValueFor(massvec2)
                            if eff > 0.:
                                m1m2m3interpol = (m1,m2,m3)
                                Text_File.write('{0} {1} {2}\n'.format(str(m1),str(m2), str(m3)))
                            else:
                                m1m2m3extrapol = (m1,m2,m3)
                        except:
                            pass
Text_File.close()

#~~~~~~~
#Lists.txt then contains the topology information from each individual analysis, a lot of which will be overlapping.
#~~~~~~~

#-------%-------#

#-------%-------# (Section 3) Separates up the topologies from Lists.txt into one txt file per topology.


Readfrom = open("Lists.txt","r+")
def file_len(fname):
    with open(fname) as f:
        for i, l in enumerate(f):
            pass
    return i + 1
nLists = file_len("Lists.txt")
print file_len("Lists.txt")

cwd = os.getcwd()

if not os.path.exists('Topos'):
    os.mkdir('Topos')

#~~~~~~~
#here, the results for the individual topolologies are moved to different text files.
#~~~~~~~

print "Separating up by topology (can also take a bit of time)\n"

count = 0
filehandles = {}
for ctr,lines in enumerate(Readfrom):
    if ctr % 100 == 0:
        print "line", ctr, nLists
    if count < nLists:
        count = count+1
        if lines.startswith('T'):
            line = lines[:-1]
            Topo = lines[:-1]
        else:
            if not Topo in filehandles:
                filehandles[Topo]=open("Topos/"+str(Topo)+".txt",'a')
            filehandles[Topo].write(str(lines) + "\n")

for topo,handle in filehandles.items():
    handle.close()

Readfrom.close()

#-------%-------#

#-------%-------# (Section 4) Looks for the minimum and maximum ranges of m1, m2 (and m3) for each topology, and outputs them in the txt file MinMax_by_Topo.txt.

txt = open("MinMax_by_Topo.txt","w+")
print "Finding combined mass ranges\n"
for file in os.listdir("Topos"):
    if file.endswith(".txt"):
        print(file)
        txt.write("\n")
        txt.write(file)
        txt.write("\n%%%%%%%%%%%\n ")
        print(file)
        try:
            data = pd.read_csv(os.path.join("Topos", file), sep=" ", header=None)
            df = pd.DataFrame(data)
        except:
            continue
        if len(df.columns) ==2:
            df.columns = ['m1','m2']
            pd.set_option('display.max_columns', None)
            pd.set_option('display.max_rows', None)
            groupmin = df.groupby(['m1'], as_index=False).min()
            groupmax = df.groupby(['m1'], as_index=False).max()
            txt.write("\n----Min m2----\n")
            txt.write(str(groupmin))
            txt.write("\n----Max m2----\n")
            txt.write(str(groupmax))
            txt.write("\n")
        elif len(df.columns) ==3:
            df.columns = ['m1', 'm2', 'm3']
            pd.set_option('display.max_columns', None)
            pd.set_option('display.max_rows', None)
            groupmin = df.groupby(['m1','m2'], as_index=False).min()
            groupmax = df.groupby(['m1','m2'], as_index=False).max()
            txt.write("\n----Min m3----\n")
            txt.write(str(groupmin))
            txt.write("\n----Max m3----\n")
            txt.write(str(groupmax))
            txt.write("\n")
    try:
        close(txt)
    except:
        continue

#remove intermidiary files.

os.remove(os.getcwd()+'/AnalysisListMA5.txt')
os.remove(os.getcwd()+'/Lists.txt')

#-------%-------#
