"""
Tuning support code.

This pulls some code out of params_tuning so it can be imported, to
make it work better with ipyparallel.
"""

import os
import logging

import pandas as pd
from lenskit import batch, util, Recommender
from lenskit import topn

import algo

_log = logging.getLogger(__name__)
train_df: pd.DataFrame
test_df: pd.DataFrame

def sample(space, state):
    "Sample a single point from a search space."
    return {
        name: dist.rvs(random_state=state)
        for (name, dist) in space
    }

def sample_points(space, state, npts: int):
    for i in range(npts):
        yield sample(space, state)

def setup_algo(name, N):
    global search_config
    _log.info('initializing tuner for algo %s', name)
    search_config = name, N

def setup_data(train, test):
    global train_df, test_df
    train_df = train
    test_df = test

def sequential_search(space, state, npts):
    for point in sample_points(space, state, npts):
        _log.info('Starting new params combination', point)
        yield evaluate_point(point)

def cluster_search(space, state, npts):
    from ipyparallel import Cluster, Client, LoadBalancedView, DirectView
    _log.info('connecting to ipcluster')
    cluster = Cluster.from_file()
    client: Client = cluster.connect_client_sync()
    _log.info('sending data to workers')
    view: DirectView = client[:]
    view.apply_sync(_init_subprocess_logging)
    view.apply_sync(setup_algo, *search_config)
    view.apply_sync(setup_data, train_df, test_df)
    _log.info('pushing search process to cluster')
    lb: LoadBalancedView = client.load_balanced_view()
    yield from lb.imap(
        evaluate_point,
        sample_points(space, state, npts),
        ordered=False
    )
    _log.info('cleaning up workers')
    view.clear()

def _init_subprocess_logging():
    pid = os.getpid()
    path = f'tuning/tune-{pid}.log'
    h = logging.FileHandler(path, 'w')
    h.setLevel(logging.DEBUG)
    logging.getLogger().addHandler(h)
    logging.getLogger().setLevel(logging.DEBUG)
    logging.getLogger('numba').setLevel(logging.INFO)

def evaluate_point(point):
    global search_config, train_df, test_df
    algo_name, N = search_config
    algo_mod = algo.algorithms[algo_name]
    algo_inst = algo_mod.from_params(**point)
    model = Recommender.adapt(util.clone(algo_inst))

    #print (data)
    model.fit(train_df)
    _log.info('Finished training')

    users = test_df['user'].unique()
    recs = batch.recommend(model, users, N)
                
    _log.info('Finished recommending')

    rla = topn.RecListAnalysis()
    rla.add_metric(topn.recip_rank, k=N)
    rla.add_metric(topn.ndcg, k=N)
    rla.add_metric(topn.ndcg, name='ndcg@10', k=10)

    user_scores = rla.compute(recs, test_df, include_missing=True)
    user_scores.fillna(0, inplace=True)
    _log.debug("computed rec scores:\n%s", user_scores)
    scores = user_scores[['recip_rank', 'ndcg', 'ndcg@10']].mean().rename({
        'recip_rank': f'MRR@{N}',
        'ndcg': f'nDCG@{N}',
        'ndcg@10': 'nDCG@10',
    }).to_dict()

    _log.info('Computed metrics')
    return (scores, point)