#!/usr/bin/python

# Copyright 2016 Thomas Dent
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

from __future__ import division

import sys, h5py
import argparse, logging

import copy, numpy as np

from pycbc import io, events
from pycbc.events import trigger_fits as trstats
import pycbc.version

#### DEFINITIONS AND FUNCTIONS ####

stat_dict = {
    "new_snr"       : events.newsnr,
    "effective_snr" : events.effsnr,
    "snr"           : lambda snr, rchisq : snr,
    "snronchi"      : lambda snr, rchisq : snr / (rchisq ** 0.5),
    "newsnr_sgveto"      : events.newsnr_sgveto,
}

def get_stat(statchoice, snr, rchisq, sgchisq, fac):
    if fac is not None:
        if statchoice not in ['new_snr', 'effective_snr']:
            raise RuntimeError("Can't use --stat-factor with this statistic!")
        return stat_dict[statchoice](snr, rchisq, fac)
    elif statchoice == 'newsnr_sgveto':
        return stat_dict[statchoice](snr, rchisq, sgchisq)
    else:
        return stat_dict[statchoice](snr, rchisq)
#### MAIN ####

parser = argparse.ArgumentParser(usage="",
    description="Perform maximum-likelihood fits of single inspiral trigger"
                " distributions to various functions")

parser.add_argument("--version", action=pycbc.version.Version)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--trigger-file",
                    help="Input hdf5 file containing single triggers. "
                    "Required")
parser.add_argument("--bank-file", default=None,
                    help="hdf file containing template parameters. Required")
parser.add_argument("--template-fraction-range", default="0/1",
                    help="Optional, analyze only part of template bank. "
                    "Format is PART/NUM_PARTS")
parser.add_argument("--veto-file", nargs='*', default=[], action='append',
                    help="File(s) in .xml format with veto segments to apply "
                    "to triggers before fitting")
parser.add_argument("--veto-segment-name", nargs='*', default=[], action='append',
                    help="Name(s) of veto segments to apply. Optional, if not "
                    "given all segments for a given ifo will be used")
parser.add_argument("--output", required=True,
                    help="Location for output file containing fit coefficients"
                    ". Required")
parser.add_argument("--ifo", required=True,
                    help="Ifo producing triggers to be fitted. Required")
parser.add_argument("--fit-function",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit")
parser.add_argument("--sngl-stat", default="new_snr",
                    choices=["snr", "snronchi", "effective_snr", "new_snr", "newsnr_sgveto"],
                    help="Function of SNR and chisq to perform fits with")
parser.add_argument("--stat-factor", type=float,
                    help="Adjustable magic number used in some sngl "
                    "statistics. Values commonly used: 6 for new_snr, 250 "
                    "or 50 for effective_snr")
parser.add_argument("--stat-threshold", type=float,
                    help="Only fit triggers with statistic value above this "
                    "threshold.  Required.  Typically 6-6.5")
parser.add_argument("--save-trig-param",
                    help="For each template, save a parameter value read from "
                    "its trigger(s). Ex. template_duration")
parser.add_argument("--prune-param",
                    help="Parameter to define bins for 'pruning' loud triggers"
                    " to make the fit insensitive to signals and outliers. "
                    "Choose from mchirp, mtotal, template_duration or a named "
                    "frequency cutoff in pnutils or a frequency function in "
                    "LALSimulation")
parser.add_argument("--prune-bins", type=int,
                    help="Number of bins to divide bank into when pruning")
parser.add_argument("--prune-number", type=int,
                    help="Number of loudest events to prune in each bin")
parser.add_argument("--log-prune-param", action='store_true',
                    help="Bin in the log of prune-param")
parser.add_argument("--f-lower", default=-1.,
                    help="Starting frequency for calculating template "
                    "duration, required if this is the prune parameter")
# FIXME : support using the trigger file duration as prune parameter?
# FIXME : have choice of SEOBNRv2 or PhenD duration formula ?
parser.add_argument("--min-duration", default=0.,
                    help="Fudge factor for templates with tiny or negative "
                    "values of template_duration: add to duration values "
                    "before pruning. Units seconds")
parser.add_argument("--approximant", default="SEOBNRv4",
                    help="Approximant for template duration. Default SEOBNRv4")

args = parser.parse_args()

args.veto_segment_name = sum(args.veto_segment_name, [])
args.veto_file = sum(args.veto_file, [])

if len(args.veto_segment_name) != len(args.veto_file):
    raise RuntimeError("Number of veto files much match veto file names")

if (args.prune_param or args.prune_bins or args.prune_number) and not \
   (args.prune_param and args.prune_bins and args.prune_number):
    raise RuntimeError("To prune, need to specify param, number of bins and "
                       "nonzero number to prune in each bin!")

if args.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

logging.info('Fitting above threshold %f' % args.stat_threshold)

logging.info('Opening trigger file: %s' % args.trigger_file)
trigf = h5py.File(args.trigger_file, 'r')
logging.info('Opening template file: %s' % args.bank_file)
templatef = h5py.File(args.bank_file, 'r')

# get the stat values
chisq = trigf[args.ifo+'/chisq'][:]
chisq_dof = trigf[args.ifo+'/chisq_dof'][:]
rchisq = chisq / (2 * chisq_dof - 2)
try:
    sgchisq = trigf[args.ifo+'/sg_chisq'][:]
except KeyError:
    sgchisq = None

del chisq
del chisq_dof
snr = trigf[args.ifo+'/snr'][:]
logging.info('Calculating stat values')
stat = get_stat(args.sngl_stat, snr, rchisq, sgchisq, args.stat_factor)
del snr
del rchisq
del sgchisq
# do first thresholding operation to reduce trigger numbers
abovethresh = stat >= args.stat_threshold
stat = stat[abovethresh]
tid = trigf[args.ifo+'/template_id'][:][abovethresh]
time = trigf[args.ifo+'/end_time'][:][abovethresh]
if args.save_trig_param:
    tparam = trigf[args.ifo+'/'+args.save_trig_param][:][abovethresh]
logging.info('%i trigs left after thresholding' % len(stat))

# now do vetoing
for veto_file, veto_segment_name in zip(args.veto_file, args.veto_segment_name):
    retain, junk = events.veto.indices_outside_segments(time, [veto_file],
                                 ifo=args.ifo, segment_name=veto_segment_name)
    stat = stat[retain]
    tid = tid[retain]
    time = time[retain]
    if args.save_trig_param:
        tparam = tparam[retain]
    logging.info('%i trigs left after vetoing with %s' %
                                                       (len(stat), veto_file))

# do pruning (removal of trigs at N loudest times defined over param bins)
if args.prune_param:
    logging.info('Getting min and max param values')
    pars = trstats.get_param(args.prune_param, args,
                     templatef['mass1'][:], templatef['mass2'][:],
                     templatef['spin1z'][:], templatef['spin2z'][:])
    minpar = min(pars)
    maxpar = max(pars)
    del pars
    logging.info('%f %f' % (minpar, maxpar))

    # hard-coded time window of 0.1s
    args.prune_window = 0.1
    # initialize bin storage
    prunedtimes = {}
    for i in range(args.prune_bins):
        prunedtimes[i] = []

    # keep a record of the triggers if all successive loudest events were to
    # be pruned
    statpruneall = copy.deepcopy(stat)
    tidpruneall = copy.deepcopy(tid)
    timepruneall = copy.deepcopy(time)

    # many trials may be required to prune in 'quieter' bins
    for j in range(1000):
        # are all the bins full already?
        numpruned = sum([len(prunedtimes[i]) for i in range(args.prune_bins)])
        if numpruned == args.prune_bins * args.prune_number:
            logging.info('Finished pruning!')
            break
        if numpruned > args.prune_bins * args.prune_number:
            logging.error('Uh-oh, we pruned too many things .. %i, to be '
                          'precise' % numpruned)
            raise RuntimeError
        loudest = np.argmax(statpruneall)
        lstat = statpruneall[loudest]
        ltid = tidpruneall[loudest]
        ltime = timepruneall[loudest]
        m1, m2, s1z, s2z = trstats.get_masses(templatef, ltid)
        lbin = trstats.which_bin(trstats.get_param(args.prune_param, args,
                                           m1, m2, s1z, s2z),
                                 minpar, maxpar, args.prune_bins,
                                                      log=args.log_prune_param)
        # is the bin where the loudest trigger lives full already?
        if len(prunedtimes[lbin]) == args.prune_number:
            logging.info('%i - Bin %i full, not pruning event with stat %f at '
                         'time %.3f' % (j, lbin, lstat, ltime))
            # prune the reference trigger array
            retain = abs(timepruneall - ltime) > args.prune_window
            statpruneall = statpruneall[retain]
            tidpruneall = tidpruneall[retain]
            timepruneall = timepruneall[retain]
            del retain
            continue
        else:
            logging.info('Pruning event with stat %f at time %.3f in bin %i' %
                         (lstat, ltime, lbin))
            # now do the pruning
            retain = abs(time - ltime) > args.prune_window
            logging.info('%i trigs before pruning' % len(stat))
            stat = stat[retain]
            logging.info('%i trigs remain' % len(stat))
            tid = tid[retain]
            time = time[retain]
            if args.save_trig_param:
                tparam = tparam[retain]
            # also for the reference trig arrays
            retain = abs(timepruneall - ltime) > args.prune_window
            statpruneall = statpruneall[retain]
            tidpruneall = tidpruneall[retain]
            timepruneall = timepruneall[retain]
            # record the time
            prunedtimes[lbin].append(ltime)
            del retain
    del statpruneall
    del tidpruneall
    del timepruneall
    logging.info('%i trigs remain after pruning loop' % len(stat))

# parse template range
num_templates = len(templatef['template_hash'])
rangestr = args.template_fraction_range
part = int(rangestr.split('/')[0])
pieces = int(rangestr.split('/')[1])
tmin = int(num_templates / float(pieces) * part)
tmax = int(num_templates / float(pieces) * (part + 1))
trange = range(tmin, tmax)

# initialize result storage
tids = []
counts = []
counts_above = []
fits = []
tpars = []

for tnum in trange:
    stat_in_template = stat[tid == tnum]
    count_above = len(stat_in_template)
    if count_above == 0:
        # 'stupid' value to indicate no data, shouldn't hurt if 1/alpha is averaged
        alpha = -100.
    else:
        alpha, sig_alpha = trstats.fit_above_thresh(
                      args.fit_function, stat_in_template, args.stat_threshold)
    tids.append(tnum)
    counts_above.append(count_above)
    fits.append(alpha)
    if args.save_trig_param:
        # save the param value of the first trig in this template
        param_in_template = tparam[tid == tnum]
        tpars.append(param_in_template[0])
    if (tnum % 100 == 0): logging.info('Fitted template %i / %i' %
                                                     (tnum - tmin,tmax - tmin))

outfile = h5py.File(args.output, 'w')
# store template-dependent fit output
outfile.create_dataset("template_id", data=trange)
outfile.create_dataset("count_above_thresh", data=counts_above)
outfile.create_dataset("fit_coeff", data=fits)
if args.save_trig_param:
    outfile.create_dataset("template_param", data=tpars)
# add some metadata
outfile.attrs.create("ifo", data=args.ifo)
outfile.attrs.create("fit_function", data=args.fit_function)
outfile.attrs.create("sngl_stat", data=args.sngl_stat)
if args.stat_factor:
    outfile.attrs.create("stat_factor", data=args.stat_factor)
if args.save_trig_param:
    outfile.attrs.create("save_trig_param", data=args.save_trig_param)
outfile.attrs.create("stat_threshold", data=args.stat_threshold)

outfile.close()
logging.info('Done!')
