#!/usr/bin/env python
""" Plot found and missed injections.
"""
import h5py, numpy, logging, os.path, argparse, sys, matplotlib
import matplotlib; matplotlib.use('Agg')
import matplotlib.pyplot as plot
import pycbc.results.followup, pycbc.pnutils, pycbc.results, pycbc.version
import pycbc.pnutils

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--injection-file',
                    help="The hdf injection file to plot", required=True)
parser.add_argument('--axis-type', default='mchirp', choices=['mchirp',
                    'effective_spin', 'time', 'mtotal', 'mass_ratio'])
parser.add_argument('--log-x', action='store_true', default=False)
parser.add_argument('--distance-type', default='decisive_distance',
                    choices=['decisive_distance', 'dec_chirp_distance',
                             'chirp_distance', 'comb_optimal_snr',
                             'decisive_optimal_snr', 'redshift'])
parser.add_argument('--plot-all-distance', action='store_true', default=False,
                    help="Plot all values of distance or SNR. If not given, "
                         "the plot will be truncated below 1")
parser.add_argument('--verbose', action='count')
parser.add_argument('--log-distance', action='store_true', default=False)
parser.add_argument('--dynamic', action='store_true', default=False)
parser.add_argument('--gradient-far', action='store_true',
                    help="Show far of found injections as a gradient")
parser.add_argument('--output-file', required=True)
parser.add_argument('--far-type', choices=('inclusive', 'exclusive'), default='inclusive',
                    help="Type of far to plot for the color. Choices are 'inclusive' or "
                         "'exclusive'. Default = 'inclusive'")
parser.add_argument('--missed-on-top', action='store_true',
                    help="Plot missed injections on top of found ones")
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
args = parser.parse_args()

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

logging.info('Read in the data')
f = h5py.File(args.injection_file, 'r')
time = f['injections/end_time'][:]
found = f['found_after_vetoes/injection_index'][:]
missed = f['missed/after_vetoes'][:]

if args.far_type == 'inclusive':
    ifar_found = f['found_after_vetoes/ifar'][:]
    far_title = 'Inclusive'
elif args.far_type == 'exclusive':
    ifar_found = f['found_after_vetoes/ifar_exc'][:]
    far_title = 'Exclusive'

# This hardcodes HL search !!!!!, replace with function that takes or/sky/det
hdist = f['injections/eff_dist_h'][:]
ldist = f['injections/eff_dist_l'][:]
s1z = f['injections/spin1z'][:]
s2z = f['injections/spin2z'][:]
dist = f['injections/distance'][:]
m1, m2 = f['injections/mass1'][:], f['injections/mass2'][:]

vals = {}
vals['mchirp'], eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(m1, m2)
# GRRR should use pnutils formula
vals['effective_spin'] = (m1 * s1z + m2 * s2z) / (m1 + m2)
vals['time'] = time
vals['mtotal'] = m1 + m2
# use convention that q>1
vals['mass_ratio'] = numpy.maximum(m1/m2, m2/m1)

dvals = {}
dvals['decisive_distance'] = numpy.maximum(hdist, ldist)
dvals['dec_chirp_distance'] = pycbc.pnutils.chirp_distance(dvals['decisive_distance'], vals['mchirp'])
dvals['chirp_distance'] = pycbc.pnutils.chirp_distance(dist, vals['mchirp'])
if args.distance_type == 'redshift':
    dvals['redshift'] = f['injections/redshift'][:]

if 'optimal_snr_1' in f['injections']:
    opt_snr_1 = f['injections/optimal_snr_1'][:]
    opt_snr_2 = f['injections/optimal_snr_2'][:]
    dvals['comb_optimal_snr'] = numpy.sqrt(opt_snr_1 ** 2 + opt_snr_2 ** 2)
    dvals['decisive_optimal_snr'] = numpy.where(opt_snr_1 < opt_snr_2, opt_snr_1, opt_snr_2)

fdvals = dvals[args.distance_type][found]
mdvals = dvals[args.distance_type][missed]

labels={'mchirp': 'Chirp Mass',
        'mtotal': 'Total Mass',
        'mass_ratio': 'Mass Ratio',
        'decisive_distance': 'Injected Decisive Distance (Mpc)',
        'dec_chirp_distance': 'Injected Decisive Chirp Distance (Mpc)',
        'chirp_distance': 'Injected Chirp Distance (Mpc)',
        'comb_optimal_snr': 'Combined Optimal SNR',
        'decisive_optimal_snr': 'Decisive Optimal SNR',
        'redshift': 'Redshift',
        'time': 'Time (s)',
        'effective_spin': 'Weighted Aligned Spin',
       }

if args.missed_on_top:
  fig_title = 'Missed and Found Injections'
else:
  fig_title = 'Found and Missed Injections'

fig = plot.figure()
zmissed = args.missed_on_top
zfound = not args.missed_on_top

if not args.gradient_far:
    color = numpy.zeros(len(found))
    ten = numpy.where(ifar_found > 10)[0]
    hundred = numpy.where(ifar_found > 100)[0]
    thousand = numpy.where(ifar_found > 1000)[0]
    color[hundred] = 0.5
    color[thousand] = 1.0

    mpoints = plot.scatter(vals[args.axis_type][missed], mdvals,
                           marker='x', color='black', label='missed', zorder=zmissed)
    points = plot.scatter(vals[args.axis_type][found], fdvals,
                           c=color, linewidth=0, vmin=0, vmax=1,
                           marker='o', label='found', zorder=zfound)
    caption = (fig_title + ": Black x's are missed injections. "
              "Blue circles are found with IFAR < 100 years, green are < 1000 years, and "
              "red are found with IFAR >=1000 years. ")
else:
    fvals = vals[args.axis_type][found]
    color = 1.0 / ifar_found

    # sort so quiet found is on top
    csort = color.argsort()
    fvals = fvals[csort]
    fdval = fdvals[csort]
    color = color[csort]
    if len(color) < 2:
        color=None

    mpoints = plot.scatter(vals[args.axis_type][missed], mdvals,
                           marker='x', color='red', label='missed', zorder=zmissed)
    points = plot.scatter(fvals, fdval, c=color, linewidth=0,
                          norm=matplotlib.colors.LogNorm(),
                          marker='o', label='found', zorder=zfound)
    caption = (fig_title + ": Red x's are missed injections. "
               "Circles are found injections. The color indicates the value of "
               "the false alarm rate. " )
    plot.subplots_adjust(right=0.99)
    try:
        c = plot.colorbar()
        c.set_label('False Alarm Rate $(yr^{-1})$, %s' % far_title)
    except (TypeError, ZeroDivisionError):
        # Can't make colorbar if no quiet found injections
        if len(fvals):
            raise

if args.missed_on_top:
  caption += "Missed injection are shown on top of found injections."
else:
  caption += "Found injections are shown on top of missed injections."

ax = plot.gca()
plot.xlabel(labels[args.axis_type])
plot.ylabel(labels[args.distance_type])
plot.grid()

if args.log_x:
    # log x axis may fail for some choices, eg effective spin
    ax.set_xscale('log')
    tmpxvals = list(vals[args.axis_type][missed])
    tmpxvals += list(vals[args.axis_type][found])
    xmax = 1.4 * max(tmpxvals)
    xmin = 0.7 * min(tmpxvals)
    plot.xlim(xmin, xmax)
if args.log_distance:
    ax.set_yscale('log')

tmpyvals = list(fdvals) + list(mdvals)
ymax = 1.2 * max(tmpyvals)
ymin = 0.9 * min(tmpyvals)
if args.plot_all_distance:
    # note: ymin=0 will clash with args.log_distance
    # in that case it *should* throw an error!
    plot.ylim(ymin, ymax)
elif args.distance_type == 'redshift':
    # default y limit: min redshift 0
    plot.ylim(ymin=0, ymax=ymax)
else:
    # arbitrary limit of 1 distance-unit or snr-unit
    plot.ylim(ymin=1, ymax=ymax)

fig_kwds = {}
if '.png' in args.output_file:
    fig_kwds['dpi'] = 200

if ('.html' in args.output_file):
    plot.subplots_adjust(left=0.1, right=0.8, top=0.9, bottom=0.1)
    import mpld3, mpld3.plugins, mpld3.utils
    mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fmt='.5g'))
    legend =  mpld3.plugins.InteractiveLegendPlugin([mpoints, points],
                                                    ['missed', 'found'],
                                                    alpha_unsel=0.1)
    mpld3.plugins.connect(fig, legend)

pycbc.results.save_fig_with_metadata\
    (fig, args.output_file, fig_kwds=fig_kwds,
     title='%s: %s vs %s' % (fig_title, args.axis_type, args.distance_type),
     cmd=' '.join(sys.argv), caption=caption)

