module TestElixirs

using LinearAlgebra
using SparseArrays
using Test
using Trixi
using OrdinaryDiffEqSSPRK: SSPRK43

import ForwardDiff

include("test_trixi.jl")

# Start with a clean environment: remove Trixi.jl output directory if it exists
outdir = "out"
isdir(outdir) && rm(outdir, recursive = true)

EXAMPLES_DIR = examples_dir()

@testset "Special elixirs" begin
#! format: noindent

@testset "Convergence test" begin
    @timed_testset "tree_2d_dgsem" begin
        eocs, _ = convergence_test(@__MODULE__,
                                   joinpath(EXAMPLES_DIR, "tree_2d_dgsem",
                                            "elixir_advection_extended.jl"),
                                   3, initial_refinement_level = 2)
        mean_convergence = Trixi.calc_mean_convergence(eocs)
        @test isapprox(mean_convergence[:l2], [4.0], rtol = 0.05)
    end

    @timed_testset "structured_2d_dgsem" begin
        eocs, _ = convergence_test(@__MODULE__,
                                   joinpath(EXAMPLES_DIR,
                                            "structured_2d_dgsem",
                                            "elixir_advection_extended.jl"),
                                   3, cells_per_dimension = (5, 9))
        mean_convergence = Trixi.calc_mean_convergence(eocs)
        @test isapprox(mean_convergence[:l2], [4.0], rtol = 0.05)
    end

    @timed_testset "structured_2d_dgsem coupled" begin
        eocs, _ = convergence_test(@__MODULE__,
                                   joinpath(EXAMPLES_DIR,
                                            "structured_2d_dgsem",
                                            "elixir_advection_coupled.jl"),
                                   3)
        for i in Trixi.eachsystem(semi)
            mean_convergence = Trixi.calc_mean_convergence(eocs[i])
            @test isapprox(mean_convergence[:l2], [4.0], rtol = 0.05)
            @test isapprox(mean_convergence[:l2], [4.0], rtol = 0.05)
        end
    end

    @timed_testset "p4est_2d_dgsem" begin
        # Run convergence test on unrefined mesh
        no_refine = @cfunction((p4est, which_tree, quadrant)->Cint(0), Cint,
                               (Ptr{Trixi.p4est_t}, Ptr{Trixi.p4est_topidx_t},
                                Ptr{Trixi.p4est_quadrant_t}))
        eocs, _ = convergence_test(@__MODULE__,
                                   joinpath(EXAMPLES_DIR, "p4est_2d_dgsem",
                                            "elixir_euler_source_terms_nonconforming_unstructured_flag.jl"),
                                   2, refine_fn_c = no_refine)
        mean_convergence = Trixi.calc_mean_convergence(eocs)
        @test isapprox(mean_convergence[:linf], [3.2, 3.2, 4.0, 3.7], rtol = 0.05)
    end

    @timed_testset "structured_3d_dgsem" begin
        eocs, _ = convergence_test(@__MODULE__,
                                   joinpath(EXAMPLES_DIR,
                                            "structured_3d_dgsem",
                                            "elixir_advection_basic.jl"),
                                   2, cells_per_dimension = (7, 4, 5))
        mean_convergence = Trixi.calc_mean_convergence(eocs)
        @test isapprox(mean_convergence[:l2], [4.0], rtol = 0.05)
    end

    @timed_testset "p4est_3d_dgsem" begin
        eocs, _ = convergence_test(@__MODULE__,
                                   joinpath(EXAMPLES_DIR, "p4est_3d_dgsem",
                                            "elixir_advection_unstructured_curved.jl"),
                                   2, initial_refinement_level = 0)
        mean_convergence = Trixi.calc_mean_convergence(eocs)
        @test isapprox(mean_convergence[:l2], [2.7], rtol = 0.05)
    end

    @timed_testset "paper_self_gravitating_gas_dynamics" begin
        eocs, _ = convergence_test(@__MODULE__,
                                   joinpath(EXAMPLES_DIR,
                                            "paper_self_gravitating_gas_dynamics",
                                            "elixir_eulergravity_convergence.jl"),
                                   2, tspan = (0.0, 0.25),
                                   initial_refinement_level = 1)
        mean_convergence = Trixi.calc_mean_convergence(eocs)
        @test isapprox(mean_convergence[:l2], 4 * ones(4), atol = 0.4)
    end
end

@timed_testset "Test linear structure (2D)" begin
    trixi_include(@__MODULE__,
                  joinpath(EXAMPLES_DIR, "tree_2d_dgsem",
                           "elixir_advection_extended.jl"),
                  tspan = (0.0, 0.0), initial_refinement_level = 2)
    A, b = linear_structure(semi)
    λ = eigvals(Matrix(A))
    @test maximum(real, λ) < 10 * sqrt(eps(real(semi)))

    trixi_include(@__MODULE__,
                  joinpath(EXAMPLES_DIR, "tree_2d_dgsem",
                           "elixir_hypdiff_lax_friedrichs.jl"),
                  tspan = (0.0, 0.0), initial_refinement_level = 2)
    A, b = linear_structure(semi)
    λ = eigvals(Matrix(A))
    @test maximum(real, λ) < 10 * sqrt(eps(real(semi)))

    # check whether the user can modify `b` without changing `A`
    x = vec(ode.u0)
    Ax = A * x
    @. b = 2 * b + x
    @test A * x ≈ Ax
end

@testset "Test Jacobian of DG (1D)" begin
    @timed_testset "TreeMesh: Linear advection" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "tree_1d_fdsbp",
                               "elixir_advection_upwind.jl"),
                      tspan = (0.0, 0.0))

        A, _ = linear_structure(semi)

        J = jacobian_ad_forward(semi)
        @test Matrix(A) ≈ J
        λ = eigvals(J)
        @test maximum(real, λ) < 10 * sqrt(eps(real(semi)))

        J = jacobian_fd(semi)
        @test Matrix(A) ≈ J
        λ = eigvals(J)
        @test maximum(real, λ) < 10 * sqrt(eps(real(semi)))

        # See https://github.com/trixi-framework/Trixi.jl/pull/2514
        @test count(real.(λ) .>= -10) > 5
        # See https://github.com/trixi-framework/Trixi.jl/pull/2522
        t0 = zero(real(semi))
        u0_ode = 1e9 * compute_coefficients(t0, semi)
        J = jacobian_fd(semi; t0, u0_ode)
        λ = eigvals(J)
        @test count((-200 .<= real.(λ) .<= -10) .&& (-100 .<= imag.(λ) .<= 100)) == 0
        @test count(isapprox.(imag.(λ), 0.0, atol = 10 * sqrt(eps(real(semi))))) == 2
    end

    @timed_testset "StructuredMesh: Compressible Euler equations" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "structured_1d_dgsem",
                               "elixir_euler_source_terms.jl"),
                      tspan = (0.0, 0.0))

        J = jacobian_ad_forward(semi)
        λ = eigvals(J)
        @test maximum(real, λ) < 1e-13

        J = jacobian_fd(semi)
        λ = eigvals(J)
        @test maximum(real, λ) < 5e-8
    end
end

@testset "Test Jacobian of DG (2D)" begin
    @timed_testset "TreeMesh: Linear advection" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "tree_2d_dgsem",
                               "elixir_advection_extended.jl"),
                      tspan = (0.0, 0.0), initial_refinement_level = 2)
        A, _ = linear_structure(semi)

        J = jacobian_ad_forward(semi)
        @test Matrix(A) ≈ J
        λ = eigvals(J)
        @test maximum(real, λ) < 10 * sqrt(eps(real(semi)))

        J = jacobian_fd(semi)
        @test Matrix(A) ≈ J
        λ = eigvals(J)
        @test maximum(real, λ) < 10 * sqrt(eps(real(semi)))
    end

    @timed_testset "TreeMesh: Linear advection-diffusion" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "tree_2d_dgsem",
                               "elixir_advection_diffusion.jl"),
                      tspan = (0.0, 0.0), initial_refinement_level = 2)

        J = jacobian_ad_forward(semi)
        λ = eigvals(J)
        @test maximum(real, λ) < 10 * sqrt(eps(real(semi)))

        A, b = linear_structure(semi)
        @test Matrix(A) == J
        @test sparse(A) == sparse(J)
        # Ensure that we do not have excessive memory allocations
        # (e.g., from type instabilities)
        du = zero(b)
        u = zero(b)
        mul!(du, A, u) # compilation run
        @test (@allocated mul!(du, A, u)) == 0

        J_parabolic = jacobian_ad_forward_parabolic(semi)
        λ_parabolic = eigvals(J_parabolic)
        # Parabolic spectrum is real and negative
        @test maximum(real, λ_parabolic) < 2 * 10^(-14)
        @test maximum(imag, λ_parabolic) < 10^(-14)
    end

    @timed_testset "TreeMesh: Compressible Euler equations" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "tree_2d_dgsem",
                               "elixir_euler_density_wave.jl"),
                      tspan = (0.0, 0.0), initial_refinement_level = 1)

        J = jacobian_ad_forward(semi)
        λ = eigvals(J)
        @test maximum(real, λ) < 7.0e-7

        J = jacobian_fd(semi)
        λ = eigvals(J)
        @test maximum(real, λ) < 7.0e-3

        # This does not work yet because of the indicators...
        @test_skip begin
            trixi_include(@__MODULE__,
                          joinpath(EXAMPLES_DIR, "tree_2d_dgsem",
                                   "elixir_euler_shockcapturing.jl"),
                          tspan = (0.0, 0.0), initial_refinement_level = 1)
            jacobian_ad_forward(semi)
        end

        @timed_testset "DGMulti: Euler, weak form" begin
            equations = CompressibleEulerEquations2D(1.4)
            initial_condition = initial_condition_density_wave

            solver = DGMulti(polydeg = 5, element_type = Quad(),
                             approximation_type = SBP(),
                             surface_integral = SurfaceIntegralWeakForm(flux_central),
                             volume_integral = VolumeIntegralWeakForm())

            # DGMultiMesh is on [-1, 1]^ndims by default
            cells_per_dimension = (2, 2)
            mesh = DGMultiMesh(solver, cells_per_dimension,
                               periodicity = (true, true))

            semi = SemidiscretizationHyperbolic(mesh, equations, initial_condition,
                                                solver;
                                                boundary_conditions = boundary_condition_periodic)

            J = jacobian_ad_forward(semi)
            λ = eigvals(J)
            @test maximum(real, λ) < 7.0e-7
        end

        @timed_testset "DGMulti: Euler, SBP & flux differencing" begin
            equations = CompressibleEulerEquations2D(1.4)
            initial_condition = initial_condition_density_wave

            solver = DGMulti(polydeg = 5, element_type = Quad(),
                             approximation_type = SBP(),
                             surface_integral = SurfaceIntegralWeakForm(flux_central),
                             volume_integral = VolumeIntegralFluxDifferencing(flux_central))

            # DGMultiMesh is on [-1, 1]^ndims by default
            cells_per_dimension = (2, 2)
            mesh = DGMultiMesh(solver, cells_per_dimension,
                               periodicity = (true, true))

            semi = SemidiscretizationHyperbolic(mesh, equations, initial_condition,
                                                solver;
                                                boundary_conditions = boundary_condition_periodic)

            J = jacobian_ad_forward(semi)
            λ = eigvals(J)
            @test maximum(real, λ) < 7.0e-7
        end
    end

    @timed_testset "TreeMesh: Navier-Stokes" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "tree_2d_dgsem",
                               "elixir_navierstokes_taylor_green_vortex.jl"),
                      tspan = (0.0, 0.0), initial_refinement_level = 2)

        J = jacobian_ad_forward(semi)
        λ = eigvals(J)
        @test maximum(real, λ) < 0.2

        J_parabolic = jacobian_ad_forward_parabolic(semi)
        λ_parabolic = eigvals(J_parabolic)
        # Parabolic spectrum is real and negative
        @test maximum(real, λ_parabolic) < eps(Float64)
        @test maximum(imag, λ_parabolic) < 10^(-15)
    end

    @timed_testset "TreeMesh: MHD" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "tree_2d_dgsem",
                               "elixir_mhd_alfven_wave.jl"),
                      tspan = (0.0, 0.0), initial_refinement_level = 0)
        @test_nowarn jacobian_ad_forward(semi)
    end

    @timed_testset "UnstructuredMesh2D: Advection" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "unstructured_2d_dgsem",
                               "elixir_advection_basic.jl"),
                      tspan = (0.0, 0.0))

        @test_nowarn jacobian_ad_forward(semi)
    end

    @timed_testset "TreeMesh: EulerGravity" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR,
                               "paper_self_gravitating_gas_dynamics",
                               "elixir_eulergravity_convergence.jl"),
                      tspan = (0.0, 0.0), initial_refinement_level = 1)
        J = jacobian_ad_forward(semi)
        λ = eigvals(J)
        @test maximum(real, λ) < 1.5
    end

    @timed_testset "StructuredMesh: Polytropic Euler equations" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "structured_2d_dgsem",
                               "elixir_eulerpolytropic_wave.jl"),
                      cells_per_dimension = (6, 6),
                      tspan = (0.0, 0.0))

        J = jacobian_ad_forward(semi)
        λ = eigvals(J)
        @test maximum(real, λ) < 0.05
    end

    @timed_testset "P4estMesh: Navier-Stokes" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "p4est_2d_dgsem",
                               "elixir_navierstokes_viscous_shock.jl"),
                      tspan = (0.0, 0.0))

        J = jacobian_ad_forward(semi)
        λ = eigvals(J)
        @test maximum(real, λ) < 0.05

        J_parabolic = jacobian_ad_forward_parabolic(semi)
        λ_parabolic = eigvals(J_parabolic)
        # Parabolic spectrum is real and negative
        @test maximum(real, λ_parabolic) < 8e-14
        @test maximum(imag, λ_parabolic) < 8e-14
    end

    @timed_testset "T8codeMesh: Advection" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "t8code_2d_dgsem",
                               "elixir_advection_unstructured_flag.jl"),
                      tspan = (0.0, 0.0), initial_refinement_level = 0,
                      polydeg = 2)
        @test_nowarn jacobian_ad_forward(semi)
    end
end

@timed_testset "Test linear structure (3D)" begin
    trixi_include(@__MODULE__,
                  joinpath(EXAMPLES_DIR, "tree_3d_dgsem",
                           "elixir_advection_extended.jl"),
                  tspan = (0.0, 0.0), initial_refinement_level = 1)
    A, b = linear_structure(semi)
    λ = eigvals(Matrix(A))
    @test maximum(real, λ) < 10 * sqrt(eps(real(semi)))
end

@timed_testset "Test Jacobian of DG (3D)" begin
    @timed_testset "TreeMesh: Advection" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "tree_3d_dgsem",
                               "elixir_advection_extended.jl"),
                      tspan = (0.0, 0.0), initial_refinement_level = 1)
        A, _ = linear_structure(semi)

        J = jacobian_ad_forward(semi)
        @test Matrix(A) ≈ J

        J = jacobian_fd(semi)
        @test Matrix(A) ≈ J
    end

    @timed_testset "StructuredMesh: MHD" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "structured_3d_dgsem",
                               "elixir_mhd_alfven_wave.jl"),
                      cells_per_dimension = (2, 2, 2),
                      polydeg = 2,
                      tspan = (0.0, 0.0))

        @test_nowarn jacobian_ad_forward(semi)
    end

    @timed_testset "P4estMesh: Navier-Stokes" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "p4est_3d_dgsem",
                               "elixir_navierstokes_convergence.jl"),
                      initial_refinement_level = 0,
                      tspan = (0.0, 0.0))

        @test_nowarn jacobian_ad_forward(semi)
        @test_nowarn jacobian_ad_forward_parabolic(semi)
    end

    @timed_testset "T8CodeMesh: Advection" begin
        trixi_include(@__MODULE__,
                      joinpath(EXAMPLES_DIR, "t8code_3d_dgsem",
                               "elixir_advection_cubed_sphere.jl"),
                      polydeg = 2,
                      trees_per_face_dimension = 3, layers = 2,
                      tspan = (0.0, 0.0))

        @test_nowarn jacobian_ad_forward(semi)
    end
end

@testset "AD using ForwardDiff" begin
    @timed_testset "Euler equations 1D" begin
        function entropy_at_final_time(k) # k is the wave number of the initial condition
            equations = CompressibleEulerEquations1D(1.4)
            mesh = TreeMesh((-1.0,), (1.0,), initial_refinement_level = 3,
                            n_cells_max = 10^4, periodicity = true)
            solver = DGSEM(3, FluxHLL(min_max_speed_naive),
                           VolumeIntegralFluxDifferencing(flux_ranocha))
            initial_condition = (x, t, equations) -> begin
                rho = 2 + sinpi(k * sum(x))
                v1 = 0.1
                p = 10.0
                return prim2cons(SVector(rho, v1, p), equations)
            end
            semi = SemidiscretizationHyperbolic(mesh, equations, initial_condition,
                                                solver;
                                                boundary_conditions = boundary_condition_periodic,
                                                uEltype = typeof(k))
            ode = semidiscretize(semi, (0.0, 1.0))
            summary_callback = SummaryCallback()
            analysis_interval = 100
            analysis_callback = AnalysisCallback(semi, interval = analysis_interval)
            alive_callback = AliveCallback(analysis_interval = analysis_interval)
            callbacks = CallbackSet(summary_callback,
                                    analysis_callback,
                                    alive_callback)
            sol = solve(ode, SSPRK43(), callback = callbacks)
            return Trixi.integrate(entropy, sol.u[end], semi)
        end
        ForwardDiff.derivative(entropy_at_final_time, 1.0) ≈ -0.4524664696235628
    end

    @timed_testset "Linear advection 2D" begin
        function energy_at_final_time(k) # k is the wave number of the initial condition
            equations = LinearScalarAdvectionEquation2D(0.2, -0.7)
            mesh = TreeMesh((-1.0, -1.0), (1.0, 1.0), initial_refinement_level = 3,
                            n_cells_max = 10^4, periodicity = true)
            solver = DGSEM(3, flux_lax_friedrichs)
            initial_condition = (x, t, equation) -> begin
                x_trans = Trixi.x_trans_periodic_2d(x - equation.advection_velocity * t)
                return SVector(sinpi(k * sum(x_trans)))
            end
            semi = SemidiscretizationHyperbolic(mesh, equations, initial_condition,
                                                solver;
                                                boundary_conditions = boundary_condition_periodic,
                                                uEltype = typeof(k))
            ode = semidiscretize(semi, (0.0, 1.0))
            summary_callback = SummaryCallback()
            analysis_interval = 100
            analysis_callback = AnalysisCallback(semi, interval = analysis_interval)
            alive_callback = AliveCallback(analysis_interval = analysis_interval)
            stepsize_callback = StepsizeCallback(cfl = 1.6)
            callbacks = CallbackSet(summary_callback,
                                    analysis_callback,
                                    alive_callback,
                                    stepsize_callback)
            sol = solve(ode, CarpenterKennedy2N54(williamson_condition = false);
                        ode_default_options()..., adaptive = false, dt = 1.0,
                        callback = callbacks)
            return Trixi.integrate(energy_total, sol.u[end], semi)
        end
        ForwardDiff.derivative(energy_at_final_time, 1.0) ≈ 1.4388628342896945e-5
    end

    @timed_testset "elixir_euler_ad.jl" begin
        @test_trixi_include(joinpath(examples_dir(), "special_elixirs",
                                     "elixir_euler_ad.jl"))
    end
end
end

# Clean up afterwards: delete Trixi.jl output directory
@test_nowarn rm(outdir, recursive = true)

end #module
