import pdb
from torch.nn.init import xavier_uniform_
from torch.utils.data import TensorDataset
import numpy as np
import logging
import os
import random
import torch
import time
from tqdm import tqdm
from _utils import *

logger = logging.getLogger(__name__)

def load_and_cache_gen_data(args, filename, pool, tokenizer, split_tag, only_src=False, is_sample=False):

    data_tag = '_all' if args.data_num == -1 else '_%d' % args.data_num
    cache_fn = '{}/{}.pt'.format(args.cache_path, split_tag + ('_src' if only_src else '') + data_tag)

    examples = read_examples(filename, args.data_num, args.task)

    if is_sample:
        examples = random.sample(examples, min(5000, len(examples)))
    if split_tag == 'train' or split_tag == 'test':
        calc_stats(examples, tokenizer, is_tokenize=True)
    else:
        calc_stats(examples)
    if os.path.exists(cache_fn) and not is_sample:
        logger.info("Load cache data from %s", cache_fn)
        data = torch.load(cache_fn)
    else:
        if is_sample:
            logger.info("Sample 10k data for evaluation from %s", filename)
        else:
            logger.info("Create cache data into %s", cache_fn)
        tuple_examples = [(example, idx, tokenizer, args, split_tag) for idx, example in enumerate(examples)]
        features = pool.map(convert_examples_to_features, tqdm(tuple_examples, total=len(tuple_examples)))
        all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
        if split_tag == 'test' or only_src:
            data = TensorDataset(all_source_ids)
        else:
            all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long)
            data = TensorDataset(all_source_ids, all_target_ids)
        if args.local_rank in [-1, 0] and not is_sample:
            torch.save(data, cache_fn)
    return examples, data

def read_examples(filename, data_num, task):
    read_example_dict = {
        'bigfixes_task':read_bigfixes_examples,
        'bigfixes_base':read_bigfixes_examples,
    }
    return read_example_dict[task](filename, data_num)


def calc_stats(examples, tokenizer=None, is_tokenize=False):
    avg_src_len = []
    avg_trg_len = []
    avg_src_len_tokenize = []
    avg_trg_len_tokenize = []
    for ex in examples:
        if is_tokenize:
            avg_src_len.append(len(ex.source.split()))
            avg_trg_len.append(len(str(ex.target).split()))
            avg_src_len_tokenize.append(len(tokenizer.tokenize(ex.source)))
            avg_trg_len_tokenize.append(len(tokenizer.tokenize(str(ex.target))))
        else:
            avg_src_len.append(len(ex.source.split()))
            avg_trg_len.append(len(str(ex.target).split()))
    if is_tokenize:
        logger.info("Read %d examples, avg src len: %d, avg trg len: %d, max src len: %d, max trg len: %d",
                    len(examples), np.mean(avg_src_len), np.mean(avg_trg_len), max(avg_src_len), max(avg_trg_len))
        logger.info("[TOKENIZE] avg src len: %d, avg trg len: %d, max src len: %d, max trg len: %d",
                    np.mean(avg_src_len_tokenize), np.mean(avg_trg_len_tokenize), max(avg_src_len_tokenize),
                    max(avg_trg_len_tokenize))
    else:
        logger.info("Read %d examples, avg src len: %d, avg trg len: %d, max src len: %d, max trg len: %d",
                    len(examples), np.mean(avg_src_len), np.mean(avg_trg_len), max(avg_src_len), max(avg_trg_len))
        
def token_length_dis(examples, tokenizer, is_tokenize):
    src_token_ranges = {
        "[0-500]": 0,
        "[501-1000]": 0,
        "[1001-1500]": 0,
        "[1501-2000]": 0,
        "2000+": 0,
    }
    trg_token_ranges = {
        "[0-500]": 0,
        "[501-1000]": 0,
        "[1001-1500]": 0,
        "[1501-2000]": 0,
        "2000+": 0,
    }

    for ex in examples:
        if is_tokenize:
            src_tokens = tokenizer.tokenize(ex.source)
            trg_tokens = tokenizer.tokenize(str(ex.target))
            num_src_tokens = len(src_tokens)
            num_trg_tokens = len(trg_tokens)
        else:
            num_src_tokens = len(ex.source.split())
            num_trg_tokens = len(str(ex.target).split())

        # Update token ranges for source tokens
        if num_src_tokens <= 500:
            src_token_ranges["[0-500]"] += 1
        elif num_src_tokens <= 1000:
            src_token_ranges["[501-1000]"] += 1
        elif num_src_tokens <= 1500:
            src_token_ranges["[1001-1500]"] += 1
        elif num_src_tokens <= 2000:
            src_token_ranges["[1501-2000]"] += 1
        else:
            src_token_ranges["2000+"] += 1

        # Update token ranges for target tokens
        if num_trg_tokens <= 500:
            trg_token_ranges["[0-500]"] += 1
        elif num_trg_tokens <= 1000:
            trg_token_ranges["[501-1000]"] += 1
        elif num_trg_tokens <= 1500:
            trg_token_ranges["[1001-1500]"] += 1
        elif num_trg_tokens <= 2000:
            trg_token_ranges["[1501-2000]"] += 1
        else:
            trg_token_ranges["2000+"] += 1

    return src_token_ranges, trg_token_ranges


def get_elapse_time(t0):
    elapse_time = time.time() - t0
    if elapse_time > 3600:
        hour = int(elapse_time // 3600)
        minute = int((elapse_time % 3600) // 60)
        return "{}h{}m".format(hour, minute)
    else:
        minute = int((elapse_time % 3600) // 60)
        return "{}m".format(minute)
