Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/cookbook/cookbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ Learn how to extend `genlm-eval` to evaluate models on your own custom domains.

[View Example](custom_domains.ipynb)

### Custom Potentials
Learn how to implement custom potentials to encode constraints and evaluate models with constrained generation. This example walks through:

- Implementing custom potentials to encode constraints
- Creating a model adaptor that uses constrained generation
- Running the evaluation

[View Example](custom_potentials.ipynb)

### Domain-Specific Examples

#### Pattern Matching
Expand Down
142 changes: 105 additions & 37 deletions docs/cookbook/custom_domains.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
"\n",
"1. Define your dataset\n",
"2. Implement an evaluator\n",
"3. Implement a model adaptor\n",
"3. Implement a prompt formatter\n",
"4. Implement a model adaptor\n",
"5. Run the evaluation\n",
"\n",
"The following example demonstrates these steps on the pattern matching domain.\n"
]
Expand Down Expand Up @@ -47,7 +49,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Given a dataset schema, you can define a dataset by subclassing `Dataset` and implementing an `__iter__` method which yields instances of the schema.\n"
"Given a dataset schema, you can define a dataset by subclassing `Dataset` and implementing an `__iter__` method which yields instances of the schema."
]
},
{
Expand Down Expand Up @@ -84,7 +86,6 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"## 2. Implement an evaluator\n",
"\n",
"An evaluator is the class responsible for scoring model outputs. Subclasses must minimally implement the `evaluate_sample` method which takes an instance and a response and returns an evaluation result."
Expand Down Expand Up @@ -115,35 +116,98 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Implement a model adaptor\n",
"## 3. Implement a prompt formatter\n",
"\n",
"A model adaptor is an async callable that takes a `PatternMatchingInstance` and returns a `ModelOutput`. For this example, we'll use a constrained `genlm.control.PromptedLLM` to generate responses."
"A prompt formatter tokenizes and standardizes the input to the model by optionally adding a system prompt and few-shot examples for the evaluation."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniconda/base/envs/genlm/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"/opt/homebrew/Caskroom/miniconda/base/envs/genlm/lib/python3.11/site-packages/genlm/backend/tokenization/vocab.py:98: UserWarning: Duplicate tokens found in string vocabulary. This may lead to downstream issues with the string vocabulary; we recommend using the byte vocabulary.\n",
" warnings.warn(\n"
]
}
],
"outputs": [],
"source": [
"from genlm.control import PromptedLLM, AWRS\n",
"from genlm.eval import ModelOutput, ModelResponse\n",
"from genlm.eval.domains.pattern_matching import (\n",
" default_prompt_formatter,\n",
" PatternPotential,\n",
"from genlm.eval.util import chat_template_messages\n",
"\n",
"\n",
"FEW_SHOT_EXAMPLES = [\n",
" (\"(ab)+\", \"ab\"),\n",
" (\"(ab|cd)+\", \"cd\"),\n",
" (\"[a-z]+\", \"hello\"),\n",
"]\n",
"\n",
"\n",
"SYSTEM_PROMPT = (\n",
" \"You are a helpful assistant that generates strings matching regular expressions. \"\n",
" + \"Only output the exact string that matches the regex pattern, nothing more.\"\n",
")\n",
"\n",
"\n",
"def default_prompt_formatter(\n",
" tokenizer,\n",
" instance,\n",
" use_chat_format=False,\n",
" system_prompt=SYSTEM_PROMPT,\n",
" few_shot_examples=FEW_SHOT_EXAMPLES,\n",
"):\n",
" \"\"\"Default prompt formatter for pattern matching.\n",
"\n",
" Args:\n",
" tokenizer (Tokenizer): The tokenizer to use.\n",
" instance (PatternMatchingInstance): The instance to format.\n",
" use_chat_format (bool): Whether to use chat format.\n",
" system_prompt (str): The system prompt to use.\n",
" few_shot_examples (list[tuple[str, str]]): The few shot examples to use. Each example is a tuple of (pattern, response).\n",
"\n",
" Returns:\n",
" (list[int]): The prompt ids.\n",
" \"\"\"\n",
" if use_chat_format:\n",
" return tokenizer.apply_chat_template(\n",
" chat_template_messages(\n",
" system_prompt,\n",
" few_shot_examples,\n",
" instance.pattern,\n",
" ),\n",
" tokenize=True,\n",
" add_generation_prompt=True,\n",
" )\n",
" else:\n",
" return tokenizer.encode(\n",
" (\n",
" system_prompt\n",
" + \"\\n\"\n",
" + \"\\n\".join(\n",
" f\"Pattern: {input}\\nOutput: {output}\"\n",
" for input, output in few_shot_examples\n",
" )\n",
" + \"\\n\"\n",
" + instance.pattern\n",
" )\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Implement a model adaptor\n",
"\n",
"A model adaptor is an async callable that takes a dataset instance (here, a `PatternMatchingInstance`) and returns a `ModelOutput`. \n",
"For this example, we'll use a `PromptedLLM` that proposes tokens by sampling directly from the LM's distribution.\n",
" \n",
"See [custom_potentials.ipynb](custom_potentials.ipynb) for a tutorial on how to implement custom constraints in genlm-control and evaluate the model."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from genlm.control import PromptedLLM, direct_token_sampler\n",
"from genlm.eval import ModelOutput, ModelResponse\n",
"\n",
"# Load an LLM\n",
"LLM = PromptedLLM.from_name(\"gpt2\", eos_tokens=[b\"\\n\", b\"\\n\\n\"])\n",
"\n",
Expand All @@ -154,17 +218,14 @@
" LLM.model.tokenizer, instance, use_chat_format=False\n",
" )\n",
"\n",
" # Define a potential that ensures the generated text matches the pattern\n",
" potential = PatternPotential(instance.pattern).coerce(LLM, f=b\"\".join)\n",
"\n",
" # Define an adaptive weighted rejection sampler to sample tokens from the constrained model.\n",
" sampler = AWRS(LLM, potential)\n",
" # Load a sampler that proposes tokens by sampling directly from the LM's distribution\n",
" sampler = direct_token_sampler(LLM)\n",
"\n",
" # Run SMC to sample sequences from the constrained model.\n",
" # Run SMC with 5 particles and a maximum of 25 tokens\n",
" sequences = await sampler.smc(\n",
" n_particles=5,\n",
" ess_threshold=0.5,\n",
" max_tokens=100,\n",
" ess_threshold=0.0,\n",
" )\n",
"\n",
" # Return the sampled sequences and their probabilities as a ModelOutput.\n",
Expand All @@ -180,27 +241,27 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Run the evaluation\n",
"## 5. Run the evaluation\n",
"\n",
"Using the dataset, evaluator, and model adaptor, we can now run the evaluation:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Instance instance_id=0 pattern='xy|xz'\n",
"Mean weighted accuracy (instance): 0.9999999999999999\n",
"Mean weighted accuracy (total): 0.9999999999999999\n",
"Mean weighted accuracy (instance): 0.0\n",
"Mean weighted accuracy (total): 0.0\n",
"\n",
"Instance instance_id=1 pattern='ab|c(e|f)'\n",
"Mean weighted accuracy (instance): 1.0\n",
"Mean weighted accuracy (total): 1.0\n",
"Mean weighted accuracy (instance): 0.0\n",
"Mean weighted accuracy (total): 0.0\n",
"\n"
]
}
Expand All @@ -220,11 +281,18 @@
" # output_dir=\"results\", # uncomment to save results\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "genlm",
"display_name": "gen",
"language": "python",
"name": "python3"
},
Expand All @@ -238,7 +306,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
"version": "3.12.11"
}
},
"nbformat": 4,
Expand Down
Loading