#!/usr/bin/env python
import argparse
import json


def main():
    '''
    Calculate up(down)stream energy correction for FCC-ee LAr calorimeter.

    This script expects json file from "process_events" script.
    '''
    parser = argparse.ArgumentParser(description='Yay, plot test results up(down)stream correction parameters!')
    parser.add_argument('--energy-slice', type=float, default=50.,
                        help='select energy slice for the correction derivation')
    parser.add_argument('--theta-slice', type=float, default=90.,
                        help='select theta slice for the correction derivation')

    required_arguments = parser.add_argument_group('required arguments')
    required_arguments.add_argument('-i', '--input-file', type=str, required=True, help='Input file path')
    required_arguments.add_argument('-d', '--dependence', type=str, required=True,
                                    choices=['energy', 'theta'],
                                    help='select variable on which the correction is dependent')

    args = parser.parse_args()

    from ROOT import gROOT
    gROOT.SetBatch(True)

    try:
        with open(args.input_file) as infile:
            data = json.load(infile)
            if data.get('test_results'):
                test_results = data['test_results']
            else:
                print('ERROR: Input file does not contain "results" array!')
                print('       ' + args.input_file)
                exit(1)
    except IOError:
        print('ERROR: Input file not found!')
        print('       ' + args.input_file)
        exit(1)

    graph_dict_mean = {}
    for result in test_results:
        if args.dependence == 'energy' and args.theta_slice != result['theta']:
            continue
        if args.dependence == 'theta' and args.energy_slice != result['energy']:
            continue
        if not graph_dict_mean.get(result['name']):
            graph = {}
            graph['x'] = []
            graph['y'] = []
            graph['y_err'] = []
            if args.dependence == 'energy':
                graph['x_label'] = 'E_{cluster} [GeV]'
            if args.dependence == 'theta':
                graph['x_label'] = '#theta [deg]'
            graph['y_label'] = 'E/E_{cluster}'
            graph_dict_mean[result['name']] = graph
        graph_dict_mean[result['name']]['x'].append(result[args.dependence])
        graph_dict_mean[result['name']]['y'].append(result['mean'] / result['energy'])
        graph_dict_mean[result['name']]['y_err'].append(result['sigma'] / result['energy'])

    graph_dict_res = {}
    for result in test_results:
        if args.dependence == 'energy' and args.theta_slice != result['theta']:
            continue
        if args.dependence == 'theta' and args.energy_slice != result['energy']:
            continue
        if not graph_dict_res.get(result['name']):
            graph = {}
            graph['x'] = []
            graph['y'] = []
            graph['y_err'] = []
            if args.dependence == 'energy':
                graph['x_label'] = 'E_{cluster} [GeV]'
            if args.dependence == 'theta':
                graph['x_label'] = '#theta [deg]'
            graph['y_label'] = '#sigma/E [%]'
            graph_dict_res[result['name']] = graph
        graph_dict_res[result['name']]['x'].append(result[args.dependence])
        graph_dict_res[result['name']]['y'].append(result['resolution'])
        graph_dict_res[result['name']]['y_err'].append(0)

    from ROOT import TGraphErrors

    graphs_energy = []
    if args.dependence == 'energy':
        shim_size = 1.
    if args.dependence == 'theta':
        shim_size = 0.3
    shim = -shim_size
    for graph_name, graph_dict in graph_dict_mean.items():
        graph = TGraphErrors()
        graph.SetName('graph_' + graph_name)
        for i in range(len(graph_dict['x'])):
            graph.SetPoint(i, graph_dict['x'][i] + shim, graph_dict['y'][i])
            graph.SetPointError(i, 0., graph_dict['y_err'][i])

        graph.GetXaxis().SetTitle(graph_dict['x_label'])
        graph.GetYaxis().SetTitle(graph_dict['y_label'])

        shim += shim_size

        graph.Print()
        if (graph.GetN() < 3):
            print('WARNING: Number of points in parameter graph is too small!')
            exit(0)

        graphs_energy.append(graph)

    if args.dependence == 'energy':
        plot_mean(graphs_energy, 'graph_test_summary_mean_energy_%ideg' % (args.theta_slice),
                  ['#theta slice: %i deg' % (args.theta_slice)])
    if args.dependence == 'theta':
        plot_mean(graphs_energy, 'graph_test_summary_mean_theta_%igev' % (args.energy_slice),
                  ['E_{cluster} slice: %i GeV' % (args.energy_slice)])

    graphs_res = []
    for graph_name, graph_dict in graph_dict_res.items():
        graph = TGraphErrors()
        graph.SetName('graph_' + graph_name)
        for i in range(len(graph_dict['x'])):
            graph.SetPoint(i, graph_dict['x'][i], graph_dict['y'][i])
            graph.SetPointError(i, 0., graph_dict['y_err'][i])

        graph.GetXaxis().SetTitle(graph_dict['x_label'])
        graph.GetYaxis().SetTitle(graph_dict['y_label'])

        if (graph.GetN() < 3):
            print('WARNING: Number of points in parameter graph is too small!')
            exit(0)

        graphs_res.append(graph)

    plot_res(graphs_res, 'graph_validate_resolution')


def plot_mean(graphs, plotname, plot_notes=[]):
    from ROOT import gPad, gStyle
    from ROOT import TCanvas, TPaveText, TLegend, TLine
    from ROOT import kBlue, kGreen, kRed

    canvas = TCanvas('canvas_' + plotname, 'Canvas', 350, 350)
    gPad.SetLeftMargin(.12)
    gPad.SetTopMargin(.05)
    gPad.SetRightMargin(.05)
    gStyle.SetOptStat(0)

    for graph in graphs:
        graph.SetLineWidth(2)
        graph.SetMarkerSize(.5)

        if graph.GetName() == 'graph_calo_energy':
            graph.SetLineColor(kRed + 2)
            graph.SetMarkerColor(kRed + 2)
            graph.SetMarkerStyle(20)
            graph.SetMarkerSize(.4)

        if graph.GetName() == 'graph_calo_and_cryo_energy':
            graph.SetLineColor(kBlue + 2)
            graph.SetMarkerColor(kBlue + 2)
            graph.SetMarkerStyle(34)

        if graph.GetName() == 'graph_calo_and_corr_energy':
            graph.SetLineColor(kGreen + 3)
            graph.SetMarkerColor(kGreen + 3)
            graph.SetMarkerStyle(47)

        graph.GetXaxis().SetTitleOffset(1.3)
        graph.SetMaximum(1.12)
        graph.SetMinimum(0.89)

    legend = TLegend(.6, .8, .95, .94)
    legend.SetBorderSize(0)
    legend.SetFillStyle(0)
    legend.SetFillColor(0)
    for graph in graphs:
        if graph.GetName() == 'graph_calo_energy':
            legend.AddEntry(graph, 'E_{calo}', 'LEP')
        if graph.GetName() == 'graph_calo_and_cryo_energy':
            legend.AddEntry(graph, 'E_{calo+cryo}', 'LEP')
        if graph.GetName() == 'graph_calo_and_corr_energy':
            legend.AddEntry(graph, 'E_{calo+corr}', 'LEP')

    note = TPaveText(.6, .65, .95, .78, "brNDC")
    note.SetFillStyle(0)
    note.SetFillColor(0)
    note.SetBorderSize(0)
    note.SetTextColor(1)
    note.SetTextFont(42)
    note.SetTextAlign(11)
    note.AddText('FCC-ee')
    note.AddText('LAr Calo (12 layers)')
    for note_text in plot_notes:
        note.AddText(note_text)

    graphs[0].Draw('APE')
    graphs[1].Draw('PE SAME')
    graphs[2].Draw('PE SAME')
    legend.Draw()
    note.Draw()

    line = TLine(graphs[0].GetXaxis().GetXmin(), 1.,
                 graphs[0].GetXaxis().GetXmax(), 1.)
    line.SetLineStyle(7)
    line.Draw()

    canvas.Print(plotname + '.pdf')


def plot_res(graphs, plotname, plot_notes=[]):
    from ROOT import gPad, gStyle
    from ROOT import TCanvas, TPaveText, TLegend, TF1
    from ROOT import kBlue, kGreen, kRed

    canvas = TCanvas('canvas_' + plotname, 'Canvas', 350, 350)
    gPad.SetLeftMargin(.12)
    gPad.SetTopMargin(.05)
    gPad.SetRightMargin(.05)
    gStyle.SetOptStat(0)

    for graph in graphs:
        graph.SetLineWidth(2)
        graph.SetMarkerSize(.5)

        if graph.GetName() == 'graph_calo_energy':
            graph.SetLineColor(kRed + 2)
            graph.SetMarkerColor(kRed + 2)
            graph.SetMarkerStyle(20)
            graph.SetMarkerSize(.4)

        if graph.GetName() == 'graph_calo_and_cryo_energy':
            graph.SetLineColor(kBlue + 2)
            graph.SetMarkerColor(kBlue + 2)
            graph.SetMarkerStyle(34)

        if graph.GetName() == 'graph_calo_and_corr_energy':
            graph.SetLineColor(kGreen + 3)
            graph.SetMarkerColor(kGreen + 3)
            graph.SetMarkerStyle(47)

        graph.SetTitle(';E_{cluster} [GeV];#sigma / E [%]')
        graph.GetXaxis().SetTitleOffset(1.3)
        graph.SetMaximum(3.5)
        graph.SetMinimum(0.2)

    fit_notes = []
    y_note1 = .82
    y_note2 = .93
    for i, graph in enumerate(graphs):
        func = TF1('func_' + graph.GetName(),
                   'sqrt([0]*[0] + pow([1]/sqrt(x), 2))',
                   min(graph.GetX()),
                   max(graph.GetX()))
        result = graph.Fit(func, 'RS')

        fit_note = TPaveText(.3, y_note1, .56, y_note2, "brNDC")
        fit_note.SetFillStyle(0)
        fit_note.SetFillColor(0)
        fit_note.SetBorderSize(0)
        fit_note.SetTextFont(42)
        fit_note.SetTextAlign(11)
        fit_note.AddText('p_{0} = %.2f #pm  %.2f' % (result.Parameter(0),
                                                     result.ParError(0)))
        fit_note.AddText('p_{1} = %.2f #pm  %.2f' % (result.Parameter(1),
                                                     result.ParError(1)))
        fit_notes.append(fit_note)

        if graph.GetName() == 'graph_calo_energy':
            fit_func = graph.GetFunction('func_' + graph.GetName())
            fit_func.SetLineColor(kRed + 2)
            fit_func.SetLineStyle(7)
            fit_note.SetTextColor(kRed + 2)
        if graph.GetName() == 'graph_calo_and_cryo_energy':
            fit_func = graph.GetFunction('func_' + graph.GetName())
            fit_func.SetLineColor(kBlue + 3)
            fit_func.SetLineStyle(7)
            fit_note.SetTextColor(kBlue + 3)
        if graph.GetName() == 'graph_calo_and_corr_energy':
            fit_func = graph.GetFunction('func_' + graph.GetName())
            fit_func.SetLineColor(kGreen + 3)
            fit_func.SetLineStyle(7)
            fit_note.SetTextColor(kGreen + 3)

        y_note1 -= 0.12
        y_note2 -= 0.12

    legend = TLegend(.6, .8, .95, .94)
    legend.SetBorderSize(0)
    legend.SetFillStyle(0)
    legend.SetFillColor(0)
    for graph in graphs:
        if graph.GetName() == 'graph_calo_energy':
            legend.AddEntry(graph, 'E_{calo}', 'LEP')
        if graph.GetName() == 'graph_calo_and_cryo_energy':
            legend.AddEntry(graph, 'E_{calo+cryo}', 'LEP')
        if graph.GetName() == 'graph_calo_and_corr_energy':
            legend.AddEntry(graph, 'E_{calo+corr}', 'LEP')

    note = TPaveText(.6, .58, .95, .78, "brNDC")
    note.SetFillStyle(0)
    note.SetFillColor(0)
    note.SetBorderSize(0)
    note.SetTextColor(1)
    note.SetTextFont(42)
    note.SetTextAlign(11)
    note.AddText('FCC-ee')
    note.AddText('LAr Calo (12 layers)')
    for note_text in plot_notes:
        note.AddText(note_text)
    note.AddText('#sqrt{p_{0}^{2} + (p_{1} / #sqrt{E_{cluster}} )^{2}}')

    graphs[0].Draw('APE')
    graphs[1].Draw('PE SAME')
    graphs[2].Draw('PE SAME')
    legend.Draw()
    note.Draw()
    for fit_note in fit_notes:
        fit_note.Draw()

    canvas.Print(plotname + '.pdf')


if __name__ == '__main__':
    main()
