import numpy
import os
from copy import deepcopy
import matplotlib
import matplotlib.pyplot as plt
import pathlib

from livereco.logging.logger import Logger
from livereco.api.plotter import NelderMeadPlotter
from livereco.api.parameters import *
from livereco.api.functions.find_focus.find_focus import find_focus
from livereco.core.utils.fileio import load_img_data
from livereco.api.paths.project_paths import ProjectPaths

matplotlib.use("Qt5Agg")

z01_guess = 81708500.0 + 2*(numpy.random.random_sample()-0.5) * 3.0 * 1e6
z01_confidence = 5.0 * 1e6

project_paths = ProjectPaths(
    root_dir=str(pathlib.Path(__file__).parent.resolve() ) + "/",
    session_name="tooth_find_focus",
    session_id=0)

project_paths.data_path = os.path.dirname(os.path.realpath(__file__)) + "/../data/holograms/tooth.tiff"
project_paths.logs_dir = os.path.dirname(os.path.realpath(__file__)) + "/../logs"

Logger.current_log_level = Logger.level_num_loss
Logger.configure(session_name=project_paths.session_logs_name, working_dir=project_paths.logs_dir)

flatfield_offset_corr = 0.96
setup = BeamSetup(energy=17.0,  px_size=6500.0, z02=19652000000.0)
measurements = [Measurement(data_path=project_paths.data_path , data=load_img_data(project_paths.data_path),z01=z01_guess, z01_confidence=z01_confidence)]
padding_options = Padding(padding_mode=Padding.PaddingMode.MIRROR_ALL,padding_factor=4.0,down_sampling_factor=16,cutting_band=0, a0=flatfield_offset_corr)

options_warmup = Options(
                         regularization_object=Regularization(
                             iterations=700,
                             update_rate=0.9,
                             l2_weight= 0.0 + 10.0*1j,
                             values_min=sys.float_info.min + 1j*numpy.log(flatfield_offset_corr),
                             gaussian_filter_fwhm= 2.0 + 0.0j

                         ),
                         nesterov_object=Regularization(
                             update_rate=1.0,
                             gaussian_filter_fwhm=8.0+ 8.0j
                         ),
                         verbose_interval=100,
                         padding=deepcopy(padding_options))

options_upscale_4 = Options(
                         regularization_object=Regularization(
                             iterations=300,
                             update_rate=1.1,
                             l2_weight= 0.0 + 10.0*1j,
                             values_min=sys.float_info.min + 1j*numpy.log(flatfield_offset_corr),
                             gaussian_filter_fwhm= 2.0 + 8.0j

                         ),
                         nesterov_object=Regularization(
                             update_rate=1.0,
                             gaussian_filter_fwhm=16.0+ 16.0j
                         ),
                         verbose_interval=100,
                         padding=deepcopy(padding_options))

options_upscale_4_lowreg = Options(
                         regularization_object=Regularization(
                             iterations=500,
                             update_rate=1.1,
                             l2_weight= 0.0 + 1.0*1j,
                             values_min=sys.float_info.min + 1j*numpy.log(flatfield_offset_corr),
                             gaussian_filter_fwhm= 2.0 + 8.0j

                         ),
                         nesterov_object=Regularization(
                             update_rate=1.0,
                             gaussian_filter_fwhm=16.0+ 16.0j
                         ),
                         verbose_interval=100,
                         padding=deepcopy(padding_options))

data_dimensions = DataDimensions(
    total_size=(2048, 2048),
    fov_size=(2048, 2048),
    window_type="blackman")

options_upscale_4.padding.down_sampling_factor = 4
options_upscale_4_lowreg.padding.down_sampling_factor = 4

reco_params = RecoParams(beam_setup=setup, output_path=project_paths.output_dir, measurements=measurements,
                         reco_options=[options_warmup,options_upscale_4,options_upscale_4_lowreg],
                         data_dimensions=data_dimensions)

plotter = NelderMeadPlotter()
result, z01_records, loss_values_history = find_focus(reco_params,plotter=[plotter])

print("Found z01=",result," after ", len(z01_records)," iterations")
plotter.finish()