diff --git a/docs/cookbook/cookbook.md b/docs/cookbook/cookbook.md index 98b816d..48d3524 100644 --- a/docs/cookbook/cookbook.md +++ b/docs/cookbook/cookbook.md @@ -19,6 +19,15 @@ Learn how to extend `genlm-eval` to evaluate models on your own custom domains. [View Example](custom_domains.ipynb) +### Custom Potentials +Learn how to implement custom potentials to encode constraints and evaluate models with constrained generation. This example walks through: + +- Implementing custom potentials to encode constraints +- Creating a model adaptor that uses constrained generation +- Running the evaluation + +[View Example](custom_potentials.ipynb) + ### Domain-Specific Examples #### Pattern Matching diff --git a/docs/cookbook/custom_domains.ipynb b/docs/cookbook/custom_domains.ipynb index 0f5fd5a..623d6ef 100644 --- a/docs/cookbook/custom_domains.ipynb +++ b/docs/cookbook/custom_domains.ipynb @@ -10,7 +10,9 @@ "\n", "1. Define your dataset\n", "2. Implement an evaluator\n", - "3. Implement a model adaptor\n", + "3. Implement a prompt formatter\n", + "4. Implement a model adaptor\n", + "5. Run the evaluation\n", "\n", "The following example demonstrates these steps on the pattern matching domain.\n" ] @@ -47,7 +49,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Given a dataset schema, you can define a dataset by subclassing `Dataset` and implementing an `__iter__` method which yields instances of the schema.\n" + "Given a dataset schema, you can define a dataset by subclassing `Dataset` and implementing an `__iter__` method which yields instances of the schema." ] }, { @@ -84,7 +86,6 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n", "## 2. Implement an evaluator\n", "\n", "An evaluator is the class responsible for scoring model outputs. Subclasses must minimally implement the `evaluate_sample` method which takes an instance and a response and returns an evaluation result." @@ -115,35 +116,98 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 3. Implement a model adaptor\n", + "## 3. Implement a prompt formatter\n", "\n", - "A model adaptor is an async callable that takes a `PatternMatchingInstance` and returns a `ModelOutput`. For this example, we'll use a constrained `genlm.control.PromptedLLM` to generate responses." + "A prompt formatter tokenizes and standardizes the input to the model by optionally adding a system prompt and few-shot examples for the evaluation." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/homebrew/Caskroom/miniconda/base/envs/genlm/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "/opt/homebrew/Caskroom/miniconda/base/envs/genlm/lib/python3.11/site-packages/genlm/backend/tokenization/vocab.py:98: UserWarning: Duplicate tokens found in string vocabulary. This may lead to downstream issues with the string vocabulary; we recommend using the byte vocabulary.\n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ - "from genlm.control import PromptedLLM, AWRS\n", - "from genlm.eval import ModelOutput, ModelResponse\n", - "from genlm.eval.domains.pattern_matching import (\n", - " default_prompt_formatter,\n", - " PatternPotential,\n", + "from genlm.eval.util import chat_template_messages\n", + "\n", + "\n", + "FEW_SHOT_EXAMPLES = [\n", + " (\"(ab)+\", \"ab\"),\n", + " (\"(ab|cd)+\", \"cd\"),\n", + " (\"[a-z]+\", \"hello\"),\n", + "]\n", + "\n", + "\n", + "SYSTEM_PROMPT = (\n", + " \"You are a helpful assistant that generates strings matching regular expressions. \"\n", + " + \"Only output the exact string that matches the regex pattern, nothing more.\"\n", ")\n", "\n", + "\n", + "def default_prompt_formatter(\n", + " tokenizer,\n", + " instance,\n", + " use_chat_format=False,\n", + " system_prompt=SYSTEM_PROMPT,\n", + " few_shot_examples=FEW_SHOT_EXAMPLES,\n", + "):\n", + " \"\"\"Default prompt formatter for pattern matching.\n", + "\n", + " Args:\n", + " tokenizer (Tokenizer): The tokenizer to use.\n", + " instance (PatternMatchingInstance): The instance to format.\n", + " use_chat_format (bool): Whether to use chat format.\n", + " system_prompt (str): The system prompt to use.\n", + " few_shot_examples (list[tuple[str, str]]): The few shot examples to use. Each example is a tuple of (pattern, response).\n", + "\n", + " Returns:\n", + " (list[int]): The prompt ids.\n", + " \"\"\"\n", + " if use_chat_format:\n", + " return tokenizer.apply_chat_template(\n", + " chat_template_messages(\n", + " system_prompt,\n", + " few_shot_examples,\n", + " instance.pattern,\n", + " ),\n", + " tokenize=True,\n", + " add_generation_prompt=True,\n", + " )\n", + " else:\n", + " return tokenizer.encode(\n", + " (\n", + " system_prompt\n", + " + \"\\n\"\n", + " + \"\\n\".join(\n", + " f\"Pattern: {input}\\nOutput: {output}\"\n", + " for input, output in few_shot_examples\n", + " )\n", + " + \"\\n\"\n", + " + instance.pattern\n", + " )\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Implement a model adaptor\n", + "\n", + "A model adaptor is an async callable that takes a dataset instance (here, a `PatternMatchingInstance`) and returns a `ModelOutput`. \n", + "For this example, we'll use a `PromptedLLM` that proposes tokens by sampling directly from the LM's distribution.\n", + " \n", + "See [custom_potentials.ipynb](custom_potentials.ipynb) for a tutorial on how to implement custom constraints in genlm-control and evaluate the model." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from genlm.control import PromptedLLM, direct_token_sampler\n", + "from genlm.eval import ModelOutput, ModelResponse\n", + "\n", "# Load an LLM\n", "LLM = PromptedLLM.from_name(\"gpt2\", eos_tokens=[b\"\\n\", b\"\\n\\n\"])\n", "\n", @@ -154,17 +218,14 @@ " LLM.model.tokenizer, instance, use_chat_format=False\n", " )\n", "\n", - " # Define a potential that ensures the generated text matches the pattern\n", - " potential = PatternPotential(instance.pattern).coerce(LLM, f=b\"\".join)\n", - "\n", - " # Define an adaptive weighted rejection sampler to sample tokens from the constrained model.\n", - " sampler = AWRS(LLM, potential)\n", + " # Load a sampler that proposes tokens by sampling directly from the LM's distribution\n", + " sampler = direct_token_sampler(LLM)\n", "\n", - " # Run SMC to sample sequences from the constrained model.\n", + " # Run SMC with 5 particles and a maximum of 25 tokens\n", " sequences = await sampler.smc(\n", " n_particles=5,\n", - " ess_threshold=0.5,\n", " max_tokens=100,\n", + " ess_threshold=0.0,\n", " )\n", "\n", " # Return the sampled sequences and their probabilities as a ModelOutput.\n", @@ -180,14 +241,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 4. Run the evaluation\n", + "## 5. Run the evaluation\n", "\n", "Using the dataset, evaluator, and model adaptor, we can now run the evaluation:" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -195,12 +256,12 @@ "output_type": "stream", "text": [ "Instance instance_id=0 pattern='xy|xz'\n", - "Mean weighted accuracy (instance): 0.9999999999999999\n", - "Mean weighted accuracy (total): 0.9999999999999999\n", + "Mean weighted accuracy (instance): 0.0\n", + "Mean weighted accuracy (total): 0.0\n", "\n", "Instance instance_id=1 pattern='ab|c(e|f)'\n", - "Mean weighted accuracy (instance): 1.0\n", - "Mean weighted accuracy (total): 1.0\n", + "Mean weighted accuracy (instance): 0.0\n", + "Mean weighted accuracy (total): 0.0\n", "\n" ] } @@ -220,11 +281,18 @@ " # output_dir=\"results\", # uncomment to save results\n", ")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "genlm", + "display_name": "gen", "language": "python", "name": "python3" }, @@ -238,7 +306,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.12.11" } }, "nbformat": 4, diff --git a/docs/cookbook/custom_potentials.ipynb b/docs/cookbook/custom_potentials.ipynb new file mode 100644 index 0000000..3063ec4 --- /dev/null +++ b/docs/cookbook/custom_potentials.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Evaluate Models Using Custom Constraints\n", + "\n", + "This notebook shows how to evaluate models using custom constraints implemented with genlm-control. It builds upon the custom domain tutorial in [custom_domains.ipynb](custom_domains.ipynb), which covers creating dataset classes and evaluators.\n", + "\n", + "This tutorial covers:\n", + "\n", + "1. Implementing custom potentials to encode constraints\n", + "2. Implementing a model adaptor that uses constrained generation\n", + "3. Running the evaluation\n", + "\n", + "The following example demonstrates these steps on the pattern-matching domain, generating strings that conform to regex pattern specifications.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Implement custom potentials\n", + "\n", + "A potential encodes constraints or preferences by assigning non-negative weights to sequences of tokens. Potentials can be used as components of samplers to propose new tokens at each step of the generation process or serve as critics to reweight sequences based on whether they satisfy the constraint encoded by the potential at each step.\n", + "\n", + "Each potential has a vocabulary that specifies the set of tokens it operates on. Potentials must implement the `prefix` and the `complete` functions, which assign weights to partial and complete sequences, respectively. For a complete guide on implementing potentials, see [documentation of GenLM Control](https://genlm.org/genlm-control/potentials/).\n", + "\n", + "Here we use a PatternPotential checks whether sequences fully match or remain consistent with the pattern-matching specification." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "import string\n", + "import regex\n", + "from genlm.control import Potential\n", + "\n", + "\n", + "class PatternPotential(Potential):\n", + " \"\"\"Potential function for regex pattern matching.\"\"\"\n", + "\n", + " def __init__(self, pattern):\n", + " vocab = list(map(ord, string.printable))\n", + " super().__init__(vocab)\n", + " self.r = regex.compile(pattern)\n", + "\n", + " async def complete(self, context):\n", + " text = \"\".join(map(chr, context))\n", + " match = self.r.fullmatch(text) is not None\n", + " return 0.0 if match else float(\"-inf\")\n", + "\n", + " async def prefix(self, context):\n", + " text = \"\".join(map(chr, context))\n", + " m = self.r.match(text, partial=True)\n", + " match = m is not None and m.start() == 0 and m.end() == len(text)\n", + " return 0.0 if match else float(\"-inf\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Implement a model adaptor\n", + "\n", + "A model adaptor is an async callable that takes a `PatternMatchingInstance` and returns a `ModelOutput`. For this example, we'll use a constrained `genlm.control.PromptedLLM` to generate responses." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "from genlm.control import PromptedLLM, AWRS\n", + "from genlm.eval import ModelOutput, ModelResponse\n", + "from genlm.eval.domains.pattern_matching import (\n", + " default_prompt_formatter,\n", + ")\n", + "\n", + "# Load an LLM\n", + "LLM = PromptedLLM.from_name(\"gpt2\", eos_tokens=[b\"\\n\", b\"\\n\\n\"])\n", + "\n", + "\n", + "async def model(instance, output_dir, replicate):\n", + " # Set the prompt for the LLM.\n", + " LLM.prompt_ids = default_prompt_formatter(\n", + " LLM.model.tokenizer, instance, use_chat_format=False\n", + " )\n", + "\n", + " # Define a potential that ensures the generated text matches the pattern\n", + " potential = PatternPotential(instance.pattern).coerce(LLM, f=b\"\".join)\n", + "\n", + " # Define an adaptive weighted rejection sampler to sample tokens from the constrained model.\n", + " sampler = AWRS(LLM, potential)\n", + "\n", + " # Run SMC to sample sequences from the constrained model.\n", + " sequences = await sampler.smc(\n", + " n_particles=5,\n", + " ess_threshold=0.5,\n", + " max_tokens=100,\n", + " )\n", + "\n", + " # Return the sampled sequences and their probabilities as a ModelOutput.\n", + " return ModelOutput(\n", + " responses=[\n", + " ModelResponse(response=sequence, weight=prob)\n", + " for sequence, prob in sequences.decoded_posterior.items()\n", + " ],\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Run the evaluation\n", + "\n", + "Using the dataset, evaluator, potential, and model adaptor, we can now run the evaluation:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Instance instance_id=0 pattern='xy|xz'\n", + "Mean weighted accuracy (instance): 1.0\n", + "Mean weighted accuracy (total): 1.0\n", + "\n", + "Instance instance_id=1 pattern='ab|c(e|f)'\n", + "Mean weighted accuracy (instance): 1.0\n", + "Mean weighted accuracy (total): 1.0\n", + "\n" + ] + } + ], + "source": [ + "from genlm.eval import run_evaluation\n", + "from genlm.eval.domains.pattern_matching import (\n", + " PatternMatchingDataset,\n", + " PatternMatchingEvaluator,\n", + ")\n", + "\n", + "dataset = PatternMatchingDataset([r\"xy|xz\", r\"ab|c(e|f)\"])\n", + "evaluator = PatternMatchingEvaluator()\n", + "\n", + "results = await run_evaluation(\n", + " dataset=dataset,\n", + " evaluator=evaluator,\n", + " model=model,\n", + " n_replicates=1,\n", + " verbosity=1,\n", + " # output_dir=\"results\", # uncomment to save results\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "gen", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/mkdocs.yml b/mkdocs.yml index 5cc6ac9..f2bfd30 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -39,6 +39,7 @@ nav: - Text to SQL (Spider): cookbook/domains/spider.ipynb - Molecular Synthesis: cookbook/domains/molecular_synthesis.ipynb - Custom Domains: cookbook/custom_domains.ipynb + - Custom Potentials: cookbook/custom_potentials.ipynb markdown_extensions: - pymdownx.highlight: