#!/usr/bin/env python3
import boutpp as bc
import numpy as np
from boutdata import collect

# requires boutpp
# requires not make

bc.init("-d input".split(" "))

f = bc.Field3D.fromMesh(None)
f.setAll(np.array([[[1.0]]]))
f2 = bc.Field3D.fromMesh(None)
f2.setAll(np.array([[[2.0]]]))
print(f.getAll())

dump = bc.Options(f3d=f, f2d=f2)
bc.writeDefaultOutputFile(dump)

errorlist = ""
n = collect("f2d", path="input", tind=-1, yguards=True)
print(n.shape)
T = collect("f3d", path="input", tind=-1, yguards=True)
print(n.shape)

print("initialized mesh")
mesh = bc.Mesh.getGlobal()
print("initialized mesh")

dens = bc.Field3D.fromMesh(mesh)
dens[:, :, :] = n.reshape((1, 1, 1))
temp = bc.Field3D.fromCollect("f3d", path="input")
pres = dens + temp
print(slice(None, None, -2).indices(5))
p = pres[::5, 0, ::-2]
pn = n + T
pn = pn.reshape(p.shape)
if not ((p == pn).all()):
    errorlist += "addition not working\n"
pres = dens * temp * 1
p = pres.getAll()
pn = n * T
pn = pn.reshape(p.shape)
if not ((p == pn).all()):
    errorlist += "multiplication not working\n"
print(p.shape)
print(pn.shape)
