|
4 | 4 | # pylint: disable=unused-argument |
5 | 5 |
|
6 | 6 | import re |
7 | | -from typing import Any, Optional |
| 7 | +from collections.abc import Callable |
| 8 | +from typing import Any |
8 | 9 |
|
9 | 10 | import pytest |
10 | 11 | from fastapi import HTTPException, status |
|
17 | 18 | AUTH_DISABLED, |
18 | 19 | _build_instructions, |
19 | 20 | _get_default_model_id, |
| 21 | + _get_prompt_template, |
20 | 22 | _get_rh_identity_context, |
21 | 23 | infer_endpoint, |
22 | 24 | retrieve_simple_response, |
|
39 | 41 | MOCK_AUTH: AuthTuple = ("mock_user_id", "mock_username", False, "mock_token") |
40 | 42 |
|
41 | 43 |
|
| 44 | +@pytest.fixture(autouse=True) |
| 45 | +def _clear_prompt_template_cache() -> None: |
| 46 | + """Clear the lru_cache on _get_prompt_template between tests.""" |
| 47 | + _get_prompt_template.cache_clear() |
| 48 | + |
| 49 | + |
| 50 | +@pytest.fixture(name="mock_custom_prompt") |
| 51 | +def mock_custom_prompt_fixture(mocker: MockerFixture) -> Callable[[str], None]: |
| 52 | + """Factory fixture that patches configuration with a custom system prompt.""" |
| 53 | + |
| 54 | + def _set(prompt: str) -> None: |
| 55 | + mock_customization = mocker.Mock() |
| 56 | + mock_customization.system_prompt = prompt |
| 57 | + mock_config = mocker.Mock() |
| 58 | + mock_config.customization = mock_customization |
| 59 | + mocker.patch("app.endpoints.rlsapi_v1.configuration", mock_config) |
| 60 | + |
| 61 | + return _set |
| 62 | + |
| 63 | + |
42 | 64 | def _create_mock_request(mocker: MockerFixture, rh_identity: Any = None) -> Any: |
43 | 65 | """Create a mock FastAPI Request with optional RH Identity data.""" |
44 | 66 | mock_request = mocker.Mock() |
@@ -140,93 +162,124 @@ def mock_generic_runtime_error_fixture(mocker: MockerFixture) -> None: |
140 | 162 | # --- Test _build_instructions --- |
141 | 163 |
|
142 | 164 |
|
143 | | -@pytest.mark.parametrize( |
144 | | - ("systeminfo_kwargs", "expected_contains", "expected_not_contains"), |
145 | | - [ |
146 | | - pytest.param( |
147 | | - {"os": "RHEL", "version": "9.3", "arch": "x86_64"}, |
148 | | - ["OS: RHEL", "Version: 9.3", "Architecture: x86_64"], |
149 | | - [], |
150 | | - id="full_systeminfo", |
151 | | - ), |
152 | | - pytest.param( |
153 | | - {"os": "RHEL", "version": "", "arch": ""}, |
154 | | - ["OS: RHEL"], |
155 | | - ["Version:", "Architecture:"], |
156 | | - id="partial_systeminfo", |
157 | | - ), |
158 | | - pytest.param( |
159 | | - {}, |
160 | | - [constants.DEFAULT_SYSTEM_PROMPT], |
161 | | - ["OS:", "Version:", "Architecture:"], |
162 | | - id="empty_systeminfo", |
163 | | - ), |
164 | | - ], |
165 | | -) |
166 | | -def test_build_instructions( |
167 | | - systeminfo_kwargs: dict[str, str], |
168 | | - expected_contains: list[str], |
169 | | - expected_not_contains: list[str], |
170 | | -) -> None: |
171 | | - """Test _build_instructions includes date and system info.""" |
172 | | - systeminfo = RlsapiV1SystemInfo(**systeminfo_kwargs) |
| 165 | +def test_build_instructions_default_prompt_passes_through() -> None: |
| 166 | + """Test _build_instructions returns default prompt unchanged when no template vars.""" |
| 167 | + systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64") |
173 | 168 | result = _build_instructions(systeminfo) |
174 | 169 |
|
175 | | - assert re.search(r"Today's date: \w+ \d{2}, \d{4}", result) |
176 | | - for expected in expected_contains: |
177 | | - assert expected in result |
178 | | - for not_expected in expected_not_contains: |
179 | | - assert not_expected not in result |
180 | | - |
181 | | - |
182 | | -# --- Test _build_instructions with customization.system_prompt --- |
| 170 | + assert result == constants.DEFAULT_SYSTEM_PROMPT |
183 | 171 |
|
184 | 172 |
|
185 | | -@pytest.mark.parametrize( |
186 | | - ("custom_prompt", "expected_prompt"), |
187 | | - [ |
188 | | - pytest.param( |
189 | | - "You are a RHEL expert.", |
190 | | - "You are a RHEL expert.", |
191 | | - id="customization_system_prompt_set", |
192 | | - ), |
193 | | - pytest.param( |
194 | | - None, |
195 | | - constants.DEFAULT_SYSTEM_PROMPT, |
196 | | - id="customization_system_prompt_none", |
197 | | - ), |
198 | | - ], |
199 | | -) |
200 | | -def test_build_instructions_with_customization( |
201 | | - mocker: MockerFixture, |
202 | | - custom_prompt: Optional[str], |
203 | | - expected_prompt: str, |
204 | | -) -> None: |
205 | | - """Test _build_instructions uses customization.system_prompt when set.""" |
| 173 | +def test_build_instructions_with_customization(mocker: MockerFixture) -> None: |
| 174 | + """Test _build_instructions uses customization.system_prompt with template vars.""" |
| 175 | + template = "Expert assistant.\n\nDate: {{ date }}\nOS: {{ os }}" |
206 | 176 | mock_customization = mocker.Mock() |
207 | | - mock_customization.system_prompt = custom_prompt |
| 177 | + mock_customization.system_prompt = template |
208 | 178 | mock_config = mocker.Mock() |
209 | 179 | mock_config.customization = mock_customization |
210 | 180 | mocker.patch("app.endpoints.rlsapi_v1.configuration", mock_config) |
211 | 181 |
|
212 | 182 | systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64") |
213 | 183 | result = _build_instructions(systeminfo) |
214 | 184 |
|
215 | | - assert expected_prompt in result |
| 185 | + assert "Expert assistant." in result |
216 | 186 | assert "OS: RHEL" in result |
| 187 | + assert re.search(r"Date: \w+ \d{2}, \d{4}", result) |
217 | 188 |
|
218 | 189 |
|
219 | 190 | def test_build_instructions_no_customization(mocker: MockerFixture) -> None: |
220 | | - """Test _build_instructions falls back when customization is None.""" |
| 191 | + """Test _build_instructions falls back to DEFAULT_SYSTEM_PROMPT.""" |
221 | 192 | mock_config = mocker.Mock() |
222 | 193 | mock_config.customization = None |
223 | 194 | mocker.patch("app.endpoints.rlsapi_v1.configuration", mock_config) |
224 | 195 |
|
225 | 196 | systeminfo = RlsapiV1SystemInfo() |
226 | 197 | result = _build_instructions(systeminfo) |
227 | 198 |
|
228 | | - assert result.startswith(constants.DEFAULT_SYSTEM_PROMPT) |
229 | | - assert re.search(r"Today's date: \w+ \d{2}, \d{4}", result) |
| 199 | + assert result == constants.DEFAULT_SYSTEM_PROMPT |
| 200 | + |
| 201 | + |
| 202 | +# --- Test Jinja2 template rendering --- |
| 203 | + |
| 204 | + |
| 205 | +def test_build_instructions_renders_jinja2_template( |
| 206 | + mock_custom_prompt: Callable[[str], None], |
| 207 | +) -> None: |
| 208 | + """Test _build_instructions renders Jinja2 template variables instead of appending.""" |
| 209 | + mock_custom_prompt( |
| 210 | + "You are an assistant.\n\nDate: {{ date }}\nOS: {{ os }} {{ version }} ({{ arch }})" |
| 211 | + ) |
| 212 | + |
| 213 | + systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64") |
| 214 | + result = _build_instructions(systeminfo) |
| 215 | + |
| 216 | + assert "OS: RHEL 9.3 (x86_64)" in result |
| 217 | + assert re.search(r"Date: \w+ \d{2}, \d{4}", result) |
| 218 | + assert "Today's date:" not in result |
| 219 | + assert "User's system:" not in result |
| 220 | + |
| 221 | + |
| 222 | +def test_build_instructions_jinja2_none_values_render_empty( |
| 223 | + mock_custom_prompt: Callable[[str], None], |
| 224 | +) -> None: |
| 225 | + """Test that None system info values render as empty strings, not 'None'.""" |
| 226 | + mock_custom_prompt("Assistant.\nOS={{ os }} VER={{ version }} ARCH={{ arch }}") |
| 227 | + |
| 228 | + systeminfo = RlsapiV1SystemInfo() |
| 229 | + result = _build_instructions(systeminfo) |
| 230 | + |
| 231 | + assert "None" not in result |
| 232 | + assert "OS= VER= ARCH=" in result |
| 233 | + |
| 234 | + |
| 235 | +def test_build_instructions_jinja2_conditionals( |
| 236 | + mock_custom_prompt: Callable[[str], None], |
| 237 | +) -> None: |
| 238 | + """Test that Jinja2 conditionals work in system prompt templates.""" |
| 239 | + mock_custom_prompt( |
| 240 | + "Assistant.{% if os %} OS: {{ os }}{% endif %}" |
| 241 | + "{% if version %} VER: {{ version }}{% endif %}" |
| 242 | + ) |
| 243 | + |
| 244 | + systeminfo = RlsapiV1SystemInfo(os="RHEL") |
| 245 | + result = _build_instructions(systeminfo) |
| 246 | + |
| 247 | + assert "OS: RHEL" in result |
| 248 | + assert "VER:" not in result |
| 249 | + |
| 250 | + |
| 251 | +def test_build_instructions_plain_prompt_passes_through( |
| 252 | + mock_custom_prompt: Callable[[str], None], |
| 253 | +) -> None: |
| 254 | + """Test that prompts without Jinja2 syntax pass through unchanged.""" |
| 255 | + plain_prompt = "You are an expert RHEL assistant." |
| 256 | + mock_custom_prompt(plain_prompt) |
| 257 | + |
| 258 | + systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64") |
| 259 | + result = _build_instructions(systeminfo) |
| 260 | + |
| 261 | + assert result == plain_prompt |
| 262 | + |
| 263 | + |
| 264 | +@pytest.mark.parametrize( |
| 265 | + "bad_template", |
| 266 | + [ |
| 267 | + pytest.param("Hello {{ unclosed", id="unclosed_variable"), |
| 268 | + pytest.param("{% if %}", id="if_without_condition"), |
| 269 | + pytest.param("{% endfor %}", id="endfor_without_for"), |
| 270 | + ], |
| 271 | +) |
| 272 | +def test_build_instructions_malformed_template_raises_value_error( |
| 273 | + mock_custom_prompt: Callable[[str], None], |
| 274 | + bad_template: str, |
| 275 | +) -> None: |
| 276 | + """Test that invalid Jinja2 syntax in system prompt raises ValueError.""" |
| 277 | + mock_custom_prompt(bad_template) |
| 278 | + |
| 279 | + systeminfo = RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64") |
| 280 | + |
| 281 | + with pytest.raises(ValueError, match="invalid Jinja2 syntax"): |
| 282 | + _build_instructions(systeminfo) |
230 | 283 |
|
231 | 284 |
|
232 | 285 | # --- Test _get_default_model_id --- |
|
0 commit comments