"""Define basic models and translate some torchvision stuff."""
"""Stuff from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py."""
import torch
import torchvision
import torch.nn as nn

from torchvision.models.resnet import Bottleneck
from .revnet import iRevNet
from .densenet import _DenseNet, _Bottleneck
from vit_timm import vit_small_patch16_224
from collections import OrderedDict
import numpy as np
from ..utils import set_random_seed
import sys
sys.path.append("C:\\Users\\92865\\Desktop\\gradient-inversion-bench\\inversefed\\nn")
torch.manual_seed(42)



def set_bn_random(module):

    if isinstance(module, nn.BatchNorm2d):
        '''
        module.weight.data = torch.randn(module.weight.data.size())
        module.bias.data = torch.randn(module.bias.data.size())
        '''
        module.weight.data.normal_(mean=0.0, std=1.0)
        module.bias.data.zero_()
    
def construct_model(model, num_classes=10, seed=None, num_channels=3, modelkey=None,pth=0,skip=-1,config_id=0):
    """Return various models."""
    if modelkey is None:
        if seed is None:
            model_init_seed = np.random.randint(0, 2**31 - 10)
        else:
            model_init_seed = seed
    else:
        model_init_seed = modelkey
    set_random_seed(model_init_seed)
    if model in ['ConvNet', 'ConvNet64']:
        model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes)
    elif model=="LeNet":
        model =  LeNet(channel=3,num_classes=num_classes)
        model_path ="/disk/Jiahui/47_2080ti/djc/workspace-ng/trainedvsuntrained/FL_saved_models_lenet/model_"+str(pth)+".pth"
        model.load_state_dict(torch.load(model_path))
    elif model=="LeNetMnist":
        model =  LeNet(channel=1)
    elif model == 'SimpleCNN':
        model =  SimpleCNN()
    elif model == "Alexnet":
        model = torchvision.models.alexnet()
    elif model== "Vgg16":
        from .cifar10_models.vgg import vgg16_bn
        model=vgg16_bn(pretrained=False)

        #model=vgg16_bn(pretrained=True)
        #model.load_state_dict(torch.load(model_path))
        #model.apply(set_bn_random)
        #model = torchvision.models.vgg16(num_classes=num_classes)
        #model_path ="/disk/Jiahui/47_2080ti/djc/workspace-ng/trainedvsuntrained/FL_saved_models_vgg16/model_"+str(pth)+".pth"
        #model.load_state_dict(torch.load(model_path))
        #torchvision.models.VGG16_Weights
        #model = torchvision.models.vgg16(pretrained=True)
    elif model=="Vgg11":
        from .cifar10_models.vgg import vgg11_bn
        model=vgg11_bn()
        #model = torchvision.models.vgg11(num_classes=num_classes)
        #if pth<0:
        #    model_path ="/opt/home/Jiahui/djc/workspace-ng/trainedvsuntrained/checkpoint/model_vgg11.pth"
        #else:
        #    model_path ="/disk/Jiahui/47_2080ti/djc/workspace-ng/trainedvsuntrained/FL_saved_models_vgg11/model_"+str(pth)+".pth"
        #model.load_state_dict(torch.load(model_path))
    elif model=="Vgg19":
        from .cifar10_models.vgg import vgg19_bn
        model=vgg19_bn(pretrained=False)
        #model = torchvision.models.vgg19(num_classes=num_classes)
        #model_path ="/disk/Jiahui/47_2080ti/djc/workspace-ng/trainedvsuntrained/FL_saved_models_vgg19/model_"+str(pth)+".pth"
        #model.load_state_dict(torch.load(model_path))
    
    elif model=="Simple_CNN":
        model=Simple_CNN()
        ##model_path = "/opt/home/Jiahui/djc/workspace-ng/structure/model_params_SimpleCNN_random_epoch_200_acc_53.9800.pth"
        #model.load_state_dict(torch.load(model_path))
    elif model=="ResNetLikeCNN":
        model=ResNetLikeCNN()
        #model_path = "/opt/home/Jiahui/djc/workspace-ng/inversefed/nn/weights/model_params_ResNetLikeCNN.pth"
        #model.load_state_dict(torch.load(model_path))
    elif model=="DenseNetLikeCNN":
        model=DenseNetLikeCNN()
        #model_path = "/opt/home/Jiahui/djc/workspace-ng/inversefed/nn/weights/model_params_DenseNetLikeCNN.pth"
        #model.load_state_dict(torch.load(model_path))
    elif model=="InceptionLikeCNN":
        model=InceptionLikeCNN()
        #model_path = "/opt/home/Jiahui/djc/workspace-ng/inversefed/nn/weights/model_params_InceptionLikeCNN.pth"
        #model.load_state_dict(torch.load(model_path))
    
    elif model=="SkipCNN-18":
        #model=SkipCNN([1,1,1])
        #model = MyNet([1,1,1])
        #model=resnet_test()
        #[3, 4, 6, 3]
        #model=ResNet_([3, 4, 6, 3],1000)
        residual_usage=[True,True,True,True,True,True,True,True]
        if skip!=-1:
            residual_usage[skip]=False
        model=ResNet_([2,2,2,2],10,residual_usage)
        #model=DenseNet(blocks=(1, 1, 1, 1))
        #model=NormalCNN()
        #model=DenseCNN()
    elif model=="SkipCNN-34":
        residual_usage=[True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True]
        if skip!=-1:
            residual_usage[skip]=False
        model=ResNet_([3, 4, 6, 3],10,residual_usage)
    elif model=="convnext_tiny":
        model=torchvision.models.convnext_tiny(num_classes=10)
    elif model=="convnext_small":
        model=torchvision.models.convnext_small(num_classes=10)
    elif model=="convnext_base":
        model=torchvision.models.convnext_base(num_classes=10)
    elif model=="convnext_large":
        model=torchvision.models.convnext_large(num_classes=10)
    elif model=="DenseNet43":
        from .densenets.DenseNet43 import DenseNet
        model=DenseNet(10)
    elif model=="DenseNet43_sht1":
        from .densenets.DenseNet43_sht1 import DenseNet
        model=DenseNet(10)
    elif model=="DenseNet43_sht2":
        from .densenets.DenseNet43_sht2 import DenseNet
        model=DenseNet(10)
    elif model=="DenseNet53":
        from .densenets.DenseNet53 import DenseNet
        model=DenseNet(10)
    elif model=="DenseNet53_sht1":
        from .densenets.DenseNet53_sht1 import DenseNet
        model=DenseNet(10)
    elif model=="DenseNet53_sht2":
        from .densenets.DenseNet53_sht2 import DenseNet
        model=DenseNet(10)
    elif model == "ConCNN":
        if config_id == 0:
            use_bias = True
            use_relu = True
            use_dropout = True
            use_maxpool = True
            kernel_size = 3
            padding = 1
        elif config_id == 1:
            use_bias = True
            use_relu = False
            use_dropout = True
            use_maxpool = True
            kernel_size = 3
            padding = 1
        elif config_id == 2:
            use_bias = True
            use_relu = True
            use_dropout = False
            use_maxpool = True
            kernel_size = 3
            padding = 1
        elif config_id == 3:
            use_bias = True
            use_relu = True
            use_dropout = True
            use_maxpool = False
            kernel_size = 3
            padding = 1
        elif config_id == 4:
            use_bias = True
            use_relu = True
            use_dropout = True
            use_maxpool = True
            kernel_size = 4
            padding = 1
        elif config_id == 5:
            use_bias = False
            use_relu = True
            use_dropout = True
            use_maxpool = True
            kernel_size = 3
            padding = 1
        elif config_id == 6:
            use_bias = True
            use_relu = True
            use_dropout = True
            use_maxpool = True
            kernel_size = 2
            padding = 1
        elif config_id == 7:
            use_bias = True
            use_relu = True
            use_dropout = True
            use_maxpool = True
            kernel_size = 1
            padding = 1
        elif config_id == 8:
            use_bias = True
            use_relu = True
            use_dropout = True
            use_maxpool = True
            kernel_size = 3
            padding = 2
        elif config_id == 9:
            use_bias = True
            use_relu = True
            use_dropout = True
            use_maxpool = True
            kernel_size = 3
            padding = 3

        #model = ConfigurableCNN(use_bias, use_relu, use_dropout, use_maxpool, kernel_size, padding)
        model = ConvNet_Config(use_bias, use_relu, use_dropout, use_maxpool, kernel_size, padding)
    
    
    elif model == "SimpleCNNImageNet":
        model=SimpleCNNImageNet()     
    elif model =='SimpleCNNMnist':
        model = SimpleCNNMnist()
    elif model == 'ConvNet8':
        model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes)
    elif model == 'ConvNet16':
        model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes)
    elif model == 'ConvNet32':
        model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes)
    elif model == 'BeyondInferringMNIST':
        model = torch.nn.Sequential(OrderedDict([
            ('conv1', torch.nn.Conv2d(1, 32, 3, stride=2, padding=1)),
            ('relu0', torch.nn.LeakyReLU()),
            ('conv2', torch.nn.Conv2d(32, 64, 3, stride=1, padding=1)),
            ('relu1', torch.nn.LeakyReLU()),
            ('conv3', torch.nn.Conv2d(64, 128, 3, stride=2, padding=1)),
            ('relu2', torch.nn.LeakyReLU()),
            ('conv4', torch.nn.Conv2d(128, 256, 3, stride=1, padding=1)),
            ('relu3', torch.nn.LeakyReLU()),
            ('flatt', torch.nn.Flatten()),
            ('linear0', torch.nn.Linear(12544, 12544)),
            ('relu4', torch.nn.LeakyReLU()),
            ('linear1', torch.nn.Linear(12544, 10)),
            ('softmax', torch.nn.Softmax(dim=1))
        ]))
    elif model == 'BeyondInferringCifar':
        model = torch.nn.Sequential(OrderedDict([
            ('conv1', torch.nn.Conv2d(3, 32, 3, stride=2, padding=1)),
            ('relu0', torch.nn.LeakyReLU()),
            ('conv2', torch.nn.Conv2d(32, 64, 3, stride=1, padding=1)),
            ('relu1', torch.nn.LeakyReLU()),
            ('conv3', torch.nn.Conv2d(64, 128, 3, stride=2, padding=1)),
            ('relu2', torch.nn.LeakyReLU()),
            ('conv4', torch.nn.Conv2d(128, 256, 3, stride=1, padding=1)),
            ('relu3', torch.nn.LeakyReLU()),
            ('flatt', torch.nn.Flatten()),
            ('linear0', torch.nn.Linear(12544, 12544)),
            ('relu4', torch.nn.LeakyReLU()),
            ('linear1', torch.nn.Linear(12544, 10)),
            ('softmax', torch.nn.Softmax(dim=1))
        ]))
    elif model == 'MLP':
        '''
        width = 1024
        model = torch.nn.Sequential(OrderedDict([
            ('flatten', torch.nn.Flatten()),
            ('linear0', torch.nn.Linear(3072, width)),
            ('relu0', torch.nn.ReLU()),
            ('linear1', torch.nn.Linear(width, width)),
            ('relu1', torch.nn.ReLU()),
            ('linear2', torch.nn.Linear(width, width)),
            ('relu2', torch.nn.ReLU()),
            ('linear3', torch.nn.Linear(width, num_classes))]))
        '''
        model=MLP()
        model_path ="/disk/Jiahui/47_2080ti/djc/workspace-ng/trainedvsuntrained/FL_saved_models_mlp/model_"+str(pth)+".pth"
        model.load_state_dict(torch.load(model_path))
    elif model == 'TwoLP':
        width = 2048
        model = torch.nn.Sequential(OrderedDict([
            ('flatten', torch.nn.Flatten()),
            ('linear0', torch.nn.Linear(3072, width)),
            ('relu0', torch.nn.ReLU()),
            ('linear3', torch.nn.Linear(width, num_classes))]))
    elif model == 'ResNet20':
        model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16)
    elif model == 'ResNet20-nostride':
        model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16,
                       strides=[1, 1, 1, 1])
    elif model == 'ResNet20-10':
        model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 10)
    elif model == 'ResNet20-4':
        model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 4)
    elif model == 'ResNet20-4-unpooled':
        model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 4,
                       pool='max')
    elif model == 'ResNet28-10':
        model = ResNet(torchvision.models.resnet.BasicBlock, [4, 4, 4], num_classes=num_classes, base_width=16 * 10)
    elif model == 'ResNet32':
        model = ResNet(torchvision.models.resnet.BasicBlock, [5, 5, 5], num_classes=num_classes, base_width=16)
    elif model == 'ResNet32-10':
        model = ResNet(torchvision.models.resnet.BasicBlock, [5, 5, 5], num_classes=num_classes, base_width=16 * 10)
    elif model == 'ResNet44':
        model = ResNet(torchvision.models.resnet.BasicBlock, [7, 7, 7], num_classes=num_classes, base_width=16)
    elif model == 'ResNet56':
        model = ResNet(torchvision.models.resnet.BasicBlock, [9, 9, 9], num_classes=num_classes, base_width=16)
    elif model == 'ResNet110':
        model = ResNet(torchvision.models.resnet.BasicBlock, [18, 18, 18], num_classes=num_classes, base_width=16)
    elif model == 'ResNet18':
        from .cifar10_models.resnet import resnet18
        model=resnet18(pretrained=False)
        
        #if pth<0:
        #    model_path ="/opt/home/Jiahui/djc/workspace-ng/trainedvsuntrained/checkpoint/model_resnet18.pth"
        #else:
        #    model_path ="/disk/Jiahui/47_2080ti/djc/workspace-ng/trainedvsuntrained/FL_saved_models_resnet18/model_"+str(pth)+".pth"
        #model.load_state_dict(torch.load(model_path))
        #model=resnet18(pretrained=True)
        #model.load_state_dict(torch.load(model_path))
        
    elif model == 'ResNet34':
        
        from .cifar10_models.resnet import resnet34
        model=resnet34(pretrained=False)
        #if pth<0:
        #    model_path ="/opt/home/Jiahui/djc/workspace-ng/trainedvsuntrained/checkpoint/model_resnet18.pth"
        #else:
        #    model_path ="/disk/Jiahui/47_2080ti/djc/workspace-ng/trainedvsuntrained/FL_saved_models_resnet34/model_"+str(pth)+".pth"
        #model.load_state_dict(torch.load(model_path))
        #model = ResNet(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], num_classes=num_classes, base_width=64)

    elif model == 'ResNet50':
        from .cifar10_models.resnet import resnet50
        model=resnet50(pretrained=False)
        #model = ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], num_classes=num_classes, base_width=64)
        #model=torchvision.models.resnet50(num_classes=num_classes)
        #model_path ="/disk/Jiahui/47_2080ti/djc/workspace-ng/trainedvsuntrained/FL_saved_models_resnet50/model_"+str(pth)+".pth"
        #model.load_state_dict(torch.load(model_path))
    elif model == 'ResNet50-2':
        model = ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], num_classes=num_classes, base_width=64 * 2)
    elif model == 'ResNet101':
        model = ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], num_classes=num_classes, base_width=64)
    elif model == 'ResNet152':
        #from .cifar10_models.resnet import resnet152
        #model=resnet152(pretrained=False)
        model = torchvision.models.resnet152(num_classes=num_classes)
        model_path ="/disk/Jiahui/47_2080ti/djc/workspace-ng/trainedvsuntrained/FL_saved_models_resnet152/model_"+str(pth)+".pth"
        model.load_state_dict(torch.load(model_path))
        #model = ResNet(torchvision.models.resnet.Bottleneck, [3, 8, 36, 3], num_classes=num_classes, base_width=64)
    elif model == 'MobileNet':
        inverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 1],  # cifar adaptation, cf.https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenetv2.py
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]
        #model = torchvision.models.MobileNetV2(num_classes=num_classes,
        #                                       inverted_residual_setting=inverted_residual_setting,
        #                                       width_mult=1.0)
        #model.features[0] = ConvBNReLU(num_channels, 32, stride=1)  # this is fixed to width=1
        from cifar10_models.mobilenetv2 import mobilenet_v2
        model=mobilenet_v2(pretrained=True)
        
    elif model == 'MNASNet':
        model = torchvision.models.MNASNet(1.0, num_classes=num_classes, dropout=0.2)
    elif model == 'DenseNet121':
        from .cifar10_models.densenet import densenet121
        model=densenet121(pretrained=False)
        #model = torchvision.models.DenseNet(growth_rate=32, block_config=(6, 12, 24, 16),
        #                                    num_init_features=64, bn_size=4, drop_rate=0, num_classes=num_classes,
        #                                    memory_efficient=False)
        #model=torchvision.models.densenet121(num_classes=10)
        #model_path ="/disk/Jiahui/47_2080ti/djc/workspace-ng/trainedvsuntrained/FL_saved_models_densenet121/model_"+str(pth)+".pth"
        #model.load_state_dict(torch.load(model_path))
    elif model == 'DenseNet40':
        model = _DenseNet(_Bottleneck, [6, 6, 6, 0], growth_rate=12, num_classes=num_classes)
    elif model == 'GoogleNet':
        from .cifar10_models.googlenet import googlenet
        model=googlenet(pretrained=False)
        #model = torchvision.models.googlenet(num_classes=10)
        #model_path ="/disk/Jiahui/47_2080ti/djc/workspace-ng/trainedvsuntrained/FL_saved_models_googlenet/model_"+str(pth)+".pth"
        #model.load_state_dict(torch.load(model_path))
    elif model ==  "Inception_v3":
        from .cifar10_models.inception import inception_v3
        model=inception_v3(pretrained=False)
        #model_path ="/disk/Jiahui/47_2080ti/djc/workspace-ng/trainedvsuntrained/FL_saved_models_inception_v3/model_"+str(pth)+".pth"
        #model.load_state_dict(torch.load(model_path))

    elif model == 'DenseNet40-4':
        model = _DenseNet(_Bottleneck, [6, 6, 6, 0], growth_rate=12 * 4, num_classes=num_classes)
    elif model == 'SRNet3':
        model = SRNet(upscale_factor=3, num_channels=num_channels)
    elif model == 'SRNet1':
        model = SRNet(upscale_factor=1, num_channels=num_channels)
    elif model == 'iRevNet':
        if num_classes <= 100:
            in_shape = [num_channels, 32, 32]  # only for cifar right now
            model = iRevNet(nBlocks=[18, 18, 18], nStrides=[1, 2, 2],
                            nChannels=[16, 64, 256], nClasses=num_classes,
                            init_ds=0, dropout_rate=0.1, affineBN=True,
                            in_shape=in_shape, mult=4)
        else:
            in_shape = [3, 224, 224]  # only for imagenet
            model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2],
                            nChannels=[24, 96, 384, 1536], nClasses=num_classes,
                            init_ds=2, dropout_rate=0.1, affineBN=True,
                            in_shape=in_shape, mult=4)
    elif model == 'LeNetZhu':
        model = LeNetZhu(num_channels=num_channels, num_classes=num_classes)


    elif model =='vit':
        from vit_pytorch import SimpleViT
        model = SimpleViT(
                image_size = 224,
                patch_size = 16,
                num_classes = 1000,
                dim = 1024,
                depth = 6,
                heads = 12,
                mlp_dim = 2048
            )
    elif model=="Vit_small":
        from .vit_small import ViT
        model = ViT(
        image_size = 32,
        patch_size = 4,
        num_classes = 10,
        dim = 512,
        depth = 6,
        heads = 8,
        mlp_dim = 512,
        dropout = 0.1,
        emb_dropout = 0.1
    )   
        model_path ="/disk/Jiahui/47_2080ti/djc/workspace-ng/trainedvsuntrained/FL_saved_models_vit_small/model_"+str(pth)+".pth"
        model.load_state_dict(torch.load(model_path))
    elif model=="Swin":
        from .swin import swin_t
        model = swin_t(window_size=4,
                    num_classes=10,
                    downscaling_factors=(2,2,2,1))
        model_path ="/disk/Jiahui/47_2080ti/djc/workspace-ng/trainedvsuntrained/FL_saved_models_swin/model_"+str(pth)+".pth"
        model.load_state_dict(torch.load(model_path))
        
    else:
        raise NotImplementedError('Model not implemented.')
    
    print(f'Model initialized with random key {model_init_seed}.')
    return model, model_init_seed

class ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU6(inplace=True)
        )
class ResNet(torchvision.models.ResNet):
    """ResNet generalization for CIFAR thingies."""

    def __init__(self, block, layers, num_classes=10, zero_init_residual=False,
                 groups=1, base_width=64, replace_stride_with_dilation=None,
                 norm_layer=None, strides=[1, 2, 2, 2], pool='avg'):
        """Initialize as usual. Layers and strides are scriptable."""
        super(torchvision.models.ResNet, self).__init__()  # nn.Module
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer


        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False, False]
        if len(replace_stride_with_dilation) != 4:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 4-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups

        self.inplanes = base_width
        self.base_width = 64  # Do this to circumvent BasicBlock errors. The value is not actually used.
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)

        self.layers = torch.nn.ModuleList()
        width = self.inplanes
        for idx, layer in enumerate(layers):
            self.layers.append(self._make_layer(block, width, layer, stride=strides[idx], dilate=replace_stride_with_dilation[idx]))
            width *= 2

        self.pool = nn.AdaptiveAvgPool2d((1, 1)) if pool == 'avg' else nn.AdaptiveMaxPool2d((1, 1))
        self.fc = nn.Linear(width // 2 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)


    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        for layer in self.layers:
            x = layer(x)

        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x
import torch.nn.functional as F
class SimpleCNNImageNet(nn.Module):
    def __init__(self):
        super(SimpleCNNImageNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=0)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(128 * 13 * 13, 512)
        self.fc2 = nn.Linear(512, 1000)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.pool3(x)
        x = x.view(-1, 128 * 13 * 13)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3,   64,  3)
        self.conv2 = nn.Conv2d(64,  128, 3)
        self.conv3 = nn.Conv2d(128, 256, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 4 * 4, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 64 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)



class SimpleCNNMnist(nn.Module):
    def __init__(self):
        super(SimpleCNNMnist, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)  
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)  
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)  
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2) 
        
        self.fc = nn.Linear(64 * 6 * 6, 10) 
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        
        x = x.view(-1, 64 * 6 * 6) 
        x = self.fc(x)
        
        return x


class ConvNet(torch.nn.Module):
    """ConvNetBN."""

    def __init__(self, width=64, num_classes=10, num_channels=3):
        """Init with width and num classes."""
        super().__init__()
        self.model = torch.nn.Sequential(OrderedDict([
            ('conv0', torch.nn.Conv2d(num_channels, 1 * width, kernel_size=3, padding=1)),
            ('bn0', torch.nn.BatchNorm2d(1 * width)),
            ('relu0', torch.nn.ReLU()),

            ('conv1', torch.nn.Conv2d(1 * width, 2 * width, kernel_size=3, padding=1)),
            ('bn1', torch.nn.BatchNorm2d(2 * width)),
            ('relu1', torch.nn.ReLU()),

            ('conv2', torch.nn.Conv2d(2 * width, 2 * width, kernel_size=3, padding=1)),
            ('bn2', torch.nn.BatchNorm2d(2 * width)),
            ('relu2', torch.nn.ReLU()),

            ('conv3', torch.nn.Conv2d(2 * width, 4 * width, kernel_size=3, padding=1)),
            ('bn3', torch.nn.BatchNorm2d(4 * width)),
            ('relu3', torch.nn.ReLU()),

            ('conv4', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)),
            ('bn4', torch.nn.BatchNorm2d(4 * width)),
            ('relu4', torch.nn.ReLU()),

            ('conv5', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)),
            ('bn5', torch.nn.BatchNorm2d(4 * width)),
            ('relu5', torch.nn.ReLU()),

            ('pool0', torch.nn.MaxPool2d(3)),

            ('conv6', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)),
            ('bn6', torch.nn.BatchNorm2d(4 * width)),
            ('relu6', torch.nn.ReLU()),

            ('conv6', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)),
            ('bn6', torch.nn.BatchNorm2d(4 * width)),
            ('relu6', torch.nn.ReLU()),

            ('conv7', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)),
            ('bn7', torch.nn.BatchNorm2d(4 * width)),
            ('relu7', torch.nn.ReLU()),

            ('pool1', torch.nn.MaxPool2d(3)),
            ('flatten', torch.nn.Flatten()),
            ('linear', torch.nn.Linear(36 * width, num_classes))
        ]))

    def forward(self, input):
        return self.model(input)
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 256)  # 输入大小为 32*32*3，输出大小为 256
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)  

    def forward(self, x):
        x = x.view(-1, 32 * 32 * 3)  
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x
class LeNet(nn.Module):
    def __init__(self, channel=1, hidden=768, num_classes=10):
        super(LeNet, self).__init__()
        act = nn.Sigmoid

        self.body = nn.Sequential(
            nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            act()
        )

        self.fc = nn.Sequential(
            nn.Linear(hidden, num_classes)
        )

    def forward(self, x):
        out = self.body(x)
        #print("out的size:",out.shape)
        out = out.view(-1, 768)
        out = self.fc(out)
        return out

class LeNetZhu(nn.Module):
    """LeNet variant from https://github.com/mit-han-lab/dlg/blob/master/models/vision.py."""

    def __init__(self, num_classes=10, num_channels=3):
        """3-Layer sigmoid Conv with large linear layer."""
        super().__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(num_channels, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            act(),
        )
        self.fc = nn.Sequential(
            nn.Linear(768, num_classes)
        )
        for module in self.modules():
            self.weights_init(module)

    @staticmethod
    def weights_init(m):
        if hasattr(m, "weight"):
            m.weight.data.uniform_(-0.5, 0.5)
        if hasattr(m, "bias"):
            m.bias.data.uniform_(-0.5, 0.5)

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        # print(out.size())
        out = self.fc(out)
        return out
    


    
    




class BasicBlock_(nn.Module):
    def __init__(self,inplanes: int,planes: int,stride: int = 1,downsample = None, use_residual: bool = True) -> None:
        super(BasicBlock_, self).__init__()
        self.use_residual = use_residual  # New parameter

        self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None and self.use_residual:
            identity = self.downsample(x)

        if self.use_residual:  # Apply residual connection based on the use_residual parameter
            out += identity
        out = self.relu(out)
        return out


class ResNet_(nn.Module):

    def __init__(self,layers: list, num_classes: int,residual_usage: list,zero_init_residual: bool = False) -> None:
        super(ResNet_, self).__init__()
        self.inplanes = 64
        self.dilation = 1
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        #[3, 4, 6, 3]

        #print(residual_usage)
        #print(layers)
        #print(layers[0])
        self.layer1 = self._make_layer(64, layers[0], stride=1, residual_usage=residual_usage[:layers[0]])
        self.layer2 = self._make_layer(128, layers[1], stride=2, residual_usage=residual_usage[layers[0]:layers[0]+layers[1]])
        self.layer3 = self._make_layer(256, layers[2], stride=2, residual_usage=residual_usage[layers[0]+layers[1]:layers[0]+layers[1]+layers[2]])
        self.layer4 = self._make_layer(512, layers[3], stride=2, residual_usage=residual_usage[layers[0]+layers[1]+layers[2]:layers[0]+layers[1]+layers[2]+layers[3]])

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        #self.do = nn.Dropout(0.2)
        self.fc = nn.Linear(512, num_classes)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, BasicBlock_):
                    nn.init.constant_(m.bn2.weight, 0)
        

    def _make_layer(self, planes: int, blocks: int, stride: int = 1, residual_usage: list = None) -> nn.Sequential:
        downsample = None
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes, 1, stride=stride, bias=False),
                                       nn.BatchNorm2d(planes))
        layers = []
        #if downsample is not None and residual_usage[0]:
        #    print(True)
        layers.append(BasicBlock_(self.inplanes, planes, stride, downsample, use_residual=residual_usage[0]))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(BasicBlock_(self.inplanes, planes, use_residual=residual_usage[i]))
        return nn.Sequential(*layers)

    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        #x = self.do(x)
        x = self.fc(x)

        return x

class _DenseLayer(nn.Module):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate=0):
        super(_DenseLayer, self).__init__()
        self.drop_rate = drop_rate
        self.dense_layer = nn.Sequential(
            nn.BatchNorm2d(num_input_features),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=num_input_features, out_channels=bn_size * growth_rate, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(bn_size * growth_rate),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=bn_size * growth_rate, out_channels=growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
        )
        self.dropout = nn.Dropout(p=self.drop_rate)

    def forward(self, x):
        y = self.dense_layer(x)
        if self.drop_rate > 0:
            y = self.dropout(y)

        return torch.cat([x, y], dim=1)


class _DenseBlock(nn.Module):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate=0):
        super(_DenseBlock, self).__init__()
        layers = []
        for i in range(num_layers):
            layers.append(_DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


class _TransitionLayer(nn.Module):
    def __init__(self, num_input_features, num_output_features):
        super(_TransitionLayer, self).__init__()
        self.transition_layer = nn.Sequential(
            nn.BatchNorm2d(num_input_features),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=num_input_features, out_channels=num_output_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        return self.transition_layer(x)

class _SequentialLayer(nn.Module):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate=0):
        super(_SequentialLayer, self).__init__()
        self.drop_rate = drop_rate
        self.sequential_layer = nn.Sequential(
            nn.BatchNorm2d(num_input_features),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=num_input_features, out_channels=bn_size * growth_rate, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(bn_size * growth_rate),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=bn_size * growth_rate, out_channels=growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
        )
        self.dropout = nn.Dropout(p=self.drop_rate)

    def forward(self, x):
        y = self.sequential_layer(x)
        if self.drop_rate > 0:
            y = self.dropout(y)

        return y

class _SequentialBlock(nn.Module):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate=0):
        super(_SequentialBlock, self).__init__()
        layers = []
        for _ in range(num_layers):
            layers.append(_SequentialLayer(num_input_features, growth_rate, bn_size, drop_rate))
            num_input_features += growth_rate
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


class DenseNet(nn.Module):
    def __init__(self, num_init_features=64, growth_rate=32, blocks=(6, 12, 24, 16), bn_size=4, drop_rate=0, num_classes=1000, block_type='dense'):
        super(DenseNet, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=num_init_features, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(num_init_features),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        num_features = num_init_features
        if block_type == 'dense':
            Block = _DenseBlock
        else:
            Block = _SequentialBlock

        self.layer1 = Block(num_layers=blocks[0], num_input_features=num_features, growth_rate=growth_rate, bn_size=bn_size, drop_rate=drop_rate)
        num_features = num_features + blocks[0] * growth_rate

        self.transtion1 = _TransitionLayer(num_input_features=num_features, num_output_features=num_features // 2)

        num_features = num_features // 2
        self.layer2 = Block(num_layers=blocks[1], num_input_features=num_features, growth_rate=growth_rate, bn_size=bn_size, drop_rate=drop_rate)
        num_features = num_features + blocks[1] * growth_rate

        self.transtion2 = _TransitionLayer(num_input_features=num_features, num_output_features=num_features // 2)

        num_features = num_features // 2
        self.layer3 = Block(num_layers=blocks[2], num_input_features=num_features, growth_rate=growth_rate, bn_size=bn_size, drop_rate=drop_rate)
        num_features = num_features + blocks[2] * growth_rate

        self.transtion3 = _TransitionLayer(num_input_features=num_features, num_output_features=num_features // 2)

        num_features = num_features // 2
        self.layer4 = Block(num_layers=blocks[3], num_input_features=num_features, growth_rate=growth_rate, bn_size=bn_size, drop_rate=drop_rate)
        num_features = num_features + blocks[3] * growth_rate

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):
        x = self.features(x)

        x = self.layer1(x)
        x = self.transtion1(x)
        x = self.layer2(x)
        x = self.transtion2(x)
        x = self.layer3(x)
        x = self.transtion3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)

        return x

class NormalBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(NormalBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        out = self.conv(x)
        out = self.relu(out)
        return out

class NormalCNN(nn.Module):
    def __init__(self, num_blocks=5, num_channels=64, num_classes=1000):
        super(NormalCNN, self).__init__()
        layers = []
        in_channels = 3
        for _ in range(5):
            for _ in range(num_blocks):
                layers.append(NormalBlock(in_channels, num_channels))
                in_channels = num_channels
            num_channels *= 2
            layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
        
        self.features = nn.Sequential(*layers)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(num_channels * 8 * 8, num_classes)  # Adjust input dimension
        
    def forward(self, x):
        out = self.features(x)
        out = self.avg_pool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out

class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers):
        super(DenseBlock, self).__init__()
        layers = []
        for _ in range(num_layers):
            layers.append(NormalBlock(in_channels, growth_rate))
            in_channels += growth_rate
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        out = self.layers(x)
        out = torch.cat([x, out], 1)
        return out

class DenseCNN(nn.Module):
    def __init__(self, num_blocks=5, growth_rate=32, num_classes=1000):
        super(DenseCNN, self).__init__()
        layers = []
        in_channels = 3
        for _ in range(5):
            layers.append(DenseBlock(in_channels, growth_rate, num_blocks))
            in_channels += num_blocks * growth_rate
            layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
        
        self.features = nn.Sequential(*layers)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_channels, num_classes)
        
    def forward(self, x):
        out = self.features(x)
        out = self.avg_pool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out



class ConvNet_Config(nn.Module):
    def __init__(self, use_bias, use_relu, use_dropout, use_maxpool, kernel_size, padding, width=64, num_classes=10, num_channels=3):
        super(ConvNet_Config, self).__init__()

        layers = []

        # Add convolutional layers
        in_channels = num_channels
        k=[1,2,2,4,4,4]
        for i in range(8):
            out_channels = width * k[i] if i < 6 else width * 4
            layers.append(('conv' + str(i), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=use_bias)))
            layers.append(('relu' + str(i), nn.ReLU() if use_relu else nn.Identity()))
            in_channels = out_channels

            if use_maxpool and i in [5, 7]:
                layers.append(('pool' + str(i), nn.MaxPool2d(3)))
        
        self.conv_layers = nn.Sequential(OrderedDict(layers))

        self.dropout1 = nn.Dropout(p=0.5) if use_dropout else nn.Identity()

        self.linear1 = nn.Linear(self._calculate_conv_output_size(num_channels, width), 10, bias=use_bias)

    def _calculate_conv_output_size(self, num_channels, width):
        dummy_input = torch.zeros(1, num_channels, 32, 32)
        dummy_output = self.conv_layers(dummy_input)
        return dummy_output.view(dummy_output.size(0), -1).size(1)

    def forward(self, input):
        x = self.conv_layers(input)
        x = x.view(x.size(0), -1)
        x = self.dropout1(x)
        x = self.linear1(x)
        return x

