#!/usr/bin/env python
#
from __future__ import division, print_function
from vtkplotter import Plotter, ProgressBar, printc, humansort, __version__
from vtkplotter import vtkio
import vtk
import sys, argparse, os

pr = argparse.ArgumentParser(description="version "+str(__version__)+""" -                              
                             check out home page https://github.com/marcomusy/vtkplotter""")
pr.add_argument('files', nargs='*',             help="Input filename(s)")
pr.add_argument("-c", "--color", type=str,      help="mesh color [integer or color name]", default=None, metavar='')
pr.add_argument("-a", "--alpha",    type=float, help="alpha value [0-1]", default=-1, metavar='')
pr.add_argument("-w", "--wireframe",            help="use wireframe representation", action="store_true")
pr.add_argument("-p", "--point-size", type=float, help="specify point size", default=-1, metavar='') 
pr.add_argument("-k", "--show-scalars",         help="use scalars as colors", action="store_true") 
pr.add_argument("-x", "--axes-type", type=int,  help="specify axes type [0-5]", default=4, metavar='')
pr.add_argument("-i", "--no-camera-share",      help="do not share camera in renderers", action="store_true")
pr.add_argument("-l", "--legend-off",           help="do not show legends", action="store_true")
pr.add_argument("-f", "--full-screen",          help="full screen mode", action="store_true")
pr.add_argument("-bg","--background", type=str, help="background color [integer or color name]", default='', metavar='')
pr.add_argument("-z", "--zoom", type=float,     help="zooming factor", default=1, metavar='')
pr.add_argument("-q", "--quiet",                help="quiet mode, less verbose", default=False, action="store_false")
pr.add_argument("-n", "--multirenderer-mode",   help="Multi renderer Mode: files go to separate renderers", action="store_true")
pr.add_argument("-s", "--scrolling-mode",       help="Scrolling Mode: use arrows to scroll files", action="store_true") 
pr.add_argument("-g", "--ray-cast-mode",        help="GPU Ray-casting Mode for 3D image files", action="store_true") 
pr.add_argument("-gz", "--z-spacing", type=float, help="Volume z-spacing factor [1]", default=None, metavar='') 
pr.add_argument("-gy", "--y-spacing", type=float, help="Volume y-spacing factor [1]", default=None, metavar='') 
pr.add_argument("--slicer",                     help="Slicer Mode for 3D image files", action="store_true")
pr.add_argument("--lego",                       help="Voxel rendering for 3D image files", action="store_true")
pr.add_argument("--cmap",                       help="Voxel rendering color map name", default='jet', metavar='')
pr.add_argument("--mode",                       help="Voxel rendering composite mode", default=0, metavar='')
args = pr.parse_args()

humansort(args.files)
nfiles = len(args.files)
if nfiles == 0:
    sys.exit()

wsize = "auto"
if args.full_screen:
    wsize = "full"

if args.lego:
    if args.background == "":
        args.background = "white"
    if args.axes_type == 4:
        args.axes_type = 1

if args.background == "":
    args.background = "blackboard"

if args.scrolling_mode and 3 < args.axes_type < 5:  # types 4 and 5 are not good for scrolling
    args.axes_type = 8

N = None
if args.multirenderer_mode:
    if nfiles < 201:
        N = nfiles
    if nfiles > 200:
        printc("~lightning Warning: option '-n' allows a maximum of 200 files", c=1)
        printc("         you are trying to load ", nfiles, " files.\n", c=1)
        N = 200
    vp = Plotter(size=wsize, N=N, bg=args.background)
    if args.axes_type == 1:
        vp.axes = 0
else:
    N = nfiles
    vp = Plotter(size=wsize, bg=args.background)
    vp.axes = args.axes_type

vp.verbose = not args.quiet
vp.sharecam = not args.no_camera_share

alpha = 1
leg = True
wire = False
if args.legend_off or nfiles == 1:
    leg = False
if args.wireframe:
    wire = True
if args.scrolling_mode and args.multirenderer_mode:
    args.scrolling_mode = False

_alphaslider0, _alphaslider1, _alphaslider2 = 0.33, 0.66, 1  # defaults

########################################################################
def _showVoxelImage():

    import vtkplotter.colors as vc
    from vtkplotter import Volume
    import numpy as np

    printc("GPU Ray-casting Mode", c="b", invert=1)
    printc("Press j to toggle jittering", c="b", invert=0, bold=0)
    printc("      q to quit.", c="b", invert=0, bold=0)

    vp.show(interactive=0)

    filename = args.files[0]
    
    img = vtkio.load(filename).imagedata()
    if args.z_spacing:
        ispa = img.GetSpacing()
        img.SetSpacing(ispa[0], ispa[1], ispa[2] * args.z_spacing)
    if args.y_spacing:
        ispa = img.GetSpacing()
        img.SetSpacing(ispa[0], ispa[1] * args.y_spacing, ispa[2])

    volume = Volume(img, mode=int(args.mode))
    volumeProperty = volume.GetProperty()

    smin, smax = img.GetScalarRange()
    if smax > 1e10:
        print("Warning, high scalar range detected:", smax)
        smax = abs(10 * smin) + 0.1
        print("         reset to:", smax)

    x0alpha = smin + (smax - smin) * 0.33
    x1alpha = smin + (smax - smin) * 0.66
    x2alpha = smin + (smax - smin) * 1.00

    ############################## color map slider
    # Create transfer mapping scalar value to color
    colorTransferFunction = volumeProperty.GetRGBTransferFunction()
    cmaps = [args.cmap, 'rainbow', 'viridis', 'bone', 'hot', 'plasma',
             'winter', 'cool', 'gist_earth', 'coolwarm', 'tab10']
    cols_cmaps = []
    for cm in cmaps:
        cols = vc.colorMap(range(0,21), cm, 0,20) # sample 20 colors
        cols_cmaps.append(cols)
    Ncols = len(cmaps)
    csl = (0.9, 0.9, 0.9)
    if sum(vc.getColor(args.background)) > 1.5:
        csl = (0.1, 0.1, 0.1)
    
    def setCMAP(k):
        cols = cols_cmaps[k]
        colorTransferFunction.RemoveAllPoints()
        for i,s in enumerate(np.linspace(smin, smax, num=20, endpoint=True)):
            r, g, b = cols[i]
            colorTransferFunction.AddRGBPoint(s,  r, g, b)
    setCMAP(0)
    
    def sliderColorMap(widget, event):
        sliderRep = widget.GetRepresentation()
        k = int(sliderRep.GetValue())        
        sliderRep.SetTitleText(cmaps[k])
        setCMAP(k)

    w1 = vp.addSlider2D(
        sliderColorMap,
        0, Ncols-1,
        value=0, showValue=0,
        title=cmaps[0],
        c=csl,
        pos=[(0.8, 0.05), (0.965, 0.05)],
    )
    w1.GetRepresentation().SetTitleHeight(0.018)

    ############################## alpha sliders
    # Create transfer mapping scalar value to opacity
    opacityTransferFunction =  volumeProperty.GetScalarOpacity()

    def setOTF():
        opacityTransferFunction.RemoveAllPoints()
        opacityTransferFunction.AddPoint(smin, 0.0)
        opacityTransferFunction.AddPoint(smin + (smax - smin) * 0.1, 0.0)
        opacityTransferFunction.AddPoint(x0alpha, _alphaslider0)
        opacityTransferFunction.AddPoint(x1alpha, _alphaslider1)
        opacityTransferFunction.AddPoint(x2alpha, _alphaslider2)
    setOTF()

    def sliderA0(widget, event):
        global _alphaslider0
        _alphaslider0 = widget.GetRepresentation().GetValue()
        setOTF()
    w0 = vp.addSlider2D(
        sliderA0, 0, 1, value=_alphaslider0, pos=[(0.855, 0.1), (0.855, 0.26)], c=csl, showValue=0
    )

    def sliderA1(widget, event):
        global _alphaslider1
        _alphaslider1 = widget.GetRepresentation().GetValue()
        setOTF()
    w1 = vp.addSlider2D(
        sliderA1, 0, 1, value=_alphaslider1, pos=[(0.9075, 0.1), (0.9075, 0.26)], c=csl, showValue=0
    )

    def sliderA2(widget, event):
        global _alphaslider2
        _alphaslider2 = widget.GetRepresentation().GetValue()
        setOTF()
    w2 = vp.addSlider2D(
        sliderA2, 0, 1, value=_alphaslider2, pos=[(0.96, 0.1), (0.96, 0.26)], c=csl, showValue=0,
        title="Opacity levels",
    )
    w2.GetRepresentation().SetTitleHeight(0.016)
    
    # add a button
    def buttonfuncMode():
        s = volume.mode()
        snew = (s+1)%2
        volume.mode(snew)
        bum.switch()
    
    bum = vp.addButton(
        buttonfuncMode,
        pos=(.7, .035),
        states=["composite", "max proj."],
        c=["bb", "gray"],
        bc=["gray", "bb"],  # colors of states
        font="arial",
        size=16,
        bold=0,
        italic=False,
    ).status(int(args.mode))

    def keyfuncJitter(key): #toggle jittering
        if key != 'j': return
        s = volume.jittering()
        snew = (s+1)%2
        volume.jittering(snew)
        vp.interactor.Render()  
    volume.jittering(True)
    vp.keyPressFunction = keyfuncJitter  # make it known to Plotter class

    def CheckAbort(obj, event):
        if obj.GetEventPending() != 0:
            obj.SetAbortRender(1)
    vp.window.AddObserver("AbortCheckEvent", CheckAbort)

    # add histogram of scalar
    from vtkplotter import histogram
    dims = img.GetDimensions()
    nvx = min(100000, dims[0]*dims[1]*dims[2])
    np.random.seed(0)
    idxs = np.random.randint(0, min(dims), size=(nvx, 3))
    data = []
    for ix, iy, iz in idxs:
        d = img.GetScalarComponentAsFloat(ix, iy, iz, 0)
        data.append(d)
    
    plot = histogram(data, bins=40, logscale=1, c="gray", bg='gray', pos=(0.78, 0.065))
    plot.GetPosition2Coordinate().SetValue(0.197, 0.20, 0)
    plot.SetNumberOfXLabels(2)
    plot.GetXAxisActor2D().SetFontFactor(0.8)    
    plot.GetProperty().SetOpacity(0.5)

    vp.add(plot)
    vp.add(volume)
    
    vp.show(viewup='z', zoom=1.2, interactive=1)
    w0.SetEnabled(0)
    w1.SetEnabled(0)
    w2.SetEnabled(0)


##########################################################
# special case of SLC/TIFF volumes with -g option
if args.ray_cast_mode or args.z_spacing or args.y_spacing:
    if args.axes_type in [1, 2, 3]:
        vp.axes = 4
    wsize = "auto"
    if args.full_screen:
        wsize = "full"
    _showVoxelImage()
    exit()

##########################################################
# special case of SLC/TIFF/DICOM volumes with --slicer option
elif args.slicer:

    filename = args.files[0]
    img = vtkio.load(filename).imagedata()

    ren1 = vtk.vtkRenderer()
    renWin = vtk.vtkRenderWindow()
    renWin.AddRenderer(ren1)
    iren = vtk.vtkRenderWindowInteractor()
    iren.SetRenderWindow(renWin)

    im = vtk.vtkImageResliceMapper()
    im.SetInputData(img)
    im.SliceFacesCameraOn()
    im.SliceAtFocalPointOn()
    im.BorderOn()

    ip = vtk.vtkImageProperty()
    ip.SetInterpolationTypeToLinear()

    ia = vtk.vtkImageSlice()
    ia.SetMapper(im)
    ia.SetProperty(ip)

    ren1.AddViewProp(ia)
    ren1.SetBackground(0.6, 0.6, 0.7)
    renWin.SetSize(900, 900)

    iren = vtk.vtkRenderWindowInteractor()
    style = vtk.vtkInteractorStyleImage()
    style.SetInteractionModeToImage3D()
    iren.SetInteractorStyle(style)
    renWin.SetInteractor(iren)

    renWin.Render()
    cam1 = ren1.GetActiveCamera()
    cam1.ParallelProjectionOn()
    ren1.ResetCameraClippingRange()
    cam1.Zoom(1.3)
    renWin.Render()

    printc("Slicer Mode:", invert=1, c="m")
    printc(
        """Press  SHIFT+Left mouse    to rotate the camera for oblique slicing
       SHIFT+Middle mouse  to slice perpendicularly through the image
       Left mouse and Drag to modify luminosity and contrast
       X                   to Reset to sagittal view
       Y                   to Reset to coronal view
       Z                   to Reset to axial view
       R                   to Reset the Window/Levels
       Q                   to Quit.""",
        c="m",
    )

    iren.Start()
    exit()


########################################################################
# normal mode for single VOXEL file with Isosurface Slider or LEGO mode
elif nfiles == 1 and (
       ".slc" in args.files[0]
    or ".vti" in args.files[0]
    or ".tif" in args.files[0]
    or ".mhd" in args.files[0]
    or ".nrrd" in args.files[0]
    or ".dem" in args.files[0]
):
    from vtkplotter import Actor, legosurface

    image = vtkio.loadImageData(args.files[0])
    scrange = image.GetScalarRange()
    threshold = (scrange[1] - scrange[0]) / 3.0 + scrange[0]
    #printHistogram(image, minbin=3, height=8, bins=40, logscale=1)

    if args.lego:
        sliderpos = ((0.79, 0.035), (0.975, 0.035))
        slidertitle = ""
        showval = False
        mbg = "white"
        act = legosurface(image, vmin=threshold, cmap=args.cmap)
        act.addScalarBar(act, horizontal=1, vmin=scrange[0], vmax=scrange[1])
    else:
        sliderpos = 4
        slidertitle = "isosurface threshold"
        showval = True
        mbg = "bb"
        cf = vtk.vtkContourFilter()
        cf.SetInputData(image)
        cf.UseScalarTreeOn()
        cf.ComputeScalarsOff()
        ic = "gold"
        if args.color is not None:
            if args.color.isdigit():
                ic = int(args.color)
            else:
                ic = args.color
        if args.show_scalars:
            ic = None

        cf.SetValue(0, threshold)
        cf.Update()
        act = Actor(cf.GetOutput(), ic, alpha=abs(args.alpha)).wire(args.wireframe)
        act.phong()

    ############################## threshold slider
    def sliderThres(widget, event):
        if args.lego:
            a = legosurface(image, vmin=widget.GetRepresentation().GetValue(), cmap=args.cmap)
        else:
            cf.SetValue(0, widget.GetRepresentation().GetValue())
            cf.Update()
            poly = cf.GetOutput()
            a = Actor(poly, ic, alpha=act.alpha()).wire(args.wireframe)
            a.phong()
        vp.actors = []
        vp.renderer.RemoveActor(vp.getActors()[0])
        vp.renderer.AddActor(a)
        vp.renderer.Render()

    dr = scrange[1] - scrange[0]
    vp.addSlider2D(
        sliderThres,
        scrange[0] + 0.025 * dr,
        scrange[1] - 0.025 * dr,
        value=threshold,
        pos=sliderpos,
        title=slidertitle,
        showValue=showval,
    )

    def CheckAbort(obj, event):
        if obj.GetEventPending() != 0:
            obj.SetAbortRender(1)

    vp.window.AddObserver("AbortCheckEvent", CheckAbort)

    act.legend(leg)
    vp.show(act, zoom=args.zoom, viewup="z")


########################################################################
# NORMAL mode for single or multiple files, or multiren mode
elif (not args.scrolling_mode) or nfiles == 1:

    actors = []
    for i in range(N):
        f = args.files[i]

        if ".neutral" in f.lower() or ".xml" in f.lower() or ".gmsh" in f.lower():
            alpha = 0.05
            wire = True
        else:
            wire = False

        if 0 < args.alpha <= 1:
            alpha = args.alpha

        if args.wireframe:
            wire = True

        ic = i  # default here
        if args.color is not None:
            if args.color.isdigit():
                ic = int(args.color)
            else:
                ic = args.color

        if args.show_scalars:
            ic = None

        actor = vp.load(f, c=ic, alpha=alpha)
        if hasattr(actor, "wire"):
            actor.wire(wire)

        if leg:
            actor.legend(os.path.basename(f))

        actors.append(actor)

        if args.point_size > 0:
            try:
                ps = actor.GetProperty().GetPointSize()
                actor.GetProperty().SetPointSize(args.point_size)
                actor.GetProperty().SetRepresentationToPoints()
            except AttributeError:
                print("cannot set point size for", f)

        if args.multirenderer_mode:
            actor._legend = None
            vp.show(actor, at=i, interactive=False, zoom=args.zoom)
            vp.actors = actors

    if args.multirenderer_mode:
        vp.interactor.Start()
    else:
        vp.show(interactive=True, zoom=args.zoom)

########################################################################
# scrolling mode  -s
else:
    import numpy

    if 0 < args.alpha <= 1:
        alpha = args.alpha

    n = len(args.files)
    pb = ProgressBar(0, n)

    actors = []
    for i, f in enumerate(args.files):
        pb.print("..loading")
        ic = "gold"
        if args.color is not None:
            if args.color.isdigit():
                ic = int(args.color)
            else:
                ic = args.color
        if args.show_scalars:
            ic = None

        actor = vp.load(f, c=ic, alpha=alpha)
        if hasattr(actor, "wire"):
            actor.wire(wire)
        actor.legend(leg)
        actors.append(actor)
        if args.point_size > 0:
            try:
                ps = actor.GetProperty().GetPointSize()
                actor.GetProperty().SetPointSize(args.point_size)
                actor.GetProperty().SetRepresentationToPoints()
            except AttributeError:
                print("cannot set point size for", f)

    # calculate max actors bounds
    bns = []
    for a in actors:
        if a and a.GetPickable():
            b = a.GetBounds()
            if b:
                bns.append(b)
    if len(bns):
        max_bns = numpy.max(bns, axis=0)
        min_bns = numpy.min(bns, axis=0)
        vbb = (min_bns[0], max_bns[1], min_bns[2], max_bns[3], min_bns[4], max_bns[5])

    def _scroll(iren, event):  # observer
        global _kact
        actors[_kact].off()

        key = iren.GetKeySym()

        n = len(actors)
        if key == "Right" and _kact < n - 1:
            _kact += 1
        elif key == "Left" and _kact > 0:
            _kact -= 1

        step = int(n / 5)
        if key == "Up":
            _kact += step
            if _kact > n - 1:
                _kact = n - 1
        elif key == "Down":
            _kact -= step
            if _kact < 0:
                _kact = 0
        sliderRep.SetValue(_kact)
        if 0 <= _kact < n:
            vp.clickedActor = actors[_kact]

        actors[_kact].on()
        fn = args.files[_kact].split("/")[-1]
        printc("showing file #", _kact, "\t", fn, "\r", c="y", bold=0, end="")

    vp.interactor.AddObserver("KeyPressEvent", _scroll)

    vp.show(actors[0], interactive=False, zoom=args.zoom)

    if isinstance(vp.axes_instances[0], vtk.vtkCubeAxesActor):
        vp.axes_instances[0].SetBounds(vbb)

    cb = (1, 1, 1)
    if numpy.sum(vp.renderer.GetBackground()) > 1.5:
        cb = (0.1, 0.1, 0.1)

    def sliderf(widget, event):
        global _kact
        actors[_kact].off()
        _kact = int(widget.GetRepresentation().GetValue())
        actors[_kact].on()
        fn = args.files[_kact].split("/")[-1]
        printc("showing file #", _kact, "\t", fn, "\r", c="y", bold=0, end="")

    wid = vp.addSlider2D(sliderf, 0, n - 1, pos=4, c=cb, showValue=False)
    wid.SetAnimationModeToJump()
    sliderRep = wid.GetRepresentation()

    _kact = 0
    for a in actors:
        a.off()
    actors[0].on()

    printc("Scrolling Mode", c="y", invert=1, end="")
    printc(", press:", c="y")
    printc("- Right and Left keys to move through files", c="y")
    printc("- Up and Down keys to fast forward and backward", c="y")
    vp.show(actors, interactive=True, zoom=args.zoom)
    print()
