#! /usr/bin/env python

# Copyright (C) 2016 Miriam Cabero Mueller, Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#

import argparse
import itertools
import logging
import numpy
import matplotlib; matplotlib.use("agg")
import pycbc
import pycbc.version
import sys
from matplotlib import patches
from matplotlib import pyplot
from pycbc.inference import option_utils, likelihood
from pycbc.io.inference_hdf import InferenceFile
from pycbc.results import metadata
from pycbc.results.scatter_histograms import create_multidim_plot

# add options to command line
parser = argparse.ArgumentParser()
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg,
                    help="Prints version information.")
parser.add_argument("--output-file", type=str, required=True,
                    help="Output plot path.")
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Be verbose")
parser.add_argument("--input-file-labels", nargs="+", default=None,
                    help="Labels to add to plot if using more than one"
                         "input file.")

# add options for what plots to create
option_utils.add_plot_posterior_option_group(parser)

# scatter configuration
option_utils.add_scatter_option_group(parser)

# density configuration
option_utils.add_density_option_group(parser)

# add standard option utils
option_utils.add_inference_results_option_group(parser)

# parse command line
opts = parser.parse_args()

# set logging
pycbc.init_logging(opts.verbose)

# get parameters
logging.info("Loading parameters")
fp, parameters, labels, samples = option_utils.results_from_cli(opts)

# typecast to list so the input files can be iterated over
fp = fp if isinstance(fp, list) else [fp]
parameters = parameters if isinstance(parameters[0], list) else [parameters]
labels = labels if isinstance(labels[0], list) else [labels]
samples = samples if isinstance(samples, list) else [samples]

# get likelihood statistic values
if opts.z_arg is not None:
    logging.info("Getting likelihood stats")

    # lists to hold z-axis values and labels for each input file
    zvals = []
    zlbl = []

    # loop over each input file and append z-axis values and labels to lists
    for f in fp:
        likelihood_stats = f.read_likelihood_stats(
            thin_start=opts.thin_start, thin_end=opts.thin_end,
            thin_interval=opts.thin_interval, iteration=opts.iteration)
        f_zvals, f_zlbl = option_utils.get_zvalues(f, opts.z_arg,
                                                   likelihood_stats)
        zvals.append(f_zvals)
        zlbl.append(f_zlbl)
        f.close()

# else there are no z-axis values
else:
    zvals = None
    zlbl = None
    for f in fp:
        f.close()

# determine if colorbar should be shown
show_colorbar = True if opts.z_arg else False

# if no plotting options selected, then the default options are based
# on the number of parameters
plot_options = [opts.plot_marginal, opts.plot_scatter, opts.plot_density]
if not numpy.any(plot_options):
    if len(parameters[0]) == 1:
        opts.plot_marginal = True
    else:
        opts.plot_scatter = True
        # FIXME: right now if there are two parameters it wants
        # both plot_scatter and plot_marginal. One should have the option
        # of give only plot_scatter and that should be the default for
        # two or more parameters
        opts.plot_marginal = True

# get minimum and maximum ranges for each parameter from command line
mins, maxs = option_utils.plot_ranges_from_cli(opts)

# add any missing parameters
for p in parameters[0]:
    if p not in mins:
        mins[p] = numpy.array([s[p].min() for s in samples]).min()
    if p not in maxs:
        maxs[p] = numpy.array([s[p].max() for s in samples]).max()

# get injection values if desired
expected_parameters = {}
if opts.plot_injection_parameters:
    injections = option_utils.injections_from_cli(opts)
    for p in parameters[0]:
        # check that all of the injections are the same
        unique_vals = numpy.unique(injections[p])
        if unique_vals.size != 1:
            raise ValueError("More than one injection found! To use "
                "plot-injection-parameters, there must be a single unique "
                "injection in all input files. Use the expected-parameters "
                "option to specify an expected parameter instead.")
        # passed: use the value for the expected
        expected_parameters[p] = unique_vals[0]

# get expected parameter values from command line
expected_parameters.update(option_utils.expected_parameters_from_cli(opts))

# assign some colors
colors = itertools.cycle(["black"] + ["C{}".format(i) for i in range(10)])

# plot each input file
logging.info("Plotting")
hist_colors = []
for i, (p, l, s) in enumerate(zip(parameters, labels, samples)):

    # on first iteration create figure otherwise update old figure
    if i == 0:
        fig = None
        axis_dict = None

    # loop over some contour colors
    contour_color = colors.next() if not opts.contour_color \
                                                else opts.contour_color

    # make histograms filled if only one input file to plot
    fill_color = "gray" if len(opts.input_file) == 1 else None

    # make histogram black lines if only one
    hist_color = "black" if len(opts.input_file) == 1 else contour_color
    hist_colors.append(hist_color)

    # pick a new color for each input file
    linecolor = "navy" if len(opts.input_file) == 1 else contour_color

    # plot
    fig, axis_dict = create_multidim_plot(
                    p, s, labels=l, fig=fig, axis_dict=axis_dict,
                    plot_marginal=opts.plot_marginal,
                    marginal_percentiles=opts.marginal_percentiles,
                    plot_scatter=opts.plot_scatter,
                    zvals=zvals[i] if zvals is not None else None,
                    show_colorbar=show_colorbar,
                    cbar_label=zlbl[i] if zlbl is not None else None,
                    vmin=opts.vmin, vmax=opts.vmax,
                    scatter_cmap=opts.scatter_cmap,
                    plot_density=opts.plot_density,
                    plot_contours=opts.plot_contours,
                    contour_percentiles=opts.contour_percentiles,
                    density_cmap=opts.density_cmap,
                    contour_color=contour_color,
                    hist_color=hist_color,
                    line_color=contour_color,
                    fill_color=fill_color,
                    use_kombine=opts.use_kombine_kde,
                    mins=mins, maxs=maxs,
                    expected_parameters=expected_parameters,
                    expected_parameters_color=opts.expected_parameters_color)

# add legend to upper right for input files
if opts.input_file_labels:
    handles = []
    for color, label in zip(hist_colors, opts.input_file_labels):
        handles.append(patches.Patch(color=color, label=label))
    fig.legend(loc="upper right", handles=handles,
               labels=opts.input_file_labels)

# set DPI
fig.set_dpi(200)

# set tight layout
fig.set_tight_layout(True)

# save
metadata.save_fig_with_metadata(
                 fig, opts.output_file, {},
                 cmd=" ".join(sys.argv),
                 title="Posteriors",
                 caption="Posterior probability density functions.")

# finish
logging.info("Done")
