Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/cookbook/cookbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ Learn how to extend `genlm-eval` to evaluate models on your own custom domains.

[View Example](custom_domains.ipynb)

### Custom Potentials
Learn how to implement custom potentials to encode constraints and evaluate models with constrained generation. This example walks through:

- Implementing custom potentials to encode constraints
- Creating a model adaptor that uses constrained generation
- Running the evaluation

[View Example](custom_potentials.ipynb)

### Domain-Specific Examples

#### Pattern Matching
Expand Down
137 changes: 99 additions & 38 deletions docs/cookbook/custom_domains.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
"\n",
"1. Define your dataset\n",
"2. Implement an evaluator\n",
"3. Implement a model adaptor\n",
"3. Implement a prompt formatter\n",
"4. Implement a model adaptor\n",
"5. Run the evaluation\n",
"\n",
"The following example demonstrates these steps on the pattern matching domain.\n"
]
Expand All @@ -26,7 +28,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -47,12 +49,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Given a dataset schema, you can define a dataset by subclassing `Dataset` and implementing an `__iter__` method which yields instances of the schema.\n"
"Given a dataset schema, you can define a dataset by subclassing `Dataset` and implementing an `__iter__` method which yields instances of the schema."
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -84,15 +86,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"## 2. Implement an evaluator\n",
"\n",
"An evaluator is the class responsible for scoring model outputs. Subclasses must minimally implement the `evaluate_sample` method which takes an instance and a response and returns an evaluation result."
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -115,35 +116,98 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Implement a model adaptor\n",
"## 3. Implement a prompt formatter\n",
"\n",
"A model adaptor is an async callable that takes a `PatternMatchingInstance` and returns a `ModelOutput`. For this example, we'll use a constrained `genlm.control.PromptedLLM` to generate responses."
"A prompt formatter tokenizes and standardizes the input to the model by optionally adding a system prompt and few-shot examples for the evaluation."
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniconda/base/envs/genlm/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"/opt/homebrew/Caskroom/miniconda/base/envs/genlm/lib/python3.11/site-packages/genlm/backend/tokenization/vocab.py:98: UserWarning: Duplicate tokens found in string vocabulary. This may lead to downstream issues with the string vocabulary; we recommend using the byte vocabulary.\n",
" warnings.warn(\n"
]
}
],
"outputs": [],
"source": [
"from genlm.control import PromptedLLM, AWRS\n",
"from genlm.eval import ModelOutput, ModelResponse\n",
"from genlm.eval.domains.pattern_matching import (\n",
" default_prompt_formatter,\n",
" PatternPotential,\n",
"from genlm.eval.util import chat_template_messages\n",
"\n",
"\n",
"FEW_SHOT_EXAMPLES = [\n",
" (\"(ab)+\", \"ab\"),\n",
" (\"(ab|cd)+\", \"cd\"),\n",
" (\"[a-z]+\", \"hello\"),\n",
"]\n",
"\n",
"\n",
"SYSTEM_PROMPT = (\n",
" \"You are a helpful assistant that generates strings matching regular expressions. \"\n",
" + \"Only output the exact string that matches the regex pattern, nothing more.\"\n",
")\n",
"\n",
"\n",
"def default_prompt_formatter(\n",
" tokenizer,\n",
" instance,\n",
" use_chat_format=False,\n",
" system_prompt=SYSTEM_PROMPT,\n",
" few_shot_examples=FEW_SHOT_EXAMPLES,\n",
"):\n",
" \"\"\"Default prompt formatter for pattern matching.\n",
"\n",
" Args:\n",
" tokenizer (Tokenizer): The tokenizer to use.\n",
" instance (PatternMatchingInstance): The instance to format.\n",
" use_chat_format (bool): Whether to use chat format.\n",
" system_prompt (str): The system prompt to use.\n",
" few_shot_examples (list[tuple[str, str]]): The few shot examples to use. Each example is a tuple of (pattern, response).\n",
"\n",
" Returns:\n",
" (list[int]): The prompt ids.\n",
" \"\"\"\n",
" if use_chat_format:\n",
" return tokenizer.apply_chat_template(\n",
" chat_template_messages(\n",
" system_prompt,\n",
" few_shot_examples,\n",
" instance.pattern,\n",
" ),\n",
" tokenize=True,\n",
" add_generation_prompt=True,\n",
" )\n",
" else:\n",
" return tokenizer.encode(\n",
" (\n",
" system_prompt\n",
" + \"\\n\"\n",
" + \"\\n\".join(\n",
" f\"Pattern: {input}\\nOutput: {output}\"\n",
" for input, output in few_shot_examples\n",
" )\n",
" + \"\\n\"\n",
" + instance.pattern\n",
" )\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Implement a model adaptor\n",
"\n",
"A model adaptor is an async callable that takes a dataset instance (here, a `PatternMatchingInstance`) and returns a `ModelOutput`. \n",
"For this example, we'll use a `PromptedLLM` that proposes tokens by sampling directly from the LM's distribution.\n",
" \n",
"See [custom_potentials.ipynb](custom_potentials.ipynb) for a tutorial on how to implement custom constraints in genlm-control and evaluate the model."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from genlm.control import PromptedLLM, direct_token_sampler\n",
"from genlm.eval import ModelOutput, ModelResponse\n",
"\n",
"# Load an LLM\n",
"LLM = PromptedLLM.from_name(\"gpt2\", eos_tokens=[b\"\\n\", b\"\\n\\n\"])\n",
"\n",
Expand All @@ -154,13 +218,10 @@
" LLM.model.tokenizer, instance, use_chat_format=False\n",
" )\n",
"\n",
" # Define a potential that ensures the generated text matches the pattern\n",
" potential = PatternPotential(instance.pattern).coerce(LLM, f=b\"\".join)\n",
"\n",
" # Define an adaptive weighted rejection sampler to sample tokens from the constrained model.\n",
" sampler = AWRS(LLM, potential)\n",
" # Load a sampler that proposes tokens by sampling directly from the LM's distribution\n",
" sampler = direct_token_sampler(LLM)\n",
"\n",
" # Run SMC to sample sequences from the constrained model.\n",
" # Run SMC with 5 particles, a maximum of 25 tokens, and an ESS threshold of 0.5\n",
" sequences = await sampler.smc(\n",
" n_particles=5,\n",
" ess_threshold=0.5,\n",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ess_threshold should be zero if we are directly sampling from LLM

Expand All @@ -180,27 +241,27 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Run the evaluation\n",
"## 5. Run the evaluation\n",
"\n",
"Using the dataset, evaluator, and model adaptor, we can now run the evaluation:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Instance instance_id=0 pattern='xy|xz'\n",
"Mean weighted accuracy (instance): 0.9999999999999999\n",
"Mean weighted accuracy (total): 0.9999999999999999\n",
"Mean weighted accuracy (instance): 0.0\n",
"Mean weighted accuracy (total): 0.0\n",
"\n",
"Instance instance_id=1 pattern='ab|c(e|f)'\n",
"Mean weighted accuracy (instance): 1.0\n",
"Mean weighted accuracy (total): 1.0\n",
"Mean weighted accuracy (instance): 0.0\n",
"Mean weighted accuracy (total): 0.0\n",
"\n"
]
}
Expand Down
Loading