import torch
from typing import List
from copy import deepcopy

from livereco.api.parameters import Measurement
from livereco.api.parameters.beam_setup import BeamSetup
from livereco.api.parameters import DataDimensions
from livereco.api.parameters.options import Options
from livereco.api.viewer.viewer import Viewer
from livereco.core.reconstruction.utils import get_filter_kernels
from livereco.core.models.fresnel_propagator import FresnelPropagator
from livereco.core.models.cone_beam import ConeBeam
from livereco.core.preprocessing.process_padding_options import process_padding_options
from livereco.core.reconstruction.logging import *
from livereco.core.reconstruction.transformation import *

class Context():

    def __init__(self,
                 measurements: List[Measurement]=None,
                 beam_setup: BeamSetup=None,
                 options: List[Options]=None,
                 data_dimensions: DataDimensions=None,
                 viewer: List[Viewer]=None,
                 oref_predicted=None,
                nesterov_vt=None):

        self.torch_device = torch.device('cuda:0')

        self.measurements_original = measurements
        self.beam_setup_original = beam_setup
        self.data_dimensions_original = data_dimensions
        self.options = options

        self.current_options = None
        self.se_losses_all = torch.empty(0,device=self.torch_device)

        self.filter_kernel_obj_phase = None
        self.filter_kernel_obj_absorption = None

        self.filter_kernel_vt = None

        self.absorption_min = None
        self.phaseshift_max = None

        self.measurements = None
        self.beam_setup = None
        self.data_dimensions = None
        self.current_options = None

        self.viewer = viewer
        self.oref_predicted = oref_predicted
        self.nesterov_vt = nesterov_vt

        self.current_stage = -1
        self.current_iter_offset = 0

        self.set_stage(0)


    def set_stage(self,stage_index):
        if None in [self.measurements_original,self.beam_setup_original, self.options, self.data_dimensions_original]:
            raise ValueError("Host context not complete.")
        if stage_index >= len(self.options):
            raise ValueError("Invalid stage index " + str(stage_index), ". Can be at maximum " + str(len(self.options)-1) + ".")

        if self.current_stage == stage_index:
            return

        self.current_stage = stage_index
        logging.comment("Stage " + str(self.current_stage + 1) + "/" + str(len(self.options)))

        self.current_options = self.options[self.current_stage]

        self.absorption_min = -torch.tensor(
            self.current_options.regularization_object.values_min.imag, device=self.torch_device, dtype=torch.float
        )


        self.se_losses_all = torch.cat((self.se_losses_all, torch.zeros(self.current_options.regularization_object.iterations, device=self.torch_device, dtype=torch.float)))

        self.phaseshift_max = torch.tensor(
            self.current_options.regularization_object.values_max.real, device=self.torch_device, dtype=torch.float
        )

        if self.current_stage > 0:
            current_probe_refractive = deepcopy(self.beam_setup.probe_refractive)
        else:
            current_probe_refractive = deepcopy(self.beam_setup_original.probe_refractive)

        self.measurements, self.beam_setup, self.data_dimensions = process_padding_options(
            self.measurements_original, self.beam_setup_original, self.data_dimensions_original, self.current_options.padding
        )

        if self.oref_predicted is None:
            self.oref_predicted = torch.zeros(
                    self.data_dimensions.total_size,
                    device=self.torch_device,
                    dtype=torch.cfloat,
                )

            self.nesterov_vt = torch.zeros(
                    self.data_dimensions.total_size,
                    device=self.torch_device,
                    dtype=torch.cfloat,
                )

            self.beam_setup.probe_refractive = torch.zeros_like(self.oref_predicted)

        else:
                self.oref_predicted, self.nesterov_vt, self.beam_setup.probe_refractive = resize_guess(
                    self.oref_predicted,
                    self.nesterov_vt,
                    self.data_dimensions.total_size,
                    current_probe_refractive,
                )

        self.filter_kernel_obj_phase, self.filter_kernel_obj_absorption = get_filter_kernels(
            self.current_options.regularization_object.gaussian_filter_fwhm, self.data_dimensions.total_size,
            self.torch_device)


        self.filter_kernel_vt, _ = get_filter_kernels(self.current_options.nesterov_object.gaussian_filter_fwhm,
                                                      self.data_dimensions.total_size,
                                                      self.torch_device)
        self.model = FresnelPropagator(
            [
                ConeBeam.get_fr(self.beam_setup, self.measurements[distance])
                for distance in range(len(self.measurements))
            ],
            self.data_dimensions.total_size,
            self.oref_predicted.device,
        )


        log_preprocessed_params(self.beam_setup, self.data_dimensions)