Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions wandb/sdk/lib/fsm.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,19 @@ def __init__(
) -> None:
self._states = states
self._table = table
# Precompute and store type(s) for states with performance-critical isinstance checks
self._fsm_state_exit_types = tuple(
type(s) for s in states if isinstance(s, FsmStateExit)
)
self._fsm_state_stay_types = tuple(
type(s) for s in states if isinstance(s, FsmStateStay)
)
self._fsm_state_enter_with_ctx_types = tuple(
type(s) for s in states if isinstance(s, FsmStateEnterWithContext)
)
self._fsm_state_enter_types = tuple(
type(s) for s in states if isinstance(s, FsmStateEnter)
)
self._state_dict = {type(s): s for s in states}
self._state = self._state_dict[type(states[0])]

Expand All @@ -133,23 +146,29 @@ def _transition(
if action:
action(inputs)

# Use cached types to avoid frequent repeated isinstance work
state = self._state
state_type = type(state)
context = None
if isinstance(self._state, FsmStateExit):
context = self._state.on_exit(inputs)
if state_type in self._fsm_state_exit_types:
context = state.on_exit(inputs)

prev_state = type(self._state)
prev_state = state_type
if prev_state == new_state:
if isinstance(self._state, FsmStateStay):
self._state.on_stay(inputs)
if prev_state in self._fsm_state_stay_types:
state.on_stay(inputs)
else:
self._state = self._state_dict[new_state]
if context and isinstance(self._state, FsmStateEnterWithContext):
self._state.on_enter(inputs, context=context)
elif isinstance(self._state, FsmStateEnter):
self._state.on_enter(inputs)
new_state_obj = self._state_dict[new_state]
self._state = new_state_obj
if context and type(new_state_obj) in self._fsm_state_enter_with_ctx_types:
new_state_obj.on_enter(inputs, context=context)
elif type(new_state_obj) in self._fsm_state_enter_types:
new_state_obj.on_enter(inputs)

def _check_transitions(self, inputs: T_FsmInputs) -> None:
for entry in self._table[type(self._state)]:
state_type = type(self._state)
entries = self._table[state_type]
for entry in entries:
if entry.condition(inputs):
self._transition(inputs, entry.target_state, entry.action)
return
Expand Down