# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
# Since these FMAs can increase the performance of many numerical algorithms,
# we need to opt-in explicitly.
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
@muladd begin
#! format: noindent

include("abstract_tree.jl")
include("serial_tree.jl")
include("parallel_tree.jl")

get_name(mesh::AbstractMesh) = mesh |> typeof |> nameof |> string

# Composite type to hold the actual tree in addition to other mesh-related data
# that is not strictly part of the tree.
# The mesh is really just about the connectivity, size, and location of the individual
# tree nodes. Neighbor information between interfaces or the large sides for mortars is
# something that is solver-specific and that might not be needed by all solvers (or in a
# different form). Also, these data values can be performance critical, so a mesh would
# have to store them for all solvers in an efficient way - OTOH, different solvers might
# use different cells of a shared mesh, so "efficient" is again solver dependent.
"""
    TreeMesh{NDIMS} <: AbstractMesh{NDIMS}

A Cartesian mesh based on trees of hypercubes to support adaptive mesh refinement.
"""
mutable struct TreeMesh{NDIMS, TreeType <: AbstractTree{NDIMS}, RealT <: Real} <:
               AbstractMesh{NDIMS}
    tree::TreeType
    current_filename::String
    unsaved_changes::Bool
    # These are needed for distributed memory (i.e., MPI) parallelization
    first_cell_by_rank::OffsetVector{Int, Vector{Int}}
    n_cells_by_rank::OffsetVector{Int, Vector{Int}}

    function TreeMesh{NDIMS, TreeType, RealT}(n_cells_max::Integer) where {NDIMS,
                                                                           TreeType <:
                                                                           AbstractTree{NDIMS},
                                                                           RealT <:
                                                                           Real}
        tree = TreeType(n_cells_max)
        current_filename = ""
        unsaved_changes = true
        first_cell_by_rank = OffsetVector(Int[], 0)
        n_cells_by_rank = OffsetVector(Int[], 0)

        return new(tree, current_filename, unsaved_changes,
                   first_cell_by_rank, n_cells_by_rank)
    end

    # TODO: Taal refactor, order of important arguments, use of n_cells_max?
    function TreeMesh{NDIMS, TreeType, RealT}(n_cells_max::Integer,
                                              domain_center::AbstractArray{RealT},
                                              domain_length::RealT,
                                              periodicity = false) where {NDIMS,
                                                                          TreeType <:
                                                                          AbstractTree{NDIMS},
                                                                          RealT <: Real}
        tree = TreeType(n_cells_max, domain_center, domain_length, periodicity)
        current_filename = ""
        unsaved_changes = true
        first_cell_by_rank = OffsetVector(Int[], 0)
        n_cells_by_rank = OffsetVector(Int[], 0)

        return new(tree, current_filename, unsaved_changes,
                   first_cell_by_rank, n_cells_by_rank)
    end
end

const TreeMesh1D = TreeMesh{1, TreeType} where {TreeType <: AbstractTree{1}}
const TreeMesh2D = TreeMesh{2, TreeType} where {TreeType <: AbstractTree{2}}
const TreeMesh3D = TreeMesh{3, TreeType} where {TreeType <: AbstractTree{3}}

const TreeMeshSerial{NDIMS} = TreeMesh{NDIMS, <:SerialTree{NDIMS}}
const TreeMeshParallel{NDIMS} = TreeMesh{NDIMS, <:ParallelTree{NDIMS}}

@inline mpi_parallel(mesh::TreeMeshSerial) = False()
@inline mpi_parallel(mesh::TreeMeshParallel) = True()

partition!(mesh::TreeMeshSerial) = nothing

# Constructor for passing the dimension and mesh type as an argument
function TreeMesh(::Type{TreeType}, args...;
                  RealT = Float64) where {NDIMS, TreeType <: AbstractTree{NDIMS}}
    return TreeMesh{NDIMS, TreeType, RealT}(args...)
end

# Constructor accepting a single number as center (as opposed to an array) for 1D
function TreeMesh{1, TreeType, RealT}(n::Int, center::RealT, len::RealT,
                                      periodicity = false) where {
                                                                  TreeType <:
                                                                  AbstractTree{1},
                                                                  RealT <: Real}
    return TreeMesh{1, TreeType, RealT}(n, SVector{1, RealT}(center), len, periodicity)
end

function TreeMesh{NDIMS, TreeType, RealT}(n_cells_max::Integer,
                                          domain_center::NTuple{NDIMS, RealT},
                                          domain_length::RealT,
                                          periodicity = false) where {NDIMS,
                                                                      TreeType <:
                                                                      AbstractTree{NDIMS},
                                                                      RealT <: Real}
    return TreeMesh{NDIMS, TreeType, RealT}(n_cells_max,
                                            SVector{NDIMS, RealT}(domain_center),
                                            domain_length, periodicity)
end

"""
    TreeMesh(coordinates_min::NTuple{NDIMS, Real},
             coordinates_max::NTuple{NDIMS, Real};
             n_cells_max,
             periodicity = false,
             initial_refinement_level,
             refinement_patches = (),
             coarsening_patches = (),
             RealT = Float64) where {NDIMS}

Create a `TreeMesh` in `NDIMS` dimensions with real type `RealT` covering the domain defined by
`coordinates_min` and `coordinates_max`. The mesh is initialized with a uniform
refinement to the specified `initial_refinement_level`. Further refinement and
coarsening patches can be specified using `refinement_patches` and
`coarsening_patches`, respectively. The maximum number of cells allowed in the mesh is
given by `n_cells_max`. The periodicity in each dimension can be specified using the
`periodicity` argument (default: non-periodic in all dimensions). If it is a single
`Bool`, the same periodicity is applied in all dimensions; otherwise, a tuple of
`Bool`s of length `NDIMS` must be provided. Note that the domain must be a hypercube, i.e.,
all dimensions must have the same length.
"""
function TreeMesh(coordinates_min::NTuple{NDIMS, Real},
                  coordinates_max::NTuple{NDIMS, Real};
                  n_cells_max,
                  periodicity = false,
                  initial_refinement_level,
                  refinement_patches = (),
                  coarsening_patches = (),
                  RealT = Float64) where {NDIMS}
    # check arguments
    if !(n_cells_max isa Integer && n_cells_max > 0)
        throw(ArgumentError("`n_cells_max` must be a positive integer (provided `n_cells_max = $n_cells_max`)"))
    end
    if !(initial_refinement_level isa Integer && initial_refinement_level >= 0)
        throw(ArgumentError("`initial_refinement_level` must be a non-negative integer (provided `initial_refinement_level = $initial_refinement_level`)"))
    end

    # Check if elements in coordinates_min and coordinates_max are all of type RealT
    for i in 1:NDIMS
        if !(coordinates_min[i] isa RealT)
            @warn "Element $i in `coordinates_min` is not of type $RealT (provided `coordinates_min[$i] = $(coordinates_min[i])`)"
        end
        if !(coordinates_max[i] isa RealT)
            @warn "Element $i in `coordinates_max` is not of type $RealT (provided `coordinates_max[$i] = $(coordinates_max[i])`)"
        end
    end

    coordinates_min_max_check(coordinates_min, coordinates_max)

    # TreeMesh requires equal domain lengths in all dimensions
    domain_center = @. convert(RealT, (coordinates_min + coordinates_max) / 2)
    domain_length = convert(RealT, coordinates_max[1] - coordinates_min[1])
    if !all(coordinates_max[i] - coordinates_min[i] ≈ domain_length for i in 2:NDIMS)
        throw(ArgumentError("The TreeMesh domain must be a hypercube (provided `coordinates_max` .- `coordinates_min` = $(coordinates_max .- coordinates_min))"))
    end

    # TODO: MPI, create nice interface for a parallel tree/mesh
    if mpi_isparallel()
        if mpi_isroot() && NDIMS != 2
            println(stderr,
                    "ERROR: The TreeMesh supports parallel execution with MPI only in 2 dimensions")
            MPI.Abort(mpi_comm(), 1)
        end
        TreeType = ParallelTree{NDIMS, RealT}
    else
        TreeType = SerialTree{NDIMS, RealT}
    end

    # Create mesh
    mesh = @trixi_timeit timer() "creation" TreeMesh{NDIMS, TreeType, RealT}(n_cells_max,
                                                                             domain_center,
                                                                             domain_length,
                                                                             periodicity)

    # Initialize mesh
    initialize!(mesh, initial_refinement_level, refinement_patches, coarsening_patches)

    return mesh
end

function initialize!(mesh::TreeMesh, initial_refinement_level,
                     refinement_patches, coarsening_patches)
    # Create initial refinement
    @trixi_timeit timer() "initial refinement" refine_uniformly!(mesh.tree,
                                                                 initial_refinement_level)

    # Apply refinement patches
    @trixi_timeit timer() "refinement patches" for patch in refinement_patches
        # TODO: Taal refactor, use multiple dispatch?
        if patch.type == "box"
            refine_box!(mesh.tree, patch.coordinates_min, patch.coordinates_max)
        elseif patch.type == "sphere"
            refine_sphere!(mesh.tree, patch.center, patch.radius)
        else
            error("unknown refinement patch type '$(patch.type)'")
        end
    end

    # Apply coarsening patches
    @trixi_timeit timer() "coarsening patches" for patch in coarsening_patches
        # TODO: Taal refactor, use multiple dispatch
        if patch.type == "box"
            coarsen_box!(mesh.tree, patch.coordinates_min, patch.coordinates_max)
        else
            error("unknown coarsening patch type '$(patch.type)'")
        end
    end

    # Partition the mesh among multiple MPI ranks (does nothing if run in serial)
    partition!(mesh)

    return nothing
end

function TreeMesh(coordinates_min::Real, coordinates_max::Real;
                  kwargs...)
    return TreeMesh((coordinates_min,), (coordinates_max,); kwargs...)
end

function Base.show(io::IO, mesh::TreeMesh{NDIMS, TreeType}) where {NDIMS, TreeType}
    print(io, "TreeMesh{", NDIMS, ", ", TreeType, "} with length ", mesh.tree.length)
    return nothing
end

function Base.show(io::IO, ::MIME"text/plain",
                   mesh::TreeMesh{NDIMS, TreeType}) where {NDIMS, TreeType}
    if get(io, :compact, false)
        show(io, mesh)
    else
        setup = [
            "center" => mesh.tree.center_level_0,
            "length" => mesh.tree.length_level_0,
            "periodicity" => mesh.tree.periodicity,
            "current #cells" => mesh.tree.length,
            "#leaf-cells" => count_leaf_cells(mesh.tree),
            "maximum #cells" => mesh.tree.capacity
        ]
        summary_box(io, "TreeMesh{" * string(NDIMS) * ", " * string(TreeType) * "}",
                    setup)
    end
end

@inline Base.ndims(mesh::TreeMesh) = ndims(mesh.tree)

# Obtain the mesh filename from a restart file
function get_restart_mesh_filename(restart_filename, mpi_parallel::False)
    # Get directory name
    dirname, _ = splitdir(restart_filename)

    # Read mesh filename from restart file
    mesh_file = ""
    h5open(restart_filename, "r") do file
        mesh_file = read(attributes(file)["mesh_file"])
        return nothing
    end

    # Construct and return filename
    return joinpath(dirname, mesh_file)
end

function total_volume(mesh::TreeMesh)
    return mesh.tree.length_level_0^ndims(mesh)
end

isperiodic(mesh::TreeMesh) = isperiodic(mesh.tree)
isperiodic(mesh::TreeMesh, dimension) = isperiodic(mesh.tree, dimension)

Base.real(::TreeMesh{NDIMS, TreeType, RealT}) where {NDIMS, TreeType, RealT} = RealT

include("parallel_tree_mesh.jl")
end # @muladd
