#!/bin/env python
# Copyright (C) 2015 Alexander Harvey Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Followup foreground events
"""
import os, sys, argparse, logging, h5py, pycbc.workflow as wf
from pycbc.results import layout
from pycbc.types import MultiDetOptionAction
from pycbc.events import select_segments_by_definer
import pycbc.workflow.minifollowups as mini
import pycbc.workflow.pegasus_workflow as wdax
import pycbc.version

def to_file(path, ifo=None):
    fil = wdax.File(os.path.basename(path))
    fil.ifo = ifo
    path = os.path.abspath(path)
    fil.PFN(path, 'local')
    return fil

parser = argparse.ArgumentParser(description=__doc__[1:])
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg) 
parser.add_argument('--workflow-name', default='my_unamed_run')
parser.add_argument("-d", "--output-dir", default=None,
                    help="Path to output directory.")
parser.add_argument('--bank-file',
                    help="HDF format template bank file")
parser.add_argument('--statmap-file',
                    help="HDF format clustered coincident trigger result file")
parser.add_argument('--single-detector-triggers', nargs='+', action=MultiDetOptionAction,
                    help="HDF format merged single detector trigger files")
parser.add_argument('--inspiral-segments', 
                    help="xml segment files containing the inspiral analysis times")
parser.add_argument('--inspiral-data-read-name',
                    help="Name of inspiral segmentlist containing data read in "
                         "by each analysis job.")
parser.add_argument('--inspiral-data-analyzed-name',
                    help="Name of inspiral segmentlist containing data "
                         "analyzed by each analysis job.")
parser.add_argument('--output-map')
parser.add_argument('--output-file')
parser.add_argument('--tags', nargs='+', default=[])
wf.add_workflow_command_line_group(parser)
args = parser.parse_args()

logging.basicConfig(format='%(asctime)s:%(levelname)s : %(message)s', 
                    level=logging.INFO)

workflow = wf.Workflow(args, args.workflow_name)

wf.makedir(args.output_dir)
             
# create a FileList that will contain all output files
layouts = []

tmpltbank_file = to_file(args.bank_file)
coinc_file = to_file(args.statmap_file)
insp_segs = to_file(args.inspiral_segments)

single_triggers = []
fsdt = {}
insp_data_seglists = {}
for ifo in args.single_detector_triggers:
    fname = args.single_detector_triggers[ifo]
    single_triggers.append(to_file(fname, ifo=ifo))
    fsdt[ifo] = h5py.File(args.single_detector_triggers[ifo], 'r')
    insp_data_seglists[ifo] = select_segments_by_definer\
        (args.inspiral_segments, segment_name=args.inspiral_data_read_name,
         ifo=ifo)
    # NOTE: make_singles_timefreq needs a coalesced set of segments. If this is
    #       being used to determine command-line options for other codes,
    #       please think if that code requires coalesced, or not, segments.
    insp_data_seglists[ifo].coalesce()
    
num_events = int(workflow.cp.get_opt_tags('workflow-minifollowups', 'num-events', ''))
f = h5py.File(args.statmap_file, 'r')
stat = f['foreground/stat'][:]

sorting = stat.argsort()[::-1]

if len(stat) == 0:
    # There are no triggers, make no-op job and exit
    noop_node = mini.create_noop_node()
    workflow += noop_node
    workflow.save(filename=args.output_file, output_map_path=args.output_map)
    sys.exit(0)

if len(stat) < num_events:
    num_events = len(stat)

stat = stat[sorting][0:num_events]

times = {f.attrs['detector_1']: f['foreground/time1'][:][sorting][0:num_events],
         f.attrs['detector_2']: f['foreground/time2'][:][sorting][0:num_events],
        }
tids = {f.attrs['detector_1']: f['foreground/trigger_id1'][:][sorting][0:num_events],
         f.attrs['detector_2']: f['foreground/trigger_id2'][:][sorting][0:num_events],
        }

bank_data = h5py.File(args.bank_file, 'r')

# loop over number of loudest events to be followed up
for num_event in range(num_events):
    files = wf.FileList([])
    
    ifo_times = '%s:%s %s:%s' % (times.keys()[0], times[times.keys()[0]][num_event],
                                 times.keys()[1], times[times.keys()[1]][num_event])

    ifo_tids = '%s:%s %s:%s' % (tids.keys()[0], tids[tids.keys()[0]][num_event],
                                 tids.keys()[1], tids[tids.keys()[1]][num_event])
    
    bank_id = f['foreground/template_id'][:][sorting][num_event]
    
    layouts += (mini.make_coinc_info(workflow, single_triggers, tmpltbank_file,
                              coinc_file, args.output_dir, n_loudest=num_event, 
                              tags=args.tags + [str(num_event)]),)        
    files += mini.make_trigger_timeseries(workflow, single_triggers,
                              ifo_times, args.output_dir, special_tids=ifo_tids,
                              tags=args.tags + [str(num_event)])
    
    params = {}                          
    for ifo in times:
        params['%s_end_time' % ifo] = times[ifo][num_event]
        try:
            # Only present for precessing case
            params['u_vals_%s'%ifo] = \
                                 fsdt[ifo][ifo]['u_vals'][tids[ifo][num_event]]
        except:
            pass

    params['mass1'] = bank_data['mass1'][bank_id]
    params['mass2'] = bank_data['mass2'][bank_id]
    params['spin1z'] = bank_data['spin1z'][bank_id]
    params['spin2z'] = bank_data['spin2z'][bank_id]
    params['f_lower'] = bank_data['f_lower'][bank_id]
    # don't require precessing template info if not present
    try:
        params['spin1x'] = bank_data['spin1x'][bank_id]
        params['spin1y'] = bank_data['spin1y'][bank_id]
        params['spin2x'] = bank_data['spin2x'][bank_id]
        params['spin2y'] = bank_data['spin2y'][bank_id]
        params['inclination'] = bank_data['inclination'][bank_id]
    except KeyError:
        pass

    files += mini.make_single_template_plots(workflow, insp_segs,
                                    args.inspiral_data_read_name,
                                    args.inspiral_data_analyzed_name,  params,
                                    args.output_dir,
                                    tags=args.tags + [str(num_event)])
    
    for single in single_triggers:
        time = times[single.ifo][num_event]
        files += mini.make_singles_timefreq(workflow, single, tmpltbank_file, 
                                time, args.output_dir,
                                data_segments=insp_data_seglists[single.ifo],
                                tags=args.tags + [str(num_event)])

        files += mini.make_qscan_plot\
            (workflow, single.ifo, time, args.output_dir,
             data_segments=insp_data_seglists[single.ifo],
             tags=args.tags + [str(num_event)])


    layouts += list(layout.grouper(files, 2))
    num_event += 1

workflow.save(filename=args.output_file, output_map_path=args.output_map)
layout.two_column_layout(args.output_dir, layouts)
