from typing import List

import pytorch_lightning as pl
import torch
import torch.nn as nn

from src.loss import RelativeErrorLoss
from src.module import (
    GeoDoubleAngleEncoder,
    MLPDecoder
)


class LitGeoDoubleAngleModel(pl.LightningModule):
    def __init__(self, num_of_feats: dict, target_task: list, d_model: int, n_tf_head: int, n_tf_layer: int, p_tf_drop: float, n_mlp_layer: int, p_mlp_drop: float, lr: float, double_angle: bool = True, loss: str = "mse"):
        super().__init__()
        self.save_hyperparameters()
        self.target_task = target_task
        #self.p_aod,self.p_fmf,self.p_faod,self.p_caod=nn.Parameter(torch.Tensor(1)),nn.Parameter(torch.Tensor(1)),nn.Parameter(torch.Tensor(1)),nn.Parameter(torch.Tensor(1))

        self.geo_encoder = GeoDoubleAngleEncoder(num_of_feats, d_model, n_tf_head, n_tf_layer, p_tf_drop, double_angle)
        self.task_decoder = MLPDecoder(d_model, n_mlp_layer, p_mlp_drop, len(target_task))

        self.lr = lr
        if loss == "mse":
            self.criterion = nn.MSELoss()
        elif loss == "rel":
            self.criterion = RelativeErrorLoss(a=0.15, b=0.05)
        else:
            raise ValueError("Loss is not Supported!")
    def forward(self, x_angle_feats: list[torch.Tensor], x_angle_datas: list[torch.Tensor], x_cont: torch.Tensor, x_cate: torch.tensor)-> torch.tensor:
        x = self.geo_encoder(x_angle_feats, x_angle_datas, x_cont, x_cate)
        return self.task_decoder(x)
	
    def _shared_step(self, batch):
        return self(batch["ANGLE_FEAT"], batch["ANGLE_DATA"], batch["CONT_FEAT"], batch["CATE_FEAT"])
    
    def training_step(self, batch, batch_idx):
        out = self._shared_step(batch)
        #loss = self.p_aod*self.criterion(out[:,0], batch[self.target_task[0]])
        loss = self.criterion(out[:,0], batch[self.target_task[0]])
        for i,task in enumerate(self.target_task[1:]):
            #exec(f"loss+=self.p_{task.lower()}*self.criterion(out[:,{i+1}],batch[task])")
            exec(f"loss+=self.criterion(out[:,{i+1}],batch[task])")
        # loss = self.criterion(aod, batch["AOD"]) + self.criterion(fmf, batch["FMF"])
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        out = self._shared_step(batch)
        #loss = self.p_aod*self.criterion(out[:,0], batch[self.target_task[0]])
        loss = self.criterion(out[:,0], batch[self.target_task[0]])
        for i,task in enumerate(self.target_task[1:]):
            #exec(f"loss+=self.p_{task.lower()}*self.criterion(out[:,{i+1}],batch[task])")
            exec(f"loss+=self.criterion(out[:,{i+1}],batch[task])")
        # loss = self.criterion(aod, batch["AOD"]) + self.criterion(fmf, batch["FMF"])
        self.log("val_loss", loss, prog_bar=True)

    def predict_step(self, batch, batch_idx):
        return self._shared_step(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=1e-3)
        return optimizer
        # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        #     optimizer=optimizer,
        #     T_0=2,
        #     T_mult=2,
        #     eta_min=1e-5
        # )
        # return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
