__author__ = 'Andres'

from pathlib import Path
import os
import time
import argparse
import pickle
import json
import logging
import sys
import csv

import zstandard
import pandas as pd
import seedbank
from scipy import sparse

from gbsim import tuner
import algo


_log = logging.getLogger('tune-algo')

tune_dir = Path('tuning')
data_dir = Path('data')
split_folder = data_dir / 'lfm-2b-artists'
gender_location = split_folder / 'features.json.zst'

def params_tune(args):
    algo_name = args.algo
    _log.info('loading algorithm %s', algo_name)
    algo_mod = algo.algorithms[algo_name]

    # Load train and test data
    _log.info('loading input data')
    fan_train_data = sparse.load_npz(split_folder / 'train_data_playcount.npz').tocoo()
    fan_test_data = pickle.load(zstandard.open(split_folder / 'test_data.pkl.zst', 'rb'))

    train_df = pd.DataFrame({
        'user': fan_train_data.row,
        'item': fan_train_data.col,
        'count': fan_train_data.data
    })
    test_df = pd.DataFrame.from_records(
        ((u, i, c)
         for (u, items) in enumerate(fan_test_data)
         for (i, c) in items),
        columns=['user', 'item', 'count']
    )
    tuner.setup_data(train_df, test_df)
 
    _log.info('Finished loading data')

    N = 100
    tuner.setup_algo(algo_name, N)

    metric = f'MRR@{N}'
    state = seedbank.numpy_random_state()
    all_rets = []

    if args.point_file:
        args.point_file.parent.mkdir(exist_ok=True)
        recf = open(args.point_file, 'wt')
        record = csv.DictWriter(recf, [
            n for (n, _d) in algo_mod.space
        ] + [metric])
        record.writeheader()
        recf.flush()
    else:
        record = None

    if args.cluster:
        search = tuner.cluster_search
    else:
        search = tuner.sequential_search
    for res, point in search(algo_mod.space, state, args.points):
        _log.info('%s: %s=%.4f', point, metric, res[metric])
        #point.update(res[metric])
        all_rets.append((res, point))
        if record is not None:
            record.writerow(point | {metric: res[metric]})
            recf.flush()

    points = sorted(all_rets, key=lambda p: p[0][metric], reverse=True)
    best_point = points[0]
    _log.info('finished in with %s %.3f', metric, best_point[0][metric])
    for v in best_point[0]:
        _log.info('best: %s', v)
    if args.output:
        _log.info('writing best point to %s', args.output)
        with open(args.output, 'wt') as out:
            json.dump(best_point[1] | best_point[0], out)

    if record is not None:
        _log.info('closing point record file')
        recf.close()


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
                    prog='Sim',
                    description='Run simulation of iterative recommendations')
    parser.add_argument('-a', '--algo', metavar='NAME',
                        help='use algorithm NAME')
    parser.add_argument('-p', '--points', default=10, type=int,
                        help='test on N points')
    parser.add_argument('-v', '--verbose', action='store_true', help='enable verbose logging')
    parser.add_argument('-s', '--seed',
                        type=lambda s: int(s) if s.isdigit() else s,
                        help='RNG seed to use')
    parser.add_argument('--cluster', action='store_true', help='run on ipyparallel cluster')
    parser.add_argument('-o', '--output', type=Path, help='output file for tuned parameters')
    parser.add_argument('-P', '--point-file', type=Path, help='output file for all tested points')
    return parser.parse_args()


def setup(args: argparse.Namespace):
    "Set up initial things to run"
    if args.verbose:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logging.basicConfig(stream=sys.stderr, level=level)
    # numba is noisy, turn off its debug messages even in verbose mode
    logging.getLogger('numba').setLevel(logging.INFO)
    _log.debug('CLI arguments: %s', args)

    if args.seed:
        _log.info('initializing RNG seed %s', args.seed)
        # include the algo name so two algos with the same seed get different points,
        # deterministically
        seedbank.initialize(args.seed, args.algo)
    else:
        _log.warn('initializing RNG from time')
        seedbank.initialize(int(time.time()))

    if 'USE_CLUSTER' in os.environ:
        args.cluster = True

if __name__== "__main__":
    args = parse_args()
    setup(args)                    
    params_tune(args)
