diff --git a/dmlcloud/util/logging.py b/dmlcloud/util/logging.py index d8fcaba..5e6c56b 100644 --- a/dmlcloud/util/logging.py +++ b/dmlcloud/util/logging.py @@ -11,7 +11,7 @@ import dmlcloud from . import slurm from .git import git_hash -from .thirdparty import try_get_version +from .thirdparty import is_imported, ML_MODULES, try_get_version class IORedirector: @@ -138,23 +138,9 @@ def general_diagnostics() -> str: except (FileNotFoundError, IndexError): pass - msg += f' - torch: {torch.__version__}\n' - if try_get_version('torchvision'): - msg += f' - torchvision: {try_get_version("torchvision")}\n' - if try_get_version('torchtext'): - msg += f' - torchtext: {try_get_version("torchtext")}\n' - if try_get_version('torchaudio'): - msg += f' - torchaudio: {try_get_version("torchaudio")}\n' - if try_get_version('einops'): - msg += f' - einops: {try_get_version("einops")}\n' - if try_get_version('numpy'): - msg += f' - numpy: {try_get_version("numpy")}\n' - if try_get_version('pandas'): - msg += f' - pandas: {try_get_version("pandas")}\n' - if try_get_version('xarray'): - msg += f' - xarray: {try_get_version("xarray")}\n' - if try_get_version('sklearn'): - msg += f' - sklearn: {try_get_version("sklearn")}\n' + for module_name in ML_MODULES: + if is_imported(module_name): + msg += f' - {module_name}: {try_get_version(module_name)}\n' if 'SLURM_JOB_ID' in os.environ: msg += '* SLURM:\n' diff --git a/dmlcloud/util/thirdparty.py b/dmlcloud/util/thirdparty.py index 9693023..b217c84 100644 --- a/dmlcloud/util/thirdparty.py +++ b/dmlcloud/util/thirdparty.py @@ -1,8 +1,26 @@ import importlib +import sys from types import ModuleType from typing import Optional +ML_MODULES = [ + 'torch', + 'torchvision', + 'torchtext', + 'torchaudio', + 'einops', + 'numpy', + 'pandas', + 'xarray', + 'sklearn', +] + + +def is_imported(name: str) -> bool: + return name in sys.modules + + def try_import(name: str) -> Optional[ModuleType]: try: return importlib.import_module(name)