"""
    __author__ = H. v. Wahl
    __date__ = 20.08.2020
    __update__ = 21.09.2020

    Stationary test case considered in Wahl, Richter, Frei.
"""
from netgen.geom2d import SplineGeometry
from ngsolve import *
from xfem import *
from xfem.lsetcurv import *

from CutFEM_utilities import UpdateMarkers, AddIntegratorsToForm
from Solvers import CutFEM_QuasiNewton

from math import ceil, pi
import os
import argparse

import time
start_time = time.time()

ngsglobals.msg_level = 1
SetHeapSize(10000000)
SetNumThreads(32)


# -------------------------------- PARAMETERS ---------------------------------
parser = argparse.ArgumentParser()
parser.add_argument("--h", help="h_max", type=float)
args = parser.parse_args()
print(args)

# Fluid problem
density_fl = 1141                                   # Density of fluid (kg/m^3)
viscosity_dyn = 0.008                               # Fluid viscosity (kg/m/s)

# Solid problem
centre_ball = 0.1                                   # Center of ball (m)
diam_ball = 0.022                                   # Diameter of ball (m)


# Mesh parameters
lowerleft, upperright = (0, 0), (0.055, 0.2)        # 2D r+/z - Domain (m)
h_max = args.h                                      # Mesh size (m)
inner_factor = 4                                    # Inner region with h/fac
bottom_factor = 1                                   # Inner bottom edge w/ h/fac

# Discretisation parameters
k = 2                                               # Order of velocity space
gamma_n = 100                                       # Nitsche parameter
gamma_s = 0.1                                       # Ghost-penalty extension

# Solver parameters
pReg = 1e-8                                         # Pressure regularisation
inverse = "pardiso"                                 # Sparse direct solver
maxit_newt = 15                                     # Maximum Newton steps
tol_newt = 1e-10                                    # Residual tolerance
jacobi_tol_newt = 0.1                               # Jacobian update tolerance
compile_flag = False                                # Compile forms
wait_compile = False                                # Complete compile first


# ----------------------------------- DATA ------------------------------------
u_in = 0.01                                         # Maximal inflow speed
inflow = CoefficientFunction((0, -u_in * (1 - (x / 0.055)**2)))


def levelset_func(h):
    return (diam_ball / 2) - sqrt(x**2 + (y - h)**2)


# ------------------------------ BACKGROUND MESH ------------------------------
# Determine inner region
ur_inner = (diam_ball * 2 / 3, upperright[1])

pnts = [lowerleft, (ur_inner[0], lowerleft[1]), 
        (upperright[0], lowerleft[1]), upperright, ur_inner,
        (lowerleft[0], upperright[1])]

# Generate geometry
geo = SplineGeometry()
p1, p2, p3, p4, p5, p6 = [geo.AppendPoint(*pnt) for pnt in pnts]
geo.Append(["line", p1, p2], leftdomain=2, rightdomain=0, bc="bot",
           maxh=h_max / bottom_factor)
geo.Append(["line", p2, p3], leftdomain=1, rightdomain=0, bc="bot")
geo.Append(["line", p3, p4], leftdomain=1, rightdomain=0, bc="wall")
geo.Append(["line", p4, p5], leftdomain=1, rightdomain=0, bc="top")
geo.Append(["line", p5, p6], leftdomain=2, rightdomain=0, bc="top")
geo.Append(["line", p6, p1], leftdomain=2, rightdomain=0, bc="rot")
geo.Append(["line", p2, p5], leftdomain=2, rightdomain=1)

# Generate mesh
geo.SetDomainMaxH(1, h_max)
geo.SetDomainMaxH(2, h_max / inner_factor)

with TaskManager():
    mesh = Mesh(geo.GenerateMesh(quad_dominated=False))
print(" Meshing completed")


# --------------------------- FINITE ELEMENT SPACE ----------------------------
V = VectorH1(mesh, order=k, dirichletx="wall|rot|top", dirichlety="wall|top")
Q = H1(mesh, order=k - 1)
X = FESpace([V, Q], dgjumps=True)

free_dofs = BitArray(X.ndof)

gfu = GridFunction(X)
vel, pre = gfu.components


# ---------------------------- LEVELSET & CUT-INFO ----------------------------
# Mesh deformation and level-set
lset_meshadap = LevelSetMeshAdaptation(mesh, order=k, threshold=0.1, 
                                       discontinuous_qn=True)
deformation = lset_meshadap.CalcDeformation(levelset_func(centre_ball))
mesh.SetDeformation(deformation)

lsetp1 = lset_meshadap.lset_p1

lset_neg = {"levelset": lsetp1, "domain_type": NEG, "subdivlvl": 0}
lset_if = {"levelset": lsetp1, "domain_type": IF, "subdivlvl": 0}

ci_main = CutInfo(mesh, lsetp1)


# ------------------------------ ELEMENT MARKERS ------------------------------
els, facets = {}, {}
for key in ["hasneg", "if"]:
    els[key] = BitArray(mesh.ne)
    els[key].Clear()

for key in ["active", "gp_stab", "none"]:
    facets[key] = BitArray(mesh.nedge)
    facets[key].Clear()


def UpdateElementInformation():
    """
    Recompute element, facet and dof markers
    """

    # Physical elements
    ci_main.Update(lsetp1)
    UpdateMarkers(els["hasneg"], ci_main.GetElementsOfType(HASNEG))
    UpdateMarkers(els["if"], ci_main.GetElementsOfType(IF))


    # Ghost-penalty facets
    UpdateMarkers(facets["gp_stab"],
                  GetFacetsWithNeighborTypes(mesh, a=els["hasneg"], b=els["if"],
                                             use_and=True))

    # Update degrees of freedom
    UpdateMarkers(free_dofs, GetDofsOfElements(X, els["hasneg"]), X.FreeDofs())

    return None


# --------------------------------- VARIABLES ---------------------------------
(u, p), (v, q) = X.TnT()                            # Trial and Test functions
h = specialcf.mesh_size                             # Mesh size cf.
n_lset = 1.0 / Norm(grad(lsetp1)) * grad(lsetp1)    # Level set normal vector


mu_fl = viscosity_dyn                               # Dynamic viscosity
rho_fl = density_fl                                 # Fluid density


# ----------------------------- (BI)LINEAR FORMS ------------------------------
stokes = mu_fl * x * InnerProduct(Grad(u), Grad(v)) + mu_fl * u[0] * v[0] / x
stokes += - p * (v[0] + x * div(v)) - q * (u[0] + x * div(u))
stokes += - pReg * p * q

convect = rho_fl * x * InnerProduct(Grad(u) * vel, v)
convect_lin = rho_fl * x * InnerProduct(Grad(u) * vel, v)
convect_lin += rho_fl * x * InnerProduct(Grad(vel) * u, v)


nitsche = -mu_fl * x * InnerProduct(Grad(u) * n_lset, v)
nitsche += -mu_fl * x * InnerProduct(Grad(v) * n_lset, u)
nitsche += mu_fl * (gamma_n * k * k / h) * x * InnerProduct(u, v)
nitsche += p * x * InnerProduct(v, n_lset)
nitsche += q * x * InnerProduct(u, n_lset)


ghost_penalty_stab = gamma_s * mu_fl / h**2 * x * (u - u.Other()) * (v - v.Other())
ghost_penalty_stab += -gamma_s / mu_fl * x * (p - p.Other()) * (q - q.Other())


# -------------------------------- INTEGRATORS --------------------------------
def InnerBFI(form, **kwargs):
    return (SymbolicBFI(lset_neg, form=form.Compile(compile_flag,
                                                    wait=wait_compile),
                        **kwargs), "inner")


def BoundaryBFI(form, **kwargs):
    return (SymbolicBFI(lset_if, form=form.Compile(compile_flag,
                                                   wait=wait_compile),
                        **kwargs), "boundary")


def GhostPenaltyBFI(form, **kwargs):
    return(SymbolicFacetPatchBFI(form=form.Compile(compile_flag,
                                                   wait=wait_compile),
                                 skeleton=False), kwargs["domain"])


def InnerLFI(form, **kwargs):
    return (SymbolicLFI(lset_neg, form=form.Compile(compile_flag,
                                                    wait=wait_compile),
                        **kwargs), "inner")


def BoundaryLFI(form, **kwargs):
    return (SymbolicLFI(lset_if, form=form.Compile(compile_flag,
                                                   wait=wait_compile),
                        **kwargs), "boundary")


def BuildNewtonSystem(integrators, integrators_lin, element_map, els_restr):
    mStar = RestrictedBilinearForm(
        X, element_restriction=els_restr["elements"],
        facet_restriction=els_restr["facet"], check_unused=False)
    mStar_lin = RestrictedBilinearForm(
        X, element_restriction=els_restr["elements"],
        facet_restriction=els_restr["facet"], check_unused=False)
    f = LinearForm(X)

    AddIntegratorsToForm(integrators=integrators, a=mStar, f=f,
                         element_map=integrator_markers)
    AddIntegratorsToForm(integrators=integrators_lin, a=mStar_lin,
                         f=None, element_map=integrator_markers)

    return mStar, mStar_lin, f


integrator_markers = {"inner": els["hasneg"],
                      "boundary": els["if"],
                      "facets_if": facets["gp_stab"]}
els_restr = {"elements": els["hasneg"], "facet": facets["gp_stab"]}


integrators, integrators_lin = [], []

integrators.append(InnerBFI(stokes + convect))
integrators.append(BoundaryBFI(nitsche))
integrators.append(GhostPenaltyBFI(ghost_penalty_stab, domain="facets_if"))

integrators_lin.append(InnerBFI(stokes + convect_lin))
integrators_lin.append(BoundaryBFI(nitsche))
integrators_lin.append(GhostPenaltyBFI(ghost_penalty_stab, domain="facets_if"))



# -------------------------------- FUNCTIONALS --------------------------------
stress = x * mu_fl * Grad(u) * n_lset - x * p * n_lset

drag_r_test, drag_z_test = GridFunction(X), GridFunction(X)
drag_r_test.components[0].Set(CoefficientFunction((1.0, 0.0)))
drag_z_test.components[0].Set(CoefficientFunction((0.0, 1.0)))
res = gfu.vec.CreateVector()


def ComputeDrag():
    a = RestrictedBilinearForm(X, element_restriction=els["if"],
                               facet_restriction=facets["none"],
                               check_unused=False)
    a += SymbolicBFI(lset_if, form=InnerProduct(stress, v),
                     definedonelements=els["if"])
    a.Apply(gfu.vec, res)

    drag_r = -2 * pi * InnerProduct(res, drag_r_test.vec)
    drag_z = -2 * pi * InnerProduct(res, drag_z_test.vec)

    del a

    return drag_r, drag_z


# ------------------------------- TIME STEPPING -------------------------------
with TaskManager():

    UpdateElementInformation()

    # Solve linearised system
    mStar, mStar_lin, f = BuildNewtonSystem(integrators, integrators_lin,
                                            integrator_markers, els_restr)

    vel.Set(inflow, definedon=mesh.Boundaries("top"))
    CutFEM_QuasiNewton(a=mStar, alin=mStar_lin, u=gfu, f=None,
                       freedofs=free_dofs, maxit=maxit_newt,
                       maxerr=tol_newt, inverse=inverse,
                       jacobi_update_tol=jacobi_tol_newt, reuse=False)

    drag_r, drag_z = ComputeDrag()

    # Do output
    print("& {:5.3f} & {:d} & {:d} & \\num{{{:10.8e}}}".format(
        h_max, int(sum(free_dofs)), int(mStar.mat.nze), drag_z))


# ------------------------------ POST-PROCESSING ------------------------------
end_time = time.time() - start_time

Draw(vel, mesh, "vel")
Draw(pre, mesh, "pre")


print("\n----------- Total time: {:02.0f}:{:02.0f}:{:02.0f}:{:06.3f}"
      " ----------".format(end_time // (24 * 60 * 60),
                           end_time % (24 * 60 * 60) // (60 * 60),
                           end_time % 3600 // 60,
                           end_time % 60))
