Skip to content

Commit

Permalink
Fixing LitQATaskDataset deserialization from config (#494)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Sep 27, 2024
1 parent 0428dd2 commit b14d094
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
16 changes: 12 additions & 4 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,22 @@ class LitQATaskDataset(

def __init__(
self,
base_query: QueryRequest | None = None,
base_docs: Docs | None = None,
base_query: QueryRequest | dict | None = None,
base_docs: Docs | dict | None = None,
rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION,
eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME,
**env_kwargs,
):
self._base_query = base_query or QueryRequest()
self._base_docs = base_docs or Docs()
if base_query is None:
base_query = QueryRequest()
if isinstance(base_query, dict):
base_query = QueryRequest(**base_query)
self._base_query = base_query
if base_docs is None:
base_docs = Docs()
if isinstance(base_docs, dict):
base_docs = Docs(**base_docs)
self._base_docs = base_docs
self._rewards = rewards
self._env_kwargs = env_kwargs
self._eval_model = eval_model
Expand Down
23 changes: 20 additions & 3 deletions tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import patch

import pytest
from aviary.env import TASK_DATASET_REGISTRY, TaskDataset
from aviary.env import TASK_DATASET_REGISTRY, TaskConfig, TaskDataset
from ldp.agent import SimpleAgent
from ldp.alg.callbacks import MeanMetricsCallback
from ldp.alg.runners import Evaluator, EvaluatorConfig
Expand Down Expand Up @@ -78,9 +78,26 @@ def test___len__(
@pytest.mark.asyncio
async def test_evaluation(self, base_query_request: QueryRequest) -> None:
docs = Docs()
dataset = TaskDataset.from_name(
STUB_TASK_DATASET_NAME, base_query=base_query_request, base_docs=docs
# Why are we constructing a TaskConfig here using a serialized QueryRequest and
# Docs? It's to confirm everything works as if hydrating from a YAML config file
task_config = TaskConfig(
name=STUB_TASK_DATASET_NAME,
eval_kwargs={
"base_query": base_query_request.model_dump(
exclude={"id", "settings", "docs_name"}
),
"base_docs": docs.model_dump(
exclude={
"id",
"docnames",
"texts_index",
"index_path",
"deleted_dockeys",
}
),
},
)
dataset = task_config.make_dataset(split="eval") # noqa: FURB184
metrics_callback = MeanMetricsCallback(eval_dataset=dataset)

evaluator = Evaluator(
Expand Down

0 comments on commit b14d094

Please sign in to comment.