Skip to content

Commit

Permalink
Update Crest Loop
Browse files Browse the repository at this point in the history
Loop through all the different heuristic TSs to generate new TSs via Crest
  • Loading branch information
calvinp0 committed Dec 15, 2024
1 parent 04755cf commit af9313c
Showing 1 changed file with 129 additions and 81 deletions.
210 changes: 129 additions & 81 deletions arc/job/adapters/ts/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import subprocess
import os
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import re

import numpy as np
import pandas as pd

from rmgpy.exceptions import ActionError
from rmgpy.molecule.molecule import Molecule
Expand All @@ -36,7 +38,15 @@
from arc.job.adapters.common import _initialize_adapter, ts_adapters_by_rmg_family
from arc.job.factory import register_job_adapter
from arc.plotter import save_geo
from arc.species.converter import compare_zmats, relocate_zmat_dummy_atoms_to_the_end, zmat_from_xyz, zmat_to_xyz, str_to_xyz, xyz_to_str, str_to_str
from arc.species.converter import (compare_zmats,
relocate_zmat_dummy_atoms_to_the_end,
zmat_from_xyz,
zmat_to_xyz,
str_to_xyz,
xyz_to_str,
str_to_str,
xyz_to_str,
xyz_to_dmat)
from arc.mapping.engine import map_arc_rmg_species, map_two_species
from arc.species.species import ARCSpecies, TSGuess, colliding_atoms
from arc.species.zmat import get_parameter_from_atom_indices, remove_1st_atom, up_param
Expand Down Expand Up @@ -1050,91 +1060,35 @@ def h_abstraction(arc_reaction: 'ARCReaction',
xyz_guesses.append(xyz_guess)

if xyz_guesses:
# Take the first guess from the list of unique guesses.
xyz_guesses_crest = xyz_guesses[0]

####
from arc.species.converter import xyz_to_dmat
import pandas as pd
import re
# 1. Convert xyz to dmat
ts_dmat = xyz_to_dmat(xyz_guesses_crest)
# 2. Create DataFrame & append ints to columns/index symbols
ts_df = pd.DataFrame(ts_dmat, index=xyz_guesses_crest["symbols"], columns=xyz_guesses_crest["symbols"])
org_labels = list(ts_df.columns)
row_label_mapping = [str(str(label) + str(i)) for i, label in enumerate(org_labels)]
column_label_mapping = [str(str(label) + str(i)) for i, label in enumerate(org_labels)]
ts_df.columns = column_label_mapping
ts_df.index = row_label_mapping
# 3. Filter Index(H), Column(~H)
columns_mask = ~ts_df.columns.str.startswith('H')
columns_to_remove = ts_df.columns[columns_mask]

rows_mask = ts_df.columns.str.startswith('H')
rows_to_keep = ts_df.index[rows_mask]

ts_df_filt = ts_df.loc[rows_to_keep, columns_to_remove]

# 4. Get min values per H
min_values_per_H = ts_df_filt.min(axis=1)
# 5. Get max value H
max_min_H_values = min_values_per_H.max()

max_H_rows = min_values_per_H[min_values_per_H == max_min_H_values].index.tolist()

row_col_pairs = [
(h_row, col)
for h_row in max_H_rows
for col in ts_df_filt.columns[ts_df_filt.loc[h_row] == min_values_per_H[h_row]].tolist()
]
print("Row/Col of the lowest value: ", row_col_pairs)
h_row = ts_df_filt.loc[row_col_pairs[0][0]]
unique_sorted_values = h_row.sort_values().unique()

crest_run = True

if len(unique_sorted_values) >= 2:
second_lowest = unique_sorted_values[1]
print(f"The second lowest unique value in H21 is: {second_lowest}")
cols_second_lowest = h_row[h_row == second_lowest].index.tolist()

if len(cols_second_lowest) == 1:
for i, xyz_guess_crest in enumerate(xyz_guesses):
# 1. Check if dict
if isinstance(xyz_guess_crest, dict):
df_dmat = convert_xyz_to_df(xyz_guess_crest)
elif isinstance(xyz_guess_crest, str):
xyz = str_to_xyz(xyz_guess_crest)
df_dmat = convert_xyz_to_df(xyz)
elif isinstance(xyz_guess_crest, list):
xyz_temp = "\n".join(xyz_guess_crest)
xyz_to_dmat = str_to_xyz(xyz_temp)
df_dmat = convert_xyz_to_df(xyz_to_dmat)

h_str = row_col_pairs[0][0] # 'H21'
print(f"h str = {h_str}")
b_str = row_col_pairs[0][1] # 'C14'
print(f"b str {b_str}")
a_str = cols_second_lowest[0] # 'C4'
print(f"a str {a_str}")

h = int(re.findall(r'\d+', h_str)[0])
b = int(re.findall(r'\d+', b_str)[0])
a = int(re.findall(r'\d+', a_str)[0])

# log info a, b, h1, h2, b_atom, a_atom, h_atom, val_inc
logger.info(f'a: {a}, b: {b}, h: {h}')
else:
crest_run = False
print(f"Received more than one result for second lowest: {cols_second_lowest}")
else:
crest_run = False
print("H21 does not have a second lowest unique value. Will not do CREST")



####

if crest_run:
xyz_guess = crest_ts_conformer_search(xyz_guesses_crest, a, h, b, path=path)
if xyz_guess is not None:
logger.info('Successfully completed crest conformer search:'
f' {xyz_to_str(xyz_guess)}')
xyz_guesses.append(xyz_guess)
try:
h_abs_atoms_dict = get_h_abs_atoms(df_dmat)
a = h_abs_atoms_dict['A']
h = h_abs_atoms_dict['H']
b = h_abs_atoms_dict['B']
xyz_guess = crest_ts_conformer_search(xyz_guess_crest, a, h, b, path=path, xyz_crest_int=i)
if xyz_guess is not None:
logger.info('Successfully completed crest conformer search:'
f' {xyz_to_str(xyz_guess)}')
xyz_guesses.append(xyz_guess)
except (ValueError or KeyError) as e:
logger.error(f'Could not determine the H abstraction atoms, got:\n{e}')

return xyz_guesses


def crest_ts_conformer_search(xyz_guess: dict, a_atom: int, h_atom: int, b_atom: int, path: str = ''
def crest_ts_conformer_search(xyz_guess: dict, a_atom: int, h_atom: int, b_atom: int, path: str = '', xyz_crest_int: int = 0
) -> None:
"""
Perform a conformer search for the TS guess using CREST.
Expand Down Expand Up @@ -1223,5 +1177,99 @@ def crest_ts_conformer_search(xyz_guess: dict, a_atom: int, h_atom: int, b_atom:
print(f"Standard Error: {stderr.decode()}")
return None

def convert_xyz_to_df(xyz: dict) -> pd.DataFrame:
"""
Convert a dictionary of xyz coords to a pandas DataFrame with bond distances
Args:
xyz (dict): The xyz coordinates of the molecule
Return:
pd.DataFrame: The xyz coordinates as a pandas DataFrame
"""
symbols = xyz["symbols"]
symbol_enum = [f"{symbol}{i}" for i, symbol in enumerate(symbols)]
ts_dmat = xyz_to_dmat(xyz)

return pd.DataFrame(ts_dmat, columns=symbol_enum, index=symbol_enum)


def get_h_abs_atoms(dataframe: pd.DataFrame) -> dict:
"""
Get the donating/accepting hydrogen atom, and the two heavy atoms that are bonded to it
Args:
dataframe (pd.DataFrame): The dataframe of the bond distances, columns and index are the atom symbols
Returns:
dict: The hydrogen atom and the two heavy atoms. The keys are 'H', 'A', 'B'
"""
# Ensure there are at least 3 atoms in the TS
if len(dataframe) < 3:
raise ValueError("TS must contain at least 3 atoms.")
if len(dataframe) == 3 and dataframe.index.str.startswith("H").sum() == 2:
# Identify the heavy atom
heavy_atom = dataframe.index[~dataframe.index.str.startswith("H")][0] # Should be the only heavy atom
hydrogen_atoms = dataframe.index[dataframe.index.str.startswith("H")] # List of hydrogen atoms

# Get distances from the heavy atom to the two hydrogens
distances_to_hydrogens = dataframe.loc[heavy_atom, hydrogen_atoms]

# Select the hydrogen with the smallest distance to the heavy atom as `H`
hydrogen_with_min_distance = distances_to_hydrogens.idxmin()

# The other hydrogen becomes `B`
other_hydrogen = hydrogen_atoms[hydrogen_atoms != hydrogen_with_min_distance][0]

return {"H": hydrogen_with_min_distance, "A": heavy_atom, "B": other_hydrogen}

elif len(dataframe) == 4 and dataframe.index.str.startswith("H").sum() == 3:
# Identify the heavy atom
heavy_atom = dataframe.index[~dataframe.index.str.startswith("H")][0] # Should be the only heavy atom
hydrogen_atoms = dataframe.index[dataframe.index.str.startswith("H")] # List of hydrogen atoms

# Remove hydrogens from columns and the heavy atom from rows
filtered_df = dataframe.loc[hydrogen_atoms, [heavy_atom]]

# Sort the distances from the heavy atom to all hydrogens
sorted_distances = filtered_df[heavy_atom].sort_values()

# Select the hydrogen with the second furthest distance
hydrogen_with_max_distance = sorted_distances.index[-2]

# Reset the DataFrame back to the original to find the other hydrogen (`B`)
remaining_hydrogens = hydrogen_atoms[hydrogen_atoms != hydrogen_with_max_distance]
filtered_hydrogens_df = dataframe.loc[[hydrogen_with_max_distance], remaining_hydrogens]

# Find the hydrogen closest to the selected hydrogen (`H`)
closest_hydrogen = filtered_hydrogens_df.idxmin(axis=1).iloc[0]

return {"H": hydrogen_with_max_distance, "A": heavy_atom, "B": closest_hydrogen}

else:

# Filter the DataFrame for hydrogen rows and non-hydrogen columns
hydrogen_rows = dataframe.index[dataframe.index.str.startswith("H")]
heavy_atom_columns = dataframe.columns[~dataframe.columns.str.startswith("H")]

filtered_df = dataframe.loc[hydrogen_rows, heavy_atom_columns]

# Find the hydrogen atom with the smallest bond distance to a heavy atom
min_distances = filtered_df.min(axis=1)
min_distances = min_distances[min_distances <= 2.0]
hydrogen_with_min_distance = min_distances.idxmax()
min_distance_column = filtered_df.loc[hydrogen_with_min_distance].idxmin()

# Handle cases with multiple heavy atoms
remaining_columns = dataframe.columns[
~dataframe.columns.isin([hydrogen_with_min_distance, min_distance_column])
]
remaining_df = dataframe.loc[[hydrogen_with_min_distance], remaining_columns]
second_closest_atom = remaining_df.idxmin(axis=1).iloc[0]

return {"H": hydrogen_with_min_distance, "A": min_distance_column, "B": second_closest_atom}


register_job_adapter('heuristics', HeuristicsAdapter)

0 comments on commit af9313c

Please sign in to comment.