Skip to content

Commit

Permalink
remove debug prints and take some notes
Browse files Browse the repository at this point in the history
This commit was sponsored by Matt Campbell, Jason Walker, and my other
patrons.  If you want to join them, you can support my work at
https://patreon.com/creatorglyph.
  • Loading branch information
glyph committed Dec 10, 2023
1 parent b3c182c commit 479bc3e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 55 deletions.
2 changes: 0 additions & 2 deletions automat/_test/test_typical.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,7 @@ def get(self):
def persistent(self) -> None:
pass

print('setting FirstState.ephemeral to enter')
FirstState.ephemeral.enter(Ephemeral)
print("i set it")

C = builder.buildClass()

Expand Down
127 changes: 74 additions & 53 deletions automat/_typical.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ def _getCore(


def _getCoreAttribute(attr: str) -> ValueBuilder:
# TODO: automatically getting attributes from the core object rather than
# the input signature is probably just a bad idea, way too much magic. it
# exists because the "state just constructed" hook of __post_init__ (or
# __init__ sometimes I guess) is an awkward way of populating
# derived-but-cached attributes. But it would probably be best to just get
# rid of these sematics and see if there's some explicit / opt-in version
# of this we could add as an API later.
def _coreGetter(
syntheticSelf: _TypicalInstance[InputsProto, StateCore],
stateCore: object,
Expand All @@ -163,8 +170,6 @@ def _stateBuilder(
stateFactory: Callable[..., Any],
suppliers: list[tuple[str, ValueBuilder]] = [],
):
print("suppliers", stateFactory, suppliers)

def _(
syntheticSelf: _TypicalInstance[InputsProto, StateCore],
stateCore: object,
Expand All @@ -173,8 +178,6 @@ def _(
kwargs: Dict[str, object],
) -> object:
toPass: Dict[str, object] = {}
print("sig", str(inputSignature.parameters))
print("a, kw", args, kwargs)
toPass.update(
inputSignature.bind(object(), *args, **kwargs).arguments
) # here's where we get the transition-supplied arguments
Expand All @@ -187,12 +190,61 @@ def _(
)
for (extraParamName, extraParamFactory) in suppliers
}
print("toPass", toPass, "computedParams", computedParams)
toPass.update(computedParams)
return stateFactory(**toPass)

return _


def _valueSuppliers(
factorySignature: Signature,
transitionSignature: Signature,
stateFactories: Dict[str, Callable[..., UserStateType]],
stateCoreType: type[object],
inputProtocols: frozenset[ProtocolAtRuntime[object]],
) -> Iterable[tuple[str, ValueBuilder]]:

factoryNeeds = set(factorySignature.parameters)
transitionSupplies = set(transitionSignature.parameters)
notSuppliedParams = factoryNeeds - transitionSupplies

for maybeTypeMismatch in factoryNeeds & transitionSupplies:
if (
transitionSignature.parameters[maybeTypeMismatch].annotation
!= factorySignature.parameters[maybeTypeMismatch].annotation
):
if (
factorySignature.parameters[maybeTypeMismatch].default
== Parameter.empty
):
notSuppliedParams.add(maybeTypeMismatch)

for notSuppliedByTransitionName in notSuppliedParams:
# These are the parameters we will need to supply.
notSuppliedByTransition = factorySignature.parameters[
notSuppliedByTransitionName
]
parameterType = notSuppliedByTransition.annotation
if parameterType.__name__ in stateFactories:
yield (
(
notSuppliedByTransitionName,
_getOtherState(parameterType.__name__),
)
)
elif parameterType is stateCoreType:
yield ((notSuppliedByTransitionName, _getCore))
elif parameterType in inputProtocols:
yield ((notSuppliedByTransitionName, _getSynthSelf))
else:
yield (
(
notSuppliedByTransitionName,
_getCoreAttribute(notSuppliedByTransitionName),
)
)


def _buildStateBuilder(
stateCoreType: type[object],
stateFactory: Callable[..., Any],
Expand All @@ -207,6 +259,10 @@ def _buildStateBuilder(
@param transitionMethod: The method from the state-machine protocol, which
documents its public parameters.
"""
# TODO: benchmark the generated function, it's probably going to be pretty
# performance sensitive, and probably switch over to codegen a-la attrs or
# dataclassess since that will probably be faster.

# the transition signature is empty / no arguments for the initial state
# build
transitionSignature = (
Expand All @@ -221,47 +277,16 @@ def _buildStateBuilder(
# supply them in other ways (default values will not be respected,
# attributes won't be pulled from the state core, etc)

def _valueSuppliers() -> Iterable[tuple[str, ValueBuilder]]:
factoryNeeds = set(factorySignature.parameters)
transitionSupplies = set(transitionSignature.parameters)
print("for:", stateFactory)
print("upon:", transitionMethod)
print("needs:", factoryNeeds)
print("supplies:", transitionSupplies)
notSuppliedParams = factoryNeeds - transitionSupplies
print("not supplied:", stateCoreType, notSuppliedParams)
# bug: even if it *is* supplied by the transition, the type must match as well.
for maybeTypeMismatch in factoryNeeds & transitionSupplies:
if transitionSignature.parameters[maybeTypeMismatch].annotation != factorySignature.parameters[maybeTypeMismatch].annotation:
if factorySignature.parameters[maybeTypeMismatch].default == Parameter.empty:
notSuppliedParams.add(maybeTypeMismatch)
for notSuppliedByTransitionName in notSuppliedParams:
# These are the parameters we will need to supply.
notSuppliedByTransition = factorySignature.parameters[
notSuppliedByTransitionName
]
parameterType = notSuppliedByTransition.annotation
if parameterType.__name__ in stateFactories:
yield (
(
notSuppliedByTransitionName,
_getOtherState(parameterType.__name__),
)
)
elif parameterType is stateCoreType:
yield ((notSuppliedByTransitionName, _getCore))
elif parameterType in inputProtocols:
yield ((notSuppliedByTransitionName, _getSynthSelf))
else:
yield (
(
notSuppliedByTransitionName,
_getCoreAttribute(notSuppliedByTransitionName),
)
)

print("building", transitionMethod, transitionSignature)
return _stateBuilder(transitionSignature, stateFactory, list(_valueSuppliers()))
suppliers = list(
_valueSuppliers(
factorySignature,
transitionSignature,
stateFactories,
stateCoreType,
inputProtocols,
)
)
return _stateBuilder(transitionSignature, stateFactory, suppliers)


_baseMethods = set(dir(Protocol))
Expand All @@ -284,10 +309,6 @@ def method(self: _TypicalInstance[InputsProto, StateCore], *a, **kw) -> object:
oldStateObject = self._stateCluster[oldStateName]
[[outputMethodName], tracer] = self._transitioner.transition(inputMethodName)
newStateName = self._transitioner._state
print(f"transition! {oldStateObject}")
print(
f"{oldStateName}.{inputMethodName}() [.{outputMethodName}()] => {newStateName}"
)
# here we need to invoke the output method
if outputMethodName is None:
self._stateCluster[newStateName] = errorState()
Expand All @@ -300,9 +321,10 @@ def method(self: _TypicalInstance[InputsProto, StateCore], *a, **kw) -> object:
newBuilt = self._stateCluster[newStateName] = stateBuilder(
self, self._stateCore, self._stateCluster, a, kw
)
print(f"built new state: {newBuilt}")
result = realMethod(*a, **kw)
if newStateName != oldStateName and not oldStateObject.__persistState__: # type:ignore[attr-defined]
if (
newStateName != oldStateName and not oldStateObject.__persistState__
): # type:ignore[attr-defined]
del self._stateCluster[oldStateName]
return result

Expand Down Expand Up @@ -579,7 +601,6 @@ def buildClass(self) -> _TypicalClass[InputsProto, StateCore, P]:
stateClassName,
newStateFactory,
) in buildAfterFactories:
print("buildAfter", output, stateCoreType, stateClassName)
assert (
getattr(output, "__stateBuilder__", None) is None
), "duplicate state builder"
Expand Down

0 comments on commit 479bc3e

Please sign in to comment.