using Printf
using Glob
using JLD2
using PyCall

using Oceananigans.Utils: prettytime

np = pyimport("numpy")
ma = pyimport("numpy.ma")
plt = pyimport("matplotlib.pyplot")
mticker = pyimport("matplotlib.ticker")
ccrs = pyimport("cartopy.crs")
cmocean = pyimport("cmocean")

function plot_cubed_sphere_tracer_field!(fig, ax, var, grid; add_colorbar, transform, cmap, vmin, vmax)

    Nf = grid["faces"] |> keys |> length
    Nx = grid["faces/1/Nx"]
    Ny = grid["faces/1/Ny"]
    Hx = grid["faces/1/Hx"]
    Hy = grid["faces/1/Hy"]

    for face in 1:Nf
        λᶠᶠᵃ = grid["faces/$face/λᶠᶠᵃ"][1+Hx:Nx+2Hx, 1+Hy:Ny+2Hy]
        φᶠᶠᵃ = grid["faces/$face/φᶠᶠᵃ"][1+Hx:Nx+2Hx, 1+Hy:Ny+2Hy]

        var_face = var[face][:, :, 1]

        var_face_masked = ma.masked_where(np.isnan(var_face), var_face)

        im = ax.pcolormesh(λᶠᶠᵃ, φᶠᶠᵃ, var_face_masked; transform, cmap, vmin, vmax)

        # Add colorbar below all the subplots.
        if add_colorbar && face == Nf
            ax_cbar = fig.add_axes([0.25, 0.1, 0.5, 0.02])
            fig.colorbar(im, cax=ax_cbar, orientation="horizontal")
        end

        ax.set_global()
    end

    return ax
end

function animate_rossby_haurwitz_three_spheres(; projection=ccrs.NearsidePerspective(central_longitude=0, central_latitude=0))

    ## Extract data

    file = jldopen("cubed_sphere_rossby_haurwitz.jld2")

    iterations = parse.(Int, keys(file["timeseries/t"]))

    ## Makie movie of u, v, η

    for (n, i) in enumerate(iterations)
        @info "Plotting iteration $i/$(iterations[end]) (frame $n/$(length(iterations)))..."

        u = file["timeseries/u/$i"]
        v = file["timeseries/v/$i"]
        η = file["timeseries/η/$i"]

        t = prettytime(file["timeseries/t/$i"])
        plot_title = "Rossby-Haurwitz wave (mode 4) at t = $t"

        fig = plt.figure(figsize=(16, 9))

        ax = fig.add_subplot(1, 3, 1, projection=projection)
        plot_cubed_sphere_tracer_field!(fig, ax, u, file["grid"], transform=ccrs.PlateCarree(), cmap=cmocean.cm.balance, vmin=-80, vmax=80, add_colorbar=false)
        gl = ax.gridlines(color="gray", alpha=0.5, linestyle="--")
        gl.xlocator = mticker.FixedLocator(-180:30:180)
        gl.ylocator = mticker.FixedLocator(-80:20:80)

        ax = fig.add_subplot(1, 3, 2, projection=projection)
        plot_cubed_sphere_tracer_field!(fig, ax, v, file["grid"], transform=ccrs.PlateCarree(), cmap=cmocean.cm.balance, vmin=-80, vmax=80, add_colorbar=false)
        gl = ax.gridlines(color="gray", alpha=0.5, linestyle="--")
        gl.xlocator = mticker.FixedLocator(-180:30:180)
        gl.ylocator = mticker.FixedLocator(-80:20:80)

        ax = fig.add_subplot(1, 3, 3, projection=projection)
        plot_cubed_sphere_tracer_field!(fig, ax, η, file["grid"], transform=ccrs.PlateCarree(), cmap=cmocean.cm.balance, vmin=8000, vmax=8250, add_colorbar=false)
        gl = ax.gridlines(color="gray", alpha=0.5, linestyle="--")
        gl.xlocator = mticker.FixedLocator(-180:30:180)
        gl.ylocator = mticker.FixedLocator(-80:20:80)

        fig.suptitle(plot_title, y=0.75)

        filename = @sprintf("cubed_sphere_rossby_haurwitz_%04d.png", n)
        plt.savefig(filename, dpi=200, bbox_inches="tight")
        plt.close(fig)
    end

    close(file)

    filename_pattern = "cubed_sphere_rossby_haurwitz_%04d.png"
    output_filename  = "cubed_sphere_rossby_haurwitz_three_spheres.mp4"

    # Need extra crop video filter in case we end up with odd number of pixels in width or height.
    # See: https://stackoverflow.com/a/29582287
    run(`ffmpeg -y -i $filename_pattern -c:v libx264 -vf "fps=10, crop=trunc(iw/2)*2:trunc(ih/2)*2" -pix_fmt yuv420p $output_filename`)

    [rm(f) for f in glob("cubed_sphere_rossby_haurwitz_*.png")]

    return nothing
end
