from copy import deepcopy
import logging
from pathlib import Path
from threading import Lock
from typing import List

from pywps import ComplexInput, Process
from pywps.app.exceptions import ProcessError
import xarray as xr
from xclim.subset import subset_bbox, subset_gridpoint

from finch.processes.utils import dataset_to_netcdf

from . import wpsio
from .utils import (
    RequestInputs,
    process_threaded,
    single_input_or_none,
    try_opendap,
    write_log,
)

LOGGER = logging.getLogger("PYWPS")


def finch_subset_gridpoint(
    process: Process, netcdf_inputs: List[ComplexInput], request_inputs: RequestInputs
) -> List[Path]:
    """Parse wps `request_inputs` based on their name and subset `netcdf_inputs`.

    The expected names of the inputs are as followed (taken from `wpsio.py`):
     - lat: Latitude coordinate, can be a comma separated list of floats
     - lon: Longitude coordinate, can be a comma separated list of floats
     - start_date: Initial date for temporal subsetting.
     - end_date: Final date for temporal subsetting.
    """

    lon_value = request_inputs[wpsio.lon.identifier][0].data
    try:
        longitudes = [float(l) for l in lon_value.split(",")]
    except AttributeError:
        longitudes = [float(lon_value)]

    lat_value = request_inputs[wpsio.lat.identifier][0].data
    try:
        latitudes = [float(l) for l in lat_value.split(",")]
    except AttributeError:
        latitudes = [float(lat_value)]

    start_date = single_input_or_none(request_inputs, wpsio.start_date.identifier)
    end_date = single_input_or_none(request_inputs, wpsio.end_date.identifier)
    variables = [r.data for r in request_inputs.get("variable", [])]

    n_files = len(netcdf_inputs)
    count = 0

    output_files = []

    lock = Lock()

    def _subset(resource: ComplexInput):
        nonlocal count

        # if not subsetting by time, it's not necessary to decode times
        time_subset = start_date is not None or end_date is not None
        dataset = try_opendap(resource, decode_times=time_subset)

        with lock:
            count += 1
            write_log(
                process,
                f"Subsetting file {count} of {n_files}",
                subtask_percentage=(count - 1) * 100 // n_files,
            )

        dataset = dataset[variables] if variables else dataset

        subsets = []
        for longitude, latitude in zip(longitudes, latitudes):
            subset = subset_gridpoint(
                dataset,
                lon=longitude,
                lat=latitude,
                start_date=start_date,
                end_date=end_date,
            )
            subsets.append(subset)

        subsetted = xr.concat(subsets, dim="region")

        if not all(subsetted.dims.values()):
            LOGGER.warning(f"Subset is empty for dataset: {resource.url}")
            return

        p = Path(resource._file or resource._build_file_name(resource.url))
        output_filename = Path(process.workdir) / (p.stem + "_sub" + p.suffix)

        dataset_to_netcdf(subsetted, output_filename)

        output_files.append(output_filename)

    process_threaded(_subset, netcdf_inputs)

    return output_files


def finch_subset_bbox(
    process: Process, netcdf_inputs: List[ComplexInput], request_inputs: RequestInputs
) -> List[Path]:
    """Parse wps `request_inputs` based on their name and subset `netcdf_inputs`.


    The expected names of the request_inputs are as followed (taken from `wpsio.py`):
     - lat0: Latitude coordinate
     - lon0: Longitude coordinate
     - lat1: Latitude coordinate
     - lon1: Longitude coordinate
     - start_date: Initial date for temporal subsetting.
     - end_date: Final date for temporal subsetting.
    """
    lon0 = single_input_or_none(request_inputs, wpsio.lon0.identifier)
    lat0 = single_input_or_none(request_inputs, wpsio.lat0.identifier)
    lon1 = single_input_or_none(request_inputs, wpsio.lon1.identifier)
    lat1 = single_input_or_none(request_inputs, wpsio.lat1.identifier)
    start_date = single_input_or_none(request_inputs, wpsio.start_date.identifier)
    end_date = single_input_or_none(request_inputs, wpsio.end_date.identifier)
    variables = [r.data for r in request_inputs.get("variable", [])]

    nones = [lat1 is None, lon1 is None]
    if any(nones) and not all(nones):
        raise ProcessError("lat1 and lon1 must be both omitted or provided")

    n_files = len(netcdf_inputs)
    count = 0

    output_files = []

    lock = Lock()

    def _subset(resource):
        nonlocal count

        # if not subsetting by time, it's not necessary to decode times
        time_subset = start_date is not None or end_date is not None
        dataset = try_opendap(resource, decode_times=time_subset)

        with lock:
            count += 1
            write_log(
                process,
                f"Subsetting file {count} of {n_files}",
                subtask_percentage=(count - 1) * 100 // n_files,
            )

        dataset = dataset[variables] if variables else dataset

        subsetted = subset_bbox(
            dataset,
            lon_bnds=[lon0, lon1],
            lat_bnds=[lat0, lat1],
            start_date=start_date,
            end_date=end_date,
        )

        if not all(subsetted.dims.values()):
            LOGGER.warning(f"Subset is empty for dataset: {resource.url}")
            return

        p = Path(resource._file or resource._build_file_name(resource.url))
        output_filename = Path(process.workdir) / (p.stem + "_sub" + p.suffix)

        dataset_to_netcdf(subsetted, output_filename)

        output_files.append(output_filename)

    process_threaded(_subset, netcdf_inputs)

    return output_files
