"""
    __author__ = H. v. Wahl
    __date__ = 21.08.2020
    __update__ = 22.08.2020

    We consider a ball falling through a water-glycerine mixture as
    studied in [1]. To reduce the computational complexity of this three
    dimensional problem, we consider the problem in cylinder coordinates
    and assume that the solution is rotationally symmetric, i.e., 
    invariant with respect to the angle in cylinder coordinates. This
    reduces the three dimensional problem into a two dimensional one.
    
    We use a CutFEM discretisation with Taylor-Hood elements to 
    discretise the Navier-Stokes problem in space and a second order 
    Eulerian time-stepping scheme based on [2] in order to handle the 
    moving domain problem posed here in fully a Eulerian framework.
    The geometry error of order O(h^2) caused by the P1 level set
    approximation is overcome by using an isoparametric mapping [3].

    Fluid and solid motion are decoupled by using an explicit 
    time-stepping scheme to advance the solid motion, such that the
    domain is given explicitly in every time-step. We iterate within 
    every time-step such that the ODE is fulfilled implicitly as the end
    of every time-step.

    Literature:
    -----------
    [1] T. Hagemeier, D. Thévenin & T. Richter: Settling of spherical,
        non-wetting particles in a high viscous fluid.
    [2] H. von Wahl, T. Richter & C. Lehrenfeld: An unfitted Eulerian
        finite element method for the time-dependent Stokes problem on
        moving domains. arXiv:2002.02352v1 [math.NA].
    [3] C. Lehrenfeld: High order unfitted finite element methods on 
        level set domains using isoparametric mappings, CMAME 300:
        716-733, 2016.
"""
from netgen.geom2d import SplineGeometry
from ngsolve import *
from xfem import *
from xfem.lsetcurv import *

from CutFEM_utilities import CheckElementHistory, UpdateMarkers, \
    AddIntegratorsToForm
from Solvers import CutFEM_QuasiNewton

from math import ceil, pi
from numpy import sign
import os
import pickle
import argparse

import time
start_time = time.time()

ngsglobals.msg_level = 1
SetHeapSize(10000000)
SetNumThreads(32)


# -------------------------------- PARAMETERS ---------------------------------
parser = argparse.ArgumentParser()
parser.add_argument("--h", help="h_max", type=float)
parser.add_argument("--dti", help="dt_inv", type=int)
parser.add_argument("--d0", help="dist_0", type=float)
parser.add_argument("--gc", help="gamma_c", type=float)
args = parser.parse_args()
print(args)


# Mesh parameters
lowerleft, upperright = (0, 0), (0.055, 0.2)        # 2D r+/z - Domain (m)
h_max = args.h                                      # Mesh size (m)
inner_factor = 7                                    # Inner region with h/fac
bottom_factor = 1                                   # Inner bottom edge w/ h/fac

# Temporal parameters
t_end = 0.7                                         # Time to run until
dt_inv = args.dti                                   # Inverse time-step

# ODE parameters
tol_ode = 1e-8                                      # Tolerance for ODE Update
relax_ode = 0.5                                     # ODE update Relaxation
max_inner_it = 50                                   # Max nr of sub-iterations

# Contact parameters
dist0 = args.d0                                     # Distance for contact force
gamma_c = args.gc                                   # Contact force parameter

# Discretisation parameters
k = 2                                               # Order of velocity space
c_delta = 4                                         # Ghost-strip width param.
gamma_n = 100                                       # Nitsche parameter
gamma_s = 0.1                                       # Ghost-penalty stability
gamma_e = 0.01                                      # Ghost-penalty extension

# Solver parameters
pReg = 1e-8                                         # Pressure regularisation
inverse = "pardiso"                                 # Sparse direct solver
maxit_newt = 15                                     # Maximum Newton steps
tol_newt = 1e-10                                    # Residual tolerance
jacobi_tol_newt = 0.1                               # Jacobian update tolerance
compile_flag = True                                 # Compile forms
wait_compile = False                                # Complete compile first

# Output
out_dir = "output/"                                 # Output Directory
out_file = "PTFE6-iso-RotSym_hmax{}inner{}bottom{}dtinv{}BDF2_dist{}gc{}.txt".format(
    h_max, inner_factor, bottom_factor, dt_inv, dist0, gamma_c)  # File name for text output

# Pickle parameters
pickle_out = True                                   # Pickle data for restart
pickle_dir = out_dir + "/ptfe-iso-hmax{}dtinv{}/pickle_dist{:3.1e}gc{}/".format(
    h_max, dt_inv, dist0, gamma_c)                  # Pickle data directory
pickle_file = "iso-hmax{}inner{}bottom{}dtinv{}".format(
    h_max, inner_factor, bottom_factor, dt_inv)     # Pickle data file name
pickle_freq = int(dt_inv / 10)                      # Frequency of pickling

restart = False                                     # Restart from pickle data
restart_dir = out_dir + "/ptfe-iso-hmax{}dtinv{}/pickle_dist5.0e-04gc2.0".format(
    h_max, dt_inv)                                  # Restart data directory
restart_file = "/iso-hmax{}inner{}bottom{}dtinv{}_1180.data".format(
    h_max, inner_factor, bottom_factor, dt_inv)     # Data to restart from


# ----------------------------------- DATA ------------------------------------
initial_vel = 0.0                                   # Initial speed of ball
initial_height = 0.1616616                          # Initial bottom of ball (m)

diam_ball = 0.006                                   # Diameter of ball (m)
density_ball = 2122                                 # Density of ball (kg/m^3)

density_fl = 1141                                   # Density of fluid (kg/m^3)
viscosity_dyn = 0.008                               # Fluid viscosity (kg/m/s)

g = -9.807                                          # Gravity field (m/s^2)


def levelset_func(height):
    return (diam_ball / 2) - sqrt(x**2 + (y - (height + (diam_ball / 2)))**2)


def y_speed_vec(v):
    return CoefficientFunction((0, v))


# ---------------------------------- PICKLE -----------------------------------
param = {"h_max": h_max, "inner_factor": inner_factor,
         "bottom_factor": bottom_factor, "dt_inv": dt_inv,
         "tol_ode": tol_ode, "relax_ode": relax_ode,
         "max_inner_it": max_inner_it, "k": k,
         "c_delta": c_delta, "gamma_n": gamma_n, "gamma_s": gamma_s,
         "gamma_e": gamma_e}


if pickle_out:
    if not os.path.isdir(pickle_dir):
        os.makedirs(pickle_dir)

def PickleSolution(step):
    file_out = pickle_dir + pickle_file + "_{}.data".format(step)
    pickler = pickle.Pickler(open(file_out, "wb"))

    data = [ngslib.__version__, step, param, gfu, gfu_last, deformation,
            deform_last, els, facets, vel_ball.Get(), vel_ball_last, height, 
            height_last, functionals]
    pickler.dump(data)

    del pickler

    return None

if restart == True:
    # Load pickle data into local variables
    unpickler = pickle.Unpickler(open(restart_dir + restart_file, "rb"))
    _version_pic, _step_pic, _param_pic, _gfu_pic, _gfu_last_pic, \
        _deformation_pic, _deform_last_pic, _els_pic, _facets_pic, \
        _vel_ball_pic, _vel_ball_last_pic, _height_pic, _height_last_pic, \
        _functionals_pic = unpickler.load()

    if ngslib.__version__ != _version_pic:
        print("WARNING: Unpickled data from different NGSolve version") 
    for key, val in param.items():
            assert val == _param_pic[key], "Loaded data with different parameters"


# ------------------------------ BACKGROUND MESH ------------------------------
if not restart:
    # Determine inner region
    ur_inner = (diam_ball * 2 / 3, upperright[1])

    pnts = [lowerleft, (ur_inner[0], lowerleft[1]), 
            (upperright[0], lowerleft[1]), upperright, ur_inner,
            (lowerleft[0], upperright[1])]

    # Generate geometry
    geo = SplineGeometry()
    p1, p2, p3, p4, p5, p6 = [geo.AppendPoint(*pnt) for pnt in pnts]
    geo.Append(["line", p1, p2], leftdomain=2, rightdomain=0, bc="wall",
               maxh=h_max / bottom_factor)
    geo.Append(["line", p2, p3], leftdomain=1, rightdomain=0, bc="wall")
    geo.Append(["line", p3, p4], leftdomain=1, rightdomain=0, bc="wall")
    geo.Append(["line", p4, p5], leftdomain=1, rightdomain=0, bc="slip")
    geo.Append(["line", p5, p6], leftdomain=2, rightdomain=0, bc="slip")
    geo.Append(["line", p6, p1], leftdomain=2, rightdomain=0, bc="rot")
    geo.Append(["line", p2, p5], leftdomain=2, rightdomain=1)

    # Generate mesh
    geo.SetDomainMaxH(1, h_max)
    geo.SetDomainMaxH(2, h_max / inner_factor)

    with TaskManager():
        mesh = Mesh(geo.GenerateMesh(quad_dominated=False))
    print(" Meshing completed")

else:
    # Get mesh from pickled data to avoid version issues
    mesh = Mesh(_gfu_pic.space.mesh.ngmesh.Copy())
    print("Mesh loaded via pickle restart file")


# --------------------------- FINITE ELEMENT SPACE ----------------------------
# No-penetration on rotation axis is natural through r-scaling
Vx = H1(mesh, order=k, dirichlet="wall")
Vy = H1(mesh, order=k, dirichlet="wall|slip")
Q  = H1(mesh, order=k - 1)
X  = FESpace([Vx, Vy, Q], dgjumps=True)

free_dofs = BitArray(X.ndof)

gfu = GridFunction(X)
vel_x, vel_y, pre = gfu.components
vel = CoefficientFunction((vel_x, vel_y))


# ---------------------------- LEVELSET & CUT-INFO ----------------------------
# Mesh deformation
lset_meshadap = LevelSetMeshAdaptation(mesh, order=k, threshold=0.1, 
                                       discontinuous_qn=True)
deformation = lset_meshadap.CalcDeformation(levelset_func(0.0))
deform_last = GridFunction(deformation.space, "deform-1")
deform_last2 = GridFunction(deformation.space, "deform-2")
lsetp1 = lset_meshadap.lset_p1

lset_neg = {"levelset": lsetp1, "domain_type": NEG, "subdivlvl": 0}
lset_if = {"levelset": lsetp1, "domain_type": IF, "subdivlvl": 0}

ci_main = CutInfo(mesh, lsetp1)

# Extension level sets
lsetp1_ext = GridFunction(H1(mesh, order=1))
InterpolateToP1(levelset_func(0.0), lsetp1_ext)
ci_ext = CutInfo(mesh, lsetp1_ext)

lsetsp1_r = GridFunction(H1(mesh, order=1))
InterpolateToP1(x - diam_ball / 2, lsetsp1_r)
ci_centr = CutInfo(mesh, lsetsp1_r)

lsetsp1_inner = tuple(GridFunction(H1(mesh, order=1)) for i in range(2))
mlci_inner = MultiLevelsetCutInfo(mesh, lsetsp1_inner)
extend_both = False                          # Keep extension for sub iteration


# ------------------------------ ELEMENT MARKERS ------------------------------
els, facets = {}, {}
for key in ["hasneg", "if", "ext", "middle", "active", "tmp1", "tmp2", "tb",
            "act_old", "act_old2"]:
    els[key] = BitArray(mesh.ne)
    els[key].Clear()

for key in ["active", "gp_stab", "gp_ext", "none"]:
    facets[key] = BitArray(mesh.nedge)
    facets[key].Clear()

with TaskManager():
    UpdateMarkers(els["middle"], ci_centr.GetElementsOfType(NEG))


def UpdateElementInformation(height, bdf, it):
    """
    Recompute element, facet and dof markers
    """

    # Physical elements
    ci_main.Update(lsetp1)
    UpdateMarkers(els["hasneg"], ci_main.GetElementsOfType(HASNEG))
    UpdateMarkers(els["if"], ci_main.GetElementsOfType(IF))

    # Check History
    if bdf == 1:
        CheckElementHistory(it, mesh.ne, els["hasneg"], els["act_old"])
    elif bdf == 2:
        CheckElementHistory(it, mesh.ne, els["hasneg"], els["act_old"],
                            els["act_old2"])
    else:
        raise SyntaxError("Unimplemented BDF scheme requested")

    # Extension elements
    InterpolateToP1(levelset_func(height) + delta, lsetp1_ext)
    ci_ext.Update(lsetp1_ext)
    UpdateMarkers(els["tmp1"], ci_ext.GetElementsOfType(HASPOS))

    InterpolateToP1(levelset_func(height + delta), lsetsp1_inner[0])
    InterpolateToP1(levelset_func(height - delta), lsetsp1_inner[1])
    mlci_inner.Update(lsetsp1_inner)
    UpdateMarkers(els["tmp2"], mlci_inner.GetElementsOfType((POS,POS)))

    global extend_both
    if (height + 3 * dt * vel_ball.Get() < dist0 or extend_both):
        UpdateMarkers(els["tb"], ci_ext.GetElementsOfType(ANY))
        extend_both = True  # Extend in all directions for sub-iterations
    else:
        InterpolateToP1(sign(vel_ball.Get()) * (y - height - diam_ball / 2),
                        lsetp1_ext)
        ci_ext.Update(lsetp1_ext)
        UpdateMarkers(els["tb"], ci_ext.GetElementsOfType(HASNEG))

    UpdateMarkers(els["ext"],
                  els["tmp1"],
                  ~els["tmp2"] & els["middle"] & els["tb"])

    UpdateMarkers(els["active"], els["hasneg"] | els["ext"])

    # Ghost-penalty facets
    UpdateMarkers(facets["gp_ext"],
                  GetFacetsWithNeighborTypes(mesh, a=els["active"],
                                             b=els["ext"], use_and=True))
    UpdateMarkers(facets["gp_stab"],
                  GetFacetsWithNeighborTypes(mesh, a=els["hasneg"], b=els["if"],
                                             use_and=True))
    UpdateMarkers(facets["active"], facets["gp_ext"] | facets["gp_stab"])

    # Update degrees of freedom
    UpdateMarkers(free_dofs,
                  CompoundBitArray([GetDofsOfElements(Vx, els["active"]),
                                    GetDofsOfElements(Vy, els["active"]),
                                    GetDofsOfElements(Q, els["hasneg"])]),
                  X.FreeDofs())

    return None


# --------------------------------- VARIABLES ---------------------------------
(ux, uy, p), (vx, vy, q) = X.TnT()                  # Trial and Test functions
u = CoefficientFunction((ux, uy))                   # Velocity trial function
v = CoefficientFunction((vx, vy))                   # Velocity test function
div_u = grad(ux)[0] + grad(uy)[1]
div_v = grad(vx)[0] + grad(vy)[1]
grad_u = CoefficientFunction((grad(ux), grad(uy)), dims=(2, 2))
grad_v = CoefficientFunction((grad(vx), grad(vy)), dims=(2, 2))
grad_vel = CoefficientFunction((grad(vel_x), grad(vel_y)), dims=(2, 2))


h = specialcf.mesh_size                             # Mesh size cf.
n_mesh = specialcf.normal(mesh.dim)                 # Mesh normal vector
n_lset = 1.0 / Norm(grad(lsetp1)) * grad(lsetp1)    # Level set normal vector

gfu_last = GridFunction(X)                          # Gridfunction for un-deformed
gfu_last_on_new_mesh = GridFunction(X)              # Gridfunction for u^{n-1}
vel_x_last = gfu_last_on_new_mesh.components[0]
vel_y_last = gfu_last_on_new_mesh.components[1]
vel_last = CoefficientFunction((vel_x_last, vel_y_last))

gfu_last2 = GridFunction(X)                         # Gridfunction for un-deformed
gfu_last2_on_new_mesh = GridFunction(X)             # Gridfunction for u^{n-2}
vel_x_last2 = gfu_last2_on_new_mesh.components[0]
vel_y_last2 = gfu_last2_on_new_mesh.components[1]
vel_last2 = CoefficientFunction((vel_x_last2, vel_y_last2))

vel_ball = Parameter(initial_vel)                   # Ball velocity parameter
delta = 0.0                                         # Ghost strip-width
K_tilde = Parameter(0)                              # Strip-width in elements

t = Parameter(0.0)                                  # Time variable
dt = 1 / dt_inv                                     # Time step
bdf1_steps = ceil(dt / (dt**(4 / 3)))               # Number of BDF1 steps
dt_bdf1 = dt / bdf1_steps                           # BDF1 time step

mu_fl = viscosity_dyn                               # Dynamic viscosity
rho_fl = density_fl                                 # Fluid density
vol_ball = (4 / 3) * pi * (diam_ball / 2)**3        # Volume of ball
mass_ball = density_ball * vol_ball                 # Mass of ball (kg)
force_b = (density_ball - density_fl) * vol_ball * g  # Net buoyancy force

drag_x, drag_y = 0.0, 0.0                           # Drag/Lift variables
height = initial_height                             # Height variable
height_last = initial_height                        # Variable for last height
vel_ball_last, vel_ball_last2 = 0.0, 0.0              # Variable for last ball vel


# ----------------------------- (BI)LINEAR FORMS ------------------------------
mass = rho_fl * x * u * v

stokes = mu_fl * x * InnerProduct(grad_u, grad_v) + mu_fl * ux * vx / x
stokes += - p * (vx + x * div_v) - q * (ux + x * div_u)
stokes += - pReg * p * q

convect = rho_fl * x * InnerProduct(grad_u * vel, v)
convect_lin = rho_fl * x * InnerProduct(grad_u * vel, v)
convect_lin += rho_fl * x * InnerProduct(grad_vel * u, v)


nitsche = -mu_fl * x * InnerProduct(grad_u * n_lset, v)
nitsche += -mu_fl * x * InnerProduct(grad_v * n_lset, u)
nitsche += mu_fl * (gamma_n * k * k / h) * x * InnerProduct(u, v)
nitsche += p * x * InnerProduct(v, n_lset)
nitsche += q * x * InnerProduct(u, n_lset)


ghost_penalty_ext = gamma_e * K_tilde * (mu_fl + 1 / mu_fl) / h**2 \
                        * x * ((ux - ux.Other()) * (vx - vx.Other())
                                + (uy - uy.Other()) * (vy - vy.Other()))
ghost_penalty_stab = gamma_s * mu_fl / h**2\
                        * x * ((ux - ux.Other()) * (vx - vx.Other())
                                + (uy - uy.Other()) * (vy - vy.Other()))
ghost_penalty_stab += -gamma_s / mu_fl * x * (p - p.Other()) * (q - q.Other())


bdf1_rhs = rho_fl * x * InnerProduct(vel_last, v)
bdf2_rhs = rho_fl * x * InnerProduct(2 * vel_last, v)
bdf2_rhs += - rho_fl * x * InnerProduct((1 / 2) * vel_last2, v)


nitsche_rhs = -mu_fl * x * InnerProduct(grad_v * n_lset, y_speed_vec(vel_ball))
nitsche_rhs += mu_fl * (gamma_n * k**2 / h) \
                * x * InnerProduct(y_speed_vec(vel_ball), v)
nitsche_rhs += q * x * InnerProduct(y_speed_vec(vel_ball), n_lset)


# -------------------------------- INTEGRATORS --------------------------------
def InnerBFI(form, **kwargs):
    return (SymbolicBFI(lset_neg, form=form.Compile(compile_flag,
                                                    wait=wait_compile),
                        **kwargs), "inner")


def BoundaryBFI(form, **kwargs):
    return (SymbolicBFI(lset_if, form=form.Compile(compile_flag,
                                                   wait=wait_compile),
                        **kwargs), "boundary")


def GhostPenaltyBFI(form, **kwargs):
    return(SymbolicFacetPatchBFI(form=form.Compile(compile_flag,
                                                   wait=wait_compile),
                                 skeleton=False), kwargs["domain"])


def InnerLFI(form, **kwargs):
    return (SymbolicLFI(lset_neg, form=form.Compile(compile_flag,
                                                    wait=wait_compile),
                        **kwargs), "inner")


def BoundaryLFI(form, **kwargs):
    return (SymbolicLFI(lset_if, form=form.Compile(compile_flag,
                                                   wait=wait_compile),
                        **kwargs), "boundary")


def BuildNewtonSystem(integrators, integrators_lin, element_map, els_restr):
    mStar = RestrictedBilinearForm(
        X, element_restriction=els_restr["elements"],
        facet_restriction=els_restr["facet"], check_unused=False)
    mStar_lin = RestrictedBilinearForm(
        X, element_restriction=els_restr["elements"],
        facet_restriction=els_restr["facet"], check_unused=False)
    f = LinearForm(X)

    AddIntegratorsToForm(integrators=integrators, a=mStar, f=f,
                         element_map=integrator_markers)
    AddIntegratorsToForm(integrators=integrators_lin, a=mStar_lin,
                         f=None, element_map=integrator_markers)

    return mStar, mStar_lin, f


integrator_markers = {"inner": els["hasneg"],
                      "boundary": els["if"],
                      "facets_ext": facets["gp_ext"],
                      "facets_if": facets["gp_stab"],
                      "bottom": None}
els_restr = {"elements": els["active"], "facet": facets["active"]}

integrators_bdf1, integrators_bdf1_lin = [], []
integrators_bdf2, integrators_bdf2_lin = [], []

integrators_bdf1.append(InnerBFI(mass + dt_bdf1 * (stokes + convect)))
integrators_bdf1.append(BoundaryBFI(dt_bdf1 * nitsche))
integrators_bdf1.append(GhostPenaltyBFI(dt_bdf1 * ghost_penalty_ext,
                                        domain="facets_ext"))
integrators_bdf1.append(GhostPenaltyBFI(dt_bdf1 * ghost_penalty_stab,
                                        domain="facets_if"))
integrators_bdf1.append(InnerLFI(bdf1_rhs))
integrators_bdf1.append(BoundaryLFI(dt_bdf1 * nitsche_rhs))

integrators_bdf1_lin.append(InnerBFI(mass + dt_bdf1 * (stokes + convect_lin)))
integrators_bdf1_lin.append(BoundaryBFI(dt_bdf1 * nitsche))
integrators_bdf1_lin.append(GhostPenaltyBFI(dt_bdf1 * ghost_penalty_ext,
                                            domain="facets_ext"))
integrators_bdf1_lin.append(GhostPenaltyBFI(dt_bdf1 * ghost_penalty_stab,
                                            domain="facets_if"))

integrators_bdf2.append(InnerBFI(3 / 2 * mass + dt * (stokes + convect)))
integrators_bdf2.append(BoundaryBFI(dt * nitsche))
integrators_bdf2.append(GhostPenaltyBFI(dt * ghost_penalty_ext,
                                        domain="facets_ext"))
integrators_bdf2.append(GhostPenaltyBFI(dt * ghost_penalty_stab,
                                        domain="facets_if"))
integrators_bdf2.append(InnerLFI(bdf2_rhs))
integrators_bdf2.append(BoundaryLFI(dt * nitsche_rhs))

integrators_bdf2_lin.append(InnerBFI(3 / 2 * mass + dt * (stokes + convect_lin)))
integrators_bdf2_lin.append(BoundaryBFI(dt * nitsche))
integrators_bdf2_lin.append(GhostPenaltyBFI(dt * ghost_penalty_ext,
                                            domain="facets_ext"))
integrators_bdf2_lin.append(GhostPenaltyBFI(dt * ghost_penalty_stab,
                                            domain="facets_if"))


# -------------------------------- FUNCTIONALS --------------------------------
stress = x * mu_fl * grad_u * n_lset - x * p * n_lset

drag_x_test, drag_y_test = GridFunction(X), GridFunction(X)
drag_x_test.components[0].Set(CoefficientFunction(1.0))
drag_y_test.components[1].Set(CoefficientFunction(1.0))
res = gfu.vec.CreateVector()


def ComputeDrag():
    a = RestrictedBilinearForm(X, element_restriction=els["if"],
                               facet_restriction=facets["none"],
                               check_unused=False)
    a += SymbolicBFI(lset_if, form=InnerProduct(stress, v),
                     definedonelements=els["if"])
    a.Apply(gfu.vec, res)

    drag_x = -2 * pi * InnerProduct(res, drag_x_test.vec)
    drag_y = -2 * pi * InnerProduct(res, drag_y_test.vec)

    del a

    return drag_x, drag_y


# ---------------------------------- OUTPUT -----------------------------------
functionals = {"time": [], "K_tilde": [], "height":[], "vel": [], "drag_x": [],
               "drag_y": []}

if not os.path.isdir(out_dir):
    os.makedirs(out_dir)

fid = open(out_dir + "/" + out_file, "w")
fid.write("time\tK_tilde\theight\tvel_ball\tdrag_y\n")
fid.close()


def WriteToFile(str_out):
    fid = open(out_dir + "/" + out_file, "a")
    fid.write(str_out) 
    fid.close()

    return None


def CollectAndWriteOutput():
    functionals["time"].append(t.Get())
    functionals["K_tilde"].append(int(K_tilde.Get()))
    functionals["height"].append(height)
    functionals["vel"].append(vel_ball.Get())
    functionals["drag_y"].append(drag_y)
    functionals["drag_x"].append(drag_x)

    str_out = "{:10.8f}\t{:1d}".format(functionals["time"][-1], 
                                      functionals["K_tilde"][-1])
    for val in ["height", "vel", "drag_y"]:
        str_out += "\t{:9.7e}".format(functionals[val][-1])
    str_out += "\n"

    WriteToFile(str_out)

    return None


# ------------------------------- VISUALISATION -------------------------------
nan_2d_cf = CoefficientFunction((float("NaN"), float("NaN")))
Draw(IfPos(-lsetp1, vel, nan_2d_cf), mesh, "u")
Draw(IfPos(-lsetp1, pre, CoefficientFunction(float("NaN"))), mesh, "p")


# ------------------------------- TIME STEPPING -------------------------------
with TaskManager():

    if not restart:
        gfu.vec[:], gfu_last.vec[:], gfu_last2.vec[:] = 0, 0, 0

        # BDF1 warm up
        for it in range(1, bdf1_steps + 1):
            t.Set(it * dt_bdf1)

            # Store data from previous time-step
            vel_ball_last, height_last = vel_ball.Get(), height
            gfu_last.vec.data = gfu.vec
            deform_last.vec.data = deformation.vec
            UpdateMarkers(els["act_old"], els["active"])

            # Sub-iteration
            for sub_it in range(max_inner_it):
                # Update deformation

                # Update velocity and position of ball
                vel_ball_tmp = vel_ball.Get()
                vel_ball_upd = vel_ball_last + (dt_bdf1 / mass_ball) * (force_b + drag_y)
                vel_ball.Set(vel_ball_tmp * (1 - relax_ode) 
                            + relax_ode * vel_ball_upd)
                if abs(vel_ball.Get() - vel_ball_tmp) < tol_ode:
                    print("Used {} subiterations".format(sub_it))
                    break

                height = height_last + dt_bdf1 * vel_ball.Get()

                # Set up parameters and variables
                delta = c_delta * dt * abs(vel_ball.Get())
                K_tilde.Set(ceil(delta / (h_max / inner_factor)))

                deformation = lset_meshadap.CalcDeformation(levelset_func(height))
                for i in range(2):
                    gfu_last_on_new_mesh.components[i].Set(
                        shifted_eval(gfu_last.components[i],
                                     deform_last, deformation))

                UpdateElementInformation(height=height, bdf=1, it=it)

                # Solve linearised system
                mesh.SetDeformation(deformation)
                mStar, mStar_lin, f = BuildNewtonSystem(integrators_bdf1,
                                                        integrators_bdf1_lin,
                                                        integrator_markers,
                                                        els_restr)
                f.Assemble()
                CutFEM_QuasiNewton(a=mStar, alin=mStar_lin, u=gfu, f=f.vec,
                                   freedofs=free_dofs, maxit=maxit_newt,
                                   maxerr=tol_newt, inverse=inverse,
                                   jacobi_update_tol=jacobi_tol_newt, reuse=False)

                drag_x, drag_y = ComputeDrag()
                mesh.UnsetDeformation()

            else:
                print("WARNING: Subiteration may not have converged")
                print("Last update = {:1.2e}".format(vel_ball.Get() - vel_ball_tmp))

            # Do output
            CollectAndWriteOutput()
            Redraw()

            print("t = {:10.6f}, height = {:6.4f}, vel_ball = {:5.3f} active_els"
                  " = {:}, K = {:d} - 1".format(t.Get(), height, vel_ball.Get(),
                                                sum(els["active"]),
                                                int(K_tilde.Get())))
    else:
        # Copy data into necessary variables
        last_step = _step_pic
        gfu.vec.data = _gfu_pic.vec
        gfu_last.vec.data = _gfu_last_pic.vec
        deformation.vec.data = _deformation_pic.vec
        deform_last.vec.data = _deform_last_pic.vec
        for key, array in els.items():
            UpdateMarkers(array, _els_pic[key])
        for key, array in facets.items():
            UpdateMarkers(array, _facets_pic[key])
        vel_ball.Set(_vel_ball_pic)
        vel_ball_last = _vel_ball_last_pic
        height, height_last = _height_pic, _height_last_pic
        for key, vals in _functionals_pic.items():
            functionals[key] = vals
        drag_y = functionals["drag_y"][-1]

        # Write functionals to file
        for i in range(len(_functionals_pic["time"])):
            str_out = "{:10.8f}\t{:1d}".format(functionals["time"][i], 
                                              functionals["K_tilde"][i])
            for val in ["height", "vel", "drag_y"]:
                str_out += "\t{:9.7e}".format(functionals[val][i])
            str_out += "\n"

            WriteToFile(str_out)

        # Remove unnecessary (more memory intensive) variables
        del unpickler, _param_pic, _gfu_last_pic, _els_pic, \
            _facets_pic, _functionals_pic, _gfu_pic, 


    # BDF2
    if not restart:
        # Reset initial condition
        last_step = 1
        vel_ball_last, height_last = initial_vel, initial_height
        gfu_last.vec.data = gfu_last2.vec

    for it in range(last_step + 1, int(t_end * dt_inv + 0.5) + 1):
        t.Set(it * dt)

        # Store data from previous time-step
        vel_ball_last2, height_last2 = vel_ball_last, height_last
        vel_ball_last, height_last = vel_ball.Get(), height
        gfu_last2.vec.data = gfu_last.vec
        gfu_last.vec.data = gfu.vec
        deform_last2.vec.data = deform_last.vec
        deform_last.vec.data = deformation.vec
        UpdateMarkers(els["act_old2"], els["act_old"])
        UpdateMarkers(els["act_old"], els["active"])

        # Sub-iteration
        extend_both = False
        for sub_it in range(max_inner_it):
            # Check contact force condition
            height = 2 / 3 * (2 * height_last - 1 / 2 * height_last2
                              + dt * vel_ball.Get())
            if height < 0:
                print("WARNING: Part of ball outside of physical domain!")
            if height < dist0:
                force_c = gamma_c * (dist0 - height) / height
            else:
                force_c = 0

            # Update velocity and position of ball
            vel_ball_tmp = vel_ball.Get()
            vel_ball_upd = 2 / 3 * (2 * vel_ball_last - 1 / 2 * vel_ball_last2
                            + dt / mass_ball * (force_b + drag_y + force_c))
            vel_ball.Set((1 - relax_ode) * vel_ball_tmp 
                         + relax_ode * vel_ball_upd)

            if abs(vel_ball.Get() - vel_ball_tmp) < tol_ode:
                print("Used {} subiterations".format(sub_it))
                break

            height = 2 / 3 * (2 * height_last - 1 / 2 * height_last2
                              + dt * vel_ball.Get())

            # Set up parameters and variables
            v_delta = max(list(map(abs, [vel_ball.Get(), 2*vel_ball.Get() - vel_ball_last])))
            delta = c_delta * dt * v_delta
            K_tilde.Set(ceil(delta / (h_max / inner_factor)))
            
            deformation = lset_meshadap.CalcDeformation(levelset_func(height))
            for i in range(2):
                gfu_last_on_new_mesh.components[i].Set(
                    shifted_eval(gfu_last.components[i], 
                                 deform_last, deformation))
                gfu_last2_on_new_mesh.components[i].Set(
                    shifted_eval(gfu_last2.components[i], 
                                 deform_last2, deformation))

            UpdateElementInformation(height=height, bdf=2, it=it)

            # Solve linearised system
            mesh.SetDeformation(deformation)
            mStar, mStar_lin, f = BuildNewtonSystem(integrators_bdf2,
                                                    integrators_bdf2_lin,
                                                    integrator_markers,
                                                    els_restr)
            f.Assemble()
            CutFEM_QuasiNewton(a=mStar, alin=mStar_lin, u=gfu, f=f.vec,
                               freedofs=free_dofs, maxit=maxit_newt,
                               maxerr=tol_newt, inverse=inverse,
                               jacobi_update_tol=jacobi_tol_newt, reuse=False)

            drag_x, drag_y = ComputeDrag()
            mesh.UnsetDeformation()

        else:
            print("WARNING: Subiteration may not have converged")
            print("Last update = {:1.2e}".format(vel_ball.Get() - vel_ball_tmp))

        # Do output
        CollectAndWriteOutput()
        Redraw()
        
        if pickle_out and (it * pickle_freq) % dt_inv == 0:
            PickleSolution(it)

        print("t = {:10.6f}, height = {:6.4f}, vel_ball = {:5.3f} active_els"
              " = {:}, K = {:d}".format(t.Get(), height, vel_ball.Get(),
                                        sum(els["active"]), int(K_tilde.Get())))


# ------------------------------ POST-PROCESSING ------------------------------
end_time = time.time() - start_time

print("\n----------- Total time: {:02.0f}:{:02.0f}:{:02.0f}:{:06.3f}"
      " ----------".format(end_time // (24 * 60 * 60),
                           end_time % (24 * 60 * 60) // (60 * 60),
                           end_time % 3600 // 60,
                           end_time % 60))
