Skip to content

Commit e67b4b4

Browse files
authored
fix issue with passing msg history on reasking (#346)
fix issue when msg_history passed as kwargs instead of prompt params
1 parent ff186ee commit e67b4b4

File tree

5 files changed

+68
-8
lines changed

5 files changed

+68
-8
lines changed

guardrails/guard.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -466,9 +466,9 @@ def _sync_parse(
466466
"""
467467
with start_action(action_type="guard_parse"):
468468
runner = Runner(
469-
instructions=kwargs.get("instructions", None),
470-
prompt=kwargs.get("prompt", None),
471-
msg_history=kwargs.get("msg_history", None),
469+
instructions=kwargs.pop("instructions", None),
470+
prompt=kwargs.pop("prompt", None),
471+
msg_history=kwargs.pop("msg_history", None),
472472
api=get_llm_ask(llm_api, *args, **kwargs) if llm_api else None,
473473
input_schema=None,
474474
output_schema=self.output_schema,
@@ -507,9 +507,9 @@ async def _async_parse(
507507
"""
508508
with start_action(action_type="guard_parse"):
509509
runner = AsyncRunner(
510-
instructions=None,
511-
prompt=None,
512-
msg_history=None,
510+
instructions=kwargs.pop("instructions", None),
511+
prompt=kwargs.pop("prompt", None),
512+
msg_history=kwargs.pop("msg_history", None),
513513
api=get_async_llm_ask(llm_api, *args, **kwargs) if llm_api else None,
514514
input_schema=None,
515515
output_schema=self.output_schema,

tests/integration_tests/mock_llm_outputs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ def _invoke_llm(
8888
pydantic.MSG_COMPILED_PROMPT_REASK,
8989
pydantic.MSG_COMPILED_INSTRUCTIONS_REASK,
9090
): pydantic.MSG_HISTORY_LLM_OUTPUT_CORRECT,
91+
(
92+
string.PARSE_COMPILED_PROMPT_REASK,
93+
string.MSG_COMPILED_INSTRUCTIONS_REASK,
94+
): string.MSG_LLM_OUTPUT_CORRECT,
9195
}
9296

9397
try:

tests/integration_tests/test_assets/string/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
]
3131
MSG_COMPILED_PROMPT_REASK = reader("msg_compiled_prompt_reask.txt")
3232
MSG_COMPILED_INSTRUCTIONS_REASK = reader("msg_compiled_instructions_reask.txt")
33+
PARSE_COMPILED_PROMPT_REASK = reader("parse_compiled_prompt_reask.txt")
3334

3435
__all__ = [
3536
"COMPILED_PROMPT",
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
This was a previous response you generated:
2+
3+
======
4+
Tomato Cheese Pizza
5+
======
6+
7+
Generate a new response that corrects your old response such that the following issues are fixed
8+
- Value Tomato Cheese Pizza should fail.
9+
10+
Here's a description of what I want you to generate: Some description
11+
12+
Your generated response should satisfy the following properties:
13+
- always_fail
14+
15+
Don't talk; just go.

tests/integration_tests/test_parsing.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1+
from typing import Dict
2+
3+
import openai
14
import pytest
25

36
import guardrails as gd
4-
5-
from .mock_llm_outputs import MockArbitraryCallable, MockAsyncArbitraryCallable
7+
from guardrails import register_validator
8+
from guardrails.validators import FailResult, ValidationResult
9+
10+
from .mock_llm_outputs import (
11+
MockArbitraryCallable,
12+
MockAsyncArbitraryCallable,
13+
MockOpenAIChatCallable,
14+
)
615
from .test_assets import pydantic
716

817

@@ -80,3 +89,34 @@ async def mock_async_callable(prompt: str):
8089
assert guard_history[1].prompt == gd.Prompt(pydantic.PARSING_COMPILED_REASK)
8190
assert guard_history[1].output == pydantic.PARSING_EXPECTED_LLM_OUTPUT
8291
assert guard_history[1].validated_output == pydantic.PARSING_EXPECTED_OUTPUT
92+
93+
94+
def test_reask_prompt_instructions(mocker):
95+
"""Test that the re-ask prompt and instructions are correct.
96+
97+
This is done implicitly, since if the incorrect prompt or instructions
98+
are used, the mock LLM will raise a KeyError.
99+
"""
100+
101+
mocker.patch(
102+
"guardrails.llm_providers.OpenAIChatCallable",
103+
new=MockOpenAIChatCallable,
104+
)
105+
106+
@register_validator(name="always_fail", data_type="string")
107+
def always_fail(value: str, metadata: Dict) -> ValidationResult:
108+
return FailResult(error_message=f"Value {value} should fail.")
109+
110+
guard = gd.Guard.from_string(
111+
validators=[(always_fail, "reask")],
112+
description="Some description",
113+
)
114+
115+
guard.parse(
116+
llm_output="Tomato Cheese Pizza",
117+
llm_api=openai.ChatCompletion.create,
118+
msg_history=[
119+
{"role": "system", "content": "Some content"},
120+
{"role": "user", "content": "Some prompt"},
121+
],
122+
)

0 commit comments

Comments
 (0)