import zmq
import traceback

import livereco.api
from livereco.api.viewer import ZeroMQViewer
from livereco.api.viewer import LossViewer
from livereco.server.flatfield_correction import FlatfieldCorrection
from livereco.server.reconstruction import Reconstruction
from livereco.server.find_focus import FindFocus
from livereco.logging.logger import Logger

context = zmq.Context()
socket = context.socket(zmq.PULL)
socket.bind("tcp://*:"+str(livereco.api.server_port))

zeromqViewer = ZeroMQViewer(livereco.api.viewer_port)
lossViewer = LossViewer()
flat = FlatfieldCorrection([zeromqViewer, lossViewer])
rec = Reconstruction([zeromqViewer, lossViewer])
findfoc = FindFocus([zeromqViewer, lossViewer])

while True:
    Logger.current_log_level = Logger.level_num_image_info

    print("Waiting for calls...")
    message = socket.recv_json()
    print("Calling: ", message["function"])

    try:
        function = message["function"]

        if function == "reconfigure_logger":
            Logger.configure(
                working_dir=message["working_dir"], session_name=message["session_name"]
            )
            continue

        if function == "reconstruct":
            rec.reconstruct_x(
                message["flatfield_correction_params"], message["reco_params"]
            )
            continue

        if function == "find_focus":
            current_log_level = Logger.current_log_level
            Logger.current_log_level = Logger.level_num_image_final
            findfoc.find_focus(
                message["flatfield_correction_params"], message["reco_params"]
            )
            Logger.current_log_level = current_log_level
            continue

        if function == "correct_flatfield":
            flat.correct_flatfield(message["flatfield_correction_params"])
            continue

        if function == "calculate_flatfield_components":
            flat.calc_flatfield_components(message["flatfield_components_params"])
            continue

        if function == "add_flatfield":
            flat.add_flatfield(message["measurement"])

        if function == "reset_flatfield_list":
            flat.reset_flatfield_list()

    except:
        traceback.print_exc()
