Skip to content

Commit

Permalink
TMP
Browse files Browse the repository at this point in the history
  • Loading branch information
alongd committed Dec 23, 2024
1 parent fb544a8 commit da39278
Show file tree
Hide file tree
Showing 10 changed files with 537 additions and 373 deletions.
1 change: 0 additions & 1 deletion arc/checks/ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from arc.imports import settings
from arc.species.converter import check_xyz_dict, displace_xyz, xyz_to_dmat
from arc.mapping.engine import get_atom_indices_of_labeled_atoms_in_an_rmg_reaction
from arc.statmech.factory import statmech_factory

if TYPE_CHECKING:
Expand Down
11 changes: 11 additions & 0 deletions arc/family/family_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,17 @@ def test_get_isomorphic_subgraph(self):
)
self.assertEqual(isomorphic_subgraph, {0: '*3', 4: '*1', 7: '*2'})

# def test_order_species_list(self):
# """Test the order_species_list() function"""
# spc1 = ARCSpecies(label='spc1', smiles='C')
# spc2 = ARCSpecies(label='spc2', smiles='CC')
# ordered_species_list = order_species_list(species_list=[spc2, spc1], reference_species=[spc1, spc2])
# self.assertEqual(ordered_species_list, [spc1, spc2])
# ordered_species_list = order_species_list(species_list=[spc2, spc1], reference_species=[spc2, spc1])
# self.assertEqual(ordered_species_list, [spc2, spc1])
# ordered_species_list = order_species_list(species_list=[spc2.mol, spc1], reference_species=[spc2, spc1.mol])
# self.assertEqual(ordered_species_list, [spc2.mol, spc1])


if __name__ == '__main__':
unittest.main(testRunner=unittest.TextTestRunner(verbosity=2))
119 changes: 70 additions & 49 deletions arc/mapping/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,34 @@

from typing import TYPE_CHECKING, List, Optional

from arc.mapping.engine import (assign_labels_to_products,
from arc.family import determine_possible_reaction_products_from_family
from arc.mapping.engine import (RESERVED_FINGERPRINT_KEYS,
are_adj_elements_in_agreement,
create_qc_mol,
flip_map,
fingerprint,
get_atom_indices_of_labeled_atoms_in_an_rmg_reaction,
get_rmg_reactions_from_arc_reaction,
glue_maps,
label_species_atoms,
make_bond_changes,
map_pairs,
iterative_dfs, map_two_species,
pairing_reactants_and_products_for_mapping,
copy_species_list_for_mapping,
find_all_bdes,
find_all_breaking_bonds,
cut_species_based_on_atom_indices,
update_xyz,
RESERVED_FINGERPRINT_KEYS,)
)
from arc.common import logger
from arc.species.converter import check_molecule_list_order

from rmgpy.exceptions import ActionError, AtomTypeError

if TYPE_CHECKING:
from rmgpy.data.rmg import RMGDatabase
from rmgpy.molecule.molecule import Molecule
from arc.reaction import ARCReaction


def map_reaction(rxn: 'ARCReaction',
backend: str = 'ARC',
db: Optional['RMGDatabase'] = None,
flip = False
) -> Optional[List[int]]:
"""
Expand All @@ -47,7 +45,6 @@ def map_reaction(rxn: 'ARCReaction',
Args:
rxn (ARCReaction): An ARCReaction object instance.
backend (str, optional): Whether to use ``'QCElemental'`` or ``ARC``'s method as the backend.
db (RMGDatabase, optional): The RMG database instance.
Returns:
Optional[List[int]]:
Expand All @@ -57,7 +54,7 @@ def map_reaction(rxn: 'ARCReaction',
if flip:
logger.warning(f"The requested ARC reaction {rxn} could not be atom mapped using {backend}. Trying again with the flipped reaction.")
try:
_map = flip_map(map_rxn(rxn.flip_reaction(), backend=backend, db=db))
_map = flip_map(map_rxn(rxn.flip_reaction(), backend=backend))
except ValueError:
return None
return _map
Expand All @@ -66,17 +63,16 @@ def map_reaction(rxn: 'ARCReaction',
logger.warning(f'Could not determine the reaction family for {rxn.label}. '
f'Mapping as a general or isomerization reaction.')
_map = map_general_rxn(rxn, backend=backend)
return _map if _map is not None else map_reaction(rxn, backend=backend, db=db, flip=True)
return _map if _map is not None else map_reaction(rxn, backend=backend, flip=True)
try:
_map = map_rxn(rxn, backend=backend, db=db)
except ValueError as e:
return map_reaction(rxn, backend=backend, db=db, flip=True)
return _map if _map is not None else map_reaction(rxn, backend=backend, db=db, flip=True)
_map = map_rxn(rxn, backend=backend)
except ValueError:
return map_reaction(rxn, backend=backend, flip=True)
return _map if _map is not None else map_reaction(rxn, backend=backend, flip=True)


def map_general_rxn(rxn: 'ARCReaction',
backend: str = 'ARC',
db: Optional['RMGDatabase'] = None,
) -> Optional[List[int]]:
"""
Map a general reaction (one that was not categorized into a reaction family by RMG).
Expand All @@ -85,7 +81,6 @@ def map_general_rxn(rxn: 'ARCReaction',
Args:
rxn (ARCReaction): An ARCReaction object instance.
backend (str, optional): Whether to use ``'QCElemental'`` or ``ARC``'s method as the backend.
db (RMGDatabase, optional): The RMG database instance.
Returns:
Optional[List[int]]:
Expand Down Expand Up @@ -120,8 +115,6 @@ def map_isomerization_reaction(rxn: 'ARCReaction',
Args:
rxn (ARCReaction): An ARCReaction object instance.
backend (str, optional): Whether to use ``'QCElemental'`` or ``ARC``'s method as the backend.
db (RMGDatabase, optional): The RMG database instance.
Returns:
Optional[List[int]]:
Expand Down Expand Up @@ -209,63 +202,91 @@ def map_isomerization_reaction(rxn: 'ARCReaction',

def map_rxn(rxn: 'ARCReaction',
backend: str = 'ARC',
db: Optional['RMGDatabase'] = None,
) -> Optional[List[int]]:
"""
A wrapper function for mapping reaction, uses databases for mapping with the correct reaction family parameters.
Strategy:
1) get_rmg_reactions_from_arc_reaction, get_atom_indices_of_labeled_atoms_in_an_rmg_reaction.
2) (For bimolecular reactions) Find the species in which the bond is broken.
3) Scissor the reactant(s) and product(s).
4) Match pair species.
5) Map_two_species.
6) Join maps together.
1) Scissor the reactant(s) and product(s).
2) Match pair species.
3) Map_two_species.
4) Join maps together.
Args:
rxn (ARCReaction): An ARCReaction object instance that belongs to the RMG H_Abstraction reaction family.
backend (str, optional): Whether to use ``'QCElemental'`` or ``ARC``'s method as the backend.
db (RMGDatabase, optional): The RMG database instance.
Returns:
Optional[List[int]]:
Entry indices are running atom indices of the reactants,
corresponding entry values are running atom indices of the products.
"""
# step 1:
rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=rxn, backend=backend)

if not rmg_reactions:
return None

r_label_dict, p_label_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=rxn,
rmg_reaction=rmg_reactions[0])

# step 2:
assign_labels_to_products(rxn, p_label_dict)

#step 3:
reactants, products = copy_species_list_for_mapping(rxn.r_species), copy_species_list_for_mapping(rxn.p_species)
reactants, products = rxn.get_reactants_and_products(arc=True, return_copies=False)
reactants, products = copy_species_list_for_mapping(reactants), copy_species_list_for_mapping(products)
label_species_atoms(reactants), label_species_atoms(products)

r_bdes, p_bdes = find_all_bdes(rxn, r_label_dict, True), find_all_bdes(rxn, p_label_dict, False)
r_bdes, p_bdes = find_all_breaking_bonds(rxn, True), find_all_breaking_bonds(rxn, False)

r_cuts = cut_species_based_on_atom_indices(reactants, r_bdes)
p_cuts = cut_species_based_on_atom_indices(products, p_bdes)

try:
make_bond_changes(rxn, r_cuts, r_label_dict)
except (ValueError, IndexError, ActionError, AtomTypeError) as e:
logger.warning(e)
print(f'\n\n 3.1 ***********\nr_cuts: {[cut.mol.copy(deep=True).to_smiles() for cut in r_cuts]}]\n')

product_dicts = determine_possible_reaction_products_from_family(rxn, family_label=rxn.family)
# try:
# r_label_dict = product_dicts[0]['r_label_map']
# make_bond_changes(rxn, r_cuts, r_label_dict)
# except (ValueError, IndexError, ActionError, AtomTypeError) as e:
# logger.warning(e)

print(f'\n\n 5.1 ***********\nr_cuts: {[cut.mol.copy(deep=True).to_smiles() for cut in r_cuts]}]\n')
# print(f'\n\n 5.2 ***********\np_cuts: {[cut.mol.copy(deep=True).to_smiles() for cut in p_cuts]}]\n')

r_cuts, p_cuts = update_xyz(r_cuts), update_xyz(p_cuts)

#step 4:
print(f'\n\n 9.1 ***********\nr_cuts: {[cut.mol.copy(deep=True).to_smiles() for cut in r_cuts]}]\n')
# print(f'\n\n 9.2 ***********\np_cuts: {[cut.mol.copy(deep=True).to_smiles() for cut in p_cuts]}]\n')

pairs_of_reactant_and_products = pairing_reactants_and_products_for_mapping(r_cuts, p_cuts)
if len(p_cuts):
logger.error(f"Could not find isomorphism for scissored species: {[cut.mol.smiles for cut in p_cuts]}")
return None
# step 5:
maps = map_pairs(pairs_of_reactant_and_products)

#step 6:
return glue_maps(maps, pairs_of_reactant_and_products)


# def convert_label_dict(label_dict: Dict[str, int],
# reference_mol_list: List['Molecule'],
# mol_list: List['Molecule'],
# ) -> Optional[Dict[str, int]]:
# """
# Convert the label dictionary to the correct atom indices in the reaction and reference molecules.
#
# Args:
# label_dict (Dict[str, int]): A dictionary of atom labels (e.g., '*1') to atom indices.
# reference_mol_list (List[Molecule]): The list of molecules to which label_dict values refer.
# mol_list (List[Molecule]): The list of molecules to which label_dict values should be converted.
#
# Returns:
# Dict[str, int]: The converted label dictionary.
# """
# if len(reference_mol_list) != len(mol_list):
# raise ValueError(f'The number of reference molecules ({len(reference_mol_list)}) '
# f'does not match the number of molecules ({len(mol_list)}).')
# if len(reference_mol_list) == 1:
# atom_map = map_two_species(reference_mol_list[0], mol_list[0])
# if atom_map is None:
# print(f'Could not map {reference_mol_list[0].to_smiles()} to {mol_list[0].to_smiles()}')
# return None
# return {label: atom_map[index] for label, index in label_dict.items()}
# elif len(reference_mol_list) == 2:
# ordered = check_molecule_list_order(mols_1=reference_mol_list, mols_2=mol_list)
# atom_map_1 = map_two_species(reference_mol_list[0], mol_list[0]) if ordered else map_two_species(reference_mol_list[1], mol_list[0])
# atom_map_2 = map_two_species(reference_mol_list[1], mol_list[1]) if ordered else map_two_species(reference_mol_list[0], mol_list[1])
# if atom_map_1 is None or atom_map_2 is None:
# print(f'Could not map {reference_mol_list[0].to_smiles()} to {mol_list[0].to_smiles()} '
# f'or {reference_mol_list[1].to_smiles()} to {mol_list[1].to_smiles()}')
# return None
# atom_map = atom_map_1 + [index + len(atom_map_1) for index in atom_map_2] if ordered else \
# atom_map_2 + [index + len(atom_map_2) for index in atom_map_1]
# return {label: atom_map[index] for label, index in label_dict.items()}
Loading

0 comments on commit da39278

Please sign in to comment.