from typing import List, Dict

import numpy as np
import pandas as pd
from torch.utils.data import Dataset
import torch


class GeoDoubleAngleDataset(Dataset):

    def __init__(self, csv_data: pd.DataFrame, input_angle_cont_cols: List[List[str]],
                 input_angle_data_cols: List[List[str]], input_cont_cols: List[str],
                 input_cate_cols: List[str], task_target_cols: List):
        self.csv_data = csv_data
        self.input_angle_cont_cols=input_angle_cont_cols
        self.task_target_cols=task_target_cols
        self.num_of_feats = {}
        assert len(input_angle_cont_cols)==len(input_angle_data_cols),"the length of angele list is not same as data list"
        for i in range(0,len(input_angle_cont_cols)):
            exec(f"self.num_of_feats['self.num_angle_feats{i+1}']=len(input_angle_cont_cols[{i}])")
            exec(f"self.num_of_feats['self.num_angle_data{i+1}']=len(input_angle_data_cols[{i}])")
            exec(f"self.angle_feats_{i+1} = csv_data[input_angle_cont_cols[{i}]].values.astype(np.float32)")
            exec(f"self.angle_data_{i+1} = csv_data[input_angle_data_cols[{i}]].values.astype(np.float32)")
        exec(f"self.num_of_feats['self.num_cont_feats']=len(input_cont_cols)")

        if input_cate_cols:
            self.num_cate_types = [len(set(csv_data[c])) for c in input_cate_cols]
            exec(f"self.num_of_feats['self.num_cate_types']=self.num_cate_types")
            self.cate_feats_data = self.csv_data[input_cate_cols].values
        else:
            self.num_cate_types = None
            self.cate_feats_data = None

        self.cont_feats_data = self.csv_data[input_cont_cols].values.astype(np.float32)
        for task_target in task_target_cols:
            exec(f"self.task_target_{task_target.lower()}= self.csv_data['{task_target}'].values.astype(np.float32)")

    def __len__(self):
        return len(self.csv_data)

    def __getitem__(self, idx):
        temp_dic = {}
        angle_feats,angle_datas = [],[] 
        for i in range(0,len(self.input_angle_cont_cols)):
            exec(f"angle_feats.append(self.angle_feats_{i+1}[idx])")
            exec(f"angle_datas.append(self.angle_data_{i+1}[idx])")
        temp_dic["ANGLE_FEAT"] = angle_feats
        temp_dic["ANGLE_DATA"] = angle_datas
        temp_dic["CONT_FEAT"]=self.cont_feats_data[idx]
        temp_dic["CATE_FEAT"]=self.cate_feats_data[idx] if self.cate_feats_data is not None else 0
        for task_target in self.task_target_cols:
            exec(f"temp_dic['{task_target}'] = self.task_target_{task_target.lower()}[idx]")
        return temp_dic
