Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Oct 30, 2025

📄 371% (3.71x) speedup for FsmWithContext._check_transitions in wandb/sdk/lib/fsm.py

⏱️ Runtime : 321 microseconds 68.2 microseconds (best of 67 runs)

📝 Explanation and details

The optimized code achieves a 370% speedup by replacing expensive isinstance() calls with fast type() in tuple lookups and reducing attribute access overhead.

Key Optimizations:

  1. Precomputed Type Tuples: During initialization, the code creates tuples of state types for each protocol (_fsm_state_exit_types, _fsm_state_stay_types, etc.). This converts runtime isinstance(obj, Protocol) checks into type(obj) in precomputed_tuple lookups, which are significantly faster.

  2. Attribute Access Reduction: Local variables (state, state_type) cache frequently accessed attributes, eliminating repeated self._state and type(self._state) calls within tight loops.

  3. Table Lookup Optimization: In _check_transitions, the table lookup self._table[type(self._state)] is computed once and stored in entries, avoiding repeated dictionary lookups.

Why This Works:

  • isinstance() calls dominated the original runtime (59.5% + 13.3% + 22.1% = ~95% of _transition time)
  • Type membership tests (type(x) in tuple) are O(1) for small tuples and much faster than isinstance() with protocol classes
  • Attribute access (self._state) involves Python's method resolution, while local variables are direct memory lookups

Performance by Test Case:
The optimization excels particularly with high-transition workloads:

  • Simple transitions: 828-1036% faster
  • Complex state protocols: 114-638% faster
  • Large-scale scenarios with many transitions: 213-2170% faster
  • Edge cases with no transitions show minimal impact (~5% slower due to setup overhead)

This optimization maintains identical behavior while dramatically improving performance for FSM-heavy workloads.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 88 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from typing import Any, Callable, Dict, Optional, Sequence, Type

# imports
import pytest  # used for our unit tests
from wandb.sdk.lib.fsm import FsmWithContext

# --- UNIT TESTS ---

# --- Basic Test Cases ---

















#------------------------------------------------
from typing import Any, Callable, Dict, Optional, Sequence, Type

# imports
import pytest
from wandb.sdk.lib.fsm import FsmWithContext


class FsmStateOutput:
    def output(self, inputs): pass

class FsmStateEnter:
    def on_enter(self, inputs): pass

class FsmStateEnterWithContext:
    def on_enter(self, inputs, context=None): pass

class FsmStateStay:
    def on_stay(self, inputs): pass

class FsmStateExit:
    def on_exit(self, inputs): pass

# FsmEntry for testing
class FsmEntry:
    def __init__(self, condition: Callable[[Any], bool], target_state: Type, action: Optional[Callable[[Any], None]] = None):
        self.condition = condition
        self.target_state = target_state
        self.action = action

# --- Helper classes for test states ---

class StateA(FsmStateEnter):
    def __init__(self):
        self.entered = False
    def on_enter(self, inputs):
        self.entered = True

class StateB(FsmStateStay):
    def __init__(self):
        self.stayed = False
    def on_stay(self, inputs):
        self.stayed = True

class StateC(FsmStateExit):
    def __init__(self):
        self.exited = False
    def on_exit(self, inputs):
        self.exited = True
        return "context_from_exit"

class StateD(FsmStateEnterWithContext):
    def __init__(self):
        self.entered_with_context = None
    def on_enter(self, inputs, context=None):
        self.entered_with_context = context

class StateE(FsmStateOutput):
    def __init__(self):
        self.output_called = False
    def output(self, inputs):
        self.output_called = True

# --- Test suite for _check_transitions ---

# 1. Basic Test Cases

def test_basic_transition_to_new_state():
    """Test normal transition from StateA to StateB with condition True."""
    state_a = StateA()
    state_b = StateB()
    table = {
        StateA: [FsmEntry(lambda x: x == "go", StateB, None)],
        StateB: []
    }
    fsm = FsmWithContext([state_a, state_b], table)
    fsm._check_transitions("go") # 21.9μs -> 1.93μs (1036% faster)

def test_basic_no_transition_when_condition_false():
    """Test that no transition occurs if condition is False."""
    state_a = StateA()
    state_b = StateB()
    table = {
        StateA: [FsmEntry(lambda x: x == "stop", StateB, None)],
        StateB: []
    }
    fsm = FsmWithContext([state_a, state_b], table)
    fsm._check_transitions("go") # 922ns -> 929ns (0.753% slower)

def test_basic_action_called_on_transition():
    """Test that the action is called during transition."""
    state_a = StateA()
    state_b = StateB()
    called = []
    def action(inputs): called.append(inputs)
    table = {
        StateA: [FsmEntry(lambda x: True, StateB, action)],
        StateB: []
    }
    fsm = FsmWithContext([state_a, state_b], table)
    fsm._check_transitions("anything") # 20.8μs -> 2.24μs (828% faster)

def test_basic_on_enter_called_on_new_state():
    """Test that on_enter is called when entering a new state."""
    state_a = StateA()
    state_b = StateB()
    table = {
        StateA: [FsmEntry(lambda x: True, StateB, None)],
        StateB: []
    }
    fsm = FsmWithContext([state_a, state_b], table)
    fsm._check_transitions("trigger") # 20.2μs -> 1.96μs (934% faster)

def test_basic_on_stay_called_when_staying():
    """Test that on_stay is called if staying in the same state and state supports on_stay."""
    state_b = StateB()
    table = {
        StateB: [FsmEntry(lambda x: True, StateB, None)]
    }
    fsm = FsmWithContext([state_b], table)
    fsm._check_transitions("trigger") # 15.8μs -> 1.93μs (719% faster)

# 2. Edge Test Cases

def test_edge_no_entries_in_table():
    """Test behavior when there are no entries for current state."""
    state_a = StateA()
    table = {StateA: []}
    fsm = FsmWithContext([state_a], table)
    fsm._check_transitions("anything") # 597ns -> 649ns (8.01% slower)

def test_edge_multiple_entries_first_true():
    """Test that only the first matching entry is used."""
    state_a = StateA()
    state_b = StateB()
    state_c = StateC()
    table = {
        StateA: [
            FsmEntry(lambda x: x == "foo", StateB, None),
            FsmEntry(lambda x: True, StateC, None)
        ],
        StateB: [],
        StateC: []
    }
    fsm = FsmWithContext([state_a, state_b, state_c], table)
    fsm._check_transitions("foo") # 20.7μs -> 1.95μs (964% faster)

def test_edge_multiple_entries_second_true():
    """Test that second entry is used if first condition is False."""
    state_a = StateA()
    state_b = StateB()
    state_c = StateC()
    table = {
        StateA: [
            FsmEntry(lambda x: x == "bar", StateB, None),
            FsmEntry(lambda x: x == "baz", StateC, None)
        ],
        StateB: [],
        StateC: []
    }
    fsm = FsmWithContext([state_a, state_b, state_c], table)
    fsm._check_transitions("baz") # 21.0μs -> 2.26μs (828% faster)

def test_edge_context_passed_from_exit_to_enter_with_context():
    """Test that context from on_exit is passed to on_enter of next state."""
    state_c = StateC()
    state_d = StateD()
    table = {
        StateC: [FsmEntry(lambda x: True, StateD, None)],
        StateD: []
    }
    fsm = FsmWithContext([state_c, state_d], table)
    fsm._check_transitions("trigger") # 5.53μs -> 2.58μs (114% faster)

def test_edge_on_enter_called_without_context():
    """Test that on_enter is called without context if context is None."""
    state_a = StateA()
    state_d = StateD()
    table = {
        StateA: [FsmEntry(lambda x: True, StateD, None)],
        StateD: []
    }
    fsm = FsmWithContext([state_a, state_d], table)
    fsm._check_transitions("trigger") # 16.0μs -> 2.17μs (638% faster)

def test_edge_state_with_output_protocol():
    """Test that output protocol is not called by _check_transitions."""
    state_e = StateE()
    table = {StateE: [FsmEntry(lambda x: False, StateE, None)]}
    fsm = FsmWithContext([state_e], table)
    fsm._check_transitions("input") # 857ns -> 901ns (4.88% slower)


def test_edge_action_raises_exception():
    """Test that exception in action propagates."""
    state_a = StateA()
    state_b = StateB()
    def action(inputs): raise ValueError("fail")
    table = {
        StateA: [FsmEntry(lambda x: True, StateB, action)],
        StateB: []
    }
    fsm = FsmWithContext([state_a, state_b], table)
    with pytest.raises(ValueError):
        fsm._check_transitions("trigger") # 2.13μs -> 1.86μs (14.2% faster)

def test_edge_table_missing_current_state_key():
    """Test behavior if current state type is not in table (should raise KeyError)."""
    state_a = StateA()
    table = {}
    fsm = FsmWithContext([state_a], table)
    with pytest.raises(KeyError):
        fsm._check_transitions("anything") # 983ns -> 944ns (4.13% faster)

# 3. Large Scale Test Cases

def test_large_scale_many_states_and_entries():
    """Test FSM with many states and entries for scalability."""
    class DummyState(FsmStateEnter):
        def __init__(self, idx):
            self.idx = idx
            self.entered = False
        def on_enter(self, inputs):
            self.entered = True

    num_states = 100
    states = [DummyState(i) for i in range(num_states)]
    # Each state transitions to the next state if input == its index
    table = {
        type(s): [FsmEntry(lambda x, idx=s.idx: x == idx, type(states[(s.idx+1)%num_states]), None)]
        for s in states
    }
    fsm = FsmWithContext(states, table)
    # Start at state 0, input 0 should move to state 1
    fsm._check_transitions(0) # 991ns -> 970ns (2.16% faster)
    # Now input 1 should move to state 2
    fsm._check_transitions(1) # 470ns -> 493ns (4.67% slower)


def test_large_scale_first_entry_true_rest_false():
    """Test that only first matching entry is used among many entries."""
    class DummyState(FsmStateEnter):
        def __init__(self):
            self.entered = False
        def on_enter(self, inputs): self.entered = True
    state_a = DummyState()
    state_b = DummyState()
    entries = [FsmEntry(lambda x: True, DummyState, None)] + [FsmEntry(lambda x: False, DummyState, None) for _ in range(999)]
    table = {DummyState: entries}
    fsm = FsmWithContext([state_a, state_b], table)
    fsm._check_transitions("input") # 39.6μs -> 1.74μs (2170% faster)

def test_large_scale_many_transitions():
    """Test FSM with a chain of transitions."""
    class ChainState(FsmStateEnter):
        def __init__(self, idx):
            self.idx = idx
            self.entered = False
        def on_enter(self, inputs): self.entered = True
    num_states = 50
    states = [ChainState(i) for i in range(num_states)]
    table = {
        type(states[i]): [FsmEntry(lambda x, idx=i: x == idx, type(states[(i+1)%num_states]), None)]
        for i in range(num_states)
    }
    fsm = FsmWithContext(states, table)
    for i in range(num_states):
        fsm._check_transitions(i) # 52.6μs -> 16.8μs (213% faster)

To edit these changes git checkout codeflash/optimize-FsmWithContext._check_transitions-mhdq0pkk and push.

Codeflash Static Badge

The optimized code achieves a **370% speedup** by replacing expensive `isinstance()` calls with fast `type() in tuple` lookups and reducing attribute access overhead.

**Key Optimizations:**

1. **Precomputed Type Tuples**: During initialization, the code creates tuples of state types for each protocol (`_fsm_state_exit_types`, `_fsm_state_stay_types`, etc.). This converts runtime `isinstance(obj, Protocol)` checks into `type(obj) in precomputed_tuple` lookups, which are significantly faster.

2. **Attribute Access Reduction**: Local variables (`state`, `state_type`) cache frequently accessed attributes, eliminating repeated `self._state` and `type(self._state)` calls within tight loops.

3. **Table Lookup Optimization**: In `_check_transitions`, the table lookup `self._table[type(self._state)]` is computed once and stored in `entries`, avoiding repeated dictionary lookups.

**Why This Works:**
- `isinstance()` calls dominated the original runtime (59.5% + 13.3% + 22.1% = ~95% of `_transition` time)
- Type membership tests (`type(x) in tuple`) are O(1) for small tuples and much faster than `isinstance()` with protocol classes
- Attribute access (`self._state`) involves Python's method resolution, while local variables are direct memory lookups

**Performance by Test Case:**
The optimization excels particularly with **high-transition workloads**:
- Simple transitions: 828-1036% faster
- Complex state protocols: 114-638% faster  
- Large-scale scenarios with many transitions: 213-2170% faster
- Edge cases with no transitions show minimal impact (~5% slower due to setup overhead)

This optimization maintains identical behavior while dramatically improving performance for FSM-heavy workloads.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 October 30, 2025 17:51
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Oct 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant