import sys
import numpy as np
import pyvista as pv
import re
import glob
from util import *
import genpvd

import vtk
try:
    from vtkmodules.util.numpy_support import numpy_to_vtk
except:
    from vtk.util.numpy_support import numpy_to_vtk

class VTIWriter():
    """
    Simple wrapper around VTK image I/O.

    Initialize, then add data of the same dimensions.
    The data added for the first time after init/write determines the dimensions.
    Adding multiple fields to the same file works via append.

    fnbase needs to contain a %d or the like for the step indicator and a .[p]vti ending.

    """
    def __init__(self, bpath, fnbase, dx, **kwargs):
        if (bpath[-1] != "/"):
            bpath = bpath + "/"
        self.bpath = bpath
        if not(".vti" not in fnbase or ".pvti" not in fnbase):
            print(" %s  doesn't have a .vti extension! assumed image data, so add a .vti extension!" % (fnbase))
            raise Exception
        self.fnbase = fnbase
        self.step = 0
        self.time = 0
        self.dx = dx
        self.data = []
        mkdirp(bpath) # ensure path exists

        self.image = None

        self.writer = vtk.vtkXMLImageDataWriter()
        self.writer.SetDataModeToBinary()
        self.writer.SetDataModeToAppended()
        self.writer.SetCompressionLevel(1)
        self.writer.SetCompressorTypeToLZ4()

    def addData(self, indata, name):
        """
        Adds indata to the VTI object with the given name.

        If it's the first one entered, this also determines the dimensions of the image.
        """
        sh = indata.shape
        ilen = len(sh)
        if ilen < 3:
            sh = np.zeros(3, dtype=np.int32)
            sh[:ilen] = indata.shape
            sh[ilen:] = 1; # add dummy dimensions
        if (self.image is None):
            self.image = vtk.vtkImageData()
            self.image.SetDimensions(sh[0], sh[1], sh[2])
            extents = np.zeros(6, dtype=np.int32)
            for i in range(3):
                extents[2*i] = 0;
                extents[2*i+1] = sh[i]
            self.image.SetExtent(extents[0], extents[1], extents[2], extents[3], extents[4], extents[5])
            self.image.SetSpacing(self.dx, self.dx, self.dx)
            self.image.SetOrigin(0,0,0)
        # Convert numpy array to VTK array

        linsize = np.prod(indata.shape)
        lindata = indata.reshape(linsize)

        arr = numpy_to_vtk(lindata)
        arr.SetName(name) # necessary
        self.image.GetCellData().AddArray(arr)
    def write(self, step, time):
        """
        Writes the accumulated data for this time/step combination.

        """
        path = self.bpath + self.fnbase % (step)
        fielddata = self.image.GetFieldData()

        steptuple = vtk.vtkTypeInt32Array()
        steptuple.InsertNextTypedTuple([step])
        steptuple.SetName("TimeValue")
        fielddata.AddArray(steptuple)

        timetuple = vtk.vtkTypeFloat32Array()
        timetuple.InsertNextTypedTuple([time])
        timetuple.SetName("RealTimeValue")
        fielddata.AddArray(timetuple)

        #self.writer.SetInputConnection(image)
        self.writer.SetInputDataObject(self.image)
        self.writer.SetFileName(path)
        self.writer.Update()
        self.writer.Write()

        del self.image
        self.image = None

class VTIReader():
    """
    Simple wrapper around pyvista to deal with timeseries without pvd files.

    This relies on an existing time.dat which is currently being written by my simulations.
    It is always written together with visualization I/O to track the frame-step-time relationship.

    After initialization, you can loop over frames like this:
    for frame in reader:
        evaluate(frame.data["myfield"])

    By default, all fields will be read. If you only need a subset of the fields, pass
        fieldnames=["name1", "name2", ...]
    to the constructor.

    This basically supports arbitrary (unique) looping over the files by using setStepList.
    Mixing setStepList and set Timestep is unsupported.

    """
    def __init__(self, bpath, fnbase, **kwargs):
        self.bpath = bpath
        self.fnbase = fnbase
        self.step = 0
        self.step_index = 0
        self.time = 0
        self.maxstep = np.iinfo(np.int64).max
        self.jumpstep = 0
        self.fieldnames = None
        self.reloadSteps()
        # reloadsteps reads our own time data
        # alternatively: generate a pvd file on init and use that?

        self.kwargs = kwargs

        #self.timedat = [None]
        self.data = None

    def __len__(self):
        return len(self.timedat) - self.jumpstep

    def __iter__(self):
        return self
    def __next__(self):
        try:
            # to get step 0
            self.step = int(self.timedat[self.step_index,0])
            self.time = self.timedat[self.step_index,1]
            if (self.step > self.maxstep):
                raise StopIteration
            self.readStep(self.step, selfkwargs=True)
            self.step_index += 1
        except IndexError:
            raise StopIteration
        return self

    def setMaxStep(self, step):
        self.maxstep = step

    def setTimestep(self, step, offset=0):
        """
        Jump to this timestep, with an optional offset in the frame sense.
        """
        filt = self.timedat[:,0] == step
        hit = np.argmax(filt)
        self.step_index = hit + offset
        self.jumpstep = self.step_index

    def setStepList(self, steplist):
        """
        Only go over the timesteps specified in steplist (if they exist).

        """
        filt = np.isin(self.timedat[:,0], steplist)
        self.timedat = self.timedat[filt]

    def setIndex(self, index):
        if (index == -1):
            index = self.timedat.shape[0]-1
        self.step_index = index


    def readvti(self, path, fieldnames=None):
        reader = pv.get_reader(path)
        if fieldnames is None and self.fieldnames is not None:
            fieldnames = self.fieldnames
        if fieldnames is not None:
            for name in reader.cell_array_names:
                if not name in fieldnames:
                    reader.disable_cell_array(name)
        else:
            fieldnames = reader.cell_array_names # all names
        data = reader.read()
        exts = data.GetExtent()
        bounds = data.GetBounds()
        lenx, leny, lenz = (exts[1]-exts[0]), (exts[3]-exts[2]), (exts[5]-exts[4])
        twod = False
        if lenz == 1:
            twod=True
        lens = np.array([lenz, leny, lenx])
        dim = np.sum( lens > 1)
        #if fieldnames is None:
        #    #get all fields by default
        #    fieldnames = data.cell_data.keys() # lazy for now

        dat = {}
        for fieldname in fieldnames:
            # reverse because it's zyx
            idat = data.cell_data[fieldname].reshape(lens) # auto reshape
            if twod:
                idat = idat[0,:,:]
            dat[fieldname] = idat

        dat.update(data.field_data)
        return dat

    def readStep(self, step, selfkwargs=False): #, **kwargs
        path = self.bpath + self.fnbase % (step)
        #if selfkwargs:
        #    kwargs = self.kwargs
        self.data = self.readvti(path, **self.kwargs)

    def reloadSteps(self):
        timefile = self.bpath + "/time.dat"
        try:
            self.timedat = np.loadtxt(timefile)
        except FileNotFoundError:
            # try finding pvd
            try:
                fmt = extract_format(self.fnbase)
                pvdpath = genpvd.makepvdpath(self.bpath+self.fnbase, fmt)
                import os
                if not os.path.isfile(pvdpath):
                    # generate pvd + alt_time_dat
                    genpvd.genPVD(self.bpath + self.fnbase, fmts=fmt)
                self.timedat = np.loadtxt(self.bpath + "/alt_time.dat")
                #tvs = np.array(pvd.time_values)
                #steps = np.zeros_like(tvs)
                #
            except:
                print("timedat not found and could not generate pvd, giving up")
                raise Exception
        try:
            self.step = int(self.timedat[self.step_index,0])
            self.time = self.timedat[self.step_index,1]
        except IndexError:
            #only one datapoint
            self.step = self.timedat[0]
            self.time = self.timedat[1]
            # make iteration still possible
            assert(len(self.timedat) == 2)
            self.timedat = self.timedat.reshape(1, 2)

