Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Oct 25, 2025

📄 287% (2.87x) speedup for select_multikrum in framework/py/flwr/serverapp/strategy/multikrum.py

⏱️ Runtime : 31.7 milliseconds 8.18 milliseconds (best of 186 runs)

📝 Explanation and details

The optimization achieves a 287% speedup by vectorizing key operations in the MultiKrum algorithm and eliminating redundant array operations.

Key optimizations applied:

  1. Optimized array flattening in compute_distances: Replaced the inefficient list comprehension [np.concatenate(rec.to_numpy_ndarrays(), axis=None).ravel() for rec in records] with a manual loop that handles single-array cases more efficiently. This avoids unnecessary concatenation when an ArrayRecord contains only one array (which is common), reducing from 50.5% to 27.3% of function time.

  2. Vectorized closest indices computation: Eliminated the expensive Python loop that called np.argsort() for each distance row individually. The original code spent 22.3% of time in the loop calling np.argsort(...).tolist() for each row. The optimization uses np.argsort(distance_matrix, axis=1) once to sort all rows simultaneously, then slices to get closest indices.

  3. Vectorized score calculation: Replaced the list comprehension that computed scores row-by-row with np.take_along_axis() followed by .sum(axis=1). This eliminates the expensive loop that was taking 50.1% of total time in select_multikrum, using NumPy's optimized indexing instead of Python iteration.

Performance characteristics: The optimizations are most effective for scenarios with:

  • Multiple clients (more distance matrix rows to process)
  • Models with single arrays per record (avoids concatenation overhead)
  • Larger model parameters (more benefit from vectorized operations)

The test results show consistent speedups across all scenarios, with the vectorized operations scaling better as the number of nodes and model size increase.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 30 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 1 Passed
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import numpy as np
# imports
import pytest
from serverapp.strategy.multikrum import select_multikrum


# Minimal stubs for required classes and types
class Array:
    """Stub for Array class, wraps a numpy array."""
    def __init__(self, dtype, shape, backend, data):
        self._array = np.frombuffer(data, dtype=dtype).reshape(shape)
        self.data = data

    def numpy(self):
        return self._array

class ArrayRecord(dict):
    """Stub for ArrayRecord, behaves like a dict of Arrays."""
    def __init__(self, numpy_ndarrays=None):
        super().__init__()
        if numpy_ndarrays is not None:
            for i, arr in enumerate(numpy_ndarrays):
                # Use float32 for all arrays for simplicity
                self[f"arr{i}"] = Array("float32", arr.shape, "numpy.ndarray", arr.astype("float32").tobytes())

    def to_numpy_ndarrays(self, keep_input=True):
        return [v.numpy() for v in self.values()]

class RecordDict:
    """Stub for RecordDict, contains array_records (dict of ArrayRecord)."""
    def __init__(self, array_record):
        self.array_records = {"weights": array_record}

    def __getitem__(self, key):
        return self.array_records[key]
from serverapp.strategy.multikrum import select_multikrum

# -------------------- UNIT TESTS --------------------

# ----------- BASIC TEST CASES -----------

def test_single_node_krum_returns_itself():
    """Single node, should always select itself regardless of malicious count."""
    arr = np.array([[1.0, 2.0]], dtype="float32")
    record = RecordDict(ArrayRecord([arr]))
    codeflash_output = select_multikrum([record], num_malicious_nodes=0, num_nodes_to_select=1); selected = codeflash_output

def test_two_nodes_no_malicious_selects_closest():
    """Two nodes, no malicious, should select the closest (both are closest to each other)."""
    arr1 = np.array([[1.0, 2.0]], dtype="float32")
    arr2 = np.array([[1.1, 2.1]], dtype="float32")
    rec1 = RecordDict(ArrayRecord([arr1]))
    rec2 = RecordDict(ArrayRecord([arr2]))
    codeflash_output = select_multikrum([rec1, rec2], num_malicious_nodes=0, num_nodes_to_select=1); selected = codeflash_output

def test_three_nodes_no_malicious_selects_closest():
    """Three nodes, no malicious, should select the one closest to the other two."""
    arrs = [
        np.array([[1.0, 2.0]], dtype="float32"),
        np.array([[1.1, 2.1]], dtype="float32"),
        np.array([[10.0, 20.0]], dtype="float32"),
    ]
    recs = [RecordDict(ArrayRecord([a])) for a in arrs]
    codeflash_output = select_multikrum(recs, num_malicious_nodes=0, num_nodes_to_select=1); selected = codeflash_output

def test_multi_krum_selects_multiple_closest():
    """Multi-Krum: select two closest out of four nodes."""
    arrs = [
        np.array([[1.0, 2.0]], dtype="float32"),
        np.array([[1.1, 2.1]], dtype="float32"),
        np.array([[10.0, 20.0]], dtype="float32"),
        np.array([[0.9, 1.9]], dtype="float32"),
    ]
    recs = [RecordDict(ArrayRecord([a])) for a in arrs]
    codeflash_output = select_multikrum(recs, num_malicious_nodes=0, num_nodes_to_select=2); selected = codeflash_output
    # The three first/last are close, so two of them should be selected
    selected_sets = set(selected)

# ----------- EDGE TEST CASES -----------

def test_all_nodes_identical():
    """All nodes have identical parameters; any can be selected."""
    arr = np.array([[5.0, 5.0]], dtype="float32")
    recs = [RecordDict(ArrayRecord([arr])) for _ in range(5)]
    codeflash_output = select_multikrum(recs, num_malicious_nodes=1, num_nodes_to_select=2); selected = codeflash_output

def test_num_nodes_to_select_greater_than_nodes():
    """num_nodes_to_select > number of nodes: should select all nodes."""
    arrs = [np.array([[i, i+1]], dtype="float32") for i in range(3)]
    recs = [RecordDict(ArrayRecord([a])) for a in arrs]
    codeflash_output = select_multikrum(recs, num_malicious_nodes=0, num_nodes_to_select=5); selected = codeflash_output

def test_num_malicious_nodes_negative():
    """Negative num_malicious_nodes should be treated as zero."""
    arrs = [np.array([[i, i+1]], dtype="float32") for i in range(4)]
    recs = [RecordDict(ArrayRecord([a])) for a in arrs]
    codeflash_output = select_multikrum(recs, num_malicious_nodes=-2, num_nodes_to_select=1); selected = codeflash_output

def test_num_malicious_nodes_too_large():
    """num_malicious_nodes >= n-2: num_closest becomes 1, should not crash."""
    arrs = [np.array([[i, i+1]], dtype="float32") for i in range(4)]
    recs = [RecordDict(ArrayRecord([a])) for a in arrs]
    codeflash_output = select_multikrum(recs, num_malicious_nodes=3, num_nodes_to_select=1); selected = codeflash_output

def test_empty_contents_raises():
    """Empty contents should raise IndexError."""
    with pytest.raises(IndexError):
        select_multikrum([], num_malicious_nodes=0, num_nodes_to_select=1)

def test_single_node_multi_krum():
    """Single node, multi-krum: should select itself if asked for more than one."""
    arr = np.array([[1.0, 2.0]], dtype="float32")
    record = RecordDict(ArrayRecord([arr]))
    codeflash_output = select_multikrum([record], num_malicious_nodes=0, num_nodes_to_select=2); selected = codeflash_output

def test_high_dimensional_arrays():
    """Test with high-dimensional arrays."""
    arrs = [
        np.ones((10, 10), dtype="float32"),
        np.ones((10, 10), dtype="float32") * 2,
        np.ones((10, 10), dtype="float32") * 3,
    ]
    recs = [RecordDict(ArrayRecord([a])) for a in arrs]
    codeflash_output = select_multikrum(recs, num_malicious_nodes=0, num_nodes_to_select=1); selected = codeflash_output


def test_large_number_of_nodes():
    """Test with 100 nodes, all similar except one outlier."""
    arrs = [np.ones((5,), dtype="float32") for _ in range(99)]
    arrs.append(np.ones((5,), dtype="float32") * 100)  # Outlier
    recs = [RecordDict(ArrayRecord([a])) for a in arrs]
    codeflash_output = select_multikrum(recs, num_malicious_nodes=10, num_nodes_to_select=5); selected = codeflash_output
    # Should select only from the similar nodes (not the outlier)
    for s in selected:
        arr = list(s.array_records.values())[0].to_numpy_ndarrays()[0]

def test_large_model_parameters():
    """Test with large model parameters (but < 100MB)."""
    arrs = [
        np.ones((1000, 10), dtype="float32"),
        np.ones((1000, 10), dtype="float32") * 2,
        np.ones((1000, 10), dtype="float32") * 3,
    ]
    recs = [RecordDict(ArrayRecord([a])) for a in arrs]
    codeflash_output = select_multikrum(recs, num_malicious_nodes=0, num_nodes_to_select=2); selected = codeflash_output
    for s in selected:
        arr = list(s.array_records.values())[0].to_numpy_ndarrays()[0]

def test_large_scale_multi_krum():
    """Test Multi-Krum with 500 nodes and select 10; all nodes identical."""
    arr = np.ones((10,), dtype="float32")
    recs = [RecordDict(ArrayRecord([arr])) for _ in range(500)]
    codeflash_output = select_multikrum(recs, num_malicious_nodes=50, num_nodes_to_select=10); selected = codeflash_output
    for s in selected:
        pass

def test_large_scale_with_malicious_outliers():
    """Test with 100 nodes, 10 malicious outliers, select 5."""
    arrs = [np.ones((5,), dtype="float32") for _ in range(90)]
    arrs += [np.ones((5,), dtype="float32") * 100 for _ in range(10)]  # Malicious
    recs = [RecordDict(ArrayRecord([a])) for a in arrs]
    codeflash_output = select_multikrum(recs, num_malicious_nodes=10, num_nodes_to_select=5); selected = codeflash_output
    # Should not select any malicious nodes
    for s in selected:
        arr = list(s.array_records.values())[0].to_numpy_ndarrays()[0]
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from collections import OrderedDict

import numpy as np
# imports
import pytest
from serverapp.strategy.multikrum import select_multikrum


# Mocks and minimal implementations for dependencies
class Array:
    """Minimal Array class for testing."""
    def __init__(self, dtype, shape, source, data):
        # dtype: str, shape: list[int], source: str, data: bytes
        self.dtype = dtype
        self.shape = shape
        self.source = source
        self.data = data
        self._np = None

    def numpy(self):
        if self._np is not None:
            return self._np
        # Deserialize from bytes (for test, simply use np.frombuffer)
        arr = np.frombuffer(self.data, dtype=self.dtype)
        arr = arr.reshape(self.shape)
        self._np = arr
        return arr

class TypedDict(dict):
    """Minimal TypedDict implementation for testing."""
    def __init__(self, *args, **kwargs):
        super().__init__()
    def __setitem__(self, key, value):
        super().__setitem__(key, value)

class ArrayRecord(TypedDict):
    """Minimal ArrayRecord for testing."""
    def __init__(self, numpy_ndarrays=None):
        super().__init__()
        if numpy_ndarrays is not None:
            for i, arr in enumerate(numpy_ndarrays):
                # Use dtype, shape, source, data
                arr_bytes = arr.astype(arr.dtype).tobytes()
                self[f"arr_{i}"] = Array(str(arr.dtype), list(arr.shape), "numpy.ndarray", arr_bytes)
                self[f"arr_{i}"]._np = arr

    def to_numpy_ndarrays(self, keep_input=True):
        # Return in insertion order
        return [v.numpy() for v in self.values()]

class RecordDict:
    """Minimal RecordDict for testing."""
    def __init__(self, array_record):
        self.array_records = {"params": array_record}
    def __getitem__(self, key):
        return self.array_records[key]
from serverapp.strategy.multikrum import select_multikrum

# ========================
# Basic Test Cases
# ========================

def make_recorddict_from_arrays(arrays):
    """Helper to create RecordDict from list of numpy arrays."""
    return RecordDict(ArrayRecord(numpy_ndarrays=arrays))

def arrays_close(a, b, tol=1e-8):
    """Helper to compare two lists of numpy arrays."""
    if len(a) != len(b):
        return False
    for x, y in zip(a, b):
        if not np.allclose(x, y, atol=tol):
            return False
    return True

def extract_arrays_from_recorddict(rd):
    """Extract numpy arrays from RecordDict."""
    return list(rd.array_records.values())[0].to_numpy_ndarrays()

def test_multikrum_basic_two_identical():
    # Two identical nodes, no malicious, select one (Krum)
    arr = np.array([1.0, 2.0, 3.0], dtype=np.float32)
    contents = [make_recorddict_from_arrays([arr]), make_recorddict_from_arrays([arr])]
    codeflash_output = select_multikrum(contents, num_malicious_nodes=0, num_nodes_to_select=1); selected = codeflash_output
    selected_arr = extract_arrays_from_recorddict(selected[0])[0]

def test_multikrum_basic_three_distinct():
    # Three nodes, one is closer to the other two
    arr1 = np.array([0.0, 0.0, 0.0], dtype=np.float32)
    arr2 = np.array([1.0, 1.0, 1.0], dtype=np.float32)
    arr3 = np.array([10.0, 10.0, 10.0], dtype=np.float32)
    contents = [
        make_recorddict_from_arrays([arr1]),
        make_recorddict_from_arrays([arr2]),
        make_recorddict_from_arrays([arr3]),
    ]
    codeflash_output = select_multikrum(contents, num_malicious_nodes=0, num_nodes_to_select=1); selected = codeflash_output
    # arr1 and arr2 are closer to each other than arr3
    selected_arr = extract_arrays_from_recorddict(selected[0])[0]

def test_multikrum_basic_multikrum_select_two():
    # Four nodes, select two, one is an outlier
    arr1 = np.array([0.0, 0.0], dtype=np.float32)
    arr2 = np.array([0.1, 0.1], dtype=np.float32)
    arr3 = np.array([0.2, 0.2], dtype=np.float32)
    arr4 = np.array([100.0, 100.0], dtype=np.float32)
    contents = [
        make_recorddict_from_arrays([arr1]),
        make_recorddict_from_arrays([arr2]),
        make_recorddict_from_arrays([arr3]),
        make_recorddict_from_arrays([arr4]),
    ]
    codeflash_output = select_multikrum(contents, num_malicious_nodes=0, num_nodes_to_select=2); selected = codeflash_output
    # arr4 is an outlier, so arr1, arr2, arr3 are closer; two of those three should be selected
    selected_arrays = [extract_arrays_from_recorddict(r)[0] for r in selected]
    # At least one of arr1, arr2, arr3 must be present, arr4 must not be present
    for arr in selected_arrays:
        pass

# ========================
# Edge Test Cases
# ========================

def test_multikrum_all_identical_nodes():
    # All nodes identical, any can be selected
    arr = np.ones((5,), dtype=np.float32)
    contents = [make_recorddict_from_arrays([arr]) for _ in range(5)]
    codeflash_output = select_multikrum(contents, num_malicious_nodes=0, num_nodes_to_select=3); selected = codeflash_output
    for s in selected:
        pass

def test_multikrum_single_node():
    # Only one node, should select itself
    arr = np.array([42.0], dtype=np.float32)
    contents = [make_recorddict_from_arrays([arr])]
    codeflash_output = select_multikrum(contents, num_malicious_nodes=0, num_nodes_to_select=1); selected = codeflash_output

def test_multikrum_more_to_select_than_nodes():
    # num_nodes_to_select > number of nodes: should select all nodes
    arrs = [np.array([i], dtype=np.float32) for i in range(3)]
    contents = [make_recorddict_from_arrays([arr]) for arr in arrs]
    codeflash_output = select_multikrum(contents, num_malicious_nodes=0, num_nodes_to_select=5); selected = codeflash_output
    selected_arrays = [extract_arrays_from_recorddict(r)[0] for r in selected]
    for arr in arrs:
        pass

def test_multikrum_num_malicious_equals_nodes_minus_two():
    # n = 5, f = 3, num_closest = max(1, n-f-2) = 0 -> fallback to 1
    arrs = [np.array([i], dtype=np.float32) for i in range(5)]
    contents = [make_recorddict_from_arrays([arr]) for arr in arrs]
    codeflash_output = select_multikrum(contents, num_malicious_nodes=3, num_nodes_to_select=1); selected = codeflash_output
    # Should not crash, and select one node

def test_multikrum_num_malicious_too_large():
    # n = 4, f = 10, num_closest = max(1, n-f-2) = 1
    arrs = [np.array([i], dtype=np.float32) for i in range(4)]
    contents = [make_recorddict_from_arrays([arr]) for arr in arrs]
    codeflash_output = select_multikrum(contents, num_malicious_nodes=10, num_nodes_to_select=2); selected = codeflash_output

def test_multikrum_empty_input():
    # Should raise an error if contents is empty
    with pytest.raises(IndexError):
        select_multikrum([], num_malicious_nodes=0, num_nodes_to_select=1)

def test_multikrum_high_dimensional_arrays():
    # Test with 2D arrays
    arr1 = np.ones((2,2), dtype=np.float32)
    arr2 = np.zeros((2,2), dtype=np.float32)
    arr3 = np.full((2,2), 5.0, dtype=np.float32)
    contents = [
        make_recorddict_from_arrays([arr1]),
        make_recorddict_from_arrays([arr2]),
        make_recorddict_from_arrays([arr3]),
    ]
    codeflash_output = select_multikrum(contents, num_malicious_nodes=0, num_nodes_to_select=1); selected = codeflash_output
    # arr1 and arr2 are closer to each other than arr3
    selected_arr = extract_arrays_from_recorddict(selected[0])[0]

def test_multikrum_multiple_arrays_per_record():
    # Each ArrayRecord contains two arrays
    arrs = [
        [np.array([1.0, 2.0]), np.array([3.0, 4.0])],
        [np.array([1.1, 2.1]), np.array([3.1, 4.1])],
        [np.array([10.0, 20.0]), np.array([30.0, 40.0])],
    ]
    contents = [make_recorddict_from_arrays(a) for a in arrs]
    codeflash_output = select_multikrum(contents, num_malicious_nodes=0, num_nodes_to_select=1); selected = codeflash_output
    selected_arrs = extract_arrays_from_recorddict(selected[0])

# ========================
# Large Scale Test Cases
# ========================

def test_multikrum_large_number_of_nodes():
    # 100 nodes, each with a 10-element vector, one outlier
    n = 100
    arrs = [np.ones((10,), dtype=np.float32) for _ in range(n-1)]
    arrs.append(np.full((10,), 1000.0, dtype=np.float32))  # outlier
    contents = [make_recorddict_from_arrays([arr]) for arr in arrs]
    codeflash_output = select_multikrum(contents, num_malicious_nodes=5, num_nodes_to_select=10); selected = codeflash_output
    # Outlier should not be selected
    selected_arrays = [extract_arrays_from_recorddict(r)[0] for r in selected]
    for arr in selected_arrays:
        pass
    # All selected arrays should be close to ones
    for arr in selected_arrays:
        pass

def test_multikrum_large_arrays():
    # 10 nodes, each with a 1000-element vector, one outlier
    n = 10
    arrs = [np.zeros((1000,), dtype=np.float32) for _ in range(n-1)]
    arrs.append(np.full((1000,), 100.0, dtype=np.float32))
    contents = [make_recorddict_from_arrays([arr]) for arr in arrs]
    codeflash_output = select_multikrum(contents, num_malicious_nodes=2, num_nodes_to_select=3); selected = codeflash_output
    selected_arrays = [extract_arrays_from_recorddict(r)[0] for r in selected]
    for arr in selected_arrays:
        pass
    for arr in selected_arrays:
        pass

def test_multikrum_scalability_many_nodes_and_arrays():
    # 50 nodes, each with 3 arrays of shape (5,)
    n = 50
    arrs = [
        [np.full((5,), i, dtype=np.float32) for _ in range(3)]
        for i in range(n)
    ]
    # Make one node an outlier
    arrs[-1] = [np.full((5,), 1e6, dtype=np.float32) for _ in range(3)]
    contents = [make_recorddict_from_arrays(a) for a in arrs]
    codeflash_output = select_multikrum(contents, num_malicious_nodes=5, num_nodes_to_select=5); selected = codeflash_output
    selected_arrays = [extract_arrays_from_recorddict(r) for r in selected]
    # Outlier should not be selected
    for arrays in selected_arrays:
        for arr in arrays:
            pass
    # The selected nodes should be among those with i in range 0..n-2

def test_multikrum_performance_within_limits():
    # 200 nodes, each with a 50-element vector (should be < 100MB)
    n = 200
    arrs = [np.random.randn(50).astype(np.float32) for _ in range(n)]
    contents = [make_recorddict_from_arrays([arr]) for arr in arrs]
    codeflash_output = select_multikrum(contents, num_malicious_nodes=10, num_nodes_to_select=20); selected = codeflash_output
    # All selected must be from input
    selected_arrays = [extract_arrays_from_recorddict(r)[0] for r in selected]
    for arr in selected_arrays:
        pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from flwr.common.record.arrayrecord import ArrayRecord
from flwr.common.record.recorddict import RecordSet
from serverapp.strategy.multikrum import select_multikrum
import pytest

def test_select_multikrum():
    with pytest.raises(ValueError, match='need\\ at\\ least\\ one\\ array\\ to\\ concatenate'):
        select_multikrum([RecordSet(records={}, parameters_records={'': ArrayRecord()}, metrics_records={}, configs_records={})], 0, 0)
🔎 Concolic Coverage Tests and Runtime

To edit these changes git checkout codeflash/optimize-select_multikrum-mh69burg and push.

Codeflash

The optimization achieves a **287% speedup** by vectorizing key operations in the MultiKrum algorithm and eliminating redundant array operations.

**Key optimizations applied:**

1. **Optimized array flattening in `compute_distances`**: Replaced the inefficient list comprehension `[np.concatenate(rec.to_numpy_ndarrays(), axis=None).ravel() for rec in records]` with a manual loop that handles single-array cases more efficiently. This avoids unnecessary concatenation when an `ArrayRecord` contains only one array (which is common), reducing from 50.5% to 27.3% of function time.

2. **Vectorized closest indices computation**: Eliminated the expensive Python loop that called `np.argsort()` for each distance row individually. The original code spent 22.3% of time in the loop calling `np.argsort(...).tolist()` for each row. The optimization uses `np.argsort(distance_matrix, axis=1)` once to sort all rows simultaneously, then slices to get closest indices.

3. **Vectorized score calculation**: Replaced the list comprehension that computed scores row-by-row with `np.take_along_axis()` followed by `.sum(axis=1)`. This eliminates the expensive loop that was taking 50.1% of total time in `select_multikrum`, using NumPy's optimized indexing instead of Python iteration.

**Performance characteristics**: The optimizations are most effective for scenarios with:
- Multiple clients (more distance matrix rows to process)
- Models with single arrays per record (avoids concatenation overhead) 
- Larger model parameters (more benefit from vectorized operations)

The test results show consistent speedups across all scenarios, with the vectorized operations scaling better as the number of nodes and model size increase.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 October 25, 2025 12:29
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Oct 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant