#!/usr/bin/env python3
""" 

Produces Geometry-scan: characteristic MO pair from max contraction (o or v) + p/d AO weights for every atom,
and SOP TOTAL norm + FINAL SINGLE-POINT ENERGY in the XYZ comment line, plus per-atom SOP norm as a 5th column on each XYZ line.

Input files in SCAN_DIR must be named:
  1_<SUFFIX>, 2_<SUFFIX>, ..., n_<SUFFIX>

Processes all of these ORCA calculations into one file described above, i.e., produces one file per trajectory.

Usage:
  python SOPEL-processing-lowdin.py --suffix="suffix" -o "out.xyz" \
      --singlet=1 --triplet=2 --partition=mulliken

Partitioning of AO SOC matrices h_ij^M(A):
  - mulliken (default):
      W_A = A-A block + 0.5*(A-rest + rest-A blocks)
  - lowdin:
      1) Reconstruct the AO overlap matrix S from the full MO coefficient matrix C via
           S = (C C^T)^(-1)
         (valid for a square, linearly independent AO basis and a complete MO set)
      2) Build the symmetric orthogonalization matrices S^(1/2) and S^(-1/2)
      3) Transform the AO SOC matrix to the orthogonalized AO basis:
           H~ = S^(-1/2) H S^(-1/2)
      4) Partition in the orthogonalized AO basis with the atom projector P_A:
           W~_A = 0.5 * (P_A H~ + H~ P_A)
         where P_A is diagonal with ones on orthogonalized AOs centered on atom A
      5) Transform to MO basis with the orthogonalized MO coefficients C~ = S^(1/2) C

Characteristic MO selection:
  - For chosen singlet state (ordinal in SINGLETS) and triplet state (ordinal in TRIPLETS):
      Build S (virt x occ) and T (virt x occ) amplitude matrices from printed excitations (c=...).
      v = T† S        (occ x occ)   where v_{ji} = sum_f (T_{fj})* S_{fi}
      o = T* S^T      (virt x virt) where o_{ji} = sum_f (T_{jf})* S_{if}
    Pick the single largest |element| among v and o.
    If from v: characteristic MO pair are (occ_j, occ_i).
    If from o: characteristic MO pair are (virt_j, virt_i).

p/d weights for MO p on atom A:
  weight_A(type) = sum_{μ in (A,type)} |C_{μ,p}|^2
  where μ runs over ALL shells (e.g., 2px + 3px + ... are all counted into px).
  This is a coefficient-based "character" measure (not overlap-population).

SOP TOTAL norm in comment line:
  TOTAL_norm_cm-1 = sqrt( sum_M |sum_A Sigma(A,M)|^2 ) * Eh_to_cm-1

Output OUT_FILE (XYZ trajectory with extra columns):
  N
  comment (includes characteristic pair and TOTAL SOP norm)
  Elem x y z  ATOM_SOP_cm-1  [8 weights for MO1 for this atom] [8 weights for MO2 for this atom]
"""

from __future__ import annotations

import argparse
import math
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np

# ---------------------------- CONFIG ---------------------------------

SCAN_DIR = Path(".")
SUFFIX = "suffix"   # files: 1_<SUFFIX>, 2_<SUFFIX>, ... # flag --suffix
OUT_FILE = Path("out-movie.xyz")  # flag -o

SINGLET_ORD = 1                      # 1-based ordinal within SINGLETS section  # flag --singlet
TRIPLET_ORD = 2                      # 1-based ordinal within TRIPLETS section  # flag --triplet

HBAR_AU = 1.0
OCC_THRESH = 1e-3                    # occupied if occ > OCC_THRESH
EH_TO_CM1 = 219474.63137             # 1 Eh in cm^-1
PARTITION_SCHEME = "mulliken"        # flag --partition {mulliken,lowdin}
LOWDIN_EIG_TOL = 1e-10               # reject nearly linearly dependent AO spaces

# --------------------------------------------------------------------

FLOAT_RE = re.compile(r"[+-]?(?:\d+\.\d*|\.\d+|\d+)(?:[Ee][+-]?\d+)?")
FINAL_SP_EN_RE = re.compile(
    r"^\s*FINAL\s+SINGLE(?:-|\s)POINT\s+ENERGY\s+([+-]?(?:\d+\.\d*|\.\d+|\d+)(?:[Ee][+-]?\d+)?)",
    re.I | re.M,
)
MO_HEADER_RE = re.compile(r"^\s+(?:\d+\s+){1,}\d+\s*$")
AO_LINE_RE = re.compile(r"^\s*(\d+)([A-Za-z]{1,2})\s+(\S+)\s+(.*)$")

STATE_RE = re.compile(r"^\s*STATE\s+(\d+):", re.I)
EXC_RE = re.compile(
    r"^\s*(\d+)[ab]?\s*->\s*(\d+)[ab]?\s*:\s*([0-9.Ee+-]+)\s*\(c=\s*([0-9.Ee+-]+)\)",
    re.I,
)

COORD_ANG_RE = re.compile(r"CARTESIAN COORDINATES \(ANGSTROEM\)\s*\n-+\s*\n", re.I)

SOCX_RE = r"Matrix elements for SOC\(X\) in AO basis"
SOCY_RE = r"Matrix elements for SOC\(Y\) in AO basis"
SOCZ_RE = r"Matrix elements for SOC\(Z\) in AO basis"
SOC_MO_RE = r"SOC\(X\) in MO basis"

ORBITALS = ["px", "py", "pz", "dxy", "dxz", "dyz", "dx2-y2", "dz2"]
Ms = [-1, 0, +1]


@dataclass(frozen=True)
class Excitation:
    occ_mo: int
    virt_mo: int
    amp: float
    weight: float


@dataclass
class ParsedState:
    printed_state_no: int
    excitations: List[Excitation]


def read_text(path: Path) -> str:
    return path.read_text(errors="ignore")


def parse_last_cart_coords_angstrom(text: str) -> Tuple[List[str], np.ndarray]:
    starts = [m.end() for m in COORD_ANG_RE.finditer(text)]
    if not starts:
        raise ValueError("Could not find 'CARTESIAN COORDINATES (ANGSTROEM)' section.")
    start = starts[-1]

    elems: List[str] = []
    coords: List[List[float]] = []
    for line in text[start:].splitlines():
        if not line.strip():
            break
        if set(line.strip()) <= {"-"}:
            break
        parts = line.split()
        if len(parts) < 4:
            break
        el = parts[0]
        x, y, z = map(float, parts[1:4])
        elems.append(el)
        coords.append([x, y, z])

    if not elems:
        raise ValueError("Failed to parse any atoms from CARTESIAN COORDINATES (ANGSTROEM).")
    return elems, np.array(coords, dtype=float)


def parse_nbf(text: str) -> int:
    m = re.search(r"Number of contracted basis functions\s+\.\.\.\s+(\d+)", text)
    if not m:
        raise ValueError("Could not find 'Number of basis functions ...' in output.")
    return int(m.group(1))


def parse_final_single_point_energy(text: str) -> float:
    """Parse the electronic energy from the last 'FINAL SINGLE POINT ENERGY' line.

    ORCA prints either 'FINAL SINGLE POINT ENERGY' or 'FINAL SINGLE-POINT ENERGY'.
    The returned value is in Hartree (Eh).
    """
    last = None
    for m in FINAL_SP_EN_RE.finditer(text):
        last = m
    if last is None:
        raise ValueError("Could not find 'FINAL SINGLE POINT ENERGY' in output.")
    return float(last.group(1))


def parse_mo_coefficients(
    text: str, nbf: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[int], List[str], Dict[int, str], int]:
    """
    Returns:
      C (nbf x nbf), eps (nbf), occ (nbf),
      ao_atom (nbf): AO row -> atom index (as printed in MO section),
      ao_label (nbf): AO label token (e.g. 2px, 3dxy, ...),
      atom_elems {atom_index: element},
      mo_index_base (0 or 1)
    """
    m = re.search(r"MOLECULAR ORBITALS.*?\n-+\n", text, re.S)
    if not m:
        raise ValueError("Could not find 'MOLECULAR ORBITALS' section.")
    start = m.end()

    end_match = re.search(r"(LOEWDIN|MULLIKEN)\s+POPULATION\s+ANALYSIS", text[start:], re.I)
    end = start + end_match.start() if end_match else len(text)
    lines = text[start:end].splitlines()

    C = np.zeros((nbf, nbf), dtype=float)
    #eps = np.zeros(nbf, nbf and nbf and nbf or nbf, dtype=float)
    eps = np.zeros(nbf, dtype=float)
    occ = np.zeros(nbf, dtype=float)

    ao_atom: List[int] = [None] * nbf  # type: ignore[assignment]
    ao_label: List[str] = [""] * nbf
    atom_elems: Dict[int, str] = {}

    mo_index_base: Optional[int] = None
    i = 0
    first_block = True

    while i < len(lines):
        if MO_HEADER_RE.match(lines[i]):
            cols = [int(x) for x in lines[i].split()]
            if mo_index_base is None:
                mo_index_base = 0 if min(cols) == 0 else 1

            if i + 3 >= len(lines):
                break

            energies = [float(x) for x in FLOAT_RE.findall(lines[i + 1])][: len(cols)]
            occs = [float(x) for x in FLOAT_RE.findall(lines[i + 2])][: len(cols)]
            i += 4  # header, energies, occs, dashed

            row = 0
            while i < len(lines) and lines[i].strip() != "":
                if MO_HEADER_RE.match(lines[i]):
                    break
                m_ao = AO_LINE_RE.match(lines[i])
                if not m_ao:
                    break

                aidx = int(m_ao.group(1))
                elem = m_ao.group(2)
                lbl = m_ao.group(3)
                rest = m_ao.group(4)

                nums = [float(x) for x in FLOAT_RE.findall(rest)]
                if len(nums) < len(cols):
                    raise ValueError(f"Not enough MO coefficients in line: {lines[i]}")

                if first_block:
                    if row >= nbf:
                        raise ValueError("Parsed more AO rows than nbf in first MO block.")
                    ao_atom[row] = aidx
                    ao_label[row] = lbl
                    atom_elems[aidx] = elem

                coeffs = nums[: len(cols)]
                for k, mo in enumerate(cols):
                    C[row, mo - mo_index_base] = coeffs[k]

                row += 1
                i += 1

            first_block = False
            for k, mo in enumerate(cols):
                mo0 = mo - mo_index_base
                eps[mo0] = energies[k]
                occ[mo0] = occs[k]
        else:
            i += 1

    if mo_index_base is None:
        raise ValueError("Failed to detect MO index base in MO section.")
    if any(x is None for x in ao_atom):
        raise ValueError("Failed to parse AO->atom mapping from MO section.")

    return C, eps, occ, ao_atom, ao_label, atom_elems, mo_index_base


def maybe_shift_atom_indexing(ao_atom: List[int], atom_elems: Dict[int, str], n_atoms_geom: int) -> Tuple[List[int], Dict[int, str]]:
    """
    If AO-center atom indices look 1-based (min=1, max=n_atoms), shift them to 0-based.
    Otherwise leave unchanged.
    """
    s = sorted(set(ao_atom))
    if len(s) == n_atoms_geom and s[0] == 1 and s[-1] == n_atoms_geom:
        ao_atom2 = [a - 1 for a in ao_atom]
        atom_elems2 = {k - 1: v for k, v in atom_elems.items()}
        return ao_atom2, atom_elems2
    return ao_atom, atom_elems


def extract_between(text: str, start_pat: str, end_pat: str) -> str:
    ms = re.search(start_pat, text, re.I)
    if not ms:
        raise ValueError(f"Start pattern not found: {start_pat}")
    start = ms.end()
    me = re.search(end_pat, text[start:], re.I)
    end = start + me.start() if me else len(text)
    return text[start:end]


def extract_excited_section(text: str, which: str) -> str:
    which = which.upper()
    if which == "SINGLETS":
        return extract_between(
            text,
            r"TD-DFT/TDA EXCITED STATES \(SINGLETS\)",
            r"TD-DFT/TDA EXCITED STATES \(TRIPLETS\)",
        )
    if which == "TRIPLETS":
        return extract_between(
            text,
            r"TD-DFT/TDA EXCITED STATES \(TRIPLETS\)",
            r"TD-DFT/TDA SPIN-ORBIT COUPLING",
        )
    raise ValueError("which must be SINGLETS or TRIPLETS")


def parse_states_ordinal(section: str, mo_index_base: int) -> List[ParsedState]:
    states: List[ParsedState] = []
    cur_idx: Optional[int] = None
    for line in section.splitlines():
        ms = STATE_RE.match(line)
        if ms:
            printed = int(ms.group(1))
            states.append(ParsedState(printed_state_no=printed, excitations=[]))
            cur_idx = len(states) - 1
            continue
        me = EXC_RE.match(line)
        if me and cur_idx is not None:
            i = int(me.group(1)) - mo_index_base
            a = int(me.group(2)) - mo_index_base
            weight = float(me.group(3))
            c = float(me.group(4))
            states[cur_idx].excitations.append(Excitation(occ_mo=i, virt_mo=a, amp=c, weight=weight))
    return states


def normalize_ao_label(lbl: str) -> str:
    s = lbl.lower()
    s = re.sub(r"^[0-9]+", "", s)  # 2px -> px
    s = s.replace("dx2y2", "dx2-y2")
    s = s.replace("d(x2-y2)", "dx2-y2")
    s = s.replace("dz^2", "dz2").replace("dz2", "dz2")
    return s


def pd_contributions_all_atoms_to_mo(
    C: np.ndarray,
    mo: int,
    ao_atom: List[int],
    ao_label: List[str],
    n_atoms_geom: int,
) -> Dict[int, Dict[str, float]]:
    """
    contribution_A(type) = sum_{μ in (A,type)} |C_{μ,mo}|^2

    Returns a dict: atom index -> dict(orbital_type -> contribution).
    Atom indexing is assumed to match the geometry atom ordering (0-based) after
    maybe_shift_atom_indexing().
    """
    contribs: Dict[int, Dict[str, float]] = {A: {k: 0.0 for k in ORBITALS} for A in range(n_atoms_geom)}
    vec = C[:, mo]

    for mu, A in enumerate(ao_atom):
        if A < 0 or A >= n_atoms_geom:
            continue
        key = normalize_ao_label(ao_label[mu])
        if key in contribs[A]:
            contribs[A][key] += float(vec[mu] * vec[mu])

    return contribs


def parse_orca_block_matrix(section: str, n: int) -> np.ndarray:
    lines = section.splitlines()
    header_re = re.compile(r"^\s+(?:\d+\s+){2,}\d+\s*$")
    row_re = re.compile(r"^\s*(\d+)\s+")
    mat = np.zeros((n, n), dtype=float)

    i = 0
    while i < len(lines):
        if header_re.match(lines[i]):
            cols = [int(x) for x in lines[i].split()]
            ncols = len(cols)
            for _ in range(n):
                i += 1
                if i >= len(lines):
                    raise ValueError("Unexpected EOF while reading matrix rows.")
                l = lines[i]
                if not row_re.match(l):
                    raise ValueError(f"Expected row line, got: {l!r}")
                parts = l.split()
                row = int(parts[0])
                vals = [float(x) for x in parts[1 : 1 + ncols]]
                if len(vals) != ncols:
                    raise ValueError(f"Row {row} has {len(vals)} values, expected {ncols}.")
                mat[row, cols] = vals
        i += 1
    return mat


def parse_soc_ao_matrices(text: str, nbf: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    sec_x = extract_between(text, SOCX_RE, SOCY_RE)
    sec_y = extract_between(text, SOCY_RE, SOCZ_RE)
    sec_z = extract_between(text, SOCZ_RE, SOC_MO_RE)
    return (
        parse_orca_block_matrix(sec_x, nbf),
        parse_orca_block_matrix(sec_y, nbf),
        parse_orca_block_matrix(sec_z, nbf),
    )


def cart_to_spherical(hx: np.ndarray, hy: np.ndarray, hz: np.ndarray, q: int) -> np.ndarray:
    # Condon–Shortley
    if q == 0:
        return hz.astype(complex)
    if q == +1:
        return -(hx + 1j * hy) / math.sqrt(2.0)
    if q == -1:
        return (hx - 1j * hy) / math.sqrt(2.0)
    raise ValueError("q must be -1, 0, +1")


def build_W_mulliken(Hq_total: np.ndarray, ao_atom: List[int], A: int) -> np.ndarray:
    nbf = Hq_total.shape[0]
    IA = np.array([k for k, a in enumerate(ao_atom) if a == A], dtype=int)
    IR = np.array([k for k, a in enumerate(ao_atom) if a != A], dtype=int)
    W = np.zeros((nbf, nbf), dtype=complex)
    W[np.ix_(IA, IA)] = Hq_total[np.ix_(IA, IA)]
    W[np.ix_(IA, IR)] = 0.5 * Hq_total[np.ix_(IA, IR)]
    W[np.ix_(IR, IA)] = 0.5 * Hq_total[np.ix_(IR, IA)]
    return W


def reconstruct_overlap_from_mos(C: np.ndarray) -> np.ndarray:
    """Reconstruct the AO overlap matrix from a complete MO coefficient matrix.

    For a square coefficient matrix C satisfying C^T S C = I,
    we have S = (C C^T)^(-1).
    """
    s_inv = C @ C.T
    s_inv = 0.5 * (s_inv + s_inv.T)
    S = np.linalg.inv(s_inv)
    return 0.5 * (S + S.T)


def symmetric_matrix_half_powers(S: np.ndarray, eig_tol: float) -> Tuple[np.ndarray, np.ndarray]:
    """Return S^(1/2) and S^(-1/2) for a real symmetric positive-definite matrix S."""
    evals, evecs = np.linalg.eigh(S)
    if float(np.min(evals)) <= eig_tol:
        raise ValueError(
            "Overlap matrix is not safely positive definite for Löwdin partitioning; "
            f"smallest eigenvalue = {float(np.min(evals)):.3e}."
        )

    sqrt_evals = np.sqrt(evals)
    inv_sqrt_evals = 1.0 / sqrt_evals
    S_half = evecs @ np.diag(sqrt_evals) @ evecs.T
    S_mhalf = evecs @ np.diag(inv_sqrt_evals) @ evecs.T
    return 0.5 * (S_half + S_half.T), 0.5 * (S_mhalf + S_mhalf.T)


def build_atom_projector(ao_atom: List[int], A: int) -> np.ndarray:
    nbf = len(ao_atom)
    P = np.zeros((nbf, nbf), dtype=float)
    IA = np.array([k for k, a in enumerate(ao_atom) if a == A], dtype=int)
    P[IA, IA] = 1.0
    return P


def build_W_lowdin_orthogonal(Hq_orth: np.ndarray, PA: np.ndarray) -> np.ndarray:
    return 0.5 * (PA @ Hq_orth + Hq_orth @ PA)


def build_amp_matrix(excs: List[Excitation], occ_idx: List[int], virt_idx: List[int]) -> np.ndarray:
    occ_pos = {mo: k for k, mo in enumerate(occ_idx)}
    virt_pos = {mo: k for k, mo in enumerate(virt_idx)}
    A = np.zeros((len(virt_idx), len(occ_idx)), dtype=complex)  # (virt x occ)
    for ex in excs:
        if ex.occ_mo in occ_pos and ex.virt_mo in virt_pos:
            A[virt_pos[ex.virt_mo], occ_pos[ex.occ_mo]] += ex.amp
    return A


def characteristic_pair_from_contractions(
    sing_state: ParsedState,
    trip_state: ParsedState,
    occ: np.ndarray,
    occ_thresh: float,
) -> Tuple[str, int, int, float, int, int]:
    """
    Returns:
      kind: "occ" or "virt"
      mo1, mo2: MO indices (0-based internal)
      max_abs: largest absolute contraction value
      j, i: indices in the contraction matrix (row j, col i)
    """
    occ_idx = [i for i, o in enumerate(occ) if o > occ_thresh]
    virt_idx = [i for i, o in enumerate(occ) if o <= occ_thresh]
    if not occ_idx or not virt_idx:
        raise ValueError("Failed to classify occupied/virtual MOs from occupations.")

    S = build_amp_matrix(sing_state.excitations, occ_idx, virt_idx)
    T = build_amp_matrix(trip_state.excitations, occ_idx, virt_idx)

    v = T.conj().T @ S   # (occ x occ)
    o = T.conj() @ S.T   # (virt x virt)

    v_abs = np.abs(v)
    jv, iv = np.unravel_index(np.argmax(v_abs), v_abs.shape)
    vmax = float(v_abs[jv, iv])

    o_abs = np.abs(o)
    jo, io = np.unravel_index(np.argmax(o_abs), o_abs.shape)
    omax = float(o_abs[jo, io])

    if omax >= vmax:
        mo1 = virt_idx[jo]
        mo2 = virt_idx[io]
        return "virt", mo1, mo2, omax, jo, io

    mo1 = occ_idx[jv]
    mo2 = occ_idx[iv]
    return "occ", mo1, mo2, vmax, jv, iv


def compute_sigmas_partitioned(
    C: np.ndarray,
    occ: np.ndarray,
    ao_atom: List[int],
    hx: np.ndarray,
    hy: np.ndarray,
    hz: np.ndarray,
    sing_state: ParsedState,
    trip_state: ParsedState,
    hbar: float,
    occ_thresh: float,
    partition: str,
    lowdin_eig_tol: float,
) -> Tuple[Dict[int, Dict[int, complex]], Dict[int, complex]]:
    occ_idx = [i for i, o in enumerate(occ) if o > occ_thresh]
    virt_idx = [i for i, o in enumerate(occ) if o <= occ_thresh]
    if not occ_idx or not virt_idx:
        raise ValueError("Failed to classify occupied/virtual MOs from occupations.")

    S_amp = build_amp_matrix(sing_state.excitations, occ_idx, virt_idx)
    T_amp = build_amp_matrix(trip_state.excitations, occ_idx, virt_idx)

    v = T_amp.conj().T @ S_amp       # (occ x occ)
    o = T_amp.conj() @ S_amp.T       # (virt x virt)

    atoms = sorted(set(ao_atom))
    sigmas_by_atom: Dict[int, Dict[int, complex]] = {A: {} for A in atoms}
    totals: Dict[int, complex] = {M: 0j for M in Ms}

    partition_lc = partition.lower()
    if partition_lc not in {"mulliken", "lowdin"}:
        raise ValueError("partition must be 'mulliken' or 'lowdin'.")

    if partition_lc == "lowdin":
        overlap = reconstruct_overlap_from_mos(C)
        S_half, S_mhalf = symmetric_matrix_half_powers(overlap, eig_tol=lowdin_eig_tol)
        C_work = S_half @ C
        atom_projectors = {A: build_atom_projector(ao_atom, A) for A in atoms}
    else:
        C_work = C
        atom_projectors = {}
        S_mhalf = None

    for M in Ms:
        q = -M
        Hq_total = cart_to_spherical(hx, hy, hz, q)

        if partition_lc == "lowdin":
            Hq_work = S_mhalf @ Hq_total @ S_mhalf
        else:
            Hq_work = Hq_total

        for A in atoms:
            if partition_lc == "lowdin":
                W_A = build_W_lowdin_orthogonal(Hq_work, atom_projectors[A])
            else:
                W_A = build_W_mulliken(Hq_work, ao_atom, A)

            h_mo = C_work.T @ W_A @ C_work

            hvirt = h_mo[np.ix_(virt_idx, virt_idx)]
            hocc = h_mo[np.ix_(occ_idx, occ_idx)]

            term1 = np.sum(o * hvirt)
            term2 = np.sum(v * hocc.T)

            sigma = ((-1) ** M) * (hbar / 2.0) * (term1 - term2)
            sigmas_by_atom[A][M] = sigma
            totals[M] += sigma

    return sigmas_by_atom, totals


def discover_outputs(scan_dir: Path, suffix: str) -> List[Tuple[int, Path]]:
    pat = re.compile(r"^(\d+)_" + re.escape(suffix) + r"$")
    hits: List[Tuple[int, Path]] = []
    for p in scan_dir.iterdir():
        if p.is_file():
            m = pat.match(p.name)
            if m:
                hits.append((int(m.group(1)), p))
    hits.sort(key=lambda x: x[0])
    return hits


def fmt8(contrib: Dict[str, float]) -> List[str]:
    return [f"{contrib[k]:.10g}" for k in ORBITALS]


def main(
    scan_dir: Path = SCAN_DIR,
    suffix: str = SUFFIX,
    out_file: Path = OUT_FILE,
    singlet_ord: int = SINGLET_ORD,
    triplet_ord: int = TRIPLET_ORD,
    occ_thresh: float = OCC_THRESH,
    hbar_au: float = HBAR_AU,
    partition: str = PARTITION_SCHEME,
    lowdin_eig_tol: float = LOWDIN_EIG_TOL,
) -> None:
    files = discover_outputs(scan_dir, suffix)
    if not files:
        raise SystemExit(f"No files found in {scan_dir} matching: <i>_{suffix}")

    out_lines: List[str] = []

    for idx, path in files:
        text = read_text(path)

        e_elec_eh = parse_final_single_point_energy(text)
        geom_elems, geom_xyz = parse_last_cart_coords_angstrom(text)
        n_atoms_geom = len(geom_elems)

        nbf = parse_nbf(text)
        C, eps, occ, ao_atom, ao_label, atom_elems, mo_base = parse_mo_coefficients(text, nbf)
        ao_atom, atom_elems = maybe_shift_atom_indexing(ao_atom, atom_elems, n_atoms_geom)

        sing_states = parse_states_ordinal(extract_excited_section(text, "SINGLETS"), mo_base)
        trip_states = parse_states_ordinal(extract_excited_section(text, "TRIPLETS"), mo_base)

        if not (1 <= singlet_ord <= len(sing_states)):
            raise SystemExit(f"In {path.name}: singlet_ord={singlet_ord} out of range (1..{len(sing_states)}).")
        if not (1 <= triplet_ord <= len(trip_states)):
            raise SystemExit(f"In {path.name}: triplet_ord={triplet_ord} out of range (1..{len(trip_states)}).")

        S_state = sing_states[singlet_ord - 1]
        T_state = trip_states[triplet_ord - 1]

        kind, mo1, mo2, maxabs, j_idx, i_idx = characteristic_pair_from_contractions(
            S_state, T_state, occ, occ_thresh
        )

        contribs1 = pd_contributions_all_atoms_to_mo(C, mo1, ao_atom, ao_label, n_atoms_geom)
        contribs2 = contribs1 if mo2 == mo1 else pd_contributions_all_atoms_to_mo(C, mo2, ao_atom, ao_label, n_atoms_geom)

        hx, hy, hz = parse_soc_ao_matrices(text, nbf)
        sigmas_by_atom, totals = compute_sigmas_partitioned(
            C=C,
            occ=occ,
            ao_atom=ao_atom,
            hx=hx,
            hy=hy,
            hz=hz,
            sing_state=S_state,
            trip_state=T_state,
            hbar=hbar_au,
            occ_thresh=occ_thresh,
            partition=partition,
            lowdin_eig_tol=lowdin_eig_tol,
        )
        total_norm_cm1 = math.sqrt(sum(abs(totals[M]) ** 2 for M in Ms)) * EH_TO_CM1

        atom_norms_cm1 = [0.0] * n_atoms_geom
        for A in range(n_atoms_geom):
            if A in sigmas_by_atom:
                atom_norms_cm1[A] = math.sqrt(sum(abs(sigmas_by_atom[A][M]) ** 2 for M in Ms)) * EH_TO_CM1

        N = len(geom_elems)
        out_lines.append(str(N))
        out_lines.append(
            " ".join(
                [
                    f"i={idx}",
                    f"E_el_Eh={e_elec_eh:.12f}",
                    f"S_n={singlet_ord}",
                    f"T_n={triplet_ord}",
                    f"partition={partition.lower()}",
                    f"MOs={kind}",
                    f"MO1={mo1 + mo_base}",
                    f"MO2={mo2 + mo_base}",
                    f"TOTAL_SOC_cm-1={total_norm_cm1:.10g}",
                ]
            )
        )
        zero_contrib = {k: 0.0 for k in ORBITALS}

        for atom_i, (el, (x, y, z)) in enumerate(zip(geom_elems, geom_xyz)):
            c18 = fmt8(contribs1.get(atom_i, zero_contrib))
            c28 = fmt8(contribs2.get(atom_i, zero_contrib))
            out_lines.append(
                f"{el} {x:.10f} {y:.10f} {z:.10f} {atom_norms_cm1[atom_i]:.10g} "
                + " ".join(c18)
                + " "
                + " ".join(c28)
            )

    out_file.write_text("\n".join(out_lines) + "\n")
    print(f"Wrote: {out_file}  (steps={len(files)}, partition={partition.lower()})")


def _parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description=(
            "Parse ORCA scan outputs and write an XYZ trajectory with per-atom SOP and "
            "MO p/d contributions."
        )
    )
    p.add_argument(
        "scan_dir",
        nargs="?",
        default=str(SCAN_DIR),
        help="Directory containing i_<suffix> ORCA output files (default: SCAN_DIR in script).",
    )
    p.add_argument(
        "--suffix",
        default=SUFFIX,
        help="Filename suffix after the numeric prefix (default: SUFFIX in script).",
    )
    p.add_argument(
        "-o",
        "--out",
        dest="out_file",
        default=str(OUT_FILE),
        help="Output XYZ filename (default: OUT_FILE in script).",
    )
    p.add_argument(
        "--singlet",
        dest="singlet_ord",
        type=int,
        default=SINGLET_ORD,
        help="1-based ordinal of the singlet state in the SINGLETS section.",
    )
    p.add_argument(
        "--triplet",
        dest="triplet_ord",
        type=int,
        default=TRIPLET_ORD,
        help="1-based ordinal of the triplet state in the TRIPLETS section.",
    )
    p.add_argument(
        "--occ-thresh",
        type=float,
        default=OCC_THRESH,
        help="Occupation threshold to classify occupied vs virtual MOs.",
    )
    p.add_argument(
        "--hbar-au",
        type=float,
        default=HBAR_AU,
        help="Value of ħ in atomic units used in the SOC prefactor.",
    )
    p.add_argument(
        "--partition",
        choices=["mulliken", "lowdin"],
        default=PARTITION_SCHEME,
        help="Atomic partitioning of the AO SOC matrices.",
    )
    p.add_argument(
        "--lowdin-eig-tol",
        type=float,
        default=LOWDIN_EIG_TOL,
        help="Minimum allowed overlap eigenvalue for Löwdin partitioning.",
    )
    return p.parse_args()


if __name__ == "__main__":
    args = _parse_args()
    main(
        scan_dir=Path(args.scan_dir),
        suffix=args.suffix,
        out_file=Path(args.out_file),
        singlet_ord=args.singlet_ord,
        triplet_ord=args.triplet_ord,
        occ_thresh=args.occ_thresh,
        hbar_au=args.hbar_au,
        partition=args.partition,
        lowdin_eig_tol=args.lowdin_eig_tol,
    )
