from implementations.utils.calc.datetime_difference_to_hours import datetime_difference_to_hours
# from implementations.utils.create_custom_colormap import underlying_mechanism_cmap
from contracts.data_processors import DataProcessor
from implementations.utils.plot_preppers.show_colourscale import show_colourscale
from implementations.utils.plot_preppers.scatter_prep import scatter_prepper
from implementations.utils.plot_preppers.export_to_svg import get_svg_config
from contracts.plotter_options import PlotterOptions
from contracts.plotter import Plotter
# from implementations.utils.get_colour import get_colour
from implementations.utils.plotly_colour_helpers.get_plotly_colour import get_plotly_colour
import plotly.graph_objects as go
import plotly.colors
from utils.logging import decorate_class_with_logging, DEBUG_PLOTTER


@decorate_class_with_logging(log_level=DEBUG_PLOTTER)
class ScatterDataPlotter(Plotter):
    """
    ScatterDataPlotter
    ==================

    A Plotly-based line scatter plotter for generic x–y datasets. This class
    implements the `Plotter` interface and draws one or more traces from
    `DataProcessor` instances.

    Overview
    --------
    `ScatterDataPlotter`:

    - Plots a chosen x observable versus a chosen y observable
    - Accepts a dictionary of `DataProcessor` objects keyed by label
    - Uses `PlotterOptions` to set axis titles, legend title, and line
      styling
    - Supports a "presentation" mode with thicker lines
    - Supports "time_evolved" plots by switching to a different colourscale

    Each trace is assigned a colour from the configured colourscale based on
    its index, and is plotted as a continuous line.

    Usage Notes
    -----------
    Call `ready_plot` to prepare the figure and configure options, then
    optionally call `set_axes_titles` to override axis labels. Finally,
    call `draw_plot` to add all traces and display the figure using the SVG
    export configuration.
    """
    def __init__(self, title, x_observable: str, y_observable: str, options: PlotterOptions):
        self.time_range = None
        self.title = title
        self.fig = go.Figure()
        self.x_observable = x_observable
        self.y_observable = y_observable

        self.data_processors = None
        self.experiment_time = None
        self.titles_set = False

        # expected_options = ["x_title", "y_title", "legend_title", "presentation", "time_evolved"]
        # if options.has_options(expected_options):
        self.options = options
        self.options.add_option(label="line", value={"width": 1})
        self.options.add_option(label="marker", value= {"size": 1})
        self.options.add_option(label='colourscale', value=plotly.colors.qualitative.Plotly)

    def ready_plot(self, data_processors: dict[str, DataProcessor], options: PlotterOptions):
        self.fig = scatter_prepper(self.fig)
        self.fig.update_layout(
            title={'text': self.title},
        )

        # Get the experiment time
        self.experiment_time = options.get_option("experiment_datetime")

        # If no mode is specified in options, assume lines only
        if not options.has_options('mode'):
            options.add_option(label='mode',value='lines')

        # Add legend if title is specified
        legend_title = options.get_option("legend_title")
        if legend_title:
            self.fig.update_layout(legend_title=legend_title)
        else:
            self.fig.update_layout(showlegend=False)

        # Define data processors to use
        self.data_processors = data_processors

        # Line should be thicker for presentations
        if self.options.get_option("presentation"):
            self.options.update_option(label="line", value={"width": 5})
        else:
            self.options.update_option(label="line", value={"width": 1})

        # Set colourscale for the continuous time evolved data and grab elapsed times from data
        if self.options.get_option("time_evolved"):
            self.options.update_option(label="colourscale", value=plotly.colors.get_colorscale("Magenta"))

            # Read the times from the files
            times = []
            for index, lbl in enumerate(self.data_processors):
                current_processor = self.data_processors[lbl]
                times.append(current_processor.get_data("elapsed_time", experiment_datetime=self.experiment_time))
            self.options.add_option(label="time_range", value=max(times))

        # TODO: Check
        # Inverts the colourmap, this was done manually and needs to be handled better
        # if options.get_option("time_evolved"):
        #     #self.colorscale = plotly.colors.get_colorscale("Magenta_r")
        #     self.colorscale = underlying_mechanism_cmap
        #     self.time_evolved = True
        #
        #     # Determine the time range for normalising the colour scale
        #     times = []
        #     for index, lbl in enumerate(self.data_processors):
        #         # FIXME: I should grab or compute the elapsed time here so I can use it
        #
        #         # Grab and plot data
        #         scatter = self.data_processors[lbl]
        #         times.append(scatter.get_data("elapsed_time", experiment_datetime=self.experiment_time))
        #     self.time_range = max(times)

        # Process data if needed
        # if options.has_options('normalise'):
        #     self.y_observable = f"{options.get_option('normalise')}_normalised"
        #
        #
        # if options.has_options("vertical_gridlines"):
        #     self.fig.update_xaxes(
        #         showgrid=True,  # Enable major gridlines
        #         gridcolor='lightgray',  # Color for major gridlines
        #         gridwidth=1,  # Make major gridlines solid
        #         griddash='solid',  # Solid major gridlines
        #         layer="below traces",  # Ensure gridlines extend fully
        #         minor=dict(
        #             showgrid=True,  # Enable minor gridlines
        #             gridcolor="lightblue",  # Color for minor gridlines
        #             gridwidth=0.5,  # Thinner minor gridlines
        #             griddash="dot",  # Make minor gridlines dotted
        #             nticks=10
        #         ),
        #     )
        #

    def set_axes_titles(self, x_title, y_title):
        self.fig.update_layout(
            xaxis_title=x_title,
            yaxis_title=y_title,
        )
        self.titles_set = True

    def draw_plot(self, *args, **kwargs):
        # FEATURE REQUEST: Draw plots with errors
        for index, lbl in enumerate(self.data_processors):
            #FIXME: I should grab or compute the elapsed time here so I can use it

            # Grab and plot data
            scatter = self.data_processors[lbl]

            # Compute value for time evolved colour assignment, fall back to index
            if self.options.get_option(label="time_evolved"):
                try:
                    value = scatter.get_data("elapsed_time", experiment_datetime=self.experiment_time) / self.options.get_option("time_range")
                except TypeError as te:
                    raise TypeError(f"Error when trying to compute percentage of elapsed time. {te}")
            else:
                value = index

            # Check for assigned colour or use colourscale
            colours = self.options.get_option("colours")
            if lbl in colours and not None in colours.values():
                current_colour = colours[lbl]
            else:
                current_colour = get_plotly_colour(self.options.get_option("colourscale"), value)

            # Apply to markers and lines as needed
            mode = self.options.get_option("mode")
            if 'marker' in mode:
                marker = self.options.get_option("marker")
                marker['color'] = current_colour
                marker["colorscale"] = self.options.get_option("colourscale")
                self.options.update_option(label="marker", value=marker)
            if 'line' in mode:
                line = self.options.get_option("line")
                line['color'] = current_colour
                self.options.update_option(label="line", value=line)

            # Add the trace
            self.fig.add_trace(go.Scatter(
                x=scatter.get_data(self.x_observable, *args, **kwargs),
                y=scatter.get_data(self.y_observable, *args, **kwargs),
                # Set the elapsed time as customdata for the colourbar
                customdata=[datetime_difference_to_hours(scatter.get_data("elapsed_time", experiment_datetime=self.experiment_time))],
                mode=self.options.get_option("mode"),
                name=scatter.get_data('label'),
                marker=self.options.get_option("marker"),
                line=self.options.get_option("line")
            ))

            # Update layout to disable sorting if needed
            if self.options.get_option("reverse_x_axis"):
                self.fig.update_layout(
                    xaxis=dict(
                        range=scatter.get_data("range", *args, **kwargs),
                    )
                )

        # Grab axis titles from observables if they have not yet been externally set
        if not self.titles_set:
            self.set_axes_titles(
                scatter.get_units(self.x_observable),
                scatter.get_units(self.y_observable)
            )

        # Show a colourbar for time evolved data
        if self.options.get_option("time_evolved"):
            self.fig = show_colourscale(self.options.get_option("colourscale"), self.fig)

        self.fig.show(config=get_svg_config())
