from egnn_clean_dr01 import *
import argparse
import os
import sys
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from dgl.dataloading import GraphDataLoader
import dgl
import math
import numpy as np
import torch
import wandb
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from train_dataloader import buildGraph


def to_np(x):
    return x.cpu().detach().numpy()

def train_epoch(epoch, model, loss_fnc, dataloader, optimizer, scheduler, FLAGS):
    
    scheduler.step()
    print('epoch ' + str(epoch))
    for i, data_and_label in enumerate(dataloader):
        (nodeFeats, xyz_feats, edges, edge_att, y) = data_and_label
        
        
        n_nodes = len(nodeFeats[0])
        n_e = len(edges[0])
        nodeFeats = nodeFeats.to(FLAGS.device)
        xyz_feats = xyz_feats.to(FLAGS.device)
        edges[0] = edges[0].to(FLAGS.device)
        edges[1] = edges[1].to(FLAGS.device)
        edge_att = edge_att.to(FLAGS.device)
        y = y.to(FLAGS.device)
        model.train()
        optimizer.zero_grad()
        
        nodeFeats = nodeFeats.squeeze()
        xyz_feats = xyz_feats.squeeze()
        edges[0] = edges[0].squeeze()
        edges[1] = edges[1].squeeze() 
        edge_att = edge_att.squeeze()
        edge_att = edge_att.unsqueeze(dim=1)
        y = y.squeeze()
        y = y.unsqueeze(dim=1)
        pred, xyz = model(nodeFeats, xyz_feats, edges, edge_att)
        pred = torch.nn.Sigmoid()(pred)
        l1_loss = loss_fnc(pred, y)
        l1_loss.backward()
        optimizer.step()



def collate(samples):
    graphs, y = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(y)

def main(FLAGS, UNPARSED_ARGV):

    # Data 
    dataset = buildGraph(FLAGS.indir)
    train_loader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
    FLAGS.train_size = len(train_loader)

    # Model
    model = EGNN(in_node_nf=5461, hidden_nf=FLAGS.hidden_nf, out_node_nf=1, in_edge_nf=1, n_layers=FLAGS.num_layers,
             attention=True)
    #print(model)
    model.to(FLAGS.device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-16)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 40) # #of epochs 40
    save_path = os.path.join(FLAGS.save_dir, FLAGS.name + '.pt')

    # Train
    print('Begin training')
    for epoch in range(40): # #of epochs 40
        #print(f"Saved: {save_path}")
        task_loss = torch.nn.BCELoss() 
        train_epoch(epoch, model, task_loss, train_loader, optimizer, scheduler, FLAGS)
        torch.save(model.state_dict(), save_path)
    print(f"Training done. Model saved in: {save_path}") 


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Model parameters
    parser.add_argument('--num_layers', type=int, default=12,
            help="Number of equivariant layers")
    parser.add_argument('--hidden_nf', type=int, default=768,
            help="Number of hidden nf")
    parser.add_argument('--indir', type=str, default="DNA_train_data/",
            help="Input data directory")
    parser.add_argument('--save_dir', type=str, default="model/DNA",
            help="Directory name to save models")
    parser.add_argument('--seed', type=int, default=1992)

    FLAGS, UNPARSED_ARGV = parser.parse_known_args()

    # Name
    FLAGS.name = f'E-l{FLAGS.num_layers}-{FLAGS.hidden_nf}'

    # Create model directory
    if not os.path.isdir(FLAGS.save_dir):
        os.makedirs(FLAGS.save_dir)

    torch.manual_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)

    FLAGS.device = torch.device('cuda:2') if torch.cuda.is_available() else torch.device('cpu')

    main(FLAGS, UNPARSED_ARGV)
