#!/usr/bin/env python3

# Copyright (c) 2017 Danny van Dyk
#
# This file is part of the EOS project. EOS is free software;
# you can redistribute it and/or modify it under the terms of the GNU General
# Public License version 2, as published by the Free Software Foundation.
#
# EOS is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 59 Temple
# Place, Suite 330, Boston, MA  02111-1307  USA

import argparse
from collections import OrderedDict
import eos
from eos.data import *
from logging import error, info, warn
import os
import scipy, scipy.stats
import sys
import yaml

def represent_ordereddict(dumper, data):
    """Hacks to allow for ordered output to YAML"""
    value = []

    for item_key, item_value in data.items():
        node_key = dumper.represent_data(item_key)
        node_value = dumper.represent_data(item_value)

        value.append((node_key, node_value))

    return yaml.nodes.MappingNode('tag:yaml.org,2002:map', value)

yaml.add_representer(OrderedDict, represent_ordereddict)

def kinstr_to_dict(s):
    result = {}
    for kv in s.split(', '):
        kvarray = kv.split('=')
        if len(kvarray) != 2:
            error('key/value pair {} is invalid'.format(kv))
            exit(-1)
        result.update({kvarray[0]: float(kvarray[1])})

    return result


def print_multivariate_gaussian(parameters, data, **kwargs):
    dim = len(parameters)

    means    = numpy.mean(data, axis=0)
    cov_stat = numpy.cov(data.T)
    for i in range(0, dim):
        if cov_stat[i, i] < 0:
            raise ValueError('variance < 0 encountered for i = {}'.format(i))
        for j in range(0, dim):
            if not numpy.isfinite(cov_stat[i, j]):
                raise ValueError('variance not finite for i, j = {}, {}'.format(i, j))

    var_stat = numpy.array([cov_stat[i, i] for i in range(0, dim)])
    if (var_stat <= 0).any():
        error('negative or zero variance(s) encountered; exiting')
        exit(-1)

    corr = numpy.corrcoef(data.T)
    if numpy.isnan(corr).any():
        error('calculation of the correlation matrix yielded non-finite entries; exiting')
        exit(-1)

    det_corr = numpy.linalg.det(corr)
    if numpy.fabs(det_corr) < 1e-7 and 'sys_rel' not in kwargs and 'sys_abs' not in kwargs:
        warn('correlation matrix almost singular with det(corr) = {}; consider adding a systematic uncertainty'.format(det_corr))

    var_sys  = numpy.zeros((dim,))
    if 'sys_abs' in kwargs:
        warn('adding absolute systematic uncertainty to covariance')
        var_sys.fill(float(kwargs['sys_abs']))
    elif 'sys_rel' in kwargs:
        warn('adding relative systematic uncertainty to covariance as fraction of the mean')
        # square the fraction of the mean since we're dealing with the covariance and not the standard deviation
        var_sys = kwargs['sys_rel'] * kwargs['sys_rel'] * numpy.square(means)

    cov_tot = cov_stat + numpy.diag(var_sys)

    constraint = OrderedDict()
    constraint['type']        = 'MultivariateGaussian(Covariance)'
    if kwargs['datafmt'] == 'UNC':
        constraint['observables'] = [p[0].decode('ascii') for p in parameters]
        constraint['kinematics']  = [kinstr_to_dict(p[1].decode('ascii')) for p in parameters]
    elif kwargs['datafmt'] == 'Prediction':
        constraint['observables'] = [p['name'] for p in parameters]
        constraint['kinematics']  = [p['kinematics'] for p in parameters]
    else:
        constraint['observables'] = [p[0].decode('ascii') for p in parameters]
        constraint['kinematics']  = [[] for x in range(0, dim)]
    constraint['options']     = [{} for x in range(0, dim)]
    constraint['means']       = means.tolist()
    constraint['covariance']  = cov_tot.tolist()
    constraint['dof']         = dim

    yaml.dump({kwargs['name']: constraint}, sys.stdout)

    corr_tot = numpy.zeros((dim, dim))
    for i in range(0, dim):
        for j in range(0, dim):
            if not numpy.isfinite(cov_tot[i, j]):
                warn('non-finite covariance entry encountered in i, j = {}, {}'.format(i, j))

            corr_tot[i, j] = cov_tot[i, j] / numpy.sqrt(cov_tot[i, i] * cov_tot[j, j])

    det_corr_tot = numpy.linalg.det(corr_tot)
    if numpy.fabs(det_corr_tot) < 1e-7 and ('sys_rel' in kwargs or 'sys_abs' in kwargs):
        warn('correlation matrix still almost singular with det(corr) = {} after adding systematic uncertainty'.format(det_corr_tot))


def main():
    parser = argparse.ArgumentParser(description='Print mean and standard deviation from an uncertainty propagation')
    parser.add_argument('input', metavar='HDF5FILE', type=str, help='Name of the HDF5 input file')
    parser.add_argument('--absolute-systematic', type=float, dest='sys_abs', help='add an absolute value as systematic uncertainty')
    parser.add_argument('--relative-systematic', type=float, dest='sys_rel', help='add a relative of the mean as systematic uncertainty')
    parser.add_argument('--name', type=str, dest='name', default='unnamed', help='name for the new constraint entry')

    args = parser.parse_args()

    # ensure that the input file exists
    if not os.path.isfile(args.input):
        error('\'%s\' is not a file' % args.input)

    if os.path.isdir(args.input):
        # open EOS numpy/yaml-based data file
        datafile = Prediction(args.input)
        data = datafile.samples
        parameters = datafile.varied_parameters
        datafmt = 'Prediction'
    else:
        # open HDF5 data file
        basename = os.path.basename(args.input)
        if basename.startswith('mcmc_'):
            datafile = MCMCDataFile(args.input)
            datafmt = 'MCMC'
        elif basename.startswith('pmc_monolithic'):
            datafile = PMCDataFile(args.input)
            datafmt = 'PMC'
        elif basename.startswith('unc'):
            datafile = UncertaintyDataFile(args.input)
            datafmt = 'UNC'

        parameters = datafile.parameters
        data = datafile.data()[:, 0:len(parameters)]

    # forward options
    argvars = vars(args)
    options = {'datafmt': datafmt}
    for key in ['name', 'sys_abs', 'sys_rel']:
        if argvars[key] is not None:
            options.update({key: argvars[key]})

    # strip row with infinite or NaN entries
    len_before = len(data)
    data = data[~numpy.isnan(data).any(axis=1)]
    len_after = len(data)
    if len_after < len_before:
        warn('encountered and removed {} data set with either NaN or inf entries; proceeding with filtered data set'.format(len_before - len_after))

    if len(parameters) > 1:
        print_multivariate_gaussian(numpy.array(parameters), data, **options)
    else:
        raise "not yet implemented"

    exit(0);

if __name__ == '__main__':
    main()
