#!/usr/bin/env python2

import vis
import visvis as vv
import numpy as np
import sys
import os
import subprocess
import argparse
import matplotlib.pyplot as pl

parser = argparse.ArgumentParser(prog="visOutput", 
        description="visualises obsidian states",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('inputfile', 
        help="e.g. mason0.npz")
# parser.add_argument('-o', '--outputfile', default='vis.png',
#         help="outputfile")
# parser.add_argument('-i', '--interactive', default=False,
#         help="number of shard instances to launch")
args = parser.parse_args()
console_width = subprocess.check_output(["tput", "cols"])

propIdx = 0

properties = ['Density',
              'LogSusceptibility',
              'ThermalConductivity',
              'ThermalProductivity',
              'LogResistivityX',
              'LogResistivityY',
              'LogResistivityZ',
              'ResistivityPhase',
              'PWaveVelocity']

# place for us to put stuff
class Data(object):
    def __init__(self, inputFile):
        raw = np.load(args.inputfile)
        self.resolution = raw['resolution'].flatten()
        self.bounds = np.array([raw['x_bounds'], raw['y_bounds'], raw['z_bounds']])[:,:,0]
        self.nlayers = sum(1 for key in raw if key.startswith('layer'))
        self.nboundaries = sum(1 for key in raw if key.startswith('boundary'))
        self.properties = [raw[k] for k in properties]
        self.layers = [raw['layer'+str(k)]for k in range(self.nlayers)]
        self.boundaries = [raw['boundary'+str(k)]for k in range(self.nlayers)]
        self.samples = self.boundaries[0].shape[0]
        #reshape these key layers
        newshapeVox = np.array((-1, self.resolution[2], self.resolution[1], self.resolution[0]), dtype=int)
        newshapeSurf = np.array((-1, self.resolution[1], self.resolution[0]), dtype=int)
        self.layers = [k.reshape(newshapeVox, order='f') for k in self.layers]
        self.properties = [k.reshape(newshapeVox, order='f') for k in self.properties]
        self.boundaries = [k.reshape(newshapeSurf, order='f') for k in self.boundaries]

def main():
    
    d =  Data(args.inputfile)
    
    for frame in range(d.samples):
        app = vv.use()
        # Surface
        vv.figure(1)
        X = np.linspace(d.bounds[0,0], d.bounds[0,1], d.resolution[0])
        Y = np.linspace(d.bounds[1,0], d.bounds[1,1], d.resolution[1])
        Z = np.linspace(d.bounds[2,0], d.bounds[2,1], d.resolution[2])
        vv.xlabel('Eastings (m)')
        vv.ylabel('Northings(m)')
        vv.zlabel('Depth (m)')
        a1 = vv.gca()
        a1.camera.fov = 70
        a1.daspect = 1, 1, -1
        for i in range(d.nlayers):
            C = pl.cm.jet(i/float(d.nlayers))
            C = np.array([[[C[0], C[1], C[2]]]])
            m = vv.surf(X, Y, d.boundaries[i][frame], C)
        # Volume
        vv.figure(2)
        # # Set labels
        vv.xlabel('Eastings (m)')
        vv.ylabel('Northings (m)')
        vv.zlabel('Depth (m)')
        a2 = vv.gca()
        a2.camera.fov = 70
        a2.daspect = 1, 1, -1
        vol = d.properties[propIdx][frame]
        t = vv.volshow(vol,cm=vv.CM_JET, renderStyle='ray')


        # Create colormap editor wibject.
        vv.ColormapEditor(a2)

        # Start app
        app.Run()

if __name__ == "__main__":
    main()
