Skip to content

Commit

Permalink
Merge branch 'Future-House:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
taabishm2 committed Sep 14, 2024
2 parents bb47be5 + 49273f2 commit d9258a9
Show file tree
Hide file tree
Showing 6 changed files with 381 additions and 10 deletions.
59 changes: 59 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Contributing to PaperQA

Thank you for your interest in contributing to PaperQA!
Here are some guidelines to help you get started.

## Setting up the development environment

We use [`uv`](https://github.com/astral-sh/uv) for our local development.

1. Install `uv` by following the instructions on the [uv website](https://astral.sh/uv/).
2. Run the following command to install all dependencies and set up the development environment:

```bash
uv sync
```

## Installing the package for development

If you prefer to use `pip` for installing the package in development mode, you can do so by running:

```bash
pip install -e .
```

## Running tests and other tooling

Use the following commands:

- Run tests (requires an OpenAI key in your environment)

```bash
pytest
# or for multiprocessing based parallelism
pytest -n auto
```

- Run `pre-commit` for formatting and type checking

```bash
pre-commit run --all-files
```

- Run `mypy`, `refurb`, or `pylint` directly:

```bash
mypy paperqa
# or
refurb paperqa
# or
pylint paperqa
```

See our GitHub Actions [`tests.yml`](.github/workflows/tests.yml) for further reference.

## Additional resources

For more information on contributing, please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file in the repository.

Happy coding!
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,18 @@ PaperQA2 depends on some awesome libraries/APIs that make our repo possible. Her
7. [pybtex](https://pybtex.org/)
8. [PyMuPDF](https://pymupdf.readthedocs.io/en/latest/)

## Install
## Installation

You can install PaperQA2 via pip:
For a non-development setup,
install PaperQA2 from [PyPI](https://pypi.org/project/paper-qa/):

```bash
pip install paper-qa
```

For development setup,
please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file.

PaperQA2 uses an LLM to operate,
so you'll need to either set an appropriate [API key environment variable][LiteLLM providers] (i.e. `export OPENAI_API_KEY=sk-...`)
or set up an open source LLM server (i.e. using [llamafile](https://github.com/Mozilla-Ocho/llamafile).
Expand Down
7 changes: 3 additions & 4 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import logging
from typing import cast

from aviary.env import Environment as _Environment
from aviary.env import Frame
from aviary.env import Environment, Frame
from aviary.message import Message
from aviary.tools import Tool, ToolRequestMessage, ToolResponseMessage

Expand Down Expand Up @@ -91,8 +90,8 @@ def settings_to_tools(
return tools


class Environment(_Environment[EnvironmentState]):
"""Environment to connect agents with paper-qa."""
class PaperQAEnvironment(Environment[EnvironmentState]):
"""Environment connecting paper-qa's tools with state."""

def __init__(
self,
Expand Down
8 changes: 4 additions & 4 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from paperqa.types import Answer
from paperqa.utils import pqa_directory

from .env import Environment
from .env import PaperQAEnvironment
from .helpers import litellm_get_search_query, table_formatter
from .models import AgentStatus, AnswerResponse, QueryRequest, SimpleProfiler
from .search import SearchDocumentStorage, SearchIndex
Expand Down Expand Up @@ -235,7 +235,7 @@ async def run_fake_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
env = Environment(query, docs, **env_kwargs)
env = PaperQAEnvironment(query, docs, **env_kwargs)
_, tools = await env.reset()
if on_env_reset_callback:
await on_env_reset_callback(env.state)
Expand Down Expand Up @@ -281,7 +281,7 @@ async def run_aviary_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
env = Environment(query, docs, **env_kwargs)
env = PaperQAEnvironment(query, docs, **env_kwargs)
done = False

try:
Expand Down Expand Up @@ -345,7 +345,7 @@ async def run_ldp_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
env = Environment(query, docs, **env_kwargs)
env = PaperQAEnvironment(query, docs, **env_kwargs)
done = False

try:
Expand Down
209 changes: 209 additions & 0 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
__all__ = [
"ENV_NAME",
"TASK_DATASET_NAME",
"GradablePaperQAEnvironment",
"LitQATaskDataset",
"LitQAv2TaskDataset",
"LitQAv2TaskSplit",
]

from abc import ABC
from collections.abc import Awaitable, Callable, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, assert_never

from aviary.env import ENV_REGISTRY, TASK_DATASET_REGISTRY, Frame, TaskDataset
from aviary.message import Message
from aviary.tools import ToolRequestMessage, ToolResponseMessage

try:
from ldp.alg.callbacks import ComputeTrajectoryMetricsMixin
except ImportError:

class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef]
"""Placeholder for when ldp isn't installed."""


from paperqa.docs import Docs
from paperqa.litqa import (
DEFAULT_EVAL_MODEL_NAME,
DEFAULT_LABBENCH_HF_HUB_NAME,
DEFAULT_REWARD_DISTRIBUTION,
LitQAEvaluation,
read_litqa_v2_from_hub,
)
from paperqa.llms import EmbeddingModel, LiteLLMModel, LLMModel
from paperqa.types import Answer

from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
from .models import QueryRequest
from .tools import GenerateAnswer

if TYPE_CHECKING:
from ldp.data_structures import Trajectory


class GradablePaperQAEnvironment(PaperQAEnvironment):
"""Extended environment that can grade answers."""

def __init__(
self,
query: QueryRequest,
docs: Docs,
llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
evaluation_from_answer: (
Callable[[Answer | str], Awaitable[LitQAEvaluation]] | None
) = None,
rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION,
evaluation_callback: Callable[[LitQAEvaluation], Awaitable] | None = None,
**env_kwargs,
):
super().__init__(
query, docs, llm_model, summary_llm_model, embedding_model, **env_kwargs
)
self._evaluation_from_answer = evaluation_from_answer
self._evaluation_callback = evaluation_callback
self._rewards = rewards

async def step(
self, action: ToolRequestMessage
) -> tuple[list[Message], float, bool, bool]:
messages, reward, done, truncated = await super().step(action)
if not done or not self._evaluation_from_answer:
return messages, reward, done, truncated
# Filter out non-answer messages (in case parallel tool calls)
answer_tool_messages = [
m
for m in messages
if isinstance(m, ToolResponseMessage)
and m.name == GenerateAnswer.gen_answer.__name__
]
if not answer_tool_messages: # No answer, so no positive reward
return messages, reward, done, truncated
if len(answer_tool_messages) != 1:
raise NotImplementedError(
f"Expected just one answer message, got {messages}."
)
answer = GenerateAnswer.extract_answer_from_message(
content=answer_tool_messages[0].content
)
if not answer:
return messages, reward, done, truncated
evaluation = await self._evaluation_from_answer(answer)
if evaluation_callback := self._evaluation_callback:
await evaluation_callback(evaluation)
return messages, reward + self._rewards[evaluation.value], done, truncated

def export_frame(self) -> Frame:
raise NotImplementedError("Didn't yet need to export a frame.")


ENV_NAME = "paperqa-local"
ENV_REGISTRY[ENV_NAME] = (
GradablePaperQAEnvironment.__module__,
GradablePaperQAEnvironment.__name__,
)


class LitQATaskDataset(
TaskDataset[GradablePaperQAEnvironment], ComputeTrajectoryMetricsMixin, ABC
):
"""
Abstract base class for a task dataset of LitQA v1 or v2 questions.
This is an ABC because it's non-specific to a LitQA version.
Examples include LitQA v1, v2, or a test stub version of LitQA.
"""

def __init__(
self,
base_query_request: QueryRequest,
rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION,
eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME,
**env_kwargs,
):
self._base_query_request = base_query_request
self._rewards = rewards
self._env_kwargs = env_kwargs
self._eval_model = eval_model

def _make_gradable_environment(
self,
ideal: str,
distractors: str | list[str],
question: str,
use_unsure: bool = True,
) -> GradablePaperQAEnvironment:
qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question(
ideal=ideal,
distractors=distractors,
question=question,
use_unsure=use_unsure,
eval_model=self._eval_model,
)
query_request = self._base_query_request.model_copy()
query_request.query = qa_prompt
return GradablePaperQAEnvironment(
query=query_request,
evaluation_from_answer=evaluation_from_answer,
rewards=self._rewards,
**self._env_kwargs,
)

def compute_trajectory_metrics(
self, trajectories: "Sequence[Trajectory]"
) -> dict[str, list[float]]:
return super().compute_trajectory_metrics(trajectories) | {
"correct": [
int(traj.steps[-1].reward == self._rewards[0]) for traj in trajectories
],
"correct_unsure": [
int(traj.steps[-1].reward in {self._rewards[0], self._rewards[1]})
for traj in trajectories
],
}


class LitQAv2TaskSplit(StrEnum):
TRAIN = "train"
EVAL = "eval"


class LitQAv2TaskDataset(LitQATaskDataset):
"""Task dataset of LitQA v2 questions."""

def __init__(
self,
*args,
labbench_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
split: str | LitQAv2TaskSplit = LitQAv2TaskSplit.EVAL,
**kwargs,
):
super().__init__(*args, **kwargs)
train_df, eval_df = read_litqa_v2_from_hub(labbench_dataset)
split = LitQAv2TaskSplit(split)
if split == LitQAv2TaskSplit.TRAIN:
self.data = train_df
elif split == LitQAv2TaskSplit.EVAL:
self.data = eval_df
else:
assert_never(split)

def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment:
return self._make_gradable_environment(
ideal=self.data.iloc[idx].ideal,
distractors=self.data.iloc[idx].distractors,
question=self.data.iloc[idx].question,
)

def __len__(self) -> int:
return len(self.data)


TASK_DATASET_NAME = "litqa-v2"
TASK_DATASET_REGISTRY[TASK_DATASET_NAME] = (
LitQAv2TaskDataset.__module__,
LitQAv2TaskDataset.__name__,
)
Loading

0 comments on commit d9258a9

Please sign in to comment.