#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: pdavid
THIS IS AN EXAMPLE SCRIPT TO RUN A SIMULATION FOR THE 0D-2D MULTISCALE MODEL
"""
# djkflmjaze
import os

directory_script = os.path.dirname(__file__)

csv_directory = os.path.join(directory_script, "csv_outputs")

directory_script = os.path.dirname(__file__)
csv_directory = os.path.join(directory_script, "csv_outputs")
with open(os.path.join(directory_script,'../current_directory.txt'), 'r') as file:
    file_contents = file.read()
# Now, file_contents contains the entire content of the file as a string

source_directory=os.path.join(file_contents[:-1], 'src')
import sys
sys.path.append(source_directory)


import matplotlib.pylab as pylab
import matplotlib.pyplot as plt
import numpy as np

from Small_functions import plot_sketch
from Testing import Testing

params = {
    "legend.fontsize": "x-large",
    "figure.figsize": (8, 8),
    "axes.labelsize": "x-large",
    "axes.titlesize": "x-large",
    "xtick.labelsize": "x-large",
    "ytick.labelsize": "x-large",
}
pylab.rcParams.update(params)


# 0-Set up the sources
# 1-Set up the domain

#write the ratio of source ratio to domain length 
alpha = 100

#write the value of the maximum metabolic consumption in mumol cm^-3 min^-1
M_input = 2
#write the value of the effective permeability (K_eff in the article)
K0 = 2e-3
#write the value of the length of the computation domain
L = 240
#How many FV cells
cells = 7
h_coarse = L / cells


# Physical constants
solubility = 1.39e-6  # mumol mm^-3 mmHg^-1
Pmax = 60  # mmHg
D_real = 2e-3  # mm^2 s^-1
alpha_Pmax_D = solubility * Pmax * D_real
alpha_Pmax_D_mumol_cm_min = alpha_Pmax_D * 6e10
M = M_input/alpha_Pmax_D_mumol_cm_min
phi_0 = 0.1

# Iterative model constants
conver_residual = 5e-5 #Dictates when the reactive model has converged
stabilization = 0.5 #Stabilization constant

# Definition of the Cartesian Grid
x_coarse = np.linspace(h_coarse / 2, L - h_coarse / 2, int(np.around(L / h_coarse)))
y_coarse = x_coarse


# V-chapeau definition
directness = 2 #This corresponds to (n-1)/2 in the article
print("directness=", directness)


pos_s = np.array([[0.5, 0.5]]) * L
S = len(pos_s)
Rv = L / alpha + np.zeros(S)
ratio = int(100 * h_coarse // L / 4) * 2 #Refinement grid used for the post processing 
print("h coarse:", h_coarse)
K_eff = K0 / (np.pi * Rv**2)/D_real


p = np.linspace(0, 1, 100)
if np.min(p - M * (1 - phi_0 / (phi_0 + p))) < 0:
    print("There is an error in the metabolism")

#Intravascular concentration
C_v_array = np.zeros(S)
C_v_array[0] = 1

#Boundary conditions
BC_value = np.array([0, 0, 0, 0])
BC_type = np.array(["Dirichlet", "Dirichlet", "Dirichlet", "Dirichlet"])
#BC_type = np.array(["Neumann", "Neumann", "Neumann", "Neumann"])
plot_sketch(x_coarse, y_coarse, directness, h_coarse, pos_s, L, directory_script)


t = Testing(
    pos_s, Rv, cells, L, K_eff, 1, directness, ratio, C_v_array, BC_type, BC_value
)
s_Multi_cart_linear, q_Multi_linear = t.Multi()
#%%
t.ratio = 10
a, b, c = t.Reconstruct_Multi(0, 1)
#%%
plt.imshow(a)
plt.colorbar()

#%%
c = 0
plt.plot(t.x_fine, t.array_phi_field_x_Multi[c], label="Multi")
plt.xlabel("x")
plt.legend()
plt.title("linear")
plt.show()

plt.plot(t.y_fine, t.array_phi_field_y_Multi[c], label="Multi")
plt.xlabel("y")
plt.legend()
plt.title("linear")
plt.show()

#%%


s_Multi_cart_metab, q_Multi_metab = t.Multi(M, phi_0)
Multi_rec_metab, _, _ = t.Reconstruct_Multi(1, 1)

plt.plot(t.x_fine, t.array_phi_field_x_Multi[c], label="Multi")
plt.xlabel("x")
plt.title("Metabolism")
plt.legend()
plt.show()

plt.plot(t.y_fine, t.array_phi_field_y_Multi[c], label="Multi")
plt.xlabel("y")
plt.legend()
plt.title("Metabolism")
plt.show()


#%% - Im gonna test the outgoing flux for each boundary node
c = 0
array_opposite = np.array([1, 0, 3, 2])
out_flux_array = np.array([])
for i in t.n.boundary:
    d = 1
    for k in i[1:-1]:  # To avoid the corners
        out_normal = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])[c]
        h = t.h_coarse
        m = t.n.boundary[array_opposite[c], d]
        (_, r_k_grad_face_kernel, _, _) = t.n.get_interface_kernels(k, out_normal, m)

        # out_flux=-t.n.A_matrix_virgin[k].dot(s_Multi_cart_metab)-t.n.b_matrix_virgin[k].dot(Multi_q_metab)-t.n.b_prime.dot(Multi_q_metab)
        out_flux = (
            t.n.A_matrix_virgin[k].dot(np.ndarray.flatten(s_Multi_cart_metab))
            + t.n.b_matrix_virgin[k].dot(q_Multi_metab)
            - r_k_grad_face_kernel.dot(q_Multi_metab)
        )
        out_flux_array = np.append(out_flux_array, out_flux)
        d += 1
    c += 1


print(np.sum(out_flux_array))
