|
15 | 15 | import inspect
|
16 | 16 | import logging
|
17 | 17 | import os
|
18 |
| -from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar |
| 18 | +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar |
19 | 19 |
|
20 | 20 | from lightning_utilities.core.imports import RequirementCache
|
21 | 21 | from torch import nn
|
@@ -104,14 +104,15 @@ def _check_mixed_imports(instance: object) -> None:
|
104 | 104 | _R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method
|
105 | 105 |
|
106 | 106 |
|
107 |
| -class _restricted_classmethod_impl(Generic[_T, _R_co, _P]): |
| 107 | +class _restricted_classmethod_impl(classmethod): |
108 | 108 | """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance
|
109 | 109 | instead of a class type."""
|
110 | 110 |
|
111 | 111 | def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None:
|
| 112 | + super().__init__(method) |
112 | 113 | self.method = method
|
113 | 114 |
|
114 |
| - def __get__(self, instance: Optional[_T], cls: type[_T]) -> Callable[_P, _R_co]: |
| 115 | + def __get__(self, instance: Optional[_T], cls: type[_T] | None = None) -> Callable[_P, _R_co]: |
115 | 116 | # The wrapper ensures that the method can be inspected, but not called on an instance
|
116 | 117 | @functools.wraps(self.method)
|
117 | 118 | def wrapper(*args: Any, **kwargs: Any) -> _R_co:
|
|
0 commit comments