#! /usr/bin/env python
""" Plots the recovered versus injected parameter values from a population
of injections.
"""

import argparse
import logging
import matplotlib as mpl; mpl.use("Agg")
import matplotlib.colorbar as cbar
import matplotlib.pyplot as plt
import numpy
import pycbc
import pycbc.version
from matplotlib import cm
from pycbc import inject
from pycbc import transforms
from pycbc.inference import option_utils

# parse command line
parser = argparse.ArgumentParser(usage=__file__ + " [--options]",
                                 description=__doc__)
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg,
                    help="Prints version information.")
parser.add_argument("--output-file", required=True, type=str,
                    help="Path to save output plot.")
parser.add_argument("--verbose", action="store_true",
                    help="Allows print statements.")
parser.add_argument("--quantiles", nargs=2, type=float, default=[0.05, 0.95],
                    help="Quantiles to use as limits.")
parser.add_argument("--injection-hdf-group", default="H1/injections",
                    help="HDF group that contains injection values.")
option_utils.add_inference_results_option_group(parser)
option_utils.add_scatter_option_group(parser)
opts = parser.parse_args()

# set logging
pycbc.init_logging(opts.verbose)

# read results
fp, parameters, labels, samples = option_utils.results_from_cli(opts)

# read injections from input files
inj_parameters = option_utils.injections_from_cli(opts)

# only plot one parameter
assert(len(opts.parameters) == 1)
parameter = parameters[0][0] if isinstance(parameters, list) else parameters
label = labels[0][0] if isinstance(labels, list) else labels

# create figure
fig = plt.figure()
ax = fig.add_subplot(111)

# typecast to list for iteratation
samples = [samples] if not isinstance(samples, list) else samples
fp = [fp] if not isinstance(fp, list) else fp

# if user wants a colorbar
if opts.z_arg:

    # store list of z-axis values and label
    zvals = []
    zlabel = None

    # loop over input files
    logging.info("Reading %s values", opts.z_arg)
    for i, input_fp in enumerate(fp):

        # get z-axis values and label
        likelihood_stats = input_fp.read_likelihood_stats(
             thin_start=opts.thin_start, thin_end=opts.thin_end,
             thin_interval=opts.thin_interval, iteration=opts.iteration)
        vals, zlabel = option_utils.get_zvalues(input_fp, opts.z_arg,
                                                likelihood_stats)
        zvals.append(numpy.median(vals))

        # update range of colorbar
        min_zval = vals.min() if i == 0 else min(min_zval, vals.min())
        max_zval = vals.max() if i == 0 else max(max_zval, vals.max())

    # create colormap
    cmap = cm.get_cmap(opts.scatter_cmap)
    vmin = opts.vmin if opts.vmin else min_zval
    vmax = opts.vmax if opts.vmax else max_zval
    norm = mpl.colors.Normalize(vmin, vmax)

# loop over input files and its samples
logging.info("Plotting")
for i, (input_file, input_fp, input_samples) in enumerate(zip(opts.input_file,
                                                              fp, samples)):

    # get paramter values
    sampled_vals = input_samples[parameter]
    injected_vals = inj_parameters[parameter][i]

    # compute quantiles of sampled results
    quantiles = numpy.array([numpy.percentile(sampled_vals, 100 * q)
                             for q in opts.quantiles])

    # get median and lowest and highest quntiles for plotting
    med = numpy.median(sampled_vals)
    high = quantiles.max()
    low = quantiles.min()

    # get color
    if opts.z_arg:
        color = cmap(norm(zvals[i]))
    else:
        color = "black"

    # plot a point for each injection
    ax.errorbar([injected_vals],
                [med - injected_vals],
                yerr=[[(med - low)], [(high - med)]],
                ecolor=color, linestyle="None", zorder=10)

# create a colorbar
if opts.z_arg:
    cax, _ = cbar.make_axes(ax)
    cb2 = cbar.ColorbarBase(cax, cmap=cmap, norm=norm)
    cb2.set_label(r"Recovered Median " + zlabel)

# set labels
ax.set_ylabel(r"Recovered " + label + r"- Injected " + label)
ax.set_xlabel(r"Injected " + r"{}".format(label))

# add grid to plot
ax.grid()

# add 1:1 line to plot
ax.axhline(0, linestyle="dashed", color="gray", zorder=9)

# save plot
plt.savefig(opts.output_file)

# done
logging.info("Done")

