Skip to content
16 changes: 11 additions & 5 deletions src/lightning/pytorch/utilities/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from lightning_utilities.core.imports import RequirementCache
from torch import nn
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import Concatenate, ParamSpec, override

import lightning.pytorch as pl

Expand Down Expand Up @@ -104,26 +104,32 @@ def _check_mixed_imports(instance: object) -> None:
_R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method


class _restricted_classmethod_impl(Generic[_T, _R_co, _P]):
class _restricted_classmethod_impl(classmethod, Generic[_T, _P, _R_co]):
"""Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance
instead of a class type."""

method: Callable[Concatenate[type[_T], _P], _R_co]

def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None:
super().__init__(method)
self.method = method

def __get__(self, instance: Optional[_T], cls: type[_T]) -> Callable[_P, _R_co]:
@override
def __get__(self, instance: _T, cls: Optional[type[_T]] = None) -> Callable[_P, _R_co]: # type: ignore[override]
# The wrapper ensures that the method can be inspected, but not called on an instance
@functools.wraps(self.method)
def wrapper(*args: Any, **kwargs: Any) -> _R_co:
# Workaround for https://github.com/pytorch/pytorch/issues/67146
is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack())
cls_type = cls if cls is not None else type(instance)
if instance is not None and not is_scripting:
raise TypeError(
f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance."
f"The classmethod `{cls_type.__name__}.{self.method.__name__}` cannot be called on an instance."
" Please call it on the class type and make sure the return value is used."
)
return self.method(cls, *args, **kwargs)
return self.method(cls_type, *args, **kwargs)

wrapper.__func__ = self.method
return wrapper


Expand Down
Loading