diff --git a/automat/_test/test_typical.py b/automat/_test/test_typical.py index 49c7749..b252084 100644 --- a/automat/_test/test_typical.py +++ b/automat/_test/test_typical.py @@ -312,9 +312,7 @@ def get(self): def persistent(self) -> None: pass -print('setting FirstState.ephemeral to enter') FirstState.ephemeral.enter(Ephemeral) -print("i set it") C = builder.buildClass() diff --git a/automat/_typical.py b/automat/_typical.py index d8fdb21..95e5b67 100644 --- a/automat/_typical.py +++ b/automat/_typical.py @@ -140,6 +140,13 @@ def _getCore( def _getCoreAttribute(attr: str) -> ValueBuilder: + # TODO: automatically getting attributes from the core object rather than + # the input signature is probably just a bad idea, way too much magic. it + # exists because the "state just constructed" hook of __post_init__ (or + # __init__ sometimes I guess) is an awkward way of populating + # derived-but-cached attributes. But it would probably be best to just get + # rid of these sematics and see if there's some explicit / opt-in version + # of this we could add as an API later. def _coreGetter( syntheticSelf: _TypicalInstance[InputsProto, StateCore], stateCore: object, @@ -163,8 +170,6 @@ def _stateBuilder( stateFactory: Callable[..., Any], suppliers: list[tuple[str, ValueBuilder]] = [], ): - print("suppliers", stateFactory, suppliers) - def _( syntheticSelf: _TypicalInstance[InputsProto, StateCore], stateCore: object, @@ -173,8 +178,6 @@ def _( kwargs: Dict[str, object], ) -> object: toPass: Dict[str, object] = {} - print("sig", str(inputSignature.parameters)) - print("a, kw", args, kwargs) toPass.update( inputSignature.bind(object(), *args, **kwargs).arguments ) # here's where we get the transition-supplied arguments @@ -187,12 +190,61 @@ def _( ) for (extraParamName, extraParamFactory) in suppliers } - print("toPass", toPass, "computedParams", computedParams) toPass.update(computedParams) return stateFactory(**toPass) + return _ +def _valueSuppliers( + factorySignature: Signature, + transitionSignature: Signature, + stateFactories: Dict[str, Callable[..., UserStateType]], + stateCoreType: type[object], + inputProtocols: frozenset[ProtocolAtRuntime[object]], +) -> Iterable[tuple[str, ValueBuilder]]: + + factoryNeeds = set(factorySignature.parameters) + transitionSupplies = set(transitionSignature.parameters) + notSuppliedParams = factoryNeeds - transitionSupplies + + for maybeTypeMismatch in factoryNeeds & transitionSupplies: + if ( + transitionSignature.parameters[maybeTypeMismatch].annotation + != factorySignature.parameters[maybeTypeMismatch].annotation + ): + if ( + factorySignature.parameters[maybeTypeMismatch].default + == Parameter.empty + ): + notSuppliedParams.add(maybeTypeMismatch) + + for notSuppliedByTransitionName in notSuppliedParams: + # These are the parameters we will need to supply. + notSuppliedByTransition = factorySignature.parameters[ + notSuppliedByTransitionName + ] + parameterType = notSuppliedByTransition.annotation + if parameterType.__name__ in stateFactories: + yield ( + ( + notSuppliedByTransitionName, + _getOtherState(parameterType.__name__), + ) + ) + elif parameterType is stateCoreType: + yield ((notSuppliedByTransitionName, _getCore)) + elif parameterType in inputProtocols: + yield ((notSuppliedByTransitionName, _getSynthSelf)) + else: + yield ( + ( + notSuppliedByTransitionName, + _getCoreAttribute(notSuppliedByTransitionName), + ) + ) + + def _buildStateBuilder( stateCoreType: type[object], stateFactory: Callable[..., Any], @@ -207,6 +259,10 @@ def _buildStateBuilder( @param transitionMethod: The method from the state-machine protocol, which documents its public parameters. """ + # TODO: benchmark the generated function, it's probably going to be pretty + # performance sensitive, and probably switch over to codegen a-la attrs or + # dataclassess since that will probably be faster. + # the transition signature is empty / no arguments for the initial state # build transitionSignature = ( @@ -221,47 +277,16 @@ def _buildStateBuilder( # supply them in other ways (default values will not be respected, # attributes won't be pulled from the state core, etc) - def _valueSuppliers() -> Iterable[tuple[str, ValueBuilder]]: - factoryNeeds = set(factorySignature.parameters) - transitionSupplies = set(transitionSignature.parameters) - print("for:", stateFactory) - print("upon:", transitionMethod) - print("needs:", factoryNeeds) - print("supplies:", transitionSupplies) - notSuppliedParams = factoryNeeds - transitionSupplies - print("not supplied:", stateCoreType, notSuppliedParams) - # bug: even if it *is* supplied by the transition, the type must match as well. - for maybeTypeMismatch in factoryNeeds & transitionSupplies: - if transitionSignature.parameters[maybeTypeMismatch].annotation != factorySignature.parameters[maybeTypeMismatch].annotation: - if factorySignature.parameters[maybeTypeMismatch].default == Parameter.empty: - notSuppliedParams.add(maybeTypeMismatch) - for notSuppliedByTransitionName in notSuppliedParams: - # These are the parameters we will need to supply. - notSuppliedByTransition = factorySignature.parameters[ - notSuppliedByTransitionName - ] - parameterType = notSuppliedByTransition.annotation - if parameterType.__name__ in stateFactories: - yield ( - ( - notSuppliedByTransitionName, - _getOtherState(parameterType.__name__), - ) - ) - elif parameterType is stateCoreType: - yield ((notSuppliedByTransitionName, _getCore)) - elif parameterType in inputProtocols: - yield ((notSuppliedByTransitionName, _getSynthSelf)) - else: - yield ( - ( - notSuppliedByTransitionName, - _getCoreAttribute(notSuppliedByTransitionName), - ) - ) - - print("building", transitionMethod, transitionSignature) - return _stateBuilder(transitionSignature, stateFactory, list(_valueSuppliers())) + suppliers = list( + _valueSuppliers( + factorySignature, + transitionSignature, + stateFactories, + stateCoreType, + inputProtocols, + ) + ) + return _stateBuilder(transitionSignature, stateFactory, suppliers) _baseMethods = set(dir(Protocol)) @@ -284,10 +309,6 @@ def method(self: _TypicalInstance[InputsProto, StateCore], *a, **kw) -> object: oldStateObject = self._stateCluster[oldStateName] [[outputMethodName], tracer] = self._transitioner.transition(inputMethodName) newStateName = self._transitioner._state - print(f"transition! {oldStateObject}") - print( - f"{oldStateName}.{inputMethodName}() [.{outputMethodName}()] => {newStateName}" - ) # here we need to invoke the output method if outputMethodName is None: self._stateCluster[newStateName] = errorState() @@ -300,9 +321,10 @@ def method(self: _TypicalInstance[InputsProto, StateCore], *a, **kw) -> object: newBuilt = self._stateCluster[newStateName] = stateBuilder( self, self._stateCore, self._stateCluster, a, kw ) - print(f"built new state: {newBuilt}") result = realMethod(*a, **kw) - if newStateName != oldStateName and not oldStateObject.__persistState__: # type:ignore[attr-defined] + if ( + newStateName != oldStateName and not oldStateObject.__persistState__ + ): # type:ignore[attr-defined] del self._stateCluster[oldStateName] return result @@ -579,7 +601,6 @@ def buildClass(self) -> _TypicalClass[InputsProto, StateCore, P]: stateClassName, newStateFactory, ) in buildAfterFactories: - print("buildAfter", output, stateCoreType, stateClassName) assert ( getattr(output, "__stateBuilder__", None) is None ), "duplicate state builder"