from icecream import ic
import utils.options as uo
from graph.sai_graph_generator import SAIGraph
import utils.sampling as usamp
from torchvision import datasets, transforms
from models.Nets import MLP
from paiv import simplePaiv
import random
from clock import Clock
from message import Message
import torch
import utils.utils as uutils
from training_styles.cent_train import CentralisedLearning
from training_styles.fed_train import FederatedLearning
from training_styles.dec_train import DecentralisedLearning

import json
import os


def write_args(args):
    """Writes the args of the current run into a json file in the stats directory
    """
    filename = "stats/" + args.outfolder + "/config.json"
    os.makedirs(os.path.dirname(filename), exist_ok=True)

    with open(filename, 'w') as f:
        json.dump(args.__dict__, f, indent=2)


def main():
    args = uo.args_parser()  # read args
    write_args(args)  # writing the args for logging

    args.device = torch.device('cuda:{}'.format(
        args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    # setting up the social graph based on the input parameters
    g = SAIGraph(args)

    # dataset split
    dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize(
                                           (0.1307,), (0.3081,))
                                   ]))
    dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize(
                                          (0.1307,), (0.3081,))
                                  ]))
    torch.manual_seed(args.seed)

    if args.noniid:
        data_partitions = usamp.mnist_noniid(
            dataset_train, g.sai_graph.number_of_nodes())
    else:
        data_partitions = usamp.mnist_iid(
            dataset_train, g.sai_graph.number_of_nodes())

    run_dec = False
    run_fed = False
    run_cent = False

    if args.run_all:
        run_dec = True
        run_fed = True
        run_cent = True
    else:
        run_dec = args.run_dec
        run_fed = args.run_fed
        run_cent = args.run_cent

    if not (run_cent or run_dec or run_fed):
        print('WARNING: No training style selected. Please choose at least one.')
    # run the learning

    # running the SAI decentralised algo
    if run_dec:
        dec_learn = DecentralisedLearning(
            args, g, dataset_train, data_partitions, dataset_test)
        dec_learn.run()

    # running federated learning
    if run_fed:
        fed_learn = FederatedLearning(
            args, g, dataset_train, data_partitions, dataset_test=dataset_test)
        fed_learn.run()

    # # running centralised learning
    if run_cent:
        centr_learn = CentralisedLearning(
            args, dataset_train=dataset_train, dataset_test=dataset_test)
        centr_learn.run()


if __name__ == "__main__":
    main()
