#! /usr/bin/env python
"""
zort-initialize :
Create an objects file and rcip map for all lightcurve files in a directory.
These initialization files speed up quick reading for zort.
"""

import glob
import argparse
import pickle
import os
import sys
import subprocess


def gather_lightcurve_files(dataDir):
    fis = [f for f in glob.glob('%s/field*.txt' % dataDir) if
           not os.path.exists(f.replace('.txt', '.objects'))]
    fis.sort()
    return fis


def generate_objects_file(lightcurve_file):
    f_in = open(lightcurve_file, 'r')

    object_keys = ['id', 'nepochs', 'filterid',
                   'fieldid', 'rcid', 'ra', 'dec', 'buffer_position']

    objects_file = lightcurve_file.replace('.txt', '.objects')
    with open(objects_file, 'w') as f_out:
        f_out.write('%s\n' % ','.join(object_keys))

        while True:
            line = f_in.readline()

            # Check for end of the file
            if not line:
                break

            # Only want to look at object lines
            if not line.startswith('#'):
                continue

            data = line.replace('\n', '').split()[1:]
            data_str = ','.join(data)

            buffer_position = f_in.tell() - len(line)
            f_out.write('%s,%i\n' % (data_str, buffer_position))

    f_in.close()

    return objects_file


def save_rcid_map(DR1_object_file, rcid_map):
    rcid_map_filename = DR1_object_file.replace('.objects', '.rcid_map')
    with open(rcid_map_filename, 'wb') as fileObj:
        pickle.dump(rcid_map, fileObj)


def return_rcid_map_size(rcid_map):
    rcid_map_size = 0
    for _, v in rcid_map.items():
        rcid_map_size += len(v)
    return rcid_map_size


def return_rcid_map_filesize(rcid_map):
    rcid_map_filesize = 0
    for k, _ in rcid_map.items():
        for _, t in rcid_map[k].items():
            rcid_map_filesize += t[1] - t[0]
    return rcid_map_filesize


def generate_rcid_map(objects_file):
    f_in = open(objects_file, 'r')
    _ = f_in.readline()  # skip past the header

    rcid, rcid_current = None, None
    filterid, filterid_current = None, None
    buffer_location_start = None
    rcid_map = dict()
    rcid_map[1] = dict()
    rcid_map[2] = dict()

    while True:
        line = f_in.readline()
        buffer_location_current = f_in.tell() - len(line)

        # Check for end of the file
        if not line:
            rcid_map[filterid_current][rcid_current] = (
                buffer_location_start, buffer_location_current)
            break

        line_split = line.split(',')
        rcid = int(line_split[4])
        filterid = int(line_split[2])

        # Initialize the rcid
        if rcid_current is None:
            buffer_location_start = buffer_location_current
            rcid_current = rcid
            filterid_current = filterid

        # Check to see if the block has switched
        if rcid != rcid_current:
            rcid_map[filterid_current][rcid_current] = (
                buffer_location_start, buffer_location_current)
            buffer_location_start = buffer_location_current
            rcid_current = rcid
            filterid_current = filterid

    f_in.close()

    rcid_map_file = objects_file.replace('.objects', '.rcid_map')
    save_rcid_map(rcid_map_file, rcid_map)


def main():
    # Get arguments
    parser = argparse.ArgumentParser(description=__doc__)
    arguments = parser.add_argument_group('arguments')
    arguments.add_argument('-lightcurve-file-directory', type=str,
                           help='Directory containing lightcurve files.',
                           required=True)

    parallelgroup = parser.add_mutually_exclusive_group()
    parallelgroup.add_argument('-single', dest='parallelFlag',
                               action='store_false',
                               help='Run in single mode. DEFAULT.')
    parallelgroup.add_argument('-parallel', dest='parallelFlag',
                               action='store_true',
                               help='Run in parallel mode. Requires mpi4py.')
    parser.set_defaults(parallelFlag=False)

    args = parser.parse_args()

    if args.parallelFlag:
        reqs = subprocess.check_output([sys.executable, '-m', 'pip', 'freeze'])
        installed_packages = [r.decode().split('==')[0] for r in reqs.split()]
        if 'mpi4py' not in installed_packages:
            print('mpi4py must be installed to use -parallel mode.')
            sys.exit(0)

    lightcurve_files = gather_lightcurve_files(args.lightcurve_file_directory)
    if args.parallelFlag:
        from mpi4py import MPI
        comm = MPI.COMM_WORLD
        rank = comm.Get_rank()
        size = comm.Get_size()

        if rank == 0:
            print('Generating object files and RCID maps '
                  'for %i lightcurve files' % len(
                lightcurve_files))

        idx = rank
        while idx < len(lightcurve_files):
            lightcurve_file = lightcurve_files[idx]
            objects_file = generate_objects_file(lightcurve_file)
            generate_rcid_map(objects_file)
            idx += size

    else:
        print('Generating object files and RCID maps '
              'for %i lightcurve files' % len(
            lightcurve_files))
        for lightcurve_file in lightcurve_files:
            objects_file = generate_objects_file(lightcurve_file)
            generate_rcid_map(objects_file)


if __name__ == '__main__':
    main()
