Skip to content

Commit

Permalink
feat: DIA-1122: [LSE BE] task-level error handling (#180)
Browse files Browse the repository at this point in the history
Co-authored-by: Matt Bernstein <[email protected]>
Co-authored-by: nik <[email protected]>
Co-authored-by: farioas <[email protected]>
  • Loading branch information
4 people authored Aug 14, 2024
1 parent 15fab7e commit 97efe2c
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions server/handlers/result_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class LSEBatchItem(BaseModel):
task_id: int
# TODO this field no longer populates if there was an error, so validation fails without a default - should probably split this item into 3 different constructors corresponding to new internal adala objects (or just reuse those objects)
output: Optional[str] = None
# TODO handle in DIA-1122
# we don't need to use reserved names anymore here because they're not in a DataFrame, but a structure with proper typing available
error: bool = Field(False, alias="_adala_error")
message: Optional[str] = Field(None, alias="_adala_message")
Expand Down Expand Up @@ -131,6 +130,19 @@ def ready(self):

return self

def prepare_errors_payload(self, error_batch):
transformed_errors = []
for error in error_batch:
error = error.dict()
transformed_error = {
"task_id": error["task_id"],
"message": error["details"] if "details" in error else "",
"error_type": error["message"] if "message" in error else ""
}
transformed_errors.append(transformed_error)

return transformed_errors

def __call__(self, result_batch: list[LSEBatchItem]):
logger.debug(f"\n\nHandler received batch: {result_batch}\n\n")

Expand All @@ -152,9 +164,8 @@ def __call__(self, result_batch: list[LSEBatchItem]):

norm_result_batch.append(LSEBatchItem(**result))

# omit failed tasks for now
# TODO handle in DIA-1122
result_batch = [record for record in norm_result_batch if not record.error]
error_batch = [record for record in norm_result_batch if record.error]

# coerce back to dicts for sending
result_batch = [record.dict() for record in result_batch]
Expand All @@ -172,6 +183,22 @@ def __call__(self, result_batch: list[LSEBatchItem]):
else:
logger.error(f'No valid results to send to LSE for modelrun_id {self.modelrun_id}')

# Send failed predictions back to LSE
if error_batch:
error_batch = self.prepare_errors_payload(error_batch)
self.client.make_request(
"POST",
"/api/model-run/batch-failed-predictions",
data=json.dumps(
{
"modelrun_id": self.modelrun_id,
"failed_predictions": error_batch,
}
)
)
else:
logger.debug(f'No errors to send to LSE for modelrun_id {self.modelrun_id}')


class CSVHandler(ResultHandler):
"""
Expand Down

0 comments on commit 97efe2c

Please sign in to comment.