import os
import pickle
from pprint import pprint
from typing import Dict
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
# from scipy.stats import pearsonr
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler, LabelEncoder
from torch.utils.data import DataLoader
import random
from src.dataset import GeoDoubleAngleDataset
from src.model import LitGeoDoubleAngleModel

def generate_next_filename(folder_path, filename):
    """
    filename(str): e.g.'input_cate_encoders.pkl'
    """
    max_index = -1
    try:
        folders = os.listdir(os.path.join(os.path.dirname(folder_path), 'lightning_logs'))
        for folder in folders:
            if folder.startswith('version_'):
                index_part = folder.replace('version_', '')
                index = int(index_part)
                if index > max_index:
                    max_index = index
                    print(max_index)
    except:
        None
    if filename.endswith('.pkl'):
        verNum = f'version_{max_index + 1}'
        if not os.path.exists(os.path.join(folder_path, verNum)):
            os.makedirs(os.path.join(folder_path, verNum))
    else:
        verNum = f'version_{max_index}'
    new_path = os.path.join(folder_path, verNum, filename)
    return new_path

def select_random_params_if_list(params):
    random.seed()
    selected_params = {}
    for key, value in params.items():
        if isinstance(value, list):
            selected_params[key] = random.choice(value)
        else:
            selected_params[key] = value
    return selected_params

def run_train_pipeline(task_config: Dict):
    pprint(task_config)
    if not os.path.exists(task_config["output_folder_path"]):
        os.makedirs(task_config["output_folder_path"])
    # outpath_supple = os.path.join(task_config["output_folder_path"], 'supplementaries')
    # if not os.path.exists(outpath_supple):
    #     os.makedirs(outpath_supple)

    train_data = pd.read_csv(task_config["train_data_path"])
    valid_data = pd.read_csv(task_config["valid_data_path"])

    if os.path.isdir(task_config["test_data_path"]):
        test_data_path = os.listdir(task_config["test_data_path"])
        test_data_path = [os.path.join(task_config["test_data_path"], p) for p in test_data_path if p.endswith('.csv')]
    else:
        test_data_path = [task_config["test_data_path"]]

    standard_scaler = StandardScaler()
    train_data[task_config["dataset"]["input_cont_cols"]] = \
        standard_scaler.fit_transform(train_data[task_config["dataset"]["input_cont_cols"]])
    valid_data[task_config["dataset"]["input_cont_cols"]] = \
        standard_scaler.transform(valid_data[task_config["dataset"]["input_cont_cols"]])
    save_path = os.path.join(task_config["output_folder_path"], "input_cont_scaler.pkl")
    # save_path = generate_next_filename(outpath_supple, "input_cont_scaler.pkl")
    with open(save_path, "wb") as f:
        pickle.dump(standard_scaler, f)

    if task_config["dataset"]["input_cate_cols"]:
        label_encoders = {}
        for c in task_config["dataset"]["input_cate_cols"]:
            label_encoder = LabelEncoder()
            train_data[c] = label_encoder.fit_transform(train_data[c])
            valid_data[c] = label_encoder.fit_transform(valid_data[c])
            label_encoders[c] = label_encoder
        save_path = os.path.join(task_config["output_folder_path"], "input_cate_encoders.pkl")
        # save_path = generate_next_filename(outpath_supple, "input_cate_encoders.pkl")
        with open(save_path, "wb") as f:
            pickle.dump(label_encoders, f)

    train_set = GeoDoubleAngleDataset(train_data, **task_config["dataset"])
    valid_set = GeoDoubleAngleDataset(valid_data, **task_config["dataset"])

    train_loader = DataLoader(train_set, shuffle=True, **task_config["dataloader"])
    valid_loader = DataLoader(valid_set, shuffle=False, **task_config["dataloader"])
    target_tasks = task_config['dataset']['task_target_cols']
    selected_model_params = select_random_params_if_list(task_config['model'])
    model = LitGeoDoubleAngleModel(
        train_set.num_of_feats,
        target_tasks,
        # **task_config["model"]
        **selected_model_params
    )

    model_checkpoint = ModelCheckpoint(**task_config["callbacks"]["model_checkpoint"])
    early_stopping = EarlyStopping(**task_config["callbacks"]["early_stopping"])
    trainer = pl.Trainer(
        default_root_dir=task_config["output_folder_path"],
        callbacks=[model_checkpoint, early_stopping],
        **task_config["trainer"]
    )
    trainer.fit(model, train_loader, valid_loader)
    trainer.save_checkpoint(os.path.join(task_config["output_folder_path"], "double_angle_model.ckpt"))
    
    for i, data_path in enumerate(test_data_path):
        test_data = pd.read_csv(data_path)
        test_data[task_config["dataset"]["input_cont_cols"]] = \
            standard_scaler.transform(test_data[task_config["dataset"]["input_cont_cols"]])

        if task_config["dataset"]["input_cate_cols"]:
            for c in task_config["dataset"]["input_cate_cols"]:
                test_data[c] = label_encoders[c].transform(test_data[c])

        test_name = os.path.basename(test_data_path[i]).split(".")[0]
        test_set = GeoDoubleAngleDataset(test_data, **task_config["dataset"])
        test_loader = DataLoader(test_set, shuffle=False, **task_config["dataloader"])

        ckpt_path = trainer.checkpoint_callback.best_model_path
        if ckpt_path == "":
            print("Best Checkpoint not Found! Using Current Weights for Prediction ...")
            ckpt_path = None
        predictions = trainer.predict(model, dataloaders=test_loader, ckpt_path=ckpt_path)

        all_preds = torch.cat([p for p in predictions], dim=0)
        target_variables = task_config['dataset']['task_target_cols']
        for j in range(len(target_variables)):
            task_name = target_variables[j]
            print(f"Task {task_name}:")
            y_pred = all_preds[:,j].cpu().numpy()
            # y_true = test_set.task_targets_data[:, j]
            y_true = getattr(test_set, f"task_target_{task_name.lower()}")
            # y_true = test_set.targets[task_name][j]
            print(f"RMSE: {mean_squared_error(y_true, y_pred) ** 0.5:.3f}")
            print(f"R2: {r2_score(y_true, y_pred):.3f}")
            test_data[f"{task_name}_PRED"] = y_pred

        test_data.to_csv(os.path.join(task_config["output_folder_path"], f"{test_name}_pred.csv"), index=False)

def run_train_inference_pipeline(task_config: Dict):
    pprint(task_config)
    if not os.path.exists(task_config["output_folder_path"]):
        os.makedirs(task_config["output_folder_path"])

    if os.path.isdir(task_config["test_data_path"]):
        test_data_path = os.listdir(task_config["test_data_path"])
        test_data_path = sorted(test_data_path)
        test_data_path = [os.path.join(task_config["test_data_path"], p) for p in test_data_path]
    else:
        test_data_path = [task_config["test_data_path"]]

    model = LitGeoDoubleAngleModel.load_from_checkpoint(task_config["model_checkpoint_path"])
    trainer = pl.Trainer(
        default_root_dir=task_config["output_folder_path"], 
        logger = False,
        **task_config["trainer"])

    for i, data_path in enumerate(test_data_path):
        print(f"File:{data_path.split('/')[-1]}")
        if data_path.endswith('csv'):
            test_data = pd.read_csv(data_path)
        else:
            test_data = pd.read_pickle(data_path)

        '''Add task_target_cols'''
        for j in range(len(task_config['dataset']['task_target_cols'])):
            try:
                _ = test_data[task_config['dataset']['task_target_cols'][j]]
            except KeyError:
                test_data[task_config['dataset']['task_target_cols'][j]] = 1
    
        with open(task_config["standard_scaler_path"], "rb") as f:
            standard_scaler = pickle.load(f)
        test_data[task_config["dataset"]["input_cont_cols"]] = \
            standard_scaler.transform(test_data[task_config["dataset"]["input_cont_cols"]])
        
        if task_config["dataset"]["input_cate_cols"]:
            with open(task_config["label_encoders_path"], "rb") as f:
                label_encoders = pickle.load(f)
            for c in task_config["dataset"]["input_cate_cols"]:
                try:
                    test_data[c] = label_encoders[c].transform(test_data[c])
                except ValueError:
                    '''if label is unseen,transform them to 9999'''
                    unknown_class = 9999
                    label_c =  label_encoders[c]
                    exist_classes = list(label_c.classes_)
                    tmp_array = np.array(test_data[c])
                    res_array = np.where(np.isin(tmp_array,exist_classes),tmp_array,unknown_class)
                    test_data[c] = res_array
                    label_c.fit_transform(exist_classes+[unknown_class])
                    test_data[c] = label_c.transform(test_data[c])
            
        test_name = os.path.basename(test_data_path[i]).split(".")[0]
        test_set = GeoDoubleAngleDataset(test_data, **task_config["dataset"])
        test_loader = DataLoader(test_set, shuffle=False, **task_config["dataloader"])

        predictions = trainer.predict(model, dataloaders=test_loader, ckpt_path=task_config["model_checkpoint_path"])
        # predictions = torch.cat(predictions, dim=0).cpu().numpy()
        all_preds = torch.cat([p for p in predictions], dim=0)
        target_variables = task_config['dataset']['task_target_cols']

        for j in range(len(target_variables)):
            task_name = target_variables[j]
            print(f"Task {task_name}:")
            y_true = getattr(test_set, f"task_target_{task_name.lower()}")
            y_pred = all_preds[:, j].cpu().numpy()
            print(f"RMSE: {mean_squared_error(y_true, y_pred) ** 0.5:.3f}")
            print(f"R2: {r2_score(y_true, y_pred):.3f}")
            test_data[f"{task_name}_PRED"] = y_pred
        # test_data = test_data[['index'] + [f'{x}_PRED' for x in target_variables]]
        test_data.to_csv(os.path.join(task_config["output_folder_path"], f"{test_name}_pred.csv"), index=False)
