Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Oct 1, 2025

📄 42% (0.42x) speedup for KIEPredictor.get_text in doctr/models/kie_predictor/pytorch.py

⏱️ Runtime : 438 microseconds 308 microseconds (best of 93 runs)

📝 Explanation and details

The optimization replaces an inefficient nested loop with repeated list concatenations with a single flattened list comprehension.

Key optimization:

  • Original: Used text += [item[0] for item in value] inside a loop, which creates a new list comprehension on each iteration and then concatenates it to the existing text list
  • Optimized: Uses a single flattened list comprehension [item[0] for value in text_pred.values() for item in value] that builds the entire result list in one pass

Why this is faster:

  • List concatenation with += is O(n) for each operation because it creates a new list and copies existing elements
  • With multiple keys, this leads to O(n²) behavior as the list grows
  • The flattened comprehension is O(n) total, building the list once without intermediate concatenations

Performance characteristics from tests:

  • Small inputs (1-5 items): 8-26% faster
  • Large inputs with many keys: 40-87% faster (e.g., test_large_many_keys_single_item_each shows 86% speedup)
  • Single key with many items: 13% faster
  • Mixed scenarios with empty lists: 28-83% faster

The optimization is most effective when there are many dictionary keys, as it eliminates the quadratic behavior of repeated list concatenations.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 61 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from typing import Any

# imports
import pytest  # used for our unit tests
from doctr.models.builder import KIEDocumentBuilder
from doctr.models.classification.predictor import OrientationPredictor
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.kie_predictor.base import _KIEPredictor
from doctr.models.kie_predictor.pytorch import KIEPredictor
from doctr.models.predictor.base import _OCRPredictor
from doctr.models.recognition.predictor import RecognitionPredictor
from torch import nn

# Copyright (C) 2021-2025, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

class _KIEPredictor(_OCRPredictor):
    """Implements an object able to localize and identify text elements in a set of documents

    Args:
        assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
            without rotated textual elements.
        straighten_pages: if True, estimates the page general orientation based on the median line orientation.
            Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
            accordingly. Doing so will improve performances for documents with page-uniform rotations.
        preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
        symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
        detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
            page. Doing so will slightly deteriorate the overall latency.
        kwargs: keyword args of `DocumentBuilder`
    """

    crop_orientation_predictor: OrientationPredictor | None
    page_orientation_predictor: OrientationPredictor | None

    def __init__(
        self,
        assume_straight_pages: bool = True,
        straighten_pages: bool = False,
        preserve_aspect_ratio: bool = True,
        symmetric_pad: bool = True,
        detect_orientation: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            assume_straight_pages,
            straighten_pages,
            preserve_aspect_ratio,
            symmetric_pad,
            detect_orientation,
            **kwargs,
        )

        # Remove the following arguments from kwargs after initialization of the parent class
        kwargs.pop("disable_page_orientation", None)
        kwargs.pop("disable_crop_orientation", None)

        self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs)


# unit tests

# Use the static method directly for testing
get_text = KIEPredictor.get_text

# --------------------------
# 1. BASIC TEST CASES
# --------------------------

def test_basic_single_key_single_item():
    # One key, one value, single item
    pred = {'field1': [('hello', 0.99)]}
    codeflash_output = get_text(pred); result = codeflash_output # 1.45μs -> 1.26μs (15.0% faster)

def test_basic_single_key_multiple_items():
    # One key, multiple values
    pred = {'field1': [('hello', 0.99), ('world', 0.98)]}
    codeflash_output = get_text(pred); result = codeflash_output # 1.51μs -> 1.34μs (12.7% faster)

def test_basic_multiple_keys_single_item_each():
    # Multiple keys, single item each
    pred = {'field1': [('foo', 1.0)], 'field2': [('bar', 0.9)]}
    codeflash_output = get_text(pred); result = codeflash_output # 1.63μs -> 1.43μs (14.4% faster)

def test_basic_multiple_keys_multiple_items():
    # Multiple keys, multiple items each
    pred = {
        'field1': [('a', 0.9), ('b', 0.8)],
        'field2': [('c', 0.7), ('d', 0.6)],
    }
    codeflash_output = get_text(pred); result = codeflash_output # 1.86μs -> 1.47μs (26.3% faster)

def test_basic_empty_dict():
    # No keys at all
    pred = {}
    codeflash_output = get_text(pred); result = codeflash_output # 766ns -> 844ns (9.24% slower)

def test_basic_key_with_empty_list():
    # Key with empty list
    pred = {'field1': []}
    codeflash_output = get_text(pred); result = codeflash_output # 1.16μs -> 980ns (18.6% faster)

def test_basic_multiple_keys_some_empty():
    # Some keys have empty lists
    pred = {
        'field1': [('a', 1.0)],
        'field2': [],
        'field3': [('b', 1.0)],
    }
    codeflash_output = get_text(pred); result = codeflash_output # 1.79μs -> 1.42μs (26.5% faster)

# --------------------------
# 2. EDGE TEST CASES
# --------------------------

def test_edge_text_with_spaces():
    # Texts with spaces
    pred = {'field1': [('hello world', 0.9)], 'field2': [('foo bar', 0.8)]}
    codeflash_output = get_text(pred); result = codeflash_output # 1.54μs -> 1.31μs (17.8% faster)

def test_edge_text_with_empty_strings():
    # Texts with empty string
    pred = {'field1': [('', 1.0), ('nonempty', 0.9)]}
    codeflash_output = get_text(pred); result = codeflash_output # 1.45μs -> 1.29μs (12.1% faster)

def test_edge_text_all_empty_strings():
    # All items are empty strings
    pred = {'field1': [('', 1.0), ('', 0.8)]}
    codeflash_output = get_text(pred); result = codeflash_output # 1.40μs -> 1.23μs (13.3% faster)

def test_edge_text_with_special_characters():
    # Special characters in text
    pred = {'field1': [('!@#, 0.7), ('\n\t', 0.6)]}
    codeflash_output = get_text(pred); result = codeflash_output # 1.37μs -> 1.23μs (11.6% faster)

def test_edge_text_with_numbers_and_unicode():
    # Numbers and unicode characters
    pred = {'field1': [('123', 1.0), ('你好', 0.99), ('café', 0.98)]}
    codeflash_output = get_text(pred); result = codeflash_output # 1.86μs -> 1.65μs (12.5% faster)

def test_edge_text_with_none_as_key():
    # None as a key (should work, as dict keys can be anything hashable)
    pred = {None: [('foo', 1.0)]}
    codeflash_output = get_text(pred); result = codeflash_output # 1.27μs -> 1.16μs (9.31% faster)

def test_edge_text_with_tuple_as_key():
    # Tuple as a key
    pred = {('tuple', 1): [('bar', 0.8)]}
    codeflash_output = get_text(pred); result = codeflash_output # 1.25μs -> 1.15μs (8.07% faster)

def test_edge_item_tuple_with_extra_elements():
    # Item tuples with more than 2 elements: only first element should be used
    pred = {'field1': [('foo', 1.0, 'extra'), ('bar', 0.9, 123)]}
    codeflash_output = get_text(pred); result = codeflash_output # 1.46μs -> 1.31μs (11.7% faster)

def test_edge_item_tuple_with_single_element():
    # Item tuples with only one element: should not fail, just use first element
    pred = {'field1': [('solo',)]}
    codeflash_output = get_text(pred); result = codeflash_output # 1.32μs -> 1.12μs (17.4% faster)



def test_edge_non_dict_input():
    # Input is not a dict (should raise AttributeError)
    pred = [('foo', 1.0)]
    with pytest.raises(AttributeError):
        get_text(pred) # 1.90μs -> 1.79μs (6.02% faster)

def test_edge_dict_with_non_list_value():
    # Value is not a list (should raise TypeError)
    pred = {'field1': ('foo', 1.0)}
    with pytest.raises(TypeError):
        get_text(pred) # 2.45μs -> 2.58μs (5.00% slower)

def test_edge_dict_with_none_value():
    # Value is None (should raise TypeError)
    pred = {'field1': None}
    with pytest.raises(TypeError):
        get_text(pred) # 1.64μs -> 1.86μs (11.9% slower)

def test_edge_dict_with_non_iterable_value():
    # Value is int (should raise TypeError)
    pred = {'field1': 123}
    with pytest.raises(TypeError):
        get_text(pred) # 1.63μs -> 1.89μs (13.9% slower)

def test_edge_dict_with_nested_empty_lists():
    # Nested empty lists
    pred = {'field1': [], 'field2': []}
    codeflash_output = get_text(pred); result = codeflash_output # 1.36μs -> 1.06μs (28.8% faster)

def test_edge_dict_with_nested_lists():
    # Value is a nested list (should raise TypeError)
    pred = {'field1': [[('foo', 1.0)]]}
    with pytest.raises(TypeError):
        get_text(pred) # 4.24μs -> 4.18μs (1.34% faster)

# --------------------------
# 3. LARGE SCALE TEST CASES
# --------------------------

def test_large_single_key_many_items():
    # Large number of items in a single key
    pred = {'field1': [(f'text{i}', 1.0) for i in range(1000)]}
    codeflash_output = get_text(pred); result = codeflash_output # 24.7μs -> 21.8μs (13.1% faster)
    expected = ' '.join(f'text{i}' for i in range(1000))

def test_large_many_keys_single_item_each():
    # Many keys, each with a single item
    pred = {f'field{i}': [(f'text{i}', 1.0)] for i in range(1000)}
    codeflash_output = get_text(pred); result = codeflash_output # 88.4μs -> 47.5μs (86.1% faster)
    expected = ' '.join(f'text{i}' for i in range(1000))

def test_large_many_keys_many_items_each():
    # Many keys, each with multiple items
    pred = {f'field{i}': [(f'text{i}_{j}', 1.0) for j in range(5)] for i in range(200)}
    codeflash_output = get_text(pred); result = codeflash_output # 40.7μs -> 29.0μs (40.2% faster)
    expected = ' '.join(f'text{i}_{j}' for i in range(200) for j in range(5))

def test_large_some_keys_empty_some_large():
    # Some keys empty, some with many items
    pred = {f'field{i}': [] for i in range(500)}
    pred.update({f'bigfield{i}': [(f'word{i}_{j}', 1.0) for j in range(10)] for i in range(10)})
    codeflash_output = get_text(pred); result = codeflash_output # 32.2μs -> 17.6μs (82.7% faster)
    expected = ' '.join(f'word{i}_{j}' for i in range(10) for j in range(10))

def test_large_key_with_long_strings():
    # Key with very long strings
    long_str = 'x' * 1000
    pred = {'field1': [(long_str, 0.99) for _ in range(10)]}
    codeflash_output = get_text(pred); result = codeflash_output # 2.46μs -> 2.03μs (20.8% faster)
    expected = ' '.join([long_str] * 10)

def test_large_all_empty_strings():
    # Large number of empty strings
    pred = {'field1': [('', 1.0) for _ in range(1000)]}
    codeflash_output = get_text(pred); result = codeflash_output # 26.9μs -> 25.2μs (6.77% faster)
    expected = ' '.join([''] * 1000)

def test_large_mixed_types():
    # Large dict with some malformed entries
    pred = {f'field{i}': [(f'text{i}', 1.0)] for i in range(995)}
    pred['badfield1'] = 'notalist'
    pred['badfield2'] = None
    with pytest.raises(TypeError):
        get_text(pred) # 82.1μs -> 43.7μs (87.7% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from typing import Any

# imports
import pytest  # used for our unit tests
from doctr.models.builder import KIEDocumentBuilder
from doctr.models.classification.predictor import OrientationPredictor
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.kie_predictor.base import _KIEPredictor
from doctr.models.kie_predictor.pytorch import KIEPredictor
from doctr.models.predictor.base import _OCRPredictor
from doctr.models.recognition.predictor import RecognitionPredictor
from torch import nn

# Copyright (C) 2021-2025, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

class _KIEPredictor(_OCRPredictor):
    """Implements an object able to localize and identify text elements in a set of documents

    Args:
        assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
            without rotated textual elements.
        straighten_pages: if True, estimates the page general orientation based on the median line orientation.
            Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
            accordingly. Doing so will improve performances for documents with page-uniform rotations.
        preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
        symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
        detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
            page. Doing so will slightly deteriorate the overall latency.
        kwargs: keyword args of `DocumentBuilder`
    """

    crop_orientation_predictor: OrientationPredictor | None
    page_orientation_predictor: OrientationPredictor | None

    def __init__(
        self,
        assume_straight_pages: bool = True,
        straighten_pages: bool = False,
        preserve_aspect_ratio: bool = True,
        symmetric_pad: bool = True,
        detect_orientation: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            assume_straight_pages,
            straighten_pages,
            preserve_aspect_ratio,
            symmetric_pad,
            detect_orientation,
            **kwargs,
        )

        # Remove the following arguments from kwargs after initialization of the parent class
        kwargs.pop("disable_page_orientation", None)
        kwargs.pop("disable_crop_orientation", None)

        self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs)

# unit tests

@pytest.mark.parametrize(
    "input_dict,expected",
    [
        # Basic: single key, single value
        ({"field1": [("Hello", 0.99)]}, "Hello"),
        # Basic: single key, multiple values
        ({"field1": [("Hello", 0.99), ("World", 0.98)]}, "Hello World"),
        # Basic: multiple keys, single value each
        ({"field1": [("Hello", 0.99)], "field2": [("World", 0.98)]}, "Hello World"),
        # Basic: multiple keys, multiple values each
        ({"field1": [("Hello", 0.99), ("there", 0.97)], "field2": [("World", 0.98), ("!", 0.96)]}, "Hello there World !"),
        # Basic: values with empty string
        ({"field1": [("", 0.99)], "field2": [("World", 0.98)]}, " World"),
        # Basic: values with spaces
        ({"field1": [("Hello", 0.99)], "field2": [(" ", 0.98)], "field3": [("World", 0.97)]}, "Hello   World"),
        # Basic: values with punctuation
        ({"field1": [("Hello,", 0.99)], "field2": [("World!", 0.98)]}, "Hello, World!"),
        # Basic: values with numbers
        ({"field1": [("123", 0.99)], "field2": [("456", 0.98)]}, "123 456"),
        # Basic: values with special characters
        ({"field1": [("@home", 0.99)], "field2": [("#work", 0.98)]}, "@home #work"),
    ]
)
def test_get_text_basic(input_dict, expected):
    # Test basic scenarios for get_text
    codeflash_output = KIEPredictor.get_text(input_dict) # 15.3μs -> 12.9μs (18.5% faster)

@pytest.mark.parametrize(
    "input_dict,expected",
    [
        # Edge: empty dictionary
        ({}, ""),
        # Edge: key with empty list
        ({"field1": []}, ""),
        # Edge: multiple keys, some with empty lists
        ({"field1": [], "field2": [("Hello", 0.99)]}, "Hello"),
        # Edge: all keys with empty lists
        ({"field1": [], "field2": []}, ""),
        # Edge: keys with only empty strings
        ({"field1": [("", 0.99)], "field2": [("", 0.98)]}, " "),
        # Edge: keys with mixed empty and non-empty strings
        ({"field1": [("", 0.99), ("Hi", 0.98)], "field2": [("", 0.97), ("There", 0.96)]}, " Hi  There"),
        # Edge: keys with None as value (should raise TypeError)
        pytest.param({"field1": None}, None, marks=pytest.mark.xfail(raises=TypeError)),
        # Edge: values not as list (should raise TypeError)
        pytest.param({"field1": ("Hello", 0.99)}, None, marks=pytest.mark.xfail(raises=TypeError)),
        # Edge: values as list but not tuples (should raise TypeError)
        pytest.param({"field1": ["Hello", "World"]}, None, marks=pytest.mark.xfail(raises=TypeError)),
        # Edge: tuple with missing confidence (should raise IndexError)
        pytest.param({"field1": [("Hello",), ("World", 0.98)]}, None, marks=pytest.mark.xfail(raises=IndexError)),
        # Edge: tuple with extra elements (should use only first)
        ({"field1": [("Hello", 0.99, "extra"), ("World", 0.98, "extra2")]}, "Hello World"),
        # Edge: keys with non-string first element (should convert to string)
        ({"field1": [(123, 0.99)], "field2": [(None, 0.98)]}, "123 None"),
    ]
)
def test_get_text_edge(input_dict, expected):
    # Test edge scenarios for get_text
    if expected is not None:
        codeflash_output = KIEPredictor.get_text(input_dict) # 13.3μs -> 11.6μs (14.1% faster)
    else:
        # For xfail cases, the test will fail as expected
        KIEPredictor.get_text(input_dict)

def test_get_text_ordering():
    # Edge: ordering is by dict value traversal, not key order
    # Dicts preserve insertion order in Python 3.7+
    input_dict = {
        "fieldA": [("first", 0.9)],
        "fieldB": [("second", 0.8)],
        "fieldC": [("third", 0.7)]
    }
    codeflash_output = KIEPredictor.get_text(input_dict) # 2.19μs -> 1.70μs (28.6% faster)
    # If keys are inserted in different order, output changes
    input_dict2 = {
        "fieldB": [("second", 0.8)],
        "fieldC": [("third", 0.7)],
        "fieldA": [("first", 0.9)]
    }
    codeflash_output = KIEPredictor.get_text(input_dict2) # 788ns -> 543ns (45.1% faster)

def test_get_text_large_scale():
    # Large Scale: many keys and values (under 1000 elements)
    num_keys = 50
    num_items_per_key = 20
    input_dict = {}
    expected_items = []
    for k in range(num_keys):
        key = f"field{k}"
        # Each value is a tuple (text, confidence)
        items = [(f"text{k}_{i}", 0.9 + i * 0.001) for i in range(num_items_per_key)]
        input_dict[key] = items
        expected_items.extend([f"text{k}_{i}" for i in range(num_items_per_key)])
    # The output should be all texts in the order of dict values traversal
    expected = " ".join(expected_items)
    codeflash_output = KIEPredictor.get_text(input_dict) # 29.9μs -> 25.1μs (19.2% faster)

def test_get_text_large_scale_empty_and_nonempty():
    # Large Scale: mixture of empty and non-empty lists
    num_keys = 100
    input_dict = {}
    expected_items = []
    for k in range(num_keys):
        key = f"field{k}"
        if k % 2 == 0:
            items = []
        else:
            items = [(f"text{k}", 0.9)]
            expected_items.append(f"text{k}")
        input_dict[key] = items
    expected = " ".join(expected_items)
    codeflash_output = KIEPredictor.get_text(input_dict) # 9.60μs -> 5.57μs (72.2% faster)

def test_get_text_large_scale_long_strings():
    # Large Scale: long strings
    long_str = "a" * 1000
    input_dict = {
        "field1": [(long_str, 0.99)],
        "field2": [(long_str[::-1], 0.98)]
    }
    expected = f"{long_str} {long_str[::-1]}"
    codeflash_output = KIEPredictor.get_text(input_dict) # 1.80μs -> 1.54μs (17.4% faster)

def test_get_text_large_scale_unicode():
    # Large Scale: unicode and non-ASCII characters
    input_dict = {
        "field1": [("你好", 0.99), ("世界", 0.98)],
        "field2": [("😊", 0.97), ("🌍", 0.96)],
    }
    expected = "你好 世界 😊 🌍"
    codeflash_output = KIEPredictor.get_text(input_dict) # 2.14μs -> 1.80μs (18.7% faster)

def test_get_text_large_scale_performance():
    # Large Scale: 1000 elements, check output and performance
    input_dict = {"field": [(str(i), 0.99) for i in range(1000)]}
    expected = " ".join(str(i) for i in range(1000))
    codeflash_output = KIEPredictor.get_text(input_dict) # 23.1μs -> 20.7μs (11.6% faster)

def test_get_text_type_error():
    # Edge: input not a dict
    with pytest.raises(AttributeError):
        KIEPredictor.get_text([("Hello", 0.99)]) # 1.87μs -> 1.81μs (3.15% faster)
    with pytest.raises(AttributeError):
        KIEPredictor.get_text("Hello World") # 1.04μs -> 969ns (7.22% faster)

def test_get_text_empty_tuple():
    # Edge: tuple is empty (should raise IndexError)
    with pytest.raises(IndexError):
        KIEPredictor.get_text({"field": [()]}) # 1.92μs -> 1.87μs (2.73% faster)

def test_get_text_confidence_not_used():
    # Edge: confidence value is ignored
    input_dict = {"field": [("Hello", 0.99), ("World", 0.01)]}
    expected = "Hello World"
    codeflash_output = KIEPredictor.get_text(input_dict) # 1.51μs -> 1.31μs (15.8% faster)

def test_get_text_non_string_text():
    # Edge: first element is not string, should convert to string
    input_dict = {"field": [(None, 0.99), (123, 0.98), (True, 0.97)]}
    expected = "None 123 True"
    codeflash_output = KIEPredictor.get_text(input_dict)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-KIEPredictor.get_text-mg7ra9vu and push.

Codeflash

The optimization replaces an inefficient nested loop with repeated list concatenations with a single flattened list comprehension.

**Key optimization:**
- **Original**: Used `text += [item[0] for item in value]` inside a loop, which creates a new list comprehension on each iteration and then concatenates it to the existing `text` list
- **Optimized**: Uses a single flattened list comprehension `[item[0] for value in text_pred.values() for item in value]` that builds the entire result list in one pass

**Why this is faster:**
- List concatenation with `+=` is O(n) for each operation because it creates a new list and copies existing elements
- With multiple keys, this leads to O(n²) behavior as the list grows
- The flattened comprehension is O(n) total, building the list once without intermediate concatenations

**Performance characteristics from tests:**
- Small inputs (1-5 items): 8-26% faster
- Large inputs with many keys: 40-87% faster (e.g., `test_large_many_keys_single_item_each` shows 86% speedup)
- Single key with many items: 13% faster
- Mixed scenarios with empty lists: 28-83% faster

The optimization is most effective when there are many dictionary keys, as it eliminates the quadratic behavior of repeated list concatenations.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 October 1, 2025 09:00
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Oct 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant