#!/usr/bin/env python
import os
import argparse
import math
import json


def main():
    '''
    Histogram and fit up(down)stream energy vs. energy in first (last) calorimeter layer.
    Developed for FCC-ee LAr Calorimeter.

    This script expects root file containing TTree "events" with following branches:
     * std::vector<double> energyInLayer -- energy deposited in calorimeter layers
          -- layer_id 0 at position 0
          -- layer_id 1 at position 1
          -- etc.
     * std::vector<double> energyInCryo -- energy deposited in the whole cryostat and in its parts
          -- energy deposited in whole cryostat at position 0
          -- energy deposited in front cryostat at position 1
          -- energy deposited in back cryostat at position 2
          -- energy deposited in cryostat sides at position 3
          -- energy deposited in front LAr bath at position 4
          -- energy deposited in back LAr bath at position 5
     * std::vector<double> particleVec -- four-momentum of initial particle
          -- mass at position 0
          -- px at position 1
          -- py at position 2
          -- pz at position 3

    Output of this script serves as input for the `cec_derive1` or `cec_derive2` scripts.

    To run over multiple energies/angles of initial particle use simple for loop, e.g.:

      for i in *root; do ./cec_process_events -i $i -t upstream; done

    The fit results are stored in json file named `fit_results.json`.

    The script requires two arguments to be provided:
        -i    input ROOT file
        -t    type of the correction
    '''
    parser = argparse.ArgumentParser(description='Yay, histogram and fit up(down)stream energy vs. energy in layer!')

    parser.add_argument('-l', '--layer-id', type=int, default=0, help='ID of the calorimeter layer')
    parser.add_argument('-f', '--function', type=str, help='fit function (in ROOT notation)')
    parser.add_argument('--func-from', type=float, help='fit function starting value')
    parser.add_argument('--func-to', type=float, help='fit function ending value')
    parser.add_argument('-r', '--rebin-factor', type=int, help='number of bins to be merged')
    parser.add_argument('-b', '--n-bins', type=int, help='number of bins of the histogram(s)')
    parser.add_argument('-n', '--note', action='append', default=[], type=str,
                        help='additional note to be displayed in the plots')
    parser.add_argument('-o', '--output', type=str, default='fit_results.json', help='output JSON file name')
    parser.add_argument('--plot-file-format', type=str, default='pdf', help='plot output file format')
    parser.add_argument('--plot-directory', type=str, default='./', help='Directory of the output plots')

    required_arguments = parser.add_argument_group('required arguments')
    required_arguments.add_argument('-i', '--input-file', type=str, required=True, help='input file path')
    required_arguments.add_argument('-t', '--corr-type', type=str, required=True,
                                    choices=['upstream', 'downstream'], help='select correction type')

    args = parser.parse_args()

    print('****************************************************************')
    print('*  Calorimeter Energy Correction: Processing generated events  *')
    print('****************************************************************')

    if not args.layer_id:
        if args.corr_type == 'upstream':
            args.layer_id = 0
        elif args.corr_type == 'downstream':
            args.layer_id = -1
        print('WARNING: Using default calorimeter layer id %i.' % args.layer_id)

    if not args.function:
        args.function = 'pol1'
        print('WARNING: Using default fitting function "pol1".')

    if not args.note:
        args.note.append('FCC-ee, LAr Calo')
        args.note.append('electron')
        print('WARNING: Using default plot legend notes.')

    if not os.path.isfile(args.input_file):
        print('WARNING: Input ROOT file not found!')
        print('         ' + args.input_file)
        exit(1)

    if not os.path.isdir(args.plot_directory):
        os.mkdir(args.plot_directory)

    from ROOT import gROOT, TFile, TH2D, TF1
    gROOT.SetBatch(True)

    infile = TFile(args.input_file, 'READ')

    if not infile.events:
        print('ERROR: TTree "events" not found in the provided ROOT file!')
        infile.Close()
        exit(1)

    n_events = infile.events.GetEntries()
    if not args.n_bins:
        if args.rebin_factor:
            args.n_bins = int(math.ceil(math.sqrt(n_events))/args.rebin_factor)
        else:
            args.n_bins = int(math.ceil(math.sqrt(n_events)))
    print('INFO: Number of histogram bins: %i' % args.n_bins)

    hist_name = '%sEnergy__energyInLayer_%i' % (args.corr_type, args.layer_id)
    hist_title = ';Energy in layer id: %i [GeV];%s Energy [GeV]' % (args.layer_id, args.corr_type.capitalize())
    hist = TH2D(hist_name, hist_title, int(args.n_bins), 0., 0., int(args.n_bins), 0., 0.)
    hist.Sumw2()
    hist.SetDirectory(0)

    particle_momentum = 0.
    particle_theta = 0.
    for event in infile.events:
        particle_momentum += get_momentum(event)
        particle_theta += get_theta(event)

        if args.corr_type == 'upstream':
            cryo_energy = event.energyInCryo[1] + event.energyInCryo[4]
        if args.corr_type == 'downstream':
            cryo_energy = event.energyInCryo[2] + event.energyInCryo[5]
        hist.Fill(event.energyInLayer[args.layer_id], cryo_energy)

    infile.Close()
    hist.BufferEmpty()

    particle_momentum = int(round(particle_momentum / n_events))
    particle_theta = int(round(particle_theta / n_events))
    print('INFO: Initial particle momentum: %i GeV' % (particle_momentum))
    print('                       theta:    %i deg' % (particle_theta))

    args.note.append('%i GeV, %i deg' % (particle_momentum, particle_theta))

    if args.rebin_factor:
        hist.Rebin(args.rebin_factor)

    plot_name = '%s_vs_layer_%i_%ideg_%igev' % (args.corr_type, args.layer_id, particle_theta, particle_momentum)
    plot(hist, 'hist_' + plot_name, args)

    profile = hist.ProfileX('profile_x')
    profile.GetYaxis().SetTitle('Mean ' + hist.GetYaxis().GetTitle())

    if count_bins_with_error(profile) < 3:
        print('WARNING: Number of non empty bins too small, ignoring histogram!')
        exit(0)

    if not args.func_from:
        args.func_from = hist.GetXaxis().GetBinLowEdge(hist.GetXaxis().GetFirst())
    if not args.func_to:
        args.func_to = hist.GetXaxis().GetBinUpEdge(hist.GetXaxis().GetLast())
    func = TF1('func', args.function, args.func_from, args.func_to)
    for i in range(func.GetNpar()):
        if args.corr_type == 'upstream':
            func.SetParName(i, '#upsilon_{%i}' % i)
        if args.corr_type == 'downstream':
            func.SetParName(i, '#delta_{%i}' % i)

    result = profile.Fit(func, 'SR')

    plot(profile, 'profile_' + plot_name, args)

    results = []
    try:
        with open(args.output) as results_file:
            data = json.load(results_file)
            if data.get('results'):
                results = data['results']
    except IOError:
        pass

    for i in range(func.GetNpar()):
        param_dict = {}
        param_dict['name'] = func.GetParName(i)
        param_dict['energy'] = int(round(particle_momentum))
        param_dict['theta'] = int(round(particle_theta))
        param_dict['value'] = result.Get().Parameter(i)
        param_dict['error'] = result.Get().Error(i)
        param_dict['chi2'] = result.Get().Chi2()
        param_dict['ndf'] = result.Get().Ndf()
        param_dict['index'] = i
        param_dict['n_params'] = func.GetNpar()
        param_dict['func'] = func.GetFormula().GetExpFormula().Data()
        param_dict['type'] = args.corr_type

        results = [r for r in results if not (r['name'] == param_dict['name'] and
                                              r['type'] == param_dict['type'] and
                                              r['energy'] == param_dict['energy'] and
                                              r['theta'] == param_dict['theta'])]
        results.append(param_dict)

    with open(args.output, 'w') as results_file:
        json.dump({'results': results}, results_file, indent=4)


def get_momentum(event):
    momentum = math.sqrt(math.pow(event.particleVec[1], 2) +
                         math.pow(event.particleVec[2], 2) +
                         math.pow(event.particleVec[3], 2))

    return momentum


def get_theta(event):
    r = math.sqrt(math.pow(event.particleVec[1], 2) + math.pow(event.particleVec[2], 2))
    theta = abs(math.atan2(r, event.particleVec[3]))
    theta = 180 * theta / math.pi

    return theta


def plot(obj, plotname, args):
    from ROOT import gPad, gStyle
    from ROOT import TCanvas, TPaveText
    canvas = TCanvas('canvas_' + plotname, 'Canvas', 450, 450)
    gPad.SetLeftMargin(.13)
    gPad.SetTopMargin(.05)

    gStyle.SetOptStat(11)
    gStyle.SetOptFit(1111)
    draw_options = ''

    obj.SetTitle('')

    if 'TH2' in obj.ClassName():
        gPad.SetRightMargin(.13)
        draw_options = 'COLZ'

    if 'TProfile' in obj.ClassName():
        gPad.SetRightMargin(.05)
        obj.SetLineWidth(2)

    legend = TPaveText(.2, .7, .5, .9, 'brNDC')
    legend.SetFillStyle(0)
    legend.SetFillColor(0)
    legend.SetBorderSize(0)
    legend.SetTextColor(1)
    legend.SetTextFont(42)
    legend.SetTextAlign(11)
    for note in args.note:
        legend.AddText(note)

    func = obj.GetListOfFunctions().Last()
    if func:
        formula = func.GetFormula().GetExpFormula().Data()
        formula = formula.replace(']', '')
        formula = formula.replace('[', '')
        if formula[0] == '(' and formula[-1] == ')':
            formula = formula[1:-1]
        legend.AddText('Fit func.:   %s' % (formula))
        legend.GetListOfLines().Last().SetTextColor(func.GetLineColor())

    obj.Draw(draw_options)
    legend.Draw()
    canvas.Print(os.path.join(args.plot_directory, plotname + '.' + args.plot_file_format))


def count_bins_with_error(hist):
    num = 0
    for i in range(hist.GetYaxis().GetFirst(), hist.GetXaxis().GetLast() + 1):
        err = hist.GetBinError(i)
        if err == 0.:
            continue
        num += 1

    return num


if __name__ == '__main__':
    main()
