-
Notifications
You must be signed in to change notification settings - Fork 162
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1373 from mito-ds/ai-evals
Ai evals
- Loading branch information
Showing
13 changed files
with
228 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"python.analysis.typeCheckingMode": "basic", | ||
"mypy-type-checker.args": [ | ||
"--config-file=mypy.ini" | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# AI Evals | ||
|
||
## Running the tests | ||
|
||
1. Create a new virtual environment | ||
``` | ||
python -m venv venv | ||
``` | ||
|
||
2. Activate the virtual environment: | ||
``` | ||
source venv/bin/activate | ||
``` | ||
|
||
3. Install the dependencies: | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
4. Run the tests from the `mito` folder: | ||
TODO: Improve the running so that we don't have to be in the `mito` folder. | ||
``` | ||
python -m evals.main | ||
``` |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import os | ||
from openai import OpenAI | ||
|
||
def get_open_ai_completion(prompt: str): | ||
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | ||
|
||
response = client.chat.completions.create( | ||
model="gpt-4", | ||
messages=[{"role": "user", "content": prompt}], | ||
temperature=0.0 | ||
) | ||
|
||
return response.choices[0].message.content |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, List, Literal | ||
|
||
@dataclass | ||
class NotebookState: | ||
"""Represents the state of variables in a notebook at test time""" | ||
global_vars: Dict[str, Any] | ||
cell_contents: List[str] | ||
|
||
|
||
@dataclass | ||
class TestCase: | ||
"""A single test case with input state and expected output""" | ||
name: str | ||
notebook_state: NotebookState | ||
user_input: str | ||
expected_code: str | ||
tags: List[Literal[ | ||
'variable declaration', | ||
'function declaration', | ||
'dataframe transformation' | ||
]] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from typing import List | ||
from evals.ai_api_calls.get_open_ai_completion import get_open_ai_completion | ||
from evals.prompts.simple_prompt import get_simple_prompt | ||
from evals.eval_types import NotebookState, TestCase | ||
from evals.utils import get_globals_to_compare, get_script_from_cells, print_green, print_red | ||
|
||
|
||
EMPTY_NOTEBOOK_STATE: NotebookState = NotebookState( | ||
global_vars={}, | ||
cell_contents=[] | ||
) | ||
|
||
INITIALIZED_VARIABLES_NOTEBOOK_STATE: NotebookState = NotebookState( | ||
global_vars={'x': 1, 'y': 2, 'z': 3}, | ||
cell_contents=['x = 1', 'y = 2', 'z = 3', ''] | ||
) | ||
|
||
|
||
TESTS: List[TestCase] = [ | ||
TestCase( | ||
name="empty_notebook_variable_declaration", | ||
notebook_state=EMPTY_NOTEBOOK_STATE, | ||
user_input="create a variable x and set it equal to 1", | ||
expected_code='x=1', | ||
tags=['variable declaration'] | ||
), | ||
TestCase( | ||
name="empty_notebook_function_declaration", | ||
notebook_state=EMPTY_NOTEBOOK_STATE, | ||
user_input="create a function my_sum that takes two arguments and returns their sum", | ||
expected_code="""def my_sum(a, b): | ||
return a + b""", | ||
tags=['function declaration'] | ||
), | ||
TestCase( | ||
name="initialized_variables_variable_declaration", | ||
notebook_state=INITIALIZED_VARIABLES_NOTEBOOK_STATE, | ||
user_input="create a new variable that is the product of x, y, and z", | ||
expected_code="w = x * y * z", | ||
tags=['variable declaration'] | ||
) | ||
] | ||
|
||
for test in TESTS: | ||
|
||
# Get the script from the cells | ||
current_cell_contents_script = get_script_from_cells(test.notebook_state.cell_contents) | ||
|
||
# Get the expected code script | ||
expected_code = current_cell_contents_script + "\n" + test.expected_code | ||
|
||
# Create the actual code script produced by the LLM | ||
prompt = get_simple_prompt(test.user_input, test.notebook_state) | ||
ai_generated_code = get_open_ai_completion(prompt) | ||
actual_code = current_cell_contents_script + "\n" + ai_generated_code | ||
|
||
# So that we can compare the results of the two scripts, create global context for | ||
# each script. When calling exec, the globals are updated in place. | ||
expected_globals = {} | ||
actual_globals = {} | ||
|
||
exec(expected_code, expected_globals) | ||
exec(actual_code, actual_globals) | ||
|
||
expected_globals = get_globals_to_compare(expected_globals) | ||
actual_globals = get_globals_to_compare(actual_globals) | ||
|
||
# TODO: Add statistics on how many tests pass/fail | ||
|
||
if expected_globals == actual_globals: | ||
print_green(f"Test {test.name} passed") | ||
else: | ||
print_red(f"Test {test.name} failed") | ||
print("Expected globals:") | ||
print(expected_globals) | ||
print("Actual globals:") | ||
print(actual_globals) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
[mypy] | ||
python_version = 3.8 | ||
warn_return_any = False | ||
warn_unused_configs = True | ||
disallow_untyped_defs = False | ||
disallow_incomplete_defs = False | ||
check_untyped_defs = True | ||
disallow_untyped_decorators = False | ||
no_implicit_optional = True | ||
warn_redundant_casts = True | ||
warn_unused_ignores = True | ||
warn_no_return = True | ||
warn_unreachable = True | ||
strict_optional = True | ||
ignore_missing_imports = True | ||
disable_error_code = var-annotated |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from evals.eval_types import NotebookState | ||
|
||
|
||
def get_simple_prompt(user_input: str, notebook_state: NotebookState) -> str: | ||
return f"""You are an expert python programmer. You are given a set of variables, existing code, and a task. | ||
Respond with the python code and nothing else. | ||
<Example> | ||
You have these variables: | ||
{{'x': 1, 'y': 2}} | ||
The current code cell is: | ||
x = 1 | ||
y = 2 | ||
Your job is to: | ||
Create a new variable z that is the sum of x and y | ||
Response: | ||
z = x + y | ||
</Example> | ||
Now complete this task: | ||
You have these variables: | ||
{notebook_state.global_vars} | ||
The current code cell is: | ||
{notebook_state.cell_contents[-1] if len(notebook_state.cell_contents) > 0 else ""} | ||
Your job is to: | ||
{user_input} | ||
Response:""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
mypy>=1.0.0 | ||
types-setuptools | ||
openai>=1.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from typing import List, Dict, Any | ||
|
||
|
||
def get_script_from_cells(cells: List[str]) -> str: | ||
return "\n".join(cells) | ||
|
||
def get_globals_to_compare(globals: Dict[str, Any]) -> Dict[str, Any]: | ||
""" | ||
Globals have a lot of stuff we don't actually care about comparing. | ||
For now, we only care about comparing variables created by the script. | ||
This functionremoves everything else | ||
""" | ||
|
||
globals = {k: v for k, v in globals.items() if k != "__builtins__"} | ||
|
||
# Remove functions from the globals since we don't want to compare them | ||
globals = {k: v for k, v in globals.items() if not callable(v)} | ||
|
||
return globals | ||
|
||
def print_green(text: str): | ||
print("\033[92m", end="") | ||
print(text) | ||
print("\033[0m", end="") | ||
|
||
def print_red(text: str): | ||
print("\033[91m", end="") | ||
print(text) | ||
print("\033[0m", end="") |
This file was deleted.
Oops, something went wrong.