import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

import argparse
parser = argparse.ArgumentParser(description='Train.')


from torch.utils.data import DataLoader, Subset

import numpy as np
from copy import deepcopy
parser.add_argument('--model', default='SimpleCNN', type=str)

parser.add_argument('--device', default='0', type=str)


args = parser.parse_args()

# Step 1: 数据加载和预处理
transform = transforms.Compose(
    [#transforms.RandomHorizontalFlip(),
     #transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False, num_workers=2)
class LeNet(nn.Module):
    def __init__(self, channel=1, hidden=768, num_classes=10):
        super(LeNet, self).__init__()
        act = nn.Sigmoid

        self.body = nn.Sequential(
            nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            act()
        )

        self.fc = nn.Sequential(
            nn.Linear(hidden, num_classes)
        )

    def forward(self, x):
        out = self.body(x)
        #print("out的size:",out.shape)
        out = out.view(-1, 768)
        out = self.fc(out)
        return out
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 256)  # 输入大小为 32*32*3，输出大小为 256
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)  # 输出为10个类别

    def forward(self, x):
        x = x.view(-1, 32 * 32 * 3)  # 展平图像数据
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x

    
# Step 2: 定义ResNet18模型
def get_resnet18(pretrained=False):
    model = torchvision.models.resnet18(pretrained=pretrained)
    return model

# Step 3: 训练函数
def train_model(model, dataloader, epochs, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    for epoch in range(epochs):
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)

    avg_loss = total_loss / len(dataloader.dataset)
    avg_acc = correct / total
    return avg_loss, avg_acc

def fed_avg(global_model, trainset, device, num_clients=5, epochs=10, batch_size=128):
    # 划分数据集
    indices = np.arange(len(trainset))
    #np.random.shuffle(indices)
    subsets = np.array_split(indices, num_clients)
    
    # 初始化
    client_models = [deepcopy(global_model) for _ in range(num_clients)]
    criterion = nn.CrossEntropyLoss()
    avg_loss, avg_acc = 0, 0

    # 设备
    global_model.to(device)
    
    # 客户端训练
    client_weights = []
    for client_id, subset_indices in enumerate(subsets):
        train_loader = DataLoader(Subset(trainset, subset_indices), batch_size=batch_size, shuffle=True)
        client_model = client_models[client_id].to(device)
        optimizer = optim.SGD(client_model.parameters(), lr=0.01)
        
        # 本地训练
        loss, acc = train_model(client_model, train_loader, epochs, criterion, optimizer, device)
        avg_loss += loss
        avg_acc += acc
        
        # 收集模型参数
        client_weights.append(deepcopy(client_model.state_dict()))
    
    # 平均损失和精度
    avg_loss /= num_clients
    avg_acc /= num_clients

    # 更新全局模型
    global_dict = global_model.state_dict()
    for key in global_dict.keys():
        for sub_id in range(num_clients):
            if sub_id == 0:
                global_dict[key] = client_weights[sub_id][key] / num_clients
            else:
                global_dict[key] += client_weights[sub_id][key] / num_clients
    global_model.load_state_dict(global_dict)

    return global_model, avg_loss, avg_acc


# Step 4: 验证函数
def test(net, testloader, device):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return 100 * correct / total
def modify_fc_shape(model, num_classes):
    # Remove existing fully connected layers
    model.fc = torch.nn.Identity()
    
    # Add a new fully connected layer with the desired number of classes
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# Step 5: 训练和保存不同阶段的模型参数
def train_and_save_model(model,save_dir,device):
    device = torch.device("cuda:"+device if torch.cuda.is_available() else "cpu")
    if model=="lenet":
        net=LeNet(channel=3).to(device)
    elif model == "mlp":
        #width = 1024
        net=MLP().to(device)
    elif model=="vgg11":
        from cifar10_models.vgg import vgg11_bn
        net=vgg11_bn(pretrained=False).to(device)
        #net = torchvision.models.vgg11(num_classes=10).to(device)
    elif model=="vgg16":
        from cifar10_models.vgg import vgg16_bn
        net=vgg16_bn(pretrained=False).to(device)
        #net = torchvision.models.vgg16(num_classes=10).to(device)
    elif model=="vgg19":
        from cifar10_models.vgg import vgg19_bn
        net=vgg19_bn(pretrained=False).to(device)
        #net = torchvision.models.vgg19(num_classes=10).to(device)
    elif model=="resnet18":
        from cifar10_models.resnet import resnet18
        net=resnet18(pretrained=False).to(device)
        #net = torchvision.models.resnet18(num_classes=10).to(device)
    elif model=="resnet50":
        from cifar10_models.resnet import resnet50
        net=resnet50(pretrained=False).to(device)
        #net = torchvision.models.resnet50(num_classes=10).to(device)
    elif model=="resnet152":
        net = torchvision.models.resnet152(num_classes=10).to(device)
    elif model == "densenet121":
        from cifar10_models.densenet import densenet121
        net=densenet121(pretrained=False).to(device)
        #net = torchvision.models.densenet121(num_classes=10).to(device)
    elif model ==  "inception_v3":
        from cifar10_models.inception import inception_v3
        net=inception_v3(pretrained=False).to(device)
    elif model ==  "googlenet":
        from cifar10_models.googlenet import googlenet
        net=googlenet(pretrained=False).to(device)
    elif model=="vit_small":
        from vit_small import ViT
        net = ViT(
        image_size = 32,
        patch_size = 4,
        num_classes = 10,
        dim = 512,
        depth = 6,
        heads = 8,
        mlp_dim = 512,
        dropout = 0.1,
        emb_dropout = 0.1
    ).to(device)   
    elif model=="swin":
        from swin import swin_t
        net = swin_t(window_size=4,
                    num_classes=10,
                    downscaling_factors=(2,2,2,1)).to(device)
                

    #milestones = [60, 120, 160]  # 在这些epoch时降低学习率
    #scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.2)

    rounds = 100
    train_loss_values = []
    test_accuracy_values = []
    train_accuracy_values = []

    for round in range(rounds):
        
        torch.save(net.state_dict(), f"{save_dir}/model_{round}.pth")
        test_accuracy = test(net, testloader, device)
        net,train_loss,train_accuracy=fed_avg(net, trainset, device, num_clients=5, epochs=10, batch_size=128)
        train_loss_values.append(train_loss)
        test_accuracy_values.append(test_accuracy)
        train_accuracy_values.append(train_accuracy)



        print(f"Epoch [{round + 1}/{rounds}] - Loss: {train_loss:.4f} - Test Accuracy: {test_accuracy:.2f}% - Train Accuracy: {train_accuracy:.2f}%")
        
    #torch.save(net.state_dict(), f"{save_dir}/acc_{test_accuracy:.4f}_model.pth")
    ''''''
    # 保存全部的测试集精度到本地txt文件
    with open(f"{save_dir}/test_accuracy.txt", "w") as f:
        for accuracy in test_accuracy_values:
            f.write(f"{accuracy}\n")
    with open(f"{save_dir}/train_accuracy.txt", "w") as f:
        for accuracy in train_accuracy_values:
            f.write(f"{accuracy}\n")
    with open(f"{save_dir}/test_loss.txt", "w") as f:
        for loss in train_loss_values:
            f.write(f"{loss}\n")
    # 绘制loss曲线图
    plt.figure(figsize=(8, 5))
    plt.plot(range(rounds), train_loss_values, label='Train Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(f"{save_dir}/train_loss_curve.png")
    

    # 绘制测试精度曲线图
    plt.figure(figsize=(8, 5))
    plt.plot(range(rounds), test_accuracy_values, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig(f"{save_dir}/test_accuracy_curve.png")
    
    # 绘制训练精度曲线图
    plt.figure(figsize=(8, 5))
    plt.plot(range(rounds), test_accuracy_values, label='Train Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig(f"{save_dir}/train_accuracy_curve.png")
    

if __name__ == "__main__":
    # 创建保存模型的文件夹
    save_dir = "FL_saved_models_"+args.model
    import os
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # 训练并保存模型参数
    train_and_save_model(model=args.model, save_dir=save_dir,device=args.device)
    
    
    
    



