__author__ = 'Andres'

import os
import pickle
import json
import argparse
import yaml
import logging
import zstandard
import algo

from scipy import sparse
from gbsim.script_util import init_logging
from gbsim.simulation import SimulationFAIR, SimulationMoveUp, SimulationRerank

_log = logging.getLogger('simulate')

split_folder = 'lfm-2b-artists'

predictions_fidelity_filename = 'predicted_features_{}.npy'
user_features_playcounts_filename = 'out_user_playcounts_als.feats'
item_features_playcounts_filename = 'out_item_playcounts_als.feats'
predictions_playcounts_filename = 'predicted_playcounts_als.npy'
gender_location = 'data/lfm-2b-artists/features.json.zst'


rerank_class_dict = {
    'moveup': SimulationMoveUp,
    'rerank': SimulationRerank,
    'fair': SimulationFAIR,
}

def call_simulation(config):
    # Load train and test data
    artists_gender = json.load(zstandard.open(gender_location, 'rt'))
    # artists_gender = {}
    # with zstandard.open(gender_location, 'rt') as tsvfile:
    #    rd = csv.reader(tsvfile, delimiter="\t", quotechar='"')
    #    for row in rd:
    #        artists_gender[row[0]] = row[0]
    fan_train_data = sparse.load_npz(os.path.join('data', split_folder, 'train_data_playcount.npz')).tocsr()
    sum_listen = fan_train_data.sum(axis=0)
    fan_test_data = pickle.load(zstandard.open(os.path.join('data', split_folder, 'test_data.pkl.zst'), 'rb'))
    fan_items_dict = pickle.load(zstandard.open(os.path.join('data', split_folder, 'items_dict.pkl.zst'), 'rb'))
    #fan_items_dict = pickle.load(open(os.path.join('data', split_folder, 'albums_dict.pkl'), 'rb'))
    items_gender = [0]*len(fan_items_dict)
    for a in fan_items_dict.keys():
        items_gender[fan_items_dict[a]] =artists_gender[str(a)]
    fan_users_dict = pickle.load(zstandard.open(os.path.join('data', split_folder,'users_dict.pkl.zst'), 'rb'))
 
    _log.info("items: %d", len(fan_items_dict))
    _log.info("users: %d", len(fan_users_dict))
    _log.info("listen matrix: %s", sum_listen.shape)


    # Create algorithm
    algo_name = config['algo']['name']
    _log.info("initializing algorithm %s", algo_name)
    algo_mod = algo.algorithms[algo_name]
    with open(f'tuning/{algo_name}.json', 'rt') as pf:
        params = json.load(pf)
    instance_algo = algo_mod.from_params(**params)
    
    # Instantiate Simulation 
    rerank_class = rerank_class_dict[config['simulation']['type']]
    lambda_val = None
    if 'lambda' in config['simulation']:
        lambda_val = config['simulation']['lambda']
        
    simulation = rerank_class(config['experiment_name'], instance_algo, fan_train_data, fan_test_data, fan_items_dict, items_gender, 
        choice_model=config['simulation']['choicemodel'], lambda_val=lambda_val)

    # Iterate over simulation steps
    n_simulation_steps = config['simulation']['steps']
    for iter_n in range(n_simulation_steps):
        simulation.step()

if __name__== "__main__":
    parser = argparse.ArgumentParser(description='Run model training.')
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('-c', "--config")
    args = parser.parse_args()
    init_logging(args)
    with open(args.config, "r") as f:
        config = yaml.full_load(f)

    call_simulation(config)
