diff --git a/CHANGELOG.md b/CHANGELOG.md index d30bcdb..98b8c43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Changed +- to set the elemental composition it is now possible to use dicts with not only int but also the element symbols (str) +- dict keys for elemental compositions will now always be checked for validity +- Renamed GP3-xTB to g-xTB + +### Added +- `GXTBConfig` class for the g-xTB method, supporting SCF cycles check + +### Fixed +- version string is now correctly formatted and printed ## [0.5.0] - 2024-12-16 ### Changed diff --git a/README.md b/README.md index 204fd3d..ba95768 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,8 @@ def main(): config.generate.max_num_atoms = 15 config.generate.element_composition = "Ce:1-1" # alternatively as a dictionary: config.generate.element_composition = {39:(1,1)} + # or: config.generate.element_composition = {"Ce":(1,1)"} + # or as mixed-key dict, e.g. for Ce and O: {"Ce":(1,1), 7:(2,2)} config.generate.forbidden_elements = "21-30,39-48,57-80" # alternatively as a list: config.generate.forbidden_elements = [20,21,22,23] # 24,25,26... diff --git a/src/mindlessgen/cli/cli_parser.py b/src/mindlessgen/cli/cli_parser.py index 12f02e2..46bb365 100644 --- a/src/mindlessgen/cli/cli_parser.py +++ b/src/mindlessgen/cli/cli_parser.py @@ -258,6 +258,19 @@ def cli_parser(argv: Sequence[str] | None = None) -> dict: required=False, help="Maximum number of SCF cycles in ORCA.", ) + ### g-xTB specific arguments ### + parser.add_argument( + "--gxtb-path", + type=str, + required=False, + help="Path to the g-xTB binary.", + ) + parser.add_argument( + "--gxtb-scf-cycles", + type=int, + required=False, + help="Maximum number of SCF cycles in g-xTB.", + ) args = parser.parse_args(argv) args_dict = vars(args) @@ -306,6 +319,11 @@ def cli_parser(argv: Sequence[str] | None = None) -> dict: "gridsize": args_dict["orca_gridsize"], "scf_cycles": args_dict["orca_scf_cycles"], } + # g-xTB specific arguments + rev_args_dict["gxtb"] = { + "gxtb_path": args_dict["gxtb_path"], + "scf_cycles": args_dict["gxtb_scf_cycles"], + } # Postprocessing arguments rev_args_dict["postprocess"] = { "engine": args_dict["postprocess_engine"], diff --git a/src/mindlessgen/generator/main.py b/src/mindlessgen/generator/main.py index e1ec908..38e910e 100644 --- a/src/mindlessgen/generator/main.py +++ b/src/mindlessgen/generator/main.py @@ -10,7 +10,7 @@ import warnings from ..molecules import generate_random_molecule, Molecule -from ..qm import XTB, get_xtb_path, QMMethod, ORCA, get_orca_path, GP3, get_gp3_path +from ..qm import XTB, get_xtb_path, QMMethod, ORCA, get_orca_path, GXTB, get_gxtb_path from ..molecules import iterative_optimization, postprocess_mol from ..prog import ConfigManager @@ -45,12 +45,16 @@ def generator(config: ConfigManager) -> tuple[list[Molecule] | None, int]: config.refine.engine, config, get_xtb_path, - get_orca_path, # GP3 cannot be used anyway + get_orca_path, # g-xTB cannot be used anyway ) if config.general.postprocess: postprocess_engine: QMMethod | None = setup_engines( - config.postprocess.engine, config, get_xtb_path, get_orca_path, get_gp3_path + config.postprocess.engine, + config, + get_xtb_path, + get_orca_path, + get_gxtb_path, ) else: postprocess_engine = None @@ -228,22 +232,41 @@ def single_molecule_generator( if config.general.gxtb_development: # additional g-xTB engine after refinement and postprocessing for development purposes - gp3 = GP3(get_gp3_path()) - try: - _gxtb_dev_check( - optimized_molecule, - gp3, - config.general.gxtb_scf_cycles, - config.general.verbosity, - ) - except (RuntimeError, ValueError) as e: - if config.general.verbosity > 0: - print(f"g-xTB postprocessing failed for cycle {cycle + 1}.") - if config.general.verbosity > 1: - print(e) - return None - if config.general.verbosity > 1: - print("g-xTB postprocessing successful.") + gxtb = GXTB(get_gxtb_path(), config.gxtb) + if not config.general.postprocess and config.postprocess.engine == "gxtb": + try: + _gxtb_scf_check( + optimized_molecule, + gxtb, + config.general.gxtb_scf_cycles, + config.general.verbosity, + ) + except (RuntimeError, ValueError) as e: + if config.general.verbosity > 0: + print( + f"g-xTB postprocessing (SCF cycles check) failed for cycle {cycle + 1}." + ) + if config.general.verbosity > 1: + print(e) + return None + if config.general.verbosity > 1: + print("g-xTB postprocessing (SCF cycles check) successful.") + if config.general.gxtb_ipea: + try: + _gxtb_ipea_check( + optimized_molecule, + gxtb, + config.general.gxtb_scf_cycles, + config.general.verbosity, + ) + except (RuntimeError, ValueError) as e: + if config.general.verbosity > 0: + print(f"g-xTB postprocessing failed for cycle {cycle + 1}.") + if config.general.verbosity > 1: + print(e) + return None + if config.general.verbosity > 1: + print("g-xTB postprocessing successful.") if not stop_event.is_set(): stop_event.set() # Signal other processes to stop @@ -285,7 +308,7 @@ def setup_engines( cfg: ConfigManager, xtb_path_func: Callable, orca_path_func: Callable, - gp3_path_func: Callable | None = None, + gxtb_path_func: Callable | None = None, ): """ Set up the required engine. @@ -306,19 +329,19 @@ def setup_engines( except ImportError as e: raise ImportError("orca not found.") from e return ORCA(path, cfg.orca) - elif engine_type == "gp3": - if gp3_path_func is None: - raise ImportError("No callable function for determining the gp3 path.") - path = gp3_path_func() + elif engine_type == "gxtb": + if gxtb_path_func is None: + raise ImportError("No callable function for determining the g-xTB path.") + path = gxtb_path_func(cfg.gxtb.gxtb_path) if not path: - raise ImportError("'gp3' binary could not be found.") - return GP3(path) + raise ImportError("'gxtb' binary could not be found.") + return GXTB(path, cfg.gxtb) else: raise NotImplementedError("Engine not implemented.") -def _gxtb_dev_check( - mol: Molecule, gp3: GP3, scf_iter_limit: int, verbosity: int = 0 +def _gxtb_ipea_check( + mol: Molecule, gxtb: GXTB, scf_iter_limit: int, verbosity: int = 0 ) -> None: """ ONLY FOR IN-HOUSE g-xTB DEVELOPMENT PURPOSES: Check the SCF iterations of the cation and anion. @@ -333,7 +356,7 @@ def _gxtb_dev_check( + "For the g-xTB cationic calculation, we are increasing it by 1. " + "(Could be ill-defined.)" ) - gxtb_output = gp3.singlepoint(tmp_mol, verbosity=verbosity) + gxtb_output = gxtb.singlepoint(tmp_mol, verbosity=verbosity) # gp3_output looks like this: # [...] # 13 -155.03101038 0.00000000 0.00000001 16.45392733 8 F @@ -361,7 +384,33 @@ def _gxtb_dev_check( + "For the g-xTB anionic calculation, we are increasing it by 1. " + "(Could be ill-defined.)" ) - gxtb_output = gp3.singlepoint(tmp_mol, verbosity=verbosity) + gxtb_output = gxtb.singlepoint(tmp_mol, verbosity=verbosity) + # Check for the number of scf iterations + scf_iterations = 0 + for line in gxtb_output.split("\n"): + if "scf iterations" in line: + scf_iterations = int(line.split()[0]) + break + if scf_iterations == 0: + raise ValueError("SCF iterations not found in GP3 output.") + if scf_iterations > scf_iter_limit: + raise ValueError(f"SCF iterations exceeded limit of {scf_iter_limit}.") + + +def _gxtb_scf_check( + mol: Molecule, gxtb: GXTB, scf_iter_limit: int, verbosity: int = 0 +) -> None: + """ + ONLY FOR IN-HOUSE g-xTB DEVELOPMENT PURPOSES: Check the SCF iterations with g-xTB. + """ + # 1) Single point calculation with g-xTB for the cation + gxtb_output = gxtb.singlepoint(mol, verbosity=verbosity) + # gp3_output looks like this: + # [...] + # 13 -155.03101038 0.00000000 0.00000001 16.45392733 8 F + # 13 scf iterations + # eigenvalues + # [...] # Check for the number of scf iterations scf_iterations = 0 for line in gxtb_output.split("\n"): diff --git a/src/mindlessgen/molecules/postprocess.py b/src/mindlessgen/molecules/postprocess.py index 17f6ae4..4189cc9 100644 --- a/src/mindlessgen/molecules/postprocess.py +++ b/src/mindlessgen/molecules/postprocess.py @@ -37,6 +37,6 @@ def postprocess_mol( postprocmol = mol except RuntimeError as e: raise RuntimeError( - "Single point calculation in postprocessing failed." + f"Single point calculation in postprocessing failed with error: {e}" ) from e return postprocmol diff --git a/src/mindlessgen/prog/__init__.py b/src/mindlessgen/prog/__init__.py index 8dee820..3ccbccb 100644 --- a/src/mindlessgen/prog/__init__.py +++ b/src/mindlessgen/prog/__init__.py @@ -7,6 +7,7 @@ GeneralConfig, XTBConfig, ORCAConfig, + GXTBConfig, GenerateConfig, RefineConfig, PostProcessConfig, @@ -17,6 +18,7 @@ "GeneralConfig", "XTBConfig", "ORCAConfig", + "GXTBConfig", "GenerateConfig", "RefineConfig", "PostProcessConfig", diff --git a/src/mindlessgen/prog/config.py b/src/mindlessgen/prog/config.py index 8f83247..6097668 100644 --- a/src/mindlessgen/prog/config.py +++ b/src/mindlessgen/prog/config.py @@ -12,6 +12,8 @@ import numpy as np import toml +from mindlessgen.molecules.molecule import PSE_SYMBOLS + from ..molecules import PSE_NUMBERS @@ -49,6 +51,7 @@ def __init__(self: GeneralConfig) -> None: ############################################################ ## g-xTB-specific settings not intended for general use #### self._gxtb_development: bool = False + self._gxtb_ipea: bool = False self._gxtb_scf_cycles: int = 100 ### End of g-xTB-specific settings ######################### ############################################################ @@ -194,6 +197,22 @@ def gxtb_development(self, gxtb_development: bool): raise TypeError("gxtb_development should be a boolean.") self._gxtb_development = gxtb_development + @property + def gxtb_ipea(self): + """ + Check for cation and anion with g-xTB. + """ + return self._gxtb_ipea + + @gxtb_ipea.setter + def gxtb_ipea(self, gxtb_ipea: bool): + """ + Set the g-xTB IPEA flag. + """ + if not isinstance(gxtb_ipea, bool): + raise TypeError("gxtb_ipea should be a boolean.") + self._gxtb_ipea = gxtb_ipea + @property def gxtb_scf_cycles(self): """ @@ -336,18 +355,19 @@ def element_composition(self): @element_composition.setter def element_composition( - self, composition: None | str | dict[int, tuple[int | None, int | None]] + self, composition: None | str | dict[int | str, tuple[int | None, int | None]] ) -> None: """ - If composition_str: str, it should be a string with the format: + If composition: str: Parses the element_composition string and stores the parsed data in the _element_composition dictionary. Format: "C:2-10, H:10-20, O:1-5, N:1-*" - If composition_str: dict, it should be a dictionary with integer keys and tuple values. Will be stored as is. + If composition: dict: + Should be a dictionary with integer/string keys and tuple values. Will be stored as is. Arguments: composition_str (str): String with the element composition - composition_str (dict): Dictionary with integer keys and tuple values + composition_str (dict): Dictionary with integer/str keys and tuple values Raises: TypeError: If composition_str is not a string or a dictionary AttributeError: If the element is not found in the periodic table @@ -358,25 +378,50 @@ def element_composition( if not composition: return + + # Will return if composition dict does not contain either int or str keys and tuple[int | None, int | None] values + # Will also return if dict is valid after setting property if isinstance(composition, dict): + tmp = {} + + # Check validity and also convert str keys into atomic numbers for key, value in composition.items(): if ( - not isinstance(key, int) + not (isinstance(key, int) or isinstance(key, str)) or not isinstance(value, tuple) or len(value) != 2 or not all(isinstance(val, int) or val is None for val in value) ): raise TypeError( - "Element composition dictionary should be a dictionary with integer keys and tuple values (int, int)." + "Element composition dictionary should be a dictionary with either integer or string keys and tuple values (int, int)." ) - self._element_composition = composition + + # Convert str keys + if isinstance(key, str): + element_number = PSE_NUMBERS.get(key.lower(), None) + if element_number is None: + raise KeyError( + f"Element {key} not found in the periodic table." + ) + tmp[element_number - 1] = composition[key] + # Check int keys + else: + if key + 1 in PSE_SYMBOLS: + tmp[key] = composition[key] + else: + raise KeyError( + f"Element with atomic number {key+1} (provided key: {key}) not found in the periodic table." + ) + self._element_composition = tmp return + if not isinstance(composition, str): raise TypeError( "Element composition should be a string (will be parsed) or " - + "a dictionary with integer keys and tuple values." + + "a dictionary with integer/string keys and tuple values." ) + # Parsing composition string element_dict: dict[int, tuple[int | None, int | None]] = {} elements = composition.split(",") # remove leading and trailing whitespaces @@ -583,9 +628,11 @@ def check_config(self, verbosity: int = 1) -> None: if ( np.sum( [ - self.element_composition.get(i, (0, 0))[0] - if self.element_composition.get(i, (0, 0))[0] is not None - else 0 + ( + self.element_composition.get(i, (0, 0))[0] + if self.element_composition.get(i, (0, 0))[0] is not None + else 0 + ) for i in self.element_composition ] ) @@ -739,7 +786,7 @@ def engine(self, engine: str): """ if not isinstance(engine, str): raise TypeError("Postprocess engine should be a string.") - if engine not in ["xtb", "orca", "gp3"]: + if engine not in ["xtb", "orca", "gxtb"]: raise ValueError("Postprocess engine can only be xtb or orca.") self._engine = engine @@ -941,6 +988,53 @@ def scf_cycles(self, max_scf_cycles: int): self._scf_cycles = max_scf_cycles +class GXTBConfig(BaseConfig): + """ + Configuration class for g-xTB. + """ + + def __init__(self: GXTBConfig) -> None: + self._gxtb_path: str | Path = "gxtb" + self._scf_cycles: int = 100 + + def get_identifier(self) -> str: + return "gxtb" + + @property + def gxtb_path(self): + """ + Get the g-xTB path. + """ + return self._gxtb_path + + @gxtb_path.setter + def gxtb_path(self, gxtb_path: str | Path): + """ + Set the g-xTB path. + """ + if not isinstance(gxtb_path, str | Path): + raise TypeError("gxtb_path should be a string or Path.") + self._gxtb_path = gxtb_path + + @property + def scf_cycles(self): + """ + Get the maximum number of SCF cycles. + """ + return self._scf_cycles + + @scf_cycles.setter + def scf_cycles(self, max_scf_cycles: int): + """ + Set the maximum number of SCF cycles. + """ + if not isinstance(max_scf_cycles, int): + raise TypeError("Max SCF cycles should be an integer.") + if max_scf_cycles < 1: + raise ValueError("Max SCF cycles should be greater than 0.") + self._scf_cycles = max_scf_cycles + + class ConfigManager: """ Overall configuration manager for the program. @@ -956,6 +1050,7 @@ def __init__(self, config_file: str | Path | None = None): self.refine = RefineConfig() self.postprocess = PostProcessConfig() self.generate = GenerateConfig() + self.gxtb = GXTBConfig() if config_file: self.load_from_toml(config_file) diff --git a/src/mindlessgen/qm/__init__.py b/src/mindlessgen/qm/__init__.py index a5e8e3a..32237f8 100644 --- a/src/mindlessgen/qm/__init__.py +++ b/src/mindlessgen/qm/__init__.py @@ -5,7 +5,7 @@ from .base import QMMethod from .xtb import XTB, get_xtb_path from .orca import ORCA, get_orca_path -from .gp3 import GP3, get_gp3_path +from .gxtb import GXTB, get_gxtb_path __all__ = [ "XTB", @@ -13,6 +13,6 @@ "QMMethod", "ORCA", "get_orca_path", - "GP3", - "get_gp3_path", + "GXTB", + "get_gxtb_path", ] diff --git a/src/mindlessgen/qm/gp3.py b/src/mindlessgen/qm/gxtb.py similarity index 54% rename from src/mindlessgen/qm/gp3.py rename to src/mindlessgen/qm/gxtb.py index 0ca4d40..970c6ca 100644 --- a/src/mindlessgen/qm/gp3.py +++ b/src/mindlessgen/qm/gxtb.py @@ -1,5 +1,5 @@ """ -This module contains all interactions with the GP3-xTB binary +This module contains all interactions with the g-xTB binary for next-gen tight-binding calculations. """ @@ -9,37 +9,39 @@ from tempfile import TemporaryDirectory from ..molecules import Molecule +from ..prog import GXTBConfig from .base import QMMethod -class GP3(QMMethod): +class GXTB(QMMethod): """ - This class handles all interaction with the GP3 external dependency. + This class handles all interaction with the g-xTB external dependency. """ - def __init__(self, path: str | Path) -> None: + def __init__(self, path: str | Path, gxtbcfg: GXTBConfig) -> None: """ - Initialize the GP3 class. + Initialize the GXTB class. """ if isinstance(path, str): self.path: Path = Path(path).resolve() elif isinstance(path, Path): self.path = path else: - raise TypeError("gp3_path should be a string or a Path object.") + raise TypeError("gxtb_path should be a string or a Path object.") + self.cfg = gxtbcfg def singlepoint(self, molecule: Molecule, verbosity: int = 1) -> str: """ - Perform a single-point calculation using GP3-xTB. + Perform a single-point calculation using g-xTB. """ # Create a unique temporary directory using TemporaryDirectory context manager - with TemporaryDirectory(prefix="gp3_") as temp_dir: + with TemporaryDirectory(prefix="gxtb_") as temp_dir: temp_path = Path(temp_dir).resolve() # write the molecule to a temporary file molecule.write_xyz_to_file(str(temp_path / "molecule.xyz")) - # run gp3 + # run g-xTB arguments = [ "-c", "molecule.xyz", @@ -53,19 +55,37 @@ def singlepoint(self, molecule: Molecule, verbosity: int = 1) -> str: f.write(str(molecule.uhf)) if verbosity > 2: - print(f"Running command: gp3 {' '.join(arguments)}") + print(f"Running command: gxtb {' '.join(arguments)}") - gp3_log_out, gp3_log_err, return_code = self._run( + gxtb_log_out, gxtb_log_err, return_code = self._run( temp_path=temp_path, arguments=arguments ) if verbosity > 2: - print(gp3_log_out) + print(gxtb_log_out) if return_code != 0: raise RuntimeError( - f"GP3-xTB failed with return code {return_code}:\n{gp3_log_err}" + f"g-xTB failed with return code {return_code}:\n{gxtb_log_err}" + ) + # gp3_output looks like this: + # [...] + # 13 -155.03101038 0.00000000 0.00000001 16.45392733 8 F + # 13 scf iterations + # eigenvalues + # [...] + # Check for the number of scf iterations + scf_iterations = 0 + for line in gxtb_log_out.split("\n"): + if "scf iterations" in line: + scf_iterations = int(line.strip().split()[0]) + break + if scf_iterations == 0: + raise RuntimeError("SCF iterations not found in GP3 output.") + if scf_iterations > self.cfg.scf_cycles: + raise RuntimeError( + f"SCF iterations exceeded limit of {self.cfg.scf_cycles}." ) - return gp3_log_out + return gxtb_log_out def check_gap( self, molecule: Molecule, threshold: float = 0.5, verbosity: int = 1 @@ -83,13 +103,13 @@ def check_gap( # Perform a single point calculation try: - gp3_out = self.singlepoint(molecule) + gxtb_out = self.singlepoint(molecule) except RuntimeError as e: raise RuntimeError("Single point calculation failed.") from e # Parse the output to get the gap hlgap = None - for line in gp3_out.split("\n"): + for line in gxtb_out.split("\n"): if "gap (eV)" in line and "dE" not in line: # check if "alpha->alpha" is present in the same line # then, the line looks as follows: @@ -102,10 +122,10 @@ def check_gap( hlgap = float(line.split()[3]) break if hlgap is None: - raise ValueError("GP3-xTB gap not determined.") + raise ValueError("g-xTB gap not determined.") if verbosity > 1: - print(f"GP3-xTB HOMO-LUMO gap: {hlgap:5f}") + print(f"g-xTB HOMO-LUMO gap: {hlgap:5f}") return hlgap > threshold @@ -113,68 +133,68 @@ def optimize( self, molecule: Molecule, max_cycles: int | None = None, verbosity: int = 1 ) -> Molecule: """ - Optimize a molecule using GP3-xTB. + Optimize a molecule using g-xTB. """ - raise NotImplementedError("Optimization is not yet implemented for GP3-xTB.") + raise NotImplementedError("Optimization is not yet implemented for g-xTB.") def _run(self, temp_path: Path, arguments: list[str]) -> tuple[str, str, int]: """ - Run GP3-xTB with the given arguments. + Run g-xTB with the given arguments. Arguments: - arguments (list[str]): The arguments to pass to GP3-xTB. + arguments (list[str]): The arguments to pass to g-xTB. Returns: - tuple[str, str, int]: The output of the GP3-xTB calculation (stdout and stderr) + tuple[str, str, int]: The output of the g-xTB calculation (stdout and stderr) and the return code """ try: - gp3_out = sp.run( + gxtb_out = sp.run( [str(self.path)] + arguments, cwd=temp_path, capture_output=True, check=True, ) - # get the output of the GP3-xTB calculation (of both stdout and stderr) - gp3_log_out = gp3_out.stdout.decode("utf8") - gp3_log_err = gp3_out.stderr.decode("utf8") + # get the output of the g-xTB calculation (of both stdout and stderr) + gxtb_log_out = gxtb_out.stdout.decode("utf8") + gxtb_log_err = gxtb_out.stderr.decode("utf8") if ( - "no SCF convergence" in gp3_log_out - or "nuclear repulsion" not in gp3_log_out + "no SCF convergence" in gxtb_log_out + or "nuclear repulsion" not in gxtb_log_out ): raise sp.CalledProcessError( 1, str(self.path), - gp3_log_out.encode("utf8"), - gp3_log_err.encode("utf8"), + gxtb_log_out.encode("utf8"), + gxtb_log_err.encode("utf8"), ) - return gp3_log_out, gp3_log_err, 0 + return gxtb_log_out, gxtb_log_err, 0 except sp.CalledProcessError as e: - gp3_log_out = e.stdout.decode("utf8") - gp3_log_err = e.stderr.decode("utf8") - return gp3_log_out, gp3_log_err, e.returncode + gxtb_log_out = e.stdout.decode("utf8") + gxtb_log_err = e.stderr.decode("utf8") + return gxtb_log_out, gxtb_log_err, e.returncode -# TODO: 1. Convert this to a @staticmethod of Class GP3 +# TODO: 1. Convert this to a @staticmethod of Class GXTB # 2. Rename to `get_method` or similar to enable an abstract interface # 3. Add the renamed method to the ABC `QMMethod` # 4. In `main.py`: Remove the passing of the path finder functions as arguments # and remove the boiler plate code to make it more general. -def get_gp3_path(binary_name: str | Path | None = None) -> Path: +def get_gxtb_path(binary_name: str | Path | None = None) -> Path: """ - Get the path to the GP3 binary based on different possible names + Get the path to the g-xTB binary based on different possible names that are searched for in the PATH. """ - default_gp3_names: list[str | Path] = ["gp3", "gp3_dev"] + default_gxtb_names: list[str | Path] = ["gxtb", "gxtb_dev"] # put binary name at the beginning of the lixt to prioritize it if binary_name is not None: - binary_names = [binary_name] + default_gp3_names + binary_names = [binary_name] + default_gxtb_names else: - binary_names = default_gp3_names - # Get gp3 path from 'which gp3' command + binary_names = default_gxtb_names + # Get g-xTB path from 'which gxtb' command for binpath in binary_names: - which_gp3 = shutil.which(binpath) - if which_gp3: - gp3_path = Path(which_gp3).resolve() - return gp3_path - raise ImportError("'gp3' binary could not be found.") + which_gxtb = shutil.which(binpath) + if which_gxtb: + gxtb_path = Path(which_gxtb).resolve() + return gxtb_path + raise ImportError("'gxtb' binary could not be found.") diff --git a/test/test_generate/test_generate_molecule.py b/test/test_generate/test_generate_molecule.py index a020536..190843f 100644 --- a/test/test_generate/test_generate_molecule.py +++ b/test/test_generate/test_generate_molecule.py @@ -41,9 +41,9 @@ def test_generate_atom_list(min_atoms, max_atoms, default_generate_config): assert np.sum(atom_list) <= max_atoms -# Test the element composition property of the GenerateConfig class -def test_generate_config_element_composition(default_generate_config): - """Test the element composition property of the GenerateConfig class.""" +# Test the element composition property of the GenerateConfig class with a composition string +def test_generate_config_element_composition_string(default_generate_config): + """Test the element composition property of the GenerateConfig class with a composition string.""" default_generate_config.min_num_atoms = 10 default_generate_config.max_num_atoms = 15 default_generate_config.element_composition = "C:2-2, N:3-3, O:1-1" @@ -55,6 +55,60 @@ def test_generate_config_element_composition(default_generate_config): assert atom_list[7] == 1 +# Test the element composition property of the GenerateConfig class with an int key composition dict +def test_generate_config_element_composition_dict_int(default_generate_config): + """Test the element composition property of the GenerateConfig class with an int key composition dict.""" + + # Pure int keys + default_generate_config.min_num_atoms = 10 + default_generate_config.max_num_atoms = 15 + default_generate_config.element_composition = { + 5: (2, 2), + 6: (3, 3), + 7: (1, 1), + } # NOTE: mind 0-based indexing for atomic numbers + atom_list = generate_atom_list(default_generate_config, verbosity=1) + + # Check that the atom list contains the correct number of atoms for each element + assert atom_list[5] == 2 + assert atom_list[6] == 3 + assert atom_list[7] == 1 + + +# Test the element composition property of the GenerateConfig class with an int key composition dict +def test_generate_config_element_composition_dict_string(default_generate_config): + """Test the element composition property of the GenerateConfig class with a str key composition dict.""" + + default_generate_config.min_num_atoms = 10 + default_generate_config.max_num_atoms = 15 + default_generate_config.element_composition = { + "C": (2, 2), + "N": (3, 3), + "O": (1, 1), + } + atom_list = generate_atom_list(default_generate_config, verbosity=1) + + # Check that the atom list contains the correct number of atoms for each element + assert atom_list[5] == 2 + assert atom_list[6] == 3 + assert atom_list[7] == 1 + + +# Test the element composition property of the GenerateConfig class with an int key composition dict +def test_generate_config_element_composition_dict_mixed(default_generate_config): + """Test the element composition property of the GenerateConfig class with a str key composition dict.""" + + default_generate_config.min_num_atoms = 10 + default_generate_config.max_num_atoms = 15 + default_generate_config.element_composition = {5: (2, 2), "N": (3, 3), "O": (1, 1)} + atom_list = generate_atom_list(default_generate_config, verbosity=1) + + # Check that the atom list contains the correct number of atoms for each element + assert atom_list[5] == 2 + assert atom_list[6] == 3 + assert atom_list[7] == 1 + + # Test the forbidden_elements property of the GenerateConfig class def test_generate_config_forbidden_elements(default_generate_config): """Test the forbidden_elements property of the GenerateConfig class."""