// molecule-generator.js
import {
  cpk_colors,
  vdw_radii,
  getCPKColor,
  createCurvedBondMesh,
  createCurvedRidgedBondMeshRelaxed,
  createTextLabel,
  rotateVector,
  adjustBondEndpoints,
  getAtomVisualRadius,
  calculateDoubleBondMidpoint,
  quadraticBezier,
  createCylinderBetweenPoints,
  drawCurvedBondRelaxed,
  forceRotateTerminalHydrogen
} from "./molecule-utils.js";

import {
  setMoleculeGroup,
  getMoleculeGroup,
  centerMolecule,
  scene
} from "./scene.js";

import { equalizeBondLengthsPreservingAngles } from "./molecule_bond-helpers.js";
import {
  getLastValidSmilesArray,
  setLastValidSmilesArray,
  getCurrentSmiles,
  setCurrentSmiles
} from "./state.js";
import { processToken } from "./smiles-parser.js";

let rdkitReady = false;
export let labelsGroup = new THREE.Group();

export async function loadRDKit() {
  return new Promise((resolve, reject) => {
    if (typeof RDKit === "undefined") {
      console.error("RDKit is not loaded!");
      reject("RDKit failed to load.");
    } else {
      RDKit.load()
        .then(() => {
          rdkitReady = true;
          console.log("RDKit loaded successfully!");
          resolve();
        })
        .catch(reject);
    }
  });
}

export async function generateMoleculeFromSmiles(smiles) {
  setCurrentSmiles(smiles);
  setLastValidSmilesArray([smiles]);
  const oldGroup = getMoleculeGroup();
  if (oldGroup) scene.remove(oldGroup);
  scene.remove(labelsGroup);
  labelsGroup = new THREE.Group();
  const generatedMolecule = await createMoleculeGroupFromSmiles(smiles);
  const newGroup = generatedMolecule[0];
  setMoleculeGroup(newGroup);
  scene.add(newGroup);
  const box = new THREE.Box3().setFromObject(newGroup);
  const center = box.getCenter(new THREE.Vector3());
  newGroup.position.sub(center);
  return newGroup;
}

export async function createMoleculeGroupFromSmiles(smiles, modelTypeOverride = null) {
  let molecule;
  if (rdkitReady) {
    try {
      const mol = RDKit.Molecule.smilesToMol(smiles);
      mol.addHs();
      mol.EmbedMolecule();
      mol.MMFFoptimizeMolecule();
      const molBlock = mol.molToMolfile();
      mol.delete();
      molecule = OCL.Molecule.fromMolfile(molBlock);
    } catch (e) {
      rdkitWarnings.push(smiles);
      console.error("RDKit 3D generation failed, falling back to OCL:", e);
      molecule = OCL.Molecule.fromSmiles(smiles);
      molecule.addImplicitHydrogens();
      molecule.ensureHelperArrays(OCL.Molecule.cHelper3D);
    }
  } else {
    molecule = OCL.Molecule.fromSmiles(smiles);
    molecule.addImplicitHydrogens();
    molecule.ensureHelperArrays(OCL.Molecule.cHelper3D);
  }

  const newGroup = new THREE.Group();
  const modelType = modelTypeOverride || document.getElementById("modelType").value;
  equalizeBondLengthsPreservingAngles(molecule, 1.4);

  const atomData = [];
  for (let i = 0; i < molecule.getAllAtoms(); i++) {
    const label = molecule.getAtomLabel(i);
    let radius = 0.15;
    if (modelType === "spacefilling") {
      radius = (vdw_radii[label] || 1.5) * 0.6;
    }
    atomData.push({
      pos: new THREE.Vector3(
        molecule.getAtomX(i),
        molecule.getAtomY(i),
        molecule.getAtomZ(i)
      ),
      type: label,
      radius: radius
    });
  }

  if (molecule.getAllAtoms() > 4) {
    const c3Pos = atomData[3].pos;
    const c4Pos = atomData[4].pos;
    const bondAxis = new THREE.Vector3().subVectors(c4Pos, c3Pos).normalize();
    forceRotateTerminalHydrogen(molecule, 4, bondAxis);
    for (let i = 0; i < molecule.getAllAtoms(); i++) {
      atomData[i].pos.set(
        molecule.getAtomX(i),
        molecule.getAtomY(i),
        molecule.getAtomZ(i)
      );
    }
  }

  const endpointRecords = [];
  const numBonds = molecule.getAllBonds();
  let doubleBondCounter = 0;
  const doubleBondSeparation = 0.5;
  for (let i = 0; i < numBonds; i++) {
    const a1 = molecule.getBondAtom(0, i);
    const a2 = molecule.getBondAtom(1, i);
    const order = molecule.getBondOrder(i);
    const r1 = (modelType === "spacefilling") ? atomData[a1].radius : 0.15;
    const r2 = (modelType === "spacefilling") ? atomData[a2].radius : 0.15;
    if (order !== 2) {
      const [p1_adj, p2_adj] = adjustBondEndpoints(atomData[a1].pos, atomData[a2].pos, r1, r2);
      endpointRecords.push({
        atom_idx: a1,
        bond_idx: i,
        endpoint: p1_adj,
        is_double: false,
        arc_sign: null,
        other_atom: a2
      });
      endpointRecords.push({
        atom_idx: a2,
        bond_idx: i,
        endpoint: p2_adj,
        is_double: false,
        arc_sign: null,
        other_atom: a1
      });
    } else {
      const p_i = (function adjustEndpoint(idx, otherPos) {
        const center = atomData[idx].pos;
        const r = (modelType === "spacefilling") ? atomData[idx].radius : -0.05;
        const direction = new THREE.Vector3().subVectors(otherPos, center).normalize();
        return center.clone().add(direction.multiplyScalar(r));
      })(a1, atomData[a2].pos);
      const p_j = (function adjustEndpoint(idx, otherPos) {
        const center = atomData[idx].pos;
        const r = (modelType === "spacefilling") ? atomData[idx].radius : -0.05;
        const direction = new THREE.Vector3().subVectors(otherPos, center).normalize();
        return center.clone().add(direction.multiplyScalar(r));
      })(a2, atomData[a1].pos);
      const v = new THREE.Vector3().subVectors(p_j, p_i);
      const n = v.clone().normalize();
      let trial = new THREE.Vector3(0, 1, 0);
      if (Math.abs(n.dot(trial)) > 0.99) trial = new THREE.Vector3(1, 0, 0);
      let p = new THREE.Vector3().crossVectors(n, trial).normalize();
      const rot_angle = (doubleBondCounter % 2) * (Math.PI / 2);
      if (rot_angle !== 0) {
        p = rotateVector(p, n, rot_angle);
      }
      doubleBondCounter++;
      for (const sign of [1, -1]) {
        const T_start = v.clone().add(p.clone().multiplyScalar(sign * doubleBondSeparation * Math.PI)).normalize();
        const T_end = v.clone().sub(p.clone().multiplyScalar(sign * doubleBondSeparation * Math.PI)).normalize();
        const p1_arc = p_i.clone().add(T_start.multiplyScalar(r1));
        const p2_arc = p_j.clone().sub(T_end.multiplyScalar(r2));
        endpointRecords.push({
          atom_idx: a1,
          bond_idx: i,
          endpoint: p1_arc,
          is_double: true,
          arc_sign: sign,
          other_atom: a2
        });
        endpointRecords.push({
          atom_idx: a2,
          bond_idx: i,
          endpoint: p2_arc,
          is_double: true,
          arc_sign: sign,
          other_atom: a1
        });
      }
    }
  }

  const endpointsByAtom = {};
  endpointRecords.forEach(rec => {
    if (!endpointsByAtom[rec.atom_idx]) endpointsByAtom[rec.atom_idx] = [];
    endpointsByAtom[rec.atom_idx].push(rec);
  });

  const numRelaxIterations = 20;
  const relaxAlpha = 0.1;
  const targetCos = Math.cos(109.5 * Math.PI / 180);
  for (const atomIdx in endpointsByAtom) {
    const records = endpointsByAtom[atomIdx];
    const center = atomData[atomIdx].pos;
    const R = getAtomVisualRadius(molecule.getAtomLabel(parseInt(atomIdx)));
    const fixed = records.filter(rec => !rec.is_double);
    const adjustable = records.filter(rec => rec.is_double);
    const doubleBondIds = new Set(adjustable.map(rec => rec.bond_idx));
    if (fixed.length > 0 && doubleBondIds.size === 1 && adjustable.length > 0) {
      adjustable.forEach(rec => {
        rec.v = rec.endpoint.clone().sub(center).normalize();
      });
      for (let it = 0; it < numRelaxIterations; it++) {
        adjustable.forEach(rec => {
          let v = rec.v;
          let grad = new THREE.Vector3(0, 0, 0);
          fixed.forEach(fixedRec => {
            const u = fixedRec.endpoint.clone().sub(center).normalize();
            const error = v.dot(u) - targetCos;
            const correction = u.clone().sub(v.clone().multiplyScalar(v.dot(u))).multiplyScalar(error);
            grad.add(correction);
          });
          rec.v = v.clone().sub(grad.multiplyScalar(relaxAlpha)).normalize();
        });
      }
      adjustable.forEach(rec => {
        rec.endpoint = center.clone().add(rec.v.clone().multiplyScalar(R));
      });
    }
  }

  const endpointsDict = {};
  endpointRecords.forEach(rec => {
    const key = rec.bond_idx + "_" + rec.atom_idx + "_" + rec.arc_sign;
    endpointsDict[key] = rec.endpoint;
  });

  if (modelType === "stick" || modelType === "building kit") {
    atomData.forEach(data => {
      const geometry = new THREE.SphereGeometry(0.15, 32, 32);
      const material = new THREE.MeshStandardMaterial({
        color: getCPKColor(data.type),
        metalness: 0.0,
        roughness: 0.5
      });
      const atomMesh = new THREE.Mesh(geometry, material);
      atomMesh.position.copy(data.pos);
      newGroup.add(atomMesh);
    });
  } else if (modelType === "spacefilling") {
    atomData.forEach(data => {
      const geometry = new THREE.SphereGeometry(data.radius, 32, 32);
      const material = new THREE.MeshStandardMaterial({
        color: getCPKColor(data.type),
        metalness: 0.0,
        roughness: 0.5
      });
      const atomMesh = new THREE.Mesh(geometry, material);
      atomMesh.position.copy(data.pos);
      newGroup.add(atomMesh);
    });
  }

  if (modelType === "stick") {
    for (let i = 0; i < numBonds; i++) {
      const a1 = molecule.getBondAtom(0, i);
      const a2 = molecule.getBondAtom(1, i);
      const order = molecule.getBondOrder(i);
      if (order === 1) {
        const key1 = i + "_" + a1 + "_null";
        const key2 = i + "_" + a2 + "_null";
        let start = endpointsDict[key1];
        let end = endpointsDict[key2];
        if (!start || !end) {
          [start, end] = adjustBondEndpoints(atomData[a1].pos, atomData[a2].pos, 0.15, 0.15);
        }
        const mid = new THREE.Vector3().addVectors(start, end).multiplyScalar(0.5);
        const material1 = new THREE.MeshStandardMaterial({ color: getCPKColor(atomData[a1].type) });
        newGroup.add(createCurvedBondMesh(start, mid, new THREE.Vector3(0, 0, 0), 0.04, material1));
        newGroup.add(createCurvedBondMesh(mid, end, new THREE.Vector3(0, 0, 0), 0.04, material1));
      } else if (order === 2) {
        for (const sign of [1, -1]) {
          const key1 = i + "_" + a1 + "_" + sign;
          const key2 = i + "_" + a2 + "_" + sign;
          const p1 = endpointsDict[key1];
          const p2 = endpointsDict[key2];
          if (p1 && p2) {
            const mid = calculateDoubleBondMidpoint(p1, p2, atomData[a1].pos, atomData[a2].pos, 0.24);
            newGroup.add(drawCurvedBondRelaxed(p1, p2, mid, 16, 0.04, getCPKColor(atomData[a1].type)));
          }
        }
      } else if (order === 3) {
        const key1 = i + "_" + a1 + "_null";
        const key2 = i + "_" + a2 + "_null";
        let origStart = endpointsDict[key1];
        let origEnd = endpointsDict[key2];
        if (!origStart || !origEnd) {
          [origStart, origEnd] = adjustBondEndpoints(atomData[a1].pos, atomData[a2].pos, 0.15, 0.15);
        }
        const bondVec = new THREE.Vector3().subVectors(origEnd, origStart).normalize();
        let trial = new THREE.Vector3(0, 1, 0);
        if (Math.abs(bondVec.dot(trial)) > 0.99) trial = new THREE.Vector3(1, 0, 0);
        const perp1 = new THREE.Vector3().crossVectors(bondVec, trial).normalize();
        const spacing = 0.5;
        const offsets = [
          perp1.clone().multiplyScalar(spacing),
          perp1.clone().applyAxisAngle(bondVec, THREE.MathUtils.degToRad(120)).multiplyScalar(spacing),
          perp1.clone().applyAxisAngle(bondVec, THREE.MathUtils.degToRad(-120)).multiplyScalar(spacing)
        ];
        const material1 = new THREE.MeshStandardMaterial({ color: getCPKColor(atomData[a1].type) });
        offsets.forEach(offset => {
          const center1 = atomData[a1].pos;
          const center2 = atomData[a2].pos;
          const r1 = origStart.clone().sub(center1).length();
          const r2 = origEnd.clone().sub(center2).length();
          const newStart = center1.clone().add(
            origStart.clone().sub(center1).add(offset).normalize().multiplyScalar(r1)
          );
          const newEnd = center2.clone().add(
            origEnd.clone().sub(center2).add(offset).normalize().multiplyScalar(r2)
          );
          const mid = calculateDoubleBondMidpoint(newStart, newEnd, center1, center2, 0.4);
          newGroup.add(drawCurvedBondRelaxed(newStart, newEnd, mid, 16, 0.04, getCPKColor(atomData[a1].type)));
        });
      }
    }
  } else if (modelType === "building kit") {
    for (let i = 0; i < numBonds; i++) {
      const a1 = molecule.getBondAtom(0, i);
      const a2 = molecule.getBondAtom(1, i);
      const order = molecule.getBondOrder(i);
      if (order === 1) {
        const key1 = i + "_" + a1 + "_null";
        const key2 = i + "_" + a2 + "_null";
        let start = endpointsDict[key1];
        let end = endpointsDict[key2];
        const mid = calculateDoubleBondMidpoint(start, end, atomData[a1].pos, atomData[a2].pos, 0);
        newGroup.add(createCurvedRidgedBondMeshRelaxed(start, end, mid, 16, 0.04, getCPKColor(atomData[a1].type)));
      } else if (order === 2) {
        for (const sign of [1, -1]) {
          const key1 = i + "_" + a1 + "_" + sign;
          const key2 = i + "_" + a2 + "_" + sign;
          const p1 = endpointsDict[key1];
          const p2 = endpointsDict[key2];
          if (p1 && p2) {
            const mid = calculateDoubleBondMidpoint(p1, p2, atomData[a1].pos, atomData[a2].pos, 0.4);
            newGroup.add(createCurvedRidgedBondMeshRelaxed(p1, p2, mid, 16, 0.04, getCPKColor(atomData[a1].type)));
          }
        }
      } else if (order === 3) {
        const key1 = i + "_" + a1 + "_null";
        const key2 = i + "_" + a2 + "_null";
        let origStart = endpointsDict[key1];
        let origEnd = endpointsDict[key2];
        if (!origStart || !origEnd) {
          [origStart, origEnd] = adjustBondEndpoints(atomData[a1].pos, atomData[a2].pos, 0.15, 0.15);
        }
        const bondVec = new THREE.Vector3().subVectors(origEnd, origStart).normalize();
        let trial = new THREE.Vector3(0, 1, 0);
        if (Math.abs(bondVec.dot(trial)) > 0.99) trial = new THREE.Vector3(1, 0, 0);
        const perp1 = new THREE.Vector3().crossVectors(bondVec, trial).normalize();
        const spacing = 0.5;
        const offsets = [
          perp1.clone().multiplyScalar(spacing),
          perp1.clone().applyAxisAngle(bondVec, THREE.MathUtils.degToRad(120)).multiplyScalar(spacing),
          perp1.clone().applyAxisAngle(bondVec, THREE.MathUtils.degToRad(-120)).multiplyScalar(spacing)
        ];
        const material1 = new THREE.MeshStandardMaterial({ color: getCPKColor(atomData[a1].type) });
        offsets.forEach(offset => {
          const center1 = atomData[a1].pos;
          const center2 = atomData[a2].pos;
          const r1 = origStart.clone().sub(center1).length();
          const r2 = origEnd.clone().sub(center2).length();
          const newStart = center1.clone().add(
            origStart.clone().sub(center1).add(offset).normalize().multiplyScalar(r1)
          );
          const newEnd = center2.clone().add(
            origEnd.clone().sub(center2).add(offset).normalize().multiplyScalar(r2)
          );
          const mid = calculateDoubleBondMidpoint(newStart, newEnd, center1, center2, 0.4);
          newGroup.add(createCurvedRidgedBondMeshRelaxed(newStart, newEnd, mid, 16, 0.04, material1));
        });
      }
    }
  }

  const boundingBox = new THREE.Box3Helper(new THREE.Box3().setFromObject(newGroup), 0xffff00);
  return [newGroup, boundingBox];
}

const rdkitWarnings = [];

export async function generateMoleculesFromList(tokens) {
  const oldGroup = getMoleculeGroup();
  if (oldGroup) scene.remove(oldGroup);
  scene.remove(labelsGroup);
  labelsGroup = new THREE.Group();
  const validSmilesArray = [];
  for (const token of tokens) {
    const valid = await processToken(token);
    if (valid) {
      validSmilesArray.push(valid);
    } else {
      console.warn(`Skipping token "${token}" because it is invalid or empty.`);
    }
  }
  if (validSmilesArray.length === 0) {
    alert("No valid SMILES to display!");
    return;
  }
  setLastValidSmilesArray(validSmilesArray);
  const containerGroup = new THREE.Group();
  const spacing = 5;
  const numMolecules = validSmilesArray.length;
  const startX = -((numMolecules - 1) * spacing) / 2;
  for (let i = 0; i < numMolecules; i++) {
    const smiles = validSmilesArray[i];
    try {
      const generatedMolecule = await createMoleculeGroupFromSmiles(smiles);
      const molGroup = generatedMolecule[0];
      let bbox = new THREE.Box3().setFromObject(molGroup);
      let bboxCenter = bbox.getCenter(new THREE.Vector3());
      let bboxMin = bbox.min;
      const desiredX = startX + i * spacing;
      let offsetX = desiredX - bboxCenter.x;
      let offsetY = -bboxMin.y;
      let offsetZ = -bboxCenter.z;
      molGroup.position.add(new THREE.Vector3(offsetX, offsetY, offsetZ));
      containerGroup.add(molGroup);
      const labelSprite = createTextLabel(smiles);
      labelSprite.position.set(desiredX, -2, 0);
      labelsGroup.add(labelSprite);
    } catch (err) {
      console.error(`Error generating molecule for "${smiles}":`, err);
    }
  }


  if (rdkitWarnings.length > 0) {
    window.alert(
      "Warning - The following SMILES could not be processed correctly:\n" +
      rdkitWarnings.join("\n")
    );
    rdkitWarnings.length = 0; // Clear warnings for future runs.
  }

  setMoleculeGroup(containerGroup);
  scene.add(containerGroup);
  scene.add(labelsGroup);
  centerMolecule();
}
