Skip to content

Commit 28ba6d4

Browse files
committed
fix: add more test coverage and remove redundant command acknowledge messages in telemetry manager
1 parent a4b0c0e commit 28ba6d4

File tree

3 files changed

+689
-1
lines changed

3 files changed

+689
-1
lines changed

tests/config/test_user_config.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
TokenizerConfig,
1515
UserConfig,
1616
)
17-
from aiperf.common.enums import EndpointType
17+
from aiperf.common.enums import EndpointType, GPUTelemetryMode
1818
from aiperf.common.enums.timing_enums import TimingMode
1919

2020

@@ -201,3 +201,151 @@ def test_compute_artifact_directory(
201201

202202
artifact_dir = config._compute_artifact_directory()
203203
assert artifact_dir == Path(expected_dir)
204+
205+
206+
@pytest.mark.parametrize(
207+
"gpu_telemetry_input,expected_mode,expected_urls",
208+
[
209+
# No telemetry configured
210+
([], GPUTelemetryMode.SUMMARY, []),
211+
# Dashboard mode only
212+
(["dashboard"], GPUTelemetryMode.REALTIME_DASHBOARD, []),
213+
# URLs only (no dashboard)
214+
(
215+
["http://node1:9401/metrics"],
216+
GPUTelemetryMode.SUMMARY,
217+
["http://node1:9401/metrics"],
218+
),
219+
# Dashboard + URLs
220+
(
221+
["dashboard", "http://node1:9401/metrics"],
222+
GPUTelemetryMode.REALTIME_DASHBOARD,
223+
["http://node1:9401/metrics"],
224+
),
225+
# Multiple URLs
226+
(
227+
["http://node1:9401/metrics", "http://node2:9401/metrics"],
228+
GPUTelemetryMode.SUMMARY,
229+
["http://node1:9401/metrics", "http://node2:9401/metrics"],
230+
),
231+
# Dashboard + multiple URLs
232+
(
233+
[
234+
"dashboard",
235+
"http://node1:9401/metrics",
236+
"http://node2:9401/metrics",
237+
],
238+
GPUTelemetryMode.REALTIME_DASHBOARD,
239+
["http://node1:9401/metrics", "http://node2:9401/metrics"],
240+
),
241+
],
242+
)
243+
def test_parse_gpu_telemetry_config(gpu_telemetry_input, expected_mode, expected_urls):
244+
"""Test parsing of gpu_telemetry list into mode and URLs."""
245+
config = UserConfig(
246+
endpoint=EndpointConfig(
247+
model_names=["test-model"],
248+
type=EndpointType.CHAT,
249+
custom_endpoint="test",
250+
),
251+
gpu_telemetry=gpu_telemetry_input,
252+
)
253+
254+
assert config.gpu_telemetry_mode == expected_mode
255+
assert config.gpu_telemetry_urls == expected_urls
256+
257+
258+
def test_parse_gpu_telemetry_config_with_defaults():
259+
"""Test that gpu_telemetry_mode and gpu_telemetry_urls have correct defaults."""
260+
config = UserConfig(
261+
endpoint=EndpointConfig(
262+
model_names=["test-model"],
263+
type=EndpointType.CHAT,
264+
custom_endpoint="test",
265+
)
266+
)
267+
268+
# Should have default values
269+
assert config.gpu_telemetry_mode == GPUTelemetryMode.SUMMARY
270+
assert config.gpu_telemetry_urls == []
271+
272+
273+
def test_parse_gpu_telemetry_config_preserves_existing_fields():
274+
"""Test that parsing GPU telemetry config doesn't affect other fields."""
275+
config = UserConfig(
276+
endpoint=EndpointConfig(
277+
model_names=["test-model"],
278+
type=EndpointType.CHAT,
279+
custom_endpoint="test",
280+
streaming=True,
281+
),
282+
gpu_telemetry=["dashboard", "http://custom:9401/metrics"],
283+
)
284+
285+
# Telemetry fields should be set
286+
assert config.gpu_telemetry_mode == GPUTelemetryMode.REALTIME_DASHBOARD
287+
assert config.gpu_telemetry_urls == ["http://custom:9401/metrics"]
288+
289+
# Other fields should be unchanged
290+
assert config.endpoint.streaming is True
291+
assert config.endpoint.model_names == ["test-model"]
292+
293+
294+
def test_gpu_telemetry_urls_extraction():
295+
"""Test that only http URLs are extracted from gpu_telemetry list."""
296+
config = UserConfig(
297+
endpoint=EndpointConfig(
298+
model_names=["test-model"],
299+
type=EndpointType.CHAT,
300+
custom_endpoint="test",
301+
),
302+
gpu_telemetry=[
303+
"dashboard", # Not a URL
304+
"http://node1:9401/metrics", # Valid URL
305+
"https://node2:9401/metrics", # Valid URL
306+
"summary", # Not a URL
307+
],
308+
)
309+
310+
# Should extract only http/https URLs
311+
assert len(config.gpu_telemetry_urls) == 2
312+
assert "http://node1:9401/metrics" in config.gpu_telemetry_urls
313+
assert "https://node2:9401/metrics" in config.gpu_telemetry_urls
314+
assert "dashboard" not in config.gpu_telemetry_urls
315+
assert "summary" not in config.gpu_telemetry_urls
316+
317+
318+
def test_gpu_telemetry_mode_detection():
319+
"""Test that dashboard mode is detected correctly in various positions."""
320+
# Dashboard at beginning
321+
config1 = UserConfig(
322+
endpoint=EndpointConfig(
323+
model_names=["test-model"],
324+
type=EndpointType.CHAT,
325+
custom_endpoint="test",
326+
),
327+
gpu_telemetry=["dashboard", "http://node1:9401/metrics"],
328+
)
329+
assert config1.gpu_telemetry_mode == GPUTelemetryMode.REALTIME_DASHBOARD
330+
331+
# Dashboard at end
332+
config2 = UserConfig(
333+
endpoint=EndpointConfig(
334+
model_names=["test-model"],
335+
type=EndpointType.CHAT,
336+
custom_endpoint="test",
337+
),
338+
gpu_telemetry=["http://node1:9401/metrics", "dashboard"],
339+
)
340+
assert config2.gpu_telemetry_mode == GPUTelemetryMode.REALTIME_DASHBOARD
341+
342+
# No dashboard
343+
config3 = UserConfig(
344+
endpoint=EndpointConfig(
345+
model_names=["test-model"],
346+
type=EndpointType.CHAT,
347+
custom_endpoint="test",
348+
),
349+
gpu_telemetry=["http://node1:9401/metrics"],
350+
)
351+
assert config3.gpu_telemetry_mode == GPUTelemetryMode.SUMMARY
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import asyncio
5+
from unittest.mock import AsyncMock, MagicMock, patch
6+
7+
import pytest
8+
9+
from aiperf.common.config import ServiceConfig
10+
from aiperf.common.hooks import AIPerfHook
11+
from aiperf.common.messages import RealtimeTelemetryMetricsMessage
12+
from aiperf.common.mixins.realtime_telemetry_metrics_mixin import (
13+
RealtimeTelemetryMetricsMixin,
14+
)
15+
from aiperf.common.models import MetricResult
16+
17+
18+
class TestRealtimeTelemetryMetricsMixin:
19+
"""Test suite for RealtimeTelemetryMetricsMixin functionality."""
20+
21+
@pytest.fixture
22+
def mocked_mixin(self):
23+
"""Create a RealtimeTelemetryMetricsMixin instance with mocked dependencies."""
24+
service_config = ServiceConfig()
25+
mock_controller = MagicMock()
26+
27+
# Mock the MessageBusClientMixin.__init__ to avoid initialization issues
28+
with patch(
29+
"aiperf.common.mixins.message_bus_mixin.MessageBusClientMixin.__init__",
30+
return_value=None,
31+
):
32+
mixin = RealtimeTelemetryMetricsMixin(
33+
service_config=service_config, controller=mock_controller
34+
)
35+
# Manually set attributes that would be set by parent __init__
36+
mixin._controller = mock_controller
37+
mixin._telemetry_metrics = []
38+
mixin.run_hooks = AsyncMock()
39+
40+
return mixin
41+
42+
def test_mixin_initialization(self, mocked_mixin):
43+
"""Test that mixin initializes with correct attributes."""
44+
assert hasattr(mocked_mixin, "_controller")
45+
assert hasattr(mocked_mixin, "_telemetry_metrics")
46+
assert hasattr(mocked_mixin, "_telemetry_metrics_lock")
47+
assert mocked_mixin._telemetry_metrics == []
48+
49+
@pytest.mark.asyncio
50+
async def test_on_realtime_telemetry_metrics_stores_metrics(self, mocked_mixin):
51+
"""Test that telemetry metrics are stored when message is received."""
52+
metrics = [
53+
MetricResult(tag="gpu_util", header="GPU Utilization", unit="%", avg=75.0),
54+
MetricResult(
55+
tag="gpu_memory", header="GPU Memory Used", unit="GB", avg=8.5
56+
),
57+
]
58+
59+
message = RealtimeTelemetryMetricsMessage(
60+
service_id="records_manager", metrics=metrics
61+
)
62+
63+
await mocked_mixin._on_realtime_telemetry_metrics(message)
64+
65+
# Verify metrics were stored
66+
assert mocked_mixin._telemetry_metrics == metrics
67+
68+
@pytest.mark.asyncio
69+
async def test_on_realtime_telemetry_metrics_triggers_hook(self, mocked_mixin):
70+
"""Test that receiving telemetry metrics triggers the appropriate hook."""
71+
72+
metrics = [
73+
MetricResult(tag="gpu_util", header="GPU Utilization", unit="%", avg=75.0)
74+
]
75+
76+
message = RealtimeTelemetryMetricsMessage(
77+
service_id="records_manager", metrics=metrics
78+
)
79+
80+
await mocked_mixin._on_realtime_telemetry_metrics(message)
81+
82+
# Verify hook was triggered with correct arguments
83+
mocked_mixin.run_hooks.assert_called_once_with(
84+
AIPerfHook.ON_REALTIME_TELEMETRY_METRICS, metrics=metrics
85+
)
86+
87+
@pytest.mark.asyncio
88+
async def test_on_realtime_telemetry_metrics_replaces_previous_metrics(
89+
self, mocked_mixin
90+
):
91+
"""Test that new metrics replace previous metrics (not append)."""
92+
# Set initial metrics
93+
initial_metrics = [
94+
MetricResult(tag="old_metric", header="Old Metric", unit="ms", avg=10.0)
95+
]
96+
mocked_mixin._telemetry_metrics = initial_metrics
97+
98+
# Receive new metrics
99+
new_metrics = [
100+
MetricResult(tag="new_metric", header="New Metric", unit="%", avg=50.0)
101+
]
102+
message = RealtimeTelemetryMetricsMessage(
103+
service_id="records_manager", metrics=new_metrics
104+
)
105+
106+
await mocked_mixin._on_realtime_telemetry_metrics(message)
107+
108+
# Verify old metrics were replaced, not appended
109+
assert mocked_mixin._telemetry_metrics == new_metrics
110+
assert len(mocked_mixin._telemetry_metrics) == 1
111+
112+
@pytest.mark.asyncio
113+
async def test_on_realtime_telemetry_metrics_with_empty_list(self, mocked_mixin):
114+
"""Test that receiving empty metrics list is handled correctly."""
115+
message = RealtimeTelemetryMetricsMessage(
116+
service_id="records_manager", metrics=[]
117+
)
118+
119+
await mocked_mixin._on_realtime_telemetry_metrics(message)
120+
121+
# Should store empty list and still trigger hook
122+
assert mocked_mixin._telemetry_metrics == []
123+
mocked_mixin.run_hooks.assert_called_once()
124+
125+
@pytest.mark.asyncio
126+
async def test_concurrent_access_with_lock(self, mocked_mixin):
127+
"""Test that the lock protects concurrent access to telemetry metrics."""
128+
129+
# Track lock acquisition order
130+
lock_acquired_order = []
131+
132+
async def acquire_lock_and_update(metrics_value, delay):
133+
"""Helper to simulate concurrent updates."""
134+
async with mocked_mixin._telemetry_metrics_lock:
135+
lock_acquired_order.append(metrics_value)
136+
await asyncio.sleep(delay)
137+
mocked_mixin._telemetry_metrics = [
138+
MetricResult(
139+
tag=f"metric_{metrics_value}",
140+
header=f"Metric {metrics_value}",
141+
unit="ms",
142+
avg=float(metrics_value),
143+
)
144+
]
145+
146+
# Start two concurrent operations
147+
await asyncio.gather(
148+
acquire_lock_and_update(1, 0.01), acquire_lock_and_update(2, 0.005)
149+
)
150+
151+
# Both should have acquired the lock (order doesn't matter for this test)
152+
assert len(lock_acquired_order) == 2
153+
assert set(lock_acquired_order) == {1, 2}
154+
155+
# Final value should be from the last completed operation
156+
assert len(mocked_mixin._telemetry_metrics) == 1
157+
158+
@pytest.mark.asyncio
159+
async def test_multiple_metrics_handling(self, mocked_mixin):
160+
"""Test handling of message with multiple metrics."""
161+
metrics = [
162+
MetricResult(
163+
tag=f"metric_{i}", header=f"Metric {i}", unit="ms", avg=float(i)
164+
)
165+
for i in range(10)
166+
]
167+
168+
message = RealtimeTelemetryMetricsMessage(
169+
service_id="records_manager", metrics=metrics
170+
)
171+
172+
await mocked_mixin._on_realtime_telemetry_metrics(message)
173+
174+
# All metrics should be stored
175+
assert len(mocked_mixin._telemetry_metrics) == 10
176+
assert mocked_mixin._telemetry_metrics == metrics
177+
178+
@pytest.mark.asyncio
179+
async def test_integration_with_controller(self):
180+
"""Test that mixin integrates correctly with controller."""
181+
service_config = ServiceConfig()
182+
mock_controller = MagicMock()
183+
mock_controller.some_method = MagicMock(return_value="test_value")
184+
185+
with patch(
186+
"aiperf.common.mixins.message_bus_mixin.MessageBusClientMixin.__init__",
187+
return_value=None,
188+
):
189+
mixin = RealtimeTelemetryMetricsMixin(
190+
service_config=service_config, controller=mock_controller
191+
)
192+
193+
# Verify controller is accessible
194+
assert mixin._controller == mock_controller
195+
assert mixin._controller.some_method() == "test_value"

0 commit comments

Comments
 (0)