#!/usr/bin/env python
# __BEGIN_LICENSE__
#  Copyright (c) 2009-2026, United States Government as represented by the
#  Administrator of the National Aeronautics and Space Administration. All
#  rights reserved.
#
#  The NGT platform is licensed under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance with the
#  License. You may obtain a copy of the License at
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
# __END_LICENSE__

"""
Compute the effective refractive index of water for a given temperature, salinity,
and wavelength or spectral response. Uses the Quan and Fry (1995) or Parrish (2020)
empirical equations.

References:
  Quan and Fry (1995): https://github.com/geojames/global_refractive_index_532
  Parrish (2020):
    https://research.engr.oregonstate.edu/parrish/
      index-refraction-seawater-and-freshwater-function-wavelength-and-temperature
"""

import sys
import os
import argparse
import numpy as np

# Set up the path to Python modules about to load
basepath    = os.path.abspath(sys.path[0])
pythonpath  = os.path.abspath(basepath + '/../Python')  # for dev ASP
libexecpath = os.path.abspath(basepath + '/../libexec') # for packaged ASP
sys.path.insert(0, basepath) # prepend to Python path
sys.path.insert(0, pythonpath)
sys.path.insert(0, libexecpath)

import asp_system_utils

# Prepend to system PATH
os.environ["PATH"] = libexecpath + os.pathsep + os.environ["PATH"]

# Define wavelength limits (in nm). The first range is from the Parrish equation
# documentation, the second is a hard error for sanity checking. The spectral
# response tables from vendors usually go beyond the strict Parrish limits, but
# have small responses outside the valid range.
min_wl_warn = 400
max_wl_warn = 700
min_wl_err  = 300
max_wl_err  = 1100

def read_csv(filename):
    wavelengths = []
    responses = []
    warn_range = False
    skipped_extreme = False
    try:
        with open(filename, 'r') as f:
            lines = f.readlines()
            if not lines:
                print("Error: Empty spectral response file.")
                sys.exit(1)

            # Skip header
            if len(lines) > 0:
                print(f"Skipping first line: {lines[0].strip()}", file=sys.stderr)

            for line in lines[1:]:
                line = line.strip()
                if not line:
                    continue
                # Handle comma or space delimiters
                parts = line.replace(',', ' ').split()
                if len(parts) < 2:
                    continue

                try:
                    wl = float(parts[0])
                    resp = float(parts[1])
                except ValueError:
                    print(f"Warning: Skipping invalid line: {line}", file=sys.stderr)
                    continue

                # Only include positive responses
                if resp <= 0:
                    continue

                # Skip wavelengths outside extreme range (don't include in computation)
                if wl < min_wl_err or wl > max_wl_err:
                    skipped_extreme = True
                    continue

                # Check wavelength range for warning
                if wl < min_wl_warn or wl > max_wl_warn:
                    warn_range = True

                wavelengths.append(wl)
                responses.append(resp)

        if skipped_extreme:
            print(f"Warning: Wavelengths outside {min_wl_err}-{max_wl_err} nm " +
                  f"were skipped.", file=sys.stderr)

        if warn_range:
            print(f"Warning: Some wavelengths are outside the valid range of " +
                  f"{min_wl_warn}-{max_wl_warn} nm for the Parrish equation. " +
                  f"Results may still be accurate if the spectral response " +
                  "for those is small.", file=sys.stderr)

    except Exception as e:
        print(f"Error reading spectral response file: {e}")
        sys.exit(1)

    if not wavelengths:
        print("Error: No valid data found in spectral response file.")
        sys.exit(1)

    if not responses:
        print("Error: No valid responses found in spectral response file.")
        sys.exit(1)

    return np.array(wavelengths), np.array(responses)

def main():
    usage = ("refr_index --salinity <val> --temperature <val> " +
             "[--spectral-response <file> | --wavelength <val>]")
    parser = argparse.ArgumentParser(usage=usage,
                                     description="Compute effective refraction index.")

    parser.add_argument('--salinity', type=float, default=-1.0,
                        help='Salinity in parts per thousand (ppt).')
    parser.add_argument('--temperature', type=float, default=-1.0,
                        help='Temperature in degrees Celsius.')
    parser.add_argument('--mode', type=str, default='Quan-Fry',
                        help='Refractive index equation: Quan-Fry (default) or Parrish.')

    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('--spectral-response', type=str, default="",
                        help='CSV file containing the spectral response.')
    group.add_argument('--wavelength', type=float, default=-1.0,
                        help='Calculate refraction index for single wavelength (nm).')

    parser.add_argument('-v', '--version', action='store_true',
                        help='Display the version of software.')

    args = parser.parse_args()

    if args.version:
        asp_system_utils.print_version_and_exit()

    # Normalize mode to lowercase
    args.mode = args.mode.lower()

    # Validate mode
    if args.mode not in ['quan-fry', 'parrish']:
        print(f"Error: Invalid mode '{args.mode}'. Must be 'Quan-Fry' or 'Parrish'.")
        sys.exit(1)

    # Input validation
    if args.salinity == -1.0:
        print("Error: --salinity must be specified.")
        sys.exit(1)
    if args.temperature == -1.0:
        print("Error: --temperature must be specified.")
        sys.exit(1)

    if args.salinity < 0:
        print("Error: Salinity must be non-negative.")
        sys.exit(1)

    if not (0 <= args.temperature <= 30):
        print("Error: Temperature must be between 0 and 30 degrees Celsius.")
        sys.exit(1)

    # Coefficients for Parrish (2020) empirical equation (used only if mode is Parrish)
    # n = a*T^2 + b*lambda^2 + c*T + d*lambda + e

    if args.mode == 'parrish':
        # Seawater (S = 35)
        params_35 = {
            'a': -1.50156e-6,
            'b': 1.07085e-7,
            'c': -4.27594e-5,
            'd': -1.60476e-4,
            'e': 1.39807
        }

        # Freshwater (S = 0)
        params_0 = {
            'a': -1.97812e-6,
            'b': 1.03223e-7,
            'c': -8.58125e-6,
            'd': -1.54834e-4,
            'e': 1.38919
        }

        # Linearly interpolate coefficients based on salinity
        s_factor = args.salinity / 35.0
        params = {}
        for key in params_0:
            params[key] = params_0[key] + s_factor * (params_35[key] - params_0[key])

    if args.spectral_response:
        if not os.path.exists(args.spectral_response):
            print(f"Error: Spectral response file not found: {args.spectral_response}")
            sys.exit(1)
        # Read spectral response
        wavelengths, responses = read_csv(args.spectral_response)
    else:
        # Single wavelength case
        wl = args.wavelength
        wavelengths = np.array([wl])
        responses = np.array([1.0])

    if np.sum(responses) == 0:
        print("Error: Sum of spectral responses is zero.")
        sys.exit(1)

    # Compute effective wavelength using spectral response weights
    eff_wl = np.sum(wavelengths * responses) / np.sum(responses)

    # Validate effective wavelength
    if eff_wl < min_wl_warn or eff_wl > max_wl_warn:
        print(f"Warning: Effective wavelength {eff_wl:.2f} nm is " +
              f"outside the valid range of {min_wl_warn}-{max_wl_warn} nm for " +
              f"the Parrish equation. Result may be inaccurate.", file=sys.stderr)
    if eff_wl < min_wl_err or eff_wl > max_wl_err:
        print(f"Error: Effective wavelength {eff_wl:.2f} nm is " +
              f"out of supported range ({min_wl_err}-{max_wl_err} nm).")
        sys.exit(1)

    # Compute refractive index using the effective wavelength
    T = args.temperature
    S = args.salinity
    lam = eff_wl

    if args.mode == 'quan-fry':
        # Quan and Fry (1995) equation
        effective_index = (1.31405 +
                          (0.0001779 + -0.00000105 * T + 0.000000016 * T**2) * S +
                          -0.00000202 * T**2 +
                          ((15.868 + 0.01155*S + -0.00423 *T)/lam) +
                          (-4382/lam**2) + (1145500/lam**3))
    else:
        # Parrish (2020) empirical equation
        # n = a*T^2 + b*lambda^2 + c*T + d*lambda + e
        effective_index = (params['a'] * T**2 +
                          params['b'] * lam**2 +
                          params['c'] * T +
                          params['d'] * lam +
                          params['e'])

    print(f"Effective index: {effective_index:.6f}")

if __name__ == "__main__":
    main()
