#!/opt/anaconda3/bin/python3

import streamlit as st
import glob
import os
import read_lif
import numpy as np
import cv2
import matplotlib.pyplot as plt
import math
import plotly.graph_objects as go
import SessionState
from plotly.offline import plot
from skimage.measure import label, regionprops
from skimage import data
from skimage import color
from skimage.morphology import extrema
from skimage import exposure
from scipy import spatial

def bytescale(data, in_min, in_max):
    data = np.clip(data, in_min, in_max)
    data = (data - float(in_min)) / float(in_max - in_min)
    return np.array(data * 255, dtype=np.uint8)


def pixel_list_by_radius(radius):
    pixel_list = []
    max_pix = math.ceil(radius)
    for x in range(max_pix+1):
        for y in range(max_pix+1):
            if math.sqrt(x*x+y*y) < radius:
                pixel_list.append((x,y))
                if x > 0:
                    pixel_list.append((-x,y))
                if y > 0:
                    pixel_list.append((x,-y))
                if x >0 and y > 0:
                    pixel_list.append((-x,-y))
    return(pixel_list)
    

@st.cache
def load_zero_frame(filename,sel_series,sel_channel):
    reader = read_lif.Reader(filename)
    series = reader.getSeries()
    image = series[sel_series].getFrame2D(T=0, channel = sel_channel, dtype=np.uint16)
    image_b = bytescale(image,np.min(image),np.max(image))
    return image_b

@st.cache
def load_final_frame(filename,sel_series,sel_channel):
    reader = read_lif.Reader(filename)
    series = reader.getSeries()
    num_timepoints = series[sel_series].getNbFrames()
    image = series[sel_series].getFrame2D(T=num_timepoints-1, channel = sel_channel, dtype=np.uint16)
    image_b = bytescale(image,np.min(image),np.max(image))
    return image_b

@st.cache
def load_time_series(filename,sel_series,sel_channel):
    reader = read_lif.Reader(filename)
    series = reader.getSeries()
    num_timepoints = series[sel_series].getNbFrames()
    shape = series[sel_series].get2DShape()
    data = np.zeros((shape[0],shape[1],num_timepoints))
    for tp in range(num_timepoints):
        im = series[sel_series].getFrame2D(T=tp, channel=sel_channel,dtype=np.uint16)
        data[:,:,tp] = im
    return data

@st.cache
def get_peaks_positions(blur, percent):
    h_maxima = extrema.h_maxima(blur, percent)
    label_h_maxima = label(h_maxima)

    return [(x.centroid[1], x.centroid[0]) for x in regionprops(label_h_maxima)]




st.sidebar.title("Input selection")
filename = st.sidebar.selectbox(
    'Input File',
    glob.glob("*.lif"))

reader = read_lif.Reader(filename)
series = reader.getSeriesHeaders()

sel_series_i = st.sidebar.selectbox(
    'Series',
    list(range(len(series))),format_func=lambda x:series[x].getName())

channels = series[sel_series_i].getChannels()

sel_channel_i = st.sidebar.selectbox(
    'Channel',
    list(range(len(channels))),format_func=lambda x:channels[x].getAttribute('ChannelTag'))

#Load timepoint 0 of desired image
image_b = load_zero_frame(filename,sel_series_i,sel_channel_i)
image_b_final = load_final_frame(filename,sel_series_i,sel_channel_i)

# Blur image
st.sidebar.title("Peak picking parameters")
blurring =  st.sidebar.slider("Blurring", min_value=1, max_value=21, value=5)
blur = cv2.GaussianBlur(image_b,(blurring,blurring),0)
blur_final = cv2.GaussianBlur(image_b_final,(blurring,blurring),0)
#st.image(blur)


# Threshold image based on percentage
percent = st.sidebar.slider("H-height", min_value=1., max_value=100., value=10.,step=0.1,format="%.3f")
percent_final = st.sidebar.slider("Exclude based on final H-height", min_value=1., max_value=100., value=10.,step=0.1,format="%.3f")


coor_list_start = get_peaks_positions(blur, percent)
coor_list_final = get_peaks_positions(blur_final, percent_final)
#st.write(len(coor_list_start))
#st.write(len(coor_list_final))

final_tree = spatial.KDTree(coor_list_final)

coor_list = [ x for x in coor_list_start if final_tree.query(x,distance_upper_bound=3)[1] == len(coor_list_final)]

#st.write(len(coor_list))

skip = st.button('skip')
step1 = st.button('1 step')
step2 = st.button('2 step')
step3 = st.button('3 step')
step4 = st.button('4 step')
step5 = st.button('5 step')

ss = SessionState.get(position=1, count0=0, count1=0, count2=0, count3=0, count4=0, count5=0)
#ss1 = SessionState.get(count=1)
widget = st.empty()

#ss1 = SessionState.get(count1=1)

if skip:
    ss.position = ss.position + 1
    ss.count0 = ss.count0 + 1
if step1:
    ss.position = ss.position + 1
    ss.count1 = ss.count1 + 1
if step2:
    ss.position = ss.position + 1
    ss.count2 = ss.count2 + 1
if step3:
    ss.position = ss.position + 1
    ss.count3 = ss.count3 + 1
if step4:
    ss.position = ss.position + 1
    ss.count4 = ss.count4 + 1
if step5:
    ss.position = ss.position + 1
    ss.count5 = ss.count5 + 1
st.write('skip=',ss.count0,'1 step=',ss.count1, '2 step=',ss.count2, '3 step=',ss.count3, '4 step=',ss.count4, '5 step=',ss.count5)
max=len(coor_list)
ss.position = widget.slider('Peak', 0, max, ss.position)
#ss1.count1 = st.write("", ss1.count1)

#selected_keypoint_i = st.slider("Select peak", min_value=0, max_value=len(coor_list), value=0,step=1)
selected_keypoint_i=ss.position

st.sidebar.title("Trace calculation parameters")
extraction_radius = st.sidebar.slider("Extraction radius", min_value=0.0, max_value=20.0, value=3.0,step=0.1)

time_series_images = load_time_series(filename,sel_series_i,sel_channel_i)


coord = (int(round(coor_list[selected_keypoint_i][0])),int(round(coor_list[selected_keypoint_i][1])))
#st.write(coord)
pixel_list = pixel_list_by_radius(extraction_radius)

trace = []
for tp in range(time_series_images.shape[2]):
    summation = 0
    for pix in pixel_list:
        ex_pix = (coord[0] + pix[0], coord[1] + pix[1])
        summation += time_series_images[ex_pix[1],ex_pix[0],tp]
    trace.append(summation)

st.line_chart(trace)

min_v = np.min(time_series_images)
max_v = np.max(time_series_images)

plot_points = np.linspace(0,time_series_images.shape[2]-1,num = 10,dtype=int)

image_to = np.zeros((30,31*10))
for i, timepoint in enumerate(plot_points):
    image = bytescale(time_series_images[coord[1]-15:coord[1]+15,coord[0]-15:coord[0]+15,timepoint],min_v,max_v)
    image_to[0:30,31*i:31 * (i+1) -1] = image

#plt.imshow(image_to)
#st.pyplot()
plt.clf()
st.write(coor_list[selected_keypoint_i])

# Show image with picked peaks

show_image=st.button('SHOW IMAGE')
if show_image:
    # Create figure
    fig = go.Figure()

    # Constants
    img_width = blur.shape[0]
    img_height = blur.shape[1]
    scale_factor = 1.0

    # Add invisible scatter trace.
    # This trace is added to help the autoresize logic work.
    fig.add_trace(
        go.Scatter(
            x=[0, img_width * scale_factor],
            y=[0, img_height * scale_factor],
            mode="markers",
            marker_opacity=0
        )
    )

    # Configure axes
    fig.update_xaxes(
        visible=False,
        range=[0, img_width * scale_factor]
    )

    fig.update_yaxes(
        visible=False,
        range=[0, img_height * scale_factor],
        # the scaleanchor attribute ensures that the aspect ratio stays constant
        scaleanchor="x"
    )

    fig.add_trace(
        go.Heatmap(
            z=blur,
            showscale=False,
            colorscale='Greys',
            reversescale=True
        )
    )
    # Add image
    #fig.update_layout(
    #    images=[go.layout.Image(
    #        x=0,
    #        sizex=img_width * scale_factor,
    #        y=img_height * scale_factor,
    #        sizey=img_height * scale_factor,
    #        xref="x",
    #        yref="y",
    #        opacity=1.0,
    #        layer="below",
    #        sizing="stretch",
    #        source=blur)]
    #)

    # Configure other layout
    fig.update_layout(
        width=img_width * scale_factor,
        height=img_height * scale_factor,
        margin={"l": 0, "r": 0, "t": 0, "b": 0},
    )

    shapes = [  go.layout.Shape(
                type="rect",
                xref="x",
                yref="y",
                x0=point[0]-2,
                y0=point[1]-2,
                x1=point[0]+2,
                y1=point[1]+2,
                line=dict(
                    color="yellow",
                    width=3,
                )
            ) for point in coor_list]
    shapes[selected_keypoint_i]["line"]["color"] = "red"
    fig.update_layout(
        shapes=shapes)
    # Disable the autosize on double click because it adds unwanted margins around the image
    # More detail: https://plot.ly/python/configuration-options/
    st.plotly_chart(fig)
    #fig.show(config={'doubleClick': 'reset'})




