import os
import sys
from argparse import Namespace
from copy import deepcopy
import numpy as np
import pickle as pkl
import tqdm
import torch
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from getdata import get_uci_data, get_toy_data
from cali_src.cali_args import parse_args
from cali_src.cali_misc_utils import (
    set_seeds,
    get_q_idx,
    discretize_domain)
from cali_src.cali_recal import iso_recal
from cali_src.cali_q_model_ens import QModelEns
from cali_src.cali_losses import get_loss_fn
from getdata import load_reg_result

args = parse_args()


if __name__ == "__main__":
    # DATA_NAMES = ['wine', 'naval', 'kin8nm', 'energy', 'yacht', 'concrete', 'power', 'boston']

    args = parse_args()

    print("DEVICE: {}".format(args.device))

    if args.debug:
        import pudb

        pudb.set_trace()

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    per_seed_cali = []
    per_seed_sharp = []
    per_seed_gcali = []
    per_seed_crps = []
    per_seed_nll = []
    per_seed_check = []
    per_seed_int = []
    per_seed_int_cali = []
    per_seed_model = []

    print(
        "Drawing group batches every {}, penalty {}".format(
            args.draw_group_every, args.sharp_penalty
        )
    )

    # Save file name
    if "penalty" not in args.loss:
        save_file_name = "{}/{}_loss{}_ens{}_boot{}_seed{}.pkl".format(
            args.save_dir,
            args.data,
            args.loss,
            args.num_ens,
            args.boot,
            args.seed,
        )
    else:
        # penalizing sharpness
        if args.sharp_all is not None and args.sharp_all:
            save_file_name = "{}/{}_loss{}_pen{}_sharpall_ens{}_boot{}_seed{}.pkl".format(
                args.save_dir,
                args.data,
                args.loss,
                args.sharp_penalty,
                args.num_ens,
                args.boot,
                args.seed,
            )
        elif args.sharp_all is not None and not args.sharp_all:
            save_file_name = "{}/{}_loss{}_pen{}_wideonly_ens{}_boot{}_seed{}.pkl".format(
                args.save_dir,
                args.data,
                args.loss,
                args.sharp_penalty,
                args.num_ens,
                args.boot,
                args.seed,
            )
    if os.path.exists(save_file_name):
        print("skipping {}".format(save_file_name))
        sys.exit()

    # Set seeds
    set_seeds(args.seed)

    # Fetching data
    data_args = Namespace(
        data_dir=args.data_dir, dataset=args.data, seed=args.seed
    )

    if "uci" in args.data_dir.lower():
        data_out = get_uci_data(args.data)
    elif "toy" in args.data_dir.lower():
        data_out = get_toy_data(args)

    (reg_res, reg_par) = load_reg_result(args.data, args.seed,
                                         reg_type='DNN',
                                         basedir='runs_reg_900',
                                         rng_split=True)

    x_tr = reg_res['dataset']['train x']
    y_tr = reg_res['dataset']['train y']
    x_te = reg_res['dataset']['test x']
    y_te = reg_res['dataset']['test y']
    x_va = reg_res['dataset']['stop x']
    y_va = reg_res['dataset']['stop y']
    x_sup = reg_res['dataset']['sup x']
    y_sup = reg_res['dataset']['sup y']

    # x_tr, x_va, x_te, y_tr, y_va, y_te, y_al = (
    #     data_out.x_tr,
    #     data_out.x_va,
    #     data_out.x_te,
    #     data_out.y_tr,
    #     data_out.y_va,
    #     data_out.y_te,
    #     data_out.y_al,
    # )

    # y_range = (y_al.max() - y_al.min()).item()
    y_range = (y_tr.max() - y_tr.min()).item()
    print("y range: {:.3f}".format(y_range))

    # Making models
    num_tr = x_tr.shape[0]
    dim_x = x_tr.shape[1]
    dim_y = y_tr.shape[1]
    model_ens = QModelEns(
        input_size=dim_x + 1,
        output_size=dim_y,
        hidden_size=args.hs,
        num_layers=args.nl,
        lr=args.lr,
        wd=args.wd,
        num_ens=args.num_ens,
        device=args.device,
    )

    # Data loader
    x_tr, y_tr = torch.tensor(x_tr), torch.tensor(y_tr)
    x_va, y_va = torch.tensor(x_va), torch.tensor(y_va)
    x_te, y_te = torch.tensor(x_te), torch.tensor(y_te)
    x_sup, y_sup = torch.tensor(x_sup), torch.tensor(y_sup)

    # dtypoe to float32
    x_tr, y_tr = x_tr.float(), y_tr.float()
    x_va, y_va = x_va.float(), y_va.float()
    x_te, y_te = x_te.float(), y_te.float()
    x_sup, y_sup = x_sup.float(), y_sup.float()

    if not args.boot:
        loader = DataLoader(
            TensorDataset(x_tr, y_tr),
            shuffle=True,
            batch_size=args.bs,
        )
    else:
        rand_idx_list = [
            np.random.choice(num_tr, size=num_tr, replace=True)
            for _ in range(args.num_ens)
        ]
        loader_list = [
            DataLoader(
                TensorDataset(x_tr[idxs], y_tr[idxs]),
                shuffle=True,
                batch_size=args.bs,
            )
            for idxs in rand_idx_list
        ]

    # Loss function
    loss_fn = get_loss_fn(args.loss)
    args.scale = True if "scale" in args.loss else False
    batch_loss = True if "batch" in args.loss else False

    """ train loop """
    tr_loss_list = []
    va_loss_list = []
    te_loss_list = []

    # setting batch groupings
    group_list = discretize_domain(x_tr.numpy(), args.bs)
    curr_group_idx = 0

    for ep in tqdm.tqdm(range(args.num_ep)):
        if model_ens.done_training:
            print("Done training ens at EP {}".format(ep))
            break

        # Take train step
        # list of losses from each batch, for one epoch
        ep_train_loss = []
        if not args.boot:
            if ep % args.draw_group_every == 0:
                # drawing a group batch
                group_idxs = group_list[curr_group_idx]
                curr_group_idx = (curr_group_idx + 1) % dim_x
                for g_idx in group_idxs:
                    xi = x_tr[g_idx.flatten()].to(args.device)
                    yi = y_tr[g_idx.flatten()].to(args.device)

                    q_list = torch.rand(args.num_q)
                    loss = model_ens.loss(
                        loss_fn,
                        xi,
                        yi,
                        q_list,
                        batch_q=batch_loss,
                        take_step=True,
                        args=args,
                    )
                    ep_train_loss.append(loss)
            else:
                # just doing ordinary random batch
                for (xi, yi) in loader:
                    xi, yi = xi.to(args.device), yi.to(args.device)
                    q_list = torch.rand(args.num_q)
                    loss = model_ens.loss(
                        loss_fn,
                        xi,
                        yi,
                        q_list,
                        batch_q=batch_loss,
                        take_step=True,
                        args=args,
                    )
                    ep_train_loss.append(loss)
        else:
            # bootstrapped ensemble of models
            for xi_yi_samp in zip(*loader_list):
                xi_list = [item[0].to(args.device) for item in xi_yi_samp]
                yi_list = [item[1].to(args.device) for item in xi_yi_samp]
                assert len(xi_list) == len(yi_list) == args.num_ens
                q_list = torch.rand(args.num_q)
                loss = model_ens.loss_boot(
                    loss_fn,
                    xi_list,
                    yi_list,
                    q_list,
                    batch_q=batch_loss,
                    take_step=True,
                    args=args,
                )
                ep_train_loss.append(loss)
        ep_tr_loss = np.nanmean(np.stack(ep_train_loss, axis=0), axis=0)
        tr_loss_list.append(ep_tr_loss)

        # Validation loss
        x_va, y_va = x_va.to(args.device), y_va.to(args.device)
        va_te_q_list = torch.linspace(0.01, 0.99, 99)
        ep_va_loss = model_ens.update_va_loss(
            loss_fn,
            x_va,
            y_va,
            va_te_q_list,
            batch_q=batch_loss,
            curr_ep=ep,
            num_wait=args.wait,
            args=args,
        )
        va_loss_list.append(ep_va_loss)

        # Test loss
        x_te, y_te = x_te.to(args.device), y_te.to(args.device)
        with torch.no_grad():
            ep_te_loss = model_ens.loss(
                loss_fn,
                x_te,
                y_te,
                va_te_q_list,
                batch_q=batch_loss,
                take_step=False,
                args=args,
            )
        te_loss_list.append(ep_te_loss)

        # Printing some losses
        if (ep % 200 == 0) or (ep == args.num_ep - 1):
            print("EP:{}".format(ep))
            print("Train loss {}".format(ep_tr_loss))
            print("Val loss {}".format(ep_va_loss))
            print("Test loss {}".format(ep_te_loss))

    # Finished training
    # Move everything to cpu
    x_tr, y_tr, x_va, y_va, x_te, y_te = (
        x_tr.cpu(),
        y_tr.cpu(),
        x_va.cpu(),
        y_va.cpu(),
        x_te.cpu(),
        y_te.cpu(),
    )
    model_ens.use_device(torch.device("cpu"))

    #############################################

    # sample from model_ens, using as input x_val (of course this can be changed any time)

    with torch.no_grad():
        pred_mat = model_ens.predict_q_customized(x_va)
    print("pred_mat shape: {}".format(pred_mat.shape))
