From bad7d57ab31d3890b665982f03e69c66fbbe1720 Mon Sep 17 00:00:00 2001 From: Alon Grinberg Dana Date: Thu, 11 Jul 2024 14:00:06 +0300 Subject: [PATCH] TMP --- arc/checks/ts.py | 7 +- arc/job/adapters/ts/heuristics.py | 2 +- arc/mapping/driver.py | 49 +--- arc/mapping/driver_test.py | 70 +++-- arc/mapping/engine.py | 413 ++++++++++++++++++------------ arc/mapping/engine_test.py | 139 ++++++---- arc/reaction/family.py | 18 +- arc/reaction/family_test.py | 81 ++++++ arc/reaction/reaction.py | 14 +- arc/rmgdb.py | 2 +- 10 files changed, 483 insertions(+), 312 deletions(-) diff --git a/arc/checks/ts.py b/arc/checks/ts.py index 98304c4608..4e80a5d1d1 100644 --- a/arc/checks/ts.py +++ b/arc/checks/ts.py @@ -19,9 +19,7 @@ ) from arc.imports import settings from arc.species.converter import check_xyz_dict, displace_xyz, xyz_to_dmat -from arc.mapping.engine import (get_atom_indices_of_labeled_atoms_in_an_rmg_reaction, - get_rmg_reactions_from_arc_reaction, - ) +from arc.mapping.engine import get_atom_indices_of_labeled_atoms_in_a_reaction from arc.statmech.factory import statmech_factory if TYPE_CHECKING: @@ -330,8 +328,7 @@ def check_normal_mode_displacement(reaction: 'ARCReaction', bond_lone_hydrogens=bond_lone_hs) got_expected_changing_bonds = False for i, rmg_reaction in enumerate(rmg_reactions): - r_label_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=reaction, - rmg_reaction=rmg_reaction)[0] + r_label_dict = get_atom_indices_of_labeled_atoms_in_a_reaction(arc_reaction=reaction)[0] if r_label_dict is None: continue expected_breaking_bonds, expected_forming_bonds = reaction.get_expected_changing_bonds(r_label_dict=r_label_dict) diff --git a/arc/job/adapters/ts/heuristics.py b/arc/job/adapters/ts/heuristics.py index 76da11d2c9..aa1171a155 100644 --- a/arc/job/adapters/ts/heuristics.py +++ b/arc/job/adapters/ts/heuristics.py @@ -33,7 +33,7 @@ from arc.job.factory import register_job_adapter from arc.plotter import save_geo from arc.species.converter import compare_zmats, relocate_zmat_dummy_atoms_to_the_end, zmat_from_xyz, zmat_to_xyz -from arc.mapping.engine import map_arc_rmg_species, map_two_species +from arc.mapping.engine import map_two_species from arc.species.species import ARCSpecies, TSGuess, colliding_atoms from arc.species.zmat import get_parameter_from_atom_indices, remove_1st_atom, up_param diff --git a/arc/mapping/driver.py b/arc/mapping/driver.py index bbdd176183..2451bb636d 100644 --- a/arc/mapping/driver.py +++ b/arc/mapping/driver.py @@ -14,8 +14,7 @@ create_qc_mol, flip_map, fingerprint, - get_atom_indices_of_labeled_atoms_in_an_rmg_reaction, - get_rmg_reactions_from_arc_reaction, + get_atom_indices_of_labeled_atoms_in_a_reaction, glue_maps, label_species_atoms, make_bond_changes, @@ -26,20 +25,19 @@ find_all_bdes, cut_species_based_on_atom_indices, update_xyz, - RESERVED_FINGERPRINT_KEYS,) + RESERVED_FINGERPRINT_KEYS, + ) from arc.common import logger from rmgpy.exceptions import ActionError, AtomTypeError if TYPE_CHECKING: - from rmgpy.data.rmg import RMGDatabase from arc.reaction import ARCReaction def map_reaction(rxn: 'ARCReaction', backend: str = 'ARC', - db: Optional['RMGDatabase'] = None, - flip = False + flip: bool = False ) -> Optional[List[int]]: """ Map a reaction. @@ -47,7 +45,7 @@ def map_reaction(rxn: 'ARCReaction', Args: rxn (ARCReaction): An ARCReaction object instance. backend (str, optional): Whether to use ``'QCElemental'`` or ``ARC``'s method as the backend. - db (RMGDatabase, optional): The RMG database instance. + flip (bool, optional): Whether to attempts fliping the reaction. Returns: Optional[List[int]]: @@ -57,7 +55,7 @@ def map_reaction(rxn: 'ARCReaction', if flip: logger.warning(f"The requested ARC reaction {rxn} could not be atom mapped using {backend}. Trying again with the flipped reaction.") try: - _map = flip_map(map_rxn(rxn.flip_reaction(), backend=backend, db=db)) + _map = flip_map(map_rxn(rxn.flip_reaction(), backend=backend)) except ValueError: return None return _map @@ -65,18 +63,16 @@ def map_reaction(rxn: 'ARCReaction', if rxn.family is None: logger.warning(f'Could not determine the reaction family for {rxn.label}. ' f'Mapping as a general or isomerization reaction.') - _map = map_general_rxn(rxn, backend=backend) - return _map if _map is not None else map_reaction(rxn, backend=backend, db=db, flip=True) + _map = map_general_rxn(rxn) + return _map if _map is not None else map_reaction(rxn, backend=backend, flip=True) try: - _map = map_rxn(rxn, backend=backend, db=db) - except ValueError as e: - return map_reaction(rxn, backend=backend, db=db, flip=True) - return _map if _map is not None else map_reaction(rxn, backend=backend, db=db, flip=True) + _map = map_rxn(rxn, backend=backend) + except ValueError: + return map_reaction(rxn, backend=backend, flip=True) + return _map if _map is not None else map_reaction(rxn, backend=backend, flip=True) def map_general_rxn(rxn: 'ARCReaction', - backend: str = 'ARC', - db: Optional['RMGDatabase'] = None, ) -> Optional[List[int]]: """ Map a general reaction (one that was not categorized into a reaction family by RMG). @@ -84,8 +80,6 @@ def map_general_rxn(rxn: 'ARCReaction', Args: rxn (ARCReaction): An ARCReaction object instance. - backend (str, optional): Whether to use ``'QCElemental'`` or ``ARC``'s method as the backend. - db (RMGDatabase, optional): The RMG database instance. Returns: Optional[List[int]]: @@ -121,7 +115,6 @@ def map_isomerization_reaction(rxn: 'ARCReaction', Args: rxn (ARCReaction): An ARCReaction object instance. backend (str, optional): Whether to use ``'QCElemental'`` or ``ARC``'s method as the backend. - db (RMGDatabase, optional): The RMG database instance. Returns: Optional[List[int]]: @@ -209,7 +202,6 @@ def map_isomerization_reaction(rxn: 'ARCReaction', def map_rxn(rxn: 'ARCReaction', backend: str = 'ARC', - db: Optional['RMGDatabase'] = None, ) -> Optional[List[int]]: """ A wrapper function for mapping reaction, uses databases for mapping with the correct reaction family parameters. @@ -224,26 +216,16 @@ def map_rxn(rxn: 'ARCReaction', Args: rxn (ARCReaction): An ARCReaction object instance that belongs to the RMG H_Abstraction reaction family. backend (str, optional): Whether to use ``'QCElemental'`` or ``ARC``'s method as the backend. - db (RMGDatabase, optional): The RMG database instance. Returns: Optional[List[int]]: Entry indices are running atom indices of the reactants, corresponding entry values are running atom indices of the products. """ - # step 1: - rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=rxn, backend=backend) - - if not rmg_reactions: - return None - - r_label_dict, p_label_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=rxn, - rmg_reaction=rmg_reactions[0]) + r_label_dict, p_label_dict = get_atom_indices_of_labeled_atoms_in_a_reaction(arc_reaction=rxn) - # step 2: assign_labels_to_products(rxn, p_label_dict) - - #step 3: + reactants, products = copy_species_list_for_mapping(rxn.r_species), copy_species_list_for_mapping(rxn.p_species) label_species_atoms(reactants), label_species_atoms(products) @@ -259,13 +241,10 @@ def map_rxn(rxn: 'ARCReaction', r_cuts, p_cuts = update_xyz(r_cuts), update_xyz(p_cuts) - #step 4: pairs_of_reactant_and_products = pairing_reactants_and_products_for_mapping(r_cuts, p_cuts) if len(p_cuts): logger.error(f"Could not find isomorphism for scissored species: {[cut.mol.smiles for cut in p_cuts]}") return None - # step 5: maps = map_pairs(pairs_of_reactant_and_products) - #step 6: return glue_maps(maps, pairs_of_reactant_and_products) diff --git a/arc/mapping/driver_test.py b/arc/mapping/driver_test.py index dba1f085c2..adf25c4e57 100644 --- a/arc/mapping/driver_test.py +++ b/arc/mapping/driver_test.py @@ -99,8 +99,8 @@ def setUpClass(cls): (-0.8942590, -0.8537420, 0.0000000)), 'isotopes': (16, 16, 1), 'symbols': ('O', 'O', 'H')} cls.nh2_xyz = """N 0.00022972 0.40059496 0.00000000 - H -0.83174214 -0.19982058 0.00000000 - H 0.83151242 -0.20077438 0.00000000""" + H -0.83174214 -0.19982058 0.00000000 + H 0.83151242 -0.20077438 0.00000000""" cls.n2h4_xyz = """N -0.67026921 -0.02117571 -0.25636419 N 0.64966276 0.05515705 0.30069593 H -1.27787600 0.74907557 0.03694453 @@ -481,7 +481,6 @@ def test_map_abstractions(self): self.assertIn(atom_map[5], [4, 5]) self.assertTrue(check_atom_map(rxn)) - # H + CH3NH2 <=> H2 + CH2NH2 ch3nh2_xyz = {'coords': ((-0.5734111454228507, 0.0203516083213337, 0.03088703933770556), (0.8105595891860601, 0.00017446498908627427, -0.4077728757313545), @@ -560,7 +559,6 @@ def test_map_abstractions(self): self.assertTrue(any(atom_map[r_index] in [6, 7, 8] for r_index in [5, 6, 7, 8])) self.assertTrue(check_atom_map(rxn)) - # CH3OO + CH3CH2OH <=> CH3OOH + CH3CH2O / peroxyl to alkoxyl, modified atom and product order r_1 = ARCSpecies( label="CH3OO", @@ -669,7 +667,6 @@ def test_map_abstractions(self): self.assertEqual(atom_map[23],23) self.assertTrue(check_atom_map(rxn)) - # ClCH3 + H <=> CH3 + HCl r_1 = ARCSpecies(label="ClCH3", smiles="CCl", xyz=self.ch3cl_xyz) r_2 = ARCSpecies(label="H", smiles="[H]", xyz=self.h_rad_xyz) @@ -714,23 +711,23 @@ def test_map_abstractions(self): (0.3717352549047681, -1.308596593192221, 0.7750989547682503), (-2.0374518517222544, -0.751480024679671, 0.37217669645466245))} - p_1_xyz = {'symbols': ('O', 'Cl', 'H'), 'isotopes': (16, 35, 1), 'coords': ( - (-0.3223044372303026, 0.4343354356368888, 0.0), (1.2650242694442462, -0.12042710381137228, 0.0), - (-0.9427198322139436, -0.3139083318255167, 0.0))} + p_1_xyz = {'symbols': ('O', 'Cl', 'H'), 'isotopes': (16, 35, 1), + 'coords': ((-0.3223044372303026, 0.4343354356368888, 0.0), + (1.2650242694442462, -0.12042710381137228, 0.0), + (-0.9427198322139436, -0.3139083318255167, 0.0))} p_2_xyz = {'symbols': ('C', 'C', 'C', 'Cl', 'Cl', 'H', 'H', 'H', 'H', 'H'), - 'isotopes': (12, 12, 12, 35, 35, 1, 1, 1, 1, 1), 'coords': ( - (-1.3496376883278178, -0.020445981649800302, -0.1995184115269273), - (-0.051149096449292386, -0.3885500107837139, 0.4222976979623008), - (1.217696701041357, 0.15947991928242372, -0.1242718714010236), - (1.7092794464102241, 1.570982412202936, 0.8295196720275746), - (2.474584210365428, -1.0919019396606517, -0.06869614478411318), - (-1.6045061896547035, 1.0179450876989615, 0.03024632893682861), - (-1.3137314500783486, -0.14754777860704252, -1.2853589013330937), - (-2.1459595425475264, -0.6625965540242661, 0.188478021031359), - (-0.044412318929613885, -0.9093853981117669, 1.373599947353138), - (1.1078359281702537, 0.47202024365290884, -1.1662963382659064))} - + 'isotopes': (12, 12, 12, 35, 35, 1, 1, 1, 1, 1), + 'coords': ((-1.3496376883278178, -0.020445981649800302, -0.1995184115269273), + (-0.051149096449292386, -0.3885500107837139, 0.4222976979623008), + (1.217696701041357, 0.15947991928242372, -0.1242718714010236), + (1.7092794464102241, 1.570982412202936, 0.8295196720275746), + (2.474584210365428, -1.0919019396606517, -0.06869614478411318), + (-1.6045061896547035, 1.0179450876989615, 0.03024632893682861), + (-1.3137314500783486, -0.14754777860704252, -1.2853589013330937), + (-2.1459595425475264, -0.6625965540242661, 0.188478021031359), + (-0.044412318929613885, -0.9093853981117669, 1.373599947353138), + (1.1078359281702537, 0.47202024365290884, -1.1662963382659064))} r_1 = ARCSpecies(label='r1', smiles=smiles[0],xyz=r_1_xyz ) r_2 = ARCSpecies(label='r2', smiles=smiles[1],xyz=r_2_xyz) p_1 = ARCSpecies(label='p1', smiles=smiles[2],xyz=p_1_xyz) @@ -750,7 +747,6 @@ def test_map_abstractions(self): self.assertTrue(check_atom_map(rxn)) # Br abstraction - # OH + CH3Br <=> HOBr + CH3 r_1_xyz = {'symbols': ('O', 'H'), 'isotopes': (16, 1), 'coords': ((0.48890386738601, 0.0, 0.0), (-0.48890386738601, 0.0, 0.0))} @@ -788,22 +784,22 @@ def test_map_abstractions(self): # [H] + CC(=O)Br <=> [H][Br] + C[C](=O) r_1_xyz = {'symbols': ('H',), 'isotopes': (1,), 'coords': ((0.0, 0.0, 0.0),)} - r_2_xyz = {'symbols': ('C', 'C', 'O', 'Br', 'H', 'H', 'H'), 'isotopes': (12, 12, 16, 79, 1, 1, 1), 'coords': ( - (-0.7087772076387326, -0.08697184565826255, 0.08295914062572969), - (0.7238141593293749, 0.2762480677183181, -0.14965326856248656), - (1.1113560248255752, 1.3624373452907719, -0.554840372311578), - (2.0636725443687616, -1.041297021241265, 0.20693447296577364), - (-0.9844931733249197, -0.9305935329026733, -0.5546432084044857), - (-0.8586221633621384, -0.3455305862905263, 1.134123935245044), - (-1.3469501841979155, 0.7657075730836449, -0.16488069955797996))} + r_2_xyz = {'symbols': ('C', 'C', 'O', 'Br', 'H', 'H', 'H'), 'isotopes': (12, 12, 16, 79, 1, 1, 1), + 'coords': ((-0.7087772076387326, -0.08697184565826255, 0.08295914062572969), + (0.7238141593293749, 0.2762480677183181, -0.14965326856248656), + (1.1113560248255752, 1.3624373452907719, -0.554840372311578), + (2.0636725443687616, -1.041297021241265, 0.20693447296577364), + (-0.9844931733249197, -0.9305935329026733, -0.5546432084044857), + (-0.8586221633621384, -0.3455305862905263, 1.134123935245044), + (-1.3469501841979155, 0.7657075730836449, -0.16488069955797996))} - p_1_xyz = {'symbols': ('C', 'C', 'O', 'H', 'H', 'H'), 'isotopes': (12, 12, 16, 1, 1, 1), 'coords': ( - (-0.4758624005470258, 0.015865899777425058, -0.11215987340300927), - (0.9456990856850401, -0.031530842469194666, 0.2228995599390481), - (2.0897646616994816, -0.06967555524967288, 0.492553667108967), - (-1.08983188764878, -0.06771143046366379, 0.7892594299969324), - (-0.7261604551815313, 0.9578749227991876, -0.6086176800339509), - (-0.7436090040071672, -0.8048229943940851, -0.7839351036079769))} + p_1_xyz = {'symbols': ('C', 'C', 'O', 'H', 'H', 'H'), 'isotopes': (12, 12, 16, 1, 1, 1), + 'coords': ((-0.4758624005470258, 0.015865899777425058, -0.11215987340300927), + (0.9456990856850401, -0.031530842469194666, 0.2228995599390481), + (2.0897646616994816, -0.06967555524967288, 0.492553667108967), + (-1.08983188764878, -0.06771143046366379, 0.7892594299969324), + (-0.7261604551815313, 0.9578749227991876, -0.6086176800339509), + (-0.7436090040071672, -0.8048229943940851, -0.7839351036079769))} p_2_xyz = {'symbols': ('Br', 'H'), 'isotopes': (79, 1), 'coords': ((0.7644788559644482, 0.0, 0.0), (-0.7644788559644482, 0.0, 0.0))} @@ -819,7 +815,7 @@ def test_map_abstractions(self): self.assertIn(tuple(atom_map[5:]), permutations([5, 6, 7])) self.assertTrue(check_atom_map(rxn)) - #Change Order [H] + CC(=O)Br <=> C[C](=O) + [H][Br] + # Change Order [H] + CC(=O)Br <=> C[C](=O) + [H][Br] r_1 = ARCSpecies(label='r1', smiles='[H]', xyz=r_1_xyz) r_2 = ARCSpecies(label='r2', smiles='CC(=O)Br', xyz=r_2_xyz) p_1 = ARCSpecies(label='p1', smiles='C[C](=O)', xyz=p_1_xyz) diff --git a/arc/mapping/engine.py b/arc/mapping/engine.py index b16760b317..a9933ee268 100644 --- a/arc/mapping/engine.py +++ b/arc/mapping/engine.py @@ -19,6 +19,7 @@ from arc.common import convert_list_index_0_to_1, extremum_list, generate_resonance_structures, logger, key_by_val from arc.exceptions import SpeciesError +from arc.reaction.family import ReactionFamily, get_reaction_family_products from arc.species import ARCSpecies from arc.species.conformers import determine_chirality from arc.species.converter import compare_confs, sort_xyz_using_indices, translate_xyz, xyz_from_data, xyz_to_str @@ -26,195 +27,268 @@ from numpy import unique if TYPE_CHECKING: - from rmgpy.data.kinetics.family import TemplateReaction - from rmgpy.data.rmg import RMGDatabase from rmgpy.molecule.molecule import Atom - from rmgpy.reaction import Reaction from arc.reaction import ARCReaction RESERVED_FINGERPRINT_KEYS = ['self', 'chirality', 'label'] -def get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction: 'ARCReaction', - rmg_reaction: 'TemplateReaction', - ) -> Tuple[Optional[Dict[str, int]], Optional[Dict[str, int]]]: +def get_atom_indices_of_labeled_atoms_in_a_reaction(arc_reaction: 'ARCReaction', + ) -> Tuple[Dict[str, int], Dict[str, int]]: """ - Get the RMG reaction atom labels and the corresponding 0-indexed atom indices - for all labeled atoms in a TemplateReaction. + Get the atom indices for all labeled atoms in an ARCReaction object instance. Args: arc_reaction (ARCReaction): An ARCReaction object instance. - rmg_reaction (TemplateReaction): A respective RMG family TemplateReaction object instance. + + Todo: + - Currently only considering the first reaction for a single atom map, + should be extended to include all reactions and make atom_map a list of ARCReaction (call it .atom_maps ?). Returns: - Tuple[Optional[Dict[str, int]], Optional[Dict[str, int]]]: - The tuple entries relate to reactants and products. - Keys are labels (e.g., '*1'), values are corresponding 0-indices atoms. - """ - if not hasattr(rmg_reaction, 'labeled_atoms') or not rmg_reaction.labeled_atoms: - return None, None - - for spc in rmg_reaction.reactants + rmg_reaction.products: - generate_resonance_structures(object_=spc, save_order=True) - - r_map, p_map = map_arc_rmg_species(arc_reaction=arc_reaction, rmg_reaction=rmg_reaction, concatenate=False) - - reactant_index_dict, product_index_dict = dict(), dict() - reactant_atoms, product_atoms = list(), list() - rmg_reactant_order = [val for _, val in sorted(r_map.items(), key=lambda item: item[0])] - rmg_product_order = [val for _, val in sorted(p_map.items(), key=lambda item: item[0])] - for i in rmg_reactant_order: - reactant_atoms.extend([atom for atom in rmg_reaction.reactants[i].atoms]) - for i in rmg_product_order: - product_atoms.extend([atom for atom in rmg_reaction.products[i].atoms]) - - for labeled_atom_dict, atom_list, index_dict in zip([rmg_reaction.labeled_atoms['reactants'], - rmg_reaction.labeled_atoms['products']], - [reactant_atoms, product_atoms], - [reactant_index_dict, product_index_dict]): - for label, atom_1 in labeled_atom_dict.items(): - for i, atom_2 in enumerate(atom_list): - if atom_1.id == atom_2.id: - index_dict[label] = i - break + Tuple[Dict[str, int], Dict[str, int]]: Keys are labels (e.g., '*1'), + values are corresponding 0-indexed atom indices + in the reactants and in the products. + """ + product_dicts = get_reaction_family_products(arc_reaction) + product_dict = product_dicts[0] + reactant_index_dict, product_index_dict = product_dict['label_map'], dict() + pairs = pair_reaction_products(arc_reaction, product_dict['products']) + maps = dict() + for arc_rxn_idx, prod_idx in pairs.items(): + maps[prod_idx] = map_two_species(arc_reaction.p_species[arc_rxn_idx], + ARCSpecies(label=f'P{prod_idx}', mol=product_dict['products'][prod_idx]), + map_type='dict', + consider_chirality=False, + ) + fam_prods_map_to_rxn_prods = dict() + for i in range(len(maps.keys())): + for key, val in maps[i].items(): + fam_prods_map_to_rxn_prods[key + sum([len(mol.atoms) for mol in product_dict['products'][:i]])] = ( + val + sum([len(spc.mol.atoms) for spc in arc_reaction.r_species[:i]])) + product_index_dict = {key: fam_prods_map_to_rxn_prods[val] for key, val in reactant_index_dict.items()} return reactant_index_dict, product_index_dict -def map_arc_rmg_species(arc_reaction: 'ARCReaction', - rmg_reaction: Union['Reaction', 'TemplateReaction'], - concatenate: bool = True, - ) -> Tuple[Dict[int, Union[List[int], int]], Dict[int, Union[List[int], int]]]: +def pair_reaction_products(reaction: 'ARCReaction', + products: List['Molecule'], + ) -> Dict[int, int]: """ - Map the species pairs in an ARC reaction to those in a respective RMG reaction - which is defined in the same direction. + Map the species pairs in an ARC reaction to those in the given product list. + This function assumes that resonance structures (mol_list) were generated for the ARCReaction object instance. Args: - arc_reaction (ARCReaction): An ARCReaction object instance. - rmg_reaction (Union[Reaction, TemplateReaction]): A respective RMG family TemplateReaction object instance. - concatenate (bool, optional): Whether to return isomorphic species as a single list (``True``, default), - or to return isomorphic species separately (``False``). + reaction (ARCReaction): An ARCReaction object instance. + products (List[ARCSpecies]): Species that correspond to the ARCReaction products that require pairing. Returns: - Tuple[Dict[int, Union[List[int], int]], Dict[int, Union[List[int], int]]]: - The first tuple entry refers to reactants, the second to products. - Keys are specie indices in the ARC reaction, - values are respective indices in the RMG reaction. - If ``concatenate`` is ``True``, values are lists of integers. Otherwise, values are integers. - """ - if rmg_reaction.is_isomerization(): - if concatenate: - return {0: [0]}, {0: [0]} - else: - return {0: 0}, {0: 0} - r_map, p_map = dict(), dict() - arc_reactants, arc_products = arc_reaction.get_reactants_and_products(arc=True) - for spc_map, rmg_species, arc_species in [(r_map, rmg_reaction.reactants, arc_reactants), - (p_map, rmg_reaction.products, arc_products)]: - for i, arc_spc in enumerate(arc_species): - for j, rmg_obj in enumerate(rmg_species): - rmg_spc = Species(molecule=[rmg_obj]) if isinstance(rmg_obj, Molecule) else rmg_obj - if not isinstance(rmg_spc, Species): - raise ValueError(f'Expected an RMG object instances of Molecule or Species, ' - f'got {rmg_obj} which is a {type(rmg_obj)}.') - generate_resonance_structures(object_=rmg_spc, save_order=True) - rmg_spc_based_on_arc_spc = Species(molecule=arc_spc.mol_list) - generate_resonance_structures(object_=rmg_spc_based_on_arc_spc, save_order=True) - if rmg_spc.is_isomorphic(rmg_spc_based_on_arc_spc, save_order=True): - if i in spc_map.keys() and concatenate: - spc_map[i].append(j) - elif concatenate: - spc_map[i] = [j] - elif i not in spc_map.keys() and j not in spc_map.values(): - spc_map[i] = j + Dict[int, int]: Keys are specie indices in the ARC reaction, values are respective indices in the product list. + """ + if reaction.is_isomerization(): + return {0: 0} + product_pairs = dict() + arc_products = reaction.get_reactants_and_products(arc=True)[1] + for i, arc_product in enumerate(arc_products): + found = False + for j, prod in enumerate(products): + if j not in product_pairs.values(): + for arc_mol in arc_product.mol_list: + if arc_mol.is_isomorphic(prod): + product_pairs[i] = j + found = True break - if not r_map or not p_map: - raise ValueError(f'Could not match some of the RMG Reaction {rmg_reaction} to the ARC Reaction {arc_reaction}.') - return r_map, p_map - - -def find_equivalent_atoms_in_reactants(arc_reaction: 'ARCReaction', - backend: str = 'ARC', - ) -> List[List[int]]: - """ - Find atom indices that are equivalent in the reactants of an ARCReaction - in the sense that they represent degenerate reaction sites that are indifferentiable in 2D. - Bridges between RMG reaction templates and ARC's 3D TS structures. - Running indices in the returned structure relate to reactant_0 + reactant_1 + ... - - Args: - arc_reaction ('ARCReaction'): The ARCReaction object instance. - backend (str, optional): Whether to use ``'QCElemental'`` or ``ARC``'s method as the backend. - - Returns: - List[List[int]]: Entries are lists of 0-indices, each such list represents equivalent atoms. - """ - rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction, backend=backend) - dicts = [get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(rmg_reaction=rmg_reaction, - arc_reaction=arc_reaction)[0] - for rmg_reaction in rmg_reactions] - equivalence_map = dict() - for index_dict in dicts: - for key, value in index_dict.items(): - if key in equivalence_map: - equivalence_map[key].append(value) - else: - equivalence_map[key] = [value] - equivalent_indices = list(list(set(equivalent_list)) for equivalent_list in equivalence_map.values()) - return equivalent_indices - - -def get_rmg_reactions_from_arc_reaction(arc_reaction: 'ARCReaction', - backend: str = 'ARC', - ) -> Optional[List['TemplateReaction']]: - """ - A helper function for getting RMG reactions from an ARC reaction. - This function calls ``map_two_species()`` so that each species in the RMG reaction is correctly mapped - to the corresponding species in the ARC reaction. It does not attempt to map reactants to products. - - Args: - arc_reaction (ARCReaction): The ARCReaction object instance. - backend (str, optional): Whether to use ``'QCElemental'`` or ``ARC``'s method as the backend. - - Returns: - Optional[List[TemplateReaction]]: - The respective RMG TemplateReaction object instances (considering resonance structures). - """ - if arc_reaction.family is None: - return None - rmg_reactions = arc_reaction.family.generate_reactions(reactants=[spc.mol.copy(deep=True) for spc in arc_reaction.r_species], - products=[spc.mol.copy(deep=True) for spc in arc_reaction.p_species], - prod_resonance=True, - delete_labels=False, - relabel_atoms=False, - ) - for rmg_reaction in rmg_reactions: - r_map, p_map = map_arc_rmg_species(arc_reaction=arc_reaction, rmg_reaction=rmg_reaction, concatenate=False) - try: - ordered_rmg_reactants = [rmg_reaction.reactants[r_map[i]] for i in range(len(rmg_reaction.reactants))] - ordered_rmg_products = [rmg_reaction.products[p_map[i]] for i in range(len(rmg_reaction.products))] - except KeyError: - logger.warning(f'Got a problematic RMG rxn from ARC rxn, trying again') - continue - mapped_rmg_reactants, mapped_rmg_products = list(), list() - for ordered_rmg_mols, arc_species, mapped_mols in zip([ordered_rmg_reactants, ordered_rmg_products], - [arc_reaction.r_species, arc_reaction.p_species], - [mapped_rmg_reactants, mapped_rmg_products], - ): - for rmg_mol, arc_spc in zip(ordered_rmg_mols, arc_species): - mol = arc_spc.copy().mol - # The RMG molecule will get a random 3D conformer, don't consider chirality when mapping. - atom_map = map_two_species(mol, rmg_mol, map_type='dict', backend=backend, consider_chirality=False) - if atom_map is None: - continue - new_atoms_list = list() - for i in range(len(rmg_mol.atoms)): - rmg_mol.atoms[atom_map[i]].id = mol.atoms[i].id - new_atoms_list.append(rmg_mol.atoms[atom_map[i]]) - rmg_mol.atoms = new_atoms_list - mapped_mols.append(rmg_mol) - rmg_reaction.reactants, rmg_reaction.products = mapped_rmg_reactants, mapped_rmg_products - return rmg_reactions + if found: + break + if len(product_pairs.keys()) != len(arc_products): + raise ValueError(f'Could not match some of the products in: {reaction}\nto the given products:\n{products}') + return product_pairs + + + + + + + +# def get_rmg_reactions_from_arc_reaction(arc_reaction: 'ARCReaction', +# backend: str = 'ARC', +# ) -> Optional[List['TemplateReaction']]: +# """ +# A helper function for getting RMG reactions from an ARC reaction. +# This function calls ``map_two_species()`` so that each species in the RMG reaction is correctly mapped +# to the corresponding species in the ARC reaction. It does not attempt to map reactants to products. +# +# Args: +# arc_reaction (ARCReaction): The ARCReaction object instance. +# backend (str, optional): Whether to use ``'QCElemental'`` or ``ARC``'s method as the backend. +# +# Returns: +# Optional[List[TemplateReaction]]: +# The respective RMG TemplateReaction object instances (considering resonance structures). +# """ +# if arc_reaction.family is None: +# return None +# rmg_reactions = arc_reaction.family.generate_reactions(reactants=[spc.mol.copy(deep=True) for spc in arc_reaction.r_species], +# products=[spc.mol.copy(deep=True) for spc in arc_reaction.p_species], +# prod_resonance=True, +# delete_labels=False, +# relabel_atoms=False, +# ) +# for rmg_reaction in rmg_reactions: +# r_map, p_map = map_arc_rmg_species(arc_reaction=arc_reaction, rmg_reaction=rmg_reaction, concatenate=False) +# try: +# ordered_rmg_reactants = [rmg_reaction.reactants[r_map[i]] for i in range(len(rmg_reaction.reactants))] +# ordered_rmg_products = [rmg_reaction.products[p_map[i]] for i in range(len(rmg_reaction.products))] +# except KeyError: +# logger.warning(f'Got a problematic RMG rxn from ARC rxn, trying again') +# continue +# mapped_rmg_reactants, mapped_rmg_products = list(), list() +# for ordered_rmg_mols, arc_species, mapped_mols in zip([ordered_rmg_reactants, ordered_rmg_products], +# [arc_reaction.r_species, arc_reaction.p_species], +# [mapped_rmg_reactants, mapped_rmg_products], +# ): +# for rmg_mol, arc_spc in zip(ordered_rmg_mols, arc_species): +# mol = arc_spc.copy().mol +# # The RMG molecule will get a random 3D conformer, don't consider chirality when mapping. +# atom_map = map_two_species(mol, rmg_mol, map_type='dict', backend=backend, consider_chirality=False) +# if atom_map is None: +# continue +# new_atoms_list = list() +# for i in range(len(rmg_mol.atoms)): +# rmg_mol.atoms[atom_map[i]].id = mol.atoms[i].id +# new_atoms_list.append(rmg_mol.atoms[atom_map[i]]) +# rmg_mol.atoms = new_atoms_list +# mapped_mols.append(rmg_mol) +# rmg_reaction.reactants, rmg_reaction.products = mapped_rmg_reactants, mapped_rmg_products +# return rmg_reactions +# +# +# def get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction: 'ARCReaction', +# rmg_reaction: 'TemplateReaction', +# ) -> Tuple[Optional[Dict[str, int]], Optional[Dict[str, int]]]: +# """ +# Get the RMG reaction atom labels and the corresponding 0-indexed atom indices +# for all labeled atoms in a TemplateReaction. +# +# Args: +# arc_reaction (ARCReaction): An ARCReaction object instance. +# rmg_reaction (TemplateReaction): A respective RMG family TemplateReaction object instance. +# +# Returns: +# Tuple[Optional[Dict[str, int]], Optional[Dict[str, int]]]: +# The tuple entries relate to reactants and products. +# Keys are labels (e.g., '*1'), values are corresponding 0-indices atoms. +# """ +# if not hasattr(rmg_reaction, 'labeled_atoms') or not rmg_reaction.labeled_atoms: +# return None, None +# +# for spc in rmg_reaction.reactants + rmg_reaction.products: +# generate_resonance_structures(object_=spc, save_order=True) +# +# r_map, p_map = map_arc_rmg_species(arc_reaction=arc_reaction, rmg_reaction=rmg_reaction, concatenate=False) +# +# reactant_index_dict, product_index_dict = dict(), dict() +# reactant_atoms, product_atoms = list(), list() +# rmg_reactant_order = [val for _, val in sorted(r_map.items(), key=lambda item: item[0])] +# rmg_product_order = [val for _, val in sorted(p_map.items(), key=lambda item: item[0])] +# for i in rmg_reactant_order: +# reactant_atoms.extend([atom for atom in rmg_reaction.reactants[i].atoms]) +# for i in rmg_product_order: +# product_atoms.extend([atom for atom in rmg_reaction.products[i].atoms]) +# +# for labeled_atom_dict, atom_list, index_dict in zip([rmg_reaction.labeled_atoms['reactants'], +# rmg_reaction.labeled_atoms['products']], +# [reactant_atoms, product_atoms], +# [reactant_index_dict, product_index_dict]): +# for label, atom_1 in labeled_atom_dict.items(): +# for i, atom_2 in enumerate(atom_list): +# if atom_1.id == atom_2.id: +# index_dict[label] = i +# break +# return reactant_index_dict, product_index_dict +# +# +# def map_arc_rmg_species(arc_reaction: 'ARCReaction', +# rmg_reaction: Union['Reaction', 'TemplateReaction'], +# concatenate: bool = True, +# ) -> Tuple[Dict[int, Union[List[int], int]], Dict[int, Union[List[int], int]]]: +# """ +# Map the species pairs in an ARC reaction to those in a respective RMG reaction +# which is defined in the same direction. +# +# Args: +# arc_reaction (ARCReaction): An ARCReaction object instance. +# rmg_reaction (Union[Reaction, TemplateReaction]): A respective RMG family TemplateReaction object instance. +# concatenate (bool, optional): Whether to return isomorphic species as a single list (``True``, default), +# or to return isomorphic species separately (``False``). +# +# Returns: +# Tuple[Dict[int, Union[List[int], int]], Dict[int, Union[List[int], int]]]: +# The first tuple entry refers to reactants, the second to products. +# Keys are specie indices in the ARC reaction, +# values are respective indices in the RMG reaction. +# If ``concatenate`` is ``True``, values are lists of integers. Otherwise, values are integers. +# """ +# if rmg_reaction.is_isomerization(): +# if concatenate: +# return {0: [0]}, {0: [0]} +# else: +# return {0: 0}, {0: 0} +# r_map, p_map = dict(), dict() +# arc_reactants, arc_products = arc_reaction.get_reactants_and_products(arc=True) +# for spc_map, rmg_species, arc_species in [(r_map, rmg_reaction.reactants, arc_reactants), +# (p_map, rmg_reaction.products, arc_products)]: +# for i, arc_spc in enumerate(arc_species): +# for j, rmg_obj in enumerate(rmg_species): +# rmg_spc = Species(molecule=[rmg_obj]) if isinstance(rmg_obj, Molecule) else rmg_obj +# if not isinstance(rmg_spc, Species): +# raise ValueError(f'Expected an RMG object instances of Molecule or Species, ' +# f'got {rmg_obj} which is a {type(rmg_obj)}.') +# generate_resonance_structures(object_=rmg_spc, save_order=True) +# rmg_spc_based_on_arc_spc = Species(molecule=arc_spc.mol_list) +# generate_resonance_structures(object_=rmg_spc_based_on_arc_spc, save_order=True) +# if rmg_spc.is_isomorphic(rmg_spc_based_on_arc_spc, save_order=True): +# if i in spc_map.keys() and concatenate: +# spc_map[i].append(j) +# elif concatenate: +# spc_map[i] = [j] +# elif i not in spc_map.keys() and j not in spc_map.values(): +# spc_map[i] = j +# break +# if not r_map or not p_map: +# raise ValueError(f'Could not match some of the RMG Reaction {rmg_reaction} to the ARC Reaction {arc_reaction}.') +# return r_map, p_map +# +# +# def find_equivalent_atoms_in_reactants(arc_reaction: 'ARCReaction', # ? +# backend: str = 'ARC', +# ) -> List[List[int]]: +# """ +# Find atom indices that are equivalent in the reactants of an ARCReaction +# in the sense that they represent degenerate reaction sites that are indifferentiable in 2D. +# Bridges between RMG reaction templates and ARC's 3D TS structures. +# Running indices in the returned structure relate to reactant_0 + reactant_1 + ... +# +# Args: +# arc_reaction ('ARCReaction'): The ARCReaction object instance. +# backend (str, optional): Whether to use ``'QCElemental'`` or ``ARC``'s method as the backend. +# +# Returns: +# List[List[int]]: Entries are lists of 0-indices, each such list represents equivalent atoms. +# """ +# rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction, backend=backend) +# dicts = [get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(rmg_reaction=rmg_reaction, +# arc_reaction=arc_reaction)[0] +# for rmg_reaction in rmg_reactions] +# equivalence_map = dict() +# for index_dict in dicts: +# for key, value in index_dict.items(): +# if key in equivalence_map: +# equivalence_map[key].append(value) +# else: +# equivalence_map[key] = [value] +# equivalent_indices = list(list(set(equivalent_list)) for equivalent_list in equivalence_map.values()) +# return equivalent_indices def map_two_species(spc_1: Union[ARCSpecies, Species, Molecule], @@ -1024,7 +1098,6 @@ def make_bond_changes(rxn: 'ARCReaction', r_cuts: the cut products r_label_dict: the dictionary object the find the relevant location. """ - for action in rxn.family.forward_recipe.actions: if action[0].lower() == "CHANGE_BOND".lower(): indicies = r_label_dict[action[1]],r_label_dict[action[3]] @@ -1279,12 +1352,12 @@ def find_all_bdes(rxn: "ARCReaction", label_dict: dict, is_reactants: bool) -> L Args: rxn (ARCReaction): The reaction in question. label_dict (dict): A dictionary of the atom indices to the atom labels. - is_reactants (bool): Whether or not the species list represents reactants or products. + is_reactants (bool): Whether the species list represents reactants or products. Returns: List[Tuple[int, int]]: A list of tuples of the form (atom_index1, atom_index2) for each broken bond. Note that these represent the atom indicies to be cut, and not final BDEs. """ bdes = list() - for action in rxn.family.forward_recipe.actions: + for action in ReactionFamily(rxn.family).actions: if action[0].lower() == ("break_bond" if is_reactants else "form_bond"): bdes.append((label_dict[action[1]] + 1, label_dict[action[3]] + 1)) return bdes diff --git a/arc/mapping/engine_test.py b/arc/mapping/engine_test.py index 8952a15bf5..7a3af908be 100644 --- a/arc/mapping/engine_test.py +++ b/arc/mapping/engine_test.py @@ -37,46 +37,46 @@ def setUpClass(cls): smiles = ['CC(C)F', '[CH3]', 'C[CH](C)', 'CF'] r_1_1_xyz = {'symbols': ('C', 'C', 'C', 'F', 'H', 'H', 'H', 'H', 'H', 'H', 'H'), - 'isotopes': (12, 12, 12, 19, 1, 1, 1, 1, 1, 1, 1), - 'coords': ((1.2509680857915237, 0.00832885083067477, -0.28594855682006387), - (-0.08450322338173592, -0.5786110309038947, 0.12835305965368538), - (-1.196883483105121, 0.4516770584363101, 0.10106807955582568), - (0.03212452836861426, -1.0465351442062332, 1.402047416169314), - (1.2170230403876368, 0.39373449465586885, -1.309310880313081), - (1.5446944155971303, 0.8206316657310906, 0.38700047363833845), - (2.0327466889922805, -0.7555292157466509, -0.22527487012253536), - (-0.3397419937928473, -1.4280299782557704, -0.5129583662636836), - (-0.9791793765226446, 1.2777482351478369, 0.786037216866474), - (-1.340583396165929, 0.8569620299504027, -0.9049411765144166), - (-2.1366652861689137, -0.00037696563964776297, 0.43392760415012316))} + 'isotopes': (12, 12, 12, 19, 1, 1, 1, 1, 1, 1, 1), + 'coords': ((1.2509680857915237, 0.00832885083067477, -0.28594855682006387), + (-0.08450322338173592, -0.5786110309038947, 0.12835305965368538), + (-1.196883483105121, 0.4516770584363101, 0.10106807955582568), + (0.03212452836861426, -1.0465351442062332, 1.402047416169314), + (1.2170230403876368, 0.39373449465586885, -1.309310880313081), + (1.5446944155971303, 0.8206316657310906, 0.38700047363833845), + (2.0327466889922805, -0.7555292157466509, -0.22527487012253536), + (-0.3397419937928473, -1.4280299782557704, -0.5129583662636836), + (-0.9791793765226446, 1.2777482351478369, 0.786037216866474), + (-1.340583396165929, 0.8569620299504027, -0.9049411765144166), + (-2.1366652861689137, -0.00037696563964776297, 0.43392760415012316))} r_2_1_xyz = {'symbols': ('C', 'H', 'H', 'H'), - 'isotopes': (12, 1, 1, 1), - 'coords' : ((3.3746019998564553e-09, 5.828827384106545e-09, -4.859105107686622e-09), - (1.0669051052331406, -0.17519582095514982, 0.05416492980439295), - (-0.6853171627400634, -0.8375353626879753, -0.028085652887100996), - (-0.3815879458676787, 1.0127311778142964, -0.026079272058187608))} + 'isotopes': (12, 1, 1, 1), + 'coords': ((3.3746019998564553e-09, 5.828827384106545e-09, -4.859105107686622e-09), + (1.0669051052331406, -0.17519582095514982, 0.05416492980439295), + (-0.6853171627400634, -0.8375353626879753, -0.028085652887100996), + (-0.3815879458676787, 1.0127311778142964, -0.026079272058187608))} p_1_1_xyz = {'symbols': ('C', 'C', 'C', 'H', 'H', 'H', 'H', 'H', 'H', 'H'), - 'isotopes': (12, 12, 12, 1, 1, 1, 1, 1, 1, 1), - 'coords': ((-1.288730238258946, 0.06292843803165035, 0.10889818910854648), - (0.01096160773224897, -0.45756396262445836, -0.3934214957819532), - (1.2841030977199492, 0.11324607936811129, 0.12206176848573647), - (-1.4984446521053447, 1.0458196461796345, -0.3223873567509909), - (-1.2824724918369017, 0.14649429503996203, 1.1995362776757934), - (-2.098384694966955, -0.616646552269074, -0.17318515188247927), - (0.027360233461550892, -1.0601383387124987, -1.2952225290380646), - (2.122551165381095, -0.534098313164123, -0.15158596254231563), - (1.2634262459696732, 0.19628891975881263, 1.2125616721427255), - (1.4596297269035956, 1.1036697883919826, -0.307255411416999))} + 'isotopes': (12, 12, 12, 1, 1, 1, 1, 1, 1, 1), + 'coords': ((-1.288730238258946, 0.06292843803165035, 0.10889818910854648), + (0.01096160773224897, -0.45756396262445836, -0.3934214957819532), + (1.2841030977199492, 0.11324607936811129, 0.12206176848573647), + (-1.4984446521053447, 1.0458196461796345, -0.3223873567509909), + (-1.2824724918369017, 0.14649429503996203, 1.1995362776757934), + (-2.098384694966955, -0.616646552269074, -0.17318515188247927), + (0.027360233461550892, -1.0601383387124987, -1.2952225290380646), + (2.122551165381095, -0.534098313164123, -0.15158596254231563), + (1.2634262459696732, 0.19628891975881263, 1.2125616721427255), + (1.4596297269035956, 1.1036697883919826, -0.307255411416999))} p_2_1_xyz = {'symbols': ('C', 'F', 'H', 'H', 'H'), - 'isotopes': (12, 19, 1, 1, 1), - 'coords': ((-0.060384822736851786, 0.004838867136375763, -0.004814368798794687), - (1.2877092002693546, -0.10318918150563985, 0.10266661058725791), - (-0.2965861926821434, 0.9189121874074381, -0.5532990701789506), - (-0.44047773762823295, -0.8660709320146035, -0.5425894744224189), - (-0.49026044722212864, 0.04550905897643097, 0.9980363028129072))} + 'isotopes': (12, 19, 1, 1, 1), + 'coords': ((-0.060384822736851786, 0.004838867136375763, -0.004814368798794687), + (1.2877092002693546, -0.10318918150563985, 0.10266661058725791), + (-0.2965861926821434, 0.9189121874074381, -0.5532990701789506), + (-0.44047773762823295, -0.8660709320146035, -0.5425894744224189), + (-0.49026044722212864, 0.04550905897643097, 0.9980363028129072))} cls.r_1 = ARCSpecies(label='r1', smiles=smiles[0],xyz=r_1_1_xyz ) cls.r_2 = ARCSpecies(label='r2', smiles=smiles[1],xyz=r_2_1_xyz) cls.p_1 = ARCSpecies(label='p1', smiles=smiles[2],xyz=p_1_1_xyz) @@ -88,10 +88,10 @@ def setUpClass(cls): cls.p_2_2 = ARCSpecies(label='p2', smiles=smiles[3],xyz=p_2_1_xyz) cls.rxn_1 = ARCReaction(r_species=[cls.r_1, cls.r_2], p_species=[cls.p_1, cls.p_2]) - cls.rmg_reactions_rxn_1 = get_rmg_reactions_from_arc_reaction(arc_reaction=cls.rxn_1, backend="ARC") - cls.r_label_dict_rxn_1, cls.p_label_dict_rxn_1 = ( - get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=cls.rxn_1, - rmg_reaction=cls.rmg_reactions_rxn_1[0])) + # cls.rmg_reactions_rxn_1 = get_rmg_reactions_from_arc_reaction(arc_reaction=cls.rxn_1, backend="ARC") + # cls.r_label_dict_rxn_1, cls.p_label_dict_rxn_1 = ( + # get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=cls.rxn_1, + # rmg_reaction=cls.rmg_reactions_rxn_1[0])) cls.spc1 = ARCSpecies(label="Test_is_isomorphic_1",smiles="C(CO)CC") cls.spc2 = ARCSpecies(label="Test_is_isomorphic_2",smiles="OCCCC") @@ -536,8 +536,36 @@ def setUpClass(cls): H -1.04996634 -0.37234114 0.91874740 H 1.36260637 0.37153887 -0.86221771""" + def test_get_atom_indices_of_labeled_atoms_in_a_reaction(self): + """Test getting atom indices of labeled atoms in a reaction""" + atom_indices = get_atom_indices_of_labeled_atoms_in_a_reaction(self.arc_reaction_1) + self.assertEqual(atom_indices, ({'*1': 0, '*2': 1, '*3': 5}, {'*1': 0, '*2': 1, '*3': 7})) + + atom_indices = get_atom_indices_of_labeled_atoms_in_a_reaction(self.arc_reaction_2) + print(atom_indices) + self.assertEqual(atom_indices, ({'*1': 0, '*2': 3, '*3': 11}, {'*1': 0, '*2': 3, '*3': 16})) + + def test_pair_reaction_products(self): + """Test pairing reaction and products""" + products = [Molecule(smiles='[CH3]'), Molecule(smiles='O')] + product_pairs = pair_reaction_products(self.arc_reaction_1, products) + self.assertEqual(product_pairs, {0: 0, 1: 1}) + + product_pairs = pair_reaction_products(self.arc_reaction_1, [products[1], products[0]]) + self.assertEqual(product_pairs, {0: 1, 1: 0}) + + rxn_1 = ARCReaction(label='H2NN(T) + N2H4 <=> N2H3 + N2H3', + r_species=[ARCSpecies(label='H2NN(T)', smiles='N[N]'), + ARCSpecies(label='N2H4', smiles='NN')], + p_species=[ARCSpecies(label='N2H3', smiles='N[NH]')]) + prod_mol = Molecule(smiles='N[NH]') + product_pairs = pair_reaction_products(rxn_1, [prod_mol, prod_mol]) + self.assertEqual(product_pairs, {0: 0, 1: 1}) + + def test_assign_labels_to_products(self): + """Test assigning labels to products based on the atom map of the reaction""" rxn_1_test = ARCReaction(r_species=[self.r_1, self.r_2], p_species=[self.p_1, self.p_2]) assign_labels_to_products(rxn_1_test, self.p_label_dict_rxn_1) index = 0 @@ -545,7 +573,7 @@ def test_assign_labels_to_products(self): for atom in product.mol.atoms: if not isinstance(atom.label, str) or atom.label != "": self.assertEqual(self.p_label_dict_rxn_1[atom.label], index) - index+=1 + index += 1 def test_inc_vals(self): """Test creating an atom map via map_two_species() and incrementing all values""" @@ -709,7 +737,7 @@ def test_get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(self): """Test the get_atom_indices_of_labeled_atoms_in_an_rmg_reaction() function.""" rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=self.arc_reaction_1) r_dict, p_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=self.arc_reaction_1, - rmg_reaction=rmg_reactions[0]) + rmg_reaction=rmg_reactions[0]) self.assertEqual(r_dict['*1'], 0) self.assertIn(r_dict['*2'], [1, 2, 3, 4]) self.assertEqual(r_dict['*3'], 5) @@ -720,7 +748,7 @@ def test_get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(self): rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=self.arc_reaction_2) r_dict, p_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=self.arc_reaction_2, - rmg_reaction=rmg_reactions[0]) + rmg_reaction=rmg_reactions[0]) self.assertIn(r_dict['*1'], [0, 2]) self.assertIn(r_dict['*2'], [3, 4, 5, 8, 9, 10]) self.assertEqual(r_dict['*3'], 11) @@ -731,7 +759,7 @@ def test_get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(self): rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=self.arc_reaction_4) r_dict, p_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=self.arc_reaction_4, - rmg_reaction=rmg_reactions[0]) + rmg_reaction=rmg_reactions[0]) self.assertEqual(r_dict['*1'], 0) self.assertEqual(r_dict['*2'], 2) self.assertIn(r_dict['*3'], [7, 8]) @@ -750,7 +778,7 @@ def test_get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(self): self.assertEqual(self.rxn_2a.p_species[0].mol.atoms[2].radical_electrons, 1) rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=self.rxn_2a) r_dict, p_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=self.rxn_2a, - rmg_reaction=rmg_reactions[0]) + rmg_reaction=rmg_reactions[0]) self.assertEqual(r_dict['*1'], 1) self.assertIn(r_dict['*2'], [0, 2]) self.assertIn(r_dict['*3'], [4, 5, 6, 7, 8, 9]) @@ -763,7 +791,7 @@ def test_get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(self): self.assertEqual(atom.symbol, symbol) rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=self.rxn_2b) r_dict, p_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=self.rxn_2b, - rmg_reaction=rmg_reactions[0]) + rmg_reaction=rmg_reactions[0]) self.assertEqual(r_dict['*1'], 1) self.assertIn(r_dict['*2'], [0, 6]) self.assertIn(r_dict['*3'], [3, 4, 5, 7, 8, 9]) @@ -782,7 +810,7 @@ def test_get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(self): rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=rxn_1) for rmg_reaction in rmg_reactions: r_dict, p_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=rxn_1, - rmg_reaction=rmg_reaction) + rmg_reaction=rmg_reaction) for d in [r_dict, p_dict]: self.assertEqual(len(list(d.keys())), 3) keys = list(d.keys()) @@ -818,21 +846,22 @@ def test_get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(self): self.assertEqual(p_dict, {'*1': 1, '*3': 9, '*2': 16}) self.assertTrue(_check_r_n_p_symbols_between_rmg_and_arc_rxns(rxn_3, rmg_reactions)) - def test_map_arc_rmg_species(self): """Test the map_arc_rmg_species() function.""" - r_map, p_map = map_arc_rmg_species(arc_reaction=ARCReaction(r_species=[ARCSpecies(label='CCjC', smiles='C[CH]C')], - p_species=[ARCSpecies(label='CjCC', smiles='[CH2]CC')]), - rmg_reaction=Reaction(reactants=[Species(smiles='C[CH]C')], - products=[Species(smiles='[CH2]CC')]), - concatenate=False) + r_map, p_map = map_arc_rmg_species( + arc_reaction=ARCReaction(r_species=[ARCSpecies(label='CCjC', smiles='C[CH]C')], + p_species=[ARCSpecies(label='CjCC', smiles='[CH2]CC')]), + rmg_reaction=Reaction(reactants=[Species(smiles='C[CH]C')], + products=[Species(smiles='[CH2]CC')]), + concatenate=False) self.assertEqual(r_map, {0: 0}) self.assertEqual(p_map, {0: 0}) - r_map, p_map = map_arc_rmg_species(arc_reaction=ARCReaction(r_species=[ARCSpecies(label='CCjC', smiles='C[CH]C')], - p_species=[ARCSpecies(label='CjCC', smiles='[CH2]CC')]), - rmg_reaction=Reaction(reactants=[Species(smiles='C[CH]C')], - products=[Species(smiles='[CH2]CC')])) + r_map, p_map = map_arc_rmg_species( + arc_reaction=ARCReaction(r_species=[ARCSpecies(label='CCjC', smiles='C[CH]C')], + p_species=[ARCSpecies(label='CjCC', smiles='[CH2]CC')]), + rmg_reaction=Reaction(reactants=[Species(smiles='C[CH]C')], + products=[Species(smiles='[CH2]CC')])) self.assertEqual(r_map, {0: [0]}) self.assertEqual(p_map, {0: [0]}) diff --git a/arc/reaction/family.py b/arc/reaction/family.py index 1011419e52..6d2522d478 100644 --- a/arc/reaction/family.py +++ b/arc/reaction/family.py @@ -393,7 +393,7 @@ def get_reaction_family_products(rxn: 'ARCReaction', Returns: List[dict]: The list of product dictionaries with the reaction family label. - Keys are: 'family', 'group_labels', 'products', 'own_reverse', 'discovered_in_reverse'. + Keys are: 'family', 'group_labels', 'products', 'own_reverse', 'discovered_in_reverse', 'actions'. """ family_labels = get_all_families(rmg_family_set=rmg_family_set, consider_rmg_families=consider_rmg_families, @@ -506,6 +506,22 @@ def check_product_isomorphism(products: List['Molecule'], return False +# def get_all_reactions_paths(rxn: 'ARCReaction') -> List['ARCReaction']: +# """ +# Get all possible reaction paths with labeled atoms for a given ARC reaction. +# +# Args: +# rxn ('ARCReaction'): The ARC reaction object. +# +# Returns: +# List['ARCReaction']: A list of reactions, each represents a possible reaction path. +# """ +# reaction_paths = list() +# for reactant in rxn.get_reactants_and_products(arc=True, return_copies=True)[0]: +# all_reactions.extend(reactant.generate_reactions()) +# return all_reactions + + def get_all_families(rmg_family_set: str = 'default', consider_rmg_families: bool = True, consider_arc_families: bool = True, diff --git a/arc/reaction/family_test.py b/arc/reaction/family_test.py index 82cfd84104..4c6be4c92d 100644 --- a/arc/reaction/family_test.py +++ b/arc/reaction/family_test.py @@ -60,6 +60,87 @@ def test_arc_families_path(self): def test_get_reaction_family_products(self): """Test determining the reaction family using product dicts""" + rxn_0a = ARCReaction(r_species=[ARCSpecies(label='CH4', smiles='C'), ARCSpecies(label='O2', smiles='[O][O]')], + p_species=[ARCSpecies(label='CH3', smiles='[CH3]'), + ARCSpecies(label='HO2', smiles='O[O]')]) + products = get_reaction_family_products(rxn_0a) + expected_products = [{'discovered_in_reverse': False, + 'family': 'H_Abstraction', + 'group_labels': ('X_H', 'Y_rad'), + 'label_map': {'*1': 0, '*2': 1, '*3': 5}, + 'own_reverse': True, + 'products': [Molecule(smiles="[O]O"), Molecule(smiles="[CH3]")]}, + {'discovered_in_reverse': False, + 'family': 'H_Abstraction', + 'group_labels': ('X_H', 'Y_rad'), + 'label_map': {'*1': 0, '*2': 1, '*3': 6}, + 'own_reverse': True, + 'products': [Molecule(smiles="[O]O"), Molecule(smiles="[CH3]")]}, + {'discovered_in_reverse': False, + 'family': 'H_Abstraction', + 'group_labels': ('X_H', 'Y_rad'), + 'label_map': {'*1': 0, '*2': 2, '*3': 5}, + 'own_reverse': True, + 'products': [Molecule(smiles="[O]O"), Molecule(smiles="[CH3]")]}, + {'discovered_in_reverse': False, + 'family': 'H_Abstraction', + 'group_labels': ('X_H', 'Y_rad'), + 'label_map': {'*1': 0, '*2': 2, '*3': 6}, + 'own_reverse': True, + 'products': [Molecule(smiles="[O]O"), Molecule(smiles="[CH3]")]}, + {'discovered_in_reverse': False, + 'family': 'H_Abstraction', + 'group_labels': ('X_H', 'Y_rad'), + 'label_map': {'*1': 0, '*2': 3, '*3': 5}, + 'own_reverse': True, + 'products': [Molecule(smiles="[O]O"), Molecule(smiles="[CH3]")]}, + {'discovered_in_reverse': False, + 'family': 'H_Abstraction', + 'group_labels': ('X_H', 'Y_rad'), + 'label_map': {'*1': 0, '*2': 3, '*3': 6}, + 'own_reverse': True, + 'products': [Molecule(smiles="[O]O"), Molecule(smiles="[CH3]")]}, + {'discovered_in_reverse': False, + 'family': 'H_Abstraction', + 'group_labels': ('X_H', 'Y_rad'), + 'label_map': {'*1': 0, '*2': 4, '*3': 5}, + 'own_reverse': True, + 'products': [Molecule(smiles="[O]O"), Molecule(smiles="[CH3]")]}, + {'discovered_in_reverse': False, + 'family': 'H_Abstraction', + 'group_labels': ('X_H', 'Y_rad'), + 'label_map': {'*1': 0, '*2': 4, '*3': 6}, + 'own_reverse': True, + 'products': [Molecule(smiles="[O]O"), Molecule(smiles="[CH3]")]}] + self.assertEqual(products, expected_products) + + ch4_xyz = """C -0.00000000 -0.00000000 0.00000000 +H -0.87497771 -0.55943190 -0.33815595 +H -0.04050904 1.01567250 -0.39958464 +H 0.00816153 0.03824434 1.09149909 +H 0.90732523 -0.49448494 -0.35375850""" + o2_xyz = {'symbols': ('O', 'O'), 'isotopes': (16, 16), 'coords': ((0.0, 0.0, 0.6029), (0.0, 0.0, -0.6029))} + ch3_xyz = """C -0.00000000 -0.00000000 -0.00000001 +H 1.04110758 -0.29553525 0.02584268 +H -0.77665903 -0.75407986 -0.00884379 +H -0.26444854 1.04961512 -0.01699888""" + ho2_xyz = """O -0.15635718 0.45208323 0.00000000 +O 0.99456866 -0.18605915 0.00000000 +H -0.83821148 -0.26602407 0.00000000""" + + rxn_0b = ARCReaction(r_species=[ARCSpecies(label='CH4', xyz=ch4_xyz), ARCSpecies(label='O2', xyz=o2_xyz, multiplicity=3)], + p_species=[ARCSpecies(label='CH3', xyz=ch3_xyz), ARCSpecies(label='HO2', xyz=ho2_xyz)]) + products = get_reaction_family_products(rxn_0b) + expected_products = [{'discovered_in_reverse': True, + 'family': 'Disproportionation', # Todo: should be H_abs after merging Calvin's PR + 'group_labels': 'Root', + 'label_map': {'*1': 0, '*2': 4, '*3': 5, '*4': 6}, + 'own_reverse': False, + 'products': [Molecule(smiles="C"), Molecule(smiles="O=O")]}] + print(products) + self.assertEqual(products, expected_products) + + rxn_1 = ARCReaction(reactants=['NH2', 'NH2'], products=['NH', 'NH3'], r_species=[ARCSpecies(label='NH2', smiles='[NH2]')], p_species=[ARCSpecies(label='NH', smiles='[NH]'), ARCSpecies(label='NH3', smiles='N')]) diff --git a/arc/reaction/reaction.py b/arc/reaction/reaction.py index ece207202f..85331a96f3 100644 --- a/arc/reaction/reaction.py +++ b/arc/reaction/reaction.py @@ -10,7 +10,7 @@ from arc.common import get_logger from arc.exceptions import ReactionError, InputError -from arc.reaction.family import ReactionFamily, determine_reaction_family +from arc.reaction.family import ReactionFamily, get_reaction_family_products from arc.species.converter import (check_xyz_dict, sort_xyz_using_indices, translate_to_center_of_mass, @@ -593,12 +593,12 @@ def determine_family(self, discover_own_reverse_rxns_in_reverse (bool, optional): Whether to discover own reverse reactions in reverse. """ if self.rmg_reaction is not None: - product_dicts = determine_reaction_family(rxn=self, - rmg_family_set=rmg_family_set, - consider_rmg_families=consider_rmg_families, - consider_arc_families=consider_arc_families, - discover_own_reverse_rxns_in_reverse=discover_own_reverse_rxns_in_reverse, - ) + product_dicts = get_reaction_family_products(rxn=self, + rmg_family_set=rmg_family_set, + consider_rmg_families=consider_rmg_families, + consider_arc_families=consider_arc_families, + discover_own_reverse_rxns_in_reverse=discover_own_reverse_rxns_in_reverse, + ) if len(product_dicts): family, family_own_reverse = product_dicts[0]['family'], product_dicts[0]['own_reverse'] return family, family_own_reverse diff --git a/arc/rmgdb.py b/arc/rmgdb.py index 0855ad9b6a..313e463da0 100644 --- a/arc/rmgdb.py +++ b/arc/rmgdb.py @@ -160,7 +160,7 @@ def load_rmg_database(rmgdb: RMGDatabase, logger.info('\n\n') -def determine_reaction_family(rmgdb: RMGDatabase, +def get_reaction_family_products(rmgdb: RMGDatabase, reaction: Reaction, save_order: bool = True, ) -> Tuple[Optional['KineticsFamily'], bool]: