#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 10 13:15:43 2022

@author: pdavid
"""

import math
import os
import random

import matplotlib.pyplot as plt
import numpy as np

directory_script = os.path.dirname(__file__)
with open(os.path.join(directory_script,'../current_directory.txt'), 'r') as file:
    file_contents = file.read()
# Now, file_contents contains the entire content of the file as a string

source_directory=os.path.join(file_contents[:-1], 'src')
import sys
sys.path.append(source_directory)

from Green import Green
from Reconstruction_extended_space import reconstruction_extended_space
from Small_functions import plot_sketch
from Testing import Testing


def position_sources(dens, L, cyl_rad):
    """dens -> density in source/square milimeter
    L -> side length of the domain
    cyl_rad -> radius of the capillary free region
    """
    pos_s = np.zeros((0, 2))
    elem_square = 1 / (dens * 1e-6)
    cells = np.around(L / np.sqrt(elem_square)).astype(int)
    h = L / cells
    grid_x = np.linspace(h / 2, L - h / 2, cells)
    grid_y = grid_x

    center = np.array([L / 2, L / 2])

    for i in grid_x:
        for j in grid_y:
            temp_s = (np.random.rand(1, 2) - 1 / 2) * h * 0.8 + np.array([[i, j]])
            if np.linalg.norm(temp_s - center) > cyl_rad:
                pos_s = np.concatenate((pos_s, temp_s), axis=0)

    return pos_s


class metab_simulation:
    def __init__(
        self,
        mean,
        std_dev,
        density,
        L,
        cyl_rad,
        R_art,
        R_cap,
        directness,
        CMRO2_max,
        phi_0,
        measures,
        BC_type,
        BC_value,
        cells,
        D,
    ):

        self.mean = mean
        self.std_dev = std_dev
        self.density = density
        self.L = L
        self.cyl_rad = cyl_rad
        self.R_art = R_art
        self.R_cap = R_cap

        self.directness = directness
        self.CMRO2_max = CMRO2_max
        self.phi_0 = phi_0
        self.measures = measures
        self.BC_type = BC_type
        self.BC_value = BC_value

        pos_s = np.array([[0.5, 0.5]]) * L
        pos_s = np.concatenate((pos_s, position_sources(density, L, L / 4)), axis=0)
        S = len(pos_s)
        C_v_array = np.array([1])
        Rv = np.array([R_art])
        for i in range(S - 1):
            C_v_array = np.append(C_v_array, random.gauss(mean, std_dev))
            Rv = np.append(Rv, R_cap)
        if np.any(pos_s > L) or np.any(pos_s < 0):
            print("ERROR IN THE POSITIONING")

        h_coarse = L / cells
        x_coarse = np.linspace(
            h_coarse / 2, L - h_coarse / 2, int(np.around(L / h_coarse))
        )
        y_coarse = x_coarse

        self.K_eff = np.zeros(len(C_v_array)) + math.inf

        t = Testing(
            pos_s,
            Rv,
            cells,
            L,
            self.K_eff,
            D,
            directness,
            1,
            C_v_array,
            BC_type,
            BC_value,
        )

        print("CMRO2= ", CMRO2_max)

        plot_sketch(x_coarse, y_coarse, directness, h_coarse, pos_s, L, os.getcwd())
        C_v_array[C_v_array > 1] = 1
        C_v_array[C_v_array < 0] = 0
        self.pos_s = t.pos_s
        self.C_v_array = t.C_v_array
        self.t = t

        self.x = t.x_coarse
        self.y = t.y_coarse

    def do_simulation(self, plateau, ratio):

        t = self.t
        self.pos_s = t.pos_s
        self.C_v_array = t.C_v_array
        L = self.L
        cyl_rad = self.cyl_rad

        K_eff = self.K_eff
        directness = self.directness
        CMRO2_max = self.CMRO2_max
        phi_0 = self.phi_0
        measures = self.measures
        BC_type = self.BC_type
        BC_value = self.BC_value
        pos_s = self.pos_s

        Multi_FV_metab, Multi_q_metab = t.Multi(CMRO2_max, phi_0)

        self.q = Multi_q_metab
        self.s = Multi_FV_metab

        if plateau:
            REC_phi_array, r = self.reconstruct_perivascular_radial_profile(
                pos_s,
                self.Rv,
                self.h_coarse,
                L,
                K_eff,
                self.D,
                directness,
                measures,
                Multi_FV_metab,
                Multi_q_metab,
                BC_type,
                t,
            )

            self.REC_phi_array = REC_phi_array
            self.r = r

            avg_phi_REC = np.sum(REC_phi_array, axis=0) / measures
            max_phi_REC = REC_phi_array[np.argmax(np.sum(REC_phi_array, axis=1))]
            min_phi_REC = REC_phi_array[np.argmin(np.sum(REC_phi_array, axis=1))]

            plt.plot(np.linspace(0, L / 2, 100), avg_phi_REC)
            plt.xlabel("$\   m$")
            plt.title("average of the average")
            plt.show()
            Rec_plat_x = np.linspace(L / 200, 199 * L / 200, 100) - L / 2
            Rec_plat_y = np.linspace(L / 200, 199 * L / 200, 100) - L / 2

            X, Y = np.meshgrid(Rec_plat_x, Rec_plat_y)
            dist = X**2 + Y**2

            pos = np.arange(len(Rec_plat_x) ** 2)[
                np.ndarray.flatten(dist) > cyl_rad**2
            ]  #
            r.set_up_manual_reconstruction_space(
                np.ndarray.flatten(X)[pos] + L / 2, np.ndarray.flatten(Y)[pos] + L / 2
            )
            r.full_rec(t.C_v_array, BC_value, BC_type)
            REC_phi = r.s + r.SL + r.DL

            plateau = np.append(plateau, np.sum(REC_phi) / len(REC_phi))
            (avg_phi_REC, max_phi_REC, min_phi_REC)

        if ratio:  # full_reconstruction
            t.ratio = ratio
            toreturn = t.Reconstruct_Multi(1, 0)
            self.a = toreturn
        return toreturn

    def reconstruct_perivascular_radial_profile(self, non_linear_object, measures):
        t = non_linear_object
        D = t.D
        directness = t.directness
        Multi_FV_metab = t.Multi_FV_metab
        Multi_q_metab = t.Multi_q_metab
        BC_type = t.BC_type
        K_eff = t.K_eff
        h_coarse = t.h_coarse
        pos_s = t.pos_s
        Rv = t.Rv
        L = t.L

        theta_arr = np.linspace(0, 2 * np.pi, measures)
        # Here we will do the loop over the variations along theta
        REC_phi_array = np.zeros((0, 100))
        REC_x_array = np.zeros((0, 100))
        REC_y_array = np.zeros((0, 100))
        for i in theta_arr:
            # loop over the angle
            REC_x = np.linspace(0, L * np.cos(i) / 2, 100) + L / 2
            REC_y = np.linspace(0, L * np.sin(i) / 2, 100) + L / 2
            r = reconstruction_extended_space(
                pos_s, Rv, h_coarse, L, K_eff, D, directness
            )
            r.s_FV = Multi_FV_metab
            r.q = Multi_q_metab
            r.set_up_manual_reconstruction_space(REC_x, REC_y)
            if np.any(BC_type == "Infinite"):
                r.b_prime = t.n.b_prime
            r.full_rec(t.C_v_array, t.BC_value, t.BC_type)
            REC_phi = r.s + r.SL + r.DL

            REC_x_array = np.concatenate((REC_x_array, [REC_x]), axis=0)
            REC_y_array = np.concatenate((REC_y_array, [REC_y]), axis=0)
            REC_phi_array = np.concatenate((REC_phi_array, [REC_phi]), axis=0)

        return REC_phi_array, r

    def rec_capillary_bed_avg(self, non_linear_object, measures, r, layer, Da, *title):
        """This function reconstructs the average of the capillary bed radially"""
        t = non_linear_object
        D = t.D
        s = t.s_Multi_cart_metab
        q = t.q_Multi_metab
        pos_s = t.pos_s
        Rv = t.Rv
        C_v_array = t.C_v_array

        theta_arr = np.linspace((2 * np.pi) / measures, 2 * np.pi, measures)

        r_cap = np.zeros(len(r))
        DL = r_cap.copy()
        r_art = r_cap.copy()
        slow_fine = r_cap.copy()
        for i in range(len(r)):
            rec_x, rec_y = (
                np.array([np.cos(theta_arr), np.sin(theta_arr)])
                * np.abs(r[i] - t.L / 2)
                + t.L / 2
            )

            p = reconstruction_extended_space(
                t.pos_s, t.Rv, t.h_coarse, t.L, t.K_eff, t.D, t.directness
            )
            p.s_FV = s
            p.q = q
            p.set_up_manual_reconstruction_space(rec_x, rec_y)
            p.full_rec(t.C_v_array, t.BC_value, t.BC_type)

            r_art[i] = (
                Green(pos_s[0], np.array([r[i], t.L / 2]), Rv[0], 1) * q[0]
                + C_v_array[0]
            )
            # r_art[i]=Green(pos_s[0], np.array([x[i],y]), R_art, 1)*m.q[0]
            slow_fine[i] = np.sum(p.s) / len(rec_x)

            for k in theta_arr:

                x = np.array([np.cos(k), np.sin(k)]) * np.abs(r[i] - t.L / 2) + t.L / 2
                for j in np.arange(len(q) - 1) + 1:
                    r_cap[i] += Green(pos_s[j], x, Rv[j], D) * q[j] + C_v_array[j]
                    # r_cap[i]+=Green(m.pos_s[j], np.array([x[i],y]), t.Rv[j], 1)*q[j]
                    # r_cap[i]+=Green(pos_s[j], np.array((x[i],y)), R_art/4, 1)*Multi_q_metab[j]
                    DL[i] += C_v_array[j]

        slow_coarse = s.reshape(t.cells, t.cells)

        slow_coarse = (
            slow_coarse[int(t.cells / 2) - 1] / 4
            + slow_coarse[int(t.cells / 2)] / 4
            + slow_coarse[:, int(t.cells / 2) - 1] / 4
            + slow_coarse[:, int(t.cells / 2)] / 4
        )
        r_cap /= measures
        DL /= measures
        self.r_cap = r_cap
        self.DL_cap = DL
        self.r_art = r_art
        self.slow_coarse = slow_coarse
        self.slow_fine = slow_fine

        if title:
            np.save(title[0] + "_cap_{}_{}".format(layer, Da), r_cap)
            np.save(title[0] + "_art_{}_{}".format(layer, Da), r_art)
            np.save(title[0] + "_slow_{}_{}".format(layer, Da), slow_fine)

        return


def get_met_plateau(b, L, cap_free_length):
    pos_0 = np.around(len(b) * L / cap_free_length)
    return np.sum(b[pos_0:]) / (len(b) - pos_0)
