Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Oct 21, 2025

📄 131% (1.31x) speedup for prompt_text in framework/py/flwr/cli/utils.py

⏱️ Runtime : 6.40 milliseconds 2.77 milliseconds (best of 143 runs)

📝 Explanation and details

The optimization moves the expensive typer.style() calls outside the while True loop to avoid repeated computation. In the original code, every time the user enters invalid input, both the prompt styling and error message styling are recalculated. The line profiler shows that typer.style() calls consumed 31.5% and 54.8% of the total runtime respectively.

Key changes:

  • Pre-compute styled_prompt and styled_error once before the loop
  • Reference the pre-computed styled strings inside the loop

Why this is faster:
typer.style() performs string formatting and ANSI escape code generation internally, which involves multiple string operations and color calculations. By computing these styled strings once upfront, we eliminate this overhead on every loop iteration when users provide invalid input.

Test case performance:
This optimization is most beneficial for test cases with multiple invalid attempts before valid input, such as:

  • test_prompt_text_large_scale() - 999 invalid inputs before valid one
  • test_large_scale_many_attempts() - 100 attempts before valid input
  • test_prompt_text_predicate_always_false() - Multiple rejections by predicate

The 131% speedup demonstrates significant savings when the validation loop runs multiple iterations, as styling overhead is eliminated from the hot path.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 37 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from typing import Callable, Optional, cast

# imports
import pytest  # used for our unit tests
# function to test
import typer
from cli.utils import prompt_text

# unit tests

# Use pytest's monkeypatch fixture to simulate user input for typer.prompt.

# -------------------------
# 1. Basic Test Cases
# -------------------------

def test_prompt_text_basic_valid(monkeypatch):
    """Test basic valid input with default predicate (always True)."""
    # Simulate user entering "hello"
    monkeypatch.setattr("typer.prompt", lambda *args, **kwargs: "hello")
    codeflash_output = prompt_text("Enter something:"); result = codeflash_output

def test_prompt_text_basic_valid_with_default(monkeypatch):
    """Test input with default value (user presses enter)."""
    monkeypatch.setattr("typer.prompt", lambda *args, **kwargs: kwargs.get("default", "default"))
    codeflash_output = prompt_text("Enter your name:", default="default_name"); result = codeflash_output

def test_prompt_text_basic_predicate(monkeypatch):
    """Test with a predicate that only accepts digits."""
    monkeypatch.setattr("typer.prompt", lambda *args, **kwargs: "12345")
    codeflash_output = prompt_text("Enter digits:", predicate=lambda s: s.isdigit()); result = codeflash_output

# -------------------------
# 2. Edge Test Cases
# -------------------------

def test_prompt_text_empty_input(monkeypatch):
    """Test that empty input is rejected and valid input is accepted."""
    # Simulate user entering "" (empty), then "valid"
    responses = ["", "valid"]
    def fake_prompt(*args, **kwargs):
        return responses.pop(0)
    monkeypatch.setattr("typer.prompt", fake_prompt)
    codeflash_output = prompt_text("Cannot be empty:"); result = codeflash_output

def test_prompt_text_predicate_always_false(monkeypatch):
    """Test that input is rejected if predicate always returns False."""
    # Simulate user entering "anything", then "something", then "ok"
    responses = ["anything", "something", "ok"]
    def fake_prompt(*args, **kwargs):
        return responses.pop(0)
    monkeypatch.setattr("typer.prompt", fake_prompt)
    # Predicate always returns False except for "ok"
    codeflash_output = prompt_text("Try again:", predicate=lambda s: s == "ok"); result = codeflash_output

def test_prompt_text_predicate_and_empty(monkeypatch):
    """Test that input is rejected if predicate returns True but input is empty."""
    # Simulate user entering "", then "good"
    responses = ["", "good"]
    def fake_prompt(*args, **kwargs):
        return responses.pop(0)
    monkeypatch.setattr("typer.prompt", fake_prompt)
    codeflash_output = prompt_text("Enter non-empty:", predicate=lambda s: True); result = codeflash_output

def test_prompt_text_default_is_empty(monkeypatch):
    """Test that default value is empty string and is rejected."""
    # Simulate user pressing enter (returns default="")
    responses = ["", "filled"]
    def fake_prompt(*args, **kwargs):
        # Returns default if provided, else ""
        return responses.pop(0) if responses else kwargs.get("default", "")
    monkeypatch.setattr("typer.prompt", fake_prompt)
    codeflash_output = prompt_text("Enter something:", default=""); result = codeflash_output

def test_prompt_text_predicate_type(monkeypatch):
    """Test predicate that checks for specific type (numeric string)."""
    responses = ["abc", "123"]
    def fake_prompt(*args, **kwargs):
        return responses.pop(0)
    monkeypatch.setattr("typer.prompt", fake_prompt)
    codeflash_output = prompt_text("Enter a number:", predicate=lambda s: s.isdigit()); result = codeflash_output

def test_prompt_text_unicode_input(monkeypatch):
    """Test Unicode input is accepted."""
    monkeypatch.setattr("typer.prompt", lambda *args, **kwargs: "你好世界")
    codeflash_output = prompt_text("Enter Unicode:"); result = codeflash_output

def test_prompt_text_whitespace_input(monkeypatch):
    """Test input with only whitespace is rejected."""
    responses = ["   ", "valid"]
    def fake_prompt(*args, **kwargs):
        return responses.pop(0)
    monkeypatch.setattr("typer.prompt", fake_prompt)
    codeflash_output = prompt_text("No whitespace:", predicate=lambda s: s.strip() != ""); result = codeflash_output

def test_prompt_text_long_string(monkeypatch):
    """Test input with a long string."""
    long_str = "a" * 500
    monkeypatch.setattr("typer.prompt", lambda *args, **kwargs: long_str)
    codeflash_output = prompt_text("Enter long string:"); result = codeflash_output

# -------------------------
# 3. Large Scale Test Cases
# -------------------------

def test_prompt_text_large_scale(monkeypatch):
    """Test with a large number of invalid inputs before a valid one (simulate 1000 tries)."""
    # 999 invalid inputs, then 1 valid
    responses = [""] * 999 + ["final_valid"]
    def fake_prompt(*args, **kwargs):
        return responses.pop(0)
    monkeypatch.setattr("typer.prompt", fake_prompt)
    codeflash_output = prompt_text("Enter non-empty:"); result = codeflash_output

def test_prompt_text_large_input(monkeypatch):
    """Test with a very large input string (1000 chars)."""
    large_input = "x" * 1000
    monkeypatch.setattr("typer.prompt", lambda *args, **kwargs: large_input)
    codeflash_output = prompt_text("Enter large input:"); result = codeflash_output

def test_prompt_text_large_predicate(monkeypatch):
    """Test predicate that only accepts strings longer than 500 characters."""
    responses = ["short"] * 10 + ["x" * 501]
    def fake_prompt(*args, **kwargs):
        return responses.pop(0)
    monkeypatch.setattr("typer.prompt", fake_prompt)
    codeflash_output = prompt_text("Enter long string:", predicate=lambda s: len(s) > 500); result = codeflash_output

# -------------------------
# 4. Additional Robustness Cases
# -------------------------


def test_prompt_text_predicate_exception(monkeypatch):
    """Test that predicate exception is not swallowed."""
    monkeypatch.setattr("typer.prompt", lambda *args, **kwargs: "test")
    def bad_predicate(s):
        raise ValueError("Bad predicate")
    with pytest.raises(ValueError):
        prompt_text("Enter:", predicate=bad_predicate)

def test_prompt_text_default_and_predicate(monkeypatch):
    """Test that default value is accepted if it passes predicate."""
    monkeypatch.setattr("typer.prompt", lambda *args, **kwargs: kwargs.get("default", "default"))
    codeflash_output = prompt_text("Enter:", predicate=lambda s: s.startswith("d"), default="default"); result = codeflash_output

def test_prompt_text_default_and_predicate_reject(monkeypatch):
    """Test that default value is rejected if it fails predicate."""
    responses = ["", "valid"]
    def fake_prompt(*args, **kwargs):
        # First call returns default (""), second call returns "valid"
        return responses.pop(0)
    monkeypatch.setattr("typer.prompt", fake_prompt)
    codeflash_output = prompt_text("Enter:", predicate=lambda s: s == "valid", default=""); result = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from typing import Callable, Optional, cast

# imports
import pytest  # used for our unit tests
# function to test
import typer
from cli.utils import prompt_text

# unit tests

# Helper to monkeypatch typer.prompt for simulating user input
@pytest.fixture
def patch_prompt(monkeypatch):
    def _patch_prompt(return_values):
        # return_values: list of values to return on each call to typer.prompt
        iterator = iter(return_values)
        monkeypatch.setattr(typer, "prompt", lambda *args, **kwargs: next(iterator))
    return _patch_prompt

# Basic Test Cases

def test_basic_valid_input_first_try(patch_prompt):
    # User enters valid input on first try
    patch_prompt(["hello"])
    codeflash_output = prompt_text("Enter something:"); result = codeflash_output

def test_basic_valid_input_with_predicate(patch_prompt):
    # Predicate: only accept input that is all digits
    patch_prompt(["abc", "123"])
    codeflash_output = prompt_text("Enter digits:", predicate=lambda s: s.isdigit()); result = codeflash_output


def test_basic_default_value_rejected_by_predicate(patch_prompt):
    # Default is "Flower", but predicate only allows "Rose"
    patch_prompt(["", "Rose"])
    codeflash_output = prompt_text("Enter flower name:", predicate=lambda s: s == "Rose", default="Flower"); result = codeflash_output

def test_basic_empty_string_not_accepted(patch_prompt):
    # User enters empty string, should not be accepted
    patch_prompt(["", "ok"])
    codeflash_output = prompt_text("Type something:"); result = codeflash_output

def test_basic_predicate_false_then_true(patch_prompt):
    # Predicate rejects first input, accepts second
    patch_prompt(["bad", "good"])
    codeflash_output = prompt_text("Type 'good':", predicate=lambda s: s == "good"); result = codeflash_output

# Edge Test Cases

def test_edge_whitespace_only_input(patch_prompt):
    # Input is whitespace only, should not be accepted
    patch_prompt(["   ", "valid"])
    codeflash_output = prompt_text("Enter text:", predicate=lambda s: not s.isspace()); result = codeflash_output

def test_edge_long_string_input(patch_prompt):
    # Input is a long string
    long_str = "a" * 1000
    patch_prompt([long_str])
    codeflash_output = prompt_text("Enter long string:"); result = codeflash_output

def test_edge_predicate_always_false(patch_prompt):
    # Predicate always returns False, should keep looping
    patch_prompt(["foo", "bar", "baz", "qux", "valid"])
    # We'll accept only "valid"
    codeflash_output = prompt_text("Type 'valid':", predicate=lambda s: s == "valid"); result = codeflash_output

def test_edge_predicate_allows_empty_but_len_check_blocks(patch_prompt):
    # Predicate allows empty, but len(result) > 0 blocks it
    patch_prompt(["", "nonempty"])
    codeflash_output = prompt_text("Enter nonempty:", predicate=lambda s: True); result = codeflash_output

def test_edge_special_characters_input(patch_prompt):
    # Input contains special characters
    patch_prompt(["@#$%^&*()"])
    codeflash_output = prompt_text("Enter special chars:"); result = codeflash_output

def test_edge_unicode_input(patch_prompt):
    # Input is unicode
    patch_prompt(["你好"])
    codeflash_output = prompt_text("Enter unicode:"); result = codeflash_output

def test_edge_default_is_none_and_user_enters_empty(patch_prompt):
    # Default is None, user enters empty string, should not be accepted
    patch_prompt(["", "something"])
    codeflash_output = prompt_text("Prompt:", default=None); result = codeflash_output

def test_edge_predicate_with_side_effect(patch_prompt):
    # Predicate that counts calls
    state = {"count": 0}
    def pred(s):
        state["count"] += 1
        return s == "pass"
    patch_prompt(["fail", "pass"])
    codeflash_output = prompt_text("Enter 'pass':", predicate=pred); result = codeflash_output

# Large Scale Test Cases

def test_large_scale_many_attempts(patch_prompt):
    # Simulate 100 attempts before valid input
    invalids = ["x"] * 99
    patch_prompt(invalids + ["valid"])
    codeflash_output = prompt_text("Type 'valid':", predicate=lambda s: s == "valid"); result = codeflash_output


def test_large_scale_large_input(patch_prompt):
    # User enters a string of length 999
    large_input = "y" * 999
    patch_prompt([large_input])
    codeflash_output = prompt_text("Enter large input:"); result = codeflash_output

def test_large_scale_predicate_complex(patch_prompt):
    # Predicate checks for palindrome and length > 500
    palindrome = "a" * 501 + "b" + "a" * 501
    patch_prompt(["abc", "notapalindrome", palindrome])
    def is_large_palindrome(s):
        return s == s[::-1] and len(s) > 500
    codeflash_output = prompt_text("Enter large palindrome:", predicate=is_large_palindrome); result = codeflash_output

def test_large_scale_predicate_rejects_all_but_last(patch_prompt):
    # 999 invalid, 1 valid
    patch_prompt(["bad"] * 999 + ["good"])
    codeflash_output = prompt_text("Type 'good':", predicate=lambda s: s == "good"); result = codeflash_output

# Additional edge: ensure default is accepted if predicate allows and not empty

def test_default_rejected_if_predicate_disallows(patch_prompt):
    patch_prompt(["", "notdefault"])
    codeflash_output = prompt_text("Prompt:", predicate=lambda s: s == "notdefault", default="default"); result = codeflash_output

# Additional edge: ensure prompt_text returns str type
def test_return_type_is_str(patch_prompt):
    patch_prompt(["abc"])
    codeflash_output = prompt_text("Prompt:"); result = codeflash_output

# Additional edge: ensure prompt_text does not accept None as input
def test_none_input_not_accepted(patch_prompt):
    patch_prompt([None, "nonone"])
    codeflash_output = prompt_text("Prompt:", predicate=lambda s: s is not None); result = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from cli.utils import prompt_text

To edit these changes git checkout codeflash/optimize-prompt_text-mh17i9dq and push.

Codeflash

The optimization moves the expensive `typer.style()` calls outside the `while True` loop to avoid repeated computation. In the original code, every time the user enters invalid input, both the prompt styling and error message styling are recalculated. The line profiler shows that `typer.style()` calls consumed 31.5% and 54.8% of the total runtime respectively.

**Key changes:**
- Pre-compute `styled_prompt` and `styled_error` once before the loop
- Reference the pre-computed styled strings inside the loop

**Why this is faster:**
`typer.style()` performs string formatting and ANSI escape code generation internally, which involves multiple string operations and color calculations. By computing these styled strings once upfront, we eliminate this overhead on every loop iteration when users provide invalid input.

**Test case performance:**
This optimization is most beneficial for test cases with multiple invalid attempts before valid input, such as:
- `test_prompt_text_large_scale()` - 999 invalid inputs before valid one
- `test_large_scale_many_attempts()` - 100 attempts before valid input
- `test_prompt_text_predicate_always_false()` - Multiple rejections by predicate

The 131% speedup demonstrates significant savings when the validation loop runs multiple iterations, as styling overhead is eliminated from the hot path.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 October 21, 2025 23:39
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Oct 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant