#!/usr/bin/env python
import os
import argparse
import math
import json


def main():
    '''
    Test up(down)stream energy calorimeter correction.

    This script expects root file containing TTree "events" with three branches:
     * std::vector<double> energyInLayer -- energy deposited in calorimeter layers
                                         -- layer_id 0 at position 0
                                         -- layer_id 1 at position 1
                                         -- etc.
     * std::vector<double> energyInCryo -- energy deposited in the whole cryostat and also in its parts
                                        -- energy deposited in whole cryostat at position 0
                                        -- energy deposited in front cryostat at position 1
                                        -- energy deposited in back cryostat at position 2
                                        -- energy deposited in cryostat sides at position 3
                                        -- energy deposited in front LAr bath at position 4
                                        -- energy deposited in back LAr bath at position 5
     * std::vector<double> particleVec -- four-momentum of initial particle
          -- mass at position 0
          -- px at position 1
          -- py at position 2
          -- pz at position 3
    and json file with parameters of the calorimeter energy corrections.

    Output of this script is input for the cec_test_summary script.
    '''
    parser = argparse.ArgumentParser(description='Yay, test up(down)stream energy corrections!')

    parser.add_argument('-n', '--note', action='append', default=[], type=str,
                        help='additional note to be displayed in the plots')
    parser.add_argument('-o', '--output-file-format', type=str, default='pdf', help='output file format')

    required_arguments = parser.add_argument_group('required arguments')
    required_arguments.add_argument('-i', '--input-file', type=str, required=True, help='input file path')
    required_arguments.add_argument('-c', '--corr-file', type=str, required=True,
                                    help='file with correction parameters')

    args = parser.parse_args()

    if not os.path.isfile(args.input_file):
        print('WARNING: Input ROOT file not found!')
        print('         ' + args.input_file)
        exit(1)

    if not os.path.isfile(args.corr_file):
        print('WARNING: Input JSON file not found!')
        print('         ' + args.corr_file)
        exit(1)

    ec = EnergyCorr(args)

    from ROOT import gROOT, TFile, TH1D
    from ROOT import kRed, kGreen, kBlue
    gROOT.SetBatch(True)

    infile = TFile(args.input_file, 'READ')

    if not infile.events:
        print('ERROR: TTree "events" not found in the provided ROOT file!')
        infile.Close()
        exit(1)

    n_events = infile.events.GetEntries()

    particle_momentum = 0.
    particle_theta = 0.
    for event in infile.events:
        particle_momentum += get_momentum(event)
        particle_theta += get_theta(event)
    particle_momentum = int(round(particle_momentum / n_events))
    particle_theta = int(round(particle_theta / n_events))
    print('INFO: Initial particle momentum: %i GeV' % (particle_momentum))
    print('                       theta:    %i deg' % (particle_theta))

    hist_calo_energy = TH1D('hist_calo_energy', ';E [GeV];Number of events',
                            100, 0.8 * particle_momentum, 1.1 * particle_momentum)
    hist_calo_energy.Sumw2()
    hist_calo_energy.SetDirectory(0)

    hist_calo_and_cryo_energy = TH1D('hist_calo_and_cryo_energy', ';E [GeV];Number of events',
                                     100, 0.8 * particle_momentum, 1.1 * particle_momentum)
    hist_calo_and_cryo_energy.Sumw2()
    hist_calo_and_cryo_energy.SetDirectory(0)

    hist_calo_and_corr_energy = TH1D('hist_calo_and_corr_energy', ';E [GeV];Number of events',
                                     100, 0.8 * particle_momentum, 1.1 * particle_momentum)
    hist_calo_and_corr_energy.Sumw2()
    hist_calo_and_corr_energy.SetDirectory(0)

    for event in infile.events:
        energy_in_calo = 0
        cluster_theta = get_theta(event)
        for i in range(len(event.energyInLayer)):
            energy_in_calo += event.energyInLayer[i]

        energy_corr = ec.get_corr(event.energyInLayer[0], event.energyInLayer[-1], energy_in_calo, cluster_theta)

        hist_calo_energy.Fill(energy_in_calo)
        hist_calo_and_cryo_energy.Fill(energy_in_calo + event.energyInCryo[0])
        hist_calo_and_corr_energy.Fill(energy_in_calo + energy_corr)

    infile.Close()
    hist_calo_energy.BufferEmpty()
    hist_calo_and_cryo_energy.BufferEmpty()
    hist_calo_and_corr_energy.BufferEmpty()

    fit_results = {}
    fit_results['hist_calo_energy'] = fit(hist_calo_energy, 'E_{calo}', kRed-4)
    fit_results['hist_calo_and_cryo_energy'] = fit(hist_calo_and_cryo_energy, 'E_{calo+cryo}', kBlue-4)
    fit_results['hist_calo_and_corr_energy'] = fit(hist_calo_and_corr_energy, 'E_{calo+corr}', kGreen-3)

    plot_name = 'hist_energy_corr_test_%ideg_%iGeV' % (particle_theta, particle_momentum)
    plot_notes = []
    plot_notes.append('FCC-ee, LAr Calo (12 layers)')
    plot_notes.append('e^{-}, %i GeV, %i deg' % (particle_momentum, particle_theta))
    plot(hist_calo_energy, hist_calo_and_cryo_energy, hist_calo_and_corr_energy, plot_name, plot_notes, fit_results)

    test_results = []
    try:
        with open('test_results.json') as infile:
            data = json.load(infile)
            if data.get('test_results'):
                test_results = data['test_results']
    except IOError:
        pass

    for name, result in fit_results.items():
        result['name'] = name.replace('hist_', '')
        result['energy'] = particle_momentum
        result['theta'] = particle_theta
        result.pop('legend_text', None)

        test_results = [t for t in test_results if not (t['name'] == result['name'] and
                                                        t['energy'] == result['energy'] and
                                                        t['theta'] == result['theta'])]

        test_results.append(result)

    with open('test_results.json', 'w') as outfile:
        json.dump({'test_results': test_results}, outfile, indent=4)


def get_momentum(event):
    momentum = math.sqrt(math.pow(event.particleVec[1], 2) +
                         math.pow(event.particleVec[2], 2) +
                         math.pow(event.particleVec[3], 2))

    return momentum


def get_theta(event):
    r = math.sqrt(math.pow(event.particleVec[1], 2) + math.pow(event.particleVec[2], 2))
    theta = abs(math.atan2(r, event.particleVec[3]))
    theta = 180 * theta / math.pi

    return theta


def plot(hist1, hist2, hist3, plotname, plot_notes=[], result_dict={}):
    from ROOT import gPad, gStyle
    from ROOT import TCanvas, TPaveText, TLegend
    from ROOT import kRed, kBlue, kGreen

    canvas = TCanvas('canvas_' + plotname, 'Canvas', 450, 450)
    gPad.SetLeftMargin(.13)
    gPad.SetTopMargin(.05)
    gPad.SetRightMargin(.05)

    gStyle.SetOptStat(0)

    hist1.SetLineColor(kRed+2)
    hist1.SetMarkerColor(kRed+2)
    hist1.SetLineWidth(2)

    hist2.SetLineColor(kBlue+2)
    hist2.SetMarkerColor(kBlue+2)
    hist2.SetLineWidth(2)

    hist3.SetLineColor(kGreen+3)
    hist3.SetMarkerColor(kGreen+3)
    hist3.SetLineWidth(2)

    note = TPaveText(.13, .55, .6, .7, "brNDC")
    note.SetFillStyle(0)
    note.SetFillColor(0)
    note.SetBorderSize(0)
    note.SetTextColor(1)
    note.SetTextFont(42)
    note.SetTextAlign(11)
    for note_text in plot_notes:
        note.AddText(note_text)

    legend = TLegend(.13, .7, .92, .94)
    legend.SetBorderSize(0)
    legend.SetFillStyle(0)
    legend.SetFillColor(0)
    legend.AddEntry(hist1, result_dict['hist_calo_energy']['legend_text'], 'LEP')
    legend.AddEntry(hist2, result_dict['hist_calo_and_cryo_energy']['legend_text'], 'LEP')
    legend.AddEntry(hist3, result_dict['hist_calo_and_corr_energy']['legend_text'], 'LEP')

    set_maximum(hist1, hist2, hist3)

    hist1.Draw('')
    hist2.Draw('SAME')
    hist3.Draw('SAME')
    note.Draw()
    legend.Draw()
    canvas.Print(plotname + '.pdf')


def count_bins_with_error(hist):
    num = 0
    for i in range(hist.GetYaxis().GetFirst(), hist.GetXaxis().GetLast() + 1):
        err = hist.GetBinError(i)
        if err == 0.:
            continue
        num += 1

    return num


def set_maximum(hist1, hist2, hist3):
    arr = []
    arr.append(hist1.GetMaximum())
    arr.append(hist2.GetMaximum())
    arr.append(hist3.GetMaximum())

    hist_max = 1.7 * max(arr)

    hist1.SetMaximum(hist_max)
    hist2.SetMaximum(hist_max)
    hist3.SetMaximum(hist_max)


class EnergyCorr:
    def __init__(self, args):
        from ROOT import TF1, TF2

        try:
            with open(args.corr_file) as infile:
                data = json.load(infile)
                if data.get('corr_params'):
                    corr_params = data['corr_params']
                else:
                    print('ERROR: Input file does not contain "results" array!')
                    print('       ' + args.corr_file)
                    exit(1)
        except IOError:
            print('ERROR: Input JSON file not found!')
            print('       ' + args.corr_file)
            exit(1)

        func_defs = {}
        for param in corr_params:
            if param['mother_param'] not in func_defs.keys():
                func_defs[param['mother_param']] = param['func']

        functions = []
        for name, func_def in func_defs.items():
            if 'y' in func_def:
                functions.append(TF2('func_' + name, func_def, 0., 500., 0., 95.))
                self.is_2d = True
            else:
                functions.append(TF1('func_' + name, func_def, 0., 500.))

        func_index = []
        for func in functions:
            mother_param = func.GetName().replace('func_', '')
            for i in range(func.GetNpar()):
                param_dict = [p for p in corr_params if p['mother_param'] == mother_param and p['index'] == i][0]
                func.SetParName(i, param_dict['name'])
                func.SetParameter(i, param_dict['value'])

            func_index.append(int(filter(str.isdigit, mother_param)))

        self.functions = zip(func_index, functions)

    def get_corr(self, energy_in_first_layer, energy_in_last_layer, cluster_energy, cluster_theta=None):
        corr = 0
        for func_index, func in self.functions:
            if 'upsilon' in func.GetName():
                corr += func.Eval(cluster_energy, cluster_theta) * pow(energy_in_first_layer, func_index)
            if 'delta' in func.GetName():
                corr += func.Eval(cluster_energy, cluster_theta) * pow(energy_in_last_layer, func_index)

        return corr


def fit(hist, text, color):
    from ROOT import TF1

    x_min = hist.GetBinLowEdge(1)
    x_max = hist.GetBinLowEdge(hist.GetNbinsX())
    func = TF1('func', 'gaus', x_min, x_max)
    result = hist.Fit(func, 'SR')

    result_dict = {}
    result_dict['mean'] = result.Parameter(1)
    result_dict['sigma'] = result.Parameter(2)
    result_dict['resolution'] = 100 * result.Parameter(2) / result.Parameter(1)
    result_dict['legend_text'] = text
    result_dict['legend_text'] += ' = %.2f #pm %.2f GeV,   #frac{#sigma}{E} = %.2f %%' % (result_dict['mean'],
                                                                                          result_dict['sigma'],
                                                                                          result_dict['resolution'])

    fit_func = hist.GetFunction('func')
    fit_func.SetLineColor(color)

    return result_dict


if __name__ == '__main__':
    main()
