#!/usr/bin/env python3
'''Merge Markov chains from multiple input files created by eos-sample-mcmc into one output file.'''

# Copyright (c) 2018 Frederik Beaujean
#
# This file is part of the EOS project. EOS is free software;
# you can redistribute it and/or modify it under the terms of the GNU General
# Public License version 2, as published by the Free Software Foundation.
#
# EOS is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 59 Temple
# Place, Suite 330, Boston, MA  02111-1307  USA

import argparse
import h5py
import numpy as np
import os

def invalid_chain(file, chain, cut_off=None):
    """Check if mode larger than cut off"""

    # last value of the last mode is the best found in all updates sequences
    mode = file['/prerun/chain #%d/stats/mode' % chain][-1][-1]

    if cut_off is None:
        print('chain %d has mode %g in prerun' % (chain, mode))
        return False

    if mode > cut_off:
        return False
    else:
        print('skipping chain %d with mode %g < %g' % (chain, mode, cut_off))
        return True


def merge_preruns(output_file_name, input_files, cut_off=None):
    """
    Merge a prerun, stored in different files into one common file.
    Before:
        file0:/prerun/chain #0/ file1:/prerun/chain #0 ...
    After:
        file0:/prerun/chain #0/ file0:/prerun/chain #1 ...
    """

    # hdf5 groups
    groups = ['/prerun', '/descriptions/prerun']

    # count chains copied
    nchains_copied = 0

    output_file = h5py.File(output_file_name, "w")
    for g in groups:
        output_file.create_group(g)

    for f_i, file_name in enumerate(input_files):
        print("merging %s" % file_name)
        input_file = h5py.File(file_name, 'r')
        nchains_in_file = len(input_file[groups[0]].keys())

        these_groups = list(groups)

        if '/main run' in input_file['/']:
            these_groups.extend(['/main run', '/descriptions/main run'])

        # copy data
        for i in range(nchains_in_file):
            if invalid_chain(input_file, i, cut_off):
                continue
            for g in these_groups:
                input_file.copy(g + "/chain #%d" % i, output_file, name=g + "/chain #%d" % nchains_copied)
            nchains_copied += 1
        input_file.close()

    print("Merged %d chains of %d files" % (nchains_copied, len(input_files)))

    output_file.close()


def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--cut-off', default=None,
                        help='Skip chains whose maximum log posterior is below the cut-off')
    parser.add_argument('--input-file-list', dest='input_file_list',
                        help='List input files in a text file, one file per line. If other input files are passed directly on the command line, then those take precedence but all files will be merged.')
    parser.add_argument('--output', help='Output file name')
    parser.add_argument('input_files', nargs='*',
                        help='Any number of input file names.')

    args = parser.parse_args()

    if args.output is None:
        output = 'mcmc_pre_merged.hdf5'
        args.output = os.path.join(os.getcwd(), output)

    print("Merging into output file %s" % args.output)

    input_files = None
    if args.input_file_list is not None:
        f = open(args.input_file_list)
        #remove trailing newline
        input_files = [name[:-1] for name in f.readlines()]
        args.input_files.extend(input_files)

    if args.cut_off is not None:
        cut_off = float(args.cut_off)
    else:
        cut_off = None

    merge_preruns(output_file_name=args.output,
                  input_files=args.input_files,
                  cut_off=cut_off)


if __name__ == '__main__':
    main()
