#!/usr/bin/env python
import argparse
import json
import os


def main():
    '''
    Derive up(down)stream calorimeter energy correction dependent on the cluster energy.

    This script expects json files with the fit results from "cec_process_events" script.
    '''
    parser = argparse.ArgumentParser(description='Yay, derive up(down)stream correction parameters!')

    parser.add_argument('-f', '--functions', type=str, default=[], nargs='*',
                        help='fit functions (in ROOT notation)')
    parser.add_argument('-n', '--notes', default=[], type=str, nargs='*',
                        help='additional note to be displayed in the plots')
    parser.add_argument('--theta-slice', type=float, default=90.,
                        help='select theta slice for the correction derivation')
    parser.add_argument('--plot-file-format', type=str, default='pdf', help='plot output file format')
    parser.add_argument('--plot-directory', type=str, default='./', help='Directory of the output plots')

    required_arguments = parser.add_argument_group('required arguments')
    required_arguments.add_argument('-i', '--input-file', type=str, required=True, help='Input file path')
    required_arguments.add_argument('-t', '--corr-type', type=str, required=True,
                                    choices=['upstream', 'downstream'], help='select correction type')

    args = parser.parse_args()

    print('*********************************************************')
    print('*  Calorimeter Energy Correction: Derive 1D correction  *')
    print('*********************************************************')

    if not args.notes:
        args.notes.append('FCC-ee, LAr Calo')
        args.notes.append('electron')
        args.notes.append('#theta = %i deg' % (args.theta_slice))
        print('WARNING: Using default plot legend notes.')

    if not os.path.isdir(args.plot_directory):
        os.mkdir(args.plot_directory)

    from ROOT import gROOT
    gROOT.SetBatch(True)

    print('INFO: Loading input file from:')
    print('      ' + args.input_file)

    try:
        with open(args.input_file) as results_file:
            data = json.load(results_file)
            if data.get('results'):
                results = data['results']
            else:
                print('ERROR: Input file does not contain "results" array!')
                print('       ' + args.input_file)
                exit(1)
    except IOError:
        print('ERROR: Input file not found!')
        print('       ' + args.input_file)
        exit(1)

    graphs = {}
    for result in results:
        if result['type'] != args.corr_type:
            continue
        if args.theta_slice != result['theta']:
            continue
        if not graphs.get(result['name']):
            graph = {}
            graph['x'] = []
            graph['y'] = []
            graph['y_err'] = []
            graph['x_label'] = 'E_{cluster} [GeV]'
            graph['y_label'] = result['name']
            graph['index'] = result['index']
            graphs[result['name']] = graph
        graphs[result['name']]['x'].append(result['energy'])
        graphs[result['name']]['y'].append(result['value'])
        graphs[result['name']]['y_err'].append(result['error'])

    from ROOT import TGraphErrors, TGraph, TF1

    graphs = sorted(graphs.items(), key=lambda graph: graph[1]['index'])

    param_counter = 97
    for graph_name, graph_dict in graphs:
        graph_dict['x'], graph_dict['y'], graph_dict['y_err'] = zip(*sorted(zip(graph_dict['x'],
                                                                                graph_dict['y'],
                                                                                graph_dict['y_err'])))

        graph = TGraphErrors()
        graph.SetName('graph_' + args.corr_type + '_energy_' +
                      graph_name.replace('#', '').replace('{', '').replace('}', ''))
        for x, y, y_err in zip(graph_dict['x'], graph_dict['y'], graph_dict['y_err']):
            graph.SetPoint(graph.GetN(), x, y)
            graph.SetPointError(graph.GetN() - 1, 0., y_err)

        graph.GetXaxis().SetTitle(graph_dict['x_label'])
        graph.GetYaxis().SetTitle(graph_dict['y_label'])

        if (graph.GetN() < 3):
            print('WARNING: Number of points in parameter graph is too small!')
            exit(0)

        try:
            overhang = 0.1 * abs(max(graph_dict['x']) - min(graph_dict['x']))
            func = TF1("func", args.functions[graph_dict['index']],
                       min(graph_dict['x']) - overhang,
                       max(graph_dict['x']) + overhang)
        except IndexError:
            print('WARNING: Fitting function for parameter "' + graph_name + '" not provided.')
            print('         Skipping fit!')
        else:
            for i in range(func.GetNpar()):
                func.SetParName(i, chr(param_counter))
                param_counter += 1

            result = graph.Fit(func, "SR")

            corr_params = []
            try:
                with open(os.path.join(args.plot_directory, 'corr_params_1d.json')) as infile:
                    data = json.load(infile)
                    if data.get('corr_params'):
                        corr_params = data['corr_params']
            except IOError:
                pass

            for i in range(func.GetNpar()):
                param_dict = {}
                param_dict['name'] = func.GetParName(i)
                param_dict['value'] = result.Get().Parameter(i)
                param_dict['error'] = result.Get().Error(i)
                param_dict['func'] = args.functions[graph_dict['index']]
                param_dict['mother_param'] = graph_name
                param_dict['type'] = args.corr_type
                param_dict['index'] = i
                param_dict['theta_slice'] = args.theta_slice

                corr_params = [p for p in corr_params if not (p['name'] == param_dict['name'] and
                                                              p['type'] == param_dict['type'])]
                corr_params.append(param_dict)

            with open(os.path.join(args.plot_directory, 'corr_params_1d.json'), 'w') as outfile:
                json.dump({'corr_params': corr_params}, outfile, indent=4)

        plot(graph, graph.GetName(), args)

    dict_chi2 = {}
    dict_chi2['x'] = []
    dict_chi2['y'] = []
    for result in results:
        if result['type'] != args.corr_type:
            continue
        if args.theta_slice != result['theta']:
            continue
        if result['index'] != 0:
            continue
        dict_chi2['x'].append(result['energy'])
        dict_chi2['y'].append(result['chi2']/result['ndf'])

    if not dict_chi2['x']:
        return

    dict_chi2['x'], dict_chi2['y'] = zip(*sorted(zip(dict_chi2['x'], dict_chi2['y'])))

    graph_chi2 = TGraph()
    graph_chi2.SetName('graph_' + args.corr_type + '_energy_chi2')
    for x, y in zip(dict_chi2['x'], dict_chi2['y']):
        graph_chi2.SetPoint(graph_chi2.GetN(), x, y)

    graph_chi2.GetXaxis().SetTitle('E_{cluster} [GeV]')
    graph_chi2.GetYaxis().SetTitle('#chi^{2} / NDF')
    plot(graph_chi2, graph_chi2.GetName(), args)


def plot(obj, plotname, args=None):
    from ROOT import gPad, gStyle
    from ROOT import TCanvas, TPaveText
    from ROOT import kBlue
    canvas = TCanvas('canvas_' + plotname, 'Canvas', 350, 350)
    gPad.SetLeftMargin(.13)
    gPad.SetTopMargin(.05)
    gPad.SetRightMargin(.05)

    gStyle.SetOptStat(11)
    gStyle.SetOptFit(1111)

    obj.SetLineColor(kBlue)
    obj.SetLineWidth(2)
    obj.SetMarkerColor(kBlue)
    obj.SetMarkerStyle(20)
    obj.SetMarkerSize(.5)
    obj.GetXaxis().SetTitleOffset(1.3)
    obj.GetYaxis().SetTitleOffset(1.9)

    obj.SetMaximum(1.4 * max(obj.GetY()))

    legend = TPaveText(.2, .7, .5, .9, 'brNDC')
    legend.SetFillStyle(0)
    legend.SetFillColor(0)
    legend.SetBorderSize(0)
    legend.SetTextColor(1)
    legend.SetTextFont(42)
    legend.SetTextAlign(11)
    for note_text in args.notes:
        legend.AddText(note_text)

    func = obj.GetListOfFunctions().Last()
    if func:
        formula = func.GetFormula().GetExpFormula().Data()
        formula = formula.replace(']', '')
        formula = formula.replace('[', '')
        if formula[0] == '(' and formula[-1] == ')':
            formula = formula[1:-1]
        formula = formula.replace('x', 'E_{cluster}')
        legend.AddText('Fit func.:   %s' % (formula))
        legend.GetListOfLines().Last().SetTextColor(func.GetLineColor())

    obj.Draw('APEL')
    legend.Draw()
    plotname += '_%ideg' % (args.theta_slice)
    canvas.Print(os.path.join(args.plot_directory, plotname + '.' + args.plot_file_format))

if __name__ == '__main__':
    main()
