import pandas as p
import math
import numpy as np
import itertools as it
from haversine import *
from haversine_script import *
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor
from sklearn.ensemble import ExtraTreesRegressor
import time
import json
from math import pi
from random import *

from ProxyFAUG import *

random_state = 42


def define_ProxyFAUG_hp( range_th = 20, maximum_points_per_cluster = 2,  p = 0.3, crossovers_per_pair = 11):
    ProxyFAUG_hp_dict = {
        'range_th': range_th,
        'maximum_points_per_cluster': maximum_points_per_cluster,
        'p': p,
        'crossovers_per_pair': crossovers_per_pair,
    }
    print(ProxyFAUG_hp_dict)
    return ProxyFAUG_hp_dict



def augment_and_merge_ProxyFAUG(x_train,y_train,ProxyFAUG_hp_dict):
    # Reading ProxyFAUG parameters
    range_th = ProxyFAUG_hp_dict['range_th']
    maximum_points_per_cluster = ProxyFAUG_hp_dict['maximum_points_per_cluster']
    p = ProxyFAUG_hp_dict['p']
    crossovers_per_pair = ProxyFAUG_hp_dict['crossovers_per_pair']

    # Augmenting
    x_aug_train, y_aug_train = augment(x_train, y_train, range_th, maximum_points_per_cluster, p, crossovers_per_pair)
    
    # Merging with the original training set
    x_train_aug_m = np.vstack((x_train, x_aug_train))
    y_train_aug_m = np.vstack((y_train, y_aug_train))
    return x_train_aug_m,  y_train_aug_m



def fit_predict_positioner_with_ExtraTrees(features_train, targets_train, features_val, targets_val, features_test, targets_test,random_state=42, n_estimators=100):
    M = ExtraTreesRegressor(n_estimators=100, random_state=random_state)
    M.fit(features_train, targets_train)
    predict_in_train = M.predict(features_train)
    predict_in_val = M.predict(features_val)
    predict_in_test = M.predict(features_test)

    return predict_in_train, predict_in_val, predict_in_test, M


def fit_predict_positioner_with_knn(features_train, targets_train, features_val, targets_val, features_test, targets_test, n_neighbors = 6, weights='distance', metric='braycurtis'):
    M = KNeighborsRegressor(n_neighbors = 6, weights='distance', metric='braycurtis')
    M.fit(features_train, targets_train)
    predict_in_train = M.predict(features_train)
    predict_in_val = M.predict(features_val)
    predict_in_test = M.predict(features_test)

    return predict_in_train, predict_in_val, predict_in_test, M


def calculated_positioning_performance_stats(predict_in_train, targets_train,
                                             predict_in_val, targets_val,
                                             predict_in_test, targets_test):
    # Error in the train set
    prediction_error_train = calculate_pairwise_error_list(predict_in_train, targets_train)

    # Error in the validation set
    prediction_error_val = calculate_pairwise_error_list(predict_in_val, targets_val)

    # Error in the test set
    prediction_error_test = calculate_pairwise_error_list(predict_in_test, targets_test)

    return prediction_error_train, prediction_error_val, prediction_error_test


def print_performance_stats(prediction_error_train, prediction_error_val, prediction_error_test):
    # Stats in the train set
    prediction_error_train_mean = statistics.mean(prediction_error_train)
    prediction_error_train_25th = np.percentile(prediction_error_train, 25)
    prediction_error_train_median = np.percentile(prediction_error_train, 50)
    prediction_error_train_75th = np.percentile(prediction_error_train, 75)

    print("Train set mean error: {:.2f}".format(prediction_error_train_mean))
    print("Train set 25th perc. error: {:.2f}".format(prediction_error_train_25th))
    print("Train set median error: {:.2f}".format(prediction_error_train_median))
    print("Train set 75th perc. error: {:.2f}".format(prediction_error_train_75th))
    print()

    # Stats in the val set
    prediction_error_val_mean = statistics.mean(prediction_error_val)
    prediction_error_val_25th = np.percentile(prediction_error_val, 25)
    prediction_error_val_median = np.percentile(prediction_error_val, 50)
    prediction_error_val_75th = np.percentile(prediction_error_val, 75)

    print("Val set mean error: {:.2f}".format(prediction_error_val_mean))
    print("Val set 25th perc. error: {:.2f}".format(prediction_error_val_25th))
    print("Val set median error: {:.2f}".format(prediction_error_val_median))
    print("Val set 75th perc. error: {:.2f}".format(prediction_error_val_75th))
    print()

    # Stats in the test set
    prediction_error_test_mean = statistics.mean(prediction_error_test)
    prediction_error_test_25th = np.percentile(prediction_error_test, 25)
    prediction_error_test_median = np.percentile(prediction_error_test, 50)
    prediction_error_test_75th = np.percentile(prediction_error_test, 75)

    print("Test set mean error: {:.2f}".format(prediction_error_test_mean))
    print("Test set 25th perc. error: {:.2f}".format(prediction_error_test_25th))
    print("Test set median error: {:.2f}".format(prediction_error_test_median))
    print("Test set 75th perc. error: {:.2f}".format(prediction_error_test_75th))
    print()