from contracts.data_processors import DataProcessor
from contracts.plotter_options import PlotterOptions
from implementations.utils.plot_preppers.heatmap_prepper import heatmap_prepper, add_profiles
from contracts.plotter import Plotter
import plotly.graph_objects as go
from utils.logging import decorate_class_with_logging, DEBUG_PLOTTER


@decorate_class_with_logging(log_level=DEBUG_PLOTTER)
class HeatmapPlotter(Plotter):
    def __init__(self, title, observable: str, options: PlotterOptions):
        self.color = None
        self.options = options
        self.observable=observable

        # Default range behaviour
        self.options.add_option(label="zrange", value=[0, 0])
        self.options.add_option(label="zauto", value=True)

        self.title = title
        self.image_processor = None
        self.fig = go.Figure()

    def ready_plot(self, processor: dict[str, DataProcessor], options: PlotterOptions):
        self.fig = heatmap_prepper(self.fig)
        self.fig.update_layout(
            title={'text': self.title},
            legend_title=options.get_option("legend_title"),
        )

        if len(processor.keys()) == 1:
            this_key = next(iter(processor))
            this_processor = processor[this_key]
            if isinstance(this_processor, DataProcessor):
                self.image_processor = processor
            else:
                raise ValueError(f"Processor dict {processor}, contains value of incorrect type "
                                 f"{type(this_processor)}, expected instance of DataProcessor")
        else:
            raise ValueError(f"Must pass exactly one processor, passed {len(processor)} processors instead")

    # FIXME: Make part of ready_plot
    def set_options(self, zrange: list = None, colour: str = 'turbid', profiles: bool = None):
        # Set range
        if zrange is not None:
            self.options.update_option(label="zrange", value=zrange)
            self.options.update_option(label="zauto", value=False)

        # Set colour
        self.options.add_option(label="colour", value=colour)

        # Add profiles if requested
        if profiles:
            add_profiles(self.fig)

    def draw_plot(self):
        key = list(self.image_processor.keys())[0]
        processor = self.image_processor[key]
        selected_zrange = self.options.get_option("zrange")
        self.fig.update_layout(yaxis_title=processor.get_units("y_axis"))
        self.fig.add_trace(
            go.Heatmap(
                x=processor.get_data("x_axis"),
                y=processor.get_data("y_axis"),
                z=processor.get_data(observable=self.observable),
                zauto=self.options.get_option("zauto"),
                zmin=selected_zrange[0],
                zmax=selected_zrange[1],
                colorscale="turbid",
                reversescale=True,
                colorbar=dict(title=processor.get_units("current")),
                # hovertemplate won't work with units as currently defined
                hovertemplate='x (mm): %{x}<br>y (mm): %{y}<br>I (A): %{z}<extra></extra>'
            )
        )
        self.fig.show()
