#! /usr/bin/env python

# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Runs a sampler to find the posterior distributions.
"""

import os
import sys
import argparse
import logging
import numpy
import shutil
import pycbc
import pycbc.opt
import pycbc.weave
import pycbc.version
from pycbc import distributions
from pycbc import transforms
from pycbc import fft
from pycbc import inference
from pycbc import psd
from pycbc import scheme
from pycbc import strain
from pycbc import types
from pycbc.io.inference_hdf import InferenceFile, check_integrity
from pycbc.waveform import generator
from pycbc.workflow import WorkflowConfigParser
from pycbc.inference import option_utils
from pycbc.inference.option_utils import validate_checkpoint_files
from pycbc.inference import burn_in

# command line usage
parser = argparse.ArgumentParser(usage=__file__ + " [--options]",
                                 description=__doc__)

# version option
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg,
                    help="Prints version information.")

# add data options
parser.add_argument("--instruments", type=str, nargs="+",
                    help="IFOs, eg. H1 L1.")
option_utils.add_low_frequency_cutoff_opt(parser)
parser.add_argument("--psd-start-time", type=float, default=None,
                    help="Start time to use for PSD estimation if different "
                         "from analysis.")
parser.add_argument("--psd-end-time", type=float, default=None,
                    help="End time to use for PSD estimation if different "
                         "from analysis.")

# add inference options
parser.add_argument("--likelihood-evaluator", required=True,
                    choices=inference.likelihood_evaluators.keys(),
                    help="Evaluator class to use to calculate the likelihood.")
parser.add_argument("--seed", type=int, default=0,
                    help="Seed to use for the random number generator that "
                         "initially distributes the walkers. Default is 0.")
parser.add_argument("--samples-file", default=None,
                    help="Use an iteration from an InferenceFile as the "
                         "initial proposal distribution. The same "
                         "number of walkers and the same [variable_args] "
                         "section in the configuration file should be used. "
                         "The priors must allow encompass the initial "
                         "positions from the InferenceFile being read.")

# add sampler options
option_utils.add_sampler_option_group(parser)

# add config options
WorkflowConfigParser.add_config_opts_to_parser(parser)

# output options
parser.add_argument("--output-file", type=str, required=True,
                    help="Output file path.")
parser.add_argument("--force", action="store_true", default=False,
                    help="If the output-file already exists, overwrite it. "
                         "Otherwise, an OSError is raised.")
parser.add_argument("--save-strain", action="store_true", default=False,
                    help="Save the conditioned strain time series to the "
                         "output file. If gate-overwhitened, this is done "
                         "before all gates have been applied.")
parser.add_argument("--save-stilde", action="store_true", default=False,
                    help="Save the conditioned strain frequency series to "
                         "the output file. This is done after all gates have "
                         "been applied.")
parser.add_argument("--save-psd", action="store_true", default=False,
                    help="Save the psd of each ifo to the output file.")
parser.add_argument("--checkpoint-interval", type=int, default=None,
                    help="Number of iterations to take before saving new "
                         "samples to file, calculating ACL, and updating "
                         "burn-in estimate.")
parser.add_argument("--resume-from-checkpoint", action="store_true",
                    default=False,
                    help="Automatically load results from checkpoint/backup "
                         "file.")
parser.add_argument("--save-backup", action="store_true",
                    default=False,
                    help="Don't delete the backup file after the run has "
                         "completed.")
parser.add_argument("--checkpoint-fast", action="store_true",
                    help="Do not calculate ACL after each checkpoint, only at "
                         "the end. Not applicable if n-independent-samples "
                         "have been specified.")
parser.add_argument("--acl-interval", type=int, default=None,
                    help="Specify the interval on which to calculate ACL and "
                         "burn-in. If not provided, will use the "
                         "checkpoint-interval. This is doubled after 50000 "
                         "iterations; doubled again after 100000 iterations; "
                         "and doubled again after 200000 iterations.")

# verbose option
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Print logging messages.")

# add module pre-defined options
fft.insert_fft_option_group(parser)
pycbc.opt.insert_optimization_option_group(parser)
psd.insert_psd_option_group_multi_ifo(parser)
scheme.insert_processing_option_group(parser)
strain.insert_strain_option_group_multi_ifo(parser)
pycbc.weave.insert_weave_option_group(parser)
strain.add_gate_option_group(parser)

# parse command line
opts = parser.parse_args()

# setup log
pycbc.init_logging(opts.verbose)

# verify options are sane
fft.verify_fft_options(opts, parser)
pycbc.opt.verify_optimization_options(opts, parser)
#psd.verify_psd_options(opts, parser)
scheme.verify_processing_options(opts, parser)
#strain.verify_strain_options(opts, parser)
pycbc.weave.verify_weave_options(opts, parser)

# check for the output file
if os.path.exists(opts.output_file) and not opts.force:
    raise OSError("output-file already exists; use --force if you wish to "
                  "overwrite it.")

# check for how many iterations to run
max_iterations = opts.niterations
if opts.niterations is not None and opts.n_independent_samples is not None:
    raise ValueError("Must specify either niterations or n-independent-"
                     "samples, not both")
elif opts.niterations is not None:
    get_nsamples = opts.niterations
elif opts.n_independent_samples is not None:
    if opts.checkpoint_interval is None:
        raise ValueError("n-independent-samples requires a checkpoint-"
                         "interval; see help")
    get_nsamples = opts.n_independent_samples
else:
    raise ValueError("Must specify niterations or n-independent-samples; "
                     "see --help")

# get likelihood class
likelihood_class = inference.likelihood_evaluators[opts.likelihood_evaluator]
likelihood_args = {}

# warn user about options that will be removed
if opts.skip_burn_in:
    logging.warning("DEPRECATION WARNING: Option --skip-burn-in has been "
                    "deprecated and does nothing. "
                    "If you wish to use no burn in, simply do not "
                    "provide a burn-in-function. This option will be removed "
                    "in future versions.")

# set seed
numpy.random.seed(opts.seed)
logging.info("Using seed %i", opts.seed)

# we'll silence numpy warnings since they are benign and make for confusing
# logging output
numpy.seterr(divide='ignore', invalid='ignore')

# get scheme
ctx = scheme.from_cli(opts)
fft.from_cli(opts)

# get the data and psd
if 'data' in likelihood_class.required_kwargs:
    logging.info("Loading data")
    strain_dict, stilde_dict, psd_dict = option_utils.data_from_cli(opts)
    low_frequency_cutoff_dict = option_utils.low_frequency_cutoff_from_cli(opts)
    likelihood_args['f_lower'] = low_frequency_cutoff_dict
    likelihood_args['psds'] = psd_dict
else:
    strain_dict = stilde_dict = psd_dict = low_frequency_cutoff_dict = None

with ctx:

    # read configuration file
    cp = WorkflowConfigParser.from_cli(opts)

    # get sampling transformations
    all_transforms = []
    if cp.has_section('sampling_parameters'):
        sampling_parameters, replace_parameters = \
            option_utils.read_sampling_args_from_config(cp)
        sampling_transforms = transforms.read_transforms_from_config(cp,
            'sampling_transforms')
        logging.info("Sampling in {} in place of {}".format(
            ', '.join(sampling_parameters), ', '.join(replace_parameters)))
        all_transforms.extend(sampling_transforms)
    else:
        sampling_parameters = replace_parameters = sampling_transforms = None

    # get waveform transformations
    if any(cp.get_subsections('waveform_transforms')):
        logging.info("Loading waveform transforms")
        waveform_transforms = transforms.read_transforms_from_config(cp,
            'waveform_transforms')
        all_transforms.extend(waveform_transforms)
    else:
        waveform_transforms = None

    # get the vairable and static arguments from the config file
    variable_args, static_args = distributions.read_params_from_config(cp)
    constraints = distributions.read_constraints_from_config(
        cp, transforms=all_transforms)

    # get prior distribution for each variable parameter
    logging.info("Setting up priors for each parameter")
    dists = distributions.read_distributions_from_config(cp, "prior")

    # construct class that will return the prior
    prior_eval = distributions.JointDistribution(variable_args, *dists,
                                      **{"constraints" : constraints})

    # get ifo-specific instances of calibration model
    if cp.has_section('calibration'):
        logging.info("Initializing calibration model")
        recalib = {ifo : strain.read_model_from_config(cp, ifo) for
                   ifo in opts.instruments}
    else:
        recalib = None

    # select generator that will generate waveform for a single IFO
    # for likelihood evaluator
    logging.info("Setting up sampler")

    # get gates for templates
    gates = strain.gates_from_cli(opts)

    if 'waveform_generator' in likelihood_class.required_kwargs:
        generator_function = generator.select_waveform_generator(
            static_args["approximant"])

        # construct class that will generate waveforms
        waveform_generator = generator.FDomainDetFrameGenerator(
            generator_function, epoch=stilde_dict.values()[0].epoch,
            variable_args=variable_args, detectors=opts.instruments,
            delta_f=stilde_dict.values()[0].delta_f,
            delta_t=strain_dict.values()[0].delta_t,
            recalib=recalib, gates=gates,
            **static_args)

    else:
        waveform_generator = None

    # construct class that will return the natural logarithm of likelihood
    likelihood = likelihood_class(variable_args,
                                  waveform_generator=waveform_generator,
                                  data=stilde_dict, prior=prior_eval,
                                  sampling_parameters=sampling_parameters,
                                  replace_parameters=replace_parameters,
                                  sampling_transforms=sampling_transforms,
                                  waveform_transforms=waveform_transforms,
                                  **likelihood_args)

    burn_in_eval = burn_in.BurnIn(opts.burn_in_function,
                                min_iterations=opts.min_burn_in)

    # create sampler that will run
    sampler = option_utils.sampler_from_cli(opts, likelihood)

    # check for backup file(s)
    checkpoint_file = opts.output_file + '.checkpoint'
    backup_file = opts.output_file + '.bkup'
    checkpoint_valid = validate_checkpoint_files(checkpoint_file, backup_file)

    # determine what to do with checkpoints
    if checkpoint_valid and not opts.resume_from_checkpoint and not opts.force:
        raise OSError("valid checkpoint file {} found, but "
                      "resume-from-checkpoint not on. If you wish to overwrite "
                      "use --force; otherwise, use --resume-from-checkpoint")
    if not opts.resume_from_checkpoint and opts.force:
        checkpoint_valid = False

    # save information about this data and settings
    if not checkpoint_valid:
        with InferenceFile(checkpoint_file, "w") as fp:
            # save command line and data
            logging.info("Creating and writing data to output file")
            fp.write_data(
                strain_dict=strain_dict if opts.save_strain else None,
                stilde_dict=stilde_dict if opts.save_stilde else None,
                psd_dict=psd_dict if opts.save_psd else None,
                low_frequency_cutoff_dict=low_frequency_cutoff_dict)

            # save injection parameters
            if opts.injection_file:
                for ifo in opts.instruments:
                    logging.info("Writing %s injections to output file", ifo)
                    if ifo in opts.injection_file.keys():
                        inj_file = opts.injection_file[ifo]
                    elif len(opts.injection_file) == 1:
                        inj_file = opts.injection_file.values()[0]
                    else:
                        logging.warn("Could not find injections for %s", ifo)
                        continue
                    fp.write_injections(opts.injection_file.values()[0], ifo)
        # copy to backup
        shutil.copy(checkpoint_file, backup_file)

    # write the command line, resume point
    for fn in [checkpoint_file, backup_file]:
        with InferenceFile(fn, "a") as fp:
            fp.write_command_line()
            if checkpoint_valid:
                fp.write_resume_point()

    # set the walkers initial positions from a pre-existing InferenceFile
    # or a specific initial distribution listed in the configuration file
    # or else use the prior distributions to set initial positions
    logging.info("Setting walkers initial conditions for varying parameters")
    samples_file = opts.samples_file
    # use the checkpoint file instead if resume from checkpoint
    if opts.resume_from_checkpoint and checkpoint_valid:
        samples_file = checkpoint_file
    if samples_file is not None:
        logging.info("Initial positions taken from last iteration in %s",
                     samples_file)
        samples_file = InferenceFile(samples_file, "r")
        init_prior = None
    elif len(cp.get_subsections("initial")):
        initial_dists = distributions.read_distributions_from_config(
            cp, section="initial")
        init_prior = distributions.JointDistribution(variable_args,
            *initial_dists, **{"constraints" : constraints})
    else:
        init_prior = None
    sampler.set_p0(samples_file=samples_file, prior=init_prior)

    # if getting samples from file then put sampler and random number generator
    # back in its former state
    if samples_file is not None:
        sampler.set_state_from_file(samples_file)
        samples_file.close()

    # run sampler's burn in if it is in the list of burn in functions
    if "use_sampler" in burn_in_eval.burn_in_functions:
        # remove the sampler's burn in so we don't run more than once
        burn_in_eval.burn_in_functions.pop("use_sampler")
        # we'll only do this if we don't have a valid checkpoint: since the
        # checkpoint happens after the sampler's burn in, the sampler's burn in
        # must have already run if we have a valid checkpoint file
        if not checkpoint_valid:
            with InferenceFile(checkpoint_file, "a") as fp:
                logging.info("Running sampler's burn in function")
                burnidx, is_burned_in = burn_in.use_sampler(sampler, fp)
                sampler.write_burn_in_iterations(fp, burnidx, is_burned_in)
                # write the burn in results
                logging.info("Writing burn in samples to file")
                sampler.write_results(fp, static_args=static_args,
                                      ifos=opts.instruments)
            # write to backup file
            with InferenceFile(backup_file, "a") as fp:
                sampler.write_burn_in_iterations(fp, burnidx, is_burned_in)
                sampler.write_results(fp, static_args=static_args,
                                      ifos=opts.instruments)


    # get the starting number of samples:
    # nsamples keeps track of the number of samples we've obtained (if
    # --n-independent-samples is used, this is the number of independent
    # samples; otherwise, this is the number of iterations);
    # start is the number of iterations that the file already contains (either
    # due to sampler burn-in, or a previous checkpoint)
    try:
        with InferenceFile(checkpoint_file, "r") as fp:
            start = fp.niterations
    except KeyError:
        start = 0
    if opts.n_independent_samples is not None:
        try:
            with InferenceFile(checkpoint_file, "r") as fp:
                nsamples = fp.n_independent_samples
        except AttributeError:
            nsamples = start
    else:
        nsamples = start
    # to ensure iterations are counted properly, he sampler's lastclear should
    # be the same as start
    sampler.lastclear = start

    interval = opts.checkpoint_interval
    if interval is None:
        interval = get_nsamples

    if opts.acl_interval is None:
        opts.acl_interval = interval

    acl_interval = opts.acl_interval

    # run sampler until we have the desired number of samples
    acl_interval_count = 0
    while nsamples < get_nsamples:

        end = start + interval

        # adjust the interval if we would go past the number of iterations
        if opts.n_independent_samples is None and end > get_nsamples:
            interval = get_nsamples - start
            end = start + interval

        # run sampler and set initial values to None so that sampler
        # picks up from where it left off next call
        logging.info("Running sampler for {} to {} iterations".format(start,
                                                                      end))
        sampler.run(interval)

        # write new samples
        with InferenceFile(checkpoint_file, "a") as fp:

            logging.info("Writing results to file")
            sampler.write_results(fp,static_args=static_args,
                                  ifos=opts.instruments)
            acl_interval_count += interval
            if acl_interval_count >= acl_interval:
                logging.info("Calculating burn in")
                burnidx, is_burned_in = burn_in_eval.update(sampler, fp)

                # compute the acls and write
                acls = None
                if opts.n_independent_samples is not None \
                        or end >= get_nsamples \
                        or not opts.checkpoint_fast:
                    logging.info("Computing acls")
                    acls = sampler.compute_acls(fp)
                    sampler.write_acls(fp, acls)

        # write to backup
        with InferenceFile(backup_file, "a") as fp:

            logging.info("Writing to backup file")
            sampler.write_results(fp,static_args=static_args,
                                  ifos=opts.instruments)
            if acl_interval_count >= acl_interval:
                sampler.write_burn_in_iterations(fp, burnidx, is_burned_in)
                if acls is not None:
                    sampler.write_acls(fp, acls)

        # check validity
        checkpoint_valid = validate_checkpoint_files(checkpoint_file,
                                                     backup_file)
        if not checkpoint_valid:
            raise IOError("error writing to checkpoint file")

        # update nsamples for next loop
        if opts.n_independent_samples is not None:
            with InferenceFile(checkpoint_file, 'r') as fp:
                nsamples = fp.n_independent_samples
            logging.info("Have {} independent samples".format(nsamples))
        else:
            nsamples += interval

        if acl_interval_count >= acl_interval:
            acl_interval_count = 0

        with InferenceFile(checkpoint_file, 'r') as fp:
            niters = fp.niterations
            if niters > 200000:
                acl_interval = 8 * opts.acl_interval
            elif niters > 100000:
                acl_interval = 4 * opts.acl_interval
            elif niters > 50000:
                acl_interval = 2 * opts.acl_interval

        # clear the in-memory chain to save memory
        logging.info("Clearing chain")
        sampler.clear_chain()

        start = end

    # compute evidence, if supported
    with InferenceFile(checkpoint_file, 'a') as fp:
        try:
            lnz, dlnz = sampler.calculate_logevidence(fp)
            logging.info("Saving evidence")
            sampler.write_logevidence(fp, lnz, dlnz)
        except NotImplementedError:
            pass

# rename checkpoint to output and delete backup
logging.info("Moving checkpoint to output")
os.rename(checkpoint_file, opts.output_file)
if not opts.save_backup:
    logging.info("Deleting backup file")
    os.remove(backup_file)

# exit
logging.info("Done")
sys.exit(0)
