#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jan  6 10:04:32 2023

@author: cbone
"""
import numpy as np
import random




LAT =np.cos(np.radians(np.load('LAT.npy')))


clus_index_name = ['water','land','ADO','PDO','sud','nord','NINO','globe'] 
def get_mean_year_2(data,cluster=-1):
    map_index = ['map_water.npy','map_land.npy','map_ADO.npy','map_PDO.npy','map_sud.npy','map_nord.npy','map_NINO.npy']
    
    if(cluster==-1):

        t = 0
        div = 0
        for j in range(90):
            for k in range(180):
                t += data[:,j, k] * LAT[j]
    
                div += LAT[j]
        t /= div
        return t
    else :
        map_clus = np.load(map_index[cluster])
        t = 0
        div = 0
        for j in range(90):
            for k in range(180):
                if(map_clus[j,k]==1):
                    t += data[:,j, k] * LAT[j]
                    div += LAT[j]
        t /= div
        return t



def moyenne_model(data,n,nombre_max=100,cluster=-1):
    result = []
    for i in range(nombre_max):
        
        index = random.sample(range(0, data.shape[0]), n)
        data_memb = np.mean(data[index],axis=0)
        rmse = np.sqrt(np.mean(get_mean_year_2(np.square(data_mean-data_memb),cluster)))
        result.append(rmse)
        print(rmse)
    result = np.array(result)
    np.save(data_save+'result_'+str(n)+'_'+clus_index_name[cluster],result)
    print(np.mean(result))

data = np.load('../data_valid/LE_FGOALS.npy')[:,4:116]
data_mean = np.mean(data,axis=0)
data_save = 'figures/FGOALS_mean/'
for i in range(26,40):
    for clus in [-1]:
        moyenne_model(data,i,cluster=clus)
        
        
data = np.load('../data_valid/LE_MPI-ESM_1950_2020.npy')[:,4:116]
data_mean = np.mean(data,axis=0)
data_save = 'figures/MPI_ESM_mean/'
for i in range(26,40):
    for clus in [-1]:
        moyenne_model(data,i,cluster=clus)

    
    
    
    