#!/usr/bin/env python
import argparse
import json


def main():
    '''
    Derive up(down)stream calorimeter energy correction dependent on cluster energy and theta.

    This script expects json file with fit results from "cec_process_events" script. Input file needs to contain fit
    results in both variables.
    '''
    parser = argparse.ArgumentParser(description='Yay, derive up(down)stream correction parameters!')

    parser.add_argument('-f', '--functions', type=str, default=[], nargs='*',
                        help='fit functions (in ROOT notation)')
    parser.add_argument('-n', '--notes', default=[], type=str, nargs='*',
                        help='additional note to be displayed in the plots')

    required_arguments = parser.add_argument_group('required arguments')
    required_arguments.add_argument('-i', '--input-file', type=str, required=True, help='Input file path')
    required_arguments.add_argument('-t', '--corr-type', type=str, required=True,
                                    choices=['upstream', 'downstream'], help='select correction type')

    args = parser.parse_args()

    print('*********************************************************')
    print('*  Calorimeter Energy Correction: Derive 2D correction  *')
    print('*********************************************************')

    if not args.notes:
        args.notes.append('FCC-ee, LAr Calo')
        args.notes.append('electron')
        print('WARNING: Using default plot legend notes.')

    from ROOT import gROOT
    gROOT.SetBatch(True)

    try:
        with open(args.input_file) as results_file:
            data = json.load(results_file)
            if data.get('results'):
                results = data['results']
            else:
                print('ERROR: Input file does not contain "results" array!')
                print('       ' + args.input_file)
                exit(1)
    except IOError:
        print('ERROR: Input file not found!')
        print('       ' + args.input_file)
        exit(1)

    graphs = {}
    for result in results:
        if result['type'] != args.corr_type:
            continue
        if not graphs.get(result['name']):
            graph = {}
            graph['x'] = []
            graph['y'] = []
            graph['z'] = []
            graph['z_err'] = []
            graph['x_label'] = 'E_{cluster} [GeV]'
            graph['y_label'] = '#theta [deg]'
            graph['z_label'] = result['name']
            graph['index'] = result['index']
            graphs[result['name']] = graph
        graphs[result['name']]['x'].append(result['energy'])
        graphs[result['name']]['y'].append(result['theta'])
        graphs[result['name']]['z'].append(result['value'])
        graphs[result['name']]['z_err'].append(result['error'])

    from ROOT import TGraph2D, TGraph2DErrors, TF2

    graphs = sorted(graphs.items(), key=lambda graph: graph[1]['index'])

    param_counter = 97
    for graph_name, graph_dict in graphs:
        if len(graph_dict['x']) < 2 or len(graph_dict['y']) < 2 or len(graph_dict['z']) < 2:
            print('WARNING: Malformed graph dictionary, skipping it.')
            continue

        graph_dict['x'], graph_dict['y'], graph_dict['z'], graph_dict['z_err'] = zip(*sorted(zip(graph_dict['x'],
                                                                                                 graph_dict['y'],
                                                                                                 graph_dict['z'],
                                                                                                 graph_dict['z_err'])))
        graph = TGraph2DErrors()
        graph.SetName('graph_' + args.corr_type + '_' +
                      graph_name.replace('#', '').replace('{', '').replace('}', ''))
        for x, y, z, z_err in zip(graph_dict['x'], graph_dict['y'], graph_dict['z'], graph_dict['z_err']):
            graph.SetPoint(graph.GetN(), x, y, z)
            graph.SetPointError(graph.GetN() - 1, 0., 0., z_err)

        graph.SetTitle('')
        graph.GetHistogram().GetXaxis().SetTitle(graph_dict['x_label'])
        graph.GetHistogram().GetYaxis().SetTitle(graph_dict['y_label'])
        graph.GetHistogram().GetZaxis().SetTitle(graph_dict['z_label'])

        if (graph.GetN() < 6):
            print('WARNING: Number of points in parameter graph is too small!')
            exit(0)

        try:
            overhang_x = 0.1 * abs(max(graph_dict['x']) - min(graph_dict['x']))
            overhang_y = 0.1 * abs(max(graph_dict['y']) - min(graph_dict['y']))
            func = TF2("func", args.functions[graph_dict['index']],
                       min(graph_dict['x']) - overhang_x,
                       max(graph_dict['x']) + overhang_x,
                       min(graph_dict['y']) - overhang_y,
                       max(graph_dict['y']) + overhang_y)
        except IndexError:
            print('WARNING: Fitting function for parameter "' + graph_name + '" not provided.')
            print('         Skipping fit!')
        else:
            for i in range(func.GetNpar()):
                func.SetParName(i, chr(param_counter))
                param_counter += 1

            result = graph.Fit(func, "SR")

            corr_params = []
            try:
                with open('corr_params_2d.json') as infile:
                    data = json.load(infile)
                    if data.get('corr_params'):
                        corr_params = data['corr_params']
            except IOError:
                pass

            for i in range(func.GetNpar()):
                param_dict = {}
                param_dict['name'] = func.GetParName(i)
                param_dict['value'] = result.Get().Parameter(i)
                param_dict['error'] = result.Get().Error(i)
                param_dict['func'] = args.functions[graph_dict['index']]
                param_dict['mother_param'] = graph_name
                param_dict['type'] = args.corr_type
                param_dict['index'] = i

                corr_params = [p for p in corr_params if not (p['name'] == param_dict['name'] and
                                                              p['type'] == param_dict['type'])]
                corr_params.append(param_dict)

            with open('corr_params_2d.json', 'w') as outfile:
                json.dump({'corr_params': corr_params}, outfile, indent=4)

        plot(graph, graph.GetName(), args)

    dict_chi2 = {}
    dict_chi2['x'] = []
    dict_chi2['y'] = []
    dict_chi2['z'] = []
    for result in results:
        if result['type'] != args.corr_type:
            continue
        if result['index'] != 0:
            continue
        dict_chi2['x'].append(result['energy'])
        dict_chi2['y'].append(result['theta'])
        dict_chi2['z'].append(result['chi2']/result['ndf'])

    if not dict_chi2['x']:
        return

    graph_chi2 = TGraph2D()
    graph_chi2.SetName('graph_' + args.corr_type + '_chi2')
    for x, y, z in zip(dict_chi2['x'], dict_chi2['y'], dict_chi2['z']):
        graph_chi2.SetPoint(graph_chi2.GetN(), x, y, z)

    graph_chi2.SetTitle('')
    graph_chi2.GetHistogram().GetXaxis().SetTitle('E_{cluster} [GeV]')
    graph_chi2.GetHistogram().GetYaxis().SetTitle('#theta [deg]')
    graph_chi2.GetHistogram().GetZaxis().SetTitle('#chi^{2} / NDF')
    plot(graph_chi2, graph_chi2.GetName(), args)


def plot(obj, plotname, args=None):
    from ROOT import gPad, gStyle
    from ROOT import TCanvas, TPaveText
    from ROOT import kBlue
    func = obj.GetListOfFunctions().Last()
    if func:
        canvas = TCanvas('canvas_' + plotname, 'Canvas', 700, 350)
        gPad.SetLeftMargin(.05)
        gPad.SetTopMargin(.05)
        gPad.SetRightMargin(.5)
        obj.GetXaxis().SetTitleOffset(1.4)
        obj.GetYaxis().SetTitleOffset(.7)
        obj.GetZaxis().SetTitleOffset(1.)

        legend = TPaveText(.63, .1, 0.98, .95, 'brNDC')
        legend.SetTextColor(1)

    else:
        canvas = TCanvas('canvas_' + plotname, 'Canvas', 350, 350)
        gPad.SetLeftMargin(.1)
        gPad.SetTopMargin(.05)
        gPad.SetRightMargin(.15)
        obj.GetXaxis().SetTitleOffset(1.4)
        obj.GetYaxis().SetTitleOffset(1.4)
        obj.GetZaxis().SetTitleOffset(1.4)

        gStyle.SetOptStat(11)
        gStyle.SetOptFit(1111)

        obj.SetLineColor(kBlue)
        obj.SetLineWidth(2)
        obj.SetMarkerColor(kBlue)
        obj.SetMarkerStyle(20)
        obj.SetMarkerSize(.5)

        legend = TPaveText(.2, .7, .5, .9, 'brNDC')
        legend.SetTextColor(0)

    legend.SetFillStyle(0)
    legend.SetFillColor(0)
    legend.SetBorderSize(0)
    legend.SetTextFont(42)
    legend.SetTextAlign(11)

    for note_text in args.notes:
        legend.AddText(note_text)

    if func:
        for i in range(func.GetNpar()):
            note_text = '%s = %.2g #pm %.2g' % (func.GetParName(i), func.GetParameter(i), func.GetParError(i))
            legend.AddText(note_text)
        formula = func.GetFormula().GetExpFormula().Data()
        formula = formula.replace(']', '')
        formula = formula.replace('[', '')
        if formula[0] == '(' and formula[-1] == ')':
            formula = formula[1:-1]
        formula = formula.replace('x', 'E_{cluster}')
        formula = formula.replace('y', '#theta')
        legend.AddText(formula)
        legend.GetListOfLines().Last().SetTextColor(func.GetLineColor())

    obj.Draw('COLZ')
    legend.Draw('same')
    canvas.Print(plotname + '.pdf')

    if func:
        func.SetTitle('')
        func.GetHistogram().GetXaxis().SetTitle('E_{cluster} [GeV]')
        func.GetHistogram().GetYaxis().SetTitle('#theta [deg]')
        func.GetHistogram().GetZaxis().SetTitle(obj.GetHistogram().GetZaxis().GetTitle())
        func.GetHistogram().GetXaxis().SetTitleOffset(1.4)
        func.GetHistogram().GetYaxis().SetTitleOffset(0.7)
        func.Draw('COLZ')
        legend.Draw('same')
        plotname = plotname.replace('graph', 'func')
        canvas.Print(plotname + '.pdf')

        from ROOT import TGraph2D
        import ctypes

        res = TGraph2D()
        for i in range(obj.GetN()):
            x, y, graph_val = ctypes.c_double(0.), ctypes.c_double(0.), ctypes.c_double(0.)
            obj.GetPoint(i, x, y, graph_val)
            x = x.value
            y = y.value
            graph_val = graph_val.value
            func_val = func.Eval(x, y)
            res.SetPoint(i, x, y, 100 * (graph_val - func_val) / graph_val)

        res.SetTitle('')
        res.GetHistogram().GetXaxis().SetTitle('E_{cluster} [GeV]')
        res.GetHistogram().GetYaxis().SetTitle('#theta [deg]')
        res.GetHistogram().GetZaxis().SetTitle('Residuals [%]')
        res.GetHistogram().GetXaxis().SetTitleOffset(1.4)
        res.GetHistogram().GetYaxis().SetTitleOffset(0.7)

        canvas.Update()
        res.Draw('COLZ')
        canvas.Update()
        legend.Draw('same')
        plotname = plotname.replace('func', 'residuals')
        canvas.Print(plotname + '.pdf')


if __name__ == '__main__':
    main()
