from pathlib import Path
import csv
import json
import datetime
import os
import pickle
import pandas as pd
import numpy as np
from collections import defaultdict
from scipy import sparse
import bz2
import zstandard

src_dir = Path('data/lfm-2b')
dst_dir = Path('data/lfm-2b-artists')

dataset_location= src_dir / 'listening-events.tsv.bz2'
gender_location = src_dir / 'album_reduced_data.json.zst'
albums_data_file = src_dir / 'albums.tsv.bz2'
artists_data_file = src_dir / 'artists.tsv.bz2'

def split(test_size, interactions_file, artist_data, album_map, years=["2020"], minimum_user_listens=15, minimum_artist_listens=15):
    tsv_file = bz2.open(interactions_file, "rt")
    read_tsv = csv.reader(tsv_file, delimiter="\t")
    next(read_tsv, None)  # skip the headers
    first_date = None
    last_date = None
    include_line = 0
    exclude_line = 0
    exclude_line2 = 0
    
    for row in read_tsv:
        if row[3][:4] in years:
            album_id = int(row[2])
            if album_id in album_map:
                artist_id = album_map[album_id]['artist_id']
                if artist_id in artist_data:
                    include_line += 1
                    lt = datetime.datetime.strptime(row[3], "%Y-%m-%d %H:%M:%S")
                    if first_date == None or lt <first_date:
                        first_date = lt
                    if last_date == None or lt >last_date:
                        last_date = lt
                else:
                    exclude_line2 += 1
            else:
                exclude_line += 1
    
    print ("lines included", include_line)
    print ("excluded lines in period", exclude_line)
    print ("excluded2 lines in period", exclude_line2)
    split_date = datetime.datetime.fromtimestamp(
         first_date.timestamp() + ((last_date.timestamp() - first_date.timestamp())*test_size))
    print ("Split date:", split_date)

    train_dict = {}
    test_dict = {}
    artists_total_counts = {}
    tsv_file = bz2.open(interactions_file, 'rt')
    read_tsv = csv.reader(tsv_file, delimiter="\t")
    next(read_tsv, None)  # skip the headers
    for line in read_tsv:
        if line[3][:4] not in years:
            continue
        album_id = int(line[2])
        if album_id not in album_map:
            continue
        artist_id = album_map[album_id]['artist_id']
        if artist_id not in artist_data:
            continue
        lt = datetime.datetime.strptime(line[3], "%Y-%m-%d %H:%M:%S")
        if lt < split_date:
            if line[0] not in train_dict:
                train_dict[line[0]] = {}
            if artist_id not in train_dict[line[0]]:
                train_dict[line[0]][artist_id] = 0
            train_dict[line[0]][artist_id] += 1
        else:
            if line[0] not in test_dict:
                test_dict[line[0]] = {}
            if artist_id not in test_dict[line[0]]:
                test_dict[line[0]][artist_id] = 0
            test_dict[line[0]][artist_id] += 1

    for u_id, artists in train_dict.items():
        for artist in artists:
            if artist not in artists_total_counts:
                artists_total_counts[artist] = 0
            artists_total_counts[artist] += 1
 
    col_train = []
    row_train = []
    play_train = []
    test_data = []
    artist_ids = []
    artists_dict = {}
    users_dict = {}
    curr_user_count = 0
    for u_id in train_dict.keys():
        remove_artists = []
        for artist,_ in train_dict[u_id].items():
            if artists_total_counts[artist] < minimum_artist_listens:
                 remove_artists.append(artist)
        for artist in remove_artists:
            del train_dict[u_id][artist]

        if len(train_dict[u_id]) < minimum_user_listens:
            continue

        users_dict[u_id] = curr_user_count
        curr_user_count += 1
        for item, play in train_dict[u_id].items():
            if item not in artists_dict:
                artists_dict[item] = len(artist_ids)
                artist_ids.append(item)
            col_train.append(artists_dict[item])
            row_train.append(users_dict[u_id])
            play_train.append(play)

        test_u = []
        if u_id in test_dict:
            for item, play in test_dict[u_id].items():
                if item in artists_total_counts and artists_total_counts[item] >= minimum_artist_listens:
                    if item not in artists_dict:
                        artists_dict[item] = len(artist_ids)
                        artist_ids.append(item)
                    test_u.append((artists_dict[item], play))
        test_data.append(test_u)
    return play_train, row_train, col_train, test_data, artists_dict, users_dict


def get_gender_data(albums_data, album_map):
    artist_data_gender = defaultdict(list)
    for i,l in albums_data.items():
        artist_id = album_map[int(i)]['artist_id']
        curr_genders = []
        curr_types = []
        for j in l:
            if 'type' in j :
                curr_types.append(j['type'])
            if 'gender' in j:
                curr_genders.append(j['gender'])
        artist_data_gender[artist_id].append((i, curr_genders, curr_types))

    final_artist_data_gender = defaultdict(int)
    for album_id in album_map.keys():
        has_album_single_gender = None
        artist_id = album_map[album_id]['artist_id']
        if artist_id in artist_data_gender:
            values = artist_data_gender[artist_id]
            for album, gender_list, types_list in values:
                if len(gender_list) == 1 and types_list == ['Person']:
                    has_album_single_gender = gender_list[0]
        if has_album_single_gender != None:
            final_artist_data_gender[artist_id] = has_album_single_gender
            if has_album_single_gender != 'Male' and has_album_single_gender != 'Female':
                final_artist_data_gender[artist_id] = 'Non-binary'
    return final_artist_data_gender

def save_files(fan_data_play, fan_row_train, fan_col_train, gt, artist_gender, output_folder):
    output_folder.mkdir(exist_ok=True)
    with zstandard.open(output_folder / 'train.tsv.zst', 'w') as tsvfile:
        writer = csv.writer(tsvfile, delimiter='\t')
        for p,u,i in zip(fan_data_play, fan_row_train, fan_col_train):
            writer.writerow([u,i,p])
    with zstandard.open(output_folder / 'test.tsv.zst', 'w') as tsvfile:
        writer = csv.writer(tsvfile, delimiter='\t')
        for u in range(len(gt)):
            for i,p in gt[u]:
                writer.writerow([u,i,p])
    with zstandard.open(output_folder / 'features.tsv.zst', 'w') as tsvfile:
        writer = csv.writer(tsvfile, delimiter='\t')
        for t,g in artist_gender.items():
            writer.writerow([t,g])
    json.dump(artist_gender, zstandard.open(output_folder / 'features.json.zst', 'w'))


if __name__== "__main__":
    albums_data = json.load(zstandard.open(gender_location, 'r'))
    albums = pd.read_csv(albums_data_file, sep='\t', names=['album_id', 'album_name', 'artist_name'], on_bad_lines='warn')
    artists = pd.read_csv(artists_data_file, sep='\t', names=['artist_id', 'artist_name'], on_bad_lines='warn')
    album_merged = albums.set_index('album_id').merge(artists, on='artist_name')
    album_map = album_merged[['artist_id']].to_dict('index')
    artist_gender = get_gender_data(albums_data, album_map)
    #print (ret_coms.keys())

    play_train, row_train, col_train, test_data, artists_dict, users_dict = split(0.9, dataset_location, artist_gender, album_map, years=["2013","2014","2015", "2016", "2017","2018","2019","2020"])

    save_files(play_train, row_train, col_train, test_data, artist_gender, dst_dir)

    train_play = sparse.coo_matrix((play_train, (row_train, col_train)), dtype=np.float32)
    sparse.save_npz(dst_dir / 'train_data_playcount.npz', train_play)
    with zstandard.open(dst_dir / 'test_data.pkl.zst', 'wb') as outf:
        pickle.dump(test_data, outf)
    with zstandard.open(dst_dir / 'items_dict.pkl.zst', 'wb') as outf:
        pickle.dump(artists_dict, outf)
    with zstandard.open(dst_dir / 'users_dict.pkl.zst', 'wb') as outf:
        pickle.dump(users_dict, outf)

