import os
import json
import random

import torch
import numpy as np
import pandas as pd
import PIL.Image
import torchvision.transforms as transforms
from fastai.vision.all import *

from PIL import Image
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler

import albumentations


class TMA_dataset(Dataset):
    def __init__(
        self,
        root,
        fold,
        transform,
        mode,
        epoch=-1,
        num_samples=-1,
        pred=[],
        probability=[],
        paths=[],
        num_class=14,
        return_time=False,
        regression=True,
    ):

        self.root = root
        self.transform = transform
        self.mode = mode
        self.train_labels = {}
        self.train_times_ori = {}
        self.train_times = {}
        self.test_labels = {}
        self.test_times = {}
        self.test_times_ori = {}
        self.val_labels = {}
        self.epoch = epoch
        self.return_time = return_time
        self.num_class = num_class

        train, valid = [], -1

        if fold == 0:
            train = [0, 1]
            valid = 2
        elif fold == 1:
            train = [1, 2]
            valid = 0
        elif fold == 2:
            train = [0, 2]
            valid = 1

        dfs_train = []
        for i in train:
            df = pd.read_csv(f"fold_{i}.csv")
            df["is_valid"] = 0  # type:ignore
            dfs_train.append(df)

        df_valid = pd.read_csv(f"fold_{valid}.csv")
        df_valid["is_valid"] = 1  # type:ignore

        max_pos_class = num_class - 2
        max_pos_class = num_class - 1
        neg_class = num_class - 1
        train_paths = []
        for i in train:
            with open(f"fold_{i}.csv", "r") as f:
                lines = f.read().splitlines()
                for l in lines:
                    entry = l.split(",")
                    if entry[0] == "slide":
                        continue
                    img_path = "%s/" % self.root + entry[0]
                    train_paths.append(img_path)
                    self.train_labels[img_path] = int(entry[1])

                    if int(entry[1]):
                        self.train_times[img_path] = min(num_class, int(entry[2])) - 1

                        if int(entry[2]) > 5:
                            self.train_times_ori[img_path] = -1
                        else:
                            self.train_times_ori[img_path] = int(entry[2]) - 1
                    else:
                        self.train_times[img_path] = num_class - 1

                        if int(entry[2]) < 6:
                            self.train_times_ori[img_path] = -1
                        else:
                            self.train_times_ori[img_path] = int(entry[2])

        test_paths = []
        with open(f"fold_{valid}.csv", "r") as f:
            lines = f.read().splitlines()
            for l in lines:
                entry = l.split(",")
                if entry[0] == "slide":
                    continue
                img_path = "%s/" % self.root + entry[0]
                test_paths.append(img_path)
                self.test_labels[img_path] = int(entry[1])

                if int(entry[1]):
                    self.test_times[img_path] = min(num_class, int(entry[2])) - 1

                    if int(entry[2]) > 5:
                        self.test_times_ori[img_path] = -1
                    else:
                        self.test_times_ori[img_path] = int(entry[2]) - 1
                else:
                    self.test_times[img_path] = num_class - 1

                    if int(entry[2]) < 6:
                        self.test_times_ori[img_path] = -1
                    else:
                        self.test_times_ori[img_path] = int(entry[2]) - 1

        if num_samples > -1:
            self.train_labels = {key: self.train_labels[key] for key in train_paths[0:num_samples]}
            self.train_times = {key: self.train_times[key] for key in train_paths[0:num_samples]}
            self.train_times_ori = {
                key: self.train_times_ori[key] for key in train_paths[0:num_samples]
            }
            self.test_labels = {key: self.test_labels[key] for key in test_paths[0:num_samples]}
            self.test_times = {key: self.test_times[key] for key in test_paths[0:num_samples]}
            self.test_times_ori = {
                key: self.test_times_ori[key] for key in test_paths[0:num_samples]
            }

        if mode == "all":
            self.train_imgs = list(self.train_labels.keys())
            random.shuffle(self.train_imgs)
        elif mode == "val":
            self.val_imgs = list(self.test_labels.keys())

    def cache_img(self, cache_path, img_path):
        if self.epoch != -1 and os.path.exists(cache_path):
            try:
                img = Image.open(cache_path).convert("RGB")
                img = self.transform[-1](img)
            except:
                os.remove(cache_path)
                return self.cache_img(cache_path, img_path)
        else:
            img = Image.open(img_path).convert("RGB")
            for i, transform in enumerate(self.transform):
                img = transform(img)
        return img

    def __getitem__(self, index):
        if self.mode == "all":
            img_path = self.train_imgs[index]
            target = self.train_labels[img_path]
            target_time = self.train_times[img_path]
            target_time_ori = self.train_times_ori[img_path]

            img = Image.open(img_path).convert("RGB")
            for i, transform in enumerate(self.transform):
                img = transform(img)

            if self.return_time:
                return img, (target, target_time, target_time_ori), img_path
            return img, target, img_path

        elif self.mode == "val":
            img_path = self.val_imgs[index]
            target = self.test_labels[img_path]
            target_time = self.test_times[img_path]
            target_time_ori = self.test_times_ori[img_path]
            image = Image.open(img_path).convert("RGB")
            img = image
            for transform in self.transform:
                img = transform(img)
            if self.return_time:
                return img, (target, target_time, target_time_ori)
            return img, target

    def __len__(self):
        if self.mode == "val":
            return len(self.val_imgs)
        else:
            return len(self.train_imgs)


class TMA_dataloader:
    def __init__(self, root, batch_size, fold, num_workers, return_times, num_class, regression):
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.return_times = return_times
        self.num_class = num_class
        self.root = root
        self.fold = fold
        self.regression = regression

        self.transform_train = albumentations.Compose(
            [
                albumentations.Rotate(limit=90, border_mode=0, value=[255, 255, 255]),
                albumentations.RandomCrop(768, 768),
                albumentations.HueSaturationValue(
                    hue_shift_limit=20, sat_shift_limit=10, val_shift_limit=25, p=0.3
                ),
                albumentations.HorizontalFlip(),
                albumentations.VerticalFlip(),
                albumentations.ImageCompression(quality_lower=90, p=0.5),
                albumentations.ElasticTransform(
                    p=0.5, alpha_affine=15, border_mode=0, value=[255, 255, 255], approximate=True
                ),
                albumentations.Blur(blur_limit=1.0, p=0.2),
                albumentations.RandomContrast(limit=0.05),
                albumentations.RandomGamma(),
                albumentations.RandomBrightness(limit=0.05),
                albumentations.Emboss(p=0.2),  # strength=(0.2, 2.0)
                albumentations.Sharpen(alpha=(0.2, 0.8), p=0.2),
                albumentations.GaussNoise(),
                albumentations.Cutout(num_holes=1, max_h_size=256, max_w_size=256, fill_value=255),
            ]
        )

        self.transform_train_2 = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        self.transform_test = transforms.Compose(
            [
                transforms.CenterCrop(768),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        self.fastai_transforms_train = aug_transforms(mult=3, size=768)

    def fastai_transforms(self, img):
        img = TensorImage(img).permute(2, 0, 1)[None].float() / 255
        for tfm in self.fastai_transforms_train:
            img = tfm(img, split_idx=0)  # type:ignore
        img = img[0].permute(1, 2, 0) * 255
        img = PIL.Image.fromarray(img.byte().numpy())
        return img

    def albumentations_transforms(self, img):
        aug_img = self.transform_train(image=np.array(img))["image"]
        return PIL.Image.fromarray(aug_img)

    def run(self, mode, epoch, pred=[], prob=[], paths=[], num_samples=-1):
        if mode == "warmup":
            warmup_dataset = TMA_dataset(
                self.root,
                epoch=epoch,
                fold=self.fold,
                transform=[
                    self.albumentations_transforms,
                    self.fastai_transforms,
                    self.transform_train_2,
                ],
                mode="all",
                num_samples=num_samples,
                return_time=self.return_times,
                num_class=self.num_class,
            )

            if self.return_times:
                labels = [
                    warmup_dataset.train_times[img_path] for img_path in warmup_dataset.train_imgs
                ]
            else:
                labels = [
                    warmup_dataset.train_labels[img_path] for img_path in warmup_dataset.train_imgs
                ]
            _, counts = np.unique(labels, return_counts=True)
            if len(counts) == 1 or len(counts) != self.num_class:
                weights = list(np.ones(len(labels)))
            else:
                weights = torch.DoubleTensor((1 / counts)[labels]).tolist()

            warmup_loader = DataLoader(
                dataset=warmup_dataset,
                batch_size=self.batch_size,
                sampler=WeightedRandomSampler(
                    weights=weights, num_samples=len(warmup_dataset.train_labels)
                ),
                num_workers=self.num_workers,
            )
            return warmup_loader
        elif mode == "val":
            val_dataset = TMA_dataset(
                self.root,
                epoch=epoch,
                fold=self.fold,
                transform=[self.transform_test],
                mode="val",
                return_time=self.return_times,
                num_class=self.num_class,
                num_samples=num_samples,
            )
            val_loader = DataLoader(
                dataset=val_dataset,
                batch_size=self.batch_size * 2,
                shuffle=False,
                num_workers=self.num_workers,
            )
            return val_loader
