Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 62 additions & 7 deletions src/lightning_utilities/cli/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
def prune_packages_in_requirements(
packages: Union[str, Sequence[str]], req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL
) -> None:
"""Remove some packages from given requirement files."""
"""Remove one or more packages from the specified requirement files.

Args:
packages: A package name or list of package names to remove.
req_files: A path or list of paths to requirement files to process.

"""
if isinstance(packages, str):
packages = [packages]
if isinstance(req_files, str):
Expand All @@ -31,7 +37,13 @@ def prune_packages_in_requirements(


def _prune_packages(req_file: str, packages: Sequence[str]) -> None:
"""Remove some packages from given requirement files."""
"""Remove all occurrences of the given packages (by line prefix) from a requirements file.

Args:
req_file: Path to a requirements file.
packages: Package names to remove. Lines starting with any of these names will be dropped.

"""
with open(req_file) as fp:
lines = fp.readlines()

Expand All @@ -46,14 +58,27 @@ def _prune_packages(req_file: str, packages: Sequence[str]) -> None:


def _replace_min_req_in_txt(req_file: str) -> None:
"""Replace all occurrences of '>=' with '==' in a plain text requirements file.

Args:
req_file: Path to the requirements.txt-like file to update.

"""
with open(req_file) as fopen:
req = fopen.read().replace(">=", "==")
with open(req_file, "w") as fw:
fw.write(req)


def _replace_min_req_in_pyproject_toml(proj_file: str = "pyproject.toml") -> None:
"""Replace all `>=` with `==` in the standard pyproject.toml file in [project.dependencies]."""
"""Replace all '>=' with '==' in the [project.dependencies] section of a standard pyproject.toml.

Preserves formatting and comments using tomlkit.

Args:
proj_file: Path to the pyproject.toml file.

"""
import tomlkit

# Load and parse the existing pyproject.toml
Expand All @@ -77,7 +102,14 @@ def _replace_min_req_in_pyproject_toml(proj_file: str = "pyproject.toml") -> Non


def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL) -> None:
"""Replace the min package version by fixed one."""
"""Convert minimal version specifiers (>=) to pinned ones (==) in the given requirement files.

Supports plain *.txt requirements and pyproject.toml files. Unsupported file types trigger a warning.

Args:
req_files: A path or list of paths to requirement files to process.

"""
if isinstance(req_files, str):
req_files = [req_files]
for fname in req_files:
Expand All @@ -95,7 +127,14 @@ def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FI


def _replace_package_name_in_txt(req_file: str, old_package: str, new_package: str) -> None:
"""Replace one package by another with the same version in a given requirement file."""
"""Rename a package in a plain text requirements file, preserving version specifiers and markers.

Args:
req_file: Path to the requirements.txt-like file to update.
old_package: The original package name to replace.
new_package: The new package name to use.

"""
# load file
with open(req_file) as fopen:
requirements = fopen.readlines()
Expand All @@ -108,7 +147,14 @@ def _replace_package_name_in_txt(req_file: str, old_package: str, new_package: s


def _replace_package_name_in_pyproject_toml(proj_file: str, old_package: str, new_package: str) -> None:
"""Replace one package by another with the same version in the standard pyproject.toml file."""
"""Rename a package in the [project.dependencies] section of a standard pyproject.toml, preserving constraints.

Args:
proj_file: Path to the pyproject.toml file.
old_package: The original package name to replace.
new_package: The new package name to use.

"""
import tomlkit

# Load and parse the existing pyproject.toml
Expand All @@ -134,7 +180,16 @@ def _replace_package_name_in_pyproject_toml(proj_file: str, old_package: str, ne
def replace_package_in_requirements(
old_package: str, new_package: str, req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL
) -> None:
"""Replace one package by another with same version in given requirement files."""
"""Rename a package across multiple requirement files while keeping version constraints intact.

Supports plain *.txt requirements and pyproject.toml files. Unsupported file types trigger a warning.

Args:
old_package: The original package name to replace.
new_package: The new package name to use.
req_files: A path or list of paths to requirement files to process.

"""
if isinstance(req_files, str):
req_files = [req_files]
for fname in req_files:
Expand Down
32 changes: 18 additions & 14 deletions src/lightning_utilities/core/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@


def is_namedtuple(obj: object) -> bool:
"""Check if object is type nametuple."""
"""Return True if the given object is a namedtuple instance.

This checks for a tuple with the namedtuple-specific attributes `_asdict` and `_fields`.

"""
# https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")


def is_dataclass_instance(obj: object) -> bool:
"""Check if object is dataclass."""
"""Return True if the given object is a dataclass instance (not a dataclass type)."""
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)

Expand Down Expand Up @@ -195,24 +199,24 @@ def apply_to_collections(
wrong_dtype: Optional[Union[type, tuple[type]]] = None,
**kwargs: Any,
) -> Any:
"""Zips two collections and applies a function to their items of a certain dtype.
"""Zip two collections and apply a function to items of a certain dtype.

Args:
data1: The first collection
data2: The second collection
dtype: the given function will be applied to all elements of this dtype
function: the function to apply
*args: positional arguments (will be forwarded to calls of ``function``)
wrong_dtype: the given function won't be applied if this type is specified and the given collections
is of the ``wrong_dtype`` even if it is of type ``dtype``
**kwargs: keyword arguments (will be forwarded to calls of ``function``)
data1: The first collection. If ``None`` and ``data2`` is not ``None``, the arguments are swapped.
data2: The second collection. May be ``None`` to apply ``function`` only to ``data1``.
dtype: The type(s) for which the given ``function`` will be applied to matching elements.
function: The function to apply to matching elements.
*args: Positional arguments forwarded to calls of ``function``.
wrong_dtype: If specified, ``function`` won't be applied to elements of this type even if they match ``dtype``.
**kwargs: Keyword arguments forwarded to calls of ``function``.

Returns:
The resulting collection
A collection with the same structure as the input where matching elements are transformed.

Raises:
AssertionError:
If sequence collections have different data sizes.
ValueError: If sequence collections have different sizes.
TypeError: If dataclass inputs are mismatched (different types or fields), or if ``data1`` is a
dataclass instance but ``data2`` is not.

"""
if data1 is None:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_utilities/core/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") ->

@classmethod
def try_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> Optional["StrEnum"]:
"""Try to create emun and if it does not match any, return `None`."""
"""Try to create the enum; if no match is found, emit a warning and return ``None``."""
try:
return cls.from_str(value, source)
except ValueError:
Expand Down
12 changes: 7 additions & 5 deletions src/lightning_utilities/core/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,14 @@ def __init__(self, module_name: str, callback: Optional[Callable] = None) -> Non
self._callback = callback

def __getattr__(self, item: str) -> Any:
"""Overwrite attribute access to attribute."""
"""Lazily import the underlying module and delegate attribute access to it."""
if self._module is None:
self._import_module()

return getattr(self._module, item)

def __dir__(self) -> list[str]:
"""Overwrite attribute access for dictionary."""
"""Lazily import the underlying module and return its attributes for introspection (dir())."""
if self._module is None:
self._import_module()

Expand Down Expand Up @@ -317,11 +317,13 @@ def lazy_import(module_name: str, callback: Optional[Callable] = None) -> LazyMo


def requires(*module_path_version: str, raise_exception: bool = True) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Wrap early import failure with some nice exception message.
"""Decorator to check optional dependencies at call time with a clear error/warning message.

Args:
module_path_version: python package path (e.g. `torch.cuda`) or pip like requiremsnt (e.g. `torch>=2.0.0`)
raise_exception: how strict the check shall be if exit the code or just warn user
module_path_version: Python module paths (e.g., ``"torch.cuda"``) and/or pip-style requirements
(e.g., ``"torch>=2.0.0"``) to verify.
raise_exception: If ``True``, raise ``ModuleNotFoundError`` when requirements are not satisfied;
otherwise emit a warning and proceed to call the function.

Example:
>>> @requires("libpath", raise_exception=bool(int(os.getenv("LIGHTING_TESTING", "0"))))
Expand Down
4 changes: 2 additions & 2 deletions src/lightning_utilities/core/inheritance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def get_all_subclasses_iterator(cls: type) -> Iterator[type]:
"""Iterate over all subclasses."""
"""Depth-first iterator over all subclasses of ``cls`` (recursively)."""

def recurse(cl: type) -> Iterator[type]:
for subclass in cl.__subclasses__():
Expand All @@ -17,5 +17,5 @@ def recurse(cl: type) -> Iterator[type]:


def get_all_subclasses(cls: type) -> set[type]:
"""List all subclasses of a class."""
"""Return a set containing all subclasses of ``cls`` discovered recursively."""
return set(get_all_subclasses_iterator(cls))
16 changes: 15 additions & 1 deletion src/lightning_utilities/core/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,21 @@


def is_overridden(method_name: str, instance: object, parent: type[object]) -> bool:
"""Check if a method of a given object was overwritten."""
"""Return True if ``instance`` overrides ``parent.method_name``.

Supports functions wrapped with ``functools.wraps``, context managers (``__wrapped__``),
``unittest.mock.Mock(wraps=...)``, and ``functools.partial``. If the parent does not define
``method_name``, a ``ValueError`` is raised.

Args:
method_name: The name of the method to check.
instance: The object instance to inspect.
parent: The parent class that declares the original method.

Returns:
True if the method implementation on the instance differs from the parent's; otherwise False.

"""
instance_attr = getattr(instance, method_name, None)
if instance_attr is None:
return False
Expand Down
15 changes: 8 additions & 7 deletions src/lightning_utilities/core/rank_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ def rank_zero_only(fn: Callable[P, T], default: T) -> Callable[P, T]: ...


def rank_zero_only(fn: Callable[P, T], default: Optional[T] = None) -> Callable[P, Optional[T]]:
"""Wrap a function to call internal function only in rank zero.
"""Decorator to run the wrapped function only on global rank 0.

Function that can be used as a decorator to enable a function/method being called only on global rank 0.
Set ``rank_zero_only.rank`` before use. On non-zero ranks, the function is skipped and the provided
``default`` is returned (or ``None`` if not given).

"""

Expand Down Expand Up @@ -86,30 +87,30 @@ def rank_zero_deprecation(message: Union[str, Warning], stacklevel: int = 5, **k


def rank_prefixed_message(message: str, rank: Optional[int]) -> str:
"""Add a prefix with the rank to a message."""
"""Add a ``[rank: X]`` prefix to the message if ``rank`` is provided; otherwise return the message unchanged."""
if rank is not None:
# specify the rank of the process being logged
return f"[rank: {rank}] {message}"
return message


class WarningCache(set):
"""Cache for warnings."""
"""A simple de-duplication cache for messages to avoid emitting the same warning/info multiple times."""

def warn(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
"""Trigger warning message."""
"""Emit a warning once on global rank 0; subsequent identical messages are suppressed."""
if message not in self:
self.add(message)
rank_zero_warn(message, stacklevel=stacklevel, **kwargs)

def deprecation(self, message: str, stacklevel: int = 6, **kwargs: Any) -> None:
"""Trigger deprecation message."""
"""Emit a deprecation warning once on global rank 0; subsequent identical messages are suppressed."""
if message not in self:
self.add(message)
rank_zero_deprecation(message, stacklevel=stacklevel, **kwargs)

def info(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
"""Trigger info message."""
"""Emit an info-level log once on global rank 0; subsequent identical messages are suppressed."""
if message not in self:
self.add(message)
rank_zero_info(message, stacklevel=stacklevel, **kwargs)
17 changes: 10 additions & 7 deletions src/lightning_utilities/docs/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@


def _transform_changelog(path_in: str, path_out: str) -> None:
"""Adjust changelog titles so not to be duplicated.
"""Adjust changelog headers to avoid duplication of short subtitles.

Args:
path_in: input MD file
path_out: output also MD file
path_in: Input Markdown file path.
path_out: Output Markdown file path.

"""
with open(path_in) as fp:
Expand Down Expand Up @@ -99,12 +99,15 @@ def _load_pypi_versions(package_name: str) -> list[str]:


def _update_link_based_imported_package(link: str, pkg_ver: str, version_digits: Optional[int]) -> str:
"""Adjust the linked external docs to be local.
"""Resolve a ``{package.version}`` placeholder in a link using the latest available version.

Args:
link: the source link to be replaced
pkg_ver: the target link to be replaced, if ``{package.version}`` is included it will be replaced accordingly
version_digits: for semantic versioning, how many digits to be considered
link: The link template containing a ``{...}`` placeholder to replace.
pkg_ver: A dotted path to resolve the version (e.g., ``"numpy.__version__"``).
version_digits: Number of version components to keep (e.g., ``2`` -> ``"1.26"``). If ``None``, keep all.

Returns:
The link with the ``{...}`` placeholder replaced by a version string.

"""
pkg_att = pkg_ver.split(".")
Expand Down
Loading
Loading