import csv
import json
import datetime
import os
import pickle
import pandas as pd
import numpy as np
from collections import defaultdict
from scipy import sparse

DATA_PATH = './data/'  ## can be used to set the path to the data folder on the local machine
dataset_location= DATA_PATH+'lfm-2b/listening-events.tsv'
gender_location = DATA_PATH+'lfm-2b/album_reduced_data.json'
albums_data_file = DATA_PATH+'lfm-2b/albums.tsv'
artists_data_file = DATA_PATH+'lfm-2b/artists.tsv'

def split(test_size, interactions_file, album_data, years=["2020"], minimum_user_listens=2, minimum_track_listens=2):
    tsv_file = open(interactions_file)
    read_tsv = csv.reader(tsv_file, delimiter="\t")
    next(read_tsv, None)  # skip the headers
    first_date = None
    last_date = None
    include_line = 0
    exclude_line = 0
    for row in read_tsv:
        if row[3][:4] in years:
            if row[2] in album_data:
                include_line += 1
                lt = datetime.datetime.strptime(row[3], "%Y-%m-%d %H:%M:%S")
                if first_date == None or lt <first_date:
                    first_date = lt
                if last_date == None or lt >last_date:
                    last_date = lt
            else:
                exclude_line += 1

    print ("lines included", include_line)
    print ("excluded lines in period", exclude_line)
    split_date = datetime.datetime.fromtimestamp(
         first_date.timestamp() + ((last_date.timestamp() - first_date.timestamp())*test_size))
    print ("Split date:", split_date)

    train_dict = {}
    test_dict = {}
    tracks_total_counts = {}
    map_track_gender = defaultdict(int)
    tsv_file = open(interactions_file)
    read_tsv = csv.reader(tsv_file, delimiter="\t")
    next(read_tsv, None)  # skip the headers
    for line in read_tsv:
        if line[3][:4] not in years or line[2] not in album_data:
            continue
        map_track_gender[line[1]] = album_data[line[2]]
        lt = datetime.datetime.strptime(line[3], "%Y-%m-%d %H:%M:%S")
        if lt < split_date:
            if line[0] not in train_dict:
                train_dict[line[0]] = {}
            if line[1] not in train_dict[line[0]]:
                train_dict[line[0]][line[1]] = 0
            train_dict[line[0]][line[1]] += 1
        else:
            if line[0] not in test_dict:
                test_dict[line[0]] = {}
            if line[1] not in test_dict[line[0]]:
                test_dict[line[0]][line[1]] = 0
            test_dict[line[0]][line[1]] += 1

    for u_id, tracks in train_dict.items():
        for track in tracks:
            if track not in tracks_total_counts:
                tracks_total_counts[track] = 0
            tracks_total_counts[track] += 1
 
    col_train = []
    row_train = []
    play_train = []
    test_data = []
    tracks_ids = []
    tracks_dict = {}
    users_dict = {}
    curr_user_count = 0
    for u_id in train_dict.keys():
        remove_tracks = []
        for track,_ in train_dict[u_id].items():
            if tracks_total_counts[track] < minimum_track_listens:
                 remove_tracks.append(track)
        for track in remove_tracks:
            del train_dict[u_id][track]

        if len(train_dict[u_id]) < minimum_user_listens:
            continue

        users_dict[u_id] = curr_user_count
        curr_user_count += 1
        for item, play in train_dict[u_id].items():
            if item not in tracks_dict:
                tracks_dict[item] = len(tracks_ids)
                tracks_ids.append(item)
            col_train.append(tracks_dict[item])
            row_train.append(users_dict[u_id])
            play_train.append(play)

        test_u = []
        if u_id in test_dict:
            for item, play in test_dict[u_id].items():
                if item in tracks_total_counts and tracks_total_counts[item] >= minimum_track_listens:
                    if item not in tracks_dict:
                        tracks_dict[item] = len(tracks_ids)
                        tracks_ids.append(item)
                    test_u.append((tracks_dict[item], play))
        test_data.append(test_u)
    return play_train, row_train, col_train, test_data, tracks_dict, users_dict, map_track_gender


def get_gender_data(albums_data, album_map):
    artist_data_gender = defaultdict(list)
    for i,l in albums_data.items():
        artist_id = album_map[int(i)]['artist_id']
        curr_genders = []
        curr_types = []
        for j in l:
            if 'type' in j :
                curr_types.append(j['type'])
            if 'gender' in j:
                curr_genders.append(j['gender'])
        artist_data_gender[artist_id].append((i, curr_genders, curr_types))

    album_data_gender = defaultdict(int)
    final_album_data_gender = defaultdict(int)
    for album_id in album_map.keys():
        has_album_single_gender = None
        artist_id = album_map[album_id]['artist_id']
        if artist_id in artist_data_gender:
            values = artist_data_gender[artist_id]
            for album, gender_list, types_list in values:
                gender_set = [i for i in set(gender_list) if i != 'Not applicable']
                album_data_gender[album] = gender_set
                if len(gender_list) == 1 and types_list == ['Person']:
                    has_album_single_gender = gender_list[0]
        str_album_id = str(album_id)
        if str_album_id in album_data_gender and len(album_data_gender[str_album_id]) == 1:
            final_album_data_gender[str_album_id] = album_data_gender[str_album_id][0]
        elif has_album_single_gender != None:
            final_album_data_gender[str_album_id] = has_album_single_gender
    return final_album_data_gender

def save_files(fan_data_play, fan_row_train, fan_col_train, gt, track_gender, output_folder):
    if not os.path.isdir(output_folder):
        os.mkdir(output_folder)
    with open(os.path.join(output_folder, 'train.tsv'), 'w') as tsvfile:
        writer = csv.writer(tsvfile, delimiter='\t')
        for p,u,i in zip(fan_data_play, fan_row_train, fan_col_train):
            writer.writerow([u,i,p])
    with open(os.path.join(output_folder, 'test.tsv'), 'w') as tsvfile:
        writer = csv.writer(tsvfile, delimiter='\t')
        for u in range(len(gt)):
            for i,p in gt[u]:
                writer.writerow([u,i,p])
    with open(os.path.join(output_folder, 'features.tsv'), 'w') as tsvfile:
        writer = csv.writer(tsvfile, delimiter='\t')
        for t,g in track_gender.items():
            writer.writerow([t,g])
    json.dump(track_gender, open(os.path.join(output_folder, 'features.json'), 'w'))


if __name__== "__main__":
    albums_data = json.load(open(gender_location))
    albums = pd.read_csv(albums_data_file, sep='\t', names=['album_id', 'album_name', 'artist_name'], on_bad_lines='warn')
    artists = pd.read_csv(artists_data_file, sep='\t', names=['artist_id', 'artist_name'], on_bad_lines='warn')
    album_merged = albums.set_index('album_id').merge(artists, on='artist_name')
    album_map = album_merged[['artist_id']].to_dict('index')
    albums_gender = get_gender_data(albums_data, album_map)
    #print (ret_coms.keys())

    play_train, row_train, col_train, test_data, tracks_dict, users_dict, map_track_gender = split(0.9, dataset_location, albums_gender, years=["2013","2014","2015", "2016", "2017","2018","2019","2020"])

    final_track_gender = defaultdict(int)
    for i,g in map_track_gender.items():
        if i in tracks_dict:
            final_track_gender[i] = g

    save_files(play_train, row_train, col_train, test_data, final_track_gender, os.path.join('./data', 'lfm2b'))

    train_play = sparse.coo_matrix((play_train, (row_train, col_train)), dtype=np.float32)
    sparse.save_npz(os.path.join(DATA_PATH, 'lfm2b', 'train_data_playcount.npz'), train_play)
    pickle.dump(test_data, open(os.path.join(DATA_PATH,  'lfm2b', 'test_data.pkl'), 'wb'))
    pickle.dump(tracks_dict, open(os.path.join(DATA_PATH, 'lfm2b', 'items_dict.pkl'), 'wb'))
    pickle.dump(users_dict, open(os.path.join(DATA_PATH, 'lfm2b', 'users_dict.pkl'), 'wb'))

