Skip to content

Commit 0b4a25f

Browse files
committed
Add an example run and update README.md
1 parent fb014e5 commit 0b4a25f

File tree

4 files changed

+4243
-3
lines changed

4 files changed

+4243
-3
lines changed

README.md

+195-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,201 @@ It uses the doc strings, type annotations, and method/function names as prompts
1717
- **Github repository**: <https://github.com/blackhc/llm-strategy/>
1818
- **Documentation** <https://blackhc.github.io/llm-strategy/>
1919

20-
## Example
20+
## Research Example
21+
22+
The latest version also includes a package for hyperparameter tracking and collecting traces from LLMs.
23+
24+
This for example allows for meta optimization. See examples/research for a simple implementation using Generics.
25+
26+
You can find an example WandB trace at: https://wandb.ai/blackhc/blackboard-pagi/reports/Meta-Optimization-Example-Trace--Vmlldzo3MDMxODEz?accessToken=p9hubfskmq1z5yj1uz7wx1idh304diiernp7pjlrjrybpaozlwv3dnitjt7vni1j
27+
28+
The prompts showing off the pattern using Generics are straightforward:
29+
```python
30+
T_TaskParameters = TypeVar("T_TaskParameters")
31+
T_TaskResults = TypeVar("T_TaskResults")
32+
T_Hyperparameters = TypeVar("T_Hyperparameters")
33+
34+
35+
class TaskRun(GenericModel, Generic[T_TaskParameters, T_TaskResults, T_Hyperparameters]):
36+
"""
37+
The task run. This is the 'data' we use to optimize the hyperparameters.
38+
"""
39+
40+
task_parameters: T_TaskParameters = Field(..., description="The task parameters.")
41+
hyperparameters: T_Hyperparameters = Field(
42+
...,
43+
description="The hyperparameters used for the task. We optimize these.",
44+
)
45+
all_chat_chains: dict = Field(..., description="The chat chains from the task execution.")
46+
return_value: T_TaskResults | None = Field(
47+
..., description="The results of the task. (None for exceptions/failure.)"
48+
)
49+
exception: list[str] | str | None = Field(..., description="Exception that occurred during the task execution.")
50+
51+
52+
class TaskReflection(BaseModel):
53+
"""
54+
The reflections on the task.
55+
56+
This contains the lessons we learn from each task run to come up with better
57+
hyperparameters to try.
58+
"""
59+
60+
feedback: str = Field(
61+
...,
62+
description=(
63+
"Only look at the final results field. Does its content satisfy the "
64+
"task description and task parameters? Does it contain all the relevant "
65+
"information from the all_chains and all_prompts fields? What could be improved "
66+
"in the results?"
67+
),
68+
)
69+
evaluation: str = Field(
70+
...,
71+
description=(
72+
"The evaluation of the outputs given the task. Is the output satisfying? What is wrong? What is missing?"
73+
),
74+
)
75+
hyperparameter_suggestion: str = Field(
76+
...,
77+
description="How we want to change the hyperparameters to improve the results. What could we try to change?",
78+
)
79+
hyperparameter_missing: str = Field(
80+
...,
81+
description=(
82+
"What hyperparameters are missing to improve the results? What could "
83+
"be changed that is not exposed via hyperparameters?"
84+
),
85+
)
86+
87+
88+
class TaskInfo(GenericModel, Generic[T_TaskParameters, T_TaskResults, T_Hyperparameters]):
89+
"""
90+
The task run and the reflection on the experiment.
91+
"""
92+
93+
task_parameters: T_TaskParameters = Field(..., description="The task parameters.")
94+
hyperparameters: T_Hyperparameters = Field(
95+
...,
96+
description="The hyperparameters used for the task. We optimize these.",
97+
)
98+
reflection: TaskReflection = Field(..., description="The reflection on the task.")
99+
100+
101+
class OptimizationInfo(GenericModel, Generic[T_TaskParameters, T_TaskResults, T_Hyperparameters]):
102+
"""
103+
The optimization information. This is the data we use to optimize the
104+
hyperparameters.
105+
"""
106+
107+
older_task_summary: str | None = Field(
108+
None,
109+
description=(
110+
"A summary of previous experiments and the proposed changes with "
111+
"the goal of avoiding trying the same changes repeatedly."
112+
),
113+
)
114+
task_infos: list[TaskInfo[T_TaskParameters, T_TaskResults, T_Hyperparameters]] = Field(
115+
..., description="The most recent tasks we have run and our reflections on them."
116+
)
117+
best_hyperparameters: T_Hyperparameters = Field(..., description="The best hyperparameters we have found so far.")
118+
119+
120+
class OptimizationStep(GenericModel, Generic[T_TaskParameters, T_TaskResults, T_Hyperparameters]):
121+
"""
122+
The next optimization steps. New hyperparameters we want to try experiments and new
123+
task parameters we want to evaluate on given the previous experiments.
124+
"""
125+
126+
best_hyperparameters: T_Hyperparameters = Field(
127+
...,
128+
description="The best hyperparameters we have found so far given task_infos and history.",
129+
)
130+
suggestion: str = Field(
131+
...,
132+
description=(
133+
"The suggestions for the next experiments. What could we try to "
134+
"change? We will try several tasks next and several sets of hyperparameters. "
135+
"Let's think step by step."
136+
),
137+
)
138+
task_parameters_suggestions: list[T_TaskParameters] = Field(
139+
...,
140+
description="The task parameters we want to try next.",
141+
hint_min_items=1,
142+
hint_max_items=4,
143+
)
144+
hyperparameter_suggestions: list[T_Hyperparameters] = Field(
145+
...,
146+
description="The hyperparameters we want to try next.",
147+
hint_min_items=1,
148+
hint_max_items=2,
149+
)
150+
151+
152+
class ImprovementProbability(BaseModel):
153+
considerations: list[str] = Field(..., description="The considerations for potential improvements.")
154+
probability: float = Field(..., description="The probability of improvement.")
155+
156+
157+
class LLMOptimizer:
158+
@llm_explicit_function
159+
@staticmethod
160+
def reflect_on_task_run(
161+
language_model,
162+
task_run: TaskRun[T_TaskParameters, T_TaskResults, T_Hyperparameters],
163+
) -> TaskReflection:
164+
"""
165+
Reflect on the results given the task parameters and hyperparameters.
166+
167+
This contains the lessons we learn from each task run to come up with better
168+
hyperparameters to try.
169+
"""
170+
raise NotImplementedError()
171+
172+
@llm_explicit_function
173+
@staticmethod
174+
def summarize_optimization_info(
175+
language_model,
176+
optimization_info: OptimizationInfo[T_TaskParameters, T_TaskResults, T_Hyperparameters],
177+
) -> str:
178+
"""
179+
Summarize the optimization info. We want to preserve all relevant knowledge for
180+
improving the hyperparameters in the future. All information from previous
181+
experiments will be forgotten except for what this summary.
182+
"""
183+
raise NotImplementedError()
184+
185+
@llm_explicit_function
186+
@staticmethod
187+
def suggest_next_optimization_step(
188+
language_model,
189+
optimization_info: OptimizationInfo[T_TaskParameters, T_TaskResults, T_Hyperparameters],
190+
) -> OptimizationStep[T_TaskParameters, T_TaskResults, T_Hyperparameters]:
191+
"""
192+
Suggest the next optimization step.
193+
"""
194+
raise NotImplementedError()
195+
196+
@llm_explicit_function
197+
@staticmethod
198+
def probability_for_improvement(
199+
language_model,
200+
optimization_info: OptimizationInfo[T_TaskParameters, T_TaskResults, T_Hyperparameters],
201+
) -> ImprovementProbability:
202+
"""
203+
Return the probability for improvement (between 0 and 1).
204+
205+
This is your confidence that your next optimization steps will improve the
206+
hyperparameters given the information provided. If you think that the
207+
information available is unlikely to lead to better hyperparameters, return 0.
208+
If you think that the information available is very likely to lead to better
209+
hyperparameters, return 1. Be concise.
210+
"""
211+
raise NotImplementedError()
212+
```
213+
214+
## Application Example
21215

22216
```python
23217
from dataclasses import dataclass
Binary file not shown.

examples/research/meta_optimization.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Generic, TypeVar
88

99
import langchain
10+
import wandb
1011
from langchain.cache import SQLiteCache
1112
from langchain.chat_models import ChatOpenAI
1213
from langchain.chat_models.base import BaseChatModel
@@ -15,8 +16,9 @@
1516
from openai import OpenAIError
1617
from pydantic import BaseModel, Field
1718
from pydantic.generics import GenericModel
19+
from rich.console import Console
20+
from rich.traceback import install
1821

19-
import wandb
2022
from llm_hyperparameters.track_execution import (
2123
LangchainInterface,
2224
TrackedChatModel,
@@ -31,6 +33,9 @@
3133
from llm_strategy.chat_chain import ChatChain
3234
from llm_strategy.llm_function import LLMFunction, llm_explicit_function, llm_function
3335

36+
install(show_locals=True, width=190, console=Console(width=190, color_system="truecolor", stderr=True))
37+
38+
3439
langchain.llm_cache = SQLiteCache(".optimization_unit.langchain.db")
3540

3641
chat_model = ChatOpenAI(
@@ -69,7 +74,7 @@ class TaskRun(GenericModel, Generic[T_TaskParameters, T_TaskResults, T_Hyperpara
6974
return_value: T_TaskResults | None = Field(
7075
..., description="The results of the task. (None for exceptions/failure.)"
7176
)
72-
exception: str | None = Field(..., description="Exception that occurred during the task execution.")
77+
exception: list[str] | str | None = Field(..., description="Exception that occurred during the task execution.")
7378

7479

7580
class TaskReflection(BaseModel):
@@ -408,6 +413,7 @@ def get_json_trace_filename(title: str) -> str:
408413
TraceViewerIntegration(),
409414
]
410415

416+
411417
with wandb_tracer(
412418
"BBO",
413419
module_filters="__main__",

0 commit comments

Comments
 (0)