from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from collections import Counter
from tqdm import tqdm
import numpy as np
import pandas as pd
import esm

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(1280, 64)
        self.fc2 = nn.Linear(64, 16)
        self.fc4 = nn.Linear(16, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc4(x)
        x = F.softmax(x)
        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net=Net().to(device)

net = torch.load('un50big.pt')
net.eval()

FASTA_PATH = "./bglb.fasta"
EMB_PATH = "./bglb/" # Path to directory of embeddings
EMB_LAYER = 33

ys = []
Xs = []
for header, _seq in esm.data.read_fasta(FASTA_PATH):
    try:
        fn = f'{EMB_PATH}/{header[1:]}.pt'
        embs = torch.load(fn)
        Xs.append(embs['mean_representations'][EMB_LAYER])
        ys.append(header[1:])
    except:
        pass
Xs = torch.stack(Xs, dim=0).to(device)

out = net(Xs)
print(out)
