Skip to content

Commit feb3d6e

Browse files
Fix desync between app_inference_spec and validator.
1 parent 5b51897 commit feb3d6e

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

validator/main.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,9 @@ def predict_jailbreak(
241241
prompts: List[str],
242242
reduction_function: Optional[Callable] = max,
243243
) -> Union[List[float], List[dict]]:
244+
"""predict_jailbreak will return an array of floats by default, one per prompt.
245+
If reduction_function is set to 'none' it will return a dict with the different
246+
sub-validators and their scores. Useful for debugging and tuning."""
244247
if isinstance(prompts, str):
245248
print("WARN: predict_jailbreak should be called with a list of strings.")
246249
prompts = [prompts, ]
@@ -308,16 +311,7 @@ def _inference_local(self, model_input: List[str]) -> Any:
308311

309312
def _inference_remote(self, model_input: List[str]) -> Any:
310313
# This needs to be kept in-sync with app_inference_spec.
311-
request_body = {
312-
"inputs": [ # Required for legacy reasons.
313-
{
314-
"name": "prompts",
315-
"shape": [len(model_input)],
316-
"data": model_input,
317-
"datatype": "BYTES"
318-
}
319-
]
320-
}
314+
request_body = {"prompts": model_input}
321315
response = self._hub_inference_request(
322316
json.dumps(request_body),
323317
self.validation_endpoint

0 commit comments

Comments
 (0)