Skip to content
28 changes: 16 additions & 12 deletions src/psygnal/_group_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import sys
import warnings
import weakref
from collections.abc import Iterable
from contextlib import suppress
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Generic,
Literal,
Optional,
TypeVar,
Expand Down Expand Up @@ -39,6 +41,7 @@

T = TypeVar("T", bound=type)
S = TypeVar("S")
GroupType = TypeVar("GroupType", bound=SignalGroup)


_EQ_OPERATORS: dict[type, dict[str, EqOperator]] = {}
Expand Down Expand Up @@ -153,10 +156,10 @@ def _psygnal_relocate_info_(self, emission_info: EmissionInfo) -> EmissionInfo:

def _build_dataclass_signal_group(
cls: type,
signal_group_class: type[SignalGroup],
signal_group_class: type[GroupType],
equality_operators: Iterable[tuple[str, EqOperator]] | None = None,
signal_aliases: Mapping[str, str | None] | FieldAliasFunc | None = None,
) -> type[SignalGroup]:
) -> type[GroupType]:
"""Build a SignalGroup with events for each field in a dataclass.

Parameters
Expand Down Expand Up @@ -424,7 +427,7 @@ def _setattr_and_emit_(self: object, name: str, value: Any) -> None:
return _inner(super_setattr) if super_setattr else _inner


class SignalGroupDescriptor:
class SignalGroupDescriptor(Generic[GroupType]):
"""Create a [`psygnal.SignalGroup`][] on first instance attribute access.

This descriptor is designed to be used as a class attribute on a dataclass-like
Expand Down Expand Up @@ -544,12 +547,12 @@ def __init__(
warn_on_no_fields: bool = True,
cache_on_instance: bool = True,
patch_setattr: bool = True,
signal_group_class: type[SignalGroup] | None = None,
signal_group_class: type[GroupType] | None = None,
collect_fields: bool = True,
connect_child_events: bool = True,
signal_aliases: Mapping[str, str | None] | FieldAliasFunc | None = None,
):
grp_cls = signal_group_class or SignalGroup
grp_cls = signal_group_class or cast("type[GroupType]", SignalGroup)
if not (isinstance(grp_cls, type) and issubclass(grp_cls, SignalGroup)):
raise TypeError( # pragma: no cover
f"'signal_group_class' must be a subclass of SignalGroup, not {grp_cls}"
Expand All @@ -574,11 +577,11 @@ def __init__(
self._patch_setattr = patch_setattr
self._connect_child_events = connect_child_events

self._signal_group_class: type[SignalGroup] = grp_cls
self._signal_group_class: type[GroupType] = grp_cls
self._collect_fields = collect_fields
self._signal_aliases = signal_aliases

self._signal_groups: dict[int, type[SignalGroup]] = {}
self._signal_groups: dict[int, type[GroupType]] = {}

def __set_name__(self, owner: type, name: str) -> None:
"""Called when this descriptor is added to class `owner` as attribute `name`."""
Expand Down Expand Up @@ -618,11 +621,11 @@ def _do_patch_setattr(self, owner: type, with_aliases: bool = True) -> None:
def __get__(self, instance: None, owner: type) -> SignalGroupDescriptor: ...

@overload
def __get__(self, instance: object, owner: type) -> SignalGroup: ...
def __get__(self, instance: object, owner: type) -> GroupType: ...

def __get__(
self, instance: object, owner: type
) -> SignalGroup | SignalGroupDescriptor:
) -> GroupType | SignalGroupDescriptor:
"""Return a SignalGroup instance for `instance`."""
if instance is None:
return self
Expand Down Expand Up @@ -652,15 +655,16 @@ def __get__(
lambda: connect_child_events(instance, recurse=True, _group=grp)
)

return self._instance_map[obj_id]
return cast("GroupType", self._instance_map[obj_id])

def _get_signal_group(self, owner: type) -> type[SignalGroup]:
def _get_signal_group(self, owner: type) -> type[GroupType]:
type_id = id(owner)
if type_id not in self._signal_groups:
self._signal_groups[type_id] = self._create_group(owner)
return self._signal_groups[type_id]

def _create_group(self, owner: type) -> type[SignalGroup]:
def _create_group(self, owner: type) -> type[GroupType]:
# Do not collect fields from owner class, copy the SignalGroup
if not self._collect_fields:
# Do not collect fields from owner class
Group = copy.deepcopy(self._signal_group_class)
Expand Down
Loading