import os
import sys
import random
import argparse

from copy import deepcopy

import wandb
import numpy as np
import torch
import torch.cuda
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import timm

from lifelines.utils import concordance_index
import dataloader_TMA as dataloader  # type:ignore

from fastai.vision.all import *  # ranger optimizer


def parse_args():
    parser = argparse.ArgumentParser(description="TMA Training")
    parser.add_argument("--name", type=str)
    parser.add_argument("--batch_size", default=16, type=int, help="Mini-batch size")
    parser.add_argument(
        "--lr", "--learning_rate", default=0.0002, type=float, help="The learning rate"
    )
    parser.add_argument("--wd", "--weight_decay", default=1e-6, type=float, help="weight decay")
    parser.add_argument("--num_epochs", default=50, type=int)
    parser.add_argument(
        "--data_path", default="/home/user/tmas_small", type=str, help="path to dataset"
    )
    parser.add_argument("--seed", default=123)
    parser.add_argument("--num_class", default=5, type=int)
    parser.add_argument("--ema", default=0, type=int)
    parser.add_argument("--pretrained", default=1, type=int)
    parser.add_argument("--num_samples", default=-1, type=int)
    parser.add_argument("--drop_path_rate", default=0.5, type=float)
    parser.add_argument("--drop_rate", default=0.5, type=float)
    args = parser.parse_args()
    return args


# Training
def train_epoch(net, optimizer, dataloader, k_net, n_epoch, fold, criterion, num_class=5):
    net.train()
    losses = []
    for batch_idx, (inputs, labels, path) in enumerate(dataloader):
        labels = labels[1]
        labels = labels.float()
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs.flatten(), labels.flatten())
        loss.backward()
        optimizer.step()
        losses.append(float(loss))

    print(n_epoch, np.average(losses))
    metrics = {f"{fold}/warmup/{k_net}/loss": np.average(losses), f"{fold}/epoch": n_epoch}
    wandb.log(metrics)  # type:ignore


def val(
    net,
    val_loader,
    k,
    n_epoch,
    best_cindex,
    criterion,
    num_class=5,
    regression=False,
    name="",
    fold=-1,
):
    net.eval()
    losses = []
    total = 0

    y_true = []
    y_logits = []
    y_times = []

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            time = targets[2]
            targets = targets[1]
            targets = targets.view(-1, 1).float()

            inputs, targets = inputs.cuda(), targets.cuda()
            logits = net(inputs)
            loss = criterion(logits.flatten(), targets.flatten())
            losses.append(float(loss))

            y_true.extend(targets.tolist())
            y_logits.extend(logits.tolist())
            y_times.extend(time.tolist())

    cindex = concordance_index(y_times, y_logits, event_observed=(np.array(y_times) != -1))
    print("C-index", k, cindex)

    if cindex > best_cindex[k - 1]:
        best_cindex[k - 1] = cindex
        wandb.log({f"{fold}/best_cindex": np.max(best_cindex)})
        weights = net.state_dict()
        torch.save(weights, "checkpoint/name_" + name + "_fold_" + str(fold) + "_c_" + str(cindex))

    metrics = {
        f"{fold}/val/{k}/loss": np.average(losses),
        f"{fold}/val/{k}/cindex": cindex,
        f"{fold}/epoch": n_epoch,
    }
    wandb.log(metrics)
    return cindex, best_cindex


def create_model(pretrained, num_class, drop_path_rate, drop_rate):
    model = timm.create_model(
        "resnet50d",
        pretrained=pretrained,
        drop_path_rate=drop_path_rate,
        drop_rate=drop_rate,
        num_classes=num_class,
    )
    model = model.cuda()
    return model


def average_param_dicts(param_dicts):
    param_dict = {}
    for current_param_dict in param_dicts:
        if not param_dict:
            param_dict = current_param_dict
        else:
            for name in current_param_dict:
                param_dict[name].data.add_(current_param_dict[name].data)

    for name in param_dict:
        if "num_batches_tracked" in name:
            continue
        param_dict[name].data.div_(float(len(param_dicts)))

    return param_dict


def average_nets(net1):
    current_net1 = deepcopy(net1.state_dict())
    net1_param_dicts.append(current_net1)
    if len(net1_param_dicts) > 5:
        net1_param_dicts.pop(0)
    net1_avg_param_dict = average_param_dicts(net1_param_dicts)
    net1.load_state_dict(net1_avg_param_dict)
    return current_net1


if __name__ == "__main__":
    args = parse_args()
    wandb.init()

    print(args)

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)  # type:ignore
    cudnn.benchmark = True

    best_auc_folds = []
    best_cindex_folds = []
    for fold in range(3):
        loader = dataloader.TMA_dataloader(
            fold=fold,
            root=args.data_path,
            batch_size=args.batch_size,
            num_workers=8,
            return_times=True,
            num_class=args.num_class,
            regression=True,
        )

        net1 = create_model(args.pretrained, 1, args.drop_path_rate, args.drop_rate)
        criterion = nn.SmoothL1Loss()
        optimizer1 = ranger(net1.parameters(), lr=args.lr, wd=args.wd)

        all_loss = [[], []]  # save the history of losses from two networks
        best_cindex = [0, 0]

        net1_param_dicts = []
        paths1, prob1 = [], []
        for epoch in range(args.num_epochs + 1):
            wandb.log({f"{fold}/epoch": epoch})

            train_loader = loader.run(
                "warmup", str(epoch) + "_warmup_1", num_samples=args.num_samples
            )

            train_epoch(net1, optimizer1, train_loader, 1, epoch, fold, criterion)

            # validation
            val_loader = loader.run("val", -1, num_samples=args.num_samples)

            if args.ema:
                current_net1 = average_nets(net1)

            cindex1, best_cindex = val(
                net1, val_loader, 1, epoch, best_cindex, criterion, name=args.name, fold=fold
            )

            if epoch == 50 and (cindex1 < 0.55):
                print("ABORTING TRAINING, performance not good enough after 50 epochs")
                break

            if args.ema:
                net1.load_state_dict(deepcopy(current_net1))

        best_cindex_folds.append(np.max(best_cindex))
        print(best_cindex_folds)
        print(best_auc_folds)

    wandb.log({"best_cindex": np.average(best_cindex_folds)})
