from typing import List, Dict
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
import torch

class GeoDoubleAngleDatasetTest(Dataset):
    def __init__(self, csv_data: pd.DataFrame, input_angle_cont_cols: List[List[str]],
                 input_angle_data_cols: List[List[str]], input_cont_cols: List[str],
                 input_cate_cols: List[str], task_target_cols: List[str]):
        assert len(input_angle_cont_cols) == len(input_angle_data_cols), "The length of angle list is not the same as data list."
        
        self.csv_data = csv_data
        self.angle_feats = [csv_data[cols].values.astype(np.float32) for cols in input_angle_cont_cols]
        self.angle_data = [csv_data[cols].values.astype(np.float32) for cols in input_angle_data_cols]
        self.cont_feats_data = csv_data[input_cont_cols].values.astype(np.float32)
        self.cate_feats_data = csv_data[input_cate_cols].values if input_cate_cols else None
        self.targets = {col: csv_data[col].values.astype(np.float32) for col in task_target_cols}

        # Count features for reporting or model initialization
        self.num_of_feats = {
            f'num_angle_feats_{i+1}': len(cols) for i, cols in enumerate(input_angle_cont_cols)
        }
        self.num_of_feats.update({
            f'num_angle_data_{i+1}': len(cols) for i, cols in enumerate(input_angle_data_cols)
        })
        self.num_of_feats['num_cont_feats'] = len(input_cont_cols)
        if input_cate_cols:
            self.num_cate_types = [len(set(csv_data[c])) for c in input_cate_cols]
            self.num_of_feats['num_cate_types'] = self.num_cate_types

    def __len__(self):
        return len(self.csv_data)

    def __getitem__(self, idx):
        angle_feats = [feat[idx] for feat in self.angle_feats]
        angle_data = [data[idx] for data in self.angle_data]
        temp_dict = {
            "ANGLE_FEAT": angle_feats,
            "ANGLE_DATA": angle_data,
            "CONT_FEAT": self.cont_feats_data[idx],
            "CATE_FEAT": self.cate_feats_data[idx] if self.cate_feats_data is not None else 0
        }
        temp_dict.update({target: self.targets[target][idx] for target in self.targets})
        return temp_dict
