import math
from typing import List

import torch
import torch.nn as nn


class NumericalFeatureTokenizer(nn.Module):

    def __init__(self, n_feats: int, d_model: int):
        super().__init__()

        self.weight = nn.Parameter(torch.Tensor(n_feats, d_model))
        self.bias = nn.Parameter(torch.Tensor(n_feats, d_model))

        nn.init.normal_(self.weight, std=1 / math.sqrt(d_model))
        nn.init.normal_(self.bias, std=1 / math.sqrt(d_model))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.weight.unsqueeze(0) * x.unsqueeze(-1)
        x = x + self.bias.unsqueeze(0)
        return x


class CategoricalFeatureTokenizer(nn.Module):

    def __init__(self, n_cates: List[int], d_model: int):
        super().__init__()

        category_offsets = torch.tensor([0] + n_cates[:-1]).cumsum(0)
        self.register_buffer("category_offsets", category_offsets, persistent=False)

        self.embedding = nn.Embedding(sum(n_cates), d_model)
        self.bias = nn.Parameter(torch.Tensor(len(n_cates), d_model))

        nn.init.normal_(self.bias, std=1 / math.sqrt(d_model))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embedding(x + self.category_offsets.unsqueeze(0))
        x + self.bias.unsqueeze(0)
        return x


class FeatureTokenizer(nn.Module):

    def __init__(self, n_feats: int, n_cates: List[int], d_model: int):
        super().__init__()

        self.numerical_tokenizer = NumericalFeatureTokenizer(n_feats, d_model)
        if n_cates:
            self.categorical_tokenizer = CategoricalFeatureTokenizer(n_cates, d_model)
        else:
            self.categorical_tokenizer = None

    def forward(self, x_cont: torch.Tensor, x_cate: torch.Tensor) -> torch.Tensor:
        if self.categorical_tokenizer:
            x = [self.numerical_tokenizer(x_cont), self.categorical_tokenizer(x_cate)]
            return torch.cat(x, dim=1)
        else:
            return self.numerical_tokenizer(x_cont)


class GlobalTokenizer(nn.Module):

    def __init__(self, d_model: int):
        super().__init__()
        self.weight = nn.Parameter(torch.Tensor(d_model))
        nn.init.normal_(self.weight, std=1 / math.sqrt(d_model))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]
        return torch.cat([self.weight.expand(batch_size, 1, -1), x], dim=1)


class GeoTransformerEncoder(nn.Module):

    def __init__(self, n_feats: int, n_cates: List[int], d_model: int, n_head: int, n_layer: int, p_drop: float):
        super().__init__()

        self.feature_tokenizer = FeatureTokenizer(n_feats, n_cates, d_model)
        self.global_tokenizer = GlobalTokenizer(d_model)
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(d_model=d_model,
                                                     nhead=n_head,
                                                     dim_feedforward=d_model * 2,
                                                     dropout=p_drop,
                                                     batch_first=True),
            num_layers=n_layer,
        )

    def forward(self, x_cont: torch.Tensor, x_cate: torch.Tensor) -> torch.Tensor:
        x = self.feature_tokenizer(x_cont, x_cate)
        x = self.global_tokenizer(x)
        x = self.transformer_encoder(x)
        return x[:, 0, :]


class MLPBlock(nn.Module):

    def __init__(self, d_model: int, p_drop: float):
        super().__init__()

        self.linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(p_drop)
        self.activation = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.activation(self.linear(x)))

class MLPDecoder(nn.Module):
    def __init__(self, d_model: int, n_layer: int, p_drop: float, n_target_taks: int):
        super().__init__()
        blocks = [MLPBlock(d_model, p_drop) for _ in range(n_layer)]
        self.blocks = nn.Sequential(*blocks)
        self.aod_head = nn.Sequential(
                nn.Linear(d_model, n_target_taks),
                nn.ReLU()
        )
        self.fmf_head = nn.Sequential(
            nn.Linear(d_model, 1),
            nn.Sigmoid()
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.blocks(x)
        out = self.aod_head(x)
        return out


class AngleDataTokenizer(nn.Module):

    def __init__(self, n_angle: int, d_model: int):
        super().__init__()
        self.linear = nn.Linear(n_angle * 2, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_sin = torch.sin(x / 180 * math.pi)
        x_cos = torch.cos(x / 180 * math.pi)
        x = torch.cat([x_sin, x_cos], dim=1)
        return self.linear(x)


class GeoDoubleAngleEncoder(nn.Module):

    def __init__(self, num_of_feats: dict, d_model: int, n_head: int, n_layer: int, p_drop: float, double_angle: bool):
        super().__init__()
        self.num_of_feats = num_of_feats
        self.feats_value = list(self.num_of_feats.values())
        for i in range((len(self.num_of_feats)-2)//2):
            exec(f"self.angle_feature_tokenizer_{i+1} = NumericalFeatureTokenizer(self.feats_value[{2*i}], d_model)")
            exec(f"self.angle_data_tokenizer_{i+1} = AngleDataTokenizer(self.feats_value[{2*i+1}], d_model)")

        self.normal_tokenizer = FeatureTokenizer(self.num_of_feats['self.num_cont_feats'], self.num_of_feats['self.num_cate_types'], d_model)
        self.global_tokenizer = GlobalTokenizer(d_model)

        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(d_model=d_model,
                                                     nhead=n_head,
                                                     dim_feedforward=d_model * 2,
                                                     dropout=p_drop,
                                                     batch_first=True),
            num_layers=n_layer,
        )
    def forward(self,x_angle_feats: list[torch.Tensor], x_angle_datas: list[torch.Tensor], x_cont: torch.Tensor, x_cate:torch.tensor) -> torch.tensor:
        x = self.normal_tokenizer(x_cont, x_cate)  #torch.Size([2, 25, 256]) ,2��batch size
        x_feats_datas = []
        for i in range((len(self.num_of_feats)-2)//2):
            exec(f"x_angle_feat{i+1}=self.angle_feature_tokenizer_{i+1}(x_angle_feats[{i}])")
            exec(f"x_angle_data{i+1}=self.angle_data_tokenizer_{i+1}(x_angle_datas[{i}])")
            exec(f"x{i+1}=x_angle_feat{i+1} + x_angle_data{i+1}.unsqueeze(1)")
            exec(f"x_feats_datas.append(x{i+1})")
        x_feats_datas.append(x)
        x = torch.cat(x_feats_datas,dim=1)
        x = self.global_tokenizer(x)
        x = self.transformer_encoder(x)
        return x[:, 0, :]
