#!/usr/bin/env python
import os
import argparse
import math


def main():
    '''
    Plot several troubleshooting histograms for up(down)stream energy calorimeter correction.

    This script expects root file containing TTree "events" with following branches:
     * std::vector<double> energyInLayer -- energy deposited in calorimeter layers
          -- layer_id 0 at position 0
          -- etc.
     * std::vector<double> energyInCryo -- energy deposited in the whole cryostat and in its parts
          -- energy deposited in whole cryostat at position 0
          -- energy deposited in front cryostat at position 1
          -- energy deposited in back cryostat at position 2
          -- energy deposited in cryostat sides at position 3
          -- energy deposited in front LAr bath at position 4
          -- energy deposited in back LAr bath at position 5
     * std::vector<double> particleVec -- four-momentum of initial particle
          -- mass at position 0
          -- px at position 1
          -- py at position 2
          -- pz at position 3
    '''
    parser = argparse.ArgumentParser(description='Yay, plot troubleshooting histograms!')

    parser.add_argument('-n', '--note', action='append', default=[], type=str,
                        help='additional note to be displayed in the plots')
    parser.add_argument('--output-file-format', type=str, default='pdf', help='plot output file format')

    required_arguments = parser.add_argument_group('required arguments')
    required_arguments.add_argument('-i', '--input-file', type=str, required=True, help='Input file path')

    args = parser.parse_args()

    if not args.note:
        args.note.append('FCC-ee, LAr Calo')
        args.note.append('electron')
        print('WARNING: Using default plot legend notes.')

    if not os.path.isfile(args.input_file):
        print('WARNING: Input ROOT file not found!')
        print('         ' + args.input_file)
        exit(1)

    from ROOT import gROOT, TFile, TH1D, TH2D
    gROOT.SetBatch(True)

    infile = TFile(args.input_file, 'READ')

    if not infile.events:
        print('ERROR: TTree "events" not found in the provided ROOT file!')
        infile.Close()
        exit(1)

    n_events = infile.events.GetEntries()
    particle_momentum = 0.
    particle_theta = 0.
    n_layers = -1
    for event in infile.events:
        particle_momentum += get_momentum(event)
        particle_theta += get_theta(event)
        if n_layers < 0:
            n_layers = len(event.energyInLayer)
    particle_momentum = int(round(particle_momentum / n_events))
    particle_theta = int(round(particle_theta / n_events))
    print('INFO: Initial particle momentum: %i GeV' % (particle_momentum))
    print('                       theta:    %i deg' % (particle_theta))

    # Positional histograms
    hist_energy_vs_phi = TH1D('energy_vs_phi', 'Energy deposited at angle #phi;#phi;Energy [GeV]',
                              100, 0., 2*math.pi)
    hist_energy_vs_phi.Sumw2()
    hist_energy_vs_phi.SetDirectory(0)

    hist_particle_momentum_xy = TH2D("particle_momentum_xy",
                                     "Particle momentum in X-Y plane;p_{x};p_{y}",
                                     100, 0., 0., 100, 0., 0.)
    hist_particle_momentum_xy.SetDirectory(0)

    hist_particle_momentum_zrxy = TH2D("particle_momentum_zrxy",
                                       "Particle momentum in Z-R_XY;p_{z};p_{r_{xy}}",
                                       100, 0., 0., 100, 0., 0.)
    hist_particle_momentum_zrxy.SetDirectory(0)

    # Energy deposited in calorimeter or crystat
    hist_energy_vs_layer = TH1D("energy_vs_layer", "Energy deposited in layer ;layer ID;Energy [GeV]",
                                n_layers, 0, n_layers)
    hist_energy_vs_layer.SetDirectory(0)
    hist_calo_energy = TH1D("calo_energy", "Sum of energy deposited in all calorimeter layers;Energy [GeV];N_{evt}",
                            100, particle_momentum/2., particle_momentum*1.3)
    hist_calo_energy.SetDirectory(0)
    hist_cryo_energy = TH1D("cryo_energy", "Energy deposited in cryostat;Energy [GeV];N_{evt}", 100, 0., 0.)
    hist_cryo_energy.SetDirectory(0)
    hist_cryo_energy_front = TH1D("cryo_energy_front", "Energy deposited in front cryostat;Energy [GeV];N_{evt}",
                                  100, 0., 0.)
    hist_cryo_energy_front.SetDirectory(0)
    hist_cryo_energy_back = TH1D("cryo_energy_back", "Energy deposited in back cryostat;Energy [GeV];N_{evt}",
                                 100, 0., 0.)
    hist_cryo_energy_back.SetDirectory(0)
    hist_cryo_energy_sides = TH1D("cryo_energy_sides", "Energy deposited in cryostat sides;Energy [GeV];N_{evt}",
                                  100, 0., 0.)
    hist_cryo_energy_sides.SetDirectory(0)
    hist_cryo_energy_lar_bath_front = TH1D("cryo_energy_lar_bath_front",
                                           'Energy deposited in cryostat\'s front LAr bath;Energy [GeV];N_{evt}',
                                           100, 0., 0.)
    hist_cryo_energy_lar_bath_front.SetDirectory(0)
    hist_cryo_energy_lar_bath_back = TH1D("cryo_energy_lar_bath_back",
                                          'Energy deposited in cryostat\'s back LAr bath;Energy [GeV];N_{evt}',
                                          100, 0., 0.)
    hist_cryo_energy_lar_bath_back.SetDirectory(0)

    # Sums of calorimeter and cryostat energies
    hist_calo_and_cryo_energy = TH1D("calo_and_cryo_energy",
                                     "Energy deposited in calorimeter and cryostat;E [GeV];N_{evt}",
                                     100, particle_momentum/2., particle_momentum*1.3)
    hist_calo_and_cryo_energy.SetDirectory(0)
    hist_calo_and_cryo_front_energy = TH1D("calo_and_cryo_front_energy",
                                           "Energy deposited in calorimeter and front cryostat;E [GeV];N_{evt}",
                                           100, particle_momentum/2., particle_momentum*1.3)
    hist_calo_and_cryo_front_energy.SetDirectory(0)
    hist_calo_and_cryo_back_energy = TH1D("calo_and_cryo_back_energy",
                                          "Energy deposited in calorimeter and back cryostat;E [GeV];N_{evt}",
                                          100, particle_momentum/2., particle_momentum*1.3)
    hist_calo_and_cryo_back_energy.SetDirectory(0)
    hist_cryo_front_and_lar_bath_front_energy = TH1D(
        "cryo_front_and_lar_bath_front_energy",
        "Energy deposited in cryostat front and LAr bath front;E [GeV];N_{evt}",
        100, 0., 0.)
    hist_cryo_front_and_lar_bath_front_energy.SetDirectory(0)
    hist_cryo_back_and_lar_bath_back_energy = TH1D(
        "cryo_back_and_lar_bath_back_energy",
        "Energy deposited in cryostat back and LAr bath back;E [GeV];N_{evt}",
        100, 0., 0.)
    hist_cryo_back_and_lar_bath_back_energy.SetDirectory(0)

    for event in infile.events:
        calo_energy = get_energy(event)

        hist_energy_vs_phi.Fill(get_phi(event), calo_energy)
        hist_particle_momentum_xy.Fill(event.particleVec[0], event.particleVec[1])
        hist_particle_momentum_zrxy.Fill(event.particleVec[2], get_rxy(event))

        for i in range(event.energyInLayer.size()):
            hist_energy_vs_layer.Fill(i, event.energyInLayer[i])

        hist_calo_energy.Fill(calo_energy)
        hist_cryo_energy.Fill(event.energyInCryo[0])
        hist_cryo_energy_front.Fill(event.energyInCryo[1])
        hist_cryo_energy_back.Fill(event.energyInCryo[2])
        hist_cryo_energy_sides.Fill(event.energyInCryo[3])
        hist_cryo_energy_lar_bath_front.Fill(event.energyInCryo[4])
        hist_cryo_energy_lar_bath_back.Fill(event.energyInCryo[5])

        hist_calo_and_cryo_energy.Fill(calo_energy + event.energyInCryo[0])
        hist_calo_and_cryo_front_energy.Fill(calo_energy + event.energyInCryo[1] + event.energyInCryo[4])
        hist_calo_and_cryo_back_energy.Fill(calo_energy + event.energyInCryo[2] + event.energyInCryo[5])
        hist_cryo_front_and_lar_bath_front_energy.Fill(event.energyInCryo[1] + event.energyInCryo[4])
        hist_cryo_back_and_lar_bath_back_energy.Fill(event.energyInCryo[2] + event.energyInCryo[5])

    infile.Close()
    hist_particle_momentum_xy.BufferEmpty()
    hist_particle_momentum_zrxy.BufferEmpty()
    hist_cryo_energy.BufferEmpty()
    hist_cryo_energy_front.BufferEmpty()
    hist_cryo_energy_back.BufferEmpty()
    hist_cryo_energy_sides.BufferEmpty()
    hist_cryo_energy_lar_bath_front.BufferEmpty()
    hist_cryo_energy_lar_bath_back.BufferEmpty()
    hist_cryo_front_and_lar_bath_front_energy.BufferEmpty()
    hist_cryo_back_and_lar_bath_back_energy.BufferEmpty()

    args.note.append('%i GeV, %i deg' % (particle_momentum, particle_theta))

    plot(hist_energy_vs_phi, 'hist_energy_vs_phi_%ideg_%igev' % (particle_theta, particle_momentum), args)
    plot(hist_particle_momentum_xy, 'hist_particle_momentum_xy_%ideg_%igev' % (particle_theta, particle_momentum), args)
    plot(hist_particle_momentum_zrxy, 'hist_particle_momentum_zrxy_%ideg_%igev' % (particle_theta, particle_momentum),
         args)

    plot(hist_energy_vs_layer, 'hist_energy_vs_layer_%ideg_%igev' % (particle_theta, particle_momentum), args)
    plot(hist_calo_energy, 'hist_calo_energy_%ideg_%igev' % (particle_theta, particle_momentum), args)
    plot(hist_cryo_energy, 'hist_cryo_energy_%ideg_%igev' % (particle_theta, particle_momentum), args)
    plot(hist_cryo_energy_front, 'hist_cryo_energy_front_%ideg_%igev' % (particle_theta, particle_momentum), args)
    plot(hist_cryo_energy_back, 'hist_cryo_energy_back_%ideg_%igev' % (particle_theta, particle_momentum), args)
    plot(hist_cryo_energy_sides, 'hist_cryo_energy_sides_%ideg_%igev' % (particle_theta, particle_momentum), args)
    plot(hist_cryo_energy_lar_bath_front,
         'hist_cryo_energy_lar_bath_front_%ideg_%igev' % (particle_theta, particle_momentum),
         args)
    plot(hist_cryo_energy_lar_bath_back,
         'hist_cryo_energy_lar_bath_back_%ideg_%igev' % (particle_theta, particle_momentum),
         args)

    plot(hist_calo_and_cryo_energy, 'hist_calo_and_cryo_energy_%ideg_%igev' % (particle_theta, particle_momentum), args)
    plot(hist_calo_and_cryo_front_energy,
         'hist_calo_and_cryo_front_energy_%ideg_%igev' % (particle_theta, particle_momentum),
         args)
    plot(hist_calo_and_cryo_back_energy,
         'hist_calo_and_cryo_back_energy_%ideg_%igev' % (particle_theta, particle_momentum),
         args)
    plot(hist_cryo_front_and_lar_bath_front_energy,
         'hist_cryo_front_and_lar_bath_front_energy_%ideg_%igev' % (particle_theta, particle_momentum),
         args)
    plot(hist_cryo_back_and_lar_bath_back_energy,
         'hist_cryo_back_and_lar_bath_back_energy_%ideg_%igev' % (particle_theta, particle_momentum),
         args)


def get_momentum(event):
    momentum = math.sqrt(math.pow(event.particleVec[0], 2) +
                         math.pow(event.particleVec[1], 2) +
                         math.pow(event.particleVec[2], 2))

    return momentum


def get_theta(event):
    rxy = math.sqrt(math.pow(event.particleVec[0], 2) + math.pow(event.particleVec[1], 2))
    theta = abs(math.atan2(rxy, event.particleVec[2]))
    theta = 180 * theta / math.pi

    return theta


def get_energy(event):
    energy = 0.
    for layer_energy in event.energyInLayer:
        energy += layer_energy

    return energy


def get_phi(event):
    phi = math.atan2(event.particleVec[1], event.particleVec[0]) + math.pi

    return phi


def get_rxy(event):
    rxy = math.sqrt(math.pow(event.particleVec[0], 2) + math.pow(event.particleVec[1], 2))

    return rxy


def plot(obj, plotname, args):
    from ROOT import gPad
    from ROOT import TCanvas, TPaveText
    canvas = TCanvas('canvas_' + plotname, 'Canvas', 450, 450)
    gPad.SetLeftMargin(.13)
    gPad.SetTopMargin(.05)

    draw_options = ''

    obj.SetTitle('')

    if 'TH1' in obj.ClassName():
        gPad.SetLeftMargin(.15)

    if 'TH2' in obj.ClassName():
        gPad.SetRightMargin(.13)
        draw_options = 'COLZ'

    legend = TPaveText(.2, .7, .5, .9, 'brNDC')
    legend.SetFillStyle(0)
    legend.SetFillColor(0)
    legend.SetBorderSize(0)
    legend.SetTextColor(1)
    legend.SetTextFont(42)
    legend.SetTextAlign(11)
    for note in args.note:
        legend.AddText(note)

    obj.Draw(draw_options)
    legend.Draw()
    canvas.Print(plotname + '.' + args.output_file_format)


if __name__ == '__main__':
    main()
