#! /usr/bin/env python

# Copyright (C) 2016 Miriam Cabero Mueller
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#

import argparse
import logging
import numpy
import subprocess
import os
import glob
import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot
from multiprocessing import Pool
import pycbc.results
from pycbc import transforms
from pycbc.inference import option_utils
from pycbc.io.inference_hdf import InferenceFile
from pycbc.results.scatter_histograms import create_multidim_plot, \
    get_scale_fac

def integer_logspace(start, end, num):
    """Generates a list of integers that are spaced approximately uniformly
    in log10 space between `start` and `end`. This is done such that the
    length of the output array is guaranteed to have length equal to num.

    Parameters
    ----------
    start : int
        Integer to start with; must be >= 0.
    end : int
        Integer to end with; must be > start.
    num : int
        The number of integers to generate.

    Returns
    -------
    array
        The output array of integers.
    """
    start += 1
    end += 1
    out = numpy.zeros(num, dtype=int)
    x = numpy.round(numpy.logspace(numpy.log10(start), numpy.log10(end),
                       num=num)).astype(int) - 1
    dx = numpy.diff(x)
    start_idx = 0
    while (dx == 0).any():
        # collect the unique values up to the point that their
        # difference becomes > 1
        x = numpy.unique(x)
        dx = numpy.diff(x)
        stop_idx = numpy.where(dx > 1)[0][0]
        keep = x[:stop_idx]
        stop_idx += start_idx
        out[start_idx:stop_idx] = keep
        start_idx = stop_idx
        # regenerate, starting from the new starting point
        num -= len(keep)
        start = keep[-1] + 2
        x = numpy.round(numpy.logspace(numpy.log10(start), numpy.log10(end),
                           num=num)).astype(int) - 1
        dx = numpy.diff(x)
    out[start_idx:len(x)+start_idx] = x
    return out

parser = argparse.ArgumentParser()

parser.add_argument("--input-file", type=str, required=True,
                    help="Results file path.")
parser.add_argument("--start-sample", type=int, default=1,
                    help="Start sample for the first frame. Note: sample "
                         "counting starts from 1. Default is 1.")
parser.add_argument("--end-sample", type=int, default=None,
                    help="End sample for the last frame. If None, will "
                         "default to the last sample.")
parser.add_argument("--frame-number", type=int,
                    help="Number of frames for the movie.")
parser.add_argument("--frame-step", type=int, 
                    help="Step in the sample between frames for the movie. "
                         "Only provide if not --frame-number given.")
parser.add_argument("--log-steps", action="store_true", default=False,
                    help="If frame-number is specified, make the number of "
                         "samples between frames uniform in log10. This "
                         "provides more detail of the early iterations, when "
                         "the sampler is changing most rapidly. An error will "
                         "be raised if frame-number is not provided.")
parser.add_argument("--output-prefix", type=str, required=True,
                    help="Output path and prefix for the frame files "
                         "(without extension).")
parser.add_argument("--parameters", type=str, nargs="+",
                    metavar="PARAM[:LABEL]",
                    help="Name of parameters to plot in same format "
                         "as for pycbc_inference_plot_posterior.")
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--dpi', type=int, default=200,
                    help='Set the dpi for each frame; default is 200')
parser.add_argument("--nprocesses", type=int, default=None,
                    help="Number of processes to use. If not given then "
                         "use maximum.")
parser.add_argument("--movie-file", type=str,
                    help="Path for creating the movie automatically after "
                         "generating all the frames. If a movie with the same "
                         "name already exists, it will remove it. Format: mp4. "
                         "Installation of FFMPEG required.")
parser.add_argument("--cleanup", action='store_true',
                    help="Delete all plots generated after creating the movie."
                         " Only works together with option make-movie.")
# add options for what plots to create
option_utils.add_plot_posterior_option_group(parser)
# add scatter and density configuration options
option_utils.add_scatter_option_group(parser) 
option_utils.add_density_option_group(parser)

opts = parser.parse_args()
pycbc.init_logging(opts.verbose)

# Get data
logging.info('Loading parameters')
fp, parameters, labels, _ = option_utils.results_from_cli(opts,
                            load_samples=False)

if opts.end_sample is None:
    opts.end_sample = fp.niterations

if opts.log_steps and not opts.frame_number:
    raise ValueError("log-steps requires a non-zero frame-number to be "
                     "provided; see help for details.")

# convert sample numbers to index
start_index = opts.start_sample - 1
end_index = opts.end_sample - 1

# Define thin interval based on number of frames or frame step
if opts.frame_number and opts.frame_step:
    raise ValueError("Only one of frame-number or frame-step must be "
                     "provided.")
elif opts.frame_number:
    if opts.frame_number > fp.niterations:
        raise ValueError("frame number is > the number of iterations "
                         "({})".format(fp.niterations))
    if opts.log_steps:
        iterations = integer_logspace(start_index, end_index,
                                      opts.frame_number)
    else:
        iterations = numpy.unique(numpy.linspace(start_index, end_index,
                                    num=opts.frame_number).astype(int))
    nframes = len(iterations)
    # create a mask to pull out the desired values
    itermask = numpy.zeros(fp.niterations, dtype=bool)
    itermask[iterations] = True
    thinint = thin_start = thin_end = None
elif opts.frame_step:
    thinint = opts.frame_step
    thin_start = start_index
    thin_end = end_index + thinint
    itermask = None
    nframes = 1 + (thin_end - thin_start)/thinint
    iterations = numpy.arange(nframes-1) * thinint + thin_start
else:
    raise ValueError("At least one of frame-number or frame-step must be "
                     "provided.")

# get samples from InferenceFile
file_parameters, trans = transforms.get_common_cbc_transforms(
                                         parameters, fp.variable_args)
samples = fp.read_samples(file_parameters, thin_start=thin_start,
                          thin_interval=thinint, thin_end=thin_end,
                          iteration=itermask, flatten=False)
samples = transforms.apply_transforms(samples, trans)
if samples.ndim > 2:
    # multi-tempered samplers will return a 3 dims, so flatten
    _, ii, jj = samples.shape
    samples = samples.reshape((ii, jj))

# Get z-values
if opts.z_arg is not None:
    logging.info("Getting likelihood stats")
    likelihood_stats = fp.read_likelihood_stats(thin_start=thin_start,
        thin_end=thin_end, thin_interval=thinint, iteration=itermask,
        flatten=False)
    if likelihood_stats.ndim > 2:
        _, ii, jj = likelihood_stats.shape
        likelihood_stats = likelihood_stats.reshape((ii, jj))
    zvals, zlbl = option_utils.get_zvalues(fp, opts.z_arg, likelihood_stats)
    show_colorbar = True
    # Set common min and max for colorbar in all plots
    if opts.vmin is None:
        vmin = zvals.min()
    else:
        vmin = opts.vmin
    if opts.vmax is None:
        vmax = zvals.max()
    else:
        vmax = opts.vmax
else:
    zvals = None
    zlbl = None
    vmin = vmax = None
    show_colorbar = False

fp.close()

# get injection values if desired
expected_parameters = {}
if opts.plot_injection_parameters:
    injections = option_utils.injections_from_cli(opts)
    for p in parameters:
        # check that all of the injections are the same
        unique_vals = numpy.unique(injections[p])
        if unique_vals.size != 1:
            raise ValueError("More than one injection found! To use "
                "plot-injection-parameters, there must be a single unique "
                "injection in all input files. Use the expected-parameters "
                "option to specify an expected parameter instead.")
        # passed: use the value for the expected
        expected_parameters[p] = unique_vals[0]
# get expected parameter values from command line
expected_parameters.update(option_utils.expected_parameters_from_cli(opts))
expected_parameters_color = opts.expected_parameters_color

logging.info('Choosing common characteristics for all figures')
# Set common min and max for axis in all plots
mins, maxs = option_utils.plot_ranges_from_cli(opts)
# add any missing parameters
for p in parameters:
    if p not in mins:
        mins[p] = samples[p].min()
for p in parameters:
    if p not in maxs:
        maxs[p] = samples[p].max()

# Make each figure
# for sorting purposes, we will need to zero-pad the sample number with the
# appriopriate number of 0's
max_sample_num = iterations[-1] + 1

def _make_frame(frame):
    """Wrapper for making the plot in a pooled environment.
    """
    plotargs = samples[:,frame]
    if zvals is not None:
        z = zvals[:,frame]
    else:
        z = None
    sample_num = str(iterations[frame] + 1)
    sample_num = sample_num.zfill(len(str(max_sample_num)))
    output = opts.output_prefix + '-{}.png'.format(sample_num)

    fig, axis_dict = create_multidim_plot(parameters, plotargs, labels=labels,
                        mins=mins, maxs=maxs,
                        plot_marginal=opts.plot_marginal,
                        plot_scatter=opts.plot_scatter,
                            zvals=z, show_colorbar=show_colorbar,
                            cbar_label=zlbl, vmin=vmin, vmax=vmax,
                            scatter_cmap=opts.scatter_cmap,
                        plot_density=opts.plot_density,
                        plot_contours=opts.plot_contours,
                            density_cmap=opts.density_cmap,
                            contour_color=opts.contour_color,
                            use_kombine=opts.use_kombine_kde,
                        expected_parameters=expected_parameters,
                        expected_parameters_color=expected_parameters_color)

    # Write sample number
    if show_colorbar:
        xtxt = 0.85
    else:
        xtxt = 0.9
    ytxt = 0.95
    scale_fac = get_scale_fac(fig)
    fontsize = 8*scale_fac
    pyplot.annotate('Sample {}'.format(sample_num), xy=(xtxt, ytxt),
        xycoords='figure fraction', horizontalalignment='right',
        verticalalignment='top', fontsize=fontsize)

    fig.savefig(output, bbox_inches='tight', dpi=opts.dpi)
    pyplot.close()

# create the pool
if opts.nprocesses is None or opts.nprocesses > 1:
    global make_frame
    make_frame = _make_frame
    pool = Pool(opts.nprocesses)
    mfunc = pool.map
else:
    make_frame = _make_frame
    mfunc = map

logging.info("Making frames")
mfunc(make_frame, range(len(iterations)))

if opts.movie_file:
    logging.info("Making movie")
    frame_files = opts.output_prefix + "*.png"
    if os.path.isfile(opts.movie_file):
        os.remove(opts.movie_file)
    subprocess.call(["ffmpeg", "-pix_fmt", "yuv420p", "-s", "1024x768",
                     "-pattern_type", "glob", "-i",
                     frame_files, opts.movie_file])
    if opts.cleanup:
        logging.info("Removing frames")
        for frame in glob.glob(frame_files):
            os.remove(frame)

logging.info('Done')
