import numpy as np
import struct
# for timedata
ttype = np.dtype([ ("step", "i4"), ("time", "f8")])

def readBinaryParticleData(bdatpath):
    """
    Reads binary particle data prefixed with timestep+time.
    Returns the timedata and particle data as separate arrays
    """
    header = bdatpath + "h"
    ver = 0
    with open(header, "r") as f:
        l1 = f.readline()
        splits = l1.split(" ")

        if len(splits) == 1:
            ver = int(splits[0])
            sizes = [int(x) for x in f.readline().split(" ")]  # no. fields, no. particles
        else:
            ver = 0 # old format
            sizes = [int(x) for x in splits] # metadata in front, no. fields, no. particles
        dtype = f.readline()
        dtype = eval("np.%s" % dtype.replace("\n", ""))

    rawdat = np.fromfile(bdatpath, dtype=dtype)
    # assume it works out nicely
    if ver == 0:
        nelem = sizes[0] + np.prod(sizes[1:])
        nframes = rawdat.shape[0] // nelem
        framedat = rawdat.reshape(nframes, nelem)
        tdat = framedat[:,:sizes[0]]
        # frames, particles, data
        pdat = framedat[:,sizes[0]:].reshape(nframes, sizes[1], sizes[2])
        # it this the right way around? there was something here...
    else:
        ttype = np.dtype([ ("step", "i4"), ("time", "f8")])
        tdat = np.fromfile(bdatpath.replace(".bdat", ".tbdat"), dtype=ttype)
        nframes = tdat.shape[0]
        pdat = rawdat.reshape(nframes, sizes[0], sizes[1])
        if np.all(pdat[0,0,:] == 0.0):
            # tmpfix from derp in c++: if all volumes are initially 0, didn't compute
            pdat[0,:,:] = pdat[1,:,:]


    return tdat, pdat

def readheader(path):
    with open(path, "rb") as f:
        ver = struct.unpack("i", f.read(4))[0]
        hsize = struct.unpack("i", f.read(4))[0]
        step = struct.unpack("i", f.read(4))[0]
        ns = []
        for i in range(3):
            ns.append(struct.unpack("i", f.read(4))[0])
        ndof = struct.unpack("i", f.read(4))[0]
        # ignore extradata here
        return ns[0], ns[1], ns[2], ndof


def loadRestart(path, dtype, dimshape=False):
    """
    pass path to .bin file + base dtype

    @param dimshape: Try to reduce to appropriate dimensions

    """
    header = path+"h"
    nx, ny, nz, ndof = readheader(header)
    #print(nx, ny, nz, ndof)
    raw = np.fromfile(path, dtype=dtype)
    try:
        raw = raw.reshape( (ndof, nz, ny, nx))
        if (dimshape):
          dims = np.array([ndof, nz, ny, nx], dtype=np.int32)
          exd = dims[dims != 1]
          raw = raw.reshape(exd)
    except Exception as e:
        print(e)
    

    return raw



def getPhasefieldX(phidat, phiidat, needle):
    sh = np.array(phidat.shape, dtype=np.int64)
    shprod = np.prod(sh)
    innerprod = np.prod(sh[1:])
    linphii, linphi = phiidat.reshape(int(shprod)), phidat.reshape(int(shprod))
    ret = numba_getPhasefieldX(linphi, linphii, needle, innerprod, shprod)
    sh[0] = 1
    return ret.reshape(sh)

try:
    import numba
except:
    class numba:
        def njit(x):
            return x

@numba.njit
def numba_getPhasefieldX(linphi, linphii, needle, innerprod, shprod):
    ret = np.zeros(innerprod)
    hits = np.ones(innerprod, dtype=np.int8)
    for idx in range(shprod):
        if (linphii[idx] == needle):
            retidx = idx % innerprod
            # this is DISGUSTING
            ret[retidx] = (linphi[idx] * hits[retidx] + ret[retidx] * (1-hits[retidx]))
            hits[retidx] = 0
    return ret
