#!/usr/bin/env python3

# Copyright (c) 2020 Danny van Dyk
#
# This file is part of the EOS project. EOS is free software;
# you can redistribute it and/or modify it under the terms of the GNU General
# Public License version 2, as published by the Free Software Foundation.
#
# EOS is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 59 Temple
# Place, Suite 330, Boston, MA  02111-1307  USA

import argparse
import eos
import logging as log
from logging import debug, info, warn
import numpy as _np
import os
import pypmc
import scipy
import sys
import yaml

# return the value of the environment variable, or a default value if the variable is unset.
def get_from_env(envvar, default):
    if not envvar in os.environ:
        return default

    return os.environ[envvar]


def main():
    parser = argparse.ArgumentParser(description='Carry out a Bayesian analysis using EOS')
    subparsers = parser.add_subparsers(title = 'commands')

    ## begin of commands

    # list-priors
    parser_list_priors = subparsers.add_parser('list-priors',
        description = 'Lists the prior PDFs defined within the scope of this analysis.',
        help = 'list the known priors'
    )
    parser_list_priors.set_defaults(cmd = cmd_list_priors)

    # list-likelihoods
    parser_list_likelihoods = subparsers.add_parser('list-likelihoods',
        description = 'Lists the likelihoods defined within the scope of this analysis.',
        help = 'list the known likelihoods'
    )
    parser_list_likelihoods.set_defaults(cmd = cmd_list_likelihoods)

    # list-posteriors
    parser_list_posteriors = subparsers.add_parser('list-posteriors',
        description = 'Lists the posteriors defined within the scope of this analysis.',
        help = 'list the known posteriors'
    )
    parser_list_posteriors.set_defaults(cmd = cmd_list_posteriors)

    # list-predictions
    parser_list_predictions = subparsers.add_parser('list-predictions',
        description = 'Lists the predictions defined within the scope of this analysis.',
        help = 'list the known predictions'
    )
    parser_list_predictions.set_defaults(cmd = cmd_list_predictions)

    # sample-mcmc
    parser_sample_mcmc = subparsers.add_parser('sample-mcmc',
        description = 'Samples from a posterior using Markov Chain Monte Carlo (MCMC) methods.',
        help = 'sample from a posterior using Markov chains'
    )
    parser_sample_mcmc.add_argument('posterior', metavar = 'POSTERIOR',
        help = 'posterior PDF from which to draw samples'
    )
    parser_sample_mcmc.add_argument('chain', metavar = 'CHAIN-IDX',
        help = 'index of the Markov chain (used to seed the RNG)'
    )
    parser_sample_mcmc.add_argument('-N', '--number-of-samples',
        help = 'number of samples to be drawn',
        dest = 'N', action = 'store', type = int, default = 1000
    )
    parser_sample_mcmc.add_argument('-S', '--stride',
        help = 'stride (i.e. number of samples to skip) to obtain final samples',
        dest = 'stride', action = 'store', type = int, default = 5
    )
    parser_sample_mcmc.add_argument('-p', '--number-of-preruns',
        help = 'number of times a prerun adapts to the posterior',
        dest = 'preruns', action = 'store', type = int, default = 3
    )
    parser_sample_mcmc.add_argument('-n', '--number-of-prerun-samples',
        help = 'number of samples to be drawn in each prerun (to be discarded)',
        dest = 'pre_N', action = 'store', type = int, default = 150
    )
    parser_sample_mcmc.add_argument('-b', '--base-directory',
        help = 'base directory where to store samples',
        dest = 'base_directory', action = 'store', default = get_from_env('EOS_BASE_DIRECTORY', './')
    )
    parser_sample_mcmc.set_defaults(cmd = cmd_sample_mcmc)

    # sample-pmc
    parser_sample_pmc = subparsers.add_parser('sample-pmc',
        description = 'Samples from a posterior using Population Monte Carlo (PMC) methods.',
        help = 'sample from a posterior using adaptive importance sampling'
    )
    parser_sample_pmc.add_argument('posterior', metavar = 'POSTERIOR',
        help = 'posterior PDF from which to draw samples'
    )
    parser_sample_pmc.add_argument('path', metavar = 'PATH',
        help = 'path to the output of \'find-cluster\''
    )
    parser_sample_pmc.add_argument('-n', '--number-of-adaptation-samples',
        help = 'number of samples to be drawn',
        dest = 'step_N', action = 'store', type = int, default = 500
    )
    parser_sample_pmc.add_argument('-s', '--number-of-adaptation-steps',
        help = 'number of samples to be drawn',
        dest = 'steps', action = 'store', type = int, default = 10
    )
    parser_sample_pmc.add_argument('-N', '--number-of-final-samples',
        help = 'number of samples to be drawn',
        dest = 'final_N', action = 'store', type = int, default = 5000
    )
    parser_sample_pmc.add_argument('-b', '--base-directory',
        help = 'base directory where to store samples',
        dest = 'base_directory', action = 'store', default = get_from_env('EOS_BASE_DIRECTORY', './')
    )
    parser_sample_pmc.set_defaults(cmd = cmd_sample_pmc)

    # plot-samples
    parser_plot_samples = subparsers.add_parser('plot-samples',
        description = 'Plots samples from a posterior.',
        help = 'plot samples'
    )
    parser_plot_samples.add_argument('path', metavar = 'PATH',
        help = 'path to the sampling data'
    )
    parser_plot_samples.add_argument('-b', '--bins',
        help = 'number of bins in the histogram',
        dest = 'bins', action = 'store', type = int, default = 50
    )
    parser_plot_samples.set_defaults(cmd = cmd_plot_samples)

    # find-mode
    parser_find_mode = subparsers.add_parser('find-mode',
        description = 'Finds the mode of a posterior.',
        help = 'find mode of posterior'
    )
    parser_find_mode.add_argument('posterior', metavar = 'POSTERIOR',
        help = 'posterior PDF from which to draw samples'
    )
    parser_find_mode.add_argument('-p', '--starting-points',
        help = 'number of points from which minization is started',
        dest = 'points', action = 'store', type = int, default = 10
    )
    parser_find_mode.set_defaults(cmd = cmd_find_mode)

    # find-clusters
    parser_find_clusters = subparsers.add_parser('find-clusters',
        description = 'Find clusters among the posterior samples by Gelman-Rubin R value.',
        help = 'find clusters'
    )
    parser_find_clusters.add_argument('output_path', metavar = 'OUTPUT_PATH',
        help = 'path to the storage location for the clusters',
        action = 'store', type = str
    )
    parser_find_clusters.add_argument('input_paths', metavar = 'INPUT_PATH',
        help = 'path to the sampling data',
        action = 'store', type = str, nargs='+',
    )
    parser_find_clusters.add_argument('-t', '--threshold',
        help = 'R value threshold',
        dest = 'threshold', action = 'store', type = float, default = 2.0
    )
    parser_find_clusters.set_defaults(cmd = cmd_find_clusters)

    # predict-observables
    parser_predict_observables = subparsers.add_parser('predict-observables',
        description = 'Predict a set of observables based on previously obtained PMC samples.',
        help = 'predict observables'
    )
    parser_predict_observables.add_argument('prediction', metavar = 'PREDICTION',
        help = 'name set of observables to predict',
        action = 'store', type = str
    )
    parser_predict_observables.add_argument('path', metavar = 'INPUT_PATH',
        help = 'path to the sampling data',
        action = 'store', type = str
    )
    parser_predict_observables.add_argument('-b', '--base-directory',
        help = 'base directory where to store samples',
        dest = 'base_directory', action = 'store', default = get_from_env('EOS_BASE_DIRECTORY', './')
    )
    parser_predict_observables.set_defaults(cmd = cmd_predict_observables)

    ## end of commands

    # add verbosity arg to all commands
    parsers = [
        parser,
        parser_find_clusters, parser_find_mode,
        parser_list_priors, parser_list_likelihoods, parser_list_posteriors, parser_list_predictions,
        parser_plot_samples,
        parser_sample_mcmc, parser_sample_pmc,
        parser_predict_observables
    ]
    for p in parsers:
        p.add_argument('-v', '--verbose',
            help = 'increase verbosity',
            dest = 'verbose', action = 'count', default = 0
        )
        p.add_argument('-f', '--analysis-file',
            help = 'analysis file to be used',
            dest = 'analysis_file', action = 'store', default = '.analysis.yaml'
        )

    args = parser.parse_args()

    if args.verbose > 3:
        args.verbose = 3

    levels = {
        0: log.ERROR,
        1: log.WARNING,
        2: log.INFO,
        3: log.DEBUG
    }

    log.basicConfig(level=levels[args.verbose])

    if not 'cmd' in args:
        parser.print_help()
    elif not callable(args.cmd):
        parser.print_help()
    else:
        args.cmd(args)


def make_analysis_file(args):
    analysis_file = eos.AnalysisFile(args.analysis_file)
    return analysis_file


class RNG:
    def __init__(self, seed):
        self._rng = _np.random.mtrand.RandomState(seed)

    def rand(self, *args):
        result = self._rng.rand(*args)
        print('rand: {}'.format(result))
        return result

    def normal(self, loc, scale=1, size=None):
        result = self._rng.normal(loc, scale, size)
        print('normal: {}'.format(result))
        return result

    def uniform(self, low, high, size=None):
        result = self._rng.uniform(low, high, size)
        print('uniform: {}'.format(result))
        return result


# Find mode
def cmd_find_mode(args):
    analysis_file = make_analysis_file(args)
    analysis = analysis_file.analysis(args.posterior)
    min_chi2 = sys.float_info.max
    gof = None
    bfp = None
    info('Starting minimization in {} points'.format(args.points))
    for i in range(args.points):
        starting_point = [float(p) for p in analysis.varied_parameters]
        _bfp = analysis.optimize(starting_point)
        _gof = eos.GoodnessOfFit(analysis.log_posterior)
        _chi2 = _gof.total_chi_square()
        if _chi2 < min_chi2:
            gof = _gof
            bfp = _bfp
    print('best-fit point = {}'.format(bfp.point))
    print('p value = {:.1f}%'.format(100 * (1.0 - scipy.stats.chi2(gof.total_degrees_of_freedom()).cdf(gof.total_chi_square()))))
    print('individual test statistics:')
    for n, e in gof:
        print('  - {}: chi^2 / dof = {:f} / {}'.format(n, e.chi2, e.dof))


# Plot samples
def cmd_plot_samples(args):
    basename = os.path.basename(os.path.normpath(args.path))
    if basename.startswith('mcmc-'):
        data = eos.data.MarkovChain(args.path)
    elif basename.startswith('pmc'):
        data = eos.data.PMCSampler(args.path)
    else:
        raise RuntimeError('unsupported data set: {}'.format(args.path))

    parameters = eos.Parameters()
    for idx, p in enumerate(data.varied_parameters):
        print('plotting histogram for {}'.format(p['name']))
        if data.type in ['Prediction']:
            label = eos.Observables()[p['name']]
        elif data.type in ['MarkovChain', 'PMCSampler']:
            pp = parameters[p['name']]
            label = pp.latex()
        else:
            label = r'\verb+{}+'.format(p['name'])

        description = {
            'plot': {
                'x': { 'label': label, 'format':  '${x:.5f}$', 'range': [p['min'], p['max']] },
                'y': { 'label': 'prob. density' }
            },
            'contents': [
                {
                    'type': 'histogram', 'bins': args.bins,
                    'data': {
                        'samples': data.samples[:, idx],
                    }
                }
            ]
        }
        plotter = eos.plot.Plotter(description, os.path.join(args.path, '{}.pdf'.format(idx)))
        plotter.plot()

# Find clusters
def cmd_find_clusters(args):
    chains    = [eos.data.MarkovChain(path).samples for path in args.input_paths]
    means     = [_np.mean(chain, axis=0) for chain in chains]
    variances = [_np.var(chain,  axis=0) for chain in chains]
    groups    = pypmc.mix_adapt.r_value.r_group(means, variances, len(chains[0]), critical_r=args.threshold)
    info('Found {} groups using an R value threshold of {}'.format(len(groups), args.threshold))
    density   = pypmc.mix_adapt.r_value.make_r_gaussmix(chains, K_g=len(groups), critical_r=args.threshold)
    eos.data.MixtureDensity.create(args.output_path, density)


# Sample MCMC
def cmd_sample_mcmc(args):
    analysis_file = make_analysis_file(args)
    analysis = analysis_file.analysis(args.posterior)
    rng = _np.random.mtrand.RandomState(int(args.chain) + 1701)
    samples, weights = analysis.sample(N=args.N, stride=args.stride, pre_N=args.pre_N, preruns=args.preruns, rng=rng)
    output_path = os.path.join(args.base_directory, args.posterior, 'mcmc-{:04}'.format(int(args.chain)))
    eos.data.MarkovChain.create(output_path, analysis.varied_parameters, samples, weights)


# Sample PMC
def cmd_sample_pmc(args):
    analysis_file = make_analysis_file(args)
    analysis = analysis_file.analysis(args.posterior)
    rng = _np.random.mtrand.RandomState(1701)
    initial_proposal = eos.data.MixtureDensity(args.path).density()
    samples, weights, proposal = analysis.sample_pmc(initial_proposal, step_N=args.step_N, steps=args.steps, final_N=args.final_N, rng=rng)
    output_path = os.path.join(args.base_directory, args.posterior, 'pmc')
    eos.data.PMCSampler.create(output_path, analysis.varied_parameters, samples, weights, proposal)


# Predict observables
def cmd_predict_observables(args):
    _parameters = eos.Parameters()
    analysis_file = make_analysis_file(args)
    observables = analysis_file.observables(args.prediction, _parameters)

    basename = os.path.basename(os.path.normpath(args.path))
    if basename.startswith('mcmc-'):
        data = eos.data.MarkovChain(args.path)
    elif basename.startswith('pmc'):
        data = eos.data.PMCSampler(args.path)
    else:
        raise RuntimeError('unsupported data set: {}'.format(args.path))

    try:
        from tqdm import tqdm
        progressbar = tqdm
    except ImportError:
        progressbar = lambda x: x

    parameters = [_parameters[p['name']] for p in data.varied_parameters]
    observable_samples = []
    for sample in progressbar(data.samples):
        for p, v in zip(parameters, sample):
            p.set(v)
        observable_samples.append([o.evaluate() for o in observables])
    observable_samples = _np.array(observable_samples)

    output_path = os.path.join(args.base_directory, args.prediction, 'pred')
    eos.data.Prediction.create(output_path, observables, observable_samples, data.weights)


# List priors
def cmd_list_priors(args):
    analysis_file = make_analysis_file(args)
    for name, prior in analysis_file.priors.items():
        print(name)


# List likelihoods
def cmd_list_likelihoods(args):
    analysis_file = make_analysis_file(args)
    for name, lh in analysis_file.likelihoods.items():
        print(name)


# List predictions
def cmd_list_predictions(args):
    analysis_file = make_analysis_file(args)
    for name, pred in analysis_file.predictions.items():
        print(name)


# List posteriors
def cmd_list_posteriors(args):
    analysis_file = make_analysis_file(args)
    for name, lh in analysis_file.posteriors.items():
        print(name)


if __name__ == '__main__':
    main()
