From 73db4d45ce2b3b3e7a10b1b9d5a248beea6ba1ab Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Mon, 13 Jan 2025 22:54:43 +0800 Subject: [PATCH] fix: fix #8847 --- torchvision/models/_api.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 0999bf7ba6b..c539319df19 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -7,7 +7,7 @@ from functools import partial from inspect import signature from types import ModuleType -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union +from typing import Any, Callable, Dict, get_args, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union from torch import nn @@ -168,14 +168,13 @@ def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]: if "weights" not in sig.parameters: raise ValueError("The method is missing the 'weights' argument.") - ann = signature(fn).parameters["weights"].annotation + ann = sig.parameters["weights"].annotation weights_enum = None if isinstance(ann, type) and issubclass(ann, WeightsEnum): weights_enum = ann else: # handle cases like Union[Optional, T] - # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 - for t in ann.__args__: # type: ignore[union-attr] + for t in get_args(ann): # type: ignore[union-attr] if isinstance(t, type) and issubclass(t, WeightsEnum): weights_enum = t break