# Created by xunannancy at 2021/9/25
"""
refer to codes at: https://github.com/rtqichen/torchdiffeq.git
"""
import warnings
warnings.filterwarnings('ignore')
import torch.nn as nn
import torch
from torchdiffeq import odeint
import numpy as np
from utils import Pytorch_DNN_exp, Pytorch_DNN_validation, Pytorch_DNN_testing, merge_parameters, print_network, \
    task_prediction_horizon, run_evaluate_V3
import yaml
from collections import OrderedDict
import os
from FNN import HistoryConcatLoader
from sklearn.model_selection import ParameterGrid
import json
import argparse

class LatentODEfunc(nn.Module):
    def __init__(self, latent_dim=4, nhidden=20):
        super(LatentODEfunc, self).__init__()
        self.elu = nn.ELU(inplace=True)
        self.fc1 = nn.Linear(latent_dim, nhidden)
        self.fc2 = nn.Linear(nhidden, nhidden)
        self.fc3 = nn.Linear(nhidden, latent_dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.fc1(x)
        out = self.elu(out)
        out = self.fc2(out)
        out = self.elu(out)
        out = self.fc3(out)
        return out

class RecognitionRNN(nn.Module):
    def __init__(self, latent_dim=4, obs_dim=2, nhidden=25):
        super(RecognitionRNN, self).__init__()
        self.nhidden = nhidden
        self.i2h = nn.Linear(obs_dim + nhidden, nhidden)
        self.h2o = nn.Linear(nhidden, latent_dim * 2)

    def forward(self, x, h):
        combined = torch.cat((x, h), dim=1)
        h = torch.tanh(self.i2h(combined))
        out = self.h2o(h)
        return out, h

    def initHidden(self, batch_size, device):
        return torch.zeros(batch_size, self.nhidden, device=device)

class Decoder(nn.Module):
    def __init__(self, latent_dim=4, obs_dim=2, nhidden=20):
        super(Decoder, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(latent_dim, nhidden)
        self.fc2 = nn.Linear(nhidden, obs_dim)

    def forward(self, z):
        out = self.fc1(z)
        out = self.relu(out)
        out = self.fc2(out)
        return out


def log_normal_pdf(x, mean, logvar):
    const = torch.from_numpy(np.array([2. * np.pi])).float().to(x.device)
    const = torch.log(const)
    return -.5 * (const + logvar + (x - mean) ** 2. / torch.exp(logvar))

def normal_kl(mu1, lv1, mu2, lv2):
    v1 = torch.exp(lv1)
    v2 = torch.exp(lv2)
    lstd1 = lv1 / 2.
    lstd2 = lv2 / 2.

    kl = lstd2 - lstd1 + ((v1 + (mu1 - mu2) ** 2.) / (2. * v2)) - .5
    return kl

class LatentODENet(nn.Module):
    def __init__(self,
                 sliding_window, external_features, history_column_names, target_val_column_names, normalization,
                 latent_dim, nhidden, rnn_nhidden, noise_std
                 ):
        super().__init__()

        self.history_column_names = history_column_names
        self.sliding_window = sliding_window
        self.external_features = external_features
        self.target_val_column_names = target_val_column_names
        self.noise_std = noise_std
        self.normalization = normalization

        # TODO: mode
        self.prediction_horizon = task_prediction_horizon['wind']

        self.actual_prediction_horizon = np.concatenate([np.arange(self.sliding_window+1), self.sliding_window+1+np.array(self.prediction_horizon)-1])
        self.selected_prediction_horizon = self.sliding_window+1+np.arange(len(self.prediction_horizon))
        self.actual_prediction_horizon = self.actual_prediction_horizon/max(self.actual_prediction_horizon)

        obs_dim = len(self.history_column_names) + len(self.external_features)

        self.func = LatentODEfunc(latent_dim, nhidden)
        self.rec = RecognitionRNN(latent_dim, obs_dim, rnn_nhidden)
        self.dec = Decoder(latent_dim, len(self.history_column_names), nhidden)

        self.latent_dim = latent_dim

    def forward(self, x):
        """
        :param x: [batch_size, (sliding_window+1)*(loc_index+external_features)]
        :param x_t: []
        :return:
        """
        batch_size, cur_device = x.shape[0], x.device
        x = x.reshape([batch_size, self.sliding_window+1, len(self.history_column_names)+len(self.external_features)])
        h = self.rec.initHidden(batch_size, cur_device)
        for t in reversed(range(self.sliding_window+1)):
            obs = x[:, t, :]
            out, h = self.rec.forward(obs, h)
        self.qz0_mean, self.qz0_logvar = out[:, :self.latent_dim], out[:, self.latent_dim:]
        epsilon = torch.randn(self.qz0_mean.size()).to(cur_device)
        self.z0 = epsilon * torch.exp(.5 * self.qz0_logvar) + self.qz0_mean
        pred_z = odeint(self.func, self.z0, torch.tensor(self.actual_prediction_horizon, dtype=torch.float, device=cur_device)).permute(1, 0, 2)
        # pred_x: [batch_size, sliding_window+1+horizon, loc_index+external_features]
        self.pred_x = self.dec(pred_z)
        pred = self.pred_x[:, self.selected_prediction_horizon, :].reshape([batch_size, len(self.target_val_column_names)])
        if self.normalization == 'minmax':
            pred = torch.sigmoid(pred)
        return pred

    # def loss_function_1(self, batch):
    #     x, y, flag = batch
    #     batch_size, cur_device = x.shape[0], x.device
    #     pred = self.forward(x)
    #     pred = pred.reshape([batch_size, len(self.prediction_horizon), len(self.loc_index)])
    #     # compute loss
    #     noise_std_ = torch.zeros(pred.size()).to(cur_device) + self.noise_std
    #     noise_logvar = 2. * torch.log(noise_std_).to(cur_device)
    #     # [batch_size, seqlen, obs_dim]
    #     logpx = log_normal_pdf(
    #         y.reshape([batch_size, len(self.prediction_horizon), len(self.loc_index)]), pred, noise_logvar)
    #     logpx = (logpx * flag.reshape([batch_size, len(self.prediction_horizon), len(self.loc_index)])).sum(-1).sum(-1)
    #     pz0_mean = pz0_logvar = torch.zeros(self.z0.size()).to(cur_device)
    #     analytic_kl = normal_kl(self.qz0_mean, self.qz0_logvar,
    #                             pz0_mean, pz0_logvar).sum(-1)
    #     loss = torch.mean(-logpx + analytic_kl, dim=0)
    #     return loss, pred.reshape([batch_size, len(self.prediction_horizon)*len(self.loc_index)])

    def loss_function(self, batch):
        x, y, _ = batch
        batch_size, cur_device = x.shape[0], x.device
        self.forward(x)
        # compute loss
        noise_std_ = torch.zeros(self.pred_x.size()).to(cur_device) + self.noise_std
        noise_logvar = 2. * torch.log(noise_std_).to(cur_device)
        # [batch_size, seqlen, obs_dim]
        gt1 = x.reshape([batch_size, self.sliding_window+1, len(self.history_column_names)+len(self.external_features)])[:, :, :len(self.history_column_names)]
        gt2 = y.reshape([batch_size, len(self.prediction_horizon), len(self.history_column_names)])
        pred = self.pred_x[:, self.selected_prediction_horizon, :].reshape([batch_size, len(self.target_val_column_names)])
        if self.normalization == 'none':
            logpx = log_normal_pdf(
                torch.cat([torch.log(gt1), torch.log(gt2)], dim=1), self.pred_x, noise_logvar).sum(-1).sum(-1)
            pred = torch.exp(pred)
        else:
            logpx = log_normal_pdf(
                torch.cat([gt1, gt2], dim=1), self.pred_x, noise_logvar).sum(-1).sum(-1)
        pz0_mean = pz0_logvar = torch.zeros(self.z0.size()).to(cur_device)
        analytic_kl = normal_kl(self.qz0_mean, self.qz0_logvar,
                                pz0_mean, pz0_logvar).sum(-1)
        loss = torch.mean(-logpx + analytic_kl, dim=0)
        return loss, pred

    # def loss_function(self, batch):
    #     x, y, flag = batch
    #     pred = self.forward(x)
    #     loss = torch.mean(nn.MSELoss(reduction='none')(pred, y) * flag)
    #     return loss, pred

class LatentODE_exp(Pytorch_DNN_exp):
    def __init__(self, file, param_dict, config):
        super().__init__(file, param_dict, config)

        self.dataloader = HistoryConcatLoader(
            file,
            param_dict,
            config
        )
        self.model = self.load_model()
        print_network(self.model)

    def load_model(self):
        model = LatentODENet(
            sliding_window=self.param_dict['sliding_window'],
            external_features=self.config['exp_params']['external_features'],
            history_column_names=self.dataloader.history_column_names,
            target_val_column_names=self.dataloader.target_val_column_names,
            latent_dim=self.param_dict['latent_dim'],
            nhidden=self.param_dict['nhidden'],
            rnn_nhidden=self.param_dict['rnn_nhidden'],
            noise_std=self.param_dict['noise_std'],
            normalization=self.param_dict['normalization']
        )
        return model

def grid_search_LatentODE(config, num_files):
    # set random seed
    torch.manual_seed(config['logging_params']['manual_seed'])
    torch.cuda.manual_seed(config['logging_params']['manual_seed'])
    np.random.seed(config['logging_params']['manual_seed'])

    saved_folder = os.path.join(config['logging_params']['save_dir'], config['logging_params']['name'])
    flag = True
    while flag:
        if config['exp_params']['test_flag']:
            last_version = config['exp_params']['last_version'] - 1
        else:
            if not os.path.exists(saved_folder):
                os.makedirs(saved_folder)
                last_version = -1
            else:
                last_version = sorted([int(i.split('_')[1]) for i in os.listdir(saved_folder) if i.startswith('version_')])[-1]
        log_dir = os.path.join(saved_folder, f'version_{last_version+1}')
        if config['exp_params']['test_flag']:
            assert os.path.exists(log_dir)
            flag = False
        else:
            try:
                os.makedirs(log_dir)
                flag = False
            except:
                flag = True
    print(f'log_dir: {log_dir}')

    data_folder = config['exp_params']['data_folder']
    file_list = sorted([i for i in os.listdir(data_folder) if 'zone' in i and i.endswith('.csv')])[:num_files]

    param_grid = {
        'sliding_window': config['exp_params']['sliding_window'],
        'batch_size': config['exp_params']['batch_size'],
        'learning_rate': config['exp_params']['learning_rate'],
        'normalization': config['exp_params']['normalization'],

        'latent_dim': config['model_params']['latent_dim'],
        'nhidden': config['model_params']['nhidden'],
        'rnn_nhidden': config['model_params']['rnn_nhidden'],
        'noise_std': config['model_params']['noise_std'],
    }
    param_dict_list = list(ParameterGrid(param_grid))

    """
    getting validation results
    """
    for file in file_list:
        cur_log_dir = os.path.join(log_dir, file.split('.')[0])
        if not config['exp_params']['test_flag']:
            if not os.path.exists(cur_log_dir):
                os.makedirs(cur_log_dir)
            Pytorch_DNN_validation(os.path.join(data_folder, file), param_dict_list, cur_log_dir, config, LatentODE_exp)
            """
            hyperparameters selection
            """
            summary = OrderedDict()
            for param_index, param_dict in enumerate(param_dict_list):
                param_dict = OrderedDict(param_dict)
                setting_name = 'param'
                for key, val in param_dict.items():
                    setting_name += f'_{key[0].capitalize()}{val}'

                model_list = [i for i in os.listdir(os.path.join(cur_log_dir, setting_name, 'version_0')) if i.endswith('.ckpt')]
                assert len(model_list) == 1
                perf = float(model_list[0][model_list[0].find('avg_val_metric=')+len('avg_val_metric='):model_list[0].find('.ckpt')])
                with open(os.path.join(cur_log_dir, setting_name, 'version_0', 'std.txt'), 'r') as f:
                    std_text = f.readlines()
                    std_list = [[int(i.split()[0]), list(map(float, i.split()[1].split('_')))] for i in std_text]
                    std_dict = dict(zip(list(zip(*std_list))[0], list(zip(*std_list))[1]))
                best_epoch = int(model_list[0][model_list[0].find('best-epoch=')+len('best-epoch='):model_list[0].find('-avg_val_metric')])
                std = std_dict[best_epoch]
                # perf = float(model_list[0][model_list[0].find('avg_val_metric=')+len('avg_val_metric='):model_list[0].find('-std')])
                # std = float(model_list[0][model_list[0].find('-std=')+len('-std='):model_list[0].find('.ckpt')])
                summary['_'.join(map(str, list(param_dict.values())))] = [perf, std]
            with open(os.path.join(cur_log_dir, 'val_summary.json'), 'w') as f:
                json.dump(summary, f, indent=4)

            selected_index = np.argmin(np.array(list(summary.values()))[:, 0])
            selected_params = list(summary.keys())[selected_index]
            param_dict = {
                'batch_size': int(selected_params.split('_')[0]),
                'latent_dim': int(selected_params.split('_')[1]),
                'learning_rate': float(selected_params.split('_')[2]),
                'nhidden': int(selected_params.split('_')[3]),
                'noise_std': float(selected_params.split('_')[4]),
                'normalization': selected_params.split('_')[5],
                'rnn_nhidden': int(selected_params.split('_')[6]),
                'sliding_window': int(selected_params.split('_')[7]),
                'std': np.array(list(summary.values()))[selected_index][-1],
            }
            # save param
            with open(os.path.join(cur_log_dir, 'param.json'), 'w') as f:
                json.dump(param_dict, f, indent=4)

        """
        prediction on testing
        """
        with open(os.path.join(cur_log_dir, 'param.json'), 'r') as f:
            param_dict = json.load(f)
        Pytorch_DNN_testing(os.path.join(data_folder, file), param_dict, cur_log_dir, config, LatentODE_exp)


    if not os.path.exists(os.path.join(log_dir, 'config.yaml')):
        with open(os.path.join(log_dir, 'config.yaml'), 'w') as f:
            yaml.dump(config, f)

    # run evaluate
    evaluate_config = {
        'exp_params': {
            'prediction_path': log_dir,
            'prediction_interval': config['exp_params']['prediction_interval'],
        }
    }
    run_evaluate_V3(config=evaluate_config, verbose=False)
    return

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=None)
    parser.add_argument('--manual_seed', '-manual_seed', type=int, help='random seed')
    parser.add_argument('--num_files', '-num_files', type=int, default=3, help='number of files to predict')

    parser.add_argument('--sliding_window', '-sliding_window', type=str, help='list of sliding_window for arima')
    parser.add_argument('--selection_metric', '-selection_metric', type=str, help='metrics to select hyperparameters, one of [RMSE, MAE, MAPE]',)
    parser.add_argument('--train_valid_ratio', '-train_valid_ratio', type=float, help='select hyperparameters on validation set')
    parser.add_argument('--time_features', '-time_features', type=str, help='list of time feature names')
    parser.add_argument('--external_features', '-external_features', type=str, help='list of external feature names')

    # model-specific features
    parser.add_argument('--batch_size', '-batch_size', type=str, help='list of batch_size')
    parser.add_argument('--max_epochs', '-max_epochs', type=int, help='number of epochs')
    parser.add_argument('--learning_rate', '-learning_rate', type=int, help='list of learning rate')
    parser.add_argument('--gpus', '-g', type=str)#, default='[1]')
    parser.add_argument('--dropout', '-dropout', type=str, help='list of dropout rates')

    parser.add_argument('--normalization', '-normalization', type=str, help='list of normalization options')
    parser.add_argument('--latent_dim', '-latent_dim', type=str, help='list of latent_dim options')
    parser.add_argument('--nhidden', '-nhidden', type=str, help='list of nhidden options')
    parser.add_argument('--rnn_nhidden', '-rnn_nhidden', type=str, help='list of rnn_nhidden options')
    parser.add_argument('--noise_std', '-noise_std', type=str, help='list of noise_std options')

    args = vars(parser.parse_args())
    with open('./../configs/NeuralODE.yaml', 'r') as file:
        try:
            config = yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)
    config = merge_parameters(args, config)
    print(f'after merge: config, {config}')

    print('gpus: ', config['trainer_params']['gpus'])
    if np.sum(config['trainer_params']['gpus']) < 0:
        config['trainer_params']['gpus'] = 0

    grid_search_LatentODE(config, num_files=args['num_files'])



