diff --git a/core/actors/_base_actor.py b/core/actors/_base_actor.py index ea8793a0..34b1b672 100644 --- a/core/actors/_base_actor.py +++ b/core/actors/_base_actor.py @@ -3,6 +3,7 @@ from typing import Union, get_args, get_origin from core.commands.base import Command +from core.events.base import Event from core.interfaces.abstract_actor import AbstractActor, Ask, Message from core.queries.base import Query from infrastructure.event_dispatcher.event_dispatcher import EventDispatcher @@ -70,19 +71,27 @@ def _unregister_events(self): self._mailbox.unregister(event, self.on_receive) def _discover_events(self): + allowed_types = (Event, Query, Command) sig = inspect.signature(self.on_receive) - params = list(sig.parameters.values()) + params = sig.parameters.values() - if len(params) < 1: + if not params: return [] - event_type = params[0].annotation + event_type = next(iter(params)).annotation - events = [] + events = ( + get_args(event_type) if get_origin(event_type) is Union else [event_type] + ) - if get_origin(event_type) is Union: - events = get_args(event_type) - else: - events = [event_type] + invalid_events = [ + event for event in events if not issubclass(event, allowed_types) + ] + + if invalid_events: + raise RuntimeError( + f"Disallowed events: {', '.join(e.__name__ for e in invalid_events)}. " + f"Must be subclasses of {allowed_types}." + ) return list(set(events))