diff --git a/arc/mapping/driver.py b/arc/mapping/driver.py index 4c3cd090d2..e378eedd69 100644 --- a/arc/mapping/driver.py +++ b/arc/mapping/driver.py @@ -12,7 +12,6 @@ import arc.rmgdb as rmgdb from arc.mapping.engine import (assign_labels_to_products, are_adj_elements_in_agreement, - cut_species_for_mapping, create_qc_mol, flip_map, fingerprint, @@ -24,7 +23,9 @@ map_pairs, iterative_dfs, map_two_species, pairing_reactants_and_products_for_mapping, - prepare_reactants_and_products_for_scissors, + copy_species_list_for_mapping, + find_all_bdes, + cut_species_based_on_atom_indices, update_xyz, RESERVED_FINGERPRINT_KEYS,) from arc.common import logger @@ -234,17 +235,24 @@ def map_rxn(rxn: 'ARCReaction', """ # step 1: rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=rxn, backend=backend) + + if not rmg_reactions: + return None + r_label_dict, p_label_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=rxn, rmg_reaction=rmg_reactions[0]) # step 2: assign_labels_to_products(rxn, p_label_dict) - reactants, products, loc_r, loc_p = prepare_reactants_and_products_for_scissors(rxn, r_label_dict, p_label_dict) + #step 3: - label_species_atoms(reactants) - label_species_atoms(products) - r_cuts = cut_species_for_mapping(reactants, loc_r) - p_cuts = cut_species_for_mapping(products, loc_p) + reactants, products = copy_species_list_for_mapping(rxn.r_species), copy_species_list_for_mapping(rxn.p_species) + label_species_atoms(reactants), label_species_atoms(products) + + r_bdes, p_bdes = find_all_bdes(rxn, r_label_dict, True), find_all_bdes(rxn, p_label_dict, False) + + r_cuts = cut_species_based_on_atom_indices(reactants, r_bdes) + p_cuts = cut_species_based_on_atom_indices(products, p_bdes) try: make_bond_changes(rxn, r_cuts, r_label_dict) diff --git a/arc/mapping/engine.py b/arc/mapping/engine.py index c11ba816a6..3d1d16ab88 100644 --- a/arc/mapping/engine.py +++ b/arc/mapping/engine.py @@ -1016,89 +1016,6 @@ def flip_map(atom_map: Optional[List[int]]) -> Optional[List[int]]: return flipped_map -def prepare_reactants_and_products_for_scissors(rxn: 'ARCReaction', - r_label_dict: dict, - p_label_dict: dict, - ) -> Tuple[List[ARCSpecies], List[ARCSpecies], List[int], List[int]]: - """ - Prepares the species to be scissored. - - Args: - rxn: ARC reaction object to be mapped - r_label_dict: the labels of the reactants - p_label_dict: the labels of the products - - Returns: - The species in the reactants and products where a bond was broken or formed. - """ - - breaks = list() - forms = list() - actions = rxn.family.forward_recipe.actions - for action in actions: - if action[0].lower() == "BREAK_BOND".lower(): - breaks.append(action) - elif action[0].lower() == "FORM_BOND".lower(): - forms.append(action) - reactants, products, loc_r, loc_p = [0]*len(rxn.r_species), [0]*len(rxn.p_species), [0]*len(rxn.r_species), [0]*len(rxn.p_species) - for broken_bond in breaks: - location = 0 - index = 0 - for reactant in rxn.r_species: - if not r_label_dict[broken_bond[1]] < reactant.number_of_atoms + index: - location += 1 - index += reactant.number_of_atoms - else: - if loc_r[location] >= 1: - loc_r[location] += 1 - reactants[location].bdes += [(r_label_dict[broken_bond[1]] + 1 - index, r_label_dict[broken_bond[3]] + 1 - index)] - else: - loc_r[location] += 1 - reactants[location] = ARCSpecies(label="".join(sorted( - [key_by_val(r_label_dict, r_label_dict[broken_bond[1]]), - key_by_val(p_label_dict, p_label_dict[broken_bond[3]])])), - mol = reactant.mol.copy(deep=True), - bdes = [(r_label_dict[broken_bond[1]] + 1 - index, - r_label_dict[broken_bond[3]] + 1 - index)]) - reactants[location].final_xyz = reactant.get_xyz() - for mol1, mol2 in zip(reactants[location].mol.atoms, reactant.mol.atoms): - mol1.label = mol2.label - break - - for formed_bond in forms: - location = 0 - index = 0 - for product in rxn.p_species: - if not p_label_dict[formed_bond[1]] < product.number_of_atoms + index: - location += 1 - index += product.number_of_atoms - else: - if loc_p[location] >= 1: - loc_p[location]+=1 - products[location].bdes += [(p_label_dict[formed_bond[1]] + 1 - index, p_label_dict[formed_bond[3]] + 1 - index)] - else: - loc_p[location] += 1 - products[location] = ARCSpecies(label="".join(sorted( - [key_by_val(p_label_dict, p_label_dict[formed_bond[1]]), - key_by_val(p_label_dict, p_label_dict[formed_bond[3]])])), - mol = product.mol.copy(deep=True), - xyz = product.get_xyz(), - bdes = [(p_label_dict[formed_bond[1]] + 1 - index, - p_label_dict[formed_bond[3]] + 1 - index)]) - for mol1, mol2 in zip(products[location].mol.atoms, product.mol.atoms): - mol1.label = mol2.label - break - for index, value in enumerate(loc_r): - if value == 0: - reactants[index] = rxn.r_species[index] - - for index, value in enumerate(loc_p): - if value == 0: - products[index] = rxn.p_species[index] - - return reactants, products, loc_r, loc_p - - def make_bond_changes(rxn: 'ARCReaction', r_cuts: List[ARCSpecies], r_label_dict: Dict): @@ -1165,95 +1082,6 @@ def assign_labels_to_products(rxn: 'ARCReaction', atom_index+=1 -def multiple_cut_on_species(spc, bdes): - """ - A function that recursively calles an is called by cut_species_for_mapping. - Is used to cut a species and reassign the other BDE's to the other cuts, and recalling cut_species_for_mapping. - Args: - spc (ARCSpecies): a species with more then one BDE's (marks for scission). - bdes (list(tuple(int))): the required BDE's. - Returns: - list(ARCSpecies): a list of the cut products. - """ - bdes = spc.bdes - spc.bdes = [bdes[0]] - bdes = bdes[1:] - try: - cuts = spc.scissors() - except SpeciesError: - return None - for species in cuts: - species.final_xyz = species.get_xyz(generate=False) - indinces = [int(atom.label) for atom in species.mol.copy(deep=True).atoms] - new_bdes = list() - for bde in bdes: - if bde[0]-1 in indinces and bde[1]-1 in indinces: - new_bdes.append((indinces.index(bde[0]-1) + 1, indinces.index(bde[1]-1) + 1)) - species.bdes = new_bdes - return cut_species_for_mapping(cuts, [len(species.bdes or list()) for species in cuts]) - - -def cut_species_for_mapping(species, locs): - """ - A function for performing the necessary scission of species for mapping purposes. Can perform appropriate scission of multiple bonds at once. - Args: - species (list(ARCSpecies)): the species (reactants or products), marked for scission. - locs (list(int)): the number of cuts that is required for each species. - Returns: - list(ARCSpecies): a list of the cut products. - """ - cuts = list() - for spc, loc in zip(species, locs): - spc.final_xyz = spc.get_xyz() - if spc.mol.copy(deep=True).smiles == "[H][H]" and loc != 0: # scissors should return one species - labels = [atom.label for atom in spc.mol.copy(deep=True).atoms] - try: - H1 = spc.scissors()[0] - except SpeciesError: - return None - H2 = H1.copy() - H2.mol.atoms[0].label = labels[0] if H1.mol.atoms[0].label != labels[0] else labels[1] - cuts += [H1, H2] - elif loc == 0: - cuts += [spc] - elif loc == 1: - try: - cuts += spc.scissors() - except SpeciesError: - return None - else: - bdes = spc.bdes - cuts += multiple_cut_on_species(spc, bdes) - return cuts - - -def find_main_cut_product(cuts: List["ARCSpecies"], - reactant: "ARCSpecies", - bde: Tuple[int] - )->Tuple["ARCSpecies", "ARCSpecies"]: - """ - Differentiate the main product from scissors product. - Strategy: use the other BDE, if the indices of the atoms are in the bdes, it should be able to cut there. - - Args: - cuts: The cut products - reactant: The reactant with multiple bdes - bde: the BDE used for scissors. - - Returns: - Tuple["ARCSpecies", "ARCSpecies"] in the correct order. - """ - list_atom_labels_cuts_0 = [int(atom.label)+1 for atom in cuts[0].mol.atoms] - bdes = reactant.bdes - for bd in bdes: - if bd == bde: - continue - elif bd[0] not in list_atom_labels_cuts_0: - return cuts[1], cuts[0] - - return cuts[0], cuts[1] - - def update_xyz(spcs: List[ARCSpecies]) -> List[ARCSpecies]: """ A helper function, updates the xyz values of each species after cutting. This is important, since the @@ -1363,31 +1191,103 @@ def glue_maps(maps, pairs_of_reactant_and_products): p_atoms = pair[1].mol.atoms for map_index, r_atom in zip(_map, r_atoms): am_dict[int(r_atom.label)] = int(p_atoms[map_index].label) - return [val for key, val in sorted(am_dict.items(), key=lambda item: item[0])] - - -def cuts_on_cycle_of_labeled_mol(spc: 'ARCSpecies')-> bool: - """A helper function determining whether or not the scission site opens a cycle. + return [val for _, val in sorted(am_dict.items(), key=lambda item: item[0])] - Args: - spc1: ARCSpecies with a bdes - Returns: - True if the scission site is on a ring, None if the speceis is unlabeled, False otherwise""" - if not any([atom.label for atom in spc.mol.atoms]): - raise ValueError("cuts_on_cycle_of_labeled_mol recives labeled ARCSpecies only, got an unlabeld species") +def determine_bdes_on_spc_based_on_atom_labels(spc: "ARCSpecies", bde: Tuple[int, int]) -> bool: + """ + A function for determining whether or not the species in question containt the bond specified by the bond dissociation indices. + Also, assigns the correct BDE to the species. + + Args: + spc (ARCSpecies): The species in question, with labels atom indices. + bde (Tuple[int, int]): The bde in question. + add_bdes (bool): Whether or not to add the bde to the species. - if not spc.mol.is_cyclic(): + Returns: + bool: Whether or not the bde is based on the atom labels. + """ + bde = convert_list_index_0_to_1(bde, direction=-1) + index1, index2 = bde[0], bde[1] + new_bde = list() + atoms = list() + for index, atom in enumerate(spc.mol.atoms): + if atom.label == str(index1) or atom.label == str(index2): + new_bde.append(index+1) + atoms.append(atom) + if len(new_bde) == 2: + break + + if len(new_bde) == 2 and atoms[1] in atoms[0].bonds.keys(): + spc.bdes = [tuple(new_bde)] + return True + else: return False - k = spc.mol.get_deterministic_sssr() - labels = [[] for i in range(len(k))] - for index, cycle in enumerate(k): - for atom in cycle: - labels[index].append(int(atom.label)) + +def cut_species_based_on_atom_indices(species: List["ARCSpecies"], bdes: List[Tuple[int, int]]) -> Optional[List["ARCSpecies"]]: + """ + A function for scissoring species based on their atom indices. + Args: + species (List[ARCSpecies]): The species list that requires scission. + bdes (List[Tuple[int, int]]): A list of the atoms between which the bond should be scissored. The atoms are described using the atom labels, and not the actuall atom positions. + Returns: + Optional[List["ARCSpecies"]]: The species list input after the scission. + """ + if not bdes: + return species - for bde in spc.bdes: - for cycle in labels: - if bde[0]-1 in cycle and bde[1]-1 in cycle: - return True - return False + for bde in bdes: + for index, spc in enumerate(species): + if determine_bdes_on_spc_based_on_atom_labels(spc, bde): + candidate = species.pop(index) + candidate.final_xyz = candidate.get_xyz() + if candidate.mol.copy(deep=True).smiles == "[H][H]": + labels = [atom.label for atom in candidate.mol.copy(deep=True).atoms] + try: + h1 = candidate.scissors()[0] + except SpeciesError: + return None + h2 = h1.copy() + h2.mol.atoms[0].label = labels[0] if h1.mol.atoms[0].label != labels[0] else labels[1] + species += [h1, h2] + else: + try: + species += candidate.scissors() + except SpeciesError: + return None + break + + return species + + +def copy_species_list_for_mapping(species: List["ARCSpecies"]) -> List["ARCSpecies"]: + """ + A helper function for copying the species list for mapping. Also keeps the atom indices when copying. + Args: + species (List[ARCSpecies]): The species list to be copied. + Returns: + List[ARCSpecies]: The copied species list. + """ + copies = [spc.copy() for spc in species] + for copy, spc in zip(copies, species): + for atom1, atom2 in zip(copy.mol.atoms, spc.mol.atoms): + atom1.label = atom2.label + return copies + + +def find_all_bdes(rxn: "ARCReaction", label_dict: dict, is_reactants: bool) -> List[Tuple[int, int]]: + """ + A function for finding all the broken(/formed) bonds during a chemical reaction, based on the atom indices. + Args: + rxn (ARCReaction): The reaction in question. + label_dict (dict): A dictionary of the atom indices to the atom labels. + is_reactants (bool): Whether or not the species list represents reactants or products. + Returns: + List[Tuple[int, int]]: A list of tuples of the form (atom_index1, atom_index2) for each broken bond. Note that these represent the atom indicies to be cut, and not final BDEs. + """ + bdes = list() + for action in rxn.family.forward_recipe.actions: + if action[0].lower() == ("break_bond" if is_reactants else "form_bond"): + bdes.append((label_dict[action[1]] + 1, label_dict[action[3]] + 1)) + return bdes diff --git a/arc/mapping/engine_test.py b/arc/mapping/engine_test.py index 701006a9fb..eaf6cd2368 100644 --- a/arc/mapping/engine_test.py +++ b/arc/mapping/engine_test.py @@ -552,26 +552,6 @@ def test_assign_labels_to_products(self): self.assertEqual(self.p_label_dict_rxn_1[atom.label], index) index+=1 - def test_prepare_reactants_and_products_for_scissors(self): - rxn_1_test = ARCReaction(r_species=[self.r_1, self.r_2], p_species=[self.p_1, self.p_2]) - rxn_1_test.determine_family(self.db) - assign_labels_to_products(rxn_1_test, self.p_label_dict_rxn_1) - reactants, products, loc_r, loc_p = prepare_reactants_and_products_for_scissors(rxn_1_test, - self.r_label_dict_rxn_1, - self.p_label_dict_rxn_1) - - for reactant in reactants: - self.assertIn(reactant.mol.smiles,["CC(C)F", "[CH3]"]) - - for product in products: - self.assertIn(product.mol.smiles,["C[CH]C", "CF"]) - - self.assertEqual(loc_r, [1,0]) - self.assertEqual(loc_p, [0,1]) - - self.assertEqual(reactants[0].bdes,[(1+1,3+1)]) - self.assertEqual(products[1].bdes, [(11+1-products[0].number_of_atoms, 10+1-products[0].number_of_atoms)]) - def test_inc_vals(self): """Test creating an atom map via map_two_species() and incrementing all values""" spc1 = ARCSpecies(label='CH4', smiles='C', xyz=self.ch4_xyz) @@ -583,9 +563,9 @@ def test_label_species_atoms(self): rxn_1_test = ARCReaction(r_species=[self.r_1, self.r_2], p_species=[self.p_1, self.p_2]) rxn_1_test.determine_family(self.db) assign_labels_to_products(rxn_1_test, self.p_label_dict_rxn_1) - reactants, products, loc_r, loc_p = prepare_reactants_and_products_for_scissors(rxn_1_test, - self.r_label_dict_rxn_1, - self.p_label_dict_rxn_1) + + reactants, products = copy_species_list_for_mapping(rxn_1_test.r_species), copy_species_list_for_mapping(rxn_1_test.p_species) + label_species_atoms(reactants) label_species_atoms(products) @@ -601,15 +581,18 @@ def test_label_species_atoms(self): self.assertEqual(atom.label,str(index)) index +=1 - def test_cut_species_for_mapping(self): + def test_cut_species_based_on_atom_indices(self): """test the cut_species_for_mapping function""" rxn_1_test = ARCReaction(r_species=[self.r_1, self.r_2], p_species=[self.p_1, self.p_2]) rxn_1_test.determine_family(self.db) - reactants, products, loc_r, loc_p = prepare_reactants_and_products_for_scissors(rxn_1_test, - self.r_label_dict_rxn_1, - self.p_label_dict_rxn_1) - r_cuts = cut_species_for_mapping(reactants, loc_r) - p_cuts = cut_species_for_mapping(products, loc_p) + reactants, products = copy_species_list_for_mapping(rxn_1_test.r_species), copy_species_list_for_mapping(rxn_1_test.p_species) + label_species_atoms(reactants), label_species_atoms(products) + + r_bdes, p_bdes = find_all_bdes(rxn_1_test, self.r_label_dict_rxn_1, True), find_all_bdes(rxn_1_test, self.p_label_dict_rxn_1, False) + + r_cuts = cut_species_based_on_atom_indices(reactants, r_bdes) + p_cuts = cut_species_based_on_atom_indices(products, p_bdes) + self.assertIn("C[CH]C", [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts]) self.assertIn("[F]", [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts]) @@ -621,34 +604,35 @@ def test_cut_species_for_mapping(self): spc = ARCSpecies(label="test", smiles="CNC", bdes = [(1, 2), (2, 3)]) for i, a in enumerate(spc.mol.atoms): a.label=str(i) - cuts = cut_species_for_mapping([spc], [2]) + cuts = cut_species_based_on_atom_indices([spc], [(1, 2), (2, 3)]) self.assertEqual(len(cuts), 3) for cut in cuts: self.assertTrue(any([cut.mol.copy(deep=True).is_isomorphic(ARCSpecies(label="1", smiles="[CH3]").mol), cut.mol.copy(deep=True).is_isomorphic(ARCSpecies(label="2", smiles="[NH]").mol)])) - cuts = cut_species_for_mapping([ARCSpecies(label="H2", smiles="[H][H]", bdes=[(1, 2)])], [1]) + h2 = ARCSpecies(label="H2", smiles="[H][H]") + label_species_atoms([h2]) + + cuts = cut_species_based_on_atom_indices([h2], [(1, 2)]) self.assertEqual(len(cuts), 2) for cut in cuts: self.assertEqual(cut.get_xyz()["symbols"], ('H',)) - - def test_multiple_cut_on_species(self): - """test the multiple_cut_on_species function""" - spc = ARCSpecies(label="test", smiles="NCN", bdes = [(1, 2), (2, 3)]) - for i, a in enumerate(spc.mol.atoms): - a.label=str(i) - spc.final_xyz = spc.get_xyz() - cuts = multiple_cut_on_species(spc, spc.bdes) - for cut in cuts: - self.assertTrue(any([cut.mol.copy(deep=True).is_isomorphic(ARCSpecies(label="1", smiles="[CH2]").mol), - cut.mol.copy(deep=True).is_isomorphic(ARCSpecies(label="2", smiles="[NH2]").mol)])) + + spcs = [ARCSpecies(label="r", smiles = 'O=C(O)CCF')] + label_species_atoms(spcs) + cuts = cut_species_based_on_atom_indices(spcs, [(6, 5), (4, 2), (3, 7)]) + self.assertEqual(len(cuts), 4) def test_r_cut_p_cut_isomorphic(self): rxn_1_test = ARCReaction(r_species=[self.r_1, self.r_2], p_species=[self.p_1, self.p_2]) rxn_1_test.determine_family(self.db) - reactants, products, loc_r, loc_p = prepare_reactants_and_products_for_scissors(rxn_1_test,self.r_label_dict_rxn_1,self.p_label_dict_rxn_1) - r_cuts = cut_species_for_mapping(reactants, loc_r) - p_cuts = cut_species_for_mapping(products, loc_p) + reactants, products = copy_species_list_for_mapping(rxn_1_test.r_species), copy_species_list_for_mapping(rxn_1_test.p_species) + label_species_atoms(reactants), label_species_atoms(products) + + r_bdes, p_bdes = find_all_bdes(rxn_1_test, self.r_label_dict_rxn_1, True), find_all_bdes(rxn_1_test, self.p_label_dict_rxn_1, False) + + r_cuts = cut_species_based_on_atom_indices(reactants, r_bdes) + p_cuts = cut_species_based_on_atom_indices(products, p_bdes) self.assertTrue(r_cut_p_cut_isomorphic(self.spc1,self.spc2)) for r_cut in r_cuts: @@ -664,11 +648,14 @@ def test_pairing_reactants_and_products_for_mapping(self): smiles = ["[F]", "C[CH]C", "[CH3]"] rxn_1_test = ARCReaction(r_species=[self.r_1, self.r_2], p_species=[self.p_1, self.p_2]) rxn_1_test.determine_family(self.db) - reactants, products, loc_r, loc_p = prepare_reactants_and_products_for_scissors(rxn_1_test, - self.r_label_dict_rxn_1, - self.p_label_dict_rxn_1) - r_cuts = cut_species_for_mapping(reactants, loc_r) - p_cuts = cut_species_for_mapping(products, loc_p) + reactants, products = copy_species_list_for_mapping(rxn_1_test.r_species), copy_species_list_for_mapping(rxn_1_test.p_species) + label_species_atoms(reactants), label_species_atoms(products) + + r_bdes, p_bdes = find_all_bdes(rxn_1_test, self.r_label_dict_rxn_1, True), find_all_bdes(rxn_1_test, self.p_label_dict_rxn_1, False) + + r_cuts = cut_species_based_on_atom_indices(reactants, r_bdes) + p_cuts = cut_species_based_on_atom_indices(products, p_bdes) + pairs_of_reactant_and_products = pairing_reactants_and_products_for_mapping(r_cuts, p_cuts) for pair in pairs_of_reactant_and_products: @@ -682,28 +669,29 @@ def test_map_pairs(self): rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=rxn_1_test, backend="ARC") r_label_dict, p_label_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=rxn_1_test, rmg_reaction=rmg_reactions[0]) - assign_labels_to_products(rxn_1_test, p_label_dict) - reactants, products,loc_r,loc_p = prepare_reactants_and_products_for_scissors(rxn_1_test, r_label_dict, p_label_dict) - label_species_atoms(reactants) - label_species_atoms(products) - r_cuts = cut_species_for_mapping(reactants, loc_r) - p_cuts = cut_species_for_mapping(products, loc_p) + reactants, products = copy_species_list_for_mapping(rxn_1_test.r_species), copy_species_list_for_mapping(rxn_1_test.p_species) + label_species_atoms(reactants), label_species_atoms(products) + + r_bdes, p_bdes = find_all_bdes(rxn_1_test, r_label_dict, True), find_all_bdes(rxn_1_test, p_label_dict, False) + + r_cuts = cut_species_based_on_atom_indices(reactants, r_bdes) + p_cuts = cut_species_based_on_atom_indices(products, p_bdes) pairs_of_reactant_and_products = pairing_reactants_and_products_for_mapping(r_cuts, p_cuts) maps = map_pairs(pairs_of_reactant_and_products) - for Map in maps: - if len(Map) == 1: - self.assertEqual(Map[0], 0) - elif len(Map) == 4: - self.assertEqual(Map[0], 0) - for i in Map[1:]: + for map_ in maps: + if len(map_) == 1: + self.assertEqual(map_[0], 0) + elif len(map_) == 4: + self.assertEqual(map_[0], 0) + for i in map_[1:]: self.assertIn(i, [1, 2, 3]) - self.assertEqual(len(np.unique(Map)), len(Map)) + self.assertEqual(len(np.unique(map_)), len(map_)) else: - self.assertEqual(Map[:3], [0, 1, 2]) - self.assertIn(tuple(Map[3:6]), list(itertools.permutations([3, 4, 5]))) - self.assertEqual(Map[6], 6) - self.assertIn(tuple(Map[7:]), list(itertools.permutations([7, 8, 9]))) - self.assertEqual(len(np.unique(Map)), len(Map)) + self.assertEqual(map_[:3], [0, 1, 2]) + self.assertIn(tuple(map_[3:6]), list(itertools.permutations([3, 4, 5]))) + self.assertEqual(map_[6], 6) + self.assertIn(tuple(map_[7:]), list(itertools.permutations([7, 8, 9]))) + self.assertEqual(len(np.unique(map_)), len(map_)) def test_glue_maps(self): rxn_1_test = ARCReaction(r_species=[self.r_1, self.r_2], p_species=[self.p_1, self.p_2]) @@ -712,11 +700,13 @@ def test_glue_maps(self): r_label_dict, p_label_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=rxn_1_test, rmg_reaction=rmg_reactions[0]) assign_labels_to_products(rxn_1_test, p_label_dict) - reactants, products,loc_r,loc_p = prepare_reactants_and_products_for_scissors(rxn_1_test, r_label_dict, p_label_dict) - label_species_atoms(reactants) - label_species_atoms(products) - r_cuts = cut_species_for_mapping(reactants, loc_r) - p_cuts = cut_species_for_mapping(products, loc_p) + reactants, products = copy_species_list_for_mapping(rxn_1_test.r_species), copy_species_list_for_mapping(rxn_1_test.p_species) + label_species_atoms(reactants), label_species_atoms(products) + + r_bdes, p_bdes = find_all_bdes(rxn_1_test, r_label_dict, True), find_all_bdes(rxn_1_test, p_label_dict, False) + + r_cuts = cut_species_based_on_atom_indices(reactants, r_bdes) + p_cuts = cut_species_based_on_atom_indices(products, p_bdes) pairs_of_reactant_and_products = pairing_reactants_and_products_for_mapping(r_cuts, p_cuts) maps = map_pairs(pairs_of_reactant_and_products) atom_map = glue_maps(maps,pairs_of_reactant_and_products) @@ -1410,9 +1400,10 @@ def test_make_bond_changes(self): r_label_dict, p_label_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=rxn, rmg_reaction=rmg_reactions[0]) assign_labels_to_products(rxn, p_label_dict) - reactants, _, loc_r, _ = prepare_reactants_and_products_for_scissors(rxn, r_label_dict, p_label_dict) - label_species_atoms(reactants) - r_cuts = cut_species_for_mapping(reactants, loc_r) + reactants, products = copy_species_list_for_mapping(rxn.r_species), copy_species_list_for_mapping(rxn.p_species) + label_species_atoms(reactants), label_species_atoms(products) + r_bdes, _ = find_all_bdes(rxn, r_label_dict, True), find_all_bdes(rxn, p_label_dict, False) + r_cuts = cut_species_based_on_atom_indices(reactants, r_bdes) self.assertFalse(r_cuts[1].mol.is_isomorphic(rxn.p_species[1].mol)) make_bond_changes(rxn=rxn, r_cuts=r_cuts, @@ -1444,25 +1435,6 @@ def test_update_xyz(self): atoms = [atom.element.symbol for atom in spc.mol.atoms] for label1,label2 in zip(atoms, xyz): self.assertEqual(label1, label2) - - def test_cuts_on_cycle_of_labeled_mol(self): - """test the cuts_on_cycle_of_labeled_mol function""" - spc1 = ARCSpecies(label = "A", smiles="NC1=NC=NC2=C1N=CN2", bdes = [(6, 7)]) - try: - cuts_on_cycle_of_labeled_mol(spc1) - except ValueError as e: - self.assertEqual(e.args[0], "cuts_on_cycle_of_labeled_mol recives labeled ARCSpecies only, got an unlabeld species") - for index, atom in enumerate(spc1.mol.atoms): - atom.label = str(index) - self.assertTrue(cuts_on_cycle_of_labeled_mol(spc1)) - spc1.bdes = [(1, 2)] - self.assertFalse(cuts_on_cycle_of_labeled_mol(spc1)) - spc1.bdes = [(1, 2), (6, 7)] - self.assertTrue(cuts_on_cycle_of_labeled_mol(spc1)) - spc2 = ARCSpecies(label = "propane", smiles = "CCC",bdes = [(1, 2)]) - for index, atom in enumerate(spc2.mol.atoms): - atom.label = str(index) - self.assertFalse(cuts_on_cycle_of_labeled_mol(spc2)) def test_add_adjacent_hydrogen_atoms_to_map_based_on_a_specific_torsion(self): "test the add_adjacent_hydrogen_atoms_to_map_based_on_a_specific_torsion function" @@ -1502,5 +1474,31 @@ def test_add_adjacent_hydrogen_atoms_to_map_based_on_a_specific_torsion(self): for key in range(4): self.assertEqual(out_dict[key], key) + def test_find_all_bdes(self): + """tests the find_all_bdes function""" + rxn = ARCReaction(r_species=[self.r_1, self.r_2], p_species=[self.p_1, self.p_2]) + rxn.determine_family(self.db) + bdes = find_all_bdes(rxn, self.r_label_dict_rxn_1, True) + self.assertEqual(bdes, [(2, 4)]) + + def test_copy_species_list_for_mapping(self): + """tests the copy_species_list_for_mapping function""" + species = [ARCSpecies(label="test_1", smiles = "BrC(F)Cl"), ARCSpecies(label="test_2", smiles = "OOC(F)CCCONNO")] + label_species_atoms(species) + species_copy = copy_species_list_for_mapping(species) + for s1, s2 in zip(species, species_copy): + self.assertIsNot(s1, s2) + self.assertTrue(s1.mol.is_isomorphic(s2.mol)) + for atom1, atom2 in zip(s1.mol.atoms, s2.mol.atoms): + self.assertIsNot(atom1, atom2) + self.assertEqual(atom1.label, atom2.label) + + def test_determine_bdes_on_spc_based_on_atom_labels(self): + """tests the determine_bdes_indices_based_on_atom_labels function""" + spc = ARCSpecies(label="ethane", smiles="CC") + label_species_atoms([spc]) + self.assertTrue(determine_bdes_on_spc_based_on_atom_labels(spc, (1, 2))) + self.assertFalse(determine_bdes_on_spc_based_on_atom_labels(spc, (2, 3))) + if __name__ == '__main__': unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) \ No newline at end of file diff --git a/arc/reaction_test.py b/arc/reaction_test.py index b95ea99c80..93753ac62e 100644 --- a/arc/reaction_test.py +++ b/arc/reaction_test.py @@ -35,7 +35,9 @@ def setUpClass(cls): """ cls.maxDiff = None cls.rmgdb = rmgdb.make_rmg_database_object() - rmgdb.load_families_only(cls.rmgdb) + cls.rmgdb.kinetics.families = None + rmgdb.load_families_only(cls.rmgdb, "all") + cls.h2_xyz = {'coords': ((0, 0, 0.3736550), (0, 0, -0.3736550)), 'isotopes': (1, 1), 'symbols': ('H', 'H')} cls.o2_xyz = {'coords': ((0, 0, 0.6487420), (0, 0, -0.6487420)), 'isotopes': (16, 16), 'symbols': ('O', 'O')} cls.co_xyz = {'coords': ((0, 0, -0.6748240), (0, 0, 0.5061180)), 'isotopes': (12, 16), 'symbols': ('C', 'O')} @@ -1645,6 +1647,24 @@ def test_get_atom_map(self): self.assertIn(atom_map[int(atom.label)], c_symmetry_h_2 if atom.symbol == "C" else h_symmetry2) self.assertTrue(check_atom_map(rxn=rxn)) + def test_mapping_XY_elimination_hydroxyl(self): + """Test mapping of a reaction family XY_elimination_hydroxyl""" + rxn = ARCReaction(r_species=[ARCSpecies(label="r", smiles = 'O=C(O)CCF')], + p_species=[ARCSpecies(label="p1", smiles='C=C'), + ARCSpecies(label="p2", smiles="F"), + ARCSpecies(label="p3", smiles="O=C=O")]) + rxn.determine_family(rmg_database=self.rmgdb) + if not rxn.family: # reaction family not found for some reason. + rxn.family = self.rmgdb.kinetics.families["XY_elimination_hydroxyl"] + atom_map = rxn.atom_map + self.assertIsNotNone(rxn.family) + self.assertTrue(check_atom_map(rxn=rxn)) + self.assertIn(atom_map[:3], [[8, 9, 10], [10, 9, 8]]) + self.assertIn(atom_map[3:5], [[0, 1], [1, 0]]) + self.assertEqual(atom_map[5:7], [6, 7]) + self.assertIn(atom_map[7:9], [[4, 5], [5, 4]]) + self.assertIn(atom_map[9:], [[2, 3], [3, 2]]) + def test_get_reactants_xyz(self): """Test getting a combined string/dict representation of the cartesian coordinates of all reactant species""" ch3nh2_xyz = {'coords': ((-0.5734111454228507, 0.0203516083213337, 0.03088703933770556),