import openmm

residue_anisotropies = {residue_anisotropies}
patch_anisotropies = {patch_anisotropies}
delete_anisotropies = {delete_anisotropies}

residue_order = {residue_order}
patch_order = {patch_order}

def extract_atom_name(raw_atom_name, templates):
    # Process Drude particles.
    atom_name_prefix = ""
    if raw_atom_name.startswith("Drude-"):
        raw_atom_name = raw_atom_name[len("Drude-"):]
        atom_name_prefix = "D"

    # Return None to indicate an atom without template name information.
    atom_name_parts = raw_atom_name.rsplit("-", maxsplit=1)
    if len(atom_name_parts) < 2:
        return None
    templates.add(atom_name_parts[0])
    return f"{{atom_name_prefix}}{{atom_name_parts[1]}}"

residue_data = []

for residue in topology.residues():
    # Extract the residue or patch names and atom names.
    templates = set()
    if residue not in templateForResidue:
        # Multi-residue patch; fall back to parsing atom names.
        atom_names = []
        skip_residue = False
        for atom in residue.atoms():
            atom_name = extract_atom_name(data.atomType[atom], templates)
            if atom_name is None:
                # Skip residues containing atoms without template name data
                # (they might be from, e.g., another solvent force field file).
                skip_residue = True
                break
            atom_names.append(atom_name)
        if skip_residue:
            residue_data.append(([], [], []))
            continue
    else:
        # Extract names from template name.
        template_data = templateForResidue[residue]
        for template_index, template_name in enumerate(template_data.name.split("-")):
            if template_index:
                # Patch name.
                templates.add(template_name.rsplit("_", maxsplit=1)[0])
            else:
                # Residue name.
                templates.add(template_name)
        atom_names = [template_data.atoms[data.atomTemplateIndexes[atom]].name for atom in residue.atoms()]

    # Get all unique residue or patch names that have anisotropies.
    residue_templates = sorted(
        (template for template in templates if template in residue_anisotropies), key=residue_order.get
    )
    patch_templates = sorted(
        (template for template in templates if template in patch_anisotropies), key=patch_order.get
    )

    # Get atom indices and classes for lookup, and ensure atom name uniqueness.
    atom_data = {{
        atom_name: (atom.index, self._atomTypes[data.atomType[atom]].atomClass)
        for atom, atom_name in zip(residue.atoms(), atom_names)
    }}
    if len(atom_data) != len(atom_names):
        raise ValueError(f"CHARMM: atom name collision in residue with index {{residue.index}}")

    residue_data.append((atom_data, residue_templates, patch_templates))

anisotropy_table = {{}}

for residue_index, (atom_data, residue_templates, patch_templates) in enumerate(residue_data):
    # Determine all anisotropies to try to add to this residue.
    anisotropies = {{}}
    for residue in residue_templates:
        anisotropies.update(residue_anisotropies.get(residue, {{}}))
    for patch in patch_templates:
        for drude_atom in delete_anisotropies.get(patch, []):
            if drude_atom in anisotropies:
                del anisotropies[drude_atom]
        anisotropies.update(patch_anisotropies.get(patch, {{}}))

    # Find the atoms, and skip if any cannot be found.
    for atom_1, (atom_2, atom_3, atom_4, aniso12, aniso34) in anisotropies.items():
        atom_data_1 = atom_data.get(atom_1)
        atom_data_2 = atom_data.get(atom_2)
        atom_data_3 = atom_data.get(atom_3)
        atom_data_4 = atom_data.get(atom_4)
        if atom_data_1 is None or atom_data_2 is None or atom_data_3 is None or atom_data_4 is None:
            continue
        anisotropy_table[atom_data_1[0]] = (atom_data_2[0], atom_data_3[0], atom_data_4[0], aniso12, aniso34)

# Add anisotropies to Drude particles.
for force in sys.getForces():
    if not isinstance(force, openmm.DrudeForce):
        continue
    for particle_index in range(force.getNumParticles()):
        particle, particle1, particle2, particle3, particle4, charge, polarizability, aniso12, aniso34 = force.getParticleParameters(particle_index)
        if particle1 in anisotropy_table:
            particle2, particle3, particle4, aniso12, aniso34 = anisotropy_table[particle1]
            force.setParticleParameters(particle_index, particle, particle1, particle2, particle3, particle4, charge, polarizability, aniso12, aniso34)
