using Oceananigans.Operators: assumed_field_location

#####
##### Default boundary conditions
#####

struct DefaultBoundaryCondition{BC}
    boundary_condition :: BC
end

DefaultBoundaryCondition() = DefaultBoundaryCondition(NoFluxBoundaryCondition())

default_prognostic_bc(::Grids.Periodic, loc,      default)  = PeriodicBoundaryCondition()
default_prognostic_bc(::FullyConnected, loc,      default)  = CommunicationBoundaryCondition()
default_prognostic_bc(::Flat,           loc,      default)  = nothing
default_prognostic_bc(::Bounded,        ::Center, default)  = default.boundary_condition
default_prognostic_bc(::LeftConnected,  ::Center, default)  = default.boundary_condition
default_prognostic_bc(::RightConnected, ::Center, default)  = default.boundary_condition

# TODO: make model constructors enforce impenetrability on velocity components to simplify this code
default_prognostic_bc(::Bounded,        ::Face, default)    = ImpenetrableBoundaryCondition()
default_prognostic_bc(::LeftConnected,  ::Face, default)    = ImpenetrableBoundaryCondition()
default_prognostic_bc(::RightConnected, ::Face, default)    = ImpenetrableBoundaryCondition()

default_prognostic_bc(::Bounded,        ::Nothing, default) = nothing
default_prognostic_bc(::Flat,           ::Nothing, default) = nothing
default_prognostic_bc(::Grids.Periodic, ::Nothing, default) = nothing
default_prognostic_bc(::FullyConnected, ::Nothing, default) = nothing
default_prognostic_bc(::LeftConnected,  ::Nothing, default) = nothing
default_prognostic_bc(::RightConnected, ::Nothing, default) = nothing

default_auxiliary_bc(topo, loc) = default_prognostic_bc(topo, loc, DefaultBoundaryCondition())
default_auxiliary_bc(::Bounded, ::Face)        = nothing
default_auxiliary_bc(::RightConnected, ::Face) = nothing
default_auxiliary_bc(::LeftConnected,  ::Face) = nothing

#####
##### Field boundary conditions
#####

mutable struct FieldBoundaryConditions{W, E, S, N, B, T, I}
        west :: W
        east :: E
       south :: S
       north :: N
      bottom :: B
         top :: T
    immersed :: I
end

function FieldBoundaryConditions(indices::Tuple, west, east, south, north, bottom, top, immersed)
    # Turn bcs in windowed dimensions into nothing
    west, east   = window_boundary_conditions(indices[1], west, east)
    south, north = window_boundary_conditions(indices[2], south, north)
    bottom, top  = window_boundary_conditions(indices[3], bottom, top)
    return FieldBoundaryConditions(west, east, south, north, bottom, top, immersed)
end

FieldBoundaryConditions(indices::Tuple, bcs::FieldBoundaryConditions) =
    FieldBoundaryConditions(indices, (getproperty(bcs, side) for side in fieldnames(FieldBoundaryConditions))...)

window_boundary_conditions(::Colon, left, right) = left, right
window_boundary_conditions(::UnitRange, left, right) = nothing, nothing

"""
    FieldBoundaryConditions(; kwargs...)

Return a template for boundary conditions on prognostic fields.

Keyword arguments
=================

Keyword arguments specify boundary conditions on the 7 possible boundaries:

- `west`: left end point in the `x`-direction where `i = 1`
- `east`: right end point in the `x`-direction where `i = grid.Nx`
- `south`: left end point in the `y`-direction where `j = 1`
- `north`: right end point in the `y`-direction where `j = grid.Ny`
- `bottom`: right end point in the `z`-direction where `k = 1`
- `top`: right end point in the `z`-direction where `k = grid.Nz`
- `immersed`: boundary between solid and fluid for immersed boundaries

If a boundary condition is unspecified, the default for prognostic fields
and the topology in the boundary-normal direction is used:

 - `PeriodicBoundaryCondition` for `Periodic` directions
 - `NoFluxBoundaryCondition` for `Bounded` directions and `Centered`-located fields
 - `ImpenetrableBoundaryCondition` for `Bounded` directions and `Face`-located fields
 - `nothing` for `Flat` directions and/or `Nothing`-located fields
"""
FieldBoundaryConditions(default_bounded_bc = NoFluxBoundaryCondition();
                        west = DefaultBoundaryCondition(default_bounded_bc),
                        east = DefaultBoundaryCondition(default_bounded_bc),
                        south = DefaultBoundaryCondition(default_bounded_bc),
                        north = DefaultBoundaryCondition(default_bounded_bc),
                        bottom = DefaultBoundaryCondition(default_bounded_bc),
                        top = DefaultBoundaryCondition(default_bounded_bc),
                        immersed = DefaultBoundaryCondition(default_bounded_bc)) = 
    FieldBoundaryConditions(west, east, south, north, bottom, top, immersed)

"""
    FieldBoundaryConditions(grid, location, indices=(:, :, :);
                            west     = default_auxiliary_bc(topology(grid, 1)(), location[1]()),
                            east     = default_auxiliary_bc(topology(grid, 1)(), location[1]()),
                            south    = default_auxiliary_bc(topology(grid, 2)(), location[2]()),
                            north    = default_auxiliary_bc(topology(grid, 2)(), location[2]()),
                            bottom   = default_auxiliary_bc(topology(grid, 3)(), location[3]()),
                            top      = default_auxiliary_bc(topology(grid, 3)(), location[3]()),
                            immersed = NoFluxBoundaryCondition())

Return boundary conditions for auxiliary fields (fields whose values are
derived from a model's prognostic fields) on `grid` and at `location`.

Keyword arguments
=================

Keyword arguments specify boundary conditions on the 6 possible boundaries:

- `west`, left end point in the `x`-direction where `i = 1`
- `east`, right end point in the `x`-direction where `i = grid.Nx`
- `south`, left end point in the `y`-direction where `j = 1`
- `north`, right end point in the `y`-direction where `j = grid.Ny`
- `bottom`, right end point in the `z`-direction where `k = 1`
- `top`, right end point in the `z`-direction where `k = grid.Nz`
- `immersed`: boundary between solid and fluid for immersed boundaries

If a boundary condition is unspecified, the default for auxiliary fields
and the topology in the boundary-normal direction is used:

- `PeriodicBoundaryCondition` for `Periodic` directions
- `GradientBoundaryCondition(0)` for `Bounded` directions and `Centered`-located fields
- `nothing` for `Bounded` directions and `Face`-located fields
- `nothing` for `Flat` directions and/or `Nothing`-located fields
"""
FieldBoundaryConditions(grid, location, indices=(:, :, :);
                        west     = default_auxiliary_bc(topology(grid, 1)(), location[1]()),
                        east     = default_auxiliary_bc(topology(grid, 1)(), location[1]()),
                        south    = default_auxiliary_bc(topology(grid, 2)(), location[2]()),
                        north    = default_auxiliary_bc(topology(grid, 2)(), location[2]()),
                        bottom   = default_auxiliary_bc(topology(grid, 3)(), location[3]()),
                        top      = default_auxiliary_bc(topology(grid, 3)(), location[3]()),
                        immersed = NoFluxBoundaryCondition()) =
    FieldBoundaryConditions(indices, west, east, south, north, bottom, top, immersed)

#####
##### Boundary condition "regularization"
#####
##### TODO: this probably belongs in Oceananigans.Models
#####

# Friendly warning?
function regularize_immersed_boundary_condition(ibc, grid, loc, field_name, args...)
    if !(ibc isa DefaultBoundaryCondition)
        msg = """
              $field_name was assigned an immersed $ibc, but this is not supported on
              $(summary(grid))
              The immersed boundary condition on $field_name will have no effect.
          """

        @warn msg
    end

    return NoFluxBoundaryCondition()
end

regularize_boundary_condition(default::DefaultBoundaryCondition, grid, loc, dim, args...) =
    default_prognostic_bc(topology(grid, dim)(), loc[dim](), default)

regularize_boundary_condition(bc, args...) = bc # fallback

# Convert all `Number` boundary conditions to `eltype(grid)`
regularize_boundary_condition(bc::BoundaryCondition{C, <:Number}, grid, args...) where C =
    BoundaryCondition(bc.classification, convert(eltype(grid), bc.condition))

""" 
    regularize_field_boundary_conditions(bcs::FieldBoundaryConditions,
                                         grid::AbstractGrid,
                                         field_name::Symbol,
                                         prognostic_names=nothing)

Compute default boundary conditions and attach field locations to ContinuousBoundaryFunction
boundary conditions for prognostic model field boundary conditions.

!!! warn "No support for `ContinuousBoundaryFunction` for immersed boundary conditions"
    Do not regularize immersed boundary conditions.

    Currently, there is no support `ContinuousBoundaryFunction` for immersed boundary
    conditions.
"""
function regularize_field_boundary_conditions(bcs::FieldBoundaryConditions,
                                              grid::AbstractGrid,
                                              field_name::Symbol,
                                              prognostic_names=nothing)

    loc = assumed_field_location(field_name)
    
    west     = regularize_boundary_condition(bcs.west,   grid, loc, 1, LeftBoundary,  prognostic_names)
    east     = regularize_boundary_condition(bcs.east,   grid, loc, 1, RightBoundary, prognostic_names)
    south    = regularize_boundary_condition(bcs.south,  grid, loc, 2, LeftBoundary,  prognostic_names)
    north    = regularize_boundary_condition(bcs.north,  grid, loc, 2, RightBoundary, prognostic_names)
    bottom   = regularize_boundary_condition(bcs.bottom, grid, loc, 3, LeftBoundary,  prognostic_names)
    top      = regularize_boundary_condition(bcs.top,    grid, loc, 3, RightBoundary, prognostic_names)

    immersed = regularize_immersed_boundary_condition(bcs.immersed, grid, loc, field_name, prognostic_names)

    return FieldBoundaryConditions(west, east, south, north, bottom, top, immersed)
end

# For nested NamedTuples of boundary conditions (eg diffusivity boundary conditions)
function regularize_field_boundary_conditions(boundary_conditions::NamedTuple,
                                              grid::AbstractGrid,
                                              group_name::Symbol,
                                              prognostic_names=nothing)

    return NamedTuple(field_name => regularize_field_boundary_conditions(field_bcs, grid, field_name, prognostic_names)
                      for (field_name, field_bcs) in pairs(boundary_conditions))
end

regularize_field_boundary_conditions(::Missing,
                                     grid::AbstractGrid,
                                     field_name::Symbol,
                                     prognostic_names=nothing) = missing

#####
##### Outer interface for model constructors
#####

regularize_field_boundary_conditions(boundary_conditions::NamedTuple, grid::AbstractGrid, prognostic_names::Tuple) =
    NamedTuple(field_name => regularize_field_boundary_conditions(field_bcs, grid, field_name, prognostic_names)
               for (field_name, field_bcs) in pairs(boundary_conditions))
