#!/usr/bin/env python
import h5py, argparse, logging, numpy, numpy.random
from pycbc import events, detector
from pycbc.events import veto, coinc, stat
import pycbc.version

parser = argparse.ArgumentParser()
parser.add_argument("--verbose", action="count")
parser.add_argument("--version", action="version", version=pycbc.version.git_verbose_msg)
parser.add_argument("--veto-files", nargs='*', action='append', default=[],
                    help="Optional veto file. Triggers within veto segments "
                    "contained in the file are ignored")
parser.add_argument("--segment-name", nargs='*', action='append', default=[],
                    help="Optional, name of veto segment in veto file")
parser.add_argument("--trigger-files", nargs='*', action='append', default=[],
                    help="File containing the single-detector triggers")
parser.add_argument("--template-bank", required=True,
                    help="Template bank file in HDF format")
parser.add_argument("--pivot-ifo", required=True,
                    help="Add the ifo to use as the pivot for multi "
                         "detector coincidence")
parser.add_argument("--fixed-ifo", required=True,
                    help="Add the ifo to use as the fixed ifo for "
                         "multi detector coincidence")
# produces a list of lists to allow multiple invocations and multiple args
parser.add_argument("--statistic-files", nargs='*', action='append', default=[],
                    help="Files containing ranking statistic info")
parser.add_argument("--ranking-statistic", choices=stat.statistic_dict.keys(),
                    help="The ranking statistic to use", default='newsnr')
parser.add_argument("--use-maxalpha", action="store_true")
parser.add_argument("--coinc-threshold", type=float, default=0.0,
                    help="Seconds to add to time-of-flight coincidence window")
parser.add_argument("--timeslide-interval", type=float,
                    help="Interval between timeslides in seconds. Timeslides are "
                         "disabled if the option is omitted.")
parser.add_argument("--decimation-factor", type=int,
                    help="The factor to reduce the background trigger rate.")
loudest = parser.add_mutually_exclusive_group()
loudest.add_argument("--loudest-keep", type=int,
                     help="Keep this number of loudest triggers from each template.")
loudest.add_argument("--loudest-keep-value", type=float,
                     help="Keep all coincident triggers above this value.")
parser.add_argument("--template-fraction-range", default="0/1",
                    help="Optional, analyze only part of template bank. Format"
                    " PART/NUM_PARTS")
parser.add_argument("--cluster-window", type=float,
                    help="Optional, window size in seconds to cluster "
                    "coincidences over the bank")
parser.add_argument("--output-file",
                    help="File to store the coincident triggers")
args = parser.parse_args()

# flatten the list of lists of filenames to a single list (may be empty)
args.statistic_files = sum(args.statistic_files, [])
args.segment_name = sum(args.segment_name, [])
args.veto_files = sum(args.veto_files, [])
args.trigger_files = sum(args.trigger_files, [])

if args.verbose:
    logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)


def parse_template_range(num_templates, rangestr):
    part = int(rangestr.split('/')[0])
    pieces = int(rangestr.split('/')[1])
    tmin =  int(num_templates / float(pieces) * part)
    tmax =  int(num_templates / float(pieces) * (part+1))
    return tmin, tmax

class ReadByTemplate(object):
    def __init__(self, filename, bank=None, segment_name=[], veto_files=[]):
        self.filename = filename
        self.file = h5py.File(filename, 'r')
        self.ifo = self.file.keys()[0]
        self.valid = None
        self.bank = h5py.File(bank, 'r') if bank else None

        # Determine the segments which define the boundaries of valid times
        # to use triggers
        from glue.segments import segmentlist, segment
        key = '%s/search/' % self.ifo
        s, e = self.file[key + 'start_time'][:], self.file[key + 'end_time'][:]
        self.segs = veto.start_end_to_segments(s, e).coalesce()
        for vfile, name in zip(veto_files, segment_name):
            veto_segs = veto.select_segments_by_definer(vfile, ifo=self.ifo,
                                                     segment_name=name)
            self.segs = (self.segs - veto_segs).coalesce()
        self.valid = veto.segments_to_start_end(self.segs)

    def get_data(self, col, num):
        """ Get a column of data for template with id 'num'

        Parameters
        ----------
        col: str
            Name of column to read
        num: int
            The template id to read triggers for

        Returns
        -------
        data: numpy.ndarray
            The requested column of data
        """
        ref = self.file['%s/%s_template' % (self.ifo, col)][num]
        return self.file['%s/%s' % (self.ifo, col)][ref]

    def set_template(self, num):
        """ Set the active template to read from

        Parameters
        ----------
        num: int
            The template id to read triggers for

        Returns
        -------
        trigger_id: numpy.ndarray
            The indices of this templates triggers
        """
        self.template_num = num
        times = self.get_data('end_time', num)

        # Determine which of these template's triggers are kept after
        # applying vetoes
        if self.valid:
            self.keep = veto.indices_within_times(times, self.valid[0], self.valid[1])
            logging.info('applying vetoes')
        else:
            self.keep = numpy.arange(0, len(times))

        if self.bank is not None:
            self.param = {}
            if 'parameters' in self.bank.attrs :
                for col in self.bank.attrs['parameters']:
                    self.param[col] = self.bank[col][self.template_num]
            else :
                for col in self.bank:
                    self.param[col] = self.bank[col][self.template_num]

        # Calculate the trigger id by adding the relative offset in self.keep
        # to the absolute beginning index of this templates triggers stored
        # in 'template_boundaries'
        trigger_id = self.keep + self.file['%s/template_boundaries' % self.ifo][num]
        return trigger_id

    def __getitem__(self, col):
        """ Return the column of data for current active template after
        applying vetoes

        Parameters
        ----------
        col: str
            Name of column to read

        Returns
        -------
        data: numpy.ndarray
            The requested column of data
        """
        if self.template_num == None:
            raise ValueError('You must call set_template to first pick the '
                             'template to read data from')
        data = self.get_data(col, self.template_num)
        data = data[self.keep] if self.valid else data
        return data

logging.info('Starting...')

num_templates = len(h5py.File(args.template_bank, "r")['template_hash'])
tmin, tmax = parse_template_range(num_templates, args.template_fraction_range)
logging.info('Analyzing template %s - %s' % (tmin, tmax-1))

# Create dictionary for trigger files
trigs={}
for i in range(len(args.trigger_files)):
    logging.info('Opening trigger file %s: %s' % (i,args.trigger_files[i]))
    reader = ReadByTemplate(args.trigger_files[i],
                                               args.template_bank,
                                               args.segment_name,
                                               args.veto_files)
    trigs[reader.ifo] = reader

coinc_segs = trigs[args.pivot_ifo].segs
for ifo in trigs:
    coinc_segs = (coinc_segs & trigs[ifo].segs).coalesce()


for ifo in trigs:
    trigs[ifo].segs = coinc_segs
    trigs[ifo].valid = veto.segments_to_start_end(trigs[ifo].segs)

# initialize a Stat class instance to calculate the coinc ranking statistic
rank_method = stat.get_statistic(args.ranking_statistic)(args.statistic_files)

detectors={}
for ifo in trigs:
    detectors[ifo] = detector.Detector(trigs[ifo].ifo)

# Sanity check, time slide interval should be larger than twice the
# Earth crossing time, which is approximately 0.085 seconds.
if 0.085 >= args.timeslide_interval and args.timeslide_interval is not None:
    raise parser.error("The time slide interval should be larger "
                       "than twice the Earth crossing time.")

# slide = 0 means don't do timeslides
if args.timeslide_interval is None:
    args.timeslide_interval = 0

data = {'stat':[], 'decimation_factor':[], 'timeslide_id':[], 'template_id':[]}

for tnum in range(tmin, tmax):
    tids={}
    for ifo in trigs:
        tids[ifo]=trigs[ifo].set_template(tnum)

    for ifo in tids:
        if len(tids[ifo] == 0):
            continue

    times={}
    for ifo in trigs:
        times[ifo]=trigs[ifo]['end_time']
    logging.info('Trigs for template %s, '% (tnum))
    for ifo in times:
        logging.info('%s:%s' % (ifo, len(times[ifo])))

    ids, slide = coinc.time_multi_coincidence(times, args.timeslide_interval,
                                              args.coinc_threshold,
                                              args.pivot_ifo,
                                              args.fixed_ifo)

    logging.info('Coincident Trigs: %s' % (len(ids[args.pivot_ifo])))

    logging.info('Calculating Single Detector Statistic')
    sds={}
    for ifo in trigs:
        sds[ifo]=rank_method.single(trigs[ifo])

    logging.info('Calculating Multi-Detector Combined Statistic')

    sdswithids={}
    for ifo in sds:
        sdswithids[ifo]=sds[ifo][ids[ifo]]
    c = rank_method.coinc_multiifo(sdswithids, slide, args.timeslide_interval)

    #index values of the zerolag triggers
    fi = numpy.where(slide == 0)[0]

    #index values of the background triggers
    bi = numpy.where(slide != 0)[0]
    logging.info('%s foreground triggers' % len(fi))
    logging.info('%s background triggers' % len(bi))

    # We split the background triggers into two types which we keep track of
    # in "bh" (triggers which are *not* decimated, stored in full) and
    # "bl" (triggers which may not be stored in full, but we keep track of
    # how many are removed)
    # "bl_int" keeps track of which triggers are *not* in the bh set. Depending
    # on the decimation factor option we may not store any of these or we may
    # keep a fraction of them corresponding to a subset of the timeslides
    if args.loudest_keep:
        sep = len(bi) - args.loudest_keep
        sep = 0 if sep < 0 else sep

        bsort = numpy.argpartition(c[bi], sep)
        bl_int = bi[bsort[0:sep]]
        bh = bi[bsort[sep:]]
        del bsort
        del bi
    elif args.loudest_keep_value:
        bh = bi[c[bi] > args.loudest_keep_value]
        bl_int = bi[c[bi] <= args.loudest_keep_value]
    else:
        bh = bi

    if args.decimation_factor:
        bl = bl_int[slide[bl_int] % args.decimation_factor == 0]
    else:
        bl = []

    ti = numpy.concatenate([bl, bh, fi]).astype(numpy.uint32)
    logging.info('%s after decimation' % len(ti))

    gs={}
    for ifo in ids:
        gs[ifo]=ids[ifo][ti]
    del ids

    # In the first for cycle the time and trigger_id keys will not exist,
    # we need to create them, otherwise add new items into data.
    checkstring = '%s/time' % args.pivot_ifo
    if checkstring in data:
        for ifo in gs:
            addtime=times[ifo][gs[ifo]]
            addtriggerid=tids[ifo][gs[ifo]]
            data['%s/time' % ifo] += [addtime]
            data['%s/trigger_id' % ifo] += [addtriggerid]
    else:
        for ifo in gs:
            addtime=times[ifo][gs[ifo]]
            addtriggerid=tids[ifo][gs[ifo]]
            data['%s/time' % ifo] = [addtime]
            data['%s/trigger_id' % ifo] = [addtriggerid]
    data['stat'] += [c[ti]]
    dec_fac = numpy.repeat([args.decimation_factor, 1, 1],
                           [len(bl), len(bh), len(fi)]).astype(numpy.uint32)
    data['decimation_factor'] += [dec_fac]
    data['timeslide_id'] += [slide[ti]]
    data['template_id'] += [numpy.zeros(len(ti), dtype=numpy.uint32) + tnum]

if len(data['stat']) > 0:
    for key in data:
        data[key] = numpy.concatenate(data[key])

if args.cluster_window and len(data['stat']) > 0:
    timestring0 = '%s/time' % args.pivot_ifo
    timestring1 = '%s/time' % args.fixed_ifo
    cid = coinc.cluster_coincs(data['stat'], data[timestring0], data[timestring1],
                               data['timeslide_id'], args.timeslide_interval,
                               args.cluster_window)

logging.info('saving coincident triggers')
f = h5py.File(args.output_file, 'w')
if len(data['stat']) > 0:
    for key in data:
        var = data[key][cid] if args.cluster_window else data[key]
        f.create_dataset(key, data=var,
                         compression='gzip',
                         compression_opts=9,
                         shuffle=True)

f['segments/coinc/start'], f['segments/coinc/end'] = veto.segments_to_start_end(coinc_segs)

for t in trigs:
    f['segments/%s/start' % trigs[t].ifo], f['segments/%s/end' % trigs[t].ifo] = trigs[t].valid

f.attrs['timeslide_interval'] = args.timeslide_interval
f.attrs['num_of_ifos'] = len(args.trigger_files)
f.attrs['pivot'] = args.pivot_ifo
f.attrs['fixed'] = args.fixed_ifo

for ifo in detectors:
    f.attrs['%s_foreground_time' % ifo] = abs(trigs[ifo].segs)
f.attrs['coinc_time'] = abs(coinc_segs)

if args.timeslide_interval:
    maxtrigs=abs(trigs[args.pivot_ifo].segs)
    for ifo in trigs:
        if abs(trigs[ifo].segs) > maxtrigs:
            maxtrigs=abs(trigs[ifo].segs)
    nslides = int(maxtrigs / args.timeslide_interval)
else:
    nslides = 0

f.attrs['num_slides'] = nslides

logging.info('Done')