"""This is code based on https://sudomake.ai/inception-score-explained/."""
import torch
import torchvision

from collections import defaultdict
from timm.models.layers import PatchEmbed

class InceptionScore(torch.nn.Module):
    """Class that manages and returns the inception score of images."""

    def __init__(self, batch_size=32, setup=dict(device=torch.device('cpu'), dtype=torch.float)):
        """Initialize with setup and target inception batch size."""
        super().__init__()
        self.preprocessing = torch.nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False)
        self.model = torchvision.models.inception_v3(pretrained=True).to(**setup)
        self.model.eval()
        self.batch_size = batch_size

    def forward(self, image_batch):
        """Image batch should have dimensions BCHW and should be normalized.

        B should be divisible by self.batch_size.
        """
        B, C, H, W = image_batch.shape
        batches = B // self.batch_size
        scores = []
        for batch in range(batches):
            input = self.preprocessing(image_batch[batch * self.batch_size: (batch + 1) * self.batch_size])
            scores.append(self.model(input))
        prob_yx = torch.nn.functional.softmax(torch.cat(scores, 0), dim=1)
        entropy = torch.where(prob_yx > 0, -prob_yx * prob_yx.log(), torch.zeros_like(prob_yx))
        return entropy.sum()


def psnr(img_batch, ref_batch, batched=False, factor=1.0):
    """Standard PSNR."""
    def get_psnr(img_in, img_ref):
        mse = ((img_in - img_ref)**2).mean()
        if mse > 0 and torch.isfinite(mse):
            return (10 * torch.log10(factor**2 / mse))
        elif not torch.isfinite(mse):
            return img_batch.new_tensor(float('nan'))
        else:
            return img_batch.new_tensor(float('inf'))

    if batched:
        psnr = get_psnr(img_batch.detach(), ref_batch)
    else:
        [B, C, m, n] = img_batch.shape
        psnrs = []
        for sample in range(B):
            psnrs.append(get_psnr(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :]))
        psnr = torch.stack(psnrs, dim=0).mean()

    return psnr.item()

def group_r(images,mean_std,xm):
    #just for two restarts
    import cv2
    import numpy as np
    from torchvision import transforms
    import copy
    from PIL import Image
    top=transforms.ToPILImage()
    tot=transforms.ToTensor()
    from matplotlib import pyplot as plt
    dm=mean_std[0]
    ds=mean_std[1]
    '''
    dm=torch.tensor([[[0.4915]],

        [[0.4823]],

        [[0.4468]]])
    ds=torch.tensor([[[0.2470]],

        [[0.2435]],

        [[0.2616]]])
    '''
    images=[image.detach().clone() for image in images]
    #images=[copy.deepcopy(image).cpu() for image in images]
    batch_size=images[0].size()[0]
    trail_size=len(images)
    #torch.stack(images,0)
    rt=list()
    #print(xm.size())
    for i in range(batch_size):
        trail=[]
        for j in range(trail_size):
            img1=images[j][i]
            img2=xm[i]
            #print(img1.size())
            #print(img2.size())
            img1=np.array(top(torch.clamp(img1* ds + dm, 0, 1)))
            img2=np.array(top(torch.clamp(img2* ds + dm, 0, 1)))
            gray1 =cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
            gray2 =cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
            sift = cv2.SIFT_create()
            kp1, des1 = sift.detectAndCompute(gray1, None)
            kp2, des2 = sift.detectAndCompute(gray2, None)

            bf = cv2.BFMatcher()
            matches = bf.knnMatch(des1, des2, k=2)


            good_matches = []
            for m,n in matches:
                if m.distance < 0.75 * n.distance:
                    good_matches.append(m)
            
            src_pts = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
            dst_pts = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
            
            M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
            h, w = gray1.shape
            #TODO 仿射变换应当是从img1到img2
            aligned_img = cv2.warpPerspective(img1, M, (w, h))
            trail.append((tot(Image.fromarray(aligned_img)).cuda()-dm)/ds)
        mean_image=torch.mean(torch.stack(trail,0),dim=0)
        #print(mean_image.szie())
        rt.append(mean_image)
    rt=torch.stack(rt)
    #exit()
    return rt


'''
def group_r(images,mean_std):
    #just for two restarts
    import cv2
    import numpy as np
    from torchvision import transforms
    import copy
    from PIL import Image
    top=transforms.ToPILImage()
    tot=transforms.ToTensor()
    from matplotlib import pyplot as plt
    dm=mean_std[0]
    ds=mean_std[1]
    '''
'''
    dm=torch.tensor([[[0.4915]],

        [[0.4823]],

        [[0.4468]]])
    ds=torch.tensor([[[0.2470]],

        [[0.2435]],

        [[0.2616]]])
    '''
'''
    images=[image.detach().clone() for image in images]
    #images=[copy.deepcopy(image).cpu() for image in images]
    batch_size=images[0].size()[0]
    rt=list()
    for i in range(batch_size):
        #t=top(np.uint8(images[0][i].detach().numpy()))
        #img1=np.array(top(images[0][i].detach()))
        #img2=np.array(top(images[1][i].detach()))
        img1=np.array(top(torch.clamp(images[0][i].detach().clone() * ds + dm, 0, 1)))
        img2=np.array(top(torch.clamp(images[1][i].detach().clone() * ds + dm, 0, 1)))
        #img2=np.array(top(images[1][i].detach()))
        p1=Image.fromarray(img1)
        p2=Image.fromarray(img2)
        p1.save('p1.png')
        p2.save('p2.png')
        
        #exit()
        
        gray1 =cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
        gray2 =cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
        sift = cv2.SIFT_create()
        kp1, des1 = sift.detectAndCompute(gray1, None)
        kp2, des2 = sift.detectAndCompute(gray2, None)

        bf = cv2.BFMatcher()
        matches = bf.knnMatch(des1, des2, k=2)


        good_matches = []
        for m,n in matches:
            if m.distance < 0.75 * n.distance:
                good_matches.append(m)
        
        src_pts = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
        dst_pts = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
        
        M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
        h, w = gray1.shape
        aligned_img = cv2.warpPerspective(img2, M, (w, h))
        rt.append((tot(Image.fromarray(aligned_img)).cuda()-dm)/ds)
        #rt.append(tot(Image.fromarray(aligned_img)))
    rt=torch.stack(rt,0)
        #rt=torch.tensor([tot(img) for img in rt])
    return rt
'''        


def total_variation(x):
    """Anisotropic TV."""
    dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))
    dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
    return dx + dy

def april_loss(o_gradients, o_input_gradient,idx):
    #gradients=o_gradients[0]
    input_gradient=o_input_gradient[idx]
    #gradients=[o_gradients[i] for i in idx]
    #input_gradient=[o_input_gradient[i] for i in idx]
    #print(type(gradients))
    #print(type(input_gradient))
    #print(gradients)
    #print(input_gradient)
    #print(input_gradient.size())
    #exit()
    r_april=0
    
    for trial_gradient in o_gradients:
        costs=0
        #print(trial_gradient[0].size())
        #exit()
        #print(input_gradient.flatten(),trial_gradient[5].flatten())
        #exit()
        costs=torch.dot(input_gradient.flatten(),trial_gradient[idx].flatten())/(torch.norm(input_gradient.flatten())*torch.norm(trial_gradient[idx].flatten()))
        
        r_april += costs
    return r_april 

def aux_patch_loss(inp_noise):
    loss = torch.nn.MSELoss(reduction="mean")
    #TODO 输入层要根据数据集调整
    
    patch_embedder = PatchEmbed(img_size=224, patch_size=16, in_chans=3, embed_dim=768).cuda()

    inp_noise_patches = patch_embedder(inp_noise)
    B, _, _  = inp_noise_patches.shape
    inp_noise_patches = inp_noise_patches.reshape((B, 14, 14, 3, 16, 16))
    b, num_h, num_w, c, h_patch, w_patch = inp_noise_patches.shape

    loss_patch = 0.
    for i in range(num_h-1):
        loss_vertical = loss(inp_noise_patches[:, i+1, :, :, 0, :], inp_noise_patches[:, i, :, :, 15, :])
        loss_horizontal = loss(inp_noise_patches[:, :, i+1, :, :, 0], inp_noise_patches[:, :, i, :, :, 15])
        loss_patch += (loss_vertical + loss_horizontal)

    return loss_patch


def activation_errors(model, x1, x2):
    """Compute activation-level error metrics for every module in the network."""
    model.eval()

    device = next(model.parameters()).device

    hooks = []
    data = defaultdict(dict)
    inputs = torch.cat((x1, x2), dim=0)
    separator = x1.shape[0]

    def check_activations(self, input, output):
        module_name = str(*[name for name, mod in model.named_modules() if self is mod])
        try:
            layer_inputs = input[0].detach()
            residual = (layer_inputs[:separator] - layer_inputs[separator:]).pow(2)
            se_error = residual.sum()
            mse_error = residual.mean()
            sim = torch.nn.functional.cosine_similarity(layer_inputs[:separator].flatten(),
                                                        layer_inputs[separator:].flatten(),
                                                        dim=0, eps=1e-8).detach()
            data['se'][module_name] = se_error.item()
            data['mse'][module_name] = mse_error.item()
            data['sim'][module_name] = sim.item()
        except (KeyboardInterrupt, SystemExit):
            raise
        except AttributeError:
            pass

    for name, module in model.named_modules():
        hooks.append(module.register_forward_hook(check_activations))

    try:
        outputs = model(inputs.to(device))
        for hook in hooks:
            hook.remove()
    except Exception as e:
        for hook in hooks:
            hook.remove()
        raise

    return data


def ssim_gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def ssim_create_window(window_size, channel):
    _1D_window = ssim_gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def ssim_ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

def ssim(img1, img2, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = ssim_create_window(window_size, channel)
    
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    
    return ssim_ssim(img1, img2, window, window_size, channel, size_average)

def ssim_batch(ref_batch, img_batch, batched=False, factor=1.0):

    [B, C, m, n] = img_batch.shape
    ssims = []
    for sample in range(B):
        ssims.append(ssim(img_batch.detach()[sample, :, :, :].unsqueeze(0), ref_batch[sample, :, :, :].unsqueeze(0)))
    
    mean_ssim = torch.stack(ssims, dim=0).mean()
    return mean_ssim.item(), ssims

def ssim_permute(ref_batch, img_batch, batched=False, factor=1.0):
    ### SSIM regarding permutation ### 
    ssims = []
    for i in range (img_batch.shape[0]):
        img_repeat = img_batch[i].unsqueeze(0).repeat(img_batch.shape[0], 1, 1, 1)
        _, candidate_ssims = ssim_batch(ref_batch, img_repeat)
        mx = torch.max(torch.stack(candidate_ssims).view(1, -1))
        ssims.append(mx)

    mean_ssim = torch.stack(ssims).mean()
    return mean_ssim.item(), ssims
import lpips
from copy import deepcopy
def lpips_loss(img_batch, ref_batch, net='alex'):
    img_batch=deepcopy(img_batch).to("cpu")
    ref_batch=deepcopy(ref_batch).to("cpu")
    #print(img_batch.device, ref_batch.device)
    loss_fn = lpips.LPIPS(net=net)
    [B, C, m, n] = img_batch.shape
    lpips_losses = []
    for sample in range(B):
        lpips_losses.append(loss_fn(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :]))
    lpips_loss = torch.stack(lpips_losses, dim=0).mean()

    return lpips_loss.item()