import logging
import torch
from typing import List
import cupy as cp

from livereco.api.parameters import Measurement
from livereco.api.parameters.beam_setup import BeamSetup
from livereco.api.parameters.options import Options
from livereco.api.parameters import DataDimensions
from livereco.core.reconstruction.single_projection.context import Context
from livereco.core.reconstruction.logging import *

from .reconstruct import reconstruct as reconstruct_stage
from livereco.api.viewer.viewer import Viewer

def reconstruct(
    measurements: List[Measurement],
    beam_setup: BeamSetup,
    options: List[Options],
    data_dimensions: DataDimensions,
    viewer: List[Viewer],
):
    log_input(measurements)

    logging.comment("Initialization")

    log_params(measurements, beam_setup, options, data_dimensions)

    context=Context(viewer=viewer,options=options,measurements=measurements,beam_setup=beam_setup,data_dimensions=data_dimensions)

    for options_index in range(len(options)):
        torch.cuda.nvtx.range_push("reconstruction")
        with torch.cuda.nvtx.range("prepare run"):
            context.set_stage(options_index)

        logging.info(
            f"{'Downsampling':<17}{str(context.current_options.padding.down_sampling_factor)}"
        )

        reconstruct_stage(context)

        context.current_iter_offset = context.current_iter_offset + context.current_options.regularization_object.iterations

        # print(torch.cuda.memory_stats())
        torch.cuda.empty_cache()  # should empty device cache
        cp.get_default_memory_pool().free_all_blocks()
        # torch.cuda.memory_stats()

        options_index < (len(options) - 1) and log_results(
            "snapshot_x" + str(context.current_options.padding.down_sampling_factor),
            [context.oref_predicted],
            context.data_dimensions,
        )

        torch.cuda.nvtx.range_pop()

    log_results(
        "result_x" + str(options[-1].padding.down_sampling_factor),
        [context.oref_predicted],
        context.data_dimensions,
    )

    return context.oref_predicted, context.se_losses_all, context.data_dimensions.fov_size
