from contracts.data_processors import DataProcessorCore
from implementations.data.data_types.lbic_image import LBICImage
from implementations.utils.calc.list_calc import flatten_matrix
import numpy as np


class LBICImageProcessor(DataProcessorCore):
    def __init__(self, lbic_image: LBICImage):
        super().__init__(lbic_image)

        self.data = lbic_image
        self._processing_functions = {
            "current_list": self.flatten_currents,
            "horizontal_profile": self.get_horizontal_profile,
            "vertical_profile": self.get_vertical_profile,
            "diagonal_profile": self.get_diagonal_profile,
            "currents_in_area": self.get_currents_in_area
        }
        self.processed_data = {}
        for key in self._processing_functions:
            self.processed_data[key] = None

        self._processed_observables = self.processed_data.keys()

    def validate_observables(self, *args):
        # Checks whether all desired observables can be obtained for this data and catches relevant errors
        for observable in args:
            self.get_data(observable)

    def get_data(self, observable: str, *args, **kwargs):
        # If observable is from raw data delegate to Data
        if observable in self.data.get_allowed_observables():
            return self.data.get_data(observable)

        # Compute processed data if needed
        if observable in self.processed_data.keys():
            if self.processed_data[observable] is None:
                self.processed_data[observable] = self._processing_functions[observable](*args, **kwargs)
            return self.processed_data[observable]['data']
        else:
            raise ValueError(f"LBICImageProcessor does not contain {observable} data")

    def flatten_currents(self, source: str, *args, **kwargs):
        currents = []
        match source:
            case "full":
                currents = self.get_data("current")
            case "area":
                currents = self.get_data(observable="currents_in_area", *args, **kwargs)
        try:
            data = flatten_matrix(currents)
        except ValueError:
            raise ValueError("LBIC Processor: No currents were found in data")

        return {"units": "$Current (A)$", "data": data}

    def get_horizontal_profile(self, y_position: float = 0.0, *args, **kwargs):
        profile_index = np.abs([yval - y_position for yval in self.get_data("y_axis")]).argmin()
        data = self.get_data("current")[profile_index]
        return {"units": "$Current ~(A)$", "data": data}

    def get_currents_in_area(self, top_left_x=0, top_left_y=0, bottom_right_x=0, bottom_right_y=0, *args, **kwargs):
        step_size = self.get_data('step_size')
        top_left_x_index = round(top_left_x/step_size)
        top_left_y_index = round(top_left_y/step_size)
        bottom_right_x_index = round(bottom_right_x/step_size)
        bottom_right_y_index = round(bottom_right_y/step_size)

        if top_left_x_index == 0 and top_left_y_index ==0 and bottom_right_x_index == 0 and bottom_right_y_index == 0:
            raise ValueError("LBIC Processor: No area defined")

        currents = self.get_data("current")
        # currents_in_area = []
        # for row in currents[top_left_y_index:bottom_right_y_index+1]:
        #     current_row = []
        #     for item in row[top_left_x_index:bottom_right_x_index+1]:
        #         current_row.append(item)
        #     currents_in_area.append(current_row)
        currents_in_area = [current_row[top_left_x_index:bottom_right_x_index+1] for current_row in currents[top_left_y_index:bottom_right_y_index+1]]
        return {"units": "$Current ~(A)$", "data": currents_in_area}

    def get_vertical_profile(self, *args, **kwargs):
        raise NotImplementedError

    def get_diagonal_profile(self, *args, **kwargs):
        raise NotImplementedError
