Skip to content

Commit 94a33b4

Browse files
committed
judge detectors conversation refactor (NVIDIA#1346)
2 parents 711718c + 259b998 commit 94a33b4

File tree

4 files changed

+42
-12
lines changed

4 files changed

+42
-12
lines changed

garak/attempt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def last_message(self, role=None) -> Message:
129129
raise ValueError("No messages available")
130130
if not role:
131131
return self.turns[-1].content
132-
for idx in range(len(self.turns), 0):
132+
for idx in range(len(self.turns) - 1, -1, -1):
133133
if role == self.turns[idx].role:
134134
return self.turns[idx].content
135135
raise ValueError(f"No messages for role: {role}")

garak/resources/red_team/evaluation.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,11 @@ def judge_score(self, attack_prompt_list, target_response_list):
122122
Conversation(
123123
[
124124
Turn(
125-
"user",
126-
Message(
127-
self._create_conv(get_evaluator_prompt(prompt, response))
128-
),
125+
role=turn["role"],
126+
content=Message(turn["content"]),
127+
)
128+
for turn in self._create_conv(
129+
get_evaluator_prompt(prompt, response)
129130
)
130131
]
131132
)
@@ -142,13 +143,12 @@ def on_topic_score(self, attempt_list):
142143
Conversation(
143144
[
144145
Turn(
145-
"user",
146-
Message(
147-
self._create_conv(
148-
get_evaluator_prompt_on_topic(prompt),
149-
system_prompt=self.system_prompt_on_topic,
150-
)
151-
),
146+
role=turn["role"],
147+
content=Message(turn["content"]),
148+
)
149+
for turn in self._create_conv(
150+
get_evaluator_prompt_on_topic(prompt),
151+
system_prompt=self.system_prompt_on_topic,
152152
)
153153
]
154154
)

tests/detectors/test_detectors_judge.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,11 @@ def test_klass_detect(judge_config, mocker, klass, response, res_val):
6969
a.outputs = [Message("fake")]
7070
res = d.detect(a)
7171
mock_generate.assert_called_once()
72+
assert (
73+
len(mock_generate.call_args_list[0].args[0].turns) > 1
74+
), "the conversation passed to the judge model must have more than one turn, the system prompt and the evaluated value"
75+
assert isinstance(
76+
mock_generate.call_args_list[0].args[0].last_message("system"), Message
77+
), "a system message must be found in the conversation passed to the judge model"
7278
assert len(a.all_outputs) == len(res)
7379
assert [res_val * len(a.all_outputs)] == res

tests/test_attempt.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,30 @@ def test_conversation_internal_serialize():
105105
assert src_conv == dest
106106

107107

108+
def test_last_message():
109+
test_system_msg = garak.attempt.Message("the system is under control")
110+
test_user_msg = garak.attempt.Message(
111+
"But the point is, if you lie all the time, nobody's going to believe you, even when you're telling the truth."
112+
)
113+
test_assistant_msg = garak.attempt.Message("AI does not understand")
114+
test_user_msg_2 = garak.attempt.Message("That figures")
115+
116+
turns = [
117+
garak.attempt.Turn("system", test_system_msg),
118+
garak.attempt.Turn("user", test_user_msg),
119+
garak.attempt.Turn("assistant", test_assistant_msg),
120+
]
121+
conv = garak.attempt.Conversation(turns)
122+
assert conv.last_message() == test_assistant_msg
123+
assert conv.last_message("system") == test_system_msg
124+
assert conv.last_message("user") == test_user_msg
125+
126+
new_turn = garak.attempt.Turn("user", test_user_msg_2)
127+
conv.turns.append(new_turn)
128+
assert conv.last_message("user") == test_user_msg_2
129+
assert conv.last_message() == test_user_msg_2
130+
131+
108132
##########################
109133
# Test Attempt LifeCycle #
110134
##########################

0 commit comments

Comments
 (0)