#! /usr/bin/env python

# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Plots samples from inference sampler.
"""

import argparse
import h5py
import logging
import matplotlib as mpl; mpl.use("Agg")
import matplotlib.pyplot as plt
import pycbc
from pycbc import results
from pycbc import transforms
from pycbc.inference import option_utils
import sys

# command line usage
parser = argparse.ArgumentParser(usage=__file__ + " [--options]",
                                 description=__doc__)

# verbose option
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Print logging info.")

# output plot options
parser.add_argument("--output-file", type=str, required=True,
                    help="Path to output plot.")

# add results group options
option_utils.add_inference_results_option_group(parser)

# parse the command line
opts = parser.parse_args()

# setup log
pycbc.init_logging(opts.verbose)

# load the results
fp, parameters, labels, _ = option_utils.results_from_cli(opts,
                                                          load_samples=False)

# get number of dimensions
ndim = len(parameters)

# plot samples
# plot each parameter as a different subplot
logging.info("Plotting samples")
fig, axs = plt.subplots(ndim, sharex=True)
plt.xlabel("Iteration")

# loop over parameters
axs = [axs] if not hasattr(axs, "__iter__") else axs
for i, arg in enumerate(parameters):

    # loop over walkers
    for j in range(fp.nwalkers):

        # plot each walker as a different line on the subplot
        file_parameters, cs = transforms.get_common_cbc_transforms(
                                                 parameters, fp.variable_args)
        y = fp.read_samples(file_parameters, walkers=j,
                            thin_start=opts.thin_start,
                            thin_interval=opts.thin_interval,
                            thin_end=opts.thin_end)
        y = transforms.apply_transforms(y, cs)
        axs[i].plot(y[arg], alpha=0.25)

        # y-axis label
        axs[i].set_ylabel(labels[i])

# save figure with meta-data
caption_kwargs = {
    "parameters" : ", ".join(labels),
}
caption = r"""All samples from all the walker chains for the parameters. Each
line is a different chain of walker samples."""
title = "Samples for {parameters}".format(**caption_kwargs)
results.save_fig_with_metadata(fig, opts.output_file,
                               cmd=" ".join(sys.argv),
                               title=title,
                               caption=caption)
plt.close()

# exit
fp.close()
logging.info("Done")
