diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 0000000..ffdf97d --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,35 @@ +# PR Labeler configuratoin file +# Automatically add labels based on modified file paths + +docs: + - changed-files: + - any-glob-to-any-file: '**/*.md' + +ci: + - changed-files: + - any-glob-to-any-file: + - '.github/**/*' + - '.pre-commit-config.yaml' + +tests: + - changed-files: + - any-glob-to-any-file: 'tests/**/*' + +core: + - changed-files: + - any-glob-to-any-file: 'vllm_fl/**/*' + +examples: + - changed-files: + - any-glob-to-any-file: 'examples/**/*' + +benchmarks: + - changed-files: + - any-glob-to-any-file: 'benchmarks/**/*' + +build: + - changed-files: + - any-glob-to-any-file: + - 'setup.py' + - 'requirements*.txt' + - 'pyproject.toml' diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml new file mode 100644 index 0000000..3f0e93a --- /dev/null +++ b/.github/workflows/labeler.yml @@ -0,0 +1,24 @@ +name: "Pull Request Labeler" + +on: + pull_request_target: + types: [opened, synchronize, reopened] + +permissions: + contents: read + pull-requests: write + +jobs: + label: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Apply labels + uses: actions/labeler@v5 + continue-on-error: true # Don't fail if config not yet on main branch + with: + repo-token: "${{ secrets.GITHUB_TOKEN }}" + sync-labels: true \ No newline at end of file diff --git a/.gitignore b/.gitignore index c6473d8..b03a82e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ *.egg-info __pycache__/ +build/ +# Coverage +.coverage +.coverage.* +htmlcov/ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ca32a8f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,61 @@ +[build-system] +requires = ["setuptools>=45", "setuptools-scm[toml]>=6.2"] +build-backend = "setuptools.build_meta" + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +markers = [ + "gpu: marks tests as requiring single GPU (deselect with '-m \"not gpu\"')", + "multi_gpu: marks tests as requiring multiple GPUs", + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "integration: marks tests as integration tests", + "e2e: marks tests as end-to-end tests", + "flaggems: marks tests as requiring flag_gems library", + "functional: marks tests as functional tests", +] +addopts = "-v --tb=short" +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore::UserWarning", +] + +[tool.coverage.run] +source = ["vllm_fl"] +omit = [ + "tests/*", + "examples/*", + "benchmarks/*", +] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise NotImplementedError", + "if TYPE_CHECKING:", + "if __name__ == .__main__.:", +] + +[tool.ruff] +line-length = 100 +target-version = "py39" + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions +] +ignore = [ + "E501", # line too long (handled by formatter) + "B008", # do not perform function calls in argument defaults +] + +[tool.ruff.lint.isort] +known-first-party = ["vllm_fl"] diff --git a/requirements.txt b/requirements.txt index 6ae7876..705aa39 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +vllm==0.13.0 decorator pyyaml scipy diff --git a/tests/e2e/conftest.py b/tests/e2e_tests/conftest.py similarity index 100% rename from tests/e2e/conftest.py rename to tests/e2e_tests/conftest.py diff --git a/tests/e2e/test_offline_inference.py b/tests/e2e_tests/test_offline_inference.py similarity index 97% rename from tests/e2e/test_offline_inference.py rename to tests/e2e_tests/test_offline_inference.py index 7627ab8..e19ad8f 100644 --- a/tests/e2e/test_offline_inference.py +++ b/tests/e2e_tests/test_offline_inference.py @@ -4,7 +4,7 @@ import pytest import vllm # noqa: F401 from conftest import VllmRunner -import vllm_flagos +import vllm_fl # noqa: F401 MODELS = [ # "Qwen/Qwen3-0.6B", diff --git a/tests/test_offline_minicmp.py b/tests/e2e_tests/test_offline_minicmp.py similarity index 100% rename from tests/test_offline_minicmp.py rename to tests/e2e_tests/test_offline_minicmp.py diff --git a/tests/test_offline_qwen3_next.py b/tests/e2e_tests/test_offline_qwen3_next.py similarity index 100% rename from tests/test_offline_qwen3_next.py rename to tests/e2e_tests/test_offline_qwen3_next.py diff --git a/tests/test_vllm_serve_minicmp.py b/tests/e2e_tests/test_vllm_serve_minicmp.py similarity index 100% rename from tests/test_vllm_serve_minicmp.py rename to tests/e2e_tests/test_vllm_serve_minicmp.py diff --git a/tests/test_vllm_serve_qwen3_next.py b/tests/e2e_tests/test_vllm_serve_qwen3_next.py similarity index 100% rename from tests/test_vllm_serve_qwen3_next.py rename to tests/e2e_tests/test_vllm_serve_qwen3_next.py diff --git a/tests/functional_tests/__init__.py b/tests/functional_tests/__init__.py new file mode 100644 index 0000000..edb92d6 --- /dev/null +++ b/tests/functional_tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025 BAAI. All rights reserved. +"""Functional tests for vllm_fl.""" diff --git a/tests/functional_tests/compilation/__init__.py b/tests/functional_tests/compilation/__init__.py new file mode 100644 index 0000000..99883a5 --- /dev/null +++ b/tests/functional_tests/compilation/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025 BAAI. All rights reserved. +"""Compilation functional tests.""" diff --git a/tests/functional_tests/compilation/test_graph_capture.py b/tests/functional_tests/compilation/test_graph_capture.py new file mode 100644 index 0000000..6bd9b63 --- /dev/null +++ b/tests/functional_tests/compilation/test_graph_capture.py @@ -0,0 +1,147 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Functional tests for graph capture and replay. +Tests CUDA/NPU graph functionality for model optimization. + +Note: Unit tests for GraphOptions, GraphEntry, and GraphWrapper are in +unit_tests/compilation/test_graph.py. This file only contains functional +tests that require actual GPU execution. +""" + +import pytest +import torch +from dataclasses import dataclass + + +# Mark all tests as requiring GPU +pytestmark = pytest.mark.gpu + + +class TestWeakRefTensors: + """Test weak reference tensor functionality.""" + + def test_weak_ref_tensors_function(self): + """Test weak_ref_tensors function exists.""" + try: + from vllm_fl.compilation.graph import weak_ref_tensors + assert weak_ref_tensors is not None + except ImportError: + pytest.skip("weak_ref_tensors not available") + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + def test_weak_ref_tensors_with_cuda_tensor(self): + """Test weak_ref_tensors with CUDA tensor.""" + try: + from vllm_fl.compilation.graph import weak_ref_tensors + except ImportError: + pytest.skip("weak_ref_tensors not available") + + tensor = torch.randn(4, 8, device="cuda") + result = weak_ref_tensors(tensor) + # Result should be either the tensor or a weak reference + assert result is not None + + +class TestGraphCaptureFlow: + """Test the complete graph capture flow.""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + def test_cuda_graph_basic_capture(self): + """Test basic CUDA graph capture and replay.""" + # Simple test without vllm_fl dependencies + device = torch.device("cuda") + + # Create a simple computation + def computation(x): + return x * 2 + 1 + + # Create input tensor + x = torch.randn(4, 8, device=device) + + # Warmup + y = computation(x) + + # Capture graph + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + y = computation(x) + + # Replay graph + g.replay() + + # Verify output + expected = x * 2 + 1 + assert torch.allclose(y, expected) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + def test_cuda_graph_with_different_inputs(self): + """Test CUDA graph with different input values.""" + device = torch.device("cuda") + + # Static input buffer + static_input = torch.randn(4, 8, device=device) + static_output = torch.empty(4, 8, device=device) + + def computation(x, out): + out.copy_(x * 2) + + # Warmup + computation(static_input, static_output) + + # Capture graph + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + computation(static_input, static_output) + + # Test with new input values (copy to static buffer) + new_input = torch.ones(4, 8, device=device) + static_input.copy_(new_input) + + # Replay + g.replay() + + expected = new_input * 2 + assert torch.allclose(static_output, expected) + + +class TestGraphCacheManagement: + """Test graph cache management functionality.""" + + def test_batch_descriptor_hashing(self): + """Test that batch descriptors can be used as dict keys.""" + @dataclass(frozen=True) + class MockBatchDescriptor: + num_tokens: int + max_num_reqs: int + + desc1 = MockBatchDescriptor(num_tokens=16, max_num_reqs=4) + desc2 = MockBatchDescriptor(num_tokens=16, max_num_reqs=4) + desc3 = MockBatchDescriptor(num_tokens=32, max_num_reqs=8) + + cache = {} + cache[desc1] = "graph1" + cache[desc3] = "graph3" + + # Same values should hash to same key + assert cache[desc2] == "graph1" + assert cache[desc3] == "graph3" + + def test_graph_entry_storage(self): + """Test storing graph entries in cache.""" + try: + from vllm_fl.compilation.graph import GraphEntry + except ImportError: + pytest.skip("GraphEntry not available") + + @dataclass(frozen=True) + class MockBatchDescriptor: + num_tokens: int + + cache = {} + desc = MockBatchDescriptor(num_tokens=16) + + entry = GraphEntry(batch_descriptor=desc) + cache[desc] = entry + + assert cache[desc].batch_descriptor.num_tokens == 16 diff --git a/tests/functional_tests/conftest.py b/tests/functional_tests/conftest.py new file mode 100644 index 0000000..e765f35 --- /dev/null +++ b/tests/functional_tests/conftest.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Functional test fixtures and configuration. +""" + +import os +import pytest +import torch + + +def pytest_configure(config): + """Register custom markers.""" + config.addinivalue_line("markers", "gpu: marks tests as requiring GPU") + config.addinivalue_line("markers", "multi_gpu: marks tests as requiring multiple GPUs") + config.addinivalue_line("markers", "flaggems: marks tests as requiring flag_gems library") + + +@pytest.fixture(scope="session") +def has_gpu(): + """Check if GPU is available.""" + return torch.cuda.is_available() + + +@pytest.fixture(scope="session") +def device(has_gpu): + """Get the test device.""" + if has_gpu: + return torch.device("cuda:0") + return torch.device("cpu") + + +@pytest.fixture(scope="session") +def gpu_count(): + """Get the number of available GPUs.""" + if torch.cuda.is_available(): + return torch.cuda.device_count() + return 0 + + +@pytest.fixture +def reset_dispatch_manager(): + """Reset dispatch manager before and after test.""" + from vllm_fl.dispatch import reset_default_manager, reset_global_policy + + reset_default_manager() + reset_global_policy() + yield + reset_default_manager() + reset_global_policy() + + +@pytest.fixture +def clean_env(): + """Clean dispatch-related environment variables.""" + env_vars = [ + "VLLM_FL_PREFER", + "VLLM_FL_STRICT", + "VLLM_FL_CONFIG", + "VLLM_FL_DENY_VENDORS", + "VLLM_FL_ALLOW_VENDORS", + "VLLM_FL_PER_OP", + "VLLM_FL_DISPATCH_DEBUG", + ] + + # Save original values + original = {k: os.environ.get(k) for k in env_vars} + + # Clear env vars + for k in env_vars: + os.environ.pop(k, None) + + yield + + # Restore original values + for k, v in original.items(): + if v is not None: + os.environ[k] = v + else: + os.environ.pop(k, None) + + +def skip_if_no_gpu(fn): + """Decorator to skip test if no GPU is available.""" + return pytest.mark.skipif( + not torch.cuda.is_available(), + reason="GPU not available" + )(fn) + + +def skip_if_no_multi_gpu(fn): + """Decorator to skip test if less than 2 GPUs available.""" + return pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Multiple GPUs not available" + )(fn) diff --git a/tests/functional_tests/distributed/__init__.py b/tests/functional_tests/distributed/__init__.py new file mode 100644 index 0000000..7a95e8f --- /dev/null +++ b/tests/functional_tests/distributed/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025 BAAI. All rights reserved. +"""Distributed functional tests.""" diff --git a/tests/functional_tests/distributed/test_collective_ops.py b/tests/functional_tests/distributed/test_collective_ops.py new file mode 100644 index 0000000..bb8e330 --- /dev/null +++ b/tests/functional_tests/distributed/test_collective_ops.py @@ -0,0 +1,209 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Functional tests for distributed collective operations. +Tests correctness of collective operations like all_reduce, reduce_scatter, etc. + +NOTE: These tests require multiple GPUs and a distributed environment. +They are designed to be run with pytest-mpi or similar multi-process test runners. +""" + +import pytest +import torch +from typing import List +from unittest.mock import MagicMock, patch + + +# Mark all tests as requiring multiple GPUs +pytestmark = [pytest.mark.multi_gpu, pytest.mark.gpu] + + +class TestCollectiveOpsBasic: + """Basic tests for collective operations that can run without actual distributed setup.""" + + def test_communicator_fl_import(self): + """Test that CommunicatorFL can be imported.""" + try: + from vllm_fl.distributed.communicator import CommunicatorFL + assert CommunicatorFL is not None + except ImportError as e: + pytest.skip(f"CommunicatorFL not available: {e}") + + def test_pyflagcx_import(self): + """Test that PyFlagcxCommunicator can be imported.""" + try: + from vllm_fl.distributed.device_communicators.flagcx import PyFlagcxCommunicator + assert PyFlagcxCommunicator is not None + except ImportError as e: + pytest.skip(f"PyFlagcxCommunicator not available: {e}") + + +class TestAllReduceCorrectness: + """Test all_reduce operation correctness.""" + + @staticmethod + def reference_all_reduce(tensors: List[torch.Tensor]) -> torch.Tensor: + """Reference implementation of all_reduce (sum).""" + return sum(tensors) + + @pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Multiple GPUs not available" + ) + def test_all_reduce_sum_correctness(self): + """Test all_reduce sum produces correct results.""" + # This test would need actual distributed setup + # For now, test the reference implementation + tensors = [ + torch.tensor([1.0, 2.0, 3.0]), + torch.tensor([4.0, 5.0, 6.0]), + ] + expected = torch.tensor([5.0, 7.0, 9.0]) + result = self.reference_all_reduce(tensors) + assert torch.allclose(result, expected) + + +class TestReduceScatterCorrectness: + """Test reduce_scatter operation correctness.""" + + @staticmethod + def reference_reduce_scatter( + input_tensor: torch.Tensor, + world_size: int + ) -> List[torch.Tensor]: + """Reference implementation of reduce_scatter.""" + # Split input into chunks + chunks = input_tensor.chunk(world_size, dim=0) + # Each rank gets the reduced chunk at its position + return list(chunks) + + def test_reduce_scatter_reference(self): + """Test reference reduce_scatter implementation.""" + input_tensor = torch.tensor([ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + ]) + world_size = 2 + + result = self.reference_reduce_scatter(input_tensor, world_size) + + assert len(result) == world_size + assert torch.allclose(result[0], torch.tensor([[1.0, 2.0], [3.0, 4.0]])) + assert torch.allclose(result[1], torch.tensor([[5.0, 6.0], [7.0, 8.0]])) + + +class TestAllGatherCorrectness: + """Test all_gather operation correctness.""" + + @staticmethod + def reference_all_gather(tensors: List[torch.Tensor]) -> torch.Tensor: + """Reference implementation of all_gather.""" + return torch.cat(tensors, dim=0) + + def test_all_gather_reference(self): + """Test reference all_gather implementation.""" + tensors = [ + torch.tensor([[1.0, 2.0]]), + torch.tensor([[3.0, 4.0]]), + ] + expected = torch.tensor([ + [1.0, 2.0], + [3.0, 4.0], + ]) + + result = self.reference_all_gather(tensors) + assert torch.allclose(result, expected) + + +class TestSendRecvCorrectness: + """Test point-to-point send/recv operations.""" + + @pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Multiple GPUs not available" + ) + def test_send_recv_mock(self): + """Test send/recv with mocked communicator.""" + # Create mock communicator + mock_comm = MagicMock() + mock_comm.disabled = False + + tensor = torch.randn(4, 8) + + # Simulate send + mock_comm.send(tensor, dst=1) + mock_comm.send.assert_called_once() + + # Simulate recv + mock_comm.recv.return_value = tensor.clone() + received = mock_comm.recv(tensor.shape, tensor.dtype, src=0) + assert received.shape == tensor.shape + + +class TestCommunicatorDisabled: + """Test communicator behavior when disabled.""" + + def test_disabled_all_reduce_returns_none(self): + """Test that disabled communicator all_reduce returns None.""" + mock_comm = MagicMock() + mock_comm.disabled = True + + def mock_all_reduce(tensor, out=None, op=None, stream=None): + if mock_comm.disabled: + return None + return torch.empty_like(tensor) + + mock_comm.all_reduce = mock_all_reduce + + result = mock_comm.all_reduce(torch.randn(4, 8)) + assert result is None + + def test_disabled_send_returns_early(self): + """Test that disabled communicator send returns early.""" + mock_comm = MagicMock() + mock_comm.disabled = True + + call_count = [0] + + def mock_send(tensor, dst, stream=None): + if mock_comm.disabled: + return + call_count[0] += 1 + # Would do actual send here + + mock_comm.send = mock_send + mock_comm.send(torch.randn(4, 8), dst=1) + + assert call_count[0] == 0 + + +class TestDistributedUtils: + """Test distributed utility functions.""" + + def test_world_size_calculation(self): + """Test world size calculation logic.""" + # Single GPU case + world_size_1 = 1 + assert world_size_1 == 1 + + # Multi-GPU case + world_size_4 = 4 + tp_size = 2 + pp_size = 2 + assert world_size_4 == tp_size * pp_size + + def test_rank_calculation(self): + """Test rank calculation in tensor/pipeline parallel.""" + world_size = 4 + tp_size = 2 + pp_size = 2 + + # Calculate expected ranks + for global_rank in range(world_size): + tp_rank = global_rank % tp_size + pp_rank = global_rank // tp_size + + assert 0 <= tp_rank < tp_size + assert 0 <= pp_rank < pp_size diff --git a/tests/functional_tests/ops/__init__.py b/tests/functional_tests/ops/__init__.py new file mode 100644 index 0000000..8ef0672 --- /dev/null +++ b/tests/functional_tests/ops/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025 BAAI. All rights reserved. +"""Ops functional tests.""" diff --git a/tests/functional_tests/ops/test_ops_correctness.py b/tests/functional_tests/ops/test_ops_correctness.py new file mode 100644 index 0000000..4df2e39 --- /dev/null +++ b/tests/functional_tests/ops/test_ops_correctness.py @@ -0,0 +1,311 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Functional tests for ops correctness. +Tests numerical correctness of operator implementations +by comparing against reference PyTorch implementations. +""" + +import pytest +import torch +import torch.nn.functional as F +from typing import Tuple + + +# Skip all tests in this module if GPU not available +pytestmark = pytest.mark.gpu + + +def allclose(a: torch.Tensor, b: torch.Tensor, rtol: float = 1e-3, atol: float = 1e-3) -> bool: + """Check if two tensors are close within tolerance.""" + return torch.allclose(a, b, rtol=rtol, atol=atol) + + +class TestSiluAndMulCorrectness: + """Test SiluAndMul operator correctness.""" + + @pytest.fixture + def test_shapes(self): + """Common test shapes for SiluAndMul.""" + return [ + (1, 64), + (4, 128), + (16, 256), + (32, 512), + (64, 1024), + ] + + @staticmethod + def reference_silu_and_mul(x: torch.Tensor) -> torch.Tensor: + """Reference implementation of SiluAndMul.""" + half = x.shape[-1] // 2 + return F.silu(x[..., :half]) * x[..., half:] + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + def test_silu_and_mul_forward(self, test_shapes, device): + """Test SiluAndMul forward pass correctness.""" + try: + from vllm_fl.dispatch import call_op + except ImportError: + pytest.skip("vllm_fl.dispatch not available") + + for batch_size, hidden_size in test_shapes: + # Input must have even hidden size for SiluAndMul + x = torch.randn(batch_size, hidden_size * 2, device=device, dtype=torch.float32) + + # Get reference result + ref_result = self.reference_silu_and_mul(x) + + # Get FL result + try: + fl_result = call_op("silu_and_mul", x) + + # Check correctness + assert fl_result.shape == ref_result.shape, ( + f"Shape mismatch: {fl_result.shape} vs {ref_result.shape}" + ) + assert allclose(fl_result, ref_result), ( + f"Value mismatch for shape ({batch_size}, {hidden_size * 2})" + ) + except RuntimeError as e: + if "No available implementation" in str(e): + pytest.skip(f"silu_and_mul not registered: {e}") + raise + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + def test_silu_and_mul_dtypes(self, device): + """Test SiluAndMul with different dtypes.""" + try: + from vllm_fl.dispatch import call_op + except ImportError: + pytest.skip("vllm_fl.dispatch not available") + + dtypes = [torch.float32, torch.float16, torch.bfloat16] + x_fp32 = torch.randn(4, 128, device=device, dtype=torch.float32) + + for dtype in dtypes: + x = x_fp32.to(dtype) + ref_result = self.reference_silu_and_mul(x) + + try: + fl_result = call_op("silu_and_mul", x) + # Use looser tolerance for half precision + tol = 1e-2 if dtype in [torch.float16, torch.bfloat16] else 1e-3 + assert allclose(fl_result, ref_result, rtol=tol, atol=tol), ( + f"Value mismatch for dtype {dtype}" + ) + except RuntimeError as e: + if "No available implementation" in str(e): + pytest.skip(f"silu_and_mul not registered for {dtype}: {e}") + raise + + +class TestRMSNormCorrectness: + """Test RMSNorm operator correctness.""" + + @pytest.fixture + def test_shapes(self): + """Common test shapes for RMSNorm.""" + return [ + (1, 64, 128), # (batch, seq, hidden) + (4, 32, 256), + (8, 16, 512), + ] + + @staticmethod + def reference_rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6 + ) -> torch.Tensor: + """Reference implementation of RMSNorm.""" + variance = x.pow(2).mean(-1, keepdim=True) + x_normed = x * torch.rsqrt(variance + eps) + return x_normed * weight + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + def test_rms_norm_forward(self, test_shapes, device): + """Test RMSNorm forward pass correctness.""" + try: + from vllm_fl.dispatch import call_op + except ImportError: + pytest.skip("vllm_fl.dispatch not available") + + eps = 1e-6 + for batch_size, seq_len, hidden_size in test_shapes: + x = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=torch.float32) + weight = torch.ones(hidden_size, device=device, dtype=torch.float32) + + # Get reference result + ref_result = self.reference_rms_norm(x, weight, eps) + + # Get FL result + try: + fl_result = call_op("rms_norm", x, None, weight, eps) + + # Handle tuple return (output, residual) + if isinstance(fl_result, tuple): + fl_result = fl_result[0] + + assert fl_result.shape == ref_result.shape, ( + f"Shape mismatch: {fl_result.shape} vs {ref_result.shape}" + ) + assert allclose(fl_result, ref_result), ( + f"Value mismatch for shape ({batch_size}, {seq_len}, {hidden_size})" + ) + except RuntimeError as e: + if "No available implementation" in str(e): + pytest.skip(f"rms_norm not registered: {e}") + raise + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + def test_rms_norm_with_residual(self, device): + """Test RMSNorm with residual connection.""" + try: + from vllm_fl.dispatch import call_op + except ImportError: + pytest.skip("vllm_fl.dispatch not available") + + batch_size, seq_len, hidden_size = 4, 32, 256 + eps = 1e-6 + + x = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=torch.float32) + residual = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=torch.float32) + weight = torch.ones(hidden_size, device=device, dtype=torch.float32) + + try: + result = call_op("rms_norm", x, residual, weight, eps) + + # Should return tuple (normalized, updated_residual) when residual is provided + if isinstance(result, tuple): + normalized, updated_residual = result + assert normalized.shape == x.shape + assert updated_residual.shape == x.shape + except RuntimeError as e: + if "No available implementation" in str(e): + pytest.skip(f"rms_norm not registered: {e}") + raise + + +class TestRotaryEmbeddingCorrectness: + """Test RotaryEmbedding operator correctness.""" + + @staticmethod + def reference_rotary_embedding( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + positions: torch.Tensor, + rotary_interleaved: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Reference implementation of rotary embedding.""" + + def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate half the hidden dims.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + # Gather cos/sin by positions + cos_pos = cos[positions] + sin_pos = sin[positions] + + # Add head dimension + while cos_pos.dim() < q.dim(): + cos_pos = cos_pos.unsqueeze(1) + sin_pos = sin_pos.unsqueeze(1) + + # Apply rotary embedding + q_embed = (q * cos_pos) + (rotate_half(q) * sin_pos) + k_embed = (k * cos_pos) + (rotate_half(k) * sin_pos) + + return q_embed, k_embed + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + def test_rotary_embedding_basic(self, device): + """Test basic rotary embedding functionality.""" + try: + from vllm_fl.dispatch import call_op + except ImportError: + pytest.skip("vllm_fl.dispatch not available") + + num_tokens = 16 + num_heads = 8 + head_size = 64 + rotary_dim = head_size + max_position = 2048 + + # Create test inputs + q = torch.randn(num_tokens, num_heads, head_size, device=device, dtype=torch.float32) + k = torch.randn(num_tokens, num_heads, head_size, device=device, dtype=torch.float32) + positions = torch.arange(num_tokens, device=device) + + # Create cos/sin cache + inv_freq = 1.0 / (10000.0 ** (torch.arange(0, rotary_dim, 2, device=device).float() / rotary_dim)) + t = torch.arange(max_position, device=device).float() + freqs = torch.outer(t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + + try: + q_out, k_out = call_op( + "rotary_embedding", + q[..., :rotary_dim], + k[..., :rotary_dim], + cos, + sin, + positions, + False, # rotary_interleaved + False, # inplace + ) + + assert q_out.shape == q[..., :rotary_dim].shape + assert k_out.shape == k[..., :rotary_dim].shape + except RuntimeError as e: + if "No available implementation" in str(e): + pytest.skip(f"rotary_embedding not registered: {e}") + raise + + +class TestOpsEdgeCases: + """Test edge cases for operators.""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + def test_empty_tensor_handling(self, device): + """Test handling of empty tensors.""" + try: + from vllm_fl.dispatch import call_op + except ImportError: + pytest.skip("vllm_fl.dispatch not available") + + # Create empty tensor + x = torch.empty(0, 64, device=device, dtype=torch.float32) + + # Some ops may handle empty tensors, others may raise + # This test documents the behavior + try: + result = call_op("silu_and_mul", x) + assert result.shape[0] == 0 + except (RuntimeError, ValueError): + # Empty tensor handling is implementation-dependent + pass + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + def test_large_batch_handling(self, device): + """Test handling of large batch sizes.""" + try: + from vllm_fl.dispatch import call_op + except ImportError: + pytest.skip("vllm_fl.dispatch not available") + + # Large batch + x = torch.randn(1024, 256, device=device, dtype=torch.float32) + + try: + result = call_op("silu_and_mul", x) + assert result.shape == (1024, 128) + except RuntimeError as e: + if "No available implementation" in str(e): + pytest.skip(f"silu_and_mul not registered: {e}") + raise diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py new file mode 100644 index 0000000..8806137 --- /dev/null +++ b/tests/unit_tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025 BAAI. All rights reserved. +"""Unit tests for vllm_fl.""" diff --git a/tests/unit_tests/compilation/__init__.py b/tests/unit_tests/compilation/__init__.py new file mode 100644 index 0000000..8c90136 --- /dev/null +++ b/tests/unit_tests/compilation/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2025 BAAI. All rights reserved. diff --git a/tests/unit_tests/compilation/test_graph.py b/tests/unit_tests/compilation/test_graph.py new file mode 100644 index 0000000..7ad08a0 --- /dev/null +++ b/tests/unit_tests/compilation/test_graph.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Tests for compilation graph module. +""" + +import pytest +from unittest.mock import MagicMock + + +class TestGraphOptions: + """Test GraphOptions dataclass.""" + + def test_default_values(self): + from vllm_fl.compilation.graph import GraphOptions + + options = GraphOptions() + + assert options.debug_log_enable is True + assert options.gc_disable is False + assert options.weak_ref_output is True + + def test_custom_values(self): + from vllm_fl.compilation.graph import GraphOptions + + options = GraphOptions( + debug_log_enable=False, + gc_disable=True, + weak_ref_output=False, + ) + + assert options.debug_log_enable is False + assert options.gc_disable is True + assert options.weak_ref_output is False + + +class TestGraphEntry: + """Test GraphEntry dataclass.""" + + def test_default_values(self): + from vllm_fl.compilation.graph import GraphEntry + + mock_batch_desc = MagicMock() + + entry = GraphEntry(batch_descriptor=mock_batch_desc) + + assert entry.batch_descriptor is mock_batch_desc + assert entry.graph is None + assert entry.output is None + assert entry.input_addresses is None diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py new file mode 100644 index 0000000..8a0a386 --- /dev/null +++ b/tests/unit_tests/conftest.py @@ -0,0 +1,145 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Unit test fixtures and configuration. + +This module provides shared fixtures for all unit tests. +""" + +import os +import pytest +import torch +from unittest.mock import MagicMock, NonCallableMagicMock + + +# ============================================================================= +# Environment Detection Helpers +# ============================================================================= + +def has_cuda(): + """Check if CUDA is available.""" + return torch.cuda.is_available() + + +def has_flagcx(): + """Check if flagcx is available.""" + flagcx_path = os.getenv('FLAGCX_PATH') + if not flagcx_path: + return False + lib_path = os.path.join(flagcx_path, "build/lib/libflagcx.so") + return os.path.exists(lib_path) + + +def has_vllm_profiler(): + """Check if vllm profiler is available.""" + try: + from vllm.profiler.wrapper import TorchProfilerWrapper + return True + except ImportError: + return False + + +# ============================================================================= +# Basic Fixtures +# ============================================================================= + +@pytest.fixture +def mock_tensor(): + """Create a simple tensor for testing.""" + return torch.randn(2, 4, 8) + + +@pytest.fixture +def device(): + """Get the available device.""" + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +@pytest.fixture +def cpu_device(): + """Always return CPU device.""" + return torch.device("cpu") + + +# ============================================================================= +# Mock Factory Fixtures +# ============================================================================= + +@pytest.fixture +def mock_module(): + """ + Create a mock that behaves like a Python module. + + Use this when you need to mock a module object that: + - Is not callable + - May have specific attributes + """ + return NonCallableMagicMock(spec=['__name__', '__file__']) + + +@pytest.fixture +def mock_module_with_register(): + """ + Create a mock module with a register function. + + Useful for testing plugin discovery. + """ + module = NonCallableMagicMock(spec=['register']) + module.register = MagicMock() + return module + + +@pytest.fixture +def mock_process_group(): + """ + Create a mock torch distributed ProcessGroup. + + Useful for testing distributed communication. + """ + group = MagicMock() + group.rank.return_value = 0 + group.size.return_value = 1 + return group + + +# ============================================================================= +# Tensor Fixtures +# ============================================================================= + +@pytest.fixture +def batch_tensors(): + """Create a batch of tensors for testing.""" + return { + 'small': torch.randn(2, 8), + 'medium': torch.randn(4, 16, 32), + 'large': torch.randn(8, 32, 64, 128), + } + + +@pytest.fixture +def dtype_tensors(): + """Create tensors with different dtypes.""" + return { + 'float32': torch.randn(2, 4, dtype=torch.float32), + 'float16': torch.randn(2, 4, dtype=torch.float16), + 'bfloat16': torch.randn(2, 4, dtype=torch.bfloat16), + } + + +# ============================================================================= +# Pytest Markers +# ============================================================================= + +def pytest_configure(config): + """Register custom markers.""" + config.addinivalue_line( + "markers", "gpu: marks tests as requiring GPU" + ) + config.addinivalue_line( + "markers", "flagcx: marks tests as requiring flagcx" + ) + config.addinivalue_line( + "markers", "slow: marks tests as slow" + ) diff --git a/tests/unit_tests/dispatch/__init__.py b/tests/unit_tests/dispatch/__init__.py new file mode 100644 index 0000000..8c90136 --- /dev/null +++ b/tests/unit_tests/dispatch/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2025 BAAI. All rights reserved. diff --git a/tests/unit_tests/dispatch/test_call_op.py b/tests/unit_tests/dispatch/test_call_op.py new file mode 100644 index 0000000..0e76084 --- /dev/null +++ b/tests/unit_tests/dispatch/test_call_op.py @@ -0,0 +1,398 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Tests for dispatch call_op and resolve_op convenience functions. + +This module tests the high-level dispatch API exposed through the +dispatch module's __init__.py, ensuring the full dispatch pipeline +works correctly from call_op -> manager -> registry -> implementation. +""" + +import os +import pytest +from unittest.mock import patch, MagicMock + +from vllm_fl.dispatch import ( + call_op, + resolve_op, + get_default_manager, + reset_default_manager, + OpRegistry, + OpImpl, + BackendImplKind, + BackendPriority, + SelectionPolicy, + set_global_policy, + reset_global_policy, + policy_context, + with_preference, + PREFER_DEFAULT, + PREFER_VENDOR, + PREFER_REFERENCE, +) + + +class TestCallOp: + """Test call_op convenience function.""" + + @pytest.fixture(autouse=True) + def reset_all(self): + """Reset global state before and after each test.""" + reset_default_manager() + reset_global_policy() + yield + reset_default_manager() + reset_global_policy() + + @pytest.fixture + def setup_test_op(self): + """Setup a test operator in the registry.""" + manager = get_default_manager() + manager._state.initialized = True + manager._state.init_pid = os.getpid() + + def impl_fn(x, multiplier=2): + return x * multiplier + + impl = OpImpl( + op_name="test_call_op", + impl_id="default.test", + kind=BackendImplKind.DEFAULT, + fn=impl_fn, + ) + manager.registry.register_impl(impl) + return impl_fn + + def test_call_op_basic(self, setup_test_op): + result = call_op("test_call_op", 5) + assert result == 10 + + def test_call_op_with_kwargs(self, setup_test_op): + result = call_op("test_call_op", 5, multiplier=3) + assert result == 15 + + def test_call_op_nonexistent_raises(self): + manager = get_default_manager() + manager._state.initialized = True + manager._state.init_pid = os.getpid() + + with pytest.raises(RuntimeError, match="No available implementation"): + call_op("nonexistent_op", 1) + + def test_call_op_uses_default_manager(self): + manager = get_default_manager() + manager._state.initialized = True + manager._state.init_pid = os.getpid() + + call_tracker = {"called": False} + + def tracking_fn(x): + call_tracker["called"] = True + return x + + manager.registry.register_impl(OpImpl( + op_name="track_op", + impl_id="default.track", + kind=BackendImplKind.DEFAULT, + fn=tracking_fn, + )) + + call_op("track_op", 1) + assert call_tracker["called"] is True + + +class TestResolveOp: + """Test resolve_op convenience function.""" + + @pytest.fixture(autouse=True) + def reset_all(self): + reset_default_manager() + reset_global_policy() + yield + reset_default_manager() + reset_global_policy() + + @pytest.fixture + def setup_test_op(self): + manager = get_default_manager() + manager._state.initialized = True + manager._state.init_pid = os.getpid() + + def impl_fn(x): + return x * 2 + + impl = OpImpl( + op_name="test_resolve_op", + impl_id="default.test", + kind=BackendImplKind.DEFAULT, + fn=impl_fn, + ) + manager.registry.register_impl(impl) + return impl_fn + + def test_resolve_op_returns_function(self, setup_test_op): + fn = resolve_op("test_resolve_op") + assert callable(fn) + assert fn is setup_test_op + + def test_resolve_op_can_be_called(self, setup_test_op): + fn = resolve_op("test_resolve_op") + result = fn(5) + assert result == 10 + + def test_resolve_op_nonexistent_raises(self): + manager = get_default_manager() + manager._state.initialized = True + manager._state.init_pid = os.getpid() + + with pytest.raises(RuntimeError, match="No available implementation"): + resolve_op("nonexistent_op") + + +class TestCallOpWithPolicy: + """Test call_op behavior with different policies.""" + + @pytest.fixture(autouse=True) + def reset_all(self): + reset_default_manager() + reset_global_policy() + yield + reset_default_manager() + reset_global_policy() + + @pytest.fixture + def setup_multi_impl_op(self): + """Setup an operator with multiple implementations.""" + manager = get_default_manager() + manager._state.initialized = True + manager._state.init_pid = os.getpid() + + results = {"default": 0, "vendor": 0, "reference": 0} + + def default_fn(x): + results["default"] += 1 + return x * 2 + + def vendor_fn(x): + results["vendor"] += 1 + return x * 3 + + def reference_fn(x): + results["reference"] += 1 + return x * 4 + + manager.registry.register_impl(OpImpl( + op_name="policy_op", + impl_id="default.impl", + kind=BackendImplKind.DEFAULT, + fn=default_fn, + priority=BackendPriority.DEFAULT, + )) + manager.registry.register_impl(OpImpl( + op_name="policy_op", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=vendor_fn, + priority=BackendPriority.VENDOR, + vendor="CUDA", + )) + manager.registry.register_impl(OpImpl( + op_name="policy_op", + impl_id="reference.pytorch", + kind=BackendImplKind.REFERENCE, + fn=reference_fn, + priority=BackendPriority.REFERENCE, + )) + + return results + + def test_call_op_default_policy_uses_default(self, setup_multi_impl_op): + results = setup_multi_impl_op + + result = call_op("policy_op", 5) + + assert result == 10 # default_fn: x * 2 + assert results["default"] == 1 + assert results["vendor"] == 0 + assert results["reference"] == 0 + + def test_call_op_vendor_policy(self, setup_multi_impl_op): + results = setup_multi_impl_op + + set_global_policy(SelectionPolicy(prefer=PREFER_VENDOR)) + # Clear cache after policy change + get_default_manager().bump_policy_epoch() + + result = call_op("policy_op", 5) + + assert result == 15 # vendor_fn: x * 3 + assert results["vendor"] == 1 + + def test_call_op_reference_policy(self, setup_multi_impl_op): + results = setup_multi_impl_op + + set_global_policy(SelectionPolicy(prefer=PREFER_REFERENCE)) + get_default_manager().bump_policy_epoch() + + result = call_op("policy_op", 5) + + assert result == 20 # reference_fn: x * 4 + assert results["reference"] == 1 + + def test_call_op_with_policy_context(self, setup_multi_impl_op): + results = setup_multi_impl_op + + # Default call + result1 = call_op("policy_op", 5) + assert result1 == 10 + assert results["default"] == 1 + + # With vendor preference context + with with_preference("vendor"): + get_default_manager().bump_policy_epoch() + result2 = call_op("policy_op", 5) + assert result2 == 15 + assert results["vendor"] == 1 + + # Back to default after context + get_default_manager().bump_policy_epoch() + result3 = call_op("policy_op", 5) + assert result3 == 10 + assert results["default"] == 2 + + +class TestCallOpIntegration: + """Integration tests for the full dispatch pipeline.""" + + @pytest.fixture(autouse=True) + def reset_all(self): + reset_default_manager() + reset_global_policy() + yield + reset_default_manager() + reset_global_policy() + + def test_full_pipeline_with_multiple_ops(self): + """Test calling multiple different operators.""" + manager = get_default_manager() + manager._state.initialized = True + manager._state.init_pid = os.getpid() + + # Register multiple operators + manager.registry.register_impl(OpImpl( + op_name="add_op", + impl_id="default.add", + kind=BackendImplKind.DEFAULT, + fn=lambda x, y: x + y, + )) + manager.registry.register_impl(OpImpl( + op_name="mul_op", + impl_id="default.mul", + kind=BackendImplKind.DEFAULT, + fn=lambda x, y: x * y, + )) + manager.registry.register_impl(OpImpl( + op_name="sub_op", + impl_id="default.sub", + kind=BackendImplKind.DEFAULT, + fn=lambda x, y: x - y, + )) + + # Call each operator + assert call_op("add_op", 2, 3) == 5 + assert call_op("mul_op", 2, 3) == 6 + assert call_op("sub_op", 5, 3) == 2 + + def test_resolve_and_call_consistency(self): + """Test that resolve_op and call_op give consistent results.""" + manager = get_default_manager() + manager._state.initialized = True + manager._state.init_pid = os.getpid() + + manager.registry.register_impl(OpImpl( + op_name="consistent_op", + impl_id="default.impl", + kind=BackendImplKind.DEFAULT, + fn=lambda x: x * 10, + )) + + # Both should give same result + fn = resolve_op("consistent_op") + result1 = fn(5) + result2 = call_op("consistent_op", 5) + + assert result1 == result2 == 50 + + @patch.dict(os.environ, {"VLLM_FL_STRICT": "1"}) + def test_fallback_chain(self): + """Test fallback from failed impl to successful one.""" + manager = get_default_manager() + manager._state.initialized = True + manager._state.init_pid = os.getpid() + + call_sequence = [] + + def failing_impl(x): + call_sequence.append("failing") + raise RuntimeError("Intentional failure") + + def success_impl(x): + call_sequence.append("success") + return x * 2 + + manager.registry.register_impl(OpImpl( + op_name="fallback_chain_op", + impl_id="default.failing", + kind=BackendImplKind.DEFAULT, + fn=failing_impl, + priority=200, + )) + manager.registry.register_impl(OpImpl( + op_name="fallback_chain_op", + impl_id="reference.success", + kind=BackendImplKind.REFERENCE, + fn=success_impl, + priority=100, + )) + + result = call_op("fallback_chain_op", 5) + + assert result == 10 + assert call_sequence == ["failing", "success"] + + def test_vendor_filtering(self): + """Test that vendor filtering works through call_op.""" + manager = get_default_manager() + manager._state.initialized = True + manager._state.init_pid = os.getpid() + + manager.registry.register_impl(OpImpl( + op_name="vendor_filter_op", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=lambda x: x * 2, + vendor="CUDA", + )) + manager.registry.register_impl(OpImpl( + op_name="vendor_filter_op", + impl_id="vendor.amd", + kind=BackendImplKind.VENDOR, + fn=lambda x: x * 3, + vendor="AMD", + )) + manager.registry.register_impl(OpImpl( + op_name="vendor_filter_op", + impl_id="reference.pytorch", + kind=BackendImplKind.REFERENCE, + fn=lambda x: x * 4, + )) + + # Deny AMD, prefer vendor -> should use CUDA + set_global_policy(SelectionPolicy( + prefer=PREFER_VENDOR, + deny_vendors=frozenset({"AMD"}) + )) + manager.bump_policy_epoch() + + result = call_op("vendor_filter_op", 5) + assert result == 10 # CUDA: x * 2 diff --git a/tests/unit_tests/dispatch/test_discovery.py b/tests/unit_tests/dispatch/test_discovery.py new file mode 100644 index 0000000..c6633cd --- /dev/null +++ b/tests/unit_tests/dispatch/test_discovery.py @@ -0,0 +1,134 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Tests for dispatch discovery module. +""" + +import os +import pytest +from unittest.mock import patch, MagicMock, NonCallableMagicMock + +from vllm_fl.dispatch.discovery import ( + discover_plugins, + discover_from_env_modules, + get_discovered_plugins, + clear_discovered_plugins, + _call_register_function, + PLUGIN_MODULES_ENV, +) + + +class TestCallRegisterFunction: + def test_direct_callable(self): + registry = MagicMock() + fn = MagicMock() + + result = _call_register_function(fn, registry, "test") + + assert result is True + fn.assert_called_once_with(registry) + + def test_module_with_register_function(self): + registry = MagicMock() + module = NonCallableMagicMock(spec=["register"]) # Only has register attr + module.register = MagicMock() + + result = _call_register_function(module, registry, "test") + + assert result is True + module.register.assert_called_once_with(registry) + + def test_module_with_vllm_fl_register(self): + registry = MagicMock() + module = NonCallableMagicMock(spec=["vllm_fl_register"]) # Only has vllm_fl_register attr + module.vllm_fl_register = MagicMock() + + result = _call_register_function(module, registry, "test") + + assert result is True + module.vllm_fl_register.assert_called_once_with(registry) + + def test_callable_raises_exception(self): + registry = MagicMock() + fn = MagicMock(side_effect=Exception("test error")) + + result = _call_register_function(fn, registry, "test") + + assert result is False + + def test_no_register_function(self): + registry = MagicMock() + module = NonCallableMagicMock(spec=[]) # Not callable, no register function + + result = _call_register_function(module, registry, "test") + + assert result is False + + +class TestDiscoverFromEnvModules: + @pytest.fixture(autouse=True) + def clear_plugins(self): + clear_discovered_plugins() + yield + clear_discovered_plugins() + + def test_empty_env_var(self): + with patch.dict(os.environ, {PLUGIN_MODULES_ENV: ""}): + registry = MagicMock() + result = discover_from_env_modules(registry) + assert result == 0 + + def test_no_env_var(self): + env = os.environ.copy() + env.pop(PLUGIN_MODULES_ENV, None) + with patch.dict(os.environ, env, clear=True): + registry = MagicMock() + result = discover_from_env_modules(registry) + assert result == 0 + + def test_import_error_handling(self): + with patch.dict(os.environ, {PLUGIN_MODULES_ENV: "nonexistent_module"}): + registry = MagicMock() + result = discover_from_env_modules(registry) + assert result == 0 + plugins = get_discovered_plugins() + assert len(plugins) == 1 + assert plugins[0][2] is False # success = False + + +class TestDiscoverPlugins: + @pytest.fixture(autouse=True) + def clear_plugins(self): + clear_discovered_plugins() + yield + clear_discovered_plugins() + + def test_none_registry(self): + result = discover_plugins(None) + assert result == 0 + + def test_empty_discovery(self): + with patch.dict(os.environ, {PLUGIN_MODULES_ENV: ""}): + with patch( + "vllm_fl.dispatch.discovery._get_entry_points", return_value=[] + ): + registry = MagicMock() + result = discover_plugins(registry) + assert result == 0 + + +class TestGetDiscoveredPlugins: + @pytest.fixture(autouse=True) + def clear_plugins(self): + clear_discovered_plugins() + yield + clear_discovered_plugins() + + def test_returns_copy(self): + plugins1 = get_discovered_plugins() + plugins2 = get_discovered_plugins() + assert plugins1 is not plugins2 + + def test_initially_empty(self): + plugins = get_discovered_plugins() + assert plugins == [] diff --git a/tests/unit_tests/dispatch/test_manager.py b/tests/unit_tests/dispatch/test_manager.py new file mode 100644 index 0000000..e5a5261 --- /dev/null +++ b/tests/unit_tests/dispatch/test_manager.py @@ -0,0 +1,795 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Tests for dispatch manager module. + +This module tests the core OpManager class which handles: +- Operator resolution and selection +- Dispatch caching with policy epoch invalidation +- Fallback mechanisms for failed implementations +- Multi-process safety (fork handling) +""" + +import os +import pytest +import threading +from unittest.mock import patch, MagicMock + +from vllm_fl.dispatch.manager import ( + OpManager, + get_default_manager, + reset_default_manager, + _OpManagerState, +) +from vllm_fl.dispatch.registry import OpRegistry +from vllm_fl.dispatch.types import OpImpl, BackendImplKind, BackendPriority +from vllm_fl.dispatch.policy import ( + SelectionPolicy, + set_global_policy, + reset_global_policy, + PREFER_DEFAULT, + PREFER_VENDOR, + PREFER_REFERENCE, +) + + +class TestOpManagerState: + """Test _OpManagerState dataclass.""" + + def test_default_values(self): + state = _OpManagerState() + assert state.init_pid == -1 + assert state.initialized is False + assert state.policy_epoch == 0 + + +class TestOpManagerBasic: + """Test basic OpManager functionality.""" + + @pytest.fixture + def registry(self): + return OpRegistry() + + @pytest.fixture + def manager(self, registry): + return OpManager(registry=registry) + + @pytest.fixture + def sample_impl(self): + return OpImpl( + op_name="test_op", + impl_id="default.test", + kind=BackendImplKind.DEFAULT, + fn=lambda x: x * 2, + priority=BackendPriority.DEFAULT, + ) + + @pytest.fixture + def reference_impl(self): + return OpImpl( + op_name="test_op", + impl_id="reference.test", + kind=BackendImplKind.REFERENCE, + fn=lambda x: x * 2, + priority=BackendPriority.REFERENCE, + ) + + @pytest.fixture + def vendor_impl(self): + return OpImpl( + op_name="test_op", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=lambda x: x * 3, + priority=BackendPriority.VENDOR, + vendor="CUDA", + ) + + def test_init_with_custom_registry(self, registry): + manager = OpManager(registry=registry) + assert manager.registry is registry + + def test_init_creates_default_registry(self): + manager = OpManager() + assert manager.registry is not None + assert isinstance(manager.registry, OpRegistry) + + def test_registry_property(self, manager, registry): + assert manager.registry is registry + + +class TestOpManagerPolicyEpoch: + """Test policy epoch and cache invalidation.""" + + @pytest.fixture(autouse=True) + def reset_policy(self): + reset_global_policy() + yield + reset_global_policy() + + @pytest.fixture + def manager(self): + return OpManager() + + def test_bump_policy_epoch(self, manager): + initial_epoch = manager._state.policy_epoch + manager.bump_policy_epoch() + assert manager._state.policy_epoch == initial_epoch + 1 + + def test_bump_policy_epoch_clears_cache(self, manager): + # Add something to cache + manager._dispatch_cache[("test", "fp", 0)] = lambda x: x + assert len(manager._dispatch_cache) == 1 + + manager.bump_policy_epoch() + assert len(manager._dispatch_cache) == 0 + + def test_bump_policy_epoch_clears_failed_impls(self, manager): + manager._failed_impls["test_op"] = {"impl1", "impl2"} + manager.bump_policy_epoch() + assert len(manager._failed_impls) == 0 + + +class TestOpManagerFailedImpls: + """Test failed implementation tracking.""" + + @pytest.fixture + def manager(self): + return OpManager() + + def test_clear_failed_impls_all(self, manager): + manager._failed_impls["op1"] = {"impl1"} + manager._failed_impls["op2"] = {"impl2"} + + manager.clear_failed_impls() + assert manager._failed_impls == {} + + def test_clear_failed_impls_specific_op(self, manager): + manager._failed_impls["op1"] = {"impl1"} + manager._failed_impls["op2"] = {"impl2"} + + manager.clear_failed_impls("op1") + assert "op1" not in manager._failed_impls + assert "op2" in manager._failed_impls + + def test_clear_failed_impls_nonexistent_op(self, manager): + manager._failed_impls["op1"] = {"impl1"} + manager.clear_failed_impls("nonexistent") + assert "op1" in manager._failed_impls + + def test_get_failed_impls_all(self, manager): + manager._failed_impls["op1"] = {"impl1", "impl2"} + manager._failed_impls["op2"] = {"impl3"} + + result = manager.get_failed_impls() + assert result == {"op1": {"impl1", "impl2"}, "op2": {"impl3"}} + + def test_get_failed_impls_specific_op(self, manager): + manager._failed_impls["op1"] = {"impl1"} + manager._failed_impls["op2"] = {"impl2"} + + result = manager.get_failed_impls("op1") + assert result == {"op1": {"impl1"}} + + def test_get_failed_impls_returns_copy(self, manager): + manager._failed_impls["op1"] = {"impl1"} + result = manager.get_failed_impls() + + # Modify the result + result["op1"].add("impl2") + + # Original should not be modified + assert manager._failed_impls["op1"] == {"impl1"} + + +class TestOpManagerResolve: + """Test operator resolution logic.""" + + @pytest.fixture(autouse=True) + def reset_policy(self): + reset_global_policy() + yield + reset_global_policy() + + @pytest.fixture + def registry(self): + return OpRegistry() + + @pytest.fixture + def manager(self, registry): + mgr = OpManager(registry=registry) + # Mark as initialized to skip builtin registration + mgr._state.initialized = True + mgr._state.init_pid = os.getpid() + return mgr + + def test_resolve_single_impl(self, manager, registry): + impl = OpImpl( + op_name="single_op", + impl_id="default.test", + kind=BackendImplKind.DEFAULT, + fn=lambda x: x * 2, + ) + registry.register_impl(impl) + + fn = manager.resolve("single_op") + assert fn is impl.fn + + def test_resolve_prefers_default_by_policy(self, manager, registry): + default_fn = lambda x: x * 2 + reference_fn = lambda x: x * 3 + + registry.register_impl(OpImpl( + op_name="multi_op", + impl_id="default.test", + kind=BackendImplKind.DEFAULT, + fn=default_fn, + priority=BackendPriority.DEFAULT, + )) + registry.register_impl(OpImpl( + op_name="multi_op", + impl_id="reference.test", + kind=BackendImplKind.REFERENCE, + fn=reference_fn, + priority=BackendPriority.REFERENCE, + )) + + # Default policy prefers "flagos" (DEFAULT) + fn = manager.resolve("multi_op") + assert fn is default_fn + + def test_resolve_prefers_vendor_with_policy(self, manager, registry): + default_fn = lambda x: x * 2 + vendor_fn = lambda x: x * 3 + + registry.register_impl(OpImpl( + op_name="vendor_op", + impl_id="default.test", + kind=BackendImplKind.DEFAULT, + fn=default_fn, + )) + registry.register_impl(OpImpl( + op_name="vendor_op", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=vendor_fn, + vendor="CUDA", + )) + + # Set policy to prefer vendor + set_global_policy(SelectionPolicy(prefer=PREFER_VENDOR)) + + fn = manager.resolve("vendor_op") + assert fn is vendor_fn + + def test_resolve_prefers_reference_with_policy(self, manager, registry): + default_fn = lambda x: x * 2 + reference_fn = lambda x: x * 3 + + registry.register_impl(OpImpl( + op_name="ref_op", + impl_id="default.test", + kind=BackendImplKind.DEFAULT, + fn=default_fn, + )) + registry.register_impl(OpImpl( + op_name="ref_op", + impl_id="reference.test", + kind=BackendImplKind.REFERENCE, + fn=reference_fn, + )) + + # Set policy to prefer reference + set_global_policy(SelectionPolicy(prefer=PREFER_REFERENCE)) + + fn = manager.resolve("ref_op") + assert fn is reference_fn + + def test_resolve_filters_denied_vendors(self, manager, registry): + default_fn = lambda x: x * 2 + vendor_fn = lambda x: x * 3 + + registry.register_impl(OpImpl( + op_name="deny_op", + impl_id="default.test", + kind=BackendImplKind.DEFAULT, + fn=default_fn, + )) + registry.register_impl(OpImpl( + op_name="deny_op", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=vendor_fn, + vendor="CUDA", + )) + + # Deny CUDA vendor, prefer vendor + set_global_policy(SelectionPolicy( + prefer=PREFER_VENDOR, + deny_vendors=frozenset({"CUDA"}) + )) + + # Should fall back to default since CUDA is denied + fn = manager.resolve("deny_op") + assert fn is default_fn + + def test_resolve_filters_by_allow_vendors(self, manager, registry): + vendor_cuda_fn = lambda x: x * 2 + vendor_amd_fn = lambda x: x * 3 + + registry.register_impl(OpImpl( + op_name="allow_op", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=vendor_cuda_fn, + vendor="CUDA", + )) + registry.register_impl(OpImpl( + op_name="allow_op", + impl_id="vendor.amd", + kind=BackendImplKind.VENDOR, + fn=vendor_amd_fn, + vendor="AMD", + )) + + # Only allow CUDA vendor + set_global_policy(SelectionPolicy( + prefer=PREFER_VENDOR, + allow_vendors=frozenset({"CUDA"}) + )) + + fn = manager.resolve("allow_op") + assert fn is vendor_cuda_fn + + def test_resolve_no_impl_raises(self, manager): + with pytest.raises(RuntimeError, match="No available implementation"): + manager.resolve("nonexistent_op") + + def test_resolve_caches_result(self, manager, registry): + impl = OpImpl( + op_name="cache_op", + impl_id="default.test", + kind=BackendImplKind.DEFAULT, + fn=lambda x: x, + ) + registry.register_impl(impl) + + # First call + fn1 = manager.resolve("cache_op") + + # Second call should return cached result + fn2 = manager.resolve("cache_op") + + assert fn1 is fn2 + assert len(manager._dispatch_cache) == 1 + + def test_resolve_cache_invalidated_by_policy_change(self, manager, registry): + impl = OpImpl( + op_name="epoch_op", + impl_id="default.test", + kind=BackendImplKind.DEFAULT, + fn=lambda x: x, + ) + registry.register_impl(impl) + + # First resolve + manager.resolve("epoch_op") + assert len(manager._dispatch_cache) == 1 + + # Bump epoch (simulates policy change) + manager.bump_policy_epoch() + assert len(manager._dispatch_cache) == 0 + + def test_resolve_filters_unavailable_impls(self, manager, registry): + available_fn = lambda x: x * 2 + unavailable_fn = lambda x: x * 3 + + # Create an unavailable implementation + def check_available(): + return False + unavailable_fn._is_available = check_available + + registry.register_impl(OpImpl( + op_name="avail_op", + impl_id="default.unavailable", + kind=BackendImplKind.DEFAULT, + fn=unavailable_fn, + priority=200, # Higher priority but unavailable + )) + registry.register_impl(OpImpl( + op_name="avail_op", + impl_id="default.available", + kind=BackendImplKind.DEFAULT, + fn=available_fn, + priority=100, + )) + + fn = manager.resolve("avail_op") + assert fn is available_fn + + +class TestOpManagerCall: + """Test operator call with fallback support.""" + + @pytest.fixture(autouse=True) + def reset_policy(self): + reset_global_policy() + yield + reset_global_policy() + + @pytest.fixture + def registry(self): + return OpRegistry() + + @pytest.fixture + def manager(self, registry): + mgr = OpManager(registry=registry) + mgr._state.initialized = True + mgr._state.init_pid = os.getpid() + return mgr + + def test_call_invokes_implementation(self, manager, registry): + result_value = [0] + + def impl_fn(x): + result_value[0] = x * 2 + return result_value[0] + + registry.register_impl(OpImpl( + op_name="call_op", + impl_id="default.test", + kind=BackendImplKind.DEFAULT, + fn=impl_fn, + )) + + result = manager.call("call_op", 5) + assert result == 10 + assert result_value[0] == 10 + + def test_call_passes_args_and_kwargs(self, manager, registry): + def impl_fn(a, b, c=10): + return a + b + c + + registry.register_impl(OpImpl( + op_name="args_op", + impl_id="default.test", + kind=BackendImplKind.DEFAULT, + fn=impl_fn, + )) + + result = manager.call("args_op", 1, 2, c=3) + assert result == 6 + + @patch.dict(os.environ, {"VLLM_FL_STRICT": "1"}) + def test_call_fallback_on_failure(self, manager, registry): + call_order = [] + + def failing_fn(x): + call_order.append("failing") + raise RuntimeError("Primary failed") + + def fallback_fn(x): + call_order.append("fallback") + return x * 2 + + registry.register_impl(OpImpl( + op_name="fallback_op", + impl_id="default.primary", + kind=BackendImplKind.DEFAULT, + fn=failing_fn, + priority=200, + )) + registry.register_impl(OpImpl( + op_name="fallback_op", + impl_id="reference.fallback", + kind=BackendImplKind.REFERENCE, + fn=fallback_fn, + priority=100, + )) + + result = manager.call("fallback_op", 5) + assert result == 10 + assert call_order == ["failing", "fallback"] + + @patch.dict(os.environ, {"VLLM_FL_STRICT": "1"}) + def test_call_tracks_failed_impls(self, manager, registry): + def failing_fn(x): + raise RuntimeError("Failed") + + def success_fn(x): + return x + + registry.register_impl(OpImpl( + op_name="track_op", + impl_id="default.failing", + kind=BackendImplKind.DEFAULT, + fn=failing_fn, + priority=200, + )) + registry.register_impl(OpImpl( + op_name="track_op", + impl_id="reference.success", + kind=BackendImplKind.REFERENCE, + fn=success_fn, + priority=100, + )) + + manager.call("track_op", 1) + + # Check that failed impl is tracked + failed = manager.get_failed_impls("track_op") + assert "default.failing" in failed["track_op"] + + @patch.dict(os.environ, {"VLLM_FL_STRICT": "1"}) + def test_call_all_impls_fail_raises(self, manager, registry): + def failing_fn1(x): + raise RuntimeError("Failed 1") + + def failing_fn2(x): + raise RuntimeError("Failed 2") + + registry.register_impl(OpImpl( + op_name="allfail_op", + impl_id="default.fail1", + kind=BackendImplKind.DEFAULT, + fn=failing_fn1, + )) + registry.register_impl(OpImpl( + op_name="allfail_op", + impl_id="reference.fail2", + kind=BackendImplKind.REFERENCE, + fn=failing_fn2, + )) + + with pytest.raises(RuntimeError, match="implementation.*failed"): + manager.call("allfail_op", 1) + + @patch.dict(os.environ, {"VLLM_FL_STRICT": "0"}) + def test_call_no_fallback_when_disabled(self, manager, registry): + def failing_fn(x): + raise RuntimeError("Primary failed") + + def fallback_fn(x): + return x * 2 + + registry.register_impl(OpImpl( + op_name="nofallback_op", + impl_id="default.primary", + kind=BackendImplKind.DEFAULT, + fn=failing_fn, + priority=200, + )) + registry.register_impl(OpImpl( + op_name="nofallback_op", + impl_id="reference.fallback", + kind=BackendImplKind.REFERENCE, + fn=fallback_fn, + priority=100, + )) + + # Should raise immediately without trying fallback + with pytest.raises(RuntimeError, match="Primary failed"): + manager.call("nofallback_op", 5) + + +class TestOpManagerResolveCandidates: + """Test resolve_candidates method.""" + + @pytest.fixture(autouse=True) + def reset_policy(self): + reset_global_policy() + yield + reset_global_policy() + + @pytest.fixture + def registry(self): + return OpRegistry() + + @pytest.fixture + def manager(self, registry): + mgr = OpManager(registry=registry) + mgr._state.initialized = True + mgr._state.init_pid = os.getpid() + return mgr + + def test_resolve_candidates_returns_sorted_list(self, manager, registry): + fn1 = lambda x: x + fn2 = lambda x: x + fn3 = lambda x: x + + registry.register_impl(OpImpl( + op_name="multi_op", + impl_id="default.impl", + kind=BackendImplKind.DEFAULT, + fn=fn1, + priority=BackendPriority.DEFAULT, + )) + registry.register_impl(OpImpl( + op_name="multi_op", + impl_id="vendor.impl", + kind=BackendImplKind.VENDOR, + fn=fn2, + priority=BackendPriority.VENDOR, + vendor="CUDA", + )) + registry.register_impl(OpImpl( + op_name="multi_op", + impl_id="reference.impl", + kind=BackendImplKind.REFERENCE, + fn=fn3, + priority=BackendPriority.REFERENCE, + )) + + candidates = manager.resolve_candidates("multi_op") + + # Default policy: flagos > vendor > reference + assert len(candidates) == 3 + assert candidates[0].impl_id == "default.impl" + assert candidates[1].impl_id == "vendor.impl" + assert candidates[2].impl_id == "reference.impl" + + def test_resolve_candidates_respects_policy_order(self, manager, registry): + fn1 = lambda x: x + fn2 = lambda x: x + + registry.register_impl(OpImpl( + op_name="order_op", + impl_id="default.impl", + kind=BackendImplKind.DEFAULT, + fn=fn1, + )) + registry.register_impl(OpImpl( + op_name="order_op", + impl_id="reference.impl", + kind=BackendImplKind.REFERENCE, + fn=fn2, + )) + + # Set policy to prefer reference + set_global_policy(SelectionPolicy(prefer=PREFER_REFERENCE)) + + candidates = manager.resolve_candidates("order_op") + + # Reference should come first + assert candidates[0].impl_id == "reference.impl" + assert candidates[1].impl_id == "default.impl" + + def test_resolve_candidates_no_impl_raises(self, manager): + with pytest.raises(RuntimeError, match="No available implementation"): + manager.resolve_candidates("nonexistent") + + +class TestOpManagerGetSelectedImplId: + """Test get_selected_impl_id method.""" + + @pytest.fixture(autouse=True) + def reset_policy(self): + reset_global_policy() + yield + reset_global_policy() + + @pytest.fixture + def registry(self): + return OpRegistry() + + @pytest.fixture + def manager(self, registry): + mgr = OpManager(registry=registry) + mgr._state.initialized = True + mgr._state.init_pid = os.getpid() + return mgr + + def test_get_selected_impl_id(self, manager, registry): + fn = lambda x: x + + registry.register_impl(OpImpl( + op_name="id_op", + impl_id="default.test", + kind=BackendImplKind.DEFAULT, + fn=fn, + )) + + impl_id = manager.get_selected_impl_id("id_op") + assert impl_id == "default.test" + + +class TestOpManagerThreadSafety: + """Test thread safety of OpManager.""" + + @pytest.fixture + def registry(self): + return OpRegistry() + + @pytest.fixture + def manager(self, registry): + mgr = OpManager(registry=registry) + mgr._state.initialized = True + mgr._state.init_pid = os.getpid() + return mgr + + def test_concurrent_resolve(self, manager, registry): + # Register multiple implementations + for i in range(5): + registry.register_impl(OpImpl( + op_name=f"thread_op_{i}", + impl_id=f"default.impl_{i}", + kind=BackendImplKind.DEFAULT, + fn=lambda x, i=i: x + i, + )) + + errors = [] + results = [] + + def resolve_op(op_idx): + try: + fn = manager.resolve(f"thread_op_{op_idx}") + results.append((op_idx, fn(10))) + except Exception as e: + errors.append(e) + + threads = [ + threading.Thread(target=resolve_op, args=(i % 5,)) + for i in range(20) + ] + + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0 + assert len(results) == 20 + + def test_concurrent_bump_policy_epoch(self, manager, registry): + registry.register_impl(OpImpl( + op_name="epoch_test", + impl_id="default.impl", + kind=BackendImplKind.DEFAULT, + fn=lambda x: x, + )) + + errors = [] + + def bump_and_resolve(): + try: + manager.bump_policy_epoch() + manager.resolve("epoch_test") + except Exception as e: + errors.append(e) + + threads = [ + threading.Thread(target=bump_and_resolve) + for _ in range(10) + ] + + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0 + + +class TestGlobalDefaultManager: + """Test global default manager functions.""" + + @pytest.fixture(autouse=True) + def reset_manager(self): + reset_default_manager() + yield + reset_default_manager() + + def test_get_default_manager_singleton(self): + manager1 = get_default_manager() + manager2 = get_default_manager() + assert manager1 is manager2 + + def test_reset_default_manager(self): + manager1 = get_default_manager() + reset_default_manager() + manager2 = get_default_manager() + assert manager1 is not manager2 + + def test_get_default_manager_creates_instance(self): + manager = get_default_manager() + assert isinstance(manager, OpManager) diff --git a/tests/unit_tests/dispatch/test_policy.py b/tests/unit_tests/dispatch/test_policy.py new file mode 100644 index 0000000..6eedaae --- /dev/null +++ b/tests/unit_tests/dispatch/test_policy.py @@ -0,0 +1,265 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Tests for dispatch policy module. +""" + +import os +import pytest +import tempfile +from unittest.mock import patch + +from vllm_fl.dispatch.policy import ( + SelectionPolicy, + PolicyManager, + PREFER_DEFAULT, + PREFER_VENDOR, + PREFER_REFERENCE, + VALID_PREFER_VALUES, + get_policy, + set_global_policy, + reset_global_policy, + policy_context, + with_preference, + with_strict_mode, +) + + +class TestSelectionPolicy: + def test_default_values(self): + policy = SelectionPolicy() + assert policy.prefer == PREFER_DEFAULT + assert policy.strict is False + assert policy.per_op_order == () + assert policy.deny_vendors == frozenset() + assert policy.allow_vendors is None + + def test_invalid_prefer_value_raises(self): + with pytest.raises(ValueError, match="Invalid prefer value"): + SelectionPolicy(prefer="invalid") + + def test_valid_prefer_values(self): + for prefer in VALID_PREFER_VALUES: + policy = SelectionPolicy(prefer=prefer) + assert policy.prefer == prefer + + def test_from_dict(self): + policy = SelectionPolicy.from_dict( + prefer="vendor", + strict=True, + per_op_order={"silu": ["vendor", "flagos"]}, + deny_vendors={"AMD"}, + allow_vendors={"CUDA"}, + ) + assert policy.prefer == "vendor" + assert policy.strict is True + assert policy.deny_vendors == frozenset({"AMD"}) + assert policy.allow_vendors == frozenset({"CUDA"}) + + def test_get_default_order_flagos(self): + policy = SelectionPolicy(prefer=PREFER_DEFAULT) + order = policy.get_default_order() + assert order == ["flagos", "vendor", "reference"] + + def test_get_default_order_vendor(self): + policy = SelectionPolicy(prefer=PREFER_VENDOR) + order = policy.get_default_order() + assert order == ["vendor", "flagos", "reference"] + + def test_get_default_order_reference(self): + policy = SelectionPolicy(prefer=PREFER_REFERENCE) + order = policy.get_default_order() + assert order == ["reference", "flagos", "vendor"] + + def test_is_vendor_allowed_deny_list(self): + policy = SelectionPolicy(deny_vendors=frozenset({"AMD"})) + assert policy.is_vendor_allowed("CUDA") is True + assert policy.is_vendor_allowed("AMD") is False + + def test_is_vendor_allowed_allow_list(self): + policy = SelectionPolicy(allow_vendors=frozenset({"CUDA"})) + assert policy.is_vendor_allowed("CUDA") is True + assert policy.is_vendor_allowed("AMD") is False + + def test_is_vendor_allowed_combined(self): + policy = SelectionPolicy( + allow_vendors=frozenset({"CUDA", "AMD"}), + deny_vendors=frozenset({"AMD"}), + ) + assert policy.is_vendor_allowed("CUDA") is True + assert policy.is_vendor_allowed("AMD") is False + + def test_get_per_op_order(self): + policy = SelectionPolicy.from_dict( + per_op_order={"silu": ["vendor", "flagos"], "rms_norm": ["reference"]}, + ) + assert policy.get_per_op_order("silu") == ["vendor", "flagos"] + assert policy.get_per_op_order("rms_norm") == ["reference"] + assert policy.get_per_op_order("nonexistent") is None + + def test_fingerprint_uniqueness(self): + policy1 = SelectionPolicy(prefer=PREFER_DEFAULT) + policy2 = SelectionPolicy(prefer=PREFER_VENDOR) + policy3 = SelectionPolicy(prefer=PREFER_DEFAULT, strict=True) + + assert policy1.fingerprint() != policy2.fingerprint() + assert policy1.fingerprint() != policy3.fingerprint() + assert policy2.fingerprint() != policy3.fingerprint() + + def test_frozen_dataclass(self): + policy = SelectionPolicy() + with pytest.raises(AttributeError): + policy.prefer = "vendor" + + +class TestPolicyManager: + @pytest.fixture(autouse=True) + def reset_policy(self): + """Reset global policy before and after each test.""" + reset_global_policy() + yield + reset_global_policy() + + def test_get_instance_singleton(self): + manager1 = PolicyManager.get_instance() + manager2 = PolicyManager.get_instance() + assert manager1 is manager2 + + def test_get_policy_returns_default(self): + policy = get_policy() + assert isinstance(policy, SelectionPolicy) + assert policy.prefer == PREFER_DEFAULT + + def test_set_global_policy(self): + new_policy = SelectionPolicy(prefer=PREFER_VENDOR) + old_policy = set_global_policy(new_policy) + + current = get_policy() + assert current.prefer == PREFER_VENDOR + assert old_policy.prefer == PREFER_DEFAULT + + def test_reset_global_policy(self): + set_global_policy(SelectionPolicy(prefer=PREFER_VENDOR)) + reset_global_policy() + + policy = get_policy() + assert policy.prefer == PREFER_DEFAULT + + def test_policy_epoch_bumps(self): + manager = PolicyManager.get_instance() + epoch1 = manager.get_policy_epoch() + + set_global_policy(SelectionPolicy(prefer=PREFER_VENDOR)) + epoch2 = manager.get_policy_epoch() + + assert epoch2 > epoch1 + + +class TestPolicyContext: + @pytest.fixture(autouse=True) + def reset_policy(self): + reset_global_policy() + yield + reset_global_policy() + + def test_policy_context_override(self): + original = get_policy() + assert original.prefer == PREFER_DEFAULT + + override_policy = SelectionPolicy(prefer=PREFER_VENDOR) + with policy_context(override_policy): + inside = get_policy() + assert inside.prefer == PREFER_VENDOR + + after = get_policy() + assert after.prefer == PREFER_DEFAULT + + def test_with_preference(self): + with with_preference("vendor"): + policy = get_policy() + assert policy.prefer == "vendor" + + policy = get_policy() + assert policy.prefer == PREFER_DEFAULT + + def test_with_strict_mode(self): + with with_strict_mode(): + policy = get_policy() + assert policy.strict is True + + policy = get_policy() + assert policy.strict is False + + def test_nested_contexts(self): + with with_preference("vendor"): + assert get_policy().prefer == "vendor" + with with_strict_mode(): + policy = get_policy() + assert policy.strict is True + assert get_policy().prefer == "vendor" + + assert get_policy().prefer == PREFER_DEFAULT + + +class TestPolicyFromEnv: + @pytest.fixture(autouse=True) + def reset_policy(self): + reset_global_policy() + yield + reset_global_policy() + + def test_policy_from_env_prefer(self): + with patch.dict(os.environ, {"VLLM_FL_PREFER": "vendor"}): + reset_global_policy() + policy = get_policy() + assert policy.prefer == "vendor" + + def test_policy_from_env_strict(self): + with patch.dict(os.environ, {"VLLM_FL_STRICT": "1"}): + reset_global_policy() + policy = get_policy() + assert policy.strict is True + + def test_policy_from_env_deny_vendors(self): + with patch.dict(os.environ, {"VLLM_FL_DENY_VENDORS": "AMD,Intel"}): + reset_global_policy() + policy = get_policy() + assert "AMD" in policy.deny_vendors + assert "Intel" in policy.deny_vendors + + +class TestPolicyFromConfig: + def test_policy_from_config_file(self): + config_content = """ +prefer: vendor +strict: true +allow_vendors: + - CUDA +deny_vendors: + - AMD +op_backends: + silu: + - vendor + - flagos +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(config_content) + f.flush() + config_path = f.name + + try: + from vllm_fl.dispatch.policy import policy_from_config + policy = policy_from_config(config_path) + + assert policy.prefer == "vendor" + assert policy.strict is True + assert "CUDA" in policy.allow_vendors + assert "AMD" in policy.deny_vendors + assert policy.get_per_op_order("silu") == ["vendor", "flagos"] + finally: + os.unlink(config_path) + + def test_policy_from_nonexistent_config(self): + from vllm_fl.dispatch.policy import policy_from_config + with pytest.raises(FileNotFoundError): + policy_from_config("/nonexistent/path/config.yaml") diff --git a/tests/unit_tests/dispatch/test_registry.py b/tests/unit_tests/dispatch/test_registry.py new file mode 100644 index 0000000..77e0d97 --- /dev/null +++ b/tests/unit_tests/dispatch/test_registry.py @@ -0,0 +1,136 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Tests for dispatch registry. +""" + +import pytest +from vllm_fl.dispatch.registry import OpRegistry, OpRegistrySnapshot +from vllm_fl.dispatch.types import BackendImplKind, OpImpl + + +class TestOpRegistry: + @pytest.fixture + def registry(self): + return OpRegistry() + + @pytest.fixture + def sample_impl(self): + return OpImpl( + op_name="silu", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=lambda x: x, + priority=100, + ) + + @pytest.fixture + def another_impl(self): + return OpImpl( + op_name="silu", + impl_id="reference.pytorch", + kind=BackendImplKind.REFERENCE, + fn=lambda x: x, + priority=50, + ) + + def test_register_impl(self, registry, sample_impl): + registry.register_impl(sample_impl) + result = registry.get_implementation("silu", "default.flagos") + assert result == sample_impl + + def test_register_impl_duplicate_raises(self, registry, sample_impl): + registry.register_impl(sample_impl) + with pytest.raises(ValueError, match="Duplicate impl_id"): + registry.register_impl(sample_impl) + + def test_register_many(self, registry, sample_impl, another_impl): + registry.register_many([sample_impl, another_impl]) + assert registry.get_implementation("silu", "default.flagos") == sample_impl + assert registry.get_implementation("silu", "reference.pytorch") == another_impl + + def test_get_implementations(self, registry, sample_impl, another_impl): + registry.register_many([sample_impl, another_impl]) + impls = registry.get_implementations("silu") + assert len(impls) == 2 + assert sample_impl in impls + assert another_impl in impls + + def test_get_implementations_empty(self, registry): + impls = registry.get_implementations("nonexistent") + assert impls == [] + + def test_get_implementation_not_found(self, registry): + result = registry.get_implementation("nonexistent", "any_id") + assert result is None + + def test_list_operators(self, registry, sample_impl, another_impl): + rms_impl = OpImpl( + op_name="rms_norm", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=lambda x: x, + ) + registry.register_many([sample_impl, another_impl, rms_impl]) + ops = registry.list_operators() + assert set(ops) == {"silu", "rms_norm"} + + def test_list_operators_empty(self, registry): + assert registry.list_operators() == [] + + def test_clear(self, registry, sample_impl): + registry.register_impl(sample_impl) + registry.clear() + assert registry.list_operators() == [] + assert registry.get_implementation("silu", "default.flagos") is None + + def test_snapshot(self, registry, sample_impl, another_impl): + registry.register_many([sample_impl, another_impl]) + snapshot = registry.snapshot() + + assert isinstance(snapshot, OpRegistrySnapshot) + assert "silu" in snapshot.impls_by_op + assert len(snapshot.impls_by_op["silu"]) == 2 + + def test_snapshot_is_immutable_copy(self, registry, sample_impl): + registry.register_impl(sample_impl) + snapshot = registry.snapshot() + + # Register more after snapshot + new_impl = OpImpl( + op_name="rms_norm", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=lambda x: x, + ) + registry.register_impl(new_impl) + + # Snapshot should not contain the new impl + assert "rms_norm" not in snapshot.impls_by_op + + def test_thread_safety(self, registry): + """Test basic thread safety of registry operations.""" + import threading + + errors = [] + + def register_impl(i): + try: + impl = OpImpl( + op_name=f"op_{i}", + impl_id=f"impl_{i}", + kind=BackendImplKind.DEFAULT, + fn=lambda x: x, + ) + registry.register_impl(impl) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=register_impl, args=(i,)) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0 + assert len(registry.list_operators()) == 10 diff --git a/tests/unit_tests/dispatch/test_types.py b/tests/unit_tests/dispatch/test_types.py new file mode 100644 index 0000000..9b9aa00 --- /dev/null +++ b/tests/unit_tests/dispatch/test_types.py @@ -0,0 +1,177 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Tests for dispatch type definitions. +""" + +import pytest +from vllm_fl.dispatch.types import ( + BackendImplKind, + BackendPriority, + OpImpl, + match_token, +) + + +class TestBackendImplKind: + def test_enum_values(self): + assert BackendImplKind.DEFAULT.value == "flagos" + assert BackendImplKind.REFERENCE.value == "reference" + assert BackendImplKind.VENDOR.value == "vendor" + + def test_str_representation(self): + assert str(BackendImplKind.DEFAULT) == "flagos" + assert str(BackendImplKind.REFERENCE) == "reference" + assert str(BackendImplKind.VENDOR) == "vendor" + + +class TestBackendPriority: + def test_priority_ordering(self): + assert BackendPriority.DEFAULT > BackendPriority.VENDOR + assert BackendPriority.VENDOR > BackendPriority.REFERENCE + assert BackendPriority.DEFAULT > BackendPriority.REFERENCE + + +class TestOpImpl: + def test_create_default_impl(self): + fn = lambda x: x + impl = OpImpl( + op_name="silu", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=fn, + priority=BackendPriority.DEFAULT, + ) + assert impl.op_name == "silu" + assert impl.impl_id == "default.flagos" + assert impl.kind == BackendImplKind.DEFAULT + assert impl.vendor is None + + def test_create_vendor_impl_requires_vendor_name(self): + fn = lambda x: x + with pytest.raises(ValueError, match="must specify vendor name"): + OpImpl( + op_name="silu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=fn, + ) + + def test_create_vendor_impl_with_vendor_name(self): + fn = lambda x: x + impl = OpImpl( + op_name="silu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=fn, + vendor="CUDA", + ) + assert impl.vendor == "CUDA" + + def test_is_available_default(self): + fn = lambda x: x + impl = OpImpl( + op_name="silu", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=fn, + ) + assert impl.is_available() is True + + def test_is_available_with_checker(self): + def fn(x): + return x + + fn._is_available = lambda: True + impl = OpImpl( + op_name="silu", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=fn, + ) + assert impl.is_available() is True + + fn._is_available = lambda: False + assert impl.is_available() is False + + def test_is_available_handles_exception(self): + def fn(x): + return x + + fn._is_available = lambda: 1 / 0 # Raises ZeroDivisionError + impl = OpImpl( + op_name="silu", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=fn, + ) + assert impl.is_available() is False + + def test_frozen_dataclass(self): + fn = lambda x: x + impl = OpImpl( + op_name="silu", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=fn, + ) + with pytest.raises(AttributeError): + impl.op_name = "new_name" + + +class TestMatchToken: + @pytest.fixture + def default_impl(self): + return OpImpl( + op_name="silu", + impl_id="default.flagos", + kind=BackendImplKind.DEFAULT, + fn=lambda x: x, + ) + + @pytest.fixture + def reference_impl(self): + return OpImpl( + op_name="silu", + impl_id="reference.pytorch", + kind=BackendImplKind.REFERENCE, + fn=lambda x: x, + ) + + @pytest.fixture + def vendor_impl(self): + return OpImpl( + op_name="silu", + impl_id="vendor.cuda", + kind=BackendImplKind.VENDOR, + fn=lambda x: x, + vendor="CUDA", + ) + + def test_match_flagos_token(self, default_impl, reference_impl, vendor_impl): + assert match_token(default_impl, "flagos") is True + assert match_token(reference_impl, "flagos") is False + assert match_token(vendor_impl, "flagos") is False + + def test_match_reference_token(self, default_impl, reference_impl, vendor_impl): + assert match_token(default_impl, "reference") is False + assert match_token(reference_impl, "reference") is True + assert match_token(vendor_impl, "reference") is False + + def test_match_vendor_token(self, default_impl, reference_impl, vendor_impl): + assert match_token(default_impl, "vendor") is False + assert match_token(reference_impl, "vendor") is False + assert match_token(vendor_impl, "vendor") is True + + def test_match_vendor_specific(self, vendor_impl): + assert match_token(vendor_impl, "vendor:CUDA") is True + assert match_token(vendor_impl, "vendor:AMD") is False + + def test_match_impl_id(self, default_impl, reference_impl): + assert match_token(default_impl, "impl:default.flagos") is True + assert match_token(default_impl, "impl:reference.pytorch") is False + assert match_token(reference_impl, "impl:reference.pytorch") is True + + def test_unknown_token_returns_false(self, default_impl): + assert match_token(default_impl, "unknown") is False + assert match_token(default_impl, "") is False diff --git a/tests/unit_tests/distributed/__init__.py b/tests/unit_tests/distributed/__init__.py new file mode 100644 index 0000000..8c90136 --- /dev/null +++ b/tests/unit_tests/distributed/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2025 BAAI. All rights reserved. diff --git a/tests/unit_tests/distributed/test_communicator.py b/tests/unit_tests/distributed/test_communicator.py new file mode 100644 index 0000000..412d2b5 --- /dev/null +++ b/tests/unit_tests/distributed/test_communicator.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Tests for distributed communicator module. + +Note: Tests require FLAGCX_PATH environment variable to be set. +Tests are skipped if flagcx is not available. +""" + +import os +import pytest + + +def has_flagcx(): + """Check if flagcx is available.""" + flagcx_path = os.getenv('FLAGCX_PATH') + if not flagcx_path: + return False + lib_path = os.path.join(flagcx_path, "build/lib/libflagcx.so") + return os.path.exists(lib_path) + + +# Skip all tests if flagcx is not available (communicator depends on it) +pytestmark = pytest.mark.skipif( + not has_flagcx(), + reason="FLAGCX_PATH not set or flagcx library not found" +) + + +# Note: CommunicatorFL requires multi-process GPU environment for meaningful tests. +# Integration tests should be in functional_tests/. +# Unit tests here are minimal as the class requires distributed infrastructure. diff --git a/tests/unit_tests/distributed/test_flagcx.py b/tests/unit_tests/distributed/test_flagcx.py new file mode 100644 index 0000000..21ab7a4 --- /dev/null +++ b/tests/unit_tests/distributed/test_flagcx.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Tests for flagcx communicator module. + +Note: Tests require FLAGCX_PATH environment variable and the flagcx Python bindings. +Tests are skipped if flagcx is not available. + +Integration tests for actual distributed operations should be in functional_tests/. +""" + +import os +import pytest + + +def has_flagcx(): + """Check if flagcx is available (both library and Python bindings).""" + flagcx_path = os.getenv('FLAGCX_PATH') + if not flagcx_path: + return False + lib_path = os.path.join(flagcx_path, "build/lib/libflagcx.so") + if not os.path.exists(lib_path): + return False + # Also check Python bindings + try: + from plugin.interservice.flagcx_wrapper import flagcxDataTypeEnum + return True + except ImportError: + return False + + +# Mark all tests in this module as requiring flagcx +pytestmark = pytest.mark.skipif( + not has_flagcx(), + reason="FLAGCX_PATH not set, flagcx library not found, or Python bindings unavailable" +) + + +# Note: PyFlagcxCommunicator requires multi-GPU distributed environment for meaningful tests. +# Unit tests for dtype/op conversions are moved here but require the plugin module. +# Integration tests should be in functional_tests/. diff --git a/tests/unit_tests/flaggems/__init__.py b/tests/unit_tests/flaggems/__init__.py new file mode 100644 index 0000000..8c90136 --- /dev/null +++ b/tests/unit_tests/flaggems/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2025 BAAI. All rights reserved. diff --git a/tests/flaggems/test_flaggems_get_ops.py b/tests/unit_tests/flaggems/test_flaggems_get_ops.py similarity index 58% rename from tests/flaggems/test_flaggems_get_ops.py rename to tests/unit_tests/flaggems/test_flaggems_get_ops.py index 453b5c4..1764d3a 100644 --- a/tests/flaggems/test_flaggems_get_ops.py +++ b/tests/unit_tests/flaggems/test_flaggems_get_ops.py @@ -1,5 +1,8 @@ -# Copyright (c) 2026 BAAI. All rights reserved. +# Copyright (c) 2025 BAAI. All rights reserved. +""" +Unit tests for FlagGems ops discovery functionality. +""" from vllm_fl.utils import get_flaggems_all_ops diff --git a/tests/flaggems/test_gems_whitelist.py b/tests/unit_tests/flaggems/test_gems_whitelist.py similarity index 95% rename from tests/flaggems/test_gems_whitelist.py rename to tests/unit_tests/flaggems/test_gems_whitelist.py index 36eaef7..46c21d9 100644 --- a/tests/flaggems/test_gems_whitelist.py +++ b/tests/unit_tests/flaggems/test_gems_whitelist.py @@ -1,4 +1,4 @@ -# Copyright (c) 2026 BAAI. All rights reserved. +# Copyright (c) 2025 BAAI. All rights reserved. """ Unit tests for FlagGems operator whitelist/blacklist functionality. @@ -70,15 +70,15 @@ def test_use_flaggems_op_blacklist_only_listed_disallowed(monkeypatch): assert use_flaggems_op("other_op") is True -def test_use_flaggems_op_whitelist_and_blacklist_same_op_raises(monkeypatch): - """When same op is in both whitelist and blacklist, ValueError is raised.""" +def test_use_flaggems_op_whitelist_and_blacklist_both_set_raises(monkeypatch): + """When both whitelist and blacklist are set, ValueError is raised.""" _env_for_flaggems_enabled(monkeypatch) monkeypatch.setenv("VLLM_FL_FLAGOS_WHITELIST", "silu_and_mul,rms_norm") monkeypatch.setenv("VLLM_FL_FLAGOS_BLACKLIST", "rms_norm,other") with pytest.raises(ValueError) as exc_info: - use_flaggems_op("rms_norm") - assert "rms_norm" in str(exc_info.value) + use_flaggems_op("silu_and_mul") + # Implementation disallows setting both whitelist and blacklist simultaneously assert "VLLM_FL_FLAGOS_WHITELIST" in str(exc_info.value) assert "VLLM_FL_FLAGOS_BLACKLIST" in str(exc_info.value) diff --git a/tests/unit_tests/ops/__init__.py b/tests/unit_tests/ops/__init__.py new file mode 100644 index 0000000..8c90136 --- /dev/null +++ b/tests/unit_tests/ops/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2025 BAAI. All rights reserved. diff --git a/tests/unit_tests/ops/test_activation.py b/tests/unit_tests/ops/test_activation.py new file mode 100644 index 0000000..c2063c3 --- /dev/null +++ b/tests/unit_tests/ops/test_activation.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Tests for activation ops. +""" + +import pytest +import torch +from unittest.mock import patch + + +class TestSiluAndMulFL: + """Test SiluAndMulFL class behavior.""" + + @pytest.fixture + def mock_call_op(self): + with patch("vllm_fl.ops.activation.call_op") as mock: + yield mock + + @pytest.fixture + def mock_parent_init(self): + with patch("vllm_fl.ops.activation.SiluAndMul.__init__", return_value=None): + yield + + def test_forward_oot_dispatches_correctly(self, mock_parent_init, mock_call_op): + """Test forward_oot calls dispatch system with correct op name and input.""" + from vllm_fl.ops.activation import SiluAndMulFL + + mock_call_op.return_value = torch.randn(2, 4) + layer = SiluAndMulFL() + x = torch.randn(2, 8) + + result = layer.forward_oot(x) + + mock_call_op.assert_called_once_with("silu_and_mul", x) + assert result.shape == (2, 4) diff --git a/tests/unit_tests/ops/test_layernorm.py b/tests/unit_tests/ops/test_layernorm.py new file mode 100644 index 0000000..ba8aa76 --- /dev/null +++ b/tests/unit_tests/ops/test_layernorm.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Tests for layernorm ops. +""" + +import pytest +import torch +from unittest.mock import patch + + +class TestRMSNormFL: + """Test RMSNormFL class behavior.""" + + @pytest.fixture + def mock_call_op(self): + with patch("vllm_fl.ops.layernorm.call_op") as mock: + yield mock + + def test_init_creates_weight_parameter(self): + """Test that initialization creates weight parameter with correct shape.""" + from vllm_fl.ops.layernorm import RMSNormFL + + hidden_size = 128 + eps = 1e-5 + layer = RMSNormFL(hidden_size=hidden_size, eps=eps) + + assert layer.variance_epsilon == eps + assert layer.weight.shape == (hidden_size,) + + def test_forward_oot_dispatches_without_residual(self, mock_call_op): + """Test forward_oot calls dispatch system correctly without residual.""" + from vllm_fl.ops.layernorm import RMSNormFL + + hidden_size = 128 + mock_call_op.return_value = torch.randn(2, hidden_size) + + layer = RMSNormFL(hidden_size=hidden_size) + x = torch.randn(2, hidden_size) + + result = layer.forward_oot(x) + + mock_call_op.assert_called_once() + call_args = mock_call_op.call_args + assert call_args[0][0] == "rms_norm" + assert torch.equal(call_args[0][1], x) + assert call_args[0][2] is None # residual should be None + + def test_forward_oot_dispatches_with_residual(self, mock_call_op): + """Test forward_oot passes residual to dispatch system.""" + from vllm_fl.ops.layernorm import RMSNormFL + + hidden_size = 128 + mock_call_op.return_value = (torch.randn(2, hidden_size), torch.randn(2, hidden_size)) + + layer = RMSNormFL(hidden_size=hidden_size) + x = torch.randn(2, hidden_size) + residual = torch.randn(2, hidden_size) + + result = layer.forward_oot(x, residual=residual) + + mock_call_op.assert_called_once() + call_args = mock_call_op.call_args + assert call_args[0][0] == "rms_norm" + assert torch.equal(call_args[0][2], residual) diff --git a/tests/unit_tests/ops/test_numerical.py b/tests/unit_tests/ops/test_numerical.py new file mode 100644 index 0000000..83511cf --- /dev/null +++ b/tests/unit_tests/ops/test_numerical.py @@ -0,0 +1,588 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Numerical correctness tests for operator implementations. + +This module tests that operator implementations produce numerically correct +results by comparing against reference (PyTorch) implementations. + +These tests verify: +- Output shape correctness +- Numerical accuracy within tolerance +- Handling of edge cases (zeros, large values, etc.) +- Different dtypes (float32, float16, bfloat16) +""" + +import pytest +import torch +import torch.nn.functional as F + + +# ============================================================================= +# Reference Implementations (PyTorch baseline) +# ============================================================================= + +def reference_silu_and_mul(x: torch.Tensor) -> torch.Tensor: + """Reference SiLU and multiply implementation.""" + d = x.shape[-1] // 2 + x1, x2 = x[..., :d], x[..., d:] + return F.silu(x1) * x2 + + +def reference_rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + residual: torch.Tensor = None, +): + """Reference RMS normalization implementation.""" + if residual is not None: + x = x + residual + new_residual = x.clone() + + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + epsilon) + output = weight * x + + if residual is not None: + return output, new_residual + return output + + +def reference_rotary_embedding( + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool = True, +): + """ + Reference rotary position embedding implementation. + + Applies rotary position embedding to query and key tensors. + """ + def apply_rotary(x, cos, sin, is_neox_style): + if is_neox_style: + # GPT-NeoX style: split at half + d = x.shape[-1] // 2 + x1, x2 = x[..., :d], x[..., d:] + rotated = torch.cat([-x2, x1], dim=-1) + else: + # GPT-J style: interleaved + x1 = x[..., ::2] + x2 = x[..., 1::2] + rotated = torch.stack([-x2, x1], dim=-1).flatten(-2) + + return x * cos + rotated * sin + + q_embed = apply_rotary(query, cos, sin, is_neox_style) + k_embed = apply_rotary(key, cos, sin, is_neox_style) + + return q_embed, k_embed + + +# ============================================================================= +# Test Fixtures +# ============================================================================= + +@pytest.fixture(params=[torch.float32, torch.float16]) +def dtype(request): + """Test with different floating point dtypes.""" + return request.param + + +@pytest.fixture(params=[(2, 128), (4, 256), (8, 512)]) +def hidden_size_config(request): + """Different batch and hidden size configurations.""" + return request.param + + +@pytest.fixture +def device(): + """Get available device (CPU for unit tests).""" + return torch.device("cpu") + + +# ============================================================================= +# SiLU and Multiply Tests +# ============================================================================= + +class TestSiluAndMulNumerical: + """Numerical correctness tests for SiLU and multiply operation.""" + + def test_basic_correctness(self, device): + """Test basic numerical correctness.""" + from vllm_fl.dispatch.backends.reference.impl.activation import silu_and_mul_torch + + x = torch.randn(2, 16, device=device, dtype=torch.float32) + + result = silu_and_mul_torch(x) + expected = reference_silu_and_mul(x) + + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + + def test_output_shape(self, hidden_size_config, device): + """Test that output shape is correct (half of input).""" + from vllm_fl.dispatch.backends.reference.impl.activation import silu_and_mul_torch + + batch, hidden = hidden_size_config + x = torch.randn(batch, hidden * 2, device=device) + + result = silu_and_mul_torch(x) + + assert result.shape == (batch, hidden) + + def test_dtype_preservation(self, dtype, device): + """Test that dtype is preserved.""" + from vllm_fl.dispatch.backends.reference.impl.activation import silu_and_mul_torch + + x = torch.randn(2, 16, device=device, dtype=dtype) + + result = silu_and_mul_torch(x) + + assert result.dtype == dtype + + def test_zero_input(self, device): + """Test with zero input tensor.""" + from vllm_fl.dispatch.backends.reference.impl.activation import silu_and_mul_torch + + x = torch.zeros(2, 16, device=device) + + result = silu_and_mul_torch(x) + expected = reference_silu_and_mul(x) + + torch.testing.assert_close(result, expected) + + def test_large_values(self, device): + """Test with large input values.""" + from vllm_fl.dispatch.backends.reference.impl.activation import silu_and_mul_torch + + x = torch.randn(2, 16, device=device) * 100 + + result = silu_and_mul_torch(x) + expected = reference_silu_and_mul(x) + + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + def test_negative_values(self, device): + """Test with negative input values.""" + from vllm_fl.dispatch.backends.reference.impl.activation import silu_and_mul_torch + + x = -torch.abs(torch.randn(2, 16, device=device)) + + result = silu_and_mul_torch(x) + expected = reference_silu_and_mul(x) + + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + + def test_3d_input(self, device): + """Test with 3D input tensor.""" + from vllm_fl.dispatch.backends.reference.impl.activation import silu_and_mul_torch + + x = torch.randn(2, 4, 16, device=device) + + result = silu_and_mul_torch(x) + expected = reference_silu_and_mul(x) + + assert result.shape == (2, 4, 8) + torch.testing.assert_close(result, expected) + + +# ============================================================================= +# RMS Normalization Tests +# ============================================================================= + +class TestRMSNormNumerical: + """Numerical correctness tests for RMS normalization.""" + + def test_basic_correctness(self, device): + """Test basic numerical correctness.""" + from vllm_fl.dispatch.backends.reference.impl.normalization import rms_norm_torch + + hidden_size = 128 + x = torch.randn(2, hidden_size, device=device, dtype=torch.float32) + weight = torch.ones(hidden_size, device=device, dtype=torch.float32) + epsilon = 1e-5 + + result = rms_norm_torch(x, None, weight, epsilon) + expected = reference_rms_norm(x, weight, epsilon) + + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + + def test_with_residual(self, device): + """Test RMS norm with residual connection.""" + from vllm_fl.dispatch.backends.reference.impl.normalization import rms_norm_torch + + hidden_size = 128 + x = torch.randn(2, hidden_size, device=device, dtype=torch.float32) + residual = torch.randn(2, hidden_size, device=device, dtype=torch.float32) + weight = torch.ones(hidden_size, device=device, dtype=torch.float32) + epsilon = 1e-5 + + result_out, result_res = rms_norm_torch(x, residual, weight, epsilon) + expected_out, expected_res = reference_rms_norm(x, weight, epsilon, residual) + + torch.testing.assert_close(result_out, expected_out, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(result_res, expected_res, rtol=1e-5, atol=1e-5) + + def test_output_shape(self, hidden_size_config, device): + """Test that output shape matches input shape.""" + from vllm_fl.dispatch.backends.reference.impl.normalization import rms_norm_torch + + batch, hidden = hidden_size_config + x = torch.randn(batch, hidden, device=device) + weight = torch.ones(hidden, device=device) + epsilon = 1e-5 + + result = rms_norm_torch(x, None, weight, epsilon) + + assert result.shape == x.shape + + def test_dtype_preservation(self, dtype, device): + """Test that dtype is preserved.""" + from vllm_fl.dispatch.backends.reference.impl.normalization import rms_norm_torch + + hidden_size = 128 + x = torch.randn(2, hidden_size, device=device, dtype=dtype) + weight = torch.ones(hidden_size, device=device, dtype=dtype) + epsilon = 1e-5 + + result = rms_norm_torch(x, None, weight, epsilon) + + assert result.dtype == dtype + + def test_normalization_effect(self, device): + """Test that normalization actually normalizes the tensor.""" + from vllm_fl.dispatch.backends.reference.impl.normalization import rms_norm_torch + + hidden_size = 128 + # Create input with known variance + x = torch.randn(2, hidden_size, device=device) * 10 + weight = torch.ones(hidden_size, device=device) + epsilon = 1e-5 + + result = rms_norm_torch(x, None, weight, epsilon) + + # After RMS norm, the RMS should be approximately 1 + rms = result.pow(2).mean(-1).sqrt() + torch.testing.assert_close( + rms, + torch.ones_like(rms), + rtol=0.1, + atol=0.1 + ) + + def test_epsilon_effect(self, device): + """Test that epsilon prevents division by zero.""" + from vllm_fl.dispatch.backends.reference.impl.normalization import rms_norm_torch + + hidden_size = 128 + x = torch.zeros(2, hidden_size, device=device) + weight = torch.ones(hidden_size, device=device) + epsilon = 1e-5 + + # Should not raise or produce NaN/Inf + result = rms_norm_torch(x, None, weight, epsilon) + + assert not torch.isnan(result).any() + assert not torch.isinf(result).any() + + def test_weight_scaling(self, device): + """Test that weight properly scales the output.""" + from vllm_fl.dispatch.backends.reference.impl.normalization import rms_norm_torch + + hidden_size = 128 + x = torch.randn(2, hidden_size, device=device) + weight1 = torch.ones(hidden_size, device=device) + weight2 = torch.ones(hidden_size, device=device) * 2 + epsilon = 1e-5 + + result1 = rms_norm_torch(x, None, weight1, epsilon) + result2 = rms_norm_torch(x, None, weight2, epsilon) + + # Result with weight=2 should be twice result with weight=1 + torch.testing.assert_close(result2, result1 * 2, rtol=1e-5, atol=1e-5) + + def test_3d_input(self, device): + """Test with 3D input tensor (batch, seq, hidden).""" + from vllm_fl.dispatch.backends.reference.impl.normalization import rms_norm_torch + + x = torch.randn(2, 4, 128, device=device) + weight = torch.ones(128, device=device) + epsilon = 1e-5 + + result = rms_norm_torch(x, None, weight, epsilon) + expected = reference_rms_norm(x, weight, epsilon) + + assert result.shape == x.shape + torch.testing.assert_close(result, expected) + + +# ============================================================================= +# Rotary Embedding Tests +# ============================================================================= + +class TestRotaryEmbeddingNumerical: + """Numerical correctness tests for rotary position embedding.""" + + def test_basic_correctness_4d(self, device): + """Test basic numerical correctness with 4D tensors.""" + from vllm_fl.dispatch.backends.reference.impl.rotary import rotary_embedding_torch + + batch, num_heads, seq_len, head_dim = 2, 4, 8, 64 + max_seq_len = 16 + rotary_dim = head_dim // 2 + + # [batch, num_heads, seq_len, head_dim] + query = torch.randn(batch, num_heads, seq_len, head_dim, device=device) + key = torch.randn(batch, num_heads, seq_len, head_dim, device=device) + + # Create cos/sin cache [max_seq_len, rotary_dim] + freqs = 1.0 / (10000 ** (torch.arange(0, rotary_dim, device=device).float() / rotary_dim)) + angles = torch.arange(max_seq_len, device=device).unsqueeze(1) * freqs.unsqueeze(0) + cos = torch.cos(angles) # [max_seq_len, rotary_dim] + sin = torch.sin(angles) # [max_seq_len, rotary_dim] + + # For 4D query, position_ids should be 2D [batch, seq_len] + positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch, -1) + + result_q, result_k = rotary_embedding_torch( + query, key, cos, sin, + position_ids=positions, + rotary_interleaved=False, + inplace=False, + ) + + # Should have same shape as input + assert result_q.shape == query.shape + assert result_k.shape == key.shape + assert not torch.isnan(result_q).any() + assert not torch.isnan(result_k).any() + + def test_output_shape_4d(self, device): + """Test that output shapes match input shapes for 4D tensors.""" + from vllm_fl.dispatch.backends.reference.impl.rotary import rotary_embedding_torch + + batch, num_heads, seq_len, head_dim = 4, 8, 16, 32 + max_seq_len = 32 + rotary_dim = head_dim // 2 + + # [batch, num_heads, seq_len, head_dim] + query = torch.randn(batch, num_heads, seq_len, head_dim, device=device) + key = torch.randn(batch, num_heads, seq_len, head_dim, device=device) + + # For 4D query, use 2D position_ids [batch, seq_len] + positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch, -1) + cos = torch.randn(max_seq_len, rotary_dim, device=device) + sin = torch.randn(max_seq_len, rotary_dim, device=device) + + result_q, result_k = rotary_embedding_torch( + query, key, cos, sin, + position_ids=positions, + rotary_interleaved=False, + inplace=False, + ) + + assert result_q.shape == query.shape + assert result_k.shape == key.shape + + def test_output_shape_3d(self, device): + """Test that output shapes match input shapes for 3D tensors.""" + from vllm_fl.dispatch.backends.reference.impl.rotary import rotary_embedding_torch + + seq_len, num_heads, head_dim = 16, 8, 32 + max_seq_len = 32 + rotary_dim = head_dim // 2 + + # [seq_len, num_heads, head_dim] + query = torch.randn(seq_len, num_heads, head_dim, device=device) + key = torch.randn(seq_len, num_heads, head_dim, device=device) + + # For 3D query, use 1D position_ids [seq_len] + positions = torch.arange(seq_len, device=device) + cos = torch.randn(max_seq_len, rotary_dim, device=device) + sin = torch.randn(max_seq_len, rotary_dim, device=device) + + result_q, result_k = rotary_embedding_torch( + query, key, cos, sin, + position_ids=positions, + rotary_interleaved=False, + inplace=False, + ) + + assert result_q.shape == query.shape + assert result_k.shape == key.shape + + def test_dtype_preservation_3d(self, dtype, device): + """Test that dtype is preserved with 3D tensors.""" + from vllm_fl.dispatch.backends.reference.impl.rotary import rotary_embedding_torch + + seq_len, num_heads, head_dim = 8, 4, 32 + max_seq_len = 16 + rotary_dim = head_dim // 2 + + # Use 3D tensors for simpler testing + query = torch.randn(seq_len, num_heads, head_dim, device=device, dtype=dtype) + key = torch.randn(seq_len, num_heads, head_dim, device=device, dtype=dtype) + + positions = torch.arange(seq_len, device=device) + cos = torch.randn(max_seq_len, rotary_dim, device=device, dtype=dtype) + sin = torch.randn(max_seq_len, rotary_dim, device=device, dtype=dtype) + + result_q, result_k = rotary_embedding_torch( + query, key, cos, sin, + position_ids=positions, + rotary_interleaved=False, + inplace=False, + ) + + assert result_q.dtype == dtype + assert result_k.dtype == dtype + + def test_interleaved_vs_neox_style(self, device): + """Test both interleaved and neox rotary styles.""" + from vllm_fl.dispatch.backends.reference.impl.rotary import rotary_embedding_torch + + seq_len, num_heads, head_dim = 8, 4, 64 + max_seq_len = 16 + rotary_dim = head_dim // 2 + + # Use 3D tensors + query = torch.randn(seq_len, num_heads, head_dim, device=device) + key = torch.randn(seq_len, num_heads, head_dim, device=device) + + positions = torch.arange(seq_len, device=device) + cos = torch.randn(max_seq_len, rotary_dim, device=device) + sin = torch.randn(max_seq_len, rotary_dim, device=device) + + # Test neox style (default) + result_q_neox, result_k_neox = rotary_embedding_torch( + query, key, cos, sin, + position_ids=positions, + rotary_interleaved=False, + inplace=False, + ) + + # Test interleaved style + result_q_interleaved, result_k_interleaved = rotary_embedding_torch( + query, key, cos, sin, + position_ids=positions, + rotary_interleaved=True, + inplace=False, + ) + + # Results should be different between styles + assert not torch.allclose(result_q_neox, result_q_interleaved) + assert not torch.allclose(result_k_neox, result_k_interleaved) + + # But both should have valid outputs (no NaN/Inf) + assert not torch.isnan(result_q_neox).any() + assert not torch.isnan(result_q_interleaved).any() + + def test_rotary_embedding_mathematically(self, device): + """Test that rotary embedding produces expected rotation behavior.""" + from vllm_fl.dispatch.backends.reference.impl.rotary import rotary_embedding_torch + + seq_len, num_heads, head_dim = 4, 2, 8 + max_seq_len = 8 + rotary_dim = head_dim // 2 + + # Create simple test tensors + query = torch.ones(seq_len, num_heads, head_dim, device=device) + key = torch.ones(seq_len, num_heads, head_dim, device=device) + + # Create cos/sin with known values + positions = torch.arange(seq_len, device=device) + cos = torch.ones(max_seq_len, rotary_dim, device=device) + sin = torch.zeros(max_seq_len, rotary_dim, device=device) + + # With cos=1 and sin=0, output should equal input (no rotation) + result_q, result_k = rotary_embedding_torch( + query, key, cos, sin, + position_ids=positions, + rotary_interleaved=False, + inplace=False, + ) + + torch.testing.assert_close(result_q, query, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(result_k, key, rtol=1e-5, atol=1e-5) + + +# ============================================================================= +# Cross-Implementation Consistency Tests +# ============================================================================= + +class TestCrossImplementationConsistency: + """Test consistency between different backend implementations.""" + + def test_silu_reference_matches_pytorch(self, device): + """Test that reference implementation matches PyTorch exactly.""" + from vllm_fl.dispatch.backends.reference.impl.activation import silu_and_mul_torch + + x = torch.randn(4, 32, device=device) + + result = silu_and_mul_torch(x) + expected = reference_silu_and_mul(x) + + # Should be exactly equal (same implementation) + torch.testing.assert_close(result, expected, rtol=0, atol=0) + + def test_rms_norm_reference_matches_pytorch(self, device): + """Test that reference RMS norm matches our baseline.""" + from vllm_fl.dispatch.backends.reference.impl.normalization import rms_norm_torch + + x = torch.randn(4, 64, device=device) + weight = torch.randn(64, device=device) + epsilon = 1e-6 + + result = rms_norm_torch(x, None, weight, epsilon) + expected = reference_rms_norm(x, weight, epsilon) + + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + + +# ============================================================================= +# Edge Case Tests +# ============================================================================= + +class TestEdgeCases: + """Test edge cases and boundary conditions.""" + + def test_silu_single_element(self, device): + """Test SiLU with single element batch.""" + from vllm_fl.dispatch.backends.reference.impl.activation import silu_and_mul_torch + + x = torch.randn(1, 4, device=device) + result = silu_and_mul_torch(x) + + assert result.shape == (1, 2) + assert not torch.isnan(result).any() + + def test_rms_norm_single_element(self, device): + """Test RMS norm with single element batch.""" + from vllm_fl.dispatch.backends.reference.impl.normalization import rms_norm_torch + + x = torch.randn(1, 64, device=device) + weight = torch.ones(64, device=device) + + result = rms_norm_torch(x, None, weight, 1e-5) + + assert result.shape == (1, 64) + assert not torch.isnan(result).any() + + def test_very_small_values(self, device): + """Test with very small input values.""" + from vllm_fl.dispatch.backends.reference.impl.activation import silu_and_mul_torch + from vllm_fl.dispatch.backends.reference.impl.normalization import rms_norm_torch + + x_silu = torch.randn(2, 8, device=device) * 1e-7 + x_norm = torch.randn(2, 32, device=device) * 1e-7 + weight = torch.ones(32, device=device) + + result_silu = silu_and_mul_torch(x_silu) + result_norm = rms_norm_torch(x_norm, None, weight, 1e-5) + + assert not torch.isnan(result_silu).any() + assert not torch.isnan(result_norm).any() + assert not torch.isinf(result_silu).any() + assert not torch.isinf(result_norm).any() diff --git a/tests/unit_tests/ops/test_rotary_embedding.py b/tests/unit_tests/ops/test_rotary_embedding.py new file mode 100644 index 0000000..fddcb41 --- /dev/null +++ b/tests/unit_tests/ops/test_rotary_embedding.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Tests for rotary embedding ops. +""" + +import pytest +import torch +from unittest.mock import patch + + +class TestRotaryEmbeddingFL: + """Test RotaryEmbeddingFL class behavior.""" + + @pytest.fixture + def mock_call_op(self): + with patch("vllm_fl.ops.rotary_embedding.call_op") as mock: + yield mock + + @pytest.fixture + def mock_parent_init(self): + with patch("vllm_fl.ops.rotary_embedding.RotaryEmbedding.__init__", return_value=None): + yield + + def test_forward_oot_dispatches_correctly(self, mock_parent_init, mock_call_op): + """Test forward_oot calls dispatch system with correct arguments.""" + from vllm_fl.ops.rotary_embedding import RotaryEmbeddingFL + + layer = RotaryEmbeddingFL( + head_size=64, + rotary_dim=32, + max_position_embeddings=2048, + base=10000.0, + is_neox_style=True, + dtype=torch.float32, + ) + + # Manually set attributes that parent __init__ would set + layer.head_size = 64 + layer.rotary_dim = 32 + layer.is_neox_style = True + layer.cos_sin_cache = torch.randn(2048, 64) + + mock_call_op.return_value = (torch.randn(4, 8, 32), torch.randn(4, 8, 32)) + + positions = torch.tensor([0, 1, 2, 3]) + query = torch.randn(4, 8, 64) + key = torch.randn(4, 8, 64) + + result = layer.forward_oot(positions, query, key) + + mock_call_op.assert_called_once() + call_args = mock_call_op.call_args + assert call_args[0][0] == "rotary_embedding" diff --git a/tests/unit_tests/worker/__init__.py b/tests/unit_tests/worker/__init__.py new file mode 100644 index 0000000..8c90136 --- /dev/null +++ b/tests/unit_tests/worker/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2025 BAAI. All rights reserved. diff --git a/tests/unit_tests/worker/test_model_runner.py b/tests/unit_tests/worker/test_model_runner.py new file mode 100644 index 0000000..f7ee06d --- /dev/null +++ b/tests/unit_tests/worker/test_model_runner.py @@ -0,0 +1,299 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Tests for model runner module. + +This module follows a layered testing strategy: +- Layer 1: Pure functions and data classes (no external dependencies) +- Layer 2: Methods with mocked dependencies +- Layer 3: Integration tests (in functional_tests/, requires GPU) + +Note: These tests require vllm >= 0.13.0 with full installation. +""" + +import pytest +import numpy as np +import torch +from unittest.mock import MagicMock + + +# ============================================================================= +# Test Utilities - Check availability before importing +# ============================================================================= + +def has_vllm_model_runner(): + """Check if vllm model runner dependencies are available.""" + try: + from vllm_fl.worker.model_runner import ModelRunnerFL + return True + except ImportError: + return False + + +# Skip all tests if vllm model runner is not available +pytestmark = pytest.mark.skipif( + not has_vllm_model_runner(), + reason="vllm_fl.worker.model_runner not available" +) + + +# ============================================================================= +# Layer 1: ExecuteModelState Data Structure Tests +# ============================================================================= + +class TestExecuteModelState: + """Test ExecuteModelState NamedTuple behavior and contract.""" + + def test_fields_match_expected_contract(self): + """Verify ExecuteModelState has exact fields required by execute_model pipeline.""" + from vllm_fl.worker.model_runner import ExecuteModelState + + expected_fields = ( + 'scheduler_output', 'logits', 'spec_decode_metadata', + 'spec_decode_common_attn_metadata', 'hidden_states', + 'sample_hidden_states', 'aux_hidden_states', + 'ec_connector_output', 'cudagraph_stats' + ) + assert ExecuteModelState._fields == expected_fields, ( + "ExecuteModelState fields changed - this may break execute_model consumers" + ) + + def test_immutability_prevents_accidental_mutation(self): + """Ensure state cannot be mutated after creation (important for pipeline safety).""" + from vllm_fl.worker.model_runner import ExecuteModelState + + state = ExecuteModelState( + scheduler_output=MagicMock(), + logits=torch.randn(4, 1000), + spec_decode_metadata=None, + spec_decode_common_attn_metadata=None, + hidden_states=torch.randn(4, 512), + sample_hidden_states=torch.randn(4, 512), + aux_hidden_states=None, + ec_connector_output=None, + cudagraph_stats=None, + ) + + with pytest.raises(AttributeError): + state.logits = torch.randn(4, 1000) + + def test_unpacking_for_downstream_processing(self): + """Test that state can be unpacked correctly for downstream use.""" + from vllm_fl.worker.model_runner import ExecuteModelState + + mock_scheduler = MagicMock() + mock_logits = torch.randn(4, 1000) + + state = ExecuteModelState( + scheduler_output=mock_scheduler, + logits=mock_logits, + spec_decode_metadata=None, + spec_decode_common_attn_metadata=None, + hidden_states=None, + sample_hidden_states=None, + aux_hidden_states=None, + ec_connector_output=None, + cudagraph_stats=None, + ) + + # Simulate downstream unpacking + scheduler, logits, *rest = state + assert scheduler is mock_scheduler + assert torch.equal(logits, mock_logits) + + +# ============================================================================= +# Layer 2: _get_cumsum_and_arange Algorithm Tests +# ============================================================================= + +class TestGetCumsumAndArange: + """Test _get_cumsum_and_arange method - critical for batch processing.""" + + @pytest.fixture + def mock_model_runner(self): + """Create a minimal mock of ModelRunnerFL for testing.""" + from vllm_fl.worker.model_runner import ModelRunnerFL + + mock_runner = MagicMock(spec=ModelRunnerFL) + mock_runner.arange_np = np.arange(10000) + mock_runner._get_cumsum_and_arange = ModelRunnerFL._get_cumsum_and_arange.__get__( + mock_runner, ModelRunnerFL + ) + return mock_runner + + def test_multi_sequence_batch(self, mock_model_runner): + """Test cumsum and per-sequence arange for typical multi-sequence batch.""" + num_tokens = np.array([2, 5, 3]) + + cu_num_tokens, arange = mock_model_runner._get_cumsum_and_arange(num_tokens) + + # Cumsum: [2, 7, 10] - used for indexing into flattened batch + np.testing.assert_array_equal(cu_num_tokens, np.array([2, 7, 10])) + + # Arange: per-sequence position indices [0,1 | 0,1,2,3,4 | 0,1,2] + expected_arange = np.array([0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + np.testing.assert_array_equal(arange, expected_arange) + + def test_single_sequence(self, mock_model_runner): + """Test with single sequence (common in generation phase).""" + num_tokens = np.array([5]) + + cu_num_tokens, arange = mock_model_runner._get_cumsum_and_arange(num_tokens) + + np.testing.assert_array_equal(cu_num_tokens, np.array([5])) + np.testing.assert_array_equal(arange, np.array([0, 1, 2, 3, 4])) + + def test_all_single_token_sequences(self, mock_model_runner): + """Test batch where each sequence has 1 token (decode phase).""" + num_tokens = np.array([1, 1, 1, 1]) + + cu_num_tokens, arange = mock_model_runner._get_cumsum_and_arange(num_tokens) + + np.testing.assert_array_equal(cu_num_tokens, np.array([1, 2, 3, 4])) + np.testing.assert_array_equal(arange, np.array([0, 0, 0, 0])) + + def test_large_sequences(self, mock_model_runner): + """Test with larger sequences to verify correct boundary handling.""" + num_tokens = np.array([10, 20, 30]) + + cu_num_tokens, arange = mock_model_runner._get_cumsum_and_arange(num_tokens) + + assert cu_num_tokens[-1] == 60 + assert len(arange) == 60 + # Verify boundaries: first seq 0-9, second seq 0-19, third seq 0-29 + np.testing.assert_array_equal(arange[:10], np.arange(10)) + np.testing.assert_array_equal(arange[10:30], np.arange(20)) + np.testing.assert_array_equal(arange[30:60], np.arange(30)) + + def test_dtype_preservation(self, mock_model_runner): + """Test that dtype is correctly applied to cumsum output.""" + num_tokens = np.array([2, 3]) + + cu_num_tokens, _ = mock_model_runner._get_cumsum_and_arange( + num_tokens, cumsum_dtype=np.int64 + ) + + assert cu_num_tokens.dtype == np.int64 + + +# ============================================================================= +# Layer 2: _pad_for_sequence_parallelism Logic Tests +# ============================================================================= + +class TestPadForSequenceParallelism: + """Test sequence parallelism padding logic.""" + + @pytest.fixture + def mock_model_runner(self): + """Create mock model runner for padding tests.""" + from vllm_fl.worker.model_runner import ModelRunnerFL + + mock_runner = MagicMock(spec=ModelRunnerFL) + mock_runner.vllm_config = MagicMock() + mock_runner.vllm_config.parallel_config = MagicMock() + mock_runner.compilation_config = MagicMock() + mock_runner.compilation_config.pass_config = MagicMock() + mock_runner._pad_for_sequence_parallelism = ( + ModelRunnerFL._pad_for_sequence_parallelism.__get__( + mock_runner, ModelRunnerFL + ) + ) + return mock_runner + + def test_no_padding_when_sp_disabled(self, mock_model_runner): + """SP disabled should return original token count.""" + mock_model_runner.vllm_config.parallel_config.tensor_parallel_size = 4 + mock_model_runner.compilation_config.pass_config.enable_sp = False + + assert mock_model_runner._pad_for_sequence_parallelism(10) == 10 + + def test_no_padding_when_tp_size_1(self, mock_model_runner): + """TP size 1 means no parallelism, no padding needed.""" + mock_model_runner.vllm_config.parallel_config.tensor_parallel_size = 1 + mock_model_runner.compilation_config.pass_config.enable_sp = True + + assert mock_model_runner._pad_for_sequence_parallelism(10) == 10 + + @pytest.mark.parametrize("num_tokens,tp_size,expected", [ + (10, 4, 12), # 10 -> ceil to multiple of 4 + (8, 4, 8), # 8 already multiple of 4 + (10, 8, 16), # 10 -> ceil to multiple of 8 + (1, 4, 4), # 1 -> ceil to multiple of 4 + (15, 8, 16), # 15 -> ceil to multiple of 8 + ]) + def test_padding_calculation(self, mock_model_runner, num_tokens, tp_size, expected): + """Verify padding rounds up to next multiple of tp_size.""" + mock_model_runner.vllm_config.parallel_config.tensor_parallel_size = tp_size + mock_model_runner.compilation_config.pass_config.enable_sp = True + + result = mock_model_runner._pad_for_sequence_parallelism(num_tokens) + + assert result == expected + assert result % tp_size == 0 # Must be divisible + + +# ============================================================================= +# Layer 2: _get_positions Routing Tests +# ============================================================================= + +class TestGetPositions: + """Test position retrieval for different position encoding schemes.""" + + @pytest.fixture + def mock_model_runner(self): + """Create mock model runner for position tests.""" + from vllm_fl.worker.model_runner import ModelRunnerFL + + mock_runner = MagicMock(spec=ModelRunnerFL) + + # Standard positions buffer + mock_runner.positions = MagicMock() + mock_runner.positions.gpu = torch.arange(100) + + # MRoPE positions (3D for temporal, height, width) + mock_runner.mrope_positions = MagicMock() + mock_runner.mrope_positions.gpu = torch.arange(300).reshape(3, 100) + + # XDRoPE positions (2D) + mock_runner.xdrope_positions = MagicMock() + mock_runner.xdrope_positions.gpu = torch.arange(200).reshape(2, 100) + + mock_runner.uses_mrope = False + mock_runner.uses_xdrope_dim = 0 + + mock_runner._get_positions = ModelRunnerFL._get_positions.__get__( + mock_runner, ModelRunnerFL + ) + return mock_runner + + def test_standard_positions_with_int(self, mock_model_runner): + """Standard RoPE: integer returns first N positions.""" + result = mock_model_runner._get_positions(10) + torch.testing.assert_close(result, torch.arange(10)) + + def test_standard_positions_with_indices(self, mock_model_runner): + """Standard RoPE: tensor indices for selective position lookup.""" + indices = torch.tensor([0, 5, 10, 15]) + result = mock_model_runner._get_positions(indices) + expected = mock_model_runner.positions.gpu[indices] + torch.testing.assert_close(result, expected) + + def test_mrope_returns_3d_positions(self, mock_model_runner): + """MRoPE (Qwen2-VL): returns [3, num_tokens] positions.""" + mock_model_runner.uses_mrope = True + + result = mock_model_runner._get_positions(10) + + expected = mock_model_runner.mrope_positions.gpu[:, :10] + assert result.shape == (3, 10) + torch.testing.assert_close(result, expected) + + def test_xdrope_returns_2d_positions(self, mock_model_runner): + """XDRoPE: returns [2, num_tokens] positions.""" + mock_model_runner.uses_xdrope_dim = 64 + + result = mock_model_runner._get_positions(10) + + expected = mock_model_runner.xdrope_positions.gpu[:, :10] + assert result.shape == (2, 10) + torch.testing.assert_close(result, expected) diff --git a/tests/unit_tests/worker/test_worker.py b/tests/unit_tests/worker/test_worker.py new file mode 100644 index 0000000..be12b06 --- /dev/null +++ b/tests/unit_tests/worker/test_worker.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +""" +Tests for worker module. + +Note: These tests require vllm >= 0.13.0 with profiler support. +""" + +import pytest + + +def has_vllm_profiler(): + """Check if vllm profiler is available.""" + try: + from vllm.profiler.wrapper import TorchProfilerWrapper + return True + except ImportError: + return False + + +# Skip all tests if vllm profiler is not available +pytestmark = pytest.mark.skipif( + not has_vllm_profiler(), + reason="vllm.profiler.wrapper not available (requires vllm >= 0.13.0)" +) + + +class TestMemorySnapshot: + """Test MemorySnapshot dataclass behavior.""" + + def test_default_values_without_auto_measure(self): + """Test MemorySnapshot initializes with correct default values.""" + from vllm_fl.worker.worker import MemorySnapshot + + snapshot = MemorySnapshot(auto_measure=False) + + assert snapshot.torch_peak == 0 + assert snapshot.free_memory == 0 + assert snapshot.total_memory == 0 + assert snapshot.cuda_memory == 0 + assert snapshot.torch_memory == 0 + assert snapshot.non_torch_memory == 0 + + def test_subtraction_computes_difference(self): + """Test MemorySnapshot subtraction operator computes correct differences.""" + from vllm_fl.worker.worker import MemorySnapshot + + snapshot1 = MemorySnapshot(auto_measure=False) + snapshot1.torch_peak = 1000 + snapshot1.free_memory = 5000 + snapshot1.total_memory = 10000 + snapshot1.cuda_memory = 5000 + snapshot1.torch_memory = 3000 + snapshot1.non_torch_memory = 2000 + snapshot1.timestamp = 10.0 + + snapshot2 = MemorySnapshot(auto_measure=False) + snapshot2.torch_peak = 500 + snapshot2.free_memory = 6000 + snapshot2.total_memory = 10000 + snapshot2.cuda_memory = 4000 + snapshot2.torch_memory = 2000 + snapshot2.non_torch_memory = 2000 + snapshot2.timestamp = 5.0 + + diff = snapshot1 - snapshot2 + + assert diff.torch_peak == 500 + assert diff.free_memory == -1000 + assert diff.cuda_memory == 1000 + assert diff.torch_memory == 1000 + assert diff.timestamp == 5.0 + + +class TestMemoryProfilingResult: + """Test MemoryProfilingResult dataclass behavior.""" + + def test_default_values(self): + """Test MemoryProfilingResult initializes with correct default values.""" + from vllm_fl.worker.worker import MemoryProfilingResult + + result = MemoryProfilingResult() + + assert result.weights_memory == 0 + assert result.torch_peak_increase == 0 + assert result.non_torch_increase == 0 + assert result.non_kv_cache_memory == 0 + assert result.profile_time == 0.0 + + def test_creates_default_snapshots(self): + """Test MemoryProfilingResult creates default snapshot objects.""" + from vllm_fl.worker.worker import MemoryProfilingResult + + result = MemoryProfilingResult() + + assert result.before_profile is not None + assert result.after_profile is not None diff --git a/vllm_fl/__init__.py b/vllm_fl/__init__.py index 4c5d40f..8592294 100644 --- a/vllm_fl/__init__.py +++ b/vllm_fl/__init__.py @@ -9,6 +9,15 @@ logger = logging.getLogger(__name__) +def __getattr__(name): + if name == "distributed": + import importlib + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + return module + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + def register(): """Register the FL platform.""" diff --git a/vllm_fl/distributed/__init__.py b/vllm_fl/distributed/__init__.py index e69de29..e215455 100644 --- a/vllm_fl/distributed/__init__.py +++ b/vllm_fl/distributed/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 BAAI. All rights reserved. + +import importlib + +__all__ = ["communicator"] + + +def __getattr__(name): + if name in __all__: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + return module + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index 8a7c269..231ada6 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -10,6 +10,12 @@ import torch +# import custom ops, trigger op registration (CUDA only) +try: + import vllm._C # noqa +except ImportError: + pass # NPU or other platforms may not have vllm._C + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger diff --git a/vllm_fl/utils.py b/vllm_fl/utils.py index dc87a19..d28c70f 100644 --- a/vllm_fl/utils.py +++ b/vllm_fl/utils.py @@ -180,12 +180,12 @@ def get_flaggems_all_ops() -> list[str]: Get all FlagGems operator names from flag_gems._FULL_CONFIG. """ try: - pass + # _FULL_CONFIG is a tuple of (op_name, function, ...) tuples + # Some entries have 2 elements, some have 3 + ops = [entry[0] for entry in flag_gems._FULL_CONFIG] + return ops except Exception: return [] - ops = flag_gems.all_registered_ops() - - return ops # OOT operator names as registered in custom_ops.py (op_name lowercase)