import os
import tempfile
import uuid

import xarray as xr     # noqa
import pandas as pd # noqa
from rasterio.crs import CRS    # noqa

from main.models import DataVariable, WmsLayer, WmsLayerLegend    # noqa
from main.lib.s3_utils import upload_file, remove_file # noqa

from . import File, FileCollection, Handler

MAX_FILE_SIZE = 2000000000  # 2GB


class NetcdfHandler(Handler):

    def create_layer(self, file: File):

        if self.is_file_too_large(file, MAX_FILE_SIZE):
            self.log('File is too large: ' + file.absolute_path)
            return

        for layer in WmsLayer.objects.filter(file_path=file.absolute_path):
            if layer.last_upload and not file.is_newer_than(layer.last_upload):
                print('No update for ' + file.absolute_path)
                return

        workspace = file.bucket.name

        file.read_netcdf_data(self.fs)

        for var_name in file.netcdf_variables:

            variable = self.get_or_create_variable(var_name, file.bucket)

            try:
                layer = WmsLayer.objects.select_related('variable').get(file_path=file.absolute_path, variable=variable)
                self.log('Add/Update Layer: ' + layer.layer_name)

                store_name = layer.get_store_name()

            except WmsLayer.DoesNotExist:

                store_name = file.get_unique_store_name(self.api, var_name)
                layer_name = workspace + ':' + store_name

                layer = WmsLayer(
                    wms_url=self.get_geoserver_url(workspace),
                    scale_factor=1,
                    layer_name=layer_name,
                    variable=variable,
                    no_data_value=-9999,
                    file_path=file.absolute_path,
                    bucket=file.bucket,
                    store_uuid=uuid.uuid4()
                )
                self.log('Create Layer: ' + layer_name)

            if len(file.netcdf_dates) > 0:
                uploaded_files = FileCollection("geoserver", "public/" + str(layer.store_uuid) + "/")

                missing_dates = uploaded_files.get_missing_dates(file.netcdf_dates)
                upload_urls = self._process_netcdf(layer, missing_dates)

                if len(upload_urls) > 0:
                    if not self.api.is_coveragestore_existing(file.bucket.name, store_name):
                        self.api.create_imagemosaic(file.bucket.name, store_name, upload_urls[0])
                        # self.log(f"[POST] Create store '{store_name}'.")
                        del upload_urls[0]
                    for url in upload_urls:
                        self.api.add_granule(workspace, store_name, url)

                    dates_to_delete = uploaded_files.get_unnecessary_dates(file.netcdf_dates)
                    self.remove_granules_from_imagemosaic(layer, store_name, dates_to_delete)

                    layer.time_steps = self.api.get_wms_timesteps(layer)

            else:
                # create single cog-datastore for netcdf without time
                upload_urls = self._process_netcdf(layer)
                if len(upload_urls) > 0:
                    self.api.create_cog_coveragestore(workspace, store_name, upload_urls[0])

            layer.last_upload = file.last_modified
            layer.save()

                # if layer.store_in_geowebcache:
                #     layer_legends = WmsLayerLegend.objects.filter(wms_layer=layer)
                #     for layer_legend in layer_legends:
                #         self.api.modify_gwc(layer.layer_name, layer_legend.legend.name_en, "seed")
                #         self.api.modify_gwc(layer.layer_name, layer_legend.legend.name_en, "seed")


    def delete_layer(self, file: File):
        for layer in WmsLayer.objects.filter(file_path=file.absolute_path):
            self.api.delete_datastore(file.bucket.name, layer.get_store_name())
            layer.delete()


    def _process_netcdf(self, layer: WmsLayer, missing_dates: list[str]=None):
        upload_urls = []

        # Open nc_file from S3
        with self.fs.open(layer.file_path, 'rb') as netcdf_file:
            ds = xr.open_dataset(netcdf_file, engine='h5netcdf').load()

            self.check_netcdf(ds)

            # Standardize coordinate names to 'lat' and 'lon'
            if 'latitude' in ds.coords:
                ds = ds.rename({'latitude': 'lat'})
            if 'longitude' in ds.coords:
                ds = ds.rename({'longitude': 'lon'})

            if missing_dates is None:
                print("Process netcdf without date.")
                missing_dates = [None]
            else:
                if len(missing_dates) == 0:
                    print("✅ All dates are already converted.")
                    # self.log("✅ All dates are already converted.")
                elif len(missing_dates) > 0:
                    print("Missing dates to process: {}".format(missing_dates))
                    # self.log("Missing dates to process: {}".format(missing_dates))

            crs = get_crs_from_netcdf(ds, layer.variable.name)
            if not crs:
                print("⚠️ No CRS found in the NetCDF!")
                crs = "EPSG:4326"

            # For loop to convert every date of the nc_file into a COG tiff
            for date in missing_dates:
                # Name file according to date
                destination_key = self._get_destination_key(layer, date)

                try:
                    if date:
                        # Read individual metric into data array (only one time)
                        da = ds[layer.variable.name].sel(time=date, method="nearest")
                    else:
                        da = ds[layer.variable.name]

                    # lon fix: Realign the x dimension to -180 origin for dataset
                    da = da.assign_coords(lon=(((da.lon + 180) % 360) - 180)).sortby("lon")

                    # lat fix: Reverse the DataArray's y dimension to comply with raster common practice
                    if da.lat.values[-1] > da.lat.values[0]:
                        da = da.isel(lat=slice(None, None, -1))

                    # Set raster data attributes, missing in netcdf file
                    da.rio.set_spatial_dims(x_dim="lon", y_dim="lat", inplace=True)

                    da.rio.write_crs(crs, inplace=True)
                    if 'grid_mapping' in da.attrs:
                        del da.attrs['grid_mapping']

                except (AttributeError, KeyError, ValueError) as e:
                    self.log("ERROR: Cannot read dataset {}: {}".format(layer.file_path, e))
                    return upload_urls

                # Convert nc to COG tiff using gdal
                with tempfile.TemporaryDirectory() as tmpdir:
                    tif_path = os.path.join(tmpdir, "temp.tif")
                    cog_path = os.path.join(tmpdir, "final_cog.tif")

                    try:
                        da.rio.to_raster(tif_path)

                        self.create_cog(tif_path, cog_path)

                        is_success = upload_file(self.get_cog_bucket_name(), destination_key, cog_path)

                        if is_success:
                            public_url = os.environ.get("MINIO_ENDPOINT") + f"/{self.get_cog_bucket_name()}/{destination_key}"
                            upload_urls.append(public_url)
                            print(f"✅ Upload successful: {public_url}")
                            # self.log(f"✅ Upload successful: {public_url}")
                        else:
                            print(f"❌ Upload failed!")
                            # self.log(f"❌ Upload failed!")
                    except Exception as e:
                        self.log(f"⚠️ Cannot transform NetCDF to COG! {str(e)}")


            ds.close()

        return upload_urls


    def remove_granules_from_imagemosaic(self, layer: WmsLayer, store_name: str, dates_to_delete: list[str]):

        if len(dates_to_delete) > 0:

            granule_list = self.api.cat.list_granules(store_name, store_name, layer.bucket.name)

            for date in dates_to_delete:
                destination_key = self._get_destination_key(layer, date)

                cog_url = os.environ.get("MINIO_ENDPOINT") + f"/{self.get_cog_bucket_name()}/{destination_key}"

                granule_id = self._get_granule_id(cog_url, granule_list['features'])

                if granule_id:
                    self.api.cat.delete_granule(store_name, store_name, granule_id, layer.bucket.name)
                    remove_file(self.get_cog_bucket_name(), destination_key)
                    print(f"[🧹] Deleted granule: {cog_url}")
                    # self.log(f"[🧹] Deleted granule: {cog_url}")
                else:
                    print(f"❌ Deleting granule failed.")
                    # self.log(f"❌ Deleting granule failed.")


    @staticmethod
    def _get_granule_id(url, granules):
        for granule in granules:
            if granule["properties"]["location"] == url:
                return granule["id"]
        return None


    def check_netcdf(self, ds):

        # Define acceptable dimension names for each required concept
        required_dims = {
            'time': ['time', 'Time'],
            'latitude': ['lat', 'latitude', 'Latitude'],
            'longitude': ['lon', 'longitude', 'Longitude']
        }

        try:
            # Check for presence of at least one alias for each required conceptual dimension
            for dim_label, aliases in required_dims.items():
                if not any(alias in ds.dims for alias in aliases):
                    self.log("Error: Missing required dimension: {}".format(dim_label))
                    return False

            # --- Time parsing check ---
            time_aliases = required_dims['time']
            time_dim_name = next((name for name in time_aliases if name in ds.dims), None)

            if time_dim_name:
                try:
                    time_values = ds[time_dim_name].values
                    pd.to_datetime(time_values)  # Just checking for parse-ability
                except Exception:
                    self.log("Time values could not be parsed as dates.")
                    return False

        except Exception as e:
            self.log(f"Error opening file: {str(e)}")
            return False

        return True


def get_crs_from_netcdf(ds, var_name=None):

    da = ds[var_name]

    # Check grid_mapping attribute
    grid_mapping_var_name = da.attrs.get("grid_mapping")

    if grid_mapping_var_name and grid_mapping_var_name in ds:
        attrs = ds[grid_mapping_var_name].attrs
        wkt = next((attrs.get(k) for k in ["crs_wkt", "spatial_ref", "crs", "WKT", "wkt"] if attrs.get(k)), None)
        if wkt:
            try:
                return CRS.from_wkt(wkt)
            except Exception as e:
                print(f"⚠️ Failed to parse WKT: {e}")

    # Fallback: check variable-level CRS keys
    for dataset in [da, ds]:
        for key in ["crs", "proj4", "projection"]:
            val = dataset.attrs.get(key)
            if val:
                try:
                    return CRS.from_string(val)
                except Exception as e:
                    print(f"⚠️ Failed to parse CRS string: {e}")

    return None
