Skip to content

Commit c256bd7

Browse files
Auto-merge PR #114 (feature/evaluator-tool) into integration for testing
2 parents e12aa4e + 720504d commit c256bd7

File tree

1 file changed

+133
-118
lines changed

1 file changed

+133
-118
lines changed

akd/tools/evaluator/base_evaluator.py

Lines changed: 133 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import deque
22
from enum import Enum
3-
from typing import Iterable, List, Optional, Set
3+
from typing import Iterable, List, Optional, Set, Tuple
44

55
from deepeval.metrics import BaseMetric
66
from deepeval.test_case import LLMTestCase, LLMTestCaseParams
@@ -202,54 +202,134 @@ def _post_init(self) -> None:
202202

203203
async def _arun(self, params: LLMEvaluatorInputSchema) -> LLMEvaluatorOutputSchema:
204204
self.metrics = params.metrics or self.metrics
205-
test_case_results = []
206-
all_scores = []
207205

208206
# Case 1: No search_results: simple evaluation
209207
if not params.search_results:
210-
# Only enforce when user didn't supply search_results; the model_validator
211-
# already enforced base presence, but this ensures metric-specific needs.
212-
213-
available_params = self._present_params_available(params)
214-
215-
problems = []
216-
for m in self.metrics:
217-
required = self._collect_metric_required_params(m.value)
218-
if not required:
219-
# Could not introspect; optionally warn instead of failing hard.
220-
# problems.append(f"Warning: could not determine required params for metric '{getattr(m, 'name', type(m).__name__)}'; skipping strict check.")
221-
continue
222-
223-
missing = required - available_params
224-
if missing:
225-
m_name = getattr(m, "name", type(m).__name__)
226-
problems.append(
227-
f"Metric '{m_name}' requires {self._pretty_missing(missing)}, "
228-
f"but they are missing from the provided input.",
229-
)
230-
231-
if problems:
232-
raise ValueError(
233-
"Input does not satisfy the selected metrics:\n- "
234-
+ "\n- ".join(problems)
235-
+ "\n\nSupply the missing fields, remove the offending metrics, or provide a 'search_results'.",
208+
all_scores, test_case_results = self._process_simple_inputs(params)
209+
210+
# Case 2: With search_results: possibly multiple LLMTestCases in future
211+
else:
212+
all_scores, test_case_results = self._process_search_results(params)
213+
214+
avg_score = sum(all_scores) / len(all_scores) if all_scores else 0.0
215+
return LLMEvaluatorOutputSchema(
216+
score=avg_score,
217+
test_case_results=test_case_results,
218+
success=avg_score >= self.threshold,
219+
)
220+
221+
def _process_simple_inputs(
222+
self,
223+
params,
224+
) -> Tuple[List[float], List[TestCaseEvaluationResult]]:
225+
test_case_results = []
226+
all_scores = []
227+
228+
# Only enforce when user didn't supply search_results; the model_validator
229+
# already enforced base presence, but this ensures metric-specific needs.
230+
231+
available_params = self._present_params_available(params)
232+
233+
problems = []
234+
for m in self.metrics:
235+
required = self._collect_metric_required_params(m.value)
236+
if not required:
237+
# Could not introspect; optionally warn instead of failing hard.
238+
# problems.append(f"Warning: could not determine required params for metric '{getattr(m, 'name', type(m).__name__)}'; skipping strict check.")
239+
continue
240+
241+
missing = required - available_params
242+
if missing:
243+
m_name = getattr(m, "name", type(m).__name__)
244+
problems.append(
245+
f"Metric '{m_name}' requires {self._pretty_missing(missing)}, "
246+
f"but they are missing from the provided input.",
247+
)
248+
249+
if problems:
250+
raise ValueError(
251+
"Input does not satisfy the selected metrics:\n- "
252+
+ "\n- ".join(problems)
253+
+ "\n\nSupply the missing fields, remove the offending metrics, or provide a 'search_results'.",
254+
)
255+
256+
params_dict = params.model_dump(exclude_none=True)
257+
filtered_params = {
258+
self.field_mapping[key]: value
259+
for key, value in params_dict.items()
260+
if key in self.field_mapping
261+
}
262+
# for some reason LLMTestCase requires 'input' and 'actual output', even if unused by the metrics
263+
filtered_params.setdefault("input", "")
264+
filtered_params.setdefault("actual_output", "")
265+
test_case = LLMTestCase(**filtered_params)
266+
metric_evals = []
267+
268+
for m in self.metrics:
269+
metric = m.value
270+
metric.measure(test_case)
271+
metric_evals.append(
272+
SingleEvaluationOutputSchema(
273+
score=metric.score,
274+
reason=metric.reason,
275+
metric=metric.name,
276+
),
277+
)
278+
all_scores.append(metric.score)
279+
280+
test_case_results.append(
281+
TestCaseEvaluationResult(
282+
test_case_index=0,
283+
test_case_input=test_case.input,
284+
actual_output=test_case.actual_output,
285+
retrieval_context=test_case.retrieval_context or [],
286+
metric_evaluations=metric_evals,
287+
),
288+
)
289+
290+
return all_scores, test_case_results
291+
292+
def _process_search_results(
293+
self,
294+
params,
295+
) -> Tuple[List[float], List[TestCaseEvaluationResult]]:
296+
test_case_results = []
297+
all_scores = []
298+
299+
# --- Temporary warnings for unsupported metrics (TODO: design new metrics for SearchResultItem) ---
300+
unsupported_metrics = [
301+
m for m in self.metrics if m is not EvalMetricDefinition.STRUCTURE
302+
]
303+
if unsupported_metrics:
304+
metric_names = [
305+
getattr(m, "name", type(m).__name__) for m in unsupported_metrics
306+
]
307+
warning_msg = (
308+
f"Warning: The following metrics are not fully supported with search_results "
309+
f"and will be skipped or may behave unexpectedly: {', '.join(metric_names)}"
310+
)
311+
logger.warning(warning_msg)
312+
313+
# --- Structure-based case ---
314+
if EvalMetricDefinition.STRUCTURE in self.metrics or any(
315+
["structure" in m.name.lower() for m in self.metrics],
316+
):
317+
# for now, the STRUCTURE metric is the only suitable one for search_results
318+
for sr in params.search_results:
319+
sections_test_case = LLMTestCase(
320+
input=sr.query,
321+
actual_output="",
322+
retrieval_context=[sr.content],
323+
expected_output=params.reference,
236324
)
325+
metric_evals = []
237326

238-
params_dict = params.model_dump(exclude_none=True)
239-
filtered_params = {
240-
self.field_mapping[key]: value
241-
for key, value in params_dict.items()
242-
if key in self.field_mapping
243-
}
244-
# for some reason LLMTestCase requires 'input' and 'actual output', even if unused by the metrics
245-
filtered_params.setdefault("input", "")
246-
filtered_params.setdefault("actual_output", "")
247-
test_case = LLMTestCase(**filtered_params)
248-
metric_evals = []
249-
250-
for m in self.metrics:
251-
metric = m.value
252-
metric.measure(test_case)
327+
# metric = EvalMetricDefinition.STRUCTURE.value
328+
metric = next(
329+
(m.value for m in self.metrics if "structure" in m.name.lower()),
330+
None,
331+
)
332+
metric.measure(sections_test_case)
253333
metric_evals.append(
254334
SingleEvaluationOutputSchema(
255335
score=metric.score,
@@ -259,82 +339,17 @@ async def _arun(self, params: LLMEvaluatorInputSchema) -> LLMEvaluatorOutputSche
259339
)
260340
all_scores.append(metric.score)
261341

262-
test_case_results.append(
263-
TestCaseEvaluationResult(
264-
test_case_index=0,
265-
test_case_input=test_case.input,
266-
actual_output=test_case.actual_output,
267-
retrieval_context=test_case.retrieval_context or [],
268-
metric_evaluations=metric_evals,
269-
),
270-
)
271-
272-
# Case 2: With search_results: possibly multiple LLMTestCases in future
273-
else:
274-
# --- Temporary warnings for unsupported metrics (TODO: design new metrics for SearchResultItem) ---
275-
unsupported_metrics = [
276-
m for m in self.metrics if m is not EvalMetricDefinition.STRUCTURE
277-
]
278-
if unsupported_metrics:
279-
metric_names = [
280-
getattr(m, "name", type(m).__name__) for m in unsupported_metrics
281-
]
282-
warning_msg = (
283-
f"Warning: The following metrics are not fully supported with search_results "
284-
f"and will be skipped or may behave unexpectedly: {', '.join(metric_names)}"
342+
test_case_results.append(
343+
TestCaseEvaluationResult(
344+
test_case_index=len(test_case_results),
345+
test_case_input=sections_test_case.input,
346+
actual_output=sections_test_case.actual_output,
347+
retrieval_context=sections_test_case.retrieval_context or [],
348+
metric_evaluations=metric_evals,
349+
),
285350
)
286-
logger.warning(warning_msg)
287-
288-
# --- Structure-based case ---
289-
if EvalMetricDefinition.STRUCTURE in self.metrics or any(
290-
["structure" in m.name.lower() for m in self.metrics],
291-
):
292-
# for now, the STRUCTURE metric is the only suitable one for search_results
293-
for sr in params.search_results:
294-
sections_test_case = LLMTestCase(
295-
input=sr.query,
296-
actual_output="",
297-
retrieval_context=[sr.content],
298-
expected_output=params.reference,
299-
)
300-
metric_evals = []
301-
302-
# metric = EvalMetricDefinition.STRUCTURE.value
303-
metric = next(
304-
(
305-
m.value
306-
for m in self.metrics
307-
if "structure" in m.name.lower()
308-
),
309-
None,
310-
)
311-
metric.measure(sections_test_case)
312-
metric_evals.append(
313-
SingleEvaluationOutputSchema(
314-
score=metric.score,
315-
reason=metric.reason,
316-
metric=metric.name,
317-
),
318-
)
319-
all_scores.append(metric.score)
320-
321-
test_case_results.append(
322-
TestCaseEvaluationResult(
323-
test_case_index=len(test_case_results),
324-
test_case_input=sections_test_case.input,
325-
actual_output=sections_test_case.actual_output,
326-
retrieval_context=sections_test_case.retrieval_context
327-
or [],
328-
metric_evaluations=metric_evals,
329-
),
330-
)
331351

332-
avg_score = sum(all_scores) / len(all_scores) if all_scores else 0.0
333-
return LLMEvaluatorOutputSchema(
334-
score=avg_score,
335-
test_case_results=test_case_results,
336-
success=avg_score >= self.threshold,
337-
)
352+
return all_scores, test_case_results
338353

339354
def _reverse_field_mapping(self, field_mapping: dict[str, str]) -> dict[str, str]:
340355
# e.g. {"input":"input","output":"actual_output"} -> {"input":"input","actual_output":"output"}

0 commit comments

Comments
 (0)