Skip to content

Commit

Permalink
Merge branch 'crest' into crest_debug_statement
Browse files Browse the repository at this point in the history
  • Loading branch information
calvinp0 committed Dec 15, 2024
2 parents b7eb9f1 + af9313c commit 789c87e
Showing 1 changed file with 129 additions and 81 deletions.
210 changes: 129 additions & 81 deletions arc/job/adapters/ts/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import subprocess
import os
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import re

import numpy as np
import pandas as pd

from rmgpy.exceptions import ActionError
from rmgpy.molecule.molecule import Molecule
Expand All @@ -36,7 +38,15 @@
from arc.job.adapters.common import _initialize_adapter, ts_adapters_by_rmg_family
from arc.job.factory import register_job_adapter
from arc.plotter import save_geo
from arc.species.converter import compare_zmats, relocate_zmat_dummy_atoms_to_the_end, zmat_from_xyz, zmat_to_xyz, str_to_xyz, xyz_to_str, str_to_str
from arc.species.converter import (compare_zmats,
relocate_zmat_dummy_atoms_to_the_end,
zmat_from_xyz,
zmat_to_xyz,
str_to_xyz,
xyz_to_str,
str_to_str,
xyz_to_str,
xyz_to_dmat)
from arc.mapping.engine import map_arc_rmg_species, map_two_species
from arc.species.species import ARCSpecies, TSGuess, colliding_atoms
from arc.species.zmat import get_parameter_from_atom_indices, remove_1st_atom, up_param
Expand Down Expand Up @@ -1050,91 +1060,35 @@ def h_abstraction(arc_reaction: 'ARCReaction',
xyz_guesses.append(xyz_guess)

if xyz_guesses:
# Take the first guess from the list of unique guesses.
xyz_guesses_crest = xyz_guesses[0]

####
from arc.species.converter import xyz_to_dmat
import pandas as pd
import re
# 1. Convert xyz to dmat
ts_dmat = xyz_to_dmat(xyz_guesses_crest)
# 2. Create DataFrame & append ints to columns/index symbols
ts_df = pd.DataFrame(ts_dmat, index=xyz_guesses_crest["symbols"], columns=xyz_guesses_crest["symbols"])
org_labels = list(ts_df.columns)
row_label_mapping = [str(str(label) + str(i)) for i, label in enumerate(org_labels)]
column_label_mapping = [str(str(label) + str(i)) for i, label in enumerate(org_labels)]
ts_df.columns = column_label_mapping
ts_df.index = row_label_mapping
# 3. Filter Index(H), Column(~H)
columns_mask = ~ts_df.columns.str.startswith('H')
columns_to_remove = ts_df.columns[columns_mask]

rows_mask = ts_df.columns.str.startswith('H')
rows_to_keep = ts_df.index[rows_mask]

ts_df_filt = ts_df.loc[rows_to_keep, columns_to_remove]

# 4. Get min values per H
min_values_per_H = ts_df_filt.min(axis=1)
# 5. Get max value H
max_min_H_values = min_values_per_H.max()

max_H_rows = min_values_per_H[min_values_per_H == max_min_H_values].index.tolist()

row_col_pairs = [
(h_row, col)
for h_row in max_H_rows
for col in ts_df_filt.columns[ts_df_filt.loc[h_row] == min_values_per_H[h_row]].tolist()
]
print("Row/Col of the lowest value: ", row_col_pairs)
h_row = ts_df_filt.loc[row_col_pairs[0][0]]
unique_sorted_values = h_row.sort_values().unique()

crest_run = True

if len(unique_sorted_values) >= 2:
second_lowest = unique_sorted_values[1]
print(f"The second lowest unique value in H21 is: {second_lowest}")
cols_second_lowest = h_row[h_row == second_lowest].index.tolist()

if len(cols_second_lowest) == 1:
for i, xyz_guess_crest in enumerate(xyz_guesses):
# 1. Check if dict
if isinstance(xyz_guess_crest, dict):
df_dmat = convert_xyz_to_df(xyz_guess_crest)
elif isinstance(xyz_guess_crest, str):
xyz = str_to_xyz(xyz_guess_crest)
df_dmat = convert_xyz_to_df(xyz)
elif isinstance(xyz_guess_crest, list):
xyz_temp = "\n".join(xyz_guess_crest)
xyz_to_dmat = str_to_xyz(xyz_temp)
df_dmat = convert_xyz_to_df(xyz_to_dmat)

h_str = row_col_pairs[0][0] # 'H21'
print(f"h str = {h_str}")
b_str = row_col_pairs[0][1] # 'C14'
print(f"b str {b_str}")
a_str = cols_second_lowest[0] # 'C4'
print(f"a str {a_str}")

h = int(re.findall(r'\d+', h_str)[0])
b = int(re.findall(r'\d+', b_str)[0])
a = int(re.findall(r'\d+', a_str)[0])

# log info a, b, h1, h2, b_atom, a_atom, h_atom, val_inc
logger.info(f'a: {a}, b: {b}, h: {h}')
else:
crest_run = False
print(f"Received more than one result for second lowest: {cols_second_lowest}")
else:
crest_run = False
print("H21 does not have a second lowest unique value. Will not do CREST")



####

if crest_run:
xyz_guess = crest_ts_conformer_search(xyz_guesses_crest, a, h, b, path=path)
if xyz_guess is not None:
logger.info('Successfully completed crest conformer search:'
f' {xyz_to_str(xyz_guess)}')
xyz_guesses.append(xyz_guess)
try:
h_abs_atoms_dict = get_h_abs_atoms(df_dmat)
a = h_abs_atoms_dict['A']
h = h_abs_atoms_dict['H']
b = h_abs_atoms_dict['B']
xyz_guess = crest_ts_conformer_search(xyz_guess_crest, a, h, b, path=path, xyz_crest_int=i)
if xyz_guess is not None:
logger.info('Successfully completed crest conformer search:'
f' {xyz_to_str(xyz_guess)}')
xyz_guesses.append(xyz_guess)
except (ValueError or KeyError) as e:
logger.error(f'Could not determine the H abstraction atoms, got:\n{e}')

return xyz_guesses


def crest_ts_conformer_search(xyz_guess: dict, a_atom: int, h_atom: int, b_atom: int, path: str = ''
def crest_ts_conformer_search(xyz_guess: dict, a_atom: int, h_atom: int, b_atom: int, path: str = '', xyz_crest_int: int = 0
) -> None:
"""
Perform a conformer search for the TS guess using CREST.
Expand Down Expand Up @@ -1223,5 +1177,99 @@ def crest_ts_conformer_search(xyz_guess: dict, a_atom: int, h_atom: int, b_atom:
print(f"Standard Error: {stderr.decode()}")
return None

def convert_xyz_to_df(xyz: dict) -> pd.DataFrame:
"""
Convert a dictionary of xyz coords to a pandas DataFrame with bond distances
Args:
xyz (dict): The xyz coordinates of the molecule
Return:
pd.DataFrame: The xyz coordinates as a pandas DataFrame
"""
symbols = xyz["symbols"]
symbol_enum = [f"{symbol}{i}" for i, symbol in enumerate(symbols)]
ts_dmat = xyz_to_dmat(xyz)

return pd.DataFrame(ts_dmat, columns=symbol_enum, index=symbol_enum)


def get_h_abs_atoms(dataframe: pd.DataFrame) -> dict:
"""
Get the donating/accepting hydrogen atom, and the two heavy atoms that are bonded to it
Args:
dataframe (pd.DataFrame): The dataframe of the bond distances, columns and index are the atom symbols
Returns:
dict: The hydrogen atom and the two heavy atoms. The keys are 'H', 'A', 'B'
"""
# Ensure there are at least 3 atoms in the TS
if len(dataframe) < 3:
raise ValueError("TS must contain at least 3 atoms.")
if len(dataframe) == 3 and dataframe.index.str.startswith("H").sum() == 2:
# Identify the heavy atom
heavy_atom = dataframe.index[~dataframe.index.str.startswith("H")][0] # Should be the only heavy atom
hydrogen_atoms = dataframe.index[dataframe.index.str.startswith("H")] # List of hydrogen atoms

# Get distances from the heavy atom to the two hydrogens
distances_to_hydrogens = dataframe.loc[heavy_atom, hydrogen_atoms]

# Select the hydrogen with the smallest distance to the heavy atom as `H`
hydrogen_with_min_distance = distances_to_hydrogens.idxmin()

# The other hydrogen becomes `B`
other_hydrogen = hydrogen_atoms[hydrogen_atoms != hydrogen_with_min_distance][0]

return {"H": hydrogen_with_min_distance, "A": heavy_atom, "B": other_hydrogen}

elif len(dataframe) == 4 and dataframe.index.str.startswith("H").sum() == 3:
# Identify the heavy atom
heavy_atom = dataframe.index[~dataframe.index.str.startswith("H")][0] # Should be the only heavy atom
hydrogen_atoms = dataframe.index[dataframe.index.str.startswith("H")] # List of hydrogen atoms

# Remove hydrogens from columns and the heavy atom from rows
filtered_df = dataframe.loc[hydrogen_atoms, [heavy_atom]]

# Sort the distances from the heavy atom to all hydrogens
sorted_distances = filtered_df[heavy_atom].sort_values()

# Select the hydrogen with the second furthest distance
hydrogen_with_max_distance = sorted_distances.index[-2]

# Reset the DataFrame back to the original to find the other hydrogen (`B`)
remaining_hydrogens = hydrogen_atoms[hydrogen_atoms != hydrogen_with_max_distance]
filtered_hydrogens_df = dataframe.loc[[hydrogen_with_max_distance], remaining_hydrogens]

# Find the hydrogen closest to the selected hydrogen (`H`)
closest_hydrogen = filtered_hydrogens_df.idxmin(axis=1).iloc[0]

return {"H": hydrogen_with_max_distance, "A": heavy_atom, "B": closest_hydrogen}

else:

# Filter the DataFrame for hydrogen rows and non-hydrogen columns
hydrogen_rows = dataframe.index[dataframe.index.str.startswith("H")]
heavy_atom_columns = dataframe.columns[~dataframe.columns.str.startswith("H")]

filtered_df = dataframe.loc[hydrogen_rows, heavy_atom_columns]

# Find the hydrogen atom with the smallest bond distance to a heavy atom
min_distances = filtered_df.min(axis=1)
min_distances = min_distances[min_distances <= 2.0]
hydrogen_with_min_distance = min_distances.idxmax()
min_distance_column = filtered_df.loc[hydrogen_with_min_distance].idxmin()

# Handle cases with multiple heavy atoms
remaining_columns = dataframe.columns[
~dataframe.columns.isin([hydrogen_with_min_distance, min_distance_column])
]
remaining_df = dataframe.loc[[hydrogen_with_min_distance], remaining_columns]
second_closest_atom = remaining_df.idxmin(axis=1).iloc[0]

return {"H": hydrogen_with_min_distance, "A": min_distance_column, "B": second_closest_atom}


register_job_adapter('heuristics', HeuristicsAdapter)

0 comments on commit 789c87e

Please sign in to comment.