import itertools
import math
import openmm

residue_impropers = {residue_impropers}
patch_impropers = {patch_impropers}
delete_impropers = {delete_impropers}

residue_order = {residue_order}
patch_order = {patch_order}

periodic_improper_types = {periodic_improper_types}
harmonic_improper_types = {harmonic_improper_types}

def extract_atom_name(raw_atom_name, templates):
    # Process Drude particles.
    atom_name_prefix = ""
    if raw_atom_name.startswith("Drude-"):
        raw_atom_name = raw_atom_name[len("Drude-"):]
        atom_name_prefix = "D"

    # Return None to indicate an atom without template name information.
    atom_name_parts = raw_atom_name.rsplit("-", maxsplit=1)
    if len(atom_name_parts) < 2:
        return None
    templates.add(atom_name_parts[0])
    return f"{{atom_name_prefix}}{{atom_name_parts[1]}}"

def is_valid_improper(atom_indices):
    index_1, index_2, index_3, index_4 = atom_indices
    bonded_to_1 = data.bondedToAtom[index_1]
    bonded_to_4 = data.bondedToAtom[index_4]
    is_improper_1 = index_2 in bonded_to_1 and index_3 in bonded_to_1 and index_4 in bonded_to_1
    is_improper_4 = index_1 in bonded_to_4 and index_2 in bonded_to_4 and index_3 in bonded_to_4
    return is_improper_1 or is_improper_4

def find_improper(improper_types, atom_classes):
    # Try all combinations of wildcards.
    for wildcard_count in range(len(atom_classes) + 1):
        for wildcard_indices in itertools.combinations(range(len(atom_classes)), wildcard_count):
            lookup_key = list(atom_classes)
            for wildcard_index in wildcard_indices:
                lookup_key[wildcard_index] = None
            lookup_key = tuple(lookup_key)
            if lookup_key in improper_types:
                return improper_types[lookup_key]
    return None

residue_data = []

for residue in topology.residues():
    # Extract the residue or patch names and atom names.
    templates = set()
    if residue not in templateForResidue:
        # Multi-residue patch; fall back to parsing atom names.
        atom_names = []
        skip_residue = False
        for atom in residue.atoms():
            atom_name = extract_atom_name(data.atomType[atom], templates)
            if atom_name is None:
                # Skip residues containing atoms without template name data
                # (they might be from, e.g., another solvent force field file).
                skip_residue = True
                break
            atom_names.append(atom_name)
        if skip_residue:
            residue_data.append(([], [], []))
            continue
    else:
        # Extract names from template name.
        template_data = templateForResidue[residue]
        for template_index, template_name in enumerate(template_data.name.split("-")):
            if template_index:
                # Patch name.
                templates.add(template_name.rsplit("_", maxsplit=1)[0])
            else:
                # Residue name.
                templates.add(template_name)
        atom_names = [template_data.atoms[data.atomTemplateIndexes[atom]].name for atom in residue.atoms()]

    # Get all unique residue or patch names that have impropers.
    residue_templates = sorted(
        (template for template in templates if template in residue_impropers), key=residue_order.get
    )
    patch_templates = sorted(
        (template for template in templates if template in patch_impropers), key=patch_order.get
    )

    # Get atom indices and classes for lookup, and ensure atom name uniqueness.
    atom_data = {{
        atom_name: (atom.index, self._atomTypes[data.atomType[atom]].atomClass)
        for atom, atom_name in zip(residue.atoms(), atom_names)
    }}
    if len(atom_data) != len(atom_names):
        raise ValueError(f"CHARMM: atom name collision in residue with index {{residue.index}}")

    residue_data.append((atom_data, residue_templates, patch_templates))

periodic_impropers = []
harmonic_impropers = []

for residue_index, (atom_data, residue_templates, patch_templates) in enumerate(residue_data):
    # Determine all impropers to try to add to this residue.  Use a list to
    # try to maintain a consistent order.  There should not be many impropers
    # per residue so this should not be a significant performance issue.
    impropers = []
    for residue in residue_templates:
        for improper in residue_impropers.get(residue, []):
            if improper not in impropers:
                impropers.append(improper)
    for patch in patch_templates:
        for improper in delete_impropers.get(patch, []):
            if improper in impropers:
                impropers.remove(improper)
        for improper in patch_impropers.get(patch, []):
            if improper not in impropers:
                impropers.append(improper)

    for improper in impropers:
        # Find the atoms in the improper.  Skip if any cannot be found.
        atom_indices = []
        atom_classes = []
        skip = False
        for offset, atom_name in improper:
            offset_index = residue_index + offset
            if not 0 <= offset_index < len(residue_data):
                skip = True
                break
            offset_atom_data = residue_data[offset_index][0]
            if atom_name not in offset_atom_data:
                skip = True
                break
            atom_index, atom_class = offset_atom_data[atom_name]
            atom_indices.append(atom_index)
            atom_classes.append(atom_class)
        if skip:
            continue

        # Skip if this is not a legitimate improper (this could happen if it
        # refers to an adjacent residue in, e.g., an entirely different chain).
        if not is_valid_improper(atom_indices):
            continue

        periodic_improper = find_improper(periodic_improper_types, atom_classes)
        harmonic_improper = find_improper(harmonic_improper_types, atom_classes)

        if periodic_improper is None:
            if harmonic_improper is None:
                raise ValueError(f"CHARMM: neither periodic nor harmonic improper found for {{atom_classes}}")
            else:
                harmonic_impropers.append((*atom_indices, *harmonic_improper))
        else:
            if harmonic_improper is None:
                periodic_impropers.append((*atom_indices, *periodic_improper))
            else:
                raise ValueError(f"CHARMM: both periodic and harmonic impropers found for {{atom_classes}}")

if periodic_impropers:
    periodic_force = None
    for force in sys.getForces():
        if isinstance(force, openmm.PeriodicTorsionForce):
            periodic_force = force
            break
    if periodic_force is None:
        periodic_force = openmm.PeriodicTorsionForce()
        sys.addForce(periodic_force)

    for improper in periodic_impropers:
        periodic_force.addTorsion(*improper)

if harmonic_impropers:
    harmonic_force = openmm.CustomTorsionForce(
        f"k*min(dtheta, 2*pi - dtheta)^2; dtheta = abs(theta - theta0); pi = {{math.pi}}"
    )
    harmonic_force.addPerTorsionParameter("k")
    harmonic_force.addPerTorsionParameter("theta0")

    for improper in harmonic_impropers:
        harmonic_force.addTorsion(*improper)

    sys.addForce(harmonic_force)
