#! /user/bin/env python

# Copyright (C) 2017 Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
"""Extracts some or all samples from an InferenceFile, writing results to a new
InferenceFile.
"""

import os
import argparse
import logging
import numpy
import pycbc
from pycbc.inference import option_utils
from pycbc.io.inference_hdf import InferenceFile
from pycbc.io.inference_txt import InferenceTXTFile

parser = argparse.ArgumentParser(description=__doc__)

parser.add_argument("--output-file", type=str, required=True,
                    help="Output file to create.")
parser.add_argument("--force", action="store_true", default=False,
                    help="If the output-file already exists, overwrite it. "
                         "Otherwise, an IOError is raised.")
parser.add_argument("--posterior-only", action="store_true", default=False,
                    help="Write copied parameter samples and likelihood stats "
                         "as flattened arrays. For example, if the sampler "
                         "wrote a parameter's samples to "
                         "{samples_group}/{param}/walker{x}/, then the "
                         "copied file will have all selected samples from "
                         "all walkers in {samples_group/{param}/. Default is "
                         "for the copied file to have the same structure "
                         "as the input file.")
option_utils.add_inference_results_option_group(parser)
parser.add_argument("--temperature", nargs="+", default=None,
                    help="For parallel-tempered samplers, which temperature "
                         "chain(s) to extract. Options are 'all', or one or "
                         "more indices indicating the temperature chain "
                         "(0 = coldest). Default is to only extract from "
                         "the coldest chain. ")
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Be verbose")

opts = parser.parse_args()

pycbc.init_logging(opts.verbose)

# check that the output doesn't already exist
if os.path.exists(opts.output_file) and not opts.force:
    raise IOError("output file already exists; use --force if you wish to "
                  "overwrite.")

# add any sampler specific options
sampler_kwargs = {}
if opts.temperature is not None:
    temps = opts.temperature
    if temps == ['all']:
        temps = temps[0]
    else:
        temps = map(int, temps)
    sampler_kwargs['temps'] = temps

# save TXT file
if option_utils.get_file_type(opts.output_file) == InferenceTXTFile:
    fp, parameters, labels, samples = option_utils.results_from_cli(
                                                        opts, **sampler_kwargs)
    likelihood_stats = fp.read_likelihood_stats(
                           thin_start=opts.thin_start,
                           thin_end=opts.thin_end,
                           thin_interval=opts.thin_interval,
                           iteration=opts.iteration)
    stat_names = fp[fp.stats_group].keys()
    n_cols = len(parameters) + len(stat_names)
    n_rows = samples.size
    out = numpy.hstack([samples.to_array(fields=parameters).T,
                        likelihood_stats.to_array(fields=stat_names).T])
    InferenceTXTFile.write(opts.output_file, out, labels + stat_names)

# otherwise save HDF file
else:

    # open the input and parse parameters
    if len(opts.input_file) > 1:
        raise ValueError("Too many input files.")
    source = InferenceFile(opts.input_file[0])
    parameters, pnames = option_utils.parse_parameters_opt(opts.parameters)

    # copy to output
    target = source.copy(opts.output_file, parameters=parameters,
                         parameter_names=pnames,
                         posterior_only=opts.posterior_only,
                         thin_start=opts.thin_start, thin_end=opts.thin_end,
                         thin_interval=opts.thin_interval,
                         iteration=opts.iteration,
                         **sampler_kwargs)
    target.close()
