diff --git a/src/lightning_utilities/cli/dependencies.py b/src/lightning_utilities/cli/dependencies.py index 04c770e5..d0467ed1 100644 --- a/src/lightning_utilities/cli/dependencies.py +++ b/src/lightning_utilities/cli/dependencies.py @@ -21,7 +21,13 @@ def prune_packages_in_requirements( packages: Union[str, Sequence[str]], req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL ) -> None: - """Remove some packages from given requirement files.""" + """Remove one or more packages from the specified requirement files. + + Args: + packages: A package name or list of package names to remove. + req_files: A path or list of paths to requirement files to process. + + """ if isinstance(packages, str): packages = [packages] if isinstance(req_files, str): @@ -31,7 +37,13 @@ def prune_packages_in_requirements( def _prune_packages(req_file: str, packages: Sequence[str]) -> None: - """Remove some packages from given requirement files.""" + """Remove all occurrences of the given packages (by line prefix) from a requirements file. + + Args: + req_file: Path to a requirements file. + packages: Package names to remove. Lines starting with any of these names will be dropped. + + """ with open(req_file) as fp: lines = fp.readlines() @@ -46,6 +58,12 @@ def _prune_packages(req_file: str, packages: Sequence[str]) -> None: def _replace_min_req_in_txt(req_file: str) -> None: + """Replace all occurrences of '>=' with '==' in a plain text requirements file. + + Args: + req_file: Path to the requirements.txt-like file to update. + + """ with open(req_file) as fopen: req = fopen.read().replace(">=", "==") with open(req_file, "w") as fw: @@ -53,7 +71,14 @@ def _replace_min_req_in_txt(req_file: str) -> None: def _replace_min_req_in_pyproject_toml(proj_file: str = "pyproject.toml") -> None: - """Replace all `>=` with `==` in the standard pyproject.toml file in [project.dependencies].""" + """Replace all '>=' with '==' in the [project.dependencies] section of a standard pyproject.toml. + + Preserves formatting and comments using tomlkit. + + Args: + proj_file: Path to the pyproject.toml file. + + """ import tomlkit # Load and parse the existing pyproject.toml @@ -77,7 +102,14 @@ def _replace_min_req_in_pyproject_toml(proj_file: str = "pyproject.toml") -> Non def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL) -> None: - """Replace the min package version by fixed one.""" + """Convert minimal version specifiers (>=) to pinned ones (==) in the given requirement files. + + Supports plain *.txt requirements and pyproject.toml files. Unsupported file types trigger a warning. + + Args: + req_files: A path or list of paths to requirement files to process. + + """ if isinstance(req_files, str): req_files = [req_files] for fname in req_files: @@ -95,7 +127,14 @@ def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FI def _replace_package_name_in_txt(req_file: str, old_package: str, new_package: str) -> None: - """Replace one package by another with the same version in a given requirement file.""" + """Rename a package in a plain text requirements file, preserving version specifiers and markers. + + Args: + req_file: Path to the requirements.txt-like file to update. + old_package: The original package name to replace. + new_package: The new package name to use. + + """ # load file with open(req_file) as fopen: requirements = fopen.readlines() @@ -108,7 +147,14 @@ def _replace_package_name_in_txt(req_file: str, old_package: str, new_package: s def _replace_package_name_in_pyproject_toml(proj_file: str, old_package: str, new_package: str) -> None: - """Replace one package by another with the same version in the standard pyproject.toml file.""" + """Rename a package in the [project.dependencies] section of a standard pyproject.toml, preserving constraints. + + Args: + proj_file: Path to the pyproject.toml file. + old_package: The original package name to replace. + new_package: The new package name to use. + + """ import tomlkit # Load and parse the existing pyproject.toml @@ -134,7 +180,16 @@ def _replace_package_name_in_pyproject_toml(proj_file: str, old_package: str, ne def replace_package_in_requirements( old_package: str, new_package: str, req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL ) -> None: - """Replace one package by another with same version in given requirement files.""" + """Rename a package across multiple requirement files while keeping version constraints intact. + + Supports plain *.txt requirements and pyproject.toml files. Unsupported file types trigger a warning. + + Args: + old_package: The original package name to replace. + new_package: The new package name to use. + req_files: A path or list of paths to requirement files to process. + + """ if isinstance(req_files, str): req_files = [req_files] for fname in req_files: diff --git a/src/lightning_utilities/core/apply_func.py b/src/lightning_utilities/core/apply_func.py index bbcc545e..2c9765e5 100644 --- a/src/lightning_utilities/core/apply_func.py +++ b/src/lightning_utilities/core/apply_func.py @@ -11,13 +11,17 @@ def is_namedtuple(obj: object) -> bool: - """Check if object is type nametuple.""" + """Return True if the given object is a namedtuple instance. + + This checks for a tuple with the namedtuple-specific attributes `_asdict` and `_fields`. + + """ # https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8 return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") def is_dataclass_instance(obj: object) -> bool: - """Check if object is dataclass.""" + """Return True if the given object is a dataclass instance (not a dataclass type).""" # https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions return dataclasses.is_dataclass(obj) and not isinstance(obj, type) @@ -195,24 +199,24 @@ def apply_to_collections( wrong_dtype: Optional[Union[type, tuple[type]]] = None, **kwargs: Any, ) -> Any: - """Zips two collections and applies a function to their items of a certain dtype. + """Zip two collections and apply a function to items of a certain dtype. Args: - data1: The first collection - data2: The second collection - dtype: the given function will be applied to all elements of this dtype - function: the function to apply - *args: positional arguments (will be forwarded to calls of ``function``) - wrong_dtype: the given function won't be applied if this type is specified and the given collections - is of the ``wrong_dtype`` even if it is of type ``dtype`` - **kwargs: keyword arguments (will be forwarded to calls of ``function``) + data1: The first collection. If ``None`` and ``data2`` is not ``None``, the arguments are swapped. + data2: The second collection. May be ``None`` to apply ``function`` only to ``data1``. + dtype: The type(s) for which the given ``function`` will be applied to matching elements. + function: The function to apply to matching elements. + *args: Positional arguments forwarded to calls of ``function``. + wrong_dtype: If specified, ``function`` won't be applied to elements of this type even if they match ``dtype``. + **kwargs: Keyword arguments forwarded to calls of ``function``. Returns: - The resulting collection + A collection with the same structure as the input where matching elements are transformed. Raises: - AssertionError: - If sequence collections have different data sizes. + ValueError: If sequence collections have different sizes. + TypeError: If dataclass inputs are mismatched (different types or fields), or if ``data1`` is a + dataclass instance but ``data2`` is not. """ if data1 is None: diff --git a/src/lightning_utilities/core/enums.py b/src/lightning_utilities/core/enums.py index 945e0ef5..7d4ce60b 100644 --- a/src/lightning_utilities/core/enums.py +++ b/src/lightning_utilities/core/enums.py @@ -57,7 +57,7 @@ def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> @classmethod def try_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> Optional["StrEnum"]: - """Try to create emun and if it does not match any, return `None`.""" + """Try to create the enum; if no match is found, emit a warning and return ``None``.""" try: return cls.from_str(value, source) except ValueError: diff --git a/src/lightning_utilities/core/imports.py b/src/lightning_utilities/core/imports.py index 0012fae2..e9af9e4d 100644 --- a/src/lightning_utilities/core/imports.py +++ b/src/lightning_utilities/core/imports.py @@ -268,14 +268,14 @@ def __init__(self, module_name: str, callback: Optional[Callable] = None) -> Non self._callback = callback def __getattr__(self, item: str) -> Any: - """Overwrite attribute access to attribute.""" + """Lazily import the underlying module and delegate attribute access to it.""" if self._module is None: self._import_module() return getattr(self._module, item) def __dir__(self) -> list[str]: - """Overwrite attribute access for dictionary.""" + """Lazily import the underlying module and return its attributes for introspection (dir()).""" if self._module is None: self._import_module() @@ -317,11 +317,13 @@ def lazy_import(module_name: str, callback: Optional[Callable] = None) -> LazyMo def requires(*module_path_version: str, raise_exception: bool = True) -> Callable[[Callable[P, T]], Callable[P, T]]: - """Wrap early import failure with some nice exception message. + """Decorator to check optional dependencies at call time with a clear error/warning message. Args: - module_path_version: python package path (e.g. `torch.cuda`) or pip like requiremsnt (e.g. `torch>=2.0.0`) - raise_exception: how strict the check shall be if exit the code or just warn user + module_path_version: Python module paths (e.g., ``"torch.cuda"``) and/or pip-style requirements + (e.g., ``"torch>=2.0.0"``) to verify. + raise_exception: If ``True``, raise ``ModuleNotFoundError`` when requirements are not satisfied; + otherwise emit a warning and proceed to call the function. Example: >>> @requires("libpath", raise_exception=bool(int(os.getenv("LIGHTING_TESTING", "0")))) diff --git a/src/lightning_utilities/core/inheritance.py b/src/lightning_utilities/core/inheritance.py index 895ab5fc..d7fb4d91 100644 --- a/src/lightning_utilities/core/inheritance.py +++ b/src/lightning_utilities/core/inheritance.py @@ -6,7 +6,7 @@ def get_all_subclasses_iterator(cls: type) -> Iterator[type]: - """Iterate over all subclasses.""" + """Depth-first iterator over all subclasses of ``cls`` (recursively).""" def recurse(cl: type) -> Iterator[type]: for subclass in cl.__subclasses__(): @@ -17,5 +17,5 @@ def recurse(cl: type) -> Iterator[type]: def get_all_subclasses(cls: type) -> set[type]: - """List all subclasses of a class.""" + """Return a set containing all subclasses of ``cls`` discovered recursively.""" return set(get_all_subclasses_iterator(cls)) diff --git a/src/lightning_utilities/core/overrides.py b/src/lightning_utilities/core/overrides.py index 5dc4636d..31899a4f 100644 --- a/src/lightning_utilities/core/overrides.py +++ b/src/lightning_utilities/core/overrides.py @@ -7,7 +7,21 @@ def is_overridden(method_name: str, instance: object, parent: type[object]) -> bool: - """Check if a method of a given object was overwritten.""" + """Return True if ``instance`` overrides ``parent.method_name``. + + Supports functions wrapped with ``functools.wraps``, context managers (``__wrapped__``), + ``unittest.mock.Mock(wraps=...)``, and ``functools.partial``. If the parent does not define + ``method_name``, a ``ValueError`` is raised. + + Args: + method_name: The name of the method to check. + instance: The object instance to inspect. + parent: The parent class that declares the original method. + + Returns: + True if the method implementation on the instance differs from the parent's; otherwise False. + + """ instance_attr = getattr(instance, method_name, None) if instance_attr is None: return False diff --git a/src/lightning_utilities/core/rank_zero.py b/src/lightning_utilities/core/rank_zero.py index ec18e533..13d15deb 100644 --- a/src/lightning_utilities/core/rank_zero.py +++ b/src/lightning_utilities/core/rank_zero.py @@ -26,9 +26,10 @@ def rank_zero_only(fn: Callable[P, T], default: T) -> Callable[P, T]: ... def rank_zero_only(fn: Callable[P, T], default: Optional[T] = None) -> Callable[P, Optional[T]]: - """Wrap a function to call internal function only in rank zero. + """Decorator to run the wrapped function only on global rank 0. - Function that can be used as a decorator to enable a function/method being called only on global rank 0. + Set ``rank_zero_only.rank`` before use. On non-zero ranks, the function is skipped and the provided + ``default`` is returned (or ``None`` if not given). """ @@ -86,7 +87,7 @@ def rank_zero_deprecation(message: Union[str, Warning], stacklevel: int = 5, **k def rank_prefixed_message(message: str, rank: Optional[int]) -> str: - """Add a prefix with the rank to a message.""" + """Add a ``[rank: X]`` prefix to the message if ``rank`` is provided; otherwise return the message unchanged.""" if rank is not None: # specify the rank of the process being logged return f"[rank: {rank}] {message}" @@ -94,22 +95,22 @@ def rank_prefixed_message(message: str, rank: Optional[int]) -> str: class WarningCache(set): - """Cache for warnings.""" + """A simple de-duplication cache for messages to avoid emitting the same warning/info multiple times.""" def warn(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None: - """Trigger warning message.""" + """Emit a warning once on global rank 0; subsequent identical messages are suppressed.""" if message not in self: self.add(message) rank_zero_warn(message, stacklevel=stacklevel, **kwargs) def deprecation(self, message: str, stacklevel: int = 6, **kwargs: Any) -> None: - """Trigger deprecation message.""" + """Emit a deprecation warning once on global rank 0; subsequent identical messages are suppressed.""" if message not in self: self.add(message) rank_zero_deprecation(message, stacklevel=stacklevel, **kwargs) def info(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None: - """Trigger info message.""" + """Emit an info-level log once on global rank 0; subsequent identical messages are suppressed.""" if message not in self: self.add(message) rank_zero_info(message, stacklevel=stacklevel, **kwargs) diff --git a/src/lightning_utilities/docs/formatting.py b/src/lightning_utilities/docs/formatting.py index a2c4d0ca..077816fc 100644 --- a/src/lightning_utilities/docs/formatting.py +++ b/src/lightning_utilities/docs/formatting.py @@ -13,11 +13,11 @@ def _transform_changelog(path_in: str, path_out: str) -> None: - """Adjust changelog titles so not to be duplicated. + """Adjust changelog headers to avoid duplication of short subtitles. Args: - path_in: input MD file - path_out: output also MD file + path_in: Input Markdown file path. + path_out: Output Markdown file path. """ with open(path_in) as fp: @@ -99,12 +99,15 @@ def _load_pypi_versions(package_name: str) -> list[str]: def _update_link_based_imported_package(link: str, pkg_ver: str, version_digits: Optional[int]) -> str: - """Adjust the linked external docs to be local. + """Resolve a ``{package.version}`` placeholder in a link using the latest available version. Args: - link: the source link to be replaced - pkg_ver: the target link to be replaced, if ``{package.version}`` is included it will be replaced accordingly - version_digits: for semantic versioning, how many digits to be considered + link: The link template containing a ``{...}`` placeholder to replace. + pkg_ver: A dotted path to resolve the version (e.g., ``"numpy.__version__"``). + version_digits: Number of version components to keep (e.g., ``2`` -> ``"1.26"``). If ``None``, keep all. + + Returns: + The link with the ``{...}`` placeholder replaced by a version string. """ pkg_att = pkg_ver.split(".") diff --git a/src/lightning_utilities/docs/retriever.py b/src/lightning_utilities/docs/retriever.py index 4de873bf..7798b1b0 100644 --- a/src/lightning_utilities/docs/retriever.py +++ b/src/lightning_utilities/docs/retriever.py @@ -10,7 +10,13 @@ def _download_file(file_url: str, folder: str) -> str: - """Download a file from URL to a particular folder.""" + """Download a file from a URL into the given folder. + + If a file with the same name already exists, it will be overwritten. + Returns the basename of the downloaded file. Network-related exceptions from + ``requests.get`` (e.g., timeouts or connection errors) may propagate to the caller. + + """ fname = os.path.basename(file_url) file_path = os.path.join(folder, fname) if os.path.isfile(file_path): @@ -23,11 +29,14 @@ def _download_file(file_url: str, folder: str) -> str: def _search_all_occurrences(list_files: list[str], pattern: str) -> list[str]: - """Search for all occurrences of specific pattern in a collection of files. + """Search for all occurrences of a regular-expression pattern across files. Args: - list_files: list of files to be scanned - pattern: pattern for search, reg. expression + list_files: The list of file paths to scan. + pattern: A regular-expression pattern to search for in each file. + + Returns: + A list with all matches found across the provided files (order preserved per file). """ collected = [] @@ -40,12 +49,12 @@ def _search_all_occurrences(list_files: list[str], pattern: str) -> list[str]: def _replace_remote_with_local(file_path: str, docs_folder: str, pairs_url_path: list[tuple[str, str]]) -> None: - """Replace all URL with local files in a given file. + """Replace all matching remote URLs with local file paths in a given file. Args: - file_path: file for replacement - docs_folder: the location of docs related to the project root - pairs_url_path: pairs of URL and local file path to be swapped + file_path: The file in which replacements should be performed. + docs_folder: The documentation root folder (used to compute relative paths). + pairs_url_path: Pairs of (remote_url, local_relative_path) to replace. """ # drop the default/global path to the docs @@ -69,13 +78,13 @@ def fetch_external_assets( file_pattern: str = "*.rst", retrieve_pattern: str = r"https?://[-a-zA-Z0-9_]+\.s3\.[-a-zA-Z0-9()_\\+.\\/=]+", ) -> None: - """Search all URL in docs, download these files locally and replace online with local version. + """Find S3 (or HTTP) asset URLs in docs, download them locally, and rewrite references to local paths. Args: - docs_folder: the location of docs related to the project root - assets_folder: a folder inside ``docs_folder`` to be created and saving online assets - file_pattern: what kind of files shall be scanned - retrieve_pattern: pattern for reg. expression to search URL/S3 resources + docs_folder: The documentation root relative to the project. + assets_folder: Subfolder inside ``docs_folder`` used to store downloaded assets (created if missing). + file_pattern: Glob pattern of files to scan. + retrieve_pattern: Regular-expression pattern used to find remote asset URLs. """ list_files = glob.glob(os.path.join(docs_folder, "**", file_pattern), recursive=True) diff --git a/src/lightning_utilities/install/requirements.py b/src/lightning_utilities/install/requirements.py index b39a4ca1..e8229c05 100644 --- a/src/lightning_utilities/install/requirements.py +++ b/src/lightning_utilities/install/requirements.py @@ -1,6 +1,13 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # http://www.apache.org/licenses/LICENSE-2.0 # +"""Utilities to parse and adjust Python requirements files. + +This module parses requirement lines while preserving inline comments and pip arguments and +supports relaxing version pins based on a chosen unfreeze strategy: "none", "major", or "all". + +""" + import re from collections.abc import Iterable, Iterator from distutils.version import LooseVersion @@ -11,6 +18,17 @@ class _RequirementWithComment(Requirement): + """Requirement subclass that preserves an inline comment and optional pip argument. + + Attributes: + comment: The trailing comment captured from the requirement line (including the leading '# ...'). + pip_argument: A preceding pip argument line (e.g., ``"--extra-index-url ..."``) associated + with this requirement, or ``None`` if not provided. + strict: Whether the special marker ``"# strict"`` appears in ``comment`` (case-insensitive), in which case + upper bound adjustments are disabled. + + """ + strict_string = "# strict" def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None: @@ -22,7 +40,9 @@ def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = self.strict = self.strict_string in comment.lower() def adjust(self, unfreeze: str) -> str: - """Remove version restrictions unless they are strict. + """Adjust version specifiers according to the selected unfreeze strategy. + + The special marker ``"# strict"`` in the captured comment disables any relaxation of upper bounds. >>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# anything").adjust("none") 'arrow<=1.2.2,>=1.2.0' @@ -43,6 +63,15 @@ def adjust(self, unfreeze: str) -> str: >>> _RequirementWithComment("arrow").adjust("major") 'arrow' + Args: + unfreeze: One of: + - ``"none"``: Keep all version specifiers unchanged. + - ``"major"``: Relax the upper bound to the next major version (e.g., ``<2.0``). + - ``"all"``: Drop any upper bound constraint entirely. + + Returns: + The adjusted requirement string. If strict, the original string is returned with the strict marker appended. + """ out = str(self) if self.strict: @@ -64,7 +93,11 @@ def adjust(self, unfreeze: str) -> str: def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_RequirementWithComment]: - r"""Adapted from `pkg_resources.parse_requirements` to include comments. + r"""Adapted from ``pkg_resources.parse_requirements`` to include comments and pip arguments. + + Parses a sequence or string of requirement lines, preserving trailing comments and associating any + preceding pip arguments (``--...``) with the subsequent requirement. Lines starting with ``-r`` or + containing direct URLs are ignored. >>> txt = ['# ignored', '', 'this # is an', '--piparg', 'example', 'foo # strict', 'thing', '-r different/file.txt'] >>> [r.adjust('none') for r in _parse_requirements(txt)] @@ -73,6 +106,12 @@ def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_Requiremen >>> [r.adjust('none') for r in _parse_requirements(txt)] ['this', 'example', 'foo # strict', 'thing'] + Args: + strs: Either an iterable of requirement lines or a single multi-line string. + + Yields: + _RequirementWithComment: Parsed requirement objects with preserved comment and pip argument. + """ lines = yield_lines(strs) pip_argument = None @@ -105,7 +144,7 @@ def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_Requiremen def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> list[str]: - """Load requirements from a file. + """Load, parse, and optionally relax requirement specifiers from a file. >>> import os >>> from lightning_utilities import _PROJECT_ROOT @@ -113,6 +152,18 @@ def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str >>> load_requirements(path_req, "docs.txt", unfreeze="major") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE ['sphinx<6.0,>=4.0', ...] + Args: + path_dir: Directory containing the requirements file. + file_name: The requirements filename inside ``path_dir``. + unfreeze: Unfreeze strategy: ``"none"``, ``"major"``, or ``"all"`` (see ``_RequirementWithComment.adjust``). + + Returns: + A list of requirement strings adjusted according to ``unfreeze``. + + Raises: + ValueError: If ``unfreeze`` is not one of the supported options. + FileNotFoundError: If the composed path does not exist. + """ if unfreeze not in {"none", "major", "all"}: raise ValueError(f'unsupported option of "{unfreeze}"') diff --git a/src/lightning_utilities/test/warning.py b/src/lightning_utilities/test/warning.py index eec3a216..bcd31486 100644 --- a/src/lightning_utilities/test/warning.py +++ b/src/lightning_utilities/test/warning.py @@ -10,7 +10,16 @@ @contextmanager def no_warning_call(expected_warning: type[Warning] = Warning, match: Optional[str] = None) -> Generator: - """Check that no warning was raised/emitted under this context manager.""" + """Assert that no matching warning is emitted within the context. + + Args: + expected_warning: The warning class (or subclass) to check for. + match: Optional regular expression to match against the warning message. + + Raises: + AssertionError: If a warning of the given type (and matching the regex, if provided) is captured. + + """ with warnings.catch_warnings(record=True) as record: yield