#!/usr/bin/env python3

from boutdata import collect
from boutdata.data import BoutOptionsFile
from boututils.run_wrapper import launch_safe, shell_safe

from netCDF4 import Dataset
import numpy
from pathlib import Path
from sys import exit

test_directories = [
                    ('data', 12),
                   ]

tolerance = 1.e-6

shell_safe('make')

success = True
for directory,nproc in test_directories:

    # set up initial profile
    options = BoutOptionsFile(Path(directory).joinpath('BOUT.inp'))
    gridfilename = options['grid']
    grid = Dataset(gridfilename)
    if not 'rhs' in grid.variables:
        gridshape = grid['dx'].shape
        ny = grid['ny'][...]
        try:
            myg = grid['y_boundary_guards'][...]
        except KeyError:
            myg = 0
        jyseps2_1 = grid['jyseps2_1'][...]
        jyseps1_2 = grid['jyseps1_2'][...]
        ny_inner = grid['ny_inner'][...]
        grid.close()
        # need to generate initial profile
        shell_safe('make -f makefile.create-initial-profiles')

        # parallel diffusion
        launch_safe('./create-initial-profiles parallel=true nout=1 timestep=10 -d '
                    + directory, nproc=nproc, output=directory+'/parallel.log 2>&1')

        # perpendicular diffusion
        launch_safe('./create-initial-profiles perpendicular=true nout=1 timestep=1.e-5 -d '
                    + directory + ' restart append', nproc=nproc,
                    output=directory+'/perpendicular.log 2>&1')

        if gridshape[1] == ny:
            yguards = False
            add_upper_target = False
        elif gridshape[1] == ny + 2*myg:
            yguards = True
            add_upper_target = False
        elif gridshape[1] == ny + 4*myg:
            yguards = True
            add_upper_target = True
        else:
            raise ValueError('unsupported grid: ny='+str(ny)+', MYG='+str(myg)+' but '
                             'y-dimension size is '+gridshape[1])

        initial = collect('initial', path=directory, xguards=True, yguards=yguards)[-1, :, :, :]
        nz = initial.shape[2]

        if add_upper_target:
            new = numpy.zeros((initial.shape[0], initial.shape[1]+2*myg, nz))
            new[:, :ny_inner+myg, :] = initial[:, :ny_inner+myg, :]
            new[:, ny_inner+3*myg:, :] = initial[:, ny_inner+myg:, :]
            initial = new

        grid = Dataset(gridfilename, 'a')
        grid.createDimension('z', nz)
        grid.createVariable('nz', 'i4')
        grid['nz'][...] = nz
        grid.createVariable('rhs', float, dimensions=('x', 'y', 'z'))
        grid['rhs'][...] = initial
    else:
        # check grid is compatible with current input file
        if 'mz' in options and grid['nz'][...] != options['mz']:
            raise ValueError("mz requested in BOUT.inp is not the same as nz in grid "
                             "file. To use this mz, use a clean copy of the grid file "
                             "that does not have 'rhs' or 'nz' added to it.")
    grid.close()

    # run test
    try:
        launch_safe('test-laplace3d -d ' + directory, nproc=nproc,
                    output=directory+'/output.log 2>&1')
    except RuntimeError:
        print(directory + ' run crashed')
        success = False
        continue

    error_max = collect('error_max', path=directory)

    if error_max > tolerance:
        print(directory + ' failed with maximum error ' + str(error_max))
        success = False
    else:
        print(directory + ' passed with maximum error ' + str(error_max))

if success:
    print('All passed')
    exit(0)
else:
    exit(1)
