#! /usr/bin/env python

# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import matplotlib as mpl; mpl.use("Agg")
import matplotlib.pyplot as plt
import argparse
import corner
import itertools
import logging
import numpy
import sys
from pycbc import distributions
from pycbc import inference
from pycbc import results
from pycbc.workflow import WorkflowConfigParser

def cartesian(arrays):
    """ Returns a cartesian product from a list of iterables.
    """
    return numpy.array([numpy.array(element) for element in itertools.product(*arrays)])

# command line usage
parser = argparse.ArgumentParser(usage="pycbc_inference_plot_prior [--options]",
    description="Plots prior distributions.")

# add input options
parser.add_argument("--config-files", type=str, nargs="+", required=True,
    help="A file parsable by pycbc.workflow.WorkflowConfigParser.")
parser.add_argument("--sections", type=str, nargs="+", default=["prior"],
    help="Name of section plus subsection with distribution configurations, eg. prior-mass1.")

# add prior options
parser.add_argument("--bins", type=int, default=20,
    help="Number of points to grid a parameter. Need at least 20 bins to fill in the plot.")

# add output options
parser.add_argument("--output-file", type=str, required=True,
    help="Path to output plot.")

# verbose option
parser.add_argument("--verbose", action="store_true", default=False,
    help="")

# parse the command line
opts = parser.parse_args()

# setup log
if opts.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format="%(asctime)s : %(message)s", level=log_level)

# read configuration file
logging.info("Reading configuration files")
cp = WorkflowConfigParser(opts.config_files)

# get prior distribution for each variable parameter
# parse command line values for section and subsection
# if only section then look for subsections
# and add distributions to list
logging.info("Constructing prior")
variable_args = []
dists = []
for sec in opts.sections:
    section = sec.split("-")[0]
    subsec = sec.split("-")[1:]
    if len(subsec):
        subsections = ["-".join(subsec)]
    else:
        subsections = cp.get_subsections(section)
    for subsection in subsections:
        name = cp.get_opt_tag(section, "name", subsection)
        dist = distributions.distribs[name].from_config(
                                            cp, section, subsection)
        variable_args += dist.params
        dists.append(dist)
variable_args = sorted(variable_args)
ndim = len(variable_args)

# construct class that will return draws from the prior
prior = distributions.JointDistribution(variable_args, *dists)

# get all points in space to calculate PDF
logging.info("Getting grid of points")
vals = numpy.zeros(shape=(ndim,opts.bins))
bounds = [{}] * ndim
for dist in dists:
    for param in dist.params:
        idx = variable_args.index(param)
        step = float(dist.bounds[param][1]-dist.bounds[param][0]) / opts.bins
        vals[idx,:] = numpy.arange(dist.bounds[param][0],dist.bounds[param][1],step)
        bounds[idx] = dist.bounds
pts = cartesian(vals)

# evaulate PDF between the bounds
logging.info("Calculating PDF")
pdf = []
for pt in pts:
    pt_dict = dict([(param,pt[j]) for j,param in enumerate(variable_args)])
    pdf.append(sum([dist.pdf(**pt_dict) for dist in dists]))
pdf = numpy.array(pdf)

# check if only one parameter to plot PDF
logging.info("Plotting") 
if len(variable_args) == 1:
    x = vals[0,:]
    xmax = x.max()
    xmin = x.min()
    pad = 0.1*(xmax-xmin)
    fig = plt.figure()
    plt.plot(x, pdf, "k", label="Prior")
    plt.xlim(xmin-pad, xmax+pad)
    plt.ylabel("Probability Density Function")
    plt.xlabel(variable_args[0])
    plt.legend()

# else make corner plot of all PDF
else:
    fig = corner.corner(pts, weights=pdf, labels=variable_args,
                       plot_contours=False, plot_datapoints=False)

    # remove the 1-D histograms
    delaxs = fig.axes[::ndim+1]
    [fig.delaxes(ax) for ax in delaxs]

    # adjust size of remaining plots to cover entire canvas
    fig.subplots_adjust(top=1+1.0/(ndim))
    fig.subplots_adjust(right=1+1.0/(ndim))
    for ax in fig.axes:
        ax.yaxis.get_label().set_position((-0.2,0.5))
        ax.xaxis.get_label().set_position((0.5,-0.2))

# save figure with meta-data
caption_kwargs = {
    "parameters" : ", ".join([param for param in variable_args])
}
caption = """This plot shows the probability density function (PDF) from the 
prior distributions."""
title = """Prior Distributions for {parameters}""".format(**caption_kwargs)
results.save_fig_with_metadata(fig, opts.output_file,
                               cmd=" ".join(sys.argv),
                               title=title,
                               caption=caption)
plt.close()

# exit
logging.info("Done")
