Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed: allow mutiple scissions in the atom mapping process. #701

Merged
merged 24 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7380b0b
Changed map_rxn to allow the changes in this PR.
kfir4444 Sep 26, 2023
477a51a
Make sure to load all reaction families in reaction_test
kfir4444 Sep 26, 2023
ada371a
Test: added XY_elimination_hydroxyl to test_get_atom_map
kfir4444 Sep 26, 2023
6fc8510
removed prepare_reactants_and_products_for_scissors
kfir4444 Sep 26, 2023
5e87baf
removed multiple_cut_on_species
kfir4444 Sep 26, 2023
3cfaedd
removed cut_species_for_mapping
kfir4444 Sep 26, 2023
89193a3
removed find_main_cut_product
kfir4444 Sep 26, 2023
cb52c6a
minor: removed unused variable.
kfir4444 Sep 26, 2023
02a80e5
Added determine_bdes_on_spc_based_on_atom_labels function
kfir4444 Sep 26, 2023
639642b
Test: determine_bdes_on_spc_based_on_atom_labels
kfir4444 Sep 26, 2023
2d0a263
Added cut_species_based_on_atom_indices function
kfir4444 Sep 26, 2023
eff3a0a
Test: cut_species_based_on_atom_indices
kfir4444 Sep 26, 2023
0d33ff3
Added copy_species_list_for_mapping function
kfir4444 Sep 26, 2023
be36edf
Test: test_copy_species_list_for_mapping
kfir4444 Sep 26, 2023
5b26a44
Added find_all_bdes function
kfir4444 Sep 26, 2023
b4308e7
Test: find_all_bdes
kfir4444 Sep 26, 2023
98515e8
Changed the tests in engine_test
kfir4444 Sep 26, 2023
d5bd59d
removed test_cuts_on_cycle_of_labeled_mol
kfir4444 Sep 26, 2023
6aea8b2
removed test_prepare_reactants_and_products_for_scissors
kfir4444 Sep 26, 2023
ae68d2d
fixup: engine test
kfir4444 Sep 27, 2023
52a5090
removed the cuts_on_cycle_of_labeled_mol function
kfir4444 Sep 27, 2023
6cab95c
docs: cut_species_based_on_atom_indices (minor change)
kfir4444 Sep 27, 2023
9897b48
Changed the rmgdb loading to the correct function in reaction test.
kfir4444 Sep 27, 2023
5bfeaac
Made sure that rmg_reactions has entries before continuing with mapping
kfir4444 Oct 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions arc/mapping/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import arc.rmgdb as rmgdb
from arc.mapping.engine import (assign_labels_to_products,
are_adj_elements_in_agreement,
cut_species_for_mapping,
create_qc_mol,
flip_map,
fingerprint,
Expand All @@ -24,7 +23,9 @@
map_pairs,
iterative_dfs, map_two_species,
pairing_reactants_and_products_for_mapping,
prepare_reactants_and_products_for_scissors,
copy_species_list_for_mapping,
find_all_bdes,
cut_species_based_on_atom_indices,
update_xyz,
RESERVED_FINGERPRINT_KEYS,)
from arc.common import logger
Expand Down Expand Up @@ -234,17 +235,24 @@
"""
# step 1:
rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=rxn, backend=backend)

if not rmg_reactions:
return None

Check warning on line 240 in arc/mapping/driver.py

View check run for this annotation

Codecov / codecov/patch

arc/mapping/driver.py#L240

Added line #L240 was not covered by tests

r_label_dict, p_label_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=rxn,
rmg_reaction=rmg_reactions[0])

# step 2:
assign_labels_to_products(rxn, p_label_dict)
reactants, products, loc_r, loc_p = prepare_reactants_and_products_for_scissors(rxn, r_label_dict, p_label_dict)

#step 3:
label_species_atoms(reactants)
label_species_atoms(products)
r_cuts = cut_species_for_mapping(reactants, loc_r)
p_cuts = cut_species_for_mapping(products, loc_p)
reactants, products = copy_species_list_for_mapping(rxn.r_species), copy_species_list_for_mapping(rxn.p_species)
label_species_atoms(reactants), label_species_atoms(products)

r_bdes, p_bdes = find_all_bdes(rxn, r_label_dict, True), find_all_bdes(rxn, p_label_dict, False)

r_cuts = cut_species_based_on_atom_indices(reactants, r_bdes)
p_cuts = cut_species_based_on_atom_indices(products, p_bdes)

try:
make_bond_changes(rxn, r_cuts, r_label_dict)
Expand Down
288 changes: 94 additions & 194 deletions arc/mapping/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,89 +1016,6 @@
return flipped_map


def prepare_reactants_and_products_for_scissors(rxn: 'ARCReaction',
r_label_dict: dict,
p_label_dict: dict,
) -> Tuple[List[ARCSpecies], List[ARCSpecies], List[int], List[int]]:
"""
Prepares the species to be scissored.

Args:
rxn: ARC reaction object to be mapped
r_label_dict: the labels of the reactants
p_label_dict: the labels of the products

Returns:
The species in the reactants and products where a bond was broken or formed.
"""

breaks = list()
forms = list()
actions = rxn.family.forward_recipe.actions
for action in actions:
if action[0].lower() == "BREAK_BOND".lower():
breaks.append(action)
elif action[0].lower() == "FORM_BOND".lower():
forms.append(action)
reactants, products, loc_r, loc_p = [0]*len(rxn.r_species), [0]*len(rxn.p_species), [0]*len(rxn.r_species), [0]*len(rxn.p_species)
for broken_bond in breaks:
location = 0
index = 0
for reactant in rxn.r_species:
if not r_label_dict[broken_bond[1]] < reactant.number_of_atoms + index:
location += 1
index += reactant.number_of_atoms
else:
if loc_r[location] >= 1:
loc_r[location] += 1
reactants[location].bdes += [(r_label_dict[broken_bond[1]] + 1 - index, r_label_dict[broken_bond[3]] + 1 - index)]
else:
loc_r[location] += 1
reactants[location] = ARCSpecies(label="".join(sorted(
[key_by_val(r_label_dict, r_label_dict[broken_bond[1]]),
key_by_val(p_label_dict, p_label_dict[broken_bond[3]])])),
mol = reactant.mol.copy(deep=True),
bdes = [(r_label_dict[broken_bond[1]] + 1 - index,
r_label_dict[broken_bond[3]] + 1 - index)])
reactants[location].final_xyz = reactant.get_xyz()
for mol1, mol2 in zip(reactants[location].mol.atoms, reactant.mol.atoms):
mol1.label = mol2.label
break

for formed_bond in forms:
location = 0
index = 0
for product in rxn.p_species:
if not p_label_dict[formed_bond[1]] < product.number_of_atoms + index:
location += 1
index += product.number_of_atoms
else:
if loc_p[location] >= 1:
loc_p[location]+=1
products[location].bdes += [(p_label_dict[formed_bond[1]] + 1 - index, p_label_dict[formed_bond[3]] + 1 - index)]
else:
loc_p[location] += 1
products[location] = ARCSpecies(label="".join(sorted(
[key_by_val(p_label_dict, p_label_dict[formed_bond[1]]),
key_by_val(p_label_dict, p_label_dict[formed_bond[3]])])),
mol = product.mol.copy(deep=True),
xyz = product.get_xyz(),
bdes = [(p_label_dict[formed_bond[1]] + 1 - index,
p_label_dict[formed_bond[3]] + 1 - index)])
for mol1, mol2 in zip(products[location].mol.atoms, product.mol.atoms):
mol1.label = mol2.label
break
for index, value in enumerate(loc_r):
if value == 0:
reactants[index] = rxn.r_species[index]

for index, value in enumerate(loc_p):
if value == 0:
products[index] = rxn.p_species[index]

return reactants, products, loc_r, loc_p


def make_bond_changes(rxn: 'ARCReaction',
r_cuts: List[ARCSpecies],
r_label_dict: Dict):
Expand Down Expand Up @@ -1165,95 +1082,6 @@
atom_index+=1


def multiple_cut_on_species(spc, bdes):
"""
A function that recursively calles an is called by cut_species_for_mapping.
Is used to cut a species and reassign the other BDE's to the other cuts, and recalling cut_species_for_mapping.
Args:
spc (ARCSpecies): a species with more then one BDE's (marks for scission).
bdes (list(tuple(int))): the required BDE's.
Returns:
list(ARCSpecies): a list of the cut products.
"""
bdes = spc.bdes
spc.bdes = [bdes[0]]
bdes = bdes[1:]
try:
cuts = spc.scissors()
except SpeciesError:
return None
for species in cuts:
species.final_xyz = species.get_xyz(generate=False)
indinces = [int(atom.label) for atom in species.mol.copy(deep=True).atoms]
new_bdes = list()
for bde in bdes:
if bde[0]-1 in indinces and bde[1]-1 in indinces:
new_bdes.append((indinces.index(bde[0]-1) + 1, indinces.index(bde[1]-1) + 1))
species.bdes = new_bdes
return cut_species_for_mapping(cuts, [len(species.bdes or list()) for species in cuts])


def cut_species_for_mapping(species, locs):
"""
A function for performing the necessary scission of species for mapping purposes. Can perform appropriate scission of multiple bonds at once.
Args:
species (list(ARCSpecies)): the species (reactants or products), marked for scission.
locs (list(int)): the number of cuts that is required for each species.
Returns:
list(ARCSpecies): a list of the cut products.
"""
cuts = list()
for spc, loc in zip(species, locs):
spc.final_xyz = spc.get_xyz()
if spc.mol.copy(deep=True).smiles == "[H][H]" and loc != 0: # scissors should return one species
labels = [atom.label for atom in spc.mol.copy(deep=True).atoms]
try:
H1 = spc.scissors()[0]
except SpeciesError:
return None
H2 = H1.copy()
H2.mol.atoms[0].label = labels[0] if H1.mol.atoms[0].label != labels[0] else labels[1]
cuts += [H1, H2]
elif loc == 0:
cuts += [spc]
elif loc == 1:
try:
cuts += spc.scissors()
except SpeciesError:
return None
else:
bdes = spc.bdes
cuts += multiple_cut_on_species(spc, bdes)
return cuts


def find_main_cut_product(cuts: List["ARCSpecies"],
reactant: "ARCSpecies",
bde: Tuple[int]
)->Tuple["ARCSpecies", "ARCSpecies"]:
"""
Differentiate the main product from scissors product.
Strategy: use the other BDE, if the indices of the atoms are in the bdes, it should be able to cut there.

Args:
cuts: The cut products
reactant: The reactant with multiple bdes
bde: the BDE used for scissors.

Returns:
Tuple["ARCSpecies", "ARCSpecies"] in the correct order.
"""
list_atom_labels_cuts_0 = [int(atom.label)+1 for atom in cuts[0].mol.atoms]
bdes = reactant.bdes
for bd in bdes:
if bd == bde:
continue
elif bd[0] not in list_atom_labels_cuts_0:
return cuts[1], cuts[0]

return cuts[0], cuts[1]


def update_xyz(spcs: List[ARCSpecies]) -> List[ARCSpecies]:
"""
A helper function, updates the xyz values of each species after cutting. This is important, since the
Expand Down Expand Up @@ -1363,31 +1191,103 @@
p_atoms = pair[1].mol.atoms
for map_index, r_atom in zip(_map, r_atoms):
am_dict[int(r_atom.label)] = int(p_atoms[map_index].label)
return [val for key, val in sorted(am_dict.items(), key=lambda item: item[0])]


def cuts_on_cycle_of_labeled_mol(spc: 'ARCSpecies')-> bool:
"""A helper function determining whether or not the scission site opens a cycle.
return [val for _, val in sorted(am_dict.items(), key=lambda item: item[0])]

Args:
spc1: ARCSpecies with a bdes

Returns:
True if the scission site is on a ring, None if the speceis is unlabeled, False otherwise"""
if not any([atom.label for atom in spc.mol.atoms]):
raise ValueError("cuts_on_cycle_of_labeled_mol recives labeled ARCSpecies only, got an unlabeld species")
def determine_bdes_on_spc_based_on_atom_labels(spc: "ARCSpecies", bde: Tuple[int, int]) -> bool:
"""
A function for determining whether or not the species in question containt the bond specified by the bond dissociation indices.
Also, assigns the correct BDE to the species.

Args:
spc (ARCSpecies): The species in question, with labels atom indices.
bde (Tuple[int, int]): The bde in question.
add_bdes (bool): Whether or not to add the bde to the species.

if not spc.mol.is_cyclic():
Returns:
bool: Whether or not the bde is based on the atom labels.
alongd marked this conversation as resolved.
Show resolved Hide resolved
"""
bde = convert_list_index_0_to_1(bde, direction=-1)
index1, index2 = bde[0], bde[1]
new_bde = list()
atoms = list()
for index, atom in enumerate(spc.mol.atoms):
if atom.label == str(index1) or atom.label == str(index2):
new_bde.append(index+1)
atoms.append(atom)
if len(new_bde) == 2:
break

if len(new_bde) == 2 and atoms[1] in atoms[0].bonds.keys():
spc.bdes = [tuple(new_bde)]
return True
else:
return False

k = spc.mol.get_deterministic_sssr()
labels = [[] for i in range(len(k))]
for index, cycle in enumerate(k):
for atom in cycle:
labels[index].append(int(atom.label))

def cut_species_based_on_atom_indices(species: List["ARCSpecies"], bdes: List[Tuple[int, int]]) -> Optional[List["ARCSpecies"]]:
"""
A function for scissoring species based on their atom indices.
Args:
species (List[ARCSpecies]): The species list that requires scission.
bdes (List[Tuple[int, int]]): A list of the atoms between which the bond should be scissored. The atoms are described using the atom labels, and not the actuall atom positions.
Returns:
Optional[List["ARCSpecies"]]: The species list input after the scission.
"""
if not bdes:
return species

for bde in spc.bdes:
for cycle in labels:
if bde[0]-1 in cycle and bde[1]-1 in cycle:
return True
return False
for bde in bdes:
for index, spc in enumerate(species):
if determine_bdes_on_spc_based_on_atom_labels(spc, bde):
candidate = species.pop(index)
alongd marked this conversation as resolved.
Show resolved Hide resolved
candidate.final_xyz = candidate.get_xyz()
if candidate.mol.copy(deep=True).smiles == "[H][H]":
labels = [atom.label for atom in candidate.mol.copy(deep=True).atoms]
try:
h1 = candidate.scissors()[0]
except SpeciesError:
return None

Check warning on line 1250 in arc/mapping/engine.py

View check run for this annotation

Codecov / codecov/patch

arc/mapping/engine.py#L1249-L1250

Added lines #L1249 - L1250 were not covered by tests
h2 = h1.copy()
h2.mol.atoms[0].label = labels[0] if h1.mol.atoms[0].label != labels[0] else labels[1]
species += [h1, h2]
else:
try:
species += candidate.scissors()
Copy link
Member

@alongd alongd Sep 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docsting says we return the species after scission, but looks like we add it to the species list, and the original species is also returned (also above when treating H2). Is this on purpose? If so, need to update the docstring

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The species are returned after all required scissoring is done. If none were requested, then no scissoring are done. It does return them after scission, so I do not understand the confusion.

except SpeciesError:
return None

Check warning on line 1258 in arc/mapping/engine.py

View check run for this annotation

Codecov / codecov/patch

arc/mapping/engine.py#L1257-L1258

Added lines #L1257 - L1258 were not covered by tests
break

return species


def copy_species_list_for_mapping(species: List["ARCSpecies"]) -> List["ARCSpecies"]:
"""
A helper function for copying the species list for mapping. Also keeps the atom indices when copying.
Args:
species (List[ARCSpecies]): The species list to be copied.
Returns:
List[ARCSpecies]: The copied species list.
"""
copies = [spc.copy() for spc in species]
for copy, spc in zip(copies, species):
for atom1, atom2 in zip(copy.mol.atoms, spc.mol.atoms):
atom1.label = atom2.label
return copies


def find_all_bdes(rxn: "ARCReaction", label_dict: dict, is_reactants: bool) -> List[Tuple[int, int]]:
kfir4444 marked this conversation as resolved.
Show resolved Hide resolved
"""
A function for finding all the broken(/formed) bonds during a chemical reaction, based on the atom indices.
Args:
rxn (ARCReaction): The reaction in question.
label_dict (dict): A dictionary of the atom indices to the atom labels.
is_reactants (bool): Whether or not the species list represents reactants or products.
Returns:
List[Tuple[int, int]]: A list of tuples of the form (atom_index1, atom_index2) for each broken bond. Note that these represent the atom indicies to be cut, and not final BDEs.
"""
bdes = list()
for action in rxn.family.forward_recipe.actions:
if action[0].lower() == ("break_bond" if is_reactants else "form_bond"):
bdes.append((label_dict[action[1]] + 1, label_dict[action[3]] + 1))
return bdes
Loading
Loading