from contracts.data_processors import DataProcessorCore
from implementations.data.data_types.iv_scatter import IVData
import implementations.utils.calc.iv_calc as iv_calc
from implementations.utils.calc.is_monotonic import is_monotonic
from utils.errors.errors import VocNotFoundError, IscNotFoundError, ObservableNotComputableError
import warnings


class IVScatterDataProcessor(DataProcessorCore):
    def __init__(self, iv_data: IVData):
        self.data = iv_data

        self._processing_functions = {
            "voltage": self.get_voltage,
            "current": self.get_current,
            "power": self.calculate_power,
            "forward_voltage": self.get_forward_voltage,
            "forward_current": self.get_forward_current,
            "forward_power": self.get_forward_power,
            "reverse_voltage": self.get_reverse_voltage,
            "reverse_current": self.get_reverse_current,
            "reverse_power": self.get_reverse_power,
            "truncated_voltage": self.get_truncated_voltage,
            "truncated_current": self.get_truncated_current,
            "truncated_power": self.get_truncated_power,
            "current_difference": self.get_current_difference,
            "power_difference": self.get_power_difference,
            "isc": self.find_isc,
            "voc": self.find_voc,
            "mpp_power": self.calculate_mpp_power,
            "mpp_voltage": self.calculate_mpp_voltage,
            "mpp_current": self.calculate_mpp_current,
            "mpp_resistance": self.calculate_mpp_resistance,
            "fill_factor": self.calculate_fill_factor,
            "series_resistance": self.calculate_series_resistance,
            "shunt_resistance": self.calculate_shunt_resistance,
            "parameters": self.get_parameters,
            "elapsed_time": self.elapsed_time
        }

        # Flag needed in case forward and reverse need to be concatenated
        self._voltage_reversed = None

        # Is filled by parent class_utils function get.data()
        self.processed_data = {}
        for key in self._processing_functions:
            self.processed_data[key] = None

        self._processed_observables = self.processed_data.keys()

    # CHECK: Is this really necessary?  Ignores Datatype observables
    # def get_allowed_observables(self):
    #     return self._processed_observables

    def validate_observables(self, *args):
        # Checks whether all desired observables can be obtained for this data
        for observable in args:
            try:
                self.get_data(observable)
            except VocNotFoundError:
                raise ObservableNotComputableError
            except ValueError:
                warnings.warn(f'Observable(s) {observable} could not be computed for file {self.get_data("label")}')

    def get_voltage(self):
        """
        This will get called if the datareader does not provide voltage data.
        I.e. it should mean it only provides forward and reverse curves separately.
        """
        forward = self.get_data("forward_voltage")
        reverse = self.get_data("reverse_voltage")

        voltages = []
        if is_monotonic(forward) and is_monotonic(reverse):
            voltages += forward
            # Reverse if increasing
            if is_monotonic(reverse, increasing=True):
                voltages += reverse[::-1]
                self._voltage_reversed = True
            else:
                voltages += reverse
                self._voltage_reversed = False
        else:
            raise ValueError("IVScatterDataProcessor: Trying to concatenate non-monotonic voltages")

        forward_units = self.get_units("forward_voltage")
        reverse_units = self.get_units("reverse_voltage")

        if not forward_units == reverse_units:
            raise ValueError("IVScatterDataProcessor: Incompatible units for forward and reverse voltages")
        return {"units": forward_units, "data": voltages}

    def get_current(self):
        # Check voltages first
        if self._voltage_reversed is None:
            self.get_data("voltage")
        currents = self.get_data("forward_current")
        reverse = self.get_data("reverse_current")

        if not self._voltage_reversed:
            currents += reverse
        else:
            currents += reverse[::-1]

        forward_units = self.get_units("forward_current")
        reverse_units = self.get_units("reverse_current")

        if not forward_units == reverse_units:
            raise ValueError("IVScatterDataProcessor: Incompatible units for forward and reverse currents")
        return {"units": forward_units, "data": currents}

    def calculate_power(self):
        current = self.data.get_data("current")
        voltage = self.data.get_data("voltage")
        power = []
        for i, v in zip(current, voltage):
            power.append(i * v)

        return {"units": "Power (W)", "data": power}

    def get_forward_voltage(self) -> dict:
        forward_voltage, reverse_voltage, forward_current, reverse_current = self._split_iv_curve()
        return {"units": "Voltage (V)", "data": forward_voltage}

    def get_forward_current(self) -> dict:
        forward_voltage, reverse_voltage, forward_current, reverse_current = self._split_iv_curve()
        return {"units": "Current (A)", "data": forward_current}

    def get_forward_power(self) -> dict:
        voltage = self.get_data("forward_voltage")
        current = self.get_data("forward_current")
        power = [voltage[i]*current[i] for i in range(len(voltage))]
        return {"units": "Power (W)", "data": power}

    def get_reverse_voltage(self) -> dict:
        forward_voltage, reverse_voltage, forward_current, reverse_current = self._split_iv_curve()
        return {"units": "Voltage (V)", "data": reverse_voltage}

    def get_reverse_current(self) -> dict:
        forward_voltage, reverse_voltage, forward_current, reverse_current = self._split_iv_curve()
        return {"units": "Current (A)", "data": reverse_current}

    def get_reverse_power(self) -> dict:
        voltage = self.get_data("reverse_voltage")
        current = self.get_data("reverse_current")
        power = [voltage[i]*current[i] for i in range(len(voltage))]
        return {"units": "Power (W)", "data": power}

    def get_truncated_voltage(self) -> dict:
        isc = self.get_data("isc")
        voc = self.get_data("voc")
        voltage = self.get_data("forward_voltage")
        current = self.get_data("forward_current")

        trimmed_voltage = iv_calc.trim_iv(voltage, current, voltage, isc, voc)
        return {"units": "Voltage (V)", "data": trimmed_voltage}

    def get_truncated_current(self) -> dict:
        isc = self.get_data("isc")
        voc = self.get_data("voc")
        voltage = self.get_data("forward_voltage")
        current = self.get_data("forward_current")

        trimmed_current = iv_calc.trim_iv(voltage, current, current, isc, voc)
        return {"units": "Current (A)", "data": trimmed_current}

    def get_truncated_power(self) -> dict:
        # Truncates power curve to between isc and voc
        isc = self.get_data("isc")
        voc = self.get_data("voc")
        voltage = self.get_data("forward_voltage")
        current = self.get_data("forward_current")
        power = self.get_data("forward_power")
        try:
            trimmed_power = iv_calc.trim_iv(voltage, current, power, isc, voc)
        except IndexError:
            raise ObservableNotComputableError("Error in power truncation, issue probably related to dark measurement")

        return {"units": "Power (W)", "data": trimmed_power}

    def get_current_difference(self) -> dict:
        fw_current = self.get_data("forward_current")
        rv_current = self.get_data("reverse_current")
        return {"units": "Current (A)", "data": [i - j for i, j in zip(fw_current, rv_current[::-1])]}

    def get_power_difference(self) -> dict:
        fw_power = self.get_data("forward_power")
        rv_power = self.get_data("reverse_power")
        return {"units": "Power (W)", "data": [p - q for p, q in zip(fw_power, rv_power[::-1])]}

    def find_isc(self) -> dict:
        """
        The short-circuit current is the current at zero voltage. This is the y-crossing in an IV curve
        """
        current = self.get_data("forward_current")
        voltage = self.get_data("forward_voltage")
        try:
            isc = iv_calc.find_crossing(voltage, current)
            # CHECK: Is there any physical check that can be placed on Isc?
            return {"units": "Current (A)", "data": isc}

        except IndexError as ie:
            raise IscNotFoundError

    def find_voc(self) -> dict:
        """
        The open-circuit voltage is the voltage where there is no current. This is the x-crossing in an IV curve.
        Equivalently one can find the y-crossing in a VI curve,
        """
        current = self.get_data("forward_current")
        voltage = self.get_data("forward_voltage")
        try:
            voc = iv_calc.find_crossing(current, voltage)
            if voc < 0:
                raise VocNotFoundError("Voc was found to be negative")
            return {"units": "Voltage (V)", "data": voc}
        except IndexError as ie:
            raise VocNotFoundError("Voc could not be determined from the data")

    def calculate_mpp_power(self) -> dict:
        power = self.get_data("truncated_power")

        if len(power) < 1:
            raise ObservableNotComputableError("Truncated power is empty")

        abs_power = [abs(p) for p in power]
        max_index = max(range(len(abs_power)), key=abs_power.__getitem__)

        return {"units": "Power (W)", "data": power[max_index]}

    def calculate_mpp_voltage(self) -> dict:
        max_power = self.get_data("mpp_power")
        power = self.get_data("power")
        voltage = self.get_data("voltage")

        return {"units": "Voltage (V)", "data": voltage[power.index(max_power)]}

    def calculate_mpp_current(self) -> dict:
        max_power = self.get_data("mpp_power")
        power = self.get_data("power")
        current = self.get_data("current")

        return {"units": "Current (A)", "data": current[power.index(max_power)]}

    def calculate_mpp_resistance(self) -> dict:
        mpp_current = self.get_data("mpp_current")
        mpp_voltage = self.get_data("mpp_voltage")
        return {"units": "Resistance (\Omega)", "data": abs(mpp_voltage/mpp_current)}

    def calculate_fill_factor(self) -> dict:
        voc = self.get_data("voc")
        isc = self.get_data("isc")
        max_power = self.get_data("mpp_power")

        try:
            ff = max_power/(voc * isc)
        except ZeroDivisionError:
            raise ObservableNotComputableError("Voc or Isc are too small throwing a div by zero error")

        if ff < 0.0 or 1.0 < ff:
            warnings.warn(f"Fill factor is {ff} and should be in the [0, 1] interval")

        return {"units": "Fill ~factor", "data": ff}

    def calculate_series_resistance(self) -> dict:
        current = self.get_data("current")
        voltage = self.get_data("voltage")
        return {"units": "Resistance (\Omega)", "data": iv_calc.find_local_slope(voltage, current, 0)}

    def calculate_shunt_resistance(self) -> dict:
        current = self.get_data("current")
        voltage = self.get_data("voltage")
        return {"units": "Resistance (\Omega)", "data": 1 / iv_calc.find_local_slope(current, voltage, 0)}

    def get_parameters(self) -> dict:
        return {"units": "N/A", "data": {
                "label": self.get_data("label"),
                "isc": self.get_data("isc"),
                "voc": self.get_data("voc"),
                "fill_factor": self.get_data("fill_factor"),
                "mpp_power": self.get_data("mpp_power"),
                "mpp_voltage": self.get_data("mpp_voltage"),
                "mpp_current": self.get_data("mpp_current"),
                "mpp_resistance": self.get_data("mpp_resistance"),
                "rsh": self.get_data("shunt_resistance"),
                "rs": self.get_data("series_resistance")
            }
        }

    def _split_iv_curve(self):
        voltage = self.data.get_data("voltage")
        current = self.data.get_data("current")

        # Throw error if the iv-curve cannot be split
        try:
            return iv_calc.split_forward_reverse(voltage, current)
        except ValueError:
            raise ObservableNotComputableError

    def elapsed_time(self, *args, **kwargs):
        # Get a reference timestamp from *args
        reference_datetime = kwargs["experiment_datetime"]
        data_datetime = self.get_data("datetime")
        return {"units": "$Elapsed ~time ~(hrs)$", "data": data_datetime - reference_datetime}
