import os
import numpy as np
import time
import csv

from collections import Counter
from scipy.linalg import norm

from lenskit import topn

from gbsim.rbp import rbp, _bulk_rbp
from gbsim.old_eval import coverage, evaluate


def gini(array):
    # based on bottom eq: http://www.statsdirect.com/help/content/image/stat0206_wmf.gif
    # from: http://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm

    # array is implemented a list; thus, not flatten necessary
    # array = array.flatten()  # all values are treated equally, arrays must be 1d
    if np.amin(array) < 0:
        array -= np.amin(array)  # values cannot be negative
    array = np.add(array, 0.0000001)  #, casting="unsafe")
    array += 0.0000001  # values cannot be 0
    array = np.sort(array)  # values must be sorted
    index = np.arange(1, array.shape[0] + 1)  # index per array element
    n = array.shape[0]  # number of array elements
    return (np.sum((2 * index - n - 1) * array)) / (n * np.sum(array))  # Gini coefficient


def get_reco_sets(fan_test_data, items_gender, predicted_x):  # total, total_10
    fan_test_data_sorted = []
    fan_test_data_female = []
    fan_test_data_male = []
    fan_test_data_nonbinary = []
    predicted_female = []
    predicted_male = []
    predicted_nonbinary = []

    artist_gender_user = []
    artist_gender_user_recommend = []
    artist_gender_first_female = []
    artist_gender_first_male = []
    artist_gender_first_nonbinary = []

    reco_set_count_female = Counter()
    reco_set_count_male = Counter()
    reco_set_count_nonbinary = Counter()
    reco_set_count_10_female = Counter()
    reco_set_count_10_male = Counter()
    reco_set_count_10_nonbinary = Counter()

    for i in range(len(fan_test_data)):
        # fan_test_data_sorted.append(fan_test_data[i])
        test_u_sorted_playcount = sorted([(a, p) for a, p in fan_test_data[i]], key=lambda x:x[1])
        fan_test_data_sorted.append([a[0] for a in test_u_sorted_playcount])
        fan_test_data_female.append([a[0] for a in test_u_sorted_playcount if items_gender[a[0]] == "Female"])
        fan_test_data_male.append([a[0] for a in test_u_sorted_playcount if items_gender[a[0]] == "Male"])
        fan_test_data_nonbinary.append([a[0] for a in test_u_sorted_playcount if items_gender[a[0]] == "Non-binary"])
        if len(fan_test_data_sorted) == 0:
            continue
        first_female = None
        first_male = None
        first_nonbinary = None
        curr_predict_female = []
        curr_predict_male = []
        curr_predict_nonbinary = []
        for p, a in enumerate(predicted_x[i]):
            if first_female is None and items_gender[a] == 'Female':
                first_female = p
            if first_male is None and items_gender[a] == 'Male':
                first_male = p
            if first_nonbinary is None and items_gender[a] == 'Non-binary':
                first_nonbinary = p
            # if first_male != None and first_female != None and first_nonbinary != None:
            #    break
            if items_gender[a] == 'Female':
                curr_predict_female.append(a)
            elif items_gender[a] == 'Male':
                curr_predict_male.append(a)
            elif items_gender[a] == 'Non-binary':
                curr_predict_nonbinary.append(a)
        predicted_female.append(curr_predict_female)
        predicted_male.append(curr_predict_male)
        predicted_nonbinary.append(curr_predict_nonbinary)
        if first_female is not None:
            artist_gender_first_female.append(first_female)
        else:
            artist_gender_first_female.append(len(predicted_x[i]) + 1)
        if first_male is not None:
            artist_gender_first_male.append(first_male)
        else:
            artist_gender_first_male.append(len(predicted_x[i]) + 1)
        if first_nonbinary is not None:
            artist_gender_first_nonbinary.append(first_nonbinary)
        else:
            artist_gender_first_nonbinary.append(len(predicted_x[i]) + 1)

        listened_gender = None
        listened = dict(Counter([items_gender[a[0]] for a in test_u_sorted_playcount]))
        female = 0
        male = 0
        nonbinary = 0
        if 'Female' in listened:
            female = listened['Female']
        if 'Male' in listened:
            male = listened['Male']
        if 'Non-binary' in listened:
            nonbinary = listened['Non-binary']
        if (male + female + nonbinary) > 0:
            listened_gender = female / (male + female + nonbinary)
            artist_gender_user.append(listened_gender)

        listened = dict(Counter([items_gender[a] for a in predicted_x[i]]))
        female = 0
        male = 0
        nonbinary = 0
        if 'Female' in listened:
            female = listened['Female']
        if 'Male' in listened:
            male = listened['Male']
        if 'Non-binary' in listened:
            nonbinary = listened['Non-binary']
        if (male + female + nonbinary) > 0 and listened_gender is not None:
            artist_gender_user_recommend.append(female / (male + female + nonbinary))          

        reco_set_count_female.update([a for a in predicted_x[i] if items_gender[a] == "Female"])
        reco_set_count_male.update([a for a in predicted_x[i] if items_gender[a] == 'Male'])
        reco_set_count_nonbinary.update([a for a in predicted_x[i] if items_gender[a] == 'Non-binary'])

        reco_set_count_10_female.update(a for a in predicted_x[i][:10] if items_gender[a] == "Female")
        reco_set_count_10_male.update(a for a in predicted_x[i][:10] if items_gender[a] == 'Male')
        reco_set_count_10_nonbinary.update(a for a in predicted_x[i][:10] if items_gender[a] == 'Non-binary')

    # array with counter dict for female, male, nonbinary, and total artists
    reco_set_count = np.array([[reco_set_count_female], [reco_set_count_male],
                               [reco_set_count_nonbinary],
                               [reco_set_count_female + reco_set_count_male + reco_set_count_nonbinary]])

    # array with counter dict for female, male, nonbinary, and total artists
    reco_set_count_10 = np.array([[reco_set_count_10_female], [reco_set_count_10_male],
                                  [reco_set_count_10_nonbinary],
                                  [reco_set_count_10_female + reco_set_count_10_male + reco_set_count_10_nonbinary]])

    return (reco_set_count, reco_set_count_10,
            artist_gender_user, artist_gender_user_recommend,
            artist_gender_first_female, artist_gender_first_male, artist_gender_first_nonbinary,
            predicted_female, predicted_male, predicted_nonbinary, fan_test_data_sorted)


def get_user_metrics(predicted_x, fan_test_data):
    rla = topn.RecListAnalysis()
    rla.add_metric(topn.ndcg)
    results = rla.compute(predicted_x, fan_test_data)
    results.groupby('Algorithm').ndcg.mean()
    return results


def show_eval(predicted_x, fan_test_data, item_ids, items_gender, sum_listen, changes):
    (reco_set_count, reco_set_count_10,
     artist_gender_user, artist_gender_user_recommend,
     artist_gender_first_female, artist_gender_first_male, artist_gender_first_nonbinary,
     predicted_female, predicted_male, predicted_nonbinary, fan_test_data_sorted) \
        = get_reco_sets(fan_test_data, items_gender, predicted_x)

    res = {}
    res['Female_listened'] = np.mean(artist_gender_user)
    res['Female_recommended'] = np.mean(artist_gender_user_recommend)
    res['First_female'] = np.mean(artist_gender_first_female)
    res['First_male'] = np.mean(artist_gender_first_male)
    res['First_nonbinary'] = np.mean(artist_gender_first_nonbinary)
    
    # coverage and gini for female artists, male artists, nonbinary artists, all artists
    # list of lists to be filled with the counter values for female, male, nonbinary, and all artists respectively
    list_gen = [[], [], [], []]
    gen_names = ['female', 'male', 'nonbinary', 'all']
    for gen in range(4):  # gen (order): female, male, nonbinary, all
        for i in reco_set_count[gen]:
            for key in i.keys():
                list_gen[gen].append(i[key])
        res[f'Coverage@100_{gen_names[gen]}'] = len(list_gen[gen])
        res[f'GINI@100_{gen_names[gen]}'] = gini(list_gen[gen])

    # coverage and gini for female artists, male artists, nonbinary artists, all artists
    # list of lists to be filled with the counter values for female, male, nonbinary, and all artists respectively
    list_gen = [[], [], [], []]  # reset list of lists
    for gen in range(4):  # gen (order): female, male, nonbinary, all
        for i in reco_set_count_10[gen]:
            for key in i.keys():
                list_gen[gen].append(i[key])
        res[f'Coverage@10_{gen_names[gen]}'] = len(list_gen[gen])
        res[f'GINI@10_{gen_names[gen]}'] = gini(list_gen[gen])

    # currently only ndcg
    # user_avg_metrics = get_user_metrics(predicted_x, fan_test_data)
    # header += ', ' + user_avg_metrics.head()
    # for metric in user_avg_metrics:
    #    res.append(metric)

    res['Coverage_on_FAN_test_set@100'] = coverage(fan_test_data_sorted, 100)
    res['all_changes'] = changes[0]
    res['non_zero_changes'] = changes[1]
    res['zero_users'] = changes[2]
    res['iter'] = changes[3]
    res['method'] = changes[4]

    #  currently only ndcg
    #user_avg_metrics = get_user_metrics(predicted_x, fan_test_data)
    #header += ', ' + user_avg_metrics.head()
    #for metric in user_avg_metrics:
    #    res.append(metric)

    #rbp_total = np.mean(_bulk_rbp(predicted_x, fan_test_data))
    #header += ', RBP_total'
    #res.append(rbp_total)

    # TODO: implement rbp for female and male and nonbinary
    #rbp_female = np.mean(_bulk_rbp(predicted_x, fan_test_data))
    #rbp_male = np.mean(_bulk_rbp(predicted_x, fan_test_data))
    #rbp_nonbinary = np.mean(_bulk_rbp(predicted_x, fan_test_data))

    metrics = ['map@10', 'precision@1', 'precision@3', 'precision@5', 'precision@10', 'r-precision', 'ndcg@10']
    results = evaluate(metrics, fan_test_data_sorted, predicted_x)  # [:, :10])
    for metric in metrics:
        res[metric] = '{:.4f}'.format(results[metric])

    print(res)
    return res


def save_eval(all_res, folder, name):
    if not os.path.exists(folder):
        os.makedirs(folder)

    des_file = name + "_" + str(time.strftime("%Y%m%d-%H%M%S"))
    header = all_res[0].keys()

    with open(os.path.join(folder, des_file + '.csv'), 'w') as f:
        writer = csv.writer(f)
        writer.writerow(header)
        for d in all_res:
            writer.writerow(list(d.values()))
