import pandas as pd
import numpy as np
import pickle
import os
import random

from scipy import sparse
from scipy.stats import binom
from gbsim.eval import show_eval
from lenskit import batch, util, Recommender

OUTPUT_DIR = "./output"


class SimulationBase():
    def __init__(self, name, algo, fan_train_data, fan_test_data, fan_items_dict, items_gender, **kwargs):
        self.algo = algo
        self.name = name
        print('Init Simulation', self.name)
        self.fan_test_data = fan_test_data
        self.fan_train_data = fan_train_data.tocoo().copy()
        self.sum_listen = fan_train_data.sum(axis=0)
        self.fan_items_dict = fan_items_dict
        self.items_gender = items_gender
        self.choice_model = kwargs.get('choice_model', 'InspectionAbandon')
        self.lambda_val = kwargs.get('lambda_val', None)

        self.save_initial()
        self.curr_iter = 0
        self.model = self.train(self.fan_train_data, self.algo)

    def step(self):
        """ Run one step of the simulation
        """
        # Generate predictions with last trained model
        predicted = self.predict()

        # If step is 0 just evaluate the initial training
        if self.curr_iter > 0:
            # Post process the predictions
            predicted, ret_all, ret_nonz, ret_zero = self.post_process(predicted, self.items_gender)
            changes = [ret_all, ret_nonz, ret_zero, self.curr_iter, self.name]

            # Update current training data with new predictions
            self.fan_train_data = self.increase_count(predicted)

            # Retrain model with new data
            self.model = self.train(self.fan_train_data, self.algo)
        else:
            changes = [0, 0, 0, 0, self.name]

        self.save_iteration(predicted, changes)
        self.curr_iter += 1
        #return res

    def post_process(self, predicted, items_gender, N=100):
        """ Method defined by each type of post process that is applied during the simulation
        """
        pass

    def train(self, impl_train_data, algo):
        """ Train model with the current data that contains interactions from last iteration
        """
        model = Recommender.adapt(util.clone(algo))
        data = pd.DataFrame({'user': impl_train_data.row,
                            'item': impl_train_data.col, 'ratings': impl_train_data.data})
        model.fit(data)
        print('Trained step', self.curr_iter)
        return model

    def predict(self, N=100):
        """ Generate predictions for all users based on current model
        """
        n_user = self.fan_train_data.shape[0]
        preds = batch.recommend(self.model, np.arange(n_user), N, n_jobs=8)
        mtrx = np.zeros((n_user, N), dtype=np.uint32)
        mtrx[preds['user'], preds['rank']-1] = preds['item']
        print('Finished predictions step', self.curr_iter)
        return mtrx

    def increase_count(self, predicted, items_ids=None, step=1000, M=10):
        """ Simulate users interactions for the given predictions,
        and include these new interactions in the current dataset
        """
        if items_ids != None:
            items_ids_dict = {v: i for i, v in enumerate(items_ids)}
        #  TODO: Try to add the mean of user's playcounts
        # len(user_ids)):
        for u in range(0, self.fan_train_data.shape[0], step):
            topn = predicted[u:u+step, :][:, :M].flatten()
            # a user's topN are assumed as consumed
            if self.choice_model == 'DeterministicTopN':
                if items_ids != None:
                    topn = np.array([items_ids_dict[i] for i in topn])
            # for each item of a user's topN, the probability of consuming the item is 50%.
            elif self.choice_model == 'RandomTopN':
                if items_ids != None:
                    for i in topn:
                        if random() < 0.5:  # Probability of consuming an item is 50%.
                            topn = np.array([items_ids_dict[i] for i in topn])
            # for a user's topN items, the probability of consuming the item is 50%, with a probability of 30% to
            # abandon the inspection of further items in the ranking
            elif self.choice_model == 'InspectionAbandon':
                if items_ids != None:
                    for i in topn:
                        if random() < 0.5:  # Probability of consuming an item is 50%.
                            topn = np.array([items_ids_dict[i] for i in topn])
                            if random() < 0.3:  # Probability of abandoning further inspection of items is 30%.
                                break
            # for a user's topN items, the probability of consuming the item is basically 50%, it is 1ß% higher if it
            # is a MALE items; general probability of 30% to abandon the inspection of further items in the ranking
            elif self.choice_model == 'InspectionAbandonBiasedProMale':
                if items_ids != None:
                    for i in topn:
                        if self.items_gender[i] == 'Male':
                            p = 0.5*1.1
                        else:
                            p=0.5
                        if random() < p:  # Probability of consuming an item is 50%.
                            topn = np.array([items_ids_dict[i] for i in topn])
                            if random() < 0.3:  # Probability of abandoning further inspection of items is 30%.
                                break
            # for a user's topN items, the probability of consuming the item is basically 50%, it is 1ß% higher if it
            # is a FEMALE items; general probability of 30% to abandon the inspection of further items in the ranking
            elif self.choice_model == 'InspectionAbandonBiasedProFemale':
                if items_ids != None:
                    for i in topn:
                        if self.items_gender[i] == 'Female':
                            p = 0.5*1.1
                        else:
                            p=0.5
                        if random() < p:  # Probability of consuming an item is 50%.
                            topn = np.array([items_ids_dict[i] for i in topn])
                            if random() < 0.3:  # Probability of abandoning further inspection of items is 30%.
                                break
            else:
                print('Choice model not correctly defined')
            u_min = min(u+step, self.fan_train_data.shape[0])
            rows = np.repeat(np.arange(u, u_min), M)
            mtrx_sum = sparse.csr_matrix((np.repeat(
                M, topn.shape[0]), (rows, topn)), shape=self.fan_train_data.shape, dtype=np.float32)
            fan_train_data = self.fan_train_data+mtrx_sum
            # artists_count.update(topn.tolist())
        return fan_train_data.tocoo()

    def save_initial(self):
        location = os.path.join(OUTPUT_DIR, self.name)
        if not os.path.exists(location):
            os.makedirs(location)

        data_to_save = {"fan_test_data": self.fan_test_data, "sum_listen": self.sum_listen.tolist(),
            "fan_items_dict": self.fan_items_dict, "items_gender": self.items_gender}
        for file_name, np_list in data_to_save.items():
            pickle.dump(np_list, open(os.path.join(location, f"{file_name}.pkl"), 'wb'))

    def save_iteration(self, predicted, changes):
        location = os.path.join(OUTPUT_DIR, self.name, str(self.curr_iter))
        if not os.path.exists(location):
            os.makedirs(location)

        np.savez_compressed(os.path.join(location, "predicted.npz"), data=predicted)
        pickle.dump(changes, open(os.path.join(location, "changes.pkl"), 'wb'))


class SimulationMoveUp(SimulationBase):
    def post_process(self, list1, items_gender, N=100):
        """ We apply a post-process to the recommendations that moves the most highly ranked female recommendation
        to the first position
        """
        res = np.zeros((list1.shape[0], N), dtype=np.uint32)
        zero_users = 0
        ret_all = []
        ret_nonzero = []
        for u in range(list1.shape[0]):
            curr_res = []
            curr_first = False
            for i in range(0, N):
                if curr_first == False and items_gender[list1[u, i]] == "Female":
                    curr_first = i
            prev_curr_first = curr_first
            if curr_first != 0 and curr_first != False:

                for i in range(0, N):
                    if curr_first != False and i <= curr_first:
                        curr_res.append(list1[u, curr_first])
                        if i != curr_first:
                            curr_res.append(list1[u, i])
                        curr_first = False
                    else:
                        if i != prev_curr_first:
                            curr_res.append(list1[u, i])
                ret_all.append(1)
                ret_nonzero.append(1)
            else:
                curr_res = list1[u].tolist()
                zero_users += 1
                ret_all.append(0)
            if (len(curr_res) == 99):
                print(curr_first, prev_curr_first, curr_res, list1[u])
            res[u] = np.array(curr_res[:N])

        return res, np.mean(ret_all), np.mean(ret_nonzero), zero_users


class SimulationRerank(SimulationBase):
    def post_process(self, list1, items_gender, N=100):
        """ We apply a post-process to the recommendations that moves all recommendations from male artists down in the
        ranking by 'lambda' positions
        """
        print (list1.shape[0])
        res = np.zeros((list1.shape[0], N), dtype=np.uint32)
        ret_all = []
        ret_non_zero = []
        zero_users = 0
        for u in range(0, list1.shape[0]):
            counter = 0
            recs_dict = {item: p for p, item in enumerate(list1[u, :])}
            for i, track in enumerate(recs_dict.keys()):
                if items_gender[track] == "Male":
                    recs_dict[track] += self.lambda_val
                    if i < 10:
                        counter += 1
            ret_all.append(counter)
            if counter == 0:
                zero_users += 1
            else:
                ret_non_zero.append(counter)

            res[u] = np.array([k for k, v in sorted(
                recs_dict.items(), key=lambda x: x[1])])
        return res, np.mean(ret_all), np.mean(ret_non_zero), zero_users


class SimulationFAIR(SimulationBase):
    def predict(self, N=100):
        # get more items so we have enough to work with
        return super().predict(1000)
    
    def post_process(self, list1, items_gender, N=100):
        """ 
        Rerank items using FA*IR targeting gender equality.
        """
        res = np.zeros((list1.shape[0], N), dtype=np.uint32)
        ret_all = []
        ret_non_zero = []
        zero_users = 0

        # 0.0207: corrected alpha for 100-item lists targeting protected proportion 0.5
        ns = np.arange(1, 101)
        thresholds = binom.ppf(0.0207, ns, 0.5)
        
        for u in range(0, list1.shape[0]):
            inrank = list1[u,:]
            protected = np.array([items_gender[track] != "Male" for track in inrank])
            outrank = inrank.copy()
            
            # now we move things up until we pass the test
            # this is equivalent to the algorithm in the paper
            moves = 0
            done_through = 0
            while done_through < N:
                # get the counts and find the places where we fail the test
                cum_prot = np.cumsum(protected[:N])
                fails = cum_prot < thresholds
                # find the *positions* of the first failure
                fis, = np.nonzero(fails[done_through:])
                if len(fis) == 0:
                    # we are done!
                    break

                bad = fis[0] + done_through

                # we want to find a protected item, and move it up
                p_loc, = np.nonzero(protected[done_through:])
                if len(p_loc) == 0:
                    # no more protected items, leave the rest of the ranking as-is
                    break

                # we now proceed in a few steps.
                # first, we get the protected item
                pos = p_loc[0] + done_through
                item = inrank[pos]
                assert protected[pos]

                # if this item is at our current postion, we're fine, go forward
                # not sure this can actually happen
                if pos == done_through:
                    done_through += 1
                    continue
            
                # we now do a swap. we move the items before this position back, and this item up
                outrank[bad+1:pos+1] = outrank[bad:pos]
                protected[bad+1:pos+1] = protected[bad:pos]
                outrank[bad] = item
                protected[bad] = True
                moves += 1

                # and go around for another go
                done_through = bad

            # now we save these results
            res[u] = outrank[:N]
            ret_all.append(moves)
            if moves:
                ret_non_zero.append(moves)
            
        return res, np.mean(ret_all), np.mean(ret_non_zero), zero_users
