Skip to content

Commit

Permalink
TMP3 now driver
Browse files Browse the repository at this point in the history
  • Loading branch information
alongd committed Dec 21, 2024
1 parent 2d162ce commit 960c507
Show file tree
Hide file tree
Showing 7 changed files with 363 additions and 399 deletions.
2 changes: 1 addition & 1 deletion arc/family/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import arc.family.family

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'arc.family.family' is imported with both 'import' and 'import from'.
from arc.family.family import ReactionFamily
from arc.family.family import ReactionFamily, determine_possible_reaction_products_from_family
from arc.family.family import get_reaction_family_products
2 changes: 2 additions & 0 deletions arc/family/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ def apply_recipe(self,
raise ValueError(f'Unknown action "{action[0]}" encountered.')
if 'validAromatic' in structure.props and not structure.props['validAromatic']:
structure.kekulize()
for atom in structure.atoms:
atom.update_charge()
structures = structure.split()
if self.product_num != len(structures):
return None
Expand Down
53 changes: 18 additions & 35 deletions arc/mapping/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from typing import TYPE_CHECKING, Dict, List, Optional

from arc.mapping.engine import (are_adj_elements_in_agreement,
from arc.family import determine_possible_reaction_products_from_family
from arc.mapping.engine import (RESERVED_FINGERPRINT_KEYS,
are_adj_elements_in_agreement,
create_qc_mol,
flip_map,
fingerprint,
Expand All @@ -23,7 +25,7 @@
find_all_breaking_bonds,
cut_species_based_on_atom_indices,
update_xyz,
RESERVED_FINGERPRINT_KEYS,)
)
from arc.common import logger
from arc.species.converter import check_molecule_list_order

Expand Down Expand Up @@ -53,7 +55,7 @@ def map_reaction(rxn: 'ARCReaction',
if flip:
logger.warning(f"The requested ARC reaction {rxn} could not be atom mapped using {backend}. Trying again with the flipped reaction.")
try:
_map = flip_map(map_rxn(rxn.flip_reaction(), backend=backend, db=db))
_map = flip_map(map_rxn(rxn.flip_reaction(), backend=backend))
except ValueError:
return None
return _map
Expand All @@ -62,12 +64,12 @@ def map_reaction(rxn: 'ARCReaction',
logger.warning(f'Could not determine the reaction family for {rxn.label}. '
f'Mapping as a general or isomerization reaction.')
_map = map_general_rxn(rxn, backend=backend)
return _map if _map is not None else map_reaction(rxn, backend=backend, db=db, flip=True)
return _map if _map is not None else map_reaction(rxn, backend=backend, flip=True)
try:
_map = map_rxn(rxn, backend=backend, db=db)
except ValueError as e:
return map_reaction(rxn, backend=backend, db=db, flip=True)
return _map if _map is not None else map_reaction(rxn, backend=backend, db=db, flip=True)
_map = map_rxn(rxn, backend=backend)
except ValueError:
return map_reaction(rxn, backend=backend, flip=True)
return _map if _map is not None else map_reaction(rxn, backend=backend, flip=True)


def map_general_rxn(rxn: 'ARCReaction',
Expand Down Expand Up @@ -114,7 +116,6 @@ def map_isomerization_reaction(rxn: 'ARCReaction',
Args:
rxn (ARCReaction): An ARCReaction object instance.
backend (str, optional): Whether to use ``'QCElemental'`` or ``ARC``'s method as the backend.
Returns:
Optional[List[int]]:
Expand Down Expand Up @@ -206,12 +207,10 @@ def map_rxn(rxn: 'ARCReaction',
"""
A wrapper function for mapping reaction, uses databases for mapping with the correct reaction family parameters.
Strategy:
1) get_rmg_reactions_from_arc_reaction, get_atom_indices_of_labeled_atoms_in_an_rmg_reaction.
2) (For bimolecular reactions) Find the species in which the bond is broken.
3) Scissor the reactant(s) and product(s).
4) Match pair species.
5) Map_two_species.
6) Join maps together.
1) Scissor the reactant(s) and product(s).
2) Match pair species.
3) Map_two_species.
4) Join maps together.
Args:
rxn (ARCReaction): An ARCReaction object instance that belongs to the RMG H_Abstraction reaction family.
Expand All @@ -222,43 +221,30 @@ def map_rxn(rxn: 'ARCReaction',
Entry indices are running atom indices of the reactants,
corresponding entry values are running atom indices of the products.
"""
# step 1:
rmg_reactions = 5

if not rmg_reactions:
return None

r_label_dict, p_label_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=rxn,
rmg_reaction=rmg_reactions[0])

# step 2:
assign_labels_to_products(rxn, p_label_dict)

# step 3:
reactants, products = copy_species_list_for_mapping(rxn.r_species), copy_species_list_for_mapping(rxn.p_species)
reactants, products = rxn.get_reactants_and_products(arc=True, return_copies=False)
reactants, products = copy_species_list_for_mapping(reactants), copy_species_list_for_mapping(products)
label_species_atoms(reactants), label_species_atoms(products)

r_bdes, p_bdes = find_all_breaking_bonds(rxn, True), find_all_breaking_bonds(rxn, False)

r_cuts = cut_species_based_on_atom_indices(reactants, r_bdes)
p_cuts = cut_species_based_on_atom_indices(products, p_bdes)

product_dicts = determine_possible_reaction_products_from_family(rxn, family_label=rxn.family)
try:
r_label_dict = product_dicts[0]['r_label_map']
make_bond_changes(rxn, r_cuts, r_label_dict)
except (ValueError, IndexError, ActionError, AtomTypeError) as e:
logger.warning(e)

r_cuts, p_cuts = update_xyz(r_cuts), update_xyz(p_cuts)

# step 4:
pairs_of_reactant_and_products = pairing_reactants_and_products_for_mapping(r_cuts, p_cuts)
if len(p_cuts):
logger.error(f"Could not find isomorphism for scissored species: {[cut.mol.smiles for cut in p_cuts]}")
return None
# step 5:
maps = map_pairs(pairs_of_reactant_and_products)

# step 6:
return glue_maps(maps, pairs_of_reactant_and_products)


Expand All @@ -277,7 +263,6 @@ def convert_label_dict(label_dict: Dict[str, int],
Returns:
Dict[str, int]: The converted label dictionary.
"""
print(f'original label_dict: {label_dict}')
if len(reference_mol_list) != len(mol_list):
raise ValueError(f'The number of reference molecules ({len(reference_mol_list)}) '
f'does not match the number of molecules ({len(mol_list)}).')
Expand All @@ -289,7 +274,6 @@ def convert_label_dict(label_dict: Dict[str, int],
return {label: atom_map[index] for label, index in label_dict.items()}
elif len(reference_mol_list) == 2:
ordered = check_molecule_list_order(mols_1=reference_mol_list, mols_2=mol_list)
print(f'ordered: {ordered}')
atom_map_1 = map_two_species(reference_mol_list[0], mol_list[0]) if ordered else map_two_species(reference_mol_list[1], mol_list[0])
atom_map_2 = map_two_species(reference_mol_list[1], mol_list[1]) if ordered else map_two_species(reference_mol_list[0], mol_list[1])
if atom_map_1 is None or atom_map_2 is None:
Expand All @@ -298,5 +282,4 @@ def convert_label_dict(label_dict: Dict[str, int],
return None
atom_map = atom_map_1 + [index + len(atom_map_1) for index in atom_map_2] if ordered else \
atom_map_2 + [index + len(atom_map_2) for index in atom_map_1]
print(f'atom_map: {atom_map}')
return {label: atom_map[index] for label, index in label_dict.items()}
Loading

0 comments on commit 960c507

Please sign in to comment.