#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 17 08:14:58 2022

@author: pdavid

TESTING MUDULE

"""

import copy

import matplotlib.pylab as pylab
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
directory_script = os.path.dirname(__file__)
with open(os.path.join(directory_script,'../current_directory.txt'), 'r') as file:
    file_contents = file.read()
# Now, file_contents contains the entire content of the file as a string

source_directory=os.path.join(file_contents[:-1], 'src')
import sys
sys.path.append(source_directory)
from FV_reference import FV_validation
from Module_Coupling_sparse import non_linear_metab_sparse
from reconst_and_test_module import reconstruction_sans_flux
from Reconstruction_extended_space import reconstruction_extended_space

params = {
    "legend.fontsize": "x-large",
    "figure.figsize": (6, 6),
    "axes.labelsize": "x-large",
    "axes.titlesize": "x-large",
    "xtick.labelsize": "x-large",
    "ytick.labelsize": "x-large",
}
pylab.rcParams.update(params)
# plt.style.use('default')
#plt.style.use("classic")
plt.rcParams["image.interpolation"] = "nearest"


class Testing:
    def __init__(
        self,
        pos_s,
        Rv,
        cells,
        L,
        K_eff,
        D,
        directness,
        ratio,
        C_v_array,
        BC_type,
        BC_value,
    ):
        """The solution to each case is stored as the solution on the FV grid
        and an array of the estimation of the vessel tissue exchanges (some type
        of q array)

        It further stores the concentration field on straight vertical and horizontal lines
        passing through each of the centers of the circular sources"""
        self.h_coarse = L / cells
        self.L = L
        self.ratio = ratio
        self.directness = directness
        self.Rv = Rv
        self.pos_s = pos_s
        self.K0 = np.pi * Rv**2 * K_eff
        self.D = D
        self.C_v_array = C_v_array
        self.BC_type = BC_type
        self.BC_value = BC_value
        self.cells = cells
        self.K_eff = K_eff

        # Metabolism Parameters by default
        self.conver_residual = 5e-5
        self.stabilization = 0.5

        # Definition of the Cartesian Grid
        self.x_coarse = np.linspace(
            self.h_coarse / 2, L - self.h_coarse / 2, int(np.around(L / self.h_coarse))
        )
        self.y_coarse = self.x_coarse.copy()

        self.x_fine = np.linspace(
            self.h_coarse / (2 * ratio),
            L - self.h_coarse / (2 * ratio),
            int(np.around(L * ratio / self.h_coarse)),
        )
        self.y_fine = self.x_fine.copy()

        self.no_interpolation = 0

    def Linear_FV_Peaceman(self, Peaceman):
        """Performs the simulation for a refined Peaceman"""

        L = self.L
        pos_s = self.pos_s
        cells = self.cells
        C_v_array = self.C_v_array
        Rv = self.Rv
        BC_type = self.BC_type
        BC_value = self.BC_value
        K_eff = self.K0 / (np.pi * self.Rv**2)

        FV = FV_validation(
            L,
            cells * self.ratio,
            pos_s,
            C_v_array,
            self.D,
            K_eff,
            Rv,
            BC_type,
            BC_value,
            Peaceman,
        )
        #####################################
        #####  CORR ARRAY!!
        #####################################
        phi_FV_linear = FV.solve_linear_system()
        q_FV_linear = FV.get_q_linear()
        phi_FV_linear_matrix = phi_FV_linear.reshape(
            cells * self.ratio, cells * self.ratio
        )

        plt.imshow(phi_FV_linear_matrix, origin="lower")
        plt.colorbar()
        plt.title(
            "FV Peaceman solution, linear system\n mesh:{}x{}".format(
                self.ratio * cells, self.ratio * cells
            )
        )
        plt.show()

        array_phi_field_x_linear = np.zeros((len(pos_s), len(self.x_fine)))
        array_phi_field_y_linear = np.zeros((len(pos_s), len(self.y_fine)))

        # The following chunk of code is to plot the vertical line that goes through
        # the center of each source. I commented since it takes quite a bit of time to
        # plot when there are many sources. It is functional on 22nd-Nov-2022
        # =============================================================================
        #         c=0
        #         for i in pos_s:
        #             pos=coord_to_pos(FV.x, FV.y, i)
        #
        #             plt.plot(self.x_fine, mat_linear[pos//len(FV.x),:], label="FV")
        #             plt.xlabel("x $\mu m$")
        #             plt.legend()
        #             plt.title("Linear Peaceman solution")
        #             plt.show()
        #
        #             array_phi_field_x_linear[c]=mat_linear[int(pos//len(FV.x)),:]
        #             array_phi_field_y_linear[c]=mat_linear[:,int(pos%len(FV.x))]
        #             c+=1
        # =============================================================================
        self.array_phi_field_x_linear_Peaceman = array_phi_field_x_linear
        self.array_phi_field_y_linear_Peaceman = array_phi_field_y_linear
        if Peaceman:
            self.q_Peaceman_linear_Peaceman = q_FV_linear
            self.phi_Peaceman_linear_Peaceman = phi_FV_linear + FV.get_corr_array()
        else:
            self.q_FV_linear = q_FV_linear
            self.phi_FV_linear = phi_FV_linear

        return (phi_FV_linear, q_FV_linear)

    def Metab_FV_Peaceman(self, M, phi_0, Peaceman):
        """Peaceman=1 if Peaceman coupling model
        Peaceman=0 if no coupling"""
        L = self.L
        pos_s = self.pos_s
        cells = self.cells
        C_v_array = self.C_v_array
        Rv = self.Rv
        K_eff = self.K0 / (np.pi * self.Rv**2)
        BC_type = self.BC_type
        BC_value = self.BC_value
        x_fine, y_fine = (
            self.x_fine,
            self.y_fine,
        )  # Values of the original mesh taking into account the ratio

        # Standard FV model object:
        FV = FV_validation(
            L,
            cells * self.ratio,
            pos_s,
            C_v_array,
            self.D,
            K_eff,
            Rv,
            BC_type,
            BC_value,
            Peaceman,
        )

        # The following is a mess. The stabilization for the FV method needs to
        # be generally way lower than for the hybrid, so I divided by 5...
        phi_FV_metab = FV.solve_non_linear_system(
            phi_0, M, self.stabilization / 5
        )  # We solve the non linear model
        phi_FV_metab = (FV.phi_metab[-1] + FV.Corr_array).reshape(
            cells * self.ratio, cells * self.ratio
        )  # We reverse the Peaceman
        # correction so the phi value represents the average value in each cell
        q_FV_metab = FV.get_q_metab()
        phi_FV_metab_matrix = phi_FV_metab.reshape(
            cells * self.ratio, cells * self.ratio
        )

        plt.imshow(phi_FV_metab_matrix, origin="lower", vmax=np.max(phi_FV_metab))
        plt.title("FV metab reference")
        plt.colorbar()
        plt.show()

        array_phi_field_x_metab = np.zeros((len(pos_s), len(self.x_fine)))
        array_phi_field_y_metab = np.zeros((len(pos_s), len(self.y_fine)))

        # The following chunk of code is to plot the vertical line that goes through
        # the center of each source. I commented since it takes quite a bit of time to
        # plot when there are many sources. It is functional on 22nd-Nov-2022
        # =============================================================================
        #         for i in pos_s:
        #             pos=coord_to_pos(x_fine, y_fine, i)
        #             plt.plot(phi_FV_metab_matrix[pos//len(FV.x),:], label="Peac metab")
        #             plt.legend()
        #             plt.show()
        #             array_phi_field_x_metab[c]=mat_metab[pos//len(FV.x),:]
        #             array_phi_field_y_metab[c]=mat_metab[:,int(pos%len(FV.x))]
        #             c+=1
        # =============================================================================

        if Peaceman:
            self.array_phi_field_x_metab_Peaceman = array_phi_field_x_metab
            self.array_phi_field_y_metab_Peaceman = array_phi_field_y_metab
            self.q_Peaceman_metab = q_FV_metab
            self.phi_Peaceman_metab = np.ndarray.flatten(
                phi_FV_metab
            ) + FV.get_corr_array(1)

        else:
            self.array_phi_field_x_metab_noPeaceman = array_phi_field_x_metab
            self.array_phi_field_y_metab_noPeaceman = array_phi_field_y_metab
            self.q_FV_metab = q_FV_metab
            self.phi_FV_metab = phi_FV_metab

        return (phi_FV_metab, q_FV_metab)

    def get_slow_rapid(self, C_v_array):
        """Returns the slow potential in both fine and coarse, and the rapid potential"""
        import pdb

        pdb.set_trace()

    def Multi(self, *Metab):
        """Solves the problem"""
        cells = self.cells
        n = non_linear_metab_sparse(
            self.pos_s,
            self.Rv,
            self.h_coarse,
            self.L,
            self.K_eff,
            self.D,
            self.directness,
        )
        if self.no_interpolation:
            n.no_interpolation = 1
        
        n.solve_linear_prob(self.BC_type, self.BC_value, self.C_v_array)
        
        
        
        s_Multi_cart_linear = n.s_FV_linear
        q_Multi_linear = n.q_linear
        
        # n.get_slow_rapid(self.C_v_array)
        self.q_Multi_linear = q_Multi_linear
        self.s_Multi_cart_linear = s_Multi_cart_linear
        self.phi_bar = n.phi_bar
        self.phi_bar2 = n.phi_bar2
        

        self.s_blocks = n.s_blocks
        n.phi_0, n.M = 1, 1
        # =============================================================================
        #
        #         n.assemble_it_matrices_Sampson(np.ndarray.flatten(s_Multi_cart_coarse_linear), q_Multi_linear)
        #         plt.imshow(n.rec_sing.reshape(cells,cells)+s_Multi_cart_coarse_linear, origin='lower', extent=[0,self.L, 0, self.L])
        #         plt.title("Average value reconstruction Multi model")
        #         plt.colorbar(); plt.show()
        #
        #         #self.Multi_linear_object.rec_sing for the potentials averaged per FV cell
        # =============================================================================
        self.Multi_linear_object = copy.deepcopy(n)
        # self.Multi_linear_object.rec_sing for the potentials averaged per FV cell
        self.n = n
        # =============================================================================
        # This is a portion of code to test the total flux going out. Useful to test the inf boundary condition
        #         import pdb
        #         pdb.set_trace()
        #         array_opposite = np.array([1, 0, 3, 2])
        #         flux_out=np.array([])
        #         for j in range(4):
        #             c=0
        #             for i in n.boundary[j,:]:
        #                 normal = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])[j]
        #                 m = n.boundary[array_opposite[j], c]
        #                 # The division by h is because the kernel calculates the integral, what we
        #                 # need is an average value per full cell
        #                 (
        #                     _,
        #                     r_k_grad_face_kernel,
        #                     _,
        #                     _,
        #                 ) = n.get_interface_kernels(i, normal, m)
        #
        #                 flux_out_t=n.A_matrix_virgin[i].dot(np.ndarray.flatten(self.s_Multi_cart_linear)) + n.b_matrix_virgin[i].dot(self.q_Multi_linear) + r_k_grad_face_kernel.dot(self.q_Multi_linear)
        #                 flux_out=np.append(flux_out, flux_out_t)
        #                 c+=1
        # =============================================================================
        if Metab:
            M, phi_0 = Metab
            n.Full_Newton(
                np.ndarray.flatten(s_Multi_cart_linear),
                np.ndarray.flatten(n.q_linear),
                self.conver_residual,
                M,
                phi_0,
            )

            self.Multi_metab_object = copy.deepcopy(n)
            s_Multi_cart_metab = n.s_FV_metab
            q_Multi_metab = n.q_metab

            n.assemble_it_matrices_Sampson(n.s_FV_metab, n.q_metab)
            plt.imshow(
                (n.rec_sing + s_Multi_cart_metab).reshape(cells, cells),
                origin="lower",
                extent=[0, self.L, 0, self.L], interpolation=None
            )
            plt.title("Coarse grid, reactive model $\phi$ field")
            plt.colorbar()
            plt.show()

            self.rec_sing = n.rec_sing
            self.q_Multi_metab = q_Multi_metab
            self.s_Multi_cart_metab = s_Multi_cart_metab
            self.residual = n.residual
            return (s_Multi_cart_metab, q_Multi_metab)
        else:
            return (s_Multi_cart_linear, q_Multi_linear)

    def Reconstruct_Multi(self, non_linear, plot_sources, *FEM_args):
        """If non_linear the reconstruction will be made on the latest non-linear
        simulation (for the arrays self.Multi_q_metab, and self.s_Multi_cart_coarse_metab)

        Inside FEM_args there are the FEM_x, FEM_y, arrays where to reconstruct the
        concentration field

        IMPORTANT to have provided the proper value of the ratio"""
        if non_linear:
            obj = self.Multi_metab_object
            s_FV = obj.s_FV_metab
            q = obj.q_metab
        else:
            obj = self.Multi_linear_object
            s_FV = np.ndarray.flatten(obj.s_FV_linear)
            q = obj.q_linear

        if not FEM_args:  # Cartesian reconstruction:
            print(
                "Have you updated the value of the ratio?? \n right now -> ratio={}".format(
                    self.ratio
                )
            )
            a = reconstruction_sans_flux(
                np.concatenate((s_FV, q)), obj, obj.L, self.ratio, obj.directness
            )
            a.reconstruction()
            a.reconstruction_boundaries_short(self.BC_type, self.BC_value)
            a.rec_corners()
            # plt.imshow(a.rec_final, origin="lower",cmap='coolwarm' ,vmax=np.max(a.rec_final))
            plt.imshow(a.rec_final, origin="lower", vmax=np.max(a.rec_final))
            plt.title("Reconstructed $\phi$ field \n reactive case")
            plt.colorbar()
            plt.show()
            self.Multi_rec = a.rec_final
            self.cart_rec_fine_x = a.x
            toreturn = a.rec_final, a.rec_potentials, a.rec_s_FV

        if FEM_args:
            FEM_x = FEM_args[0]
            FEM_y = FEM_args[1]
            b = reconstruction_extended_space(
                self.pos_s,
                self.Rv,
                self.h_coarse,
                self.L,
                self.K_eff,
                self.D,
                self.directness,
            )
            b.s_FV = s_FV
            b.q = q
            b.set_up_manual_reconstruction_space(FEM_x, FEM_y)
            b.full_rec(self.C_v_array, self.BC_value, self.BC_type)
            if plot_sources:
                plt.tricontourf(b.FEM_x, b.FEM_y, b.s, levels=100)
                plt.colorbar()
                plt.title("s FEM rec")
                plt.show()
                plt.tricontourf(b.FEM_x, b.FEM_y, b.SL, levels=100)
                plt.colorbar()
                plt.title("SL FEM rec")
                plt.show()
                plt.tricontourf(b.FEM_x, b.FEM_y, b.DL, levels=100)
                plt.colorbar()
                plt.title("DL FEM rec")
                plt.show()
                plt.tricontourf(b.FEM_x, b.FEM_y, b.s + b.SL + b.DL, levels=100)
                plt.colorbar()
                plt.title("Full FEM rec")
                plt.show()

            self.SL = b.SL
            self.DL = b.DL
            self.s = b.s

            # We return the phi-field, the single layer field and the smooth field
            toreturn = b.s + b.SL + b.DL, b.SL, b.s

        if plot_sources:
            array_phi_field_x = np.zeros((len(self.pos_s), len(self.x_fine)))
            array_phi_field_y = np.zeros((len(self.pos_s), len(self.y_fine)))
            c = 0
            for i in self.pos_s:
                r = reconstruction_extended_space(
                    self.pos_s,
                    self.Rv,
                    self.h_coarse,
                    self.L,
                    self.K_eff,
                    self.D,
                    self.directness,
                )
                r.s_FV = s_FV
                r.q = q
                r.set_up_manual_reconstruction_space(
                    i[0] + np.zeros(len(self.x_fine)), self.y_fine
                )
                if np.any(self.BC_type == "Infinite"):
                    r.b_prime = obj.b_prime
                r.full_rec(self.C_v_array, self.BC_value, self.BC_type)
                array_phi_field_y[c] = r.s + r.SL + r.DL

                r.set_up_manual_reconstruction_space(
                    self.x_fine, i[1] + np.zeros(len(self.y_fine))
                )
                r.full_rec(self.C_v_array, self.BC_value, self.BC_type)
                array_phi_field_x[c] = r.s + r.SL + r.DL

                plt.plot(self.x_fine, array_phi_field_x[c], label="Multi_x")
                plt.legend()
                plt.title("Plot through source center")
                plt.xlabel("x")
                plt.ylabel("$\phi$")
                plt.show()

                plt.plot(self.y_fine, array_phi_field_y[c], label="Multi_y")
                plt.legend()
                plt.title("Plot through source center")
                plt.xlabel("y")
                plt.ylabel("$\phi$")
                plt.show()

            c += 1
            self.array_phi_field_x_Multi = array_phi_field_x
            self.array_phi_field_y_Multi = array_phi_field_y
        # We return the phi-field, the single layer field and the smooth field
        return toreturn
    
    def get_reaction_map(self, s_FV, q):
        """This function is intended to provide the values of the reaction terms 
        within each FV cell. In order to also test the conservative properties of
        the model, we evaluate the conservativeness through the FV method as well.
        This means that we use the flux conservation developed throught the linear
        numerical model to see how much mass is consumed within each cell"""
        D_operator=self.n.LIN_MAT
        conservation=self.D*(D_operator.dot(np.concatenate((np.ndarray.flatten(s_FV), q)))+self.n.H0)/self.n.h**2
        
        return(conservation[:-self.n.S].reshape(self.cells, self.cells), conservation[-self.n.S:])
        


def extract_COMSOL_data(directory_COMSOL, args):

    """args corresponds to which files need to be extracted"""
    toreturn = []
    if args[0]:
        # Vessel_tissue exchanges reference
        q_file = directory_COMSOL + "/q.txt"
        q = np.array(pd.read_fwf(q_file, infer_rows=500).columns.astype(float))
        toreturn.append(q)

    if args[1]:
        # Concentration field data for the linear problem
        field_file = directory_COMSOL + "/contour.txt"
        df = pd.read_fwf(field_file, infer_nrows=500)
        ref_data = np.array(df).T  # reference 2D data from COMSOL
        FEM_x = ref_data[0] * 10**6  # in micrometers
        FEM_y = ref_data[1] * 10**6
        FEM_phi = ref_data[2]
        toreturn.append(FEM_phi)
        toreturn.append(FEM_x)
        toreturn.append(FEM_y)

    if args[2]:
        # Plots of the concentration field along a horizontal and vertial lines passing through the center of the source
        x_file = directory_COMSOL + "/plot_x.txt"
        y_file = directory_COMSOL + "/plot_y.txt"

        FEM_x_1D = np.array(pd.read_fwf(x_file, infer_rows=500)).T[1]
        x_1D = np.array(pd.read_fwf(x_file, infer_rows=500)).T[0] * 10**6
        FEM_y_1D = np.array(pd.read_fwf(y_file, infer_rows=500)).T[1]
        y_1D = np.array(pd.read_fwf(y_file, infer_rows=500)).T[0] * 10**6

        toreturn.append(FEM_x_1D)
        toreturn.append(FEM_y_1D)
        toreturn.append(x_1D)
        toreturn.append(y_1D)

    return toreturn


def FEM_to_Cartesian(FEM_x, FEM_y, FEM_phi, x_c, y_c):
    phi_Cart = np.zeros((len(y_c), len(x_c)))
    for i in range(len(y_c)):
        for j in range(len(x_c)):
            dist = (FEM_x - x_c[j]) ** 2 + (FEM_y - y_c[i]) ** 2
            phi_Cart[i, j] = FEM_phi[np.argmin(dist)]
    return phi_Cart


import csv


def write_parameters_COMSOL(pos_s, L, alpha, K_0, M):
    """Writes the parameter file for COMSOL"""
    rows = [
        ["L", L],
        ["alpha", alpha],
        ["R", "L/alpha"],
        ["K_com", "K_0/(2*pi*R)"],
        ["M", M],
        ["phi_0", 0.4],
    ]

    for i in range(len(pos_s)):
        rows.append(["x_{}".format(i), np.around(pos_s[i, 0], decimals=4)])
        rows.append(["y_{}".format(i), np.around(pos_s[i, 1], decimals=4)])
    with open("Parameters.txt", "w") as f:
        writer = csv.writer(f, delimiter=" ")
        for i in rows:
            writer.writerow(i)


def save_csv(path, name_columns, data_phi):
    b = pd.DataFrame(data_phi.T)
    b.columns = name_columns
    b.to_csv(path, sep=",", index=None)
    return b


def save_var_name_csv(directory, variable):
    variable_name = f"{variable=}".split("=")[0]
    path = directory + "/" + variable_name + ".csv"
    save_csv(path, [variable_name], variable)
    return path


def load_var_name(directory, variable):
    variable_name = f"{variable=}".split("=")[0]
    path = directory + "/" + variable_name + ".csv"
    return pd.read_csv(path).to_numpy()


def array_cartesian_positions(x, y):
    x_c = np.zeros(len(x) * len(y))
    y_c = np.copy(x_c)
    for i in range(len(y)):
        x_c[i * len(x) : (i + 1) * len(x)] = x
        y_c[i * len(x) : (i + 1) * len(x)] = np.zeros(len(x)) + y[i]
    return (x_c, y_c)
