diff --git a/automat/_test/test_type_based.py b/automat/_test/test_type_based.py index 4b35733..117f18c 100644 --- a/automat/_test/test_type_based.py +++ b/automat/_test/test_type_based.py @@ -342,3 +342,25 @@ def test_buildLock(self) -> None: builder.state("hello") with self.assertRaises(AlreadyBuiltError): builder.build() + + def test_methodMembership(self) -> None: + """ + Input methods must be members of their protocol. + """ + builder = TypeMachineBuilder(TestProtocol, NoOpCore) + state = builder.state("test-state") + def stateful(proto: TestProtocol, core: NoOpCore) -> int: + return 4 + state2 = builder.state("state2", stateful) + def change(self: TestProtocol) -> None: + "fake copy" + def rogue(self: TestProtocol) -> int: + "not present" + return 3 + with self.assertRaises(ValueError): + state.upon(change) + with self.assertRaises(ValueError) as ve: + state2.upon(change) + print(ve.exception) + with self.assertRaises(ValueError): + state.upon(rogue) diff --git a/automat/_typified.py b/automat/_typified.py index 413e5f7..3078d5c 100644 --- a/automat/_typified.py +++ b/automat/_typified.py @@ -252,6 +252,7 @@ class TypifiedState(Generic[InputProtocol, Core]): def upon( self, input: Callable[Concatenate[InputProtocol, P], R] ) -> UponFromNo[InputProtocol, Core, P, R]: + self.builder._checkMembership(input) return UponFromNo(self, input) def _produce_outputs( @@ -292,6 +293,7 @@ def upon( UponFromData[InputProtocol, Core, P, R, Data] | UponFromNo[InputProtocol, Core, P, R] ): + self.builder._checkMembership(input) if nodata: return UponFromNo(self, input) else: @@ -590,3 +592,13 @@ def build(self) -> TypifiedMachine[InputProtocol, Core]: ) return TypifiedMachine(runtime_type, self._automaton) + + def _checkMembership(self, input: Callable[..., object]) -> None: + """ + Ensure that ``input`` is a valid member function of the input protocol, + not just a function that happens to take the right first argument. + """ + if (checked := getattr(self.protocol, input.__name__, None)) is not input: + raise ValueError( + f"{input.__qualname__} is not a member of {self.protocol.__module__}.{self.protocol.__name__}" + )