"""
    __author__ = H. v. Wahl
    __date__ = 20.08.2020
    __update__ = 21.09.2020

    Non-stationary test case considered in Wahl, Richter, Frei. The 
    motion of the ball is prescribed analytically.
"""
from netgen.geom2d import SplineGeometry
from ngsolve import *
from xfem import *
from xfem.lsetcurv import *

from CutFEM_utilities import CheckElementHistory, UpdateMarkers, \
    AddIntegratorsToForm
from Solvers import CutFEM_QuasiNewton

from math import ceil, pi
import os
import argparse

import time
start_time = time.time()

ngsglobals.msg_level = 1
SetHeapSize(10000000)
SetNumThreads(24)


# -------------------------------- PARAMETERS ---------------------------------
parser = argparse.ArgumentParser()
parser.add_argument("--h", help="h_max", type=float)
parser.add_argument("--dti", help="dt_inv", type=float)
args = parser.parse_args()
print(args)

# Fluid problem
density_fl = 1141                                   # Density of fluid (kg/m^3)
viscosity_dyn = 0.008                               # Fluid viscosity (kg/m/s)

# Solid Problem
diam_ball = 0.022                                   # Diameter of ball (m)


# Mesh parameters
lowerleft, upperright = (0, 0), (0.055, 0.2)        # 2D r+/z - Domain (m)
h_max = args.h                                      # Mesh size (m)
inner_factor = 4                                    # Inner region with h/fac
bottom_factor = 1                                   # Inner bottom edge w/ h/fac

# Temporal parameters
t_end = 20                                          # Time to run until
dt_inv = args.dti                                   # Inverse time-step

# Discretisation parameters
k = 2                                               # Order of velocity space
c_delta = 4                                         # Ghost-strip width param.
gamma_n = 100                                       # Nitsche parameter
gamma_s = 0.1                                       # Ghost-penalty stability
gamma_e = 0.01                                      # Ghost-penalty extension

# Solver parameters
pReg = 1e-8                                         # Pressure regularisation
inverse = "pardiso"                                 # Sparse direct solver
maxit_newt = 15                                     # Maximum Newton steps
tol_newt = 1e-10                                    # Residual tolerance
jacobi_tol_newt = 0.1                               # Jacobian update tolerance
compile_flag = True                                 # Compile forms
wait_compile = True                                 # Complete compile first

# Output
out_dir = "output/"                                 # Output Directory
out_file = "Output_TimeDepTest_diam{}mu{}rho{}_iso-hmax{}dtinv{}BDF2.txt".format(
    diam_ball, viscosity_dyn, density_fl, h_max, dt_inv)  # File name for text output

# VTK output
vtk_out = False                                     # Write VTK output
vtk_subdiv = 1                                      # Nr subdivisions for vtk
vtk_freq = dt_inv                                   # Nr. vtks per time unit
vtk_dir = out_dir + "vtk_TimeDepTest/iso-hmax{}dtinv{}BDF2/".format(
    h_max, dt_inv)                                  # Directory for vtk output
vtk_file = "BallTest-diam{}mu{}rho{}".format(
    diam_ball, viscosity_dyn, density_fl)           # File name for vtk output


# ----------------------------------- DATA ------------------------------------
velmax = 0.016


def d1(t):
    """Displacement of the cylinder in the vertical direction"""
    return 0.1 + 0.05 * cos(0.1 * pi * t)


def levelset_func(h):
    return (diam_ball / 2) - sqrt(x**2 + (y - h)**2)


def v1(t):
    """Velocity of the cylinder in the z-direction"""
    return - 0.05 * 0.1 * pi * sin(0.1 * pi * t)


def circ_speed(t):
    """Velocity of the cylinder"""
    return CoefficientFunction((0.0, v1(t)))


def sign(x):
    if x >= 0:
        s = 1
    else:
        s = -1
    return s


def lset_center(t):
    """Signed centre of circle"""
    return sign(v1(t)) * (y - d1(t))


# ------------------------------ BACKGROUND MESH ------------------------------
# Determine inner region
ur_inner = (diam_ball * 2 / 3, upperright[1])

pnts = [lowerleft, (ur_inner[0], lowerleft[1]), 
        (upperright[0], lowerleft[1]), upperright, ur_inner,
        (lowerleft[0], upperright[1])]

# Generate geometry
geo = SplineGeometry()
p1, p2, p3, p4, p5, p6 = [geo.AppendPoint(*pnt) for pnt in pnts]
geo.Append(["line", p1, p2], leftdomain=2, rightdomain=0, bc="wall",
           maxh=h_max / bottom_factor)
geo.Append(["line", p2, p3], leftdomain=1, rightdomain=0, bc="wall")
geo.Append(["line", p3, p4], leftdomain=1, rightdomain=0, bc="wall")
geo.Append(["line", p4, p5], leftdomain=1, rightdomain=0, bc="slip")
geo.Append(["line", p5, p6], leftdomain=2, rightdomain=0, bc="slip")
geo.Append(["line", p6, p1], leftdomain=2, rightdomain=0, bc="rot")
geo.Append(["line", p2, p5], leftdomain=2, rightdomain=1)

# Generate mesh
geo.SetDomainMaxH(1, h_max)
geo.SetDomainMaxH(2, h_max / inner_factor)

with TaskManager():
    mesh = Mesh(geo.GenerateMesh(quad_dominated=False))
print(" Meshing completed")


# --------------------------- FINITE ELEMENT SPACE ----------------------------
V = VectorH1(mesh, order=k, dirichletx="wall|rot", dirichlety="wall|slip")
Q = H1(mesh, order=k - 1)
X = FESpace([V, Q], dgjumps=True)

free_dofs = BitArray(X.ndof)

gfu = GridFunction(X)
vel, pre = gfu.components


# ---------------------------- LEVELSET & CUT-INFO ----------------------------
# Main level set
lset_meshadap = LevelSetMeshAdaptation(mesh, order=k, threshold=0.1, 
                                       discontinuous_qn=True)
deformation = lset_meshadap.CalcDeformation(levelset_func(0.0))
deform_last = GridFunction(deformation.space, "deform-1")
deform_last2 = GridFunction(deformation.space, "deform-2")

lsetp1 = lset_meshadap.lset_p1

lset_neg = {"levelset": lsetp1, "domain_type": NEG, "subdivlvl": 0}
lset_if = {"levelset": lsetp1, "domain_type": IF, "subdivlvl": 0}

ci_main = CutInfo(mesh, lsetp1)

# Extension level sets
lsetp1_ext = GridFunction(H1(mesh, order=1))
InterpolateToP1(levelset_func(0.0), lsetp1_ext)
ci_ext = CutInfo(mesh, lsetp1_ext)

lsetp1_center = GridFunction(H1(mesh, order=1))
InterpolateToP1(x - diam_ball / 2, lsetp1_center)
ci_center = CutInfo(mesh, lsetp1_center)


# ------------------------------ ELEMENT MARKERS ------------------------------
els, facets = {}, {}
for key in ["hasneg", "if", "ext", "middle", "active", "tmp", "act_old", 
            "act_old2"]:
    els[key] = BitArray(mesh.ne)
    els[key].Clear()

for key in ["active", "gp_stab", "gp_ext", "none"]:
    facets[key] = BitArray(mesh.nedge)
    facets[key].Clear()

with TaskManager():
    UpdateMarkers(els["middle"], ci_center.GetElementsOfType(NEG))


def UpdateElementInformation(bdf, it):
    """
    Recompute element, facet and dof markers
    """

    # Physical elements
    ci_main.Update(lsetp1)
    UpdateMarkers(els["hasneg"], ci_main.GetElementsOfType(HASNEG))
    UpdateMarkers(els["if"], ci_main.GetElementsOfType(IF))

    # Check History
    if bdf == 1:
        CheckElementHistory(it, mesh.ne, els["hasneg"], els["act_old"])
    elif bdf == 2:
        CheckElementHistory(it, mesh.ne, els["hasneg"], els["act_old"],
                            els["act_old2"])
    else:
        raise SyntaxError("Unimplemented BDF scheme requested")

    # Extension elements
    delta = 2 * max(abs(v1(t.Get())), 
                abs(v1(t.Get() + dt)), 
                abs(v1(t.Get() + 2 * dt))) * dt

    InterpolateToP1(levelset_func(d1(t)) + delta, lsetp1_ext)
    ci_ext.Update(lsetp1_ext)
    UpdateMarkers(els["ext"], ci_ext.GetElementsOfType(HASPOS))

    InterpolateToP1(levelset_func(d1(t.Get() + dt) ), lsetp1_ext)
    ci_ext.Update(lsetp1_ext)
    UpdateMarkers(els["tmp"], ci_ext.GetElementsOfType(HASNEG))
    InterpolateToP1(levelset_func(d1(t.Get() + 2 * dt) ), lsetp1_ext)
    ci_ext.Update(lsetp1_ext)
    els["tmp"] |= ci_ext.GetElementsOfType(HASNEG)

    els["ext"] &= els["tmp"]


    InterpolateToP1(lset_center(t.Get() + dt), lsetp1_center)
    ci_ext.Update(lsetp1_center)
    UpdateMarkers(els["tmp"], ci_ext.GetElementsOfType(HASNEG))
    InterpolateToP1(lset_center(t.Get() + 2 * dt), lsetp1_center)
    ci_ext.Update(lsetp1_center)
    els["tmp"] |= ci_ext.GetElementsOfType(HASNEG)

    els["ext"] &= els["tmp"]
    els["ext"] &= els["middle"]

    UpdateMarkers(els["active"], els["hasneg"] | els["ext"])

    # Ghost-penalty facets
    UpdateMarkers(facets["gp_ext"],
                  GetFacetsWithNeighborTypes(mesh, a=els["active"],
                                             b=els["ext"], use_and=True))
    UpdateMarkers(facets["gp_stab"],
                  GetFacetsWithNeighborTypes(mesh, a=els["hasneg"], b=els["if"],
                                             use_and=True))
    UpdateMarkers(facets["active"], facets["gp_ext"] | facets["gp_stab"])

    # Update degrees of freedom
    UpdateMarkers(free_dofs,
                  CompoundBitArray([GetDofsOfElements(V, els["active"]),
                                    GetDofsOfElements(Q, els["hasneg"])]),
                  X.FreeDofs())

    return None


# --------------------------------- VARIABLES ---------------------------------
(u, p), (v, q) = X.TnT()                            # Trial and Test functions
h = specialcf.mesh_size                             # Mesh size cf.
n_mesh = specialcf.normal(mesh.dim)                 # Mesh normal vector
n_lset = 1.0 / Norm(grad(lsetp1)) * grad(lsetp1)    # Level set normal vector

gfu_last = GridFunction(X)                          # Gridfunction for u^{n-1}
gfu_last_on_new_mesh = GridFunction(X)              # Gridfunction for u^{n-1}
vel_last = gfu_last_on_new_mesh.components[0]
gfu_last2 = GridFunction(X)                         # Gridfunction for u^{n-2}
gfu_last2_on_new_mesh = GridFunction(X)             # Gridfunction for u^{n-2}
vel_last2 = gfu_last2_on_new_mesh.components[0]

t = Parameter(0.0)                                  # Time variable
dt = 1 / dt_inv                                     # Time step
bdf1_steps = ceil(dt / (dt**(4 / 3)))               # Number of BDF1 steps
dt_bdf1 = dt / bdf1_steps                           # BDF1 time step

delta = 2 * dt * velmax                             # Ghost strip-width
K_tilde = int(ceil(delta / h_max))                  # Strip-width in elements

mu_fl = viscosity_dyn                               # Dynamic viscosity
rho_fl = density_fl                                 # Fluid density


drag_x, drag_y = 0.0, 0.0                           # Drag/Lift variables


# ----------------------------- (BI)LINEAR FORMS ------------------------------
mass = rho_fl * x * u * v

stokes = mu_fl * x * InnerProduct(Grad(u), Grad(v)) + mu_fl * u[0] * v[0] / x
stokes += - p * (v[0] + x * div(v)) - q * (u[0] + x * div(u))
stokes += - pReg * p * q

convect = rho_fl * x * InnerProduct(Grad(u) * vel, v)
convect_lin = rho_fl * x * InnerProduct(Grad(u) * vel, v)
convect_lin += rho_fl * x * InnerProduct(Grad(vel) * u, v)


nitsche = -mu_fl * x * InnerProduct(Grad(u) * n_lset, v)
nitsche += -mu_fl * x * InnerProduct(Grad(v) * n_lset, u)
nitsche += mu_fl * (gamma_n * k * k / h) * x * InnerProduct(u, v)
nitsche += p * x * InnerProduct(v, n_lset)
nitsche += q * x * InnerProduct(u, n_lset)


ghost_penalty_ext = gamma_e * K_tilde * (mu_fl + 1 / mu_fl) / h**2 \
                        * x * (u - u.Other()) * (v - v.Other())
ghost_penalty_stab = gamma_s * mu_fl / h**2 * x * (u - u.Other()) * (v - v.Other())
ghost_penalty_stab += -gamma_s / mu_fl * x * (p - p.Other()) * (q - q.Other())


bdf1_rhs = rho_fl * x * InnerProduct(vel_last, v)
bdf2_rhs = rho_fl * x * InnerProduct(2 * vel_last, v)
bdf2_rhs += - rho_fl * x * InnerProduct((1 / 2) * vel_last2, v)


nitsche_rhs = -mu_fl * x * InnerProduct(Grad(v) * n_lset, circ_speed(t))
nitsche_rhs += mu_fl * (gamma_n * k**2 / h) * x * InnerProduct(circ_speed(t), v)
nitsche_rhs += q * x * InnerProduct(circ_speed(t), n_lset)


# -------------------------------- INTEGRATORS --------------------------------
def InnerBFI(form, **kwargs):
    return (SymbolicBFI(lset_neg, form=form.Compile(compile_flag,
                                                    wait=wait_compile),
                        **kwargs), "inner")


def BoundaryBFI(form, **kwargs):
    return (SymbolicBFI(lset_if, form=form.Compile(compile_flag,
                                                   wait=wait_compile),
                        **kwargs), "boundary")


def GhostPenaltyBFI(form, **kwargs):
    return(SymbolicFacetPatchBFI(form=form.Compile(compile_flag,
                                                   wait=wait_compile),
                                 skeleton=False), kwargs["domain"])


def InnerLFI(form, **kwargs):
    return (SymbolicLFI(lset_neg, form=form.Compile(compile_flag,
                                                    wait=wait_compile),
                        **kwargs), "inner")


def BoundaryLFI(form, **kwargs):
    return (SymbolicLFI(lset_if, form=form.Compile(compile_flag,
                                                   wait=wait_compile),
                        **kwargs), "boundary")


def BuildNewtonSystem(integrators, integrators_lin, element_map, els_restr):
    mStar = RestrictedBilinearForm(
        X, element_restriction=els_restr["elements"],
        facet_restriction=els_restr["facet"], check_unused=False)
    mStar_lin = RestrictedBilinearForm(
        X, element_restriction=els_restr["elements"],
        facet_restriction=els_restr["facet"], check_unused=False)
    f = LinearForm(X)

    AddIntegratorsToForm(integrators=integrators, a=mStar, f=f,
                         element_map=integrator_markers)
    AddIntegratorsToForm(integrators=integrators_lin, a=mStar_lin,
                         f=None, element_map=integrator_markers)

    return mStar, mStar_lin, f


integrator_markers = {"inner": els["hasneg"],
                      "boundary": els["if"],
                      "facets_ext": facets["gp_ext"],
                      "facets_if": facets["gp_stab"],
                      "bottom": None}
els_restr = {"elements": els["active"], "facet": facets["active"]}

integrators_bdf1, integrators_bdf1_lin = [], []
integrators_bdf2, integrators_bdf2_lin = [], []

integrators_bdf1.append(InnerBFI(mass + dt_bdf1 * (stokes + convect)))
integrators_bdf1.append(BoundaryBFI(dt_bdf1 * nitsche))
integrators_bdf1.append(GhostPenaltyBFI(dt_bdf1 * ghost_penalty_ext,
                                        domain="facets_ext"))
integrators_bdf1.append(GhostPenaltyBFI(dt_bdf1 * ghost_penalty_stab,
                                        domain="facets_if"))
integrators_bdf1.append(InnerLFI(bdf1_rhs))
integrators_bdf1.append(BoundaryLFI(dt_bdf1 * nitsche_rhs))

integrators_bdf1_lin.append(InnerBFI(mass + dt_bdf1 * (stokes + convect_lin)))
integrators_bdf1_lin.append(BoundaryBFI(dt_bdf1 * nitsche))
integrators_bdf1_lin.append(GhostPenaltyBFI(dt_bdf1 * ghost_penalty_ext,
                                            domain="facets_ext"))
integrators_bdf1_lin.append(GhostPenaltyBFI(dt_bdf1 * ghost_penalty_stab,
                                            domain="facets_if"))

integrators_bdf2.append(InnerBFI(3 / 2 * mass + dt * (stokes + convect)))
integrators_bdf2.append(BoundaryBFI(dt * nitsche))
integrators_bdf2.append(GhostPenaltyBFI(dt * ghost_penalty_ext,
                                        domain="facets_ext"))
integrators_bdf2.append(GhostPenaltyBFI(dt * ghost_penalty_stab,
                                        domain="facets_if"))
integrators_bdf2.append(InnerLFI(bdf2_rhs))
integrators_bdf2.append(BoundaryLFI(dt * nitsche_rhs))

integrators_bdf2_lin.append(InnerBFI(3 / 2 * mass + dt * (stokes + convect_lin)))
integrators_bdf2_lin.append(BoundaryBFI(dt * nitsche))
integrators_bdf2_lin.append(GhostPenaltyBFI(dt * ghost_penalty_ext,
                                            domain="facets_ext"))
integrators_bdf2_lin.append(GhostPenaltyBFI(dt * ghost_penalty_stab,
                                            domain="facets_if"))


# -------------------------------- FUNCTIONALS --------------------------------
stress = x * mu_fl * Grad(u) * n_lset - x * p * n_lset

drag_x_test, drag_y_test = GridFunction(X), GridFunction(X)
drag_x_test.components[0].Set(CoefficientFunction((1.0, 0.0)))
drag_y_test.components[0].Set(CoefficientFunction((0.0, 1.0)))
res = gfu.vec.CreateVector()


def ComputeDrag():
    a = RestrictedBilinearForm(X, element_restriction=els["if"],
                               facet_restriction=facets["none"],
                               check_unused=False)
    a += SymbolicBFI(lset_if, form=InnerProduct(stress, v),
                     definedonelements=els["if"])
    a.Apply(gfu.vec, res)

    drag_x = -2 * pi * InnerProduct(res, drag_x_test.vec)
    drag_y = -2 * pi * InnerProduct(res, drag_y_test.vec)

    del a

    return drag_x, drag_y


# ---------------------------------- OUTPUT -----------------------------------
functionals = {"time": [], "height":[], "vel": [], "drag_x": [],
               "drag_y": []}

if not os.path.isdir(out_dir):
    os.makedirs(out_dir)

fid = open(out_dir + out_file, "w")
fid.write("time\theight\tvel_ball\tdrag_r\tdrag_z\n")
fid.close()


def WriteToFile(str_out):
    fid = open(out_dir + out_file, "a")
    fid.write(str_out) 
    fid.close()

    return None


def CollectAndWriteOutput():
    functionals["time"].append(t.Get())
    functionals["height"].append(d1(t.Get()))
    functionals["vel"].append(v1(t.Get()))
    functionals["drag_y"].append(drag_y)
    functionals["drag_x"].append(drag_x)

    str_out = "{:6.4f}".format(functionals["time"][-1])
    for val in ["height", "vel", "drag_x", "drag_y"]:
        str_out += "\t{:9.7e}".format(functionals[val][-1])
    str_out += "\n"

    WriteToFile(str_out)

    return None


if vtk_out:
    if not os.path.isdir(vtk_dir):
        os.makedirs(vtk_dir)

    vtk = VTKOutput(ma=mesh, coefs=[vel, pre, lsetp1, deformation],
                    names=["velocity", "pressure", "lset", "deformation"],
                    filename=vtk_dir + vtk_file, subdivision=vtk_subdiv)
    vtk.Do()

    vtk_mesh = VTKOutput(ma=mesh, coefs=[1], names=["const"],
                         filename=vtk_dir + vtk_file + "Mesh",
                         subdivision=0)
    vtk_mesh.Do()


# ------------------------------- TIME STEPPING -------------------------------
with TaskManager():

    gfu.vec[:], gfu_last.vec[:], gfu_last2.vec[:] = 0, 0, 0

    # BDF1 warm up
    for it in range(1, bdf1_steps + 1):
        t.Set(it * dt_bdf1)

        # Store data from previous time-step
        gfu_last.vec.data = gfu.vec
        deform_last.vec.data = deformation.vec
        UpdateMarkers(els["act_old"], els["active"])

        deformation = lset_meshadap.CalcDeformation(levelset_func(d1(t)))
        for i in range(2):
            gfu_last_on_new_mesh.components[0].components[i].Set(
                shifted_eval(gfu_last.components[0].components[i],
                             deform_last, deformation))

        UpdateElementInformation(bdf=1, it=it)

        # Solve linearised system
        mStar, mStar_lin, f = BuildNewtonSystem(integrators_bdf1,
                                                integrators_bdf1_lin,
                                                integrator_markers, els_restr)
        f.Assemble()
        CutFEM_QuasiNewton(a=mStar, alin=mStar_lin, u=gfu, f=f.vec,
                           freedofs=free_dofs, maxit=maxit_newt,
                           maxerr=tol_newt, inverse=inverse,
                           jacobi_update_tol=jacobi_tol_newt, reuse=False)

        drag_x, drag_y = ComputeDrag()

        # Do output
        CollectAndWriteOutput()
        Redraw(blocking=True)

        print("t = {:10.6f}, height = {:6.4f}, vel_ball = {:5.3f} active_els"
              " = {:} - 1".format(t.Get(), d1(t.Get()), v1(t.Get()), 
                                  sum(els["active"])))


    if vtk_out and (1 * vtk_freq) % dt_inv == 0:
        projector = Projector(free_dofs, True)
        gfu.vec.data = projector * gfu.vec
        vtk.Do()

    # BDF2
    gfu_last.vec.data = gfu_last2.vec

    for it in range(2, int(t_end * dt_inv + 0.5) + 1):
        t.Set(it * dt)

        # Store data from previous time-step
        gfu_last2.vec.data = gfu_last.vec
        gfu_last.vec.data = gfu.vec
        deform_last2.vec.data = deform_last.vec
        deform_last.vec.data = deformation.vec
        UpdateMarkers(els["act_old2"], els["act_old"])
        UpdateMarkers(els["act_old"], els["active"])

        deformation = lset_meshadap.CalcDeformation(levelset_func(d1(t)))
        for i in range(2):
            gfu_last_on_new_mesh.components[0].components[i].Set(
                shifted_eval(gfu_last.components[0].components[i], 
                             deform_last, deformation))
            gfu_last2_on_new_mesh.components[0].components[i].Set(
                shifted_eval(gfu_last2.components[0].components[i], 
                             deform_last2, deformation))

        UpdateElementInformation(bdf=2, it=it)

        # Solve linearised system
        mStar, mStar_lin, f = BuildNewtonSystem(integrators_bdf2,
                                                integrators_bdf2_lin,
                                                integrator_markers, els_restr)
        f.Assemble()
        CutFEM_QuasiNewton(a=mStar, alin=mStar_lin, u=gfu, f=f.vec,
                           freedofs=free_dofs, maxit=maxit_newt, 
                           maxerr=tol_newt, inverse=inverse,
                           jacobi_update_tol=jacobi_tol_newt, reuse=False)

        drag_x, drag_y = ComputeDrag()

        # Do output
        CollectAndWriteOutput()
        Redraw(blocking=True)
        if vtk_out and (it * vtk_freq) % dt_inv == 0:
            projector = Projector(free_dofs, True)
            gfu.vec.data = projector * gfu.vec
            vtk.Do()

        print("t = {:10.6f}, height = {:6.4f}, vel_ball = {:5.3f} active_els"
              " = {:}".format(t.Get(), d1(t.Get()), v1(t.Get()),
                              sum(els["active"])))


# ------------------------------ POST-PROCESSING ------------------------------
end_time = time.time() - start_time

print("\n----------- Total time: {:02.0f}:{:02.0f}:{:02.0f}:{:06.3f}"
      " ----------".format(end_time // (24 * 60 * 60),
                           end_time % (24 * 60 * 60) // (60 * 60),
                           end_time % 3600 // 60,
                           end_time % 60))

