#! /usr/bin/env python
""" Plots the fraction of injections with their parameter value recovered
within a credible interval versus credible interval.
"""

import sys
import argparse
import logging
import matplotlib as mpl; mpl.use("Agg")
import matplotlib.colorbar as cbar
import matplotlib.pyplot as plt
import numpy
import pycbc
from scipy import stats
from matplotlib import cm
from pycbc import inject
from pycbc import transforms
from pycbc.results import save_fig_with_metadata
from pycbc.inference import option_utils
from pycbc.io.record import FieldArray

# parse command line
parser = argparse.ArgumentParser(usage=__file__ + " [--options]",
                                 description=__doc__)
parser.add_argument("--output-file", required=True, type=str,
                    help="Path to save output plot.")
parser.add_argument("--verbose", action="store_true",
                    help="Allows print statements.")
parser.add_argument("--injection-hdf-group", default="H1/injections",
                    help="HDF group that contains injection values.")
option_utils.add_inference_results_option_group(parser)
option_utils.add_scatter_option_group(parser)
opts = parser.parse_args()

# set logging
pycbc.init_logging(opts.verbose)

# read results
_, parameters, labels, samples = option_utils.results_from_cli(opts)
labels = labels[0]

# typecast to list for iteration
parameters = [parameters] if not isinstance(samples, list) else parameters
labels = [labels] if not isinstance(labels, list) else labels
samples = [samples] if not isinstance(samples, list) else samples

# loop over input files and its samples
logging.info("Plotting")
measured_percentiles = {}
for input_file, input_parameters, input_samples in zip(
                                        opts.input_file, parameters, samples):
    # load the injections
    opts.input_file = input_file
    inj_parameters = option_utils.injections_from_cli(opts)

    for p in input_parameters:
        inj_val = inj_parameters[p]
        sample_vals = input_samples[p]
        measured = stats.percentileofscore(sample_vals, inj_val, kind='weak')
        try:
            measured_percentiles[p].append(measured)
        except KeyError:
            measured_percentiles[p] = []
            measured_percentiles[p].append(measured)

# create figure for plotting
fig = plt.figure()
ax = fig.add_subplot(111)

# calculate the expected percentile for each injection and plot
for param,label in zip(input_parameters, labels):
    # calculate expected
    meas = numpy.array(measured_percentiles[param])
    meas.sort()
    expected = numpy.array([stats.percentileofscore(meas, x, kind='weak')
                            for x in meas])
    # perform ks test
    ks, p = stats.kstest(meas/100., 'uniform')
    ax.plot(meas/100., expected/100.,
                label='{} $D_{{KS}}$: {:.3f} p-value: {:.3f}'.format(label, ks, p))

# set legend
ax.legend()

# set labels
ax.set_ylabel(r"Fraction of Injections Recovered in Credible Interval")
ax.set_xlabel(r"Credible Interval")

# add grid to plot
ax.grid()

# add 1:1 line to plot
ax.plot([0, 1], [0, 1], linestyle="dashed", color="gray", zorder=9)

# save plot
caption = ('Percentile-percentile plot. The value of the KS statistic '
           '$D_{KS}$ is given in the legend. This gives the maximum distance '
           'between the observed line and the expected (dashed) line; i.e., '
           'it gives the maximum distance between the measured CDF and the '
           'expected (uniform) CDF. The associated two-tailed p-value gives '
           'the probability of getting a maximum distance (either above '
           'or below the expected line) larger than the observed $D_{KS}$ '
           'assuming that the measured CDF is the same as the expected. '
           'In other words, the larger (smaller) the p-value ($D_{KS}$), the '
           'more likely the measured distribution is the same as the '
           'expected.')
save_fig_with_metadata(fig, opts.output_file,
    caption=caption,
    cmd=' '.join(sys.argv))

# done
logging.info("Done")

