import re
import argparse
from argparse import RawTextHelpFormatter

import os

from collections import OrderedDict
import torch
import torchvision.models as models
import torchvision.models.detection as models_detection

from yaml import dump
import yamlordereddictloader

class ordereddict_dumper(yamlordereddictloader.SafeDumper):
    """ yaml dumper """
    
    def ignore_aliases(self, _data):
        return True

def write_yaml_file(filepath, content):
    """
    :param filepath: string that specifies the destination file path
    :param content: yaml string that needs to be written to the destination file
    :return: None
    """
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    with open(filepath, "w") as file:
        file.write(dump( content, default_flow_style= False, Dumper= ordereddict_dumper))

import torch.nn as nn
from torch.autograd import Variable

import numpy as np
import re

def my_summary(model, input_size, batch_size=-1, device="cuda", silent=False, 
    lS_o=None, lS_i=None, inputs=None, 
    targets_list=None, input_sizes=None, target_sizes=None):
    """
    Default to none arguments can be ignored for torchvision models
    """
    def register_hook(module):

        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)
            pc = re.compile(".*conv$")
            pl = re.compile(".*linear$")
            key =  module.__module__
            match_pc = pc.match(key)
            match_pl = pl.match(key)
            if match_pc or match_pl:

                m_key = "%s-%i" % (class_name, module_idx + 1)
                summary[m_key] = OrderedDict()
                summary[m_key]["type"] = "CONV"
                summary[m_key]["input_shape"] = list(input[0].size())
                summary[m_key]["input_shape"][0] = batch_size
                if isinstance(output, (list, tuple)):
                    summary[m_key]["output_shape"] = [
                        [-1] + list(o.size())[1:] for o in output
                    ]
                else:
                    summary[m_key]["output_shape"] = list(output.size())
                    summary[m_key]["output_shape"][0] = batch_size

                params = 0
                if hasattr(module, "weight") and hasattr(module.weight, "size"):
                    params += torch.prod(torch.LongTensor(list(module.weight.size())))
                    summary[m_key]["trainable"] = module.weight.requires_grad
                if hasattr(module, "bias") and hasattr(module.bias, "size"):
                    params += torch.prod(torch.LongTensor(list(module.bias.size())))

                summary[m_key]["nb_params"] = params
                N = batch_size
                if len(module.weight.size()) == 4:
                    groups = module.groups
                    _,C, Y, X = input[0].size()
                    _,K, Yo, Xo = output.size()
                    _, _, R, S = module.weight.size()
                    summary[m_key]["stride"] = module.stride
                    if groups == C:
                        summary[m_key]["type"]= "DSCONV"
                        K = 1
                else:
                    K, C = module.weight.size()
                    X, Xo, Y, Yo, R, S = 1,1,1,1,1,1
                    summary[m_key]["stride"] = None
                summary[m_key]["dimension_ic"] = (N, K, C, R, S, Y, X)
                summary[m_key]["dimension_oc"] = (N, K, C, R, S, Yo, Xo)


        if (
            not isinstance(module, nn.Sequential)
            and not isinstance(module, nn.ModuleList)
            and not (module == model)
        ):
            hooks.append(module.register_forward_hook(hook))

    device = device.lower()
    assert device in [
        "cuda",
        "cpu",
    ], "Input device is not valid, please specify 'cuda' or 'cpu'"

    if device == "cuda" and torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor

    no_input_size_original = False
    if input_size == None:
        no_input_size_original = True
        input_size = inputs.size()
        print(input_size)
    # multiple inputs to the network
    input_size_tuple = input_size
    if isinstance(input_size, tuple):
        input_size = [input_size]

    # Default input batch_size of 2 for batchnorm
    x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
        
    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook
    model.apply(register_hook)

    # make a forward pass
    if model.__module__ == 'torchvision.models.detection.ssd': 
        x = [torch.rand(3, 300, 300).to(device), torch.rand(3, 500, 400).to(device)]
        model.eval()
        model.forward(x)
    elif 'bert' in model.__module__:
        from transformers import BertTokenizer, BertModel
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        text = "Replace me by any text you'd like."
        x = tokenizer(text, return_tensors='pt').to(device)
        model(**x)
    elif 'dlrm' in type(model).__name__.lower():
        x = torch.rand(input_size_tuple).type(dtype)
        model(inputs.to("cpu"), lS_o.to("cpu"), lS_i.to("cpu"))
    elif 'unet' in type(model).__name__.lower():
        x = torch.rand(input_size_tuple).type(dtype)
        model(torch.tensor(x))
    
    elif no_input_size_original:
        model(inputs, targets_list, input_sizes, target_sizes)
    else: 
        model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()
    if not silent:
        print("----------------------------------------------------------------")
        line_new = "{:>20}  {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
        print(line_new)
        print("================================================================")
        total_params = 0
        total_output = 0
        trainable_params = 0
        for layer in summary:
            # input_shape, output_shape, trainable, nb_params
            line_new = "{:>20}  {:>25} {:>15}".format(
                layer,
                str(summary[layer]["output_shape"]),
                "{0:,}".format(summary[layer]["nb_params"]),
            )
            total_params += summary[layer]["nb_params"]
            total_output += np.prod(summary[layer]["output_shape"])
            if "trainable" in summary[layer]:
                if summary[layer]["trainable"] == True:
                    trainable_params += summary[layer]["nb_params"]
            print(line_new)

        # assume 4 bytes/number (float on cuda).
        total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
        total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
        total_params_size = abs(total_params * 4. / (1024 ** 2.))
        total_size = total_params_size + total_output_size + total_input_size

        print("================================================================")
        print("Total params: {0:,}".format(total_params))
        print("Trainable params: {0:,}".format(trainable_params))
        print("Non-trainable params: {0:,}".format(total_params - trainable_params))
        print("----------------------------------------------------------------")
        print("Input size (MB): %0.2f" % total_input_size)
        print("Forward/backward pass size (MB): %0.2f" % total_output_size)
        print("Params size (MB): %0.2f" % total_params_size)
        print("Estimated Total Size (MB): %0.2f" % total_size)
        print("----------------------------------------------------------------")
    return summary

# inception_v3 input should be 3, 299, 299

def write(my_summary, model_, outfile_='workload.yaml', path=None):
    '''
    Take summary and dump as workload.yaml.
    '''
    hier_dict = OrderedDict()
    wl_dict = OrderedDict()
    for key, val in my_summary.items():
        layer_dict = OrderedDict()
        layer_name = model_ + '_' + key
        pc = re.compile("^Conv")
        pl = re.compile("^Linear")
        match_pc = pc.match(key)
        match_pl = pl.match(key)
        if match_pc or match_pl:
            layer_dict['cycle'] = 1 # FIXME: hardcode to 1
            layer_dict['type'] = 'conv' #val["type"].lower()
            layer_dict['Hdilation'] = 1 # FIXME: hardcode dilation to 1
            layer_dict['Wdilation'] = 1
            if not match_pl:
                layer_dict['Hstride'] = ((val["stride"])[1])
                layer_dict['Wstride'] = ((val["stride"])[0])
            else:
                layer_dict['Hstride'] = 1
                layer_dict['Wstride'] = 1
            layer_dict['N'] = 1 # NOTE: determine N
            layer_dict['K'] = val["dimension_ic"][1]
            layer_dict['C'] = val["dimension_ic"][2]
            layer_dict['R'] = val["dimension_ic"][3]
            layer_dict['S'] = val["dimension_ic"][4]
            layer_dict['Y'] = val["dimension_ic"][5]
            layer_dict['X'] = val["dimension_ic"][6]

        wl_dict[layer_name] = layer_dict

    hier_dict['workload'] = wl_dict

    if path == None:
        filename = "./" + outfile_
    else: 
        filename = path + '/' + outfile_
    write_yaml_file(filename, hier_dict)

    return filename


if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter)
    parser.add_argument('--input_size', type=str, default="3,224,224", help='input size')
    parser.add_argument('--model', type=str, default="resnet50",
    help='-------\nmodel from torchvision choices: \n'
         'resnet18, resnet34, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d, wide_resnet101_2, \n'
         '-------\nother supported model examples: \n'
         'alexnet, vgg16, squeezenet, densenet, \n'
         'inception_v3, googlenet, shufflenet, \n'
         'mobilenet_v2, mnasnet,\n'
         'bert_base_uncased, \n'
         '-------\nmodel from torchvision.detection choices: \n'
         'ssd300_vgg16, \n'
         '-------\ncustom model supported with outside importing:\n'
         'rnnt, unet, dlrm\n')


    parser.add_argument('--outfile', type=str, default=f"workload.yaml", help='output file name')
    opt = parser.parse_args()
    INPUT_SIZE = tuple((int(d) for d in str.split(opt.input_size, ",")))

    print('Begin processing')
    print('Model name: ' + str(opt.model))
    if 'bert' in opt.model:
        print('Ignoring input size')
    if 'ssd300' in opt.model:
        opt.input_size = '3,300,300' 
        INPUT_SIZE = tuple((int(d) for d in str.split(opt.input_size, ",")))
        print(f'using {opt.input_size} for {opt.model}')
    if 'rnnt' in opt.model or 'unet' in opt.model or 'dlrm' in opt.model:
        print(f'{opt.model} is not supported with direct exe. clone original repo and import module.\n')
        exit()
    else:
        print('Input size: ' + str(INPUT_SIZE))

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    if 'ssd' in opt.model:
        model = getattr(models_detection, opt.model)()
        model.to(device)
        
    elif 'bert' in opt.model:
        from transformers import BertTokenizer, BertModel
        model_name = opt.model.replace('_', '-')
        tokenizer = BertTokenizer.from_pretrained(model_name)
        model = BertModel.from_pretrained(model_name)
        model.to(device)

    else: 
        model = getattr(models, opt.model)()
        model = model.to(device)

    my_summary_ = my_summary(model, INPUT_SIZE)
    filename = write(my_summary_, opt.model, opt.outfile, f'./workload/{opt.model}')
    print("Done converting to the DNN MODEL file at "+ filename)