Skip to content

Commit 8ad504c

Browse files
authored
Merge pull request #1379 from mito-ds/improve-ai-chat-prompt
Improve ai chat prompt
2 parents c1724d0 + a7a0367 commit 8ad504c

11 files changed

+284
-33
lines changed

evals/ai_api_calls/get_open_ai_completion.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from typing import Any, Dict, Optional
23
from openai import OpenAI
34

45
def get_open_ai_completion(prompt: str):
@@ -9,5 +10,22 @@ def get_open_ai_completion(prompt: str):
910
messages=[{"role": "user", "content": prompt}],
1011
temperature=0.0
1112
)
13+
14+
response_content = response.choices[0].message.content
15+
return get_code_block_from_message(response_content)
16+
17+
18+
19+
def get_code_block_from_message(message: str) -> str:
20+
"""
21+
Extract the first code block from a message. A code block is a block of
22+
text that starts with ```python and ends with ```.
23+
"""
24+
print(f"Message: {message}")
25+
26+
# If ```python is not part of the message, then we assume that the
27+
# entire message is the code block
28+
if "```python" not in message:
29+
return message
1230

13-
return response.choices[0].message.content
31+
return message.split('```python')[1].split('```')[0]

evals/main.py

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
# Create the actual code script produced by the LLM
6161
prompt = prompt_generator.get_prompt(test.user_input, test.notebook_state)
6262
ai_generated_code = get_open_ai_completion(prompt)
63+
print(f"AI generated code:\n{ai_generated_code}")
6364
actual_code = current_cell_contents_script + "\n" + ai_generated_code
6465

6566
# So that we can compare the results of the two scripts, create global context for

evals/notebook_states.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@
8383
'excel_transactions': pd.DataFrame({'Transaction ID': [12975, 16889, 57686, 53403, 42699], 'Share Quantity': [20, 25, 24, 22, 40]}),
8484
'excel_transactions': pd.DataFrame({'Transaction ID': [12975, 16889, 57686, 53403, 42699], 'Share Quantity': [20, 25, 24, 22, 0]})},
8585
cell_contents=["""import pandas as pd
86-
excel_transactions = pd.read_excel('evals/data/simple_recon/transactions_excel.csv')
87-
eagle_transactions = pd.read_excel('evals/data/simple_recon/transactions_eagle.csv')
86+
excel_transactions = pd.read_csv('evals/data/simple_recon/transactions_excel.csv')
87+
eagle_transactions = pd.read_csv('evals/data/simple_recon/transactions_eagle.csv')
8888
""", '']
8989
)
9090

evals/prompts/__init__.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from evals.prompts.single_shot_prompt import single_shot_prompt_generator
22
from evals.prompts.multi_shot_prompt import multi_shot_prompt_generator
3-
3+
from evals.prompts.production_prompt_v1 import production_prompt_v1_generator
4+
from evals.prompts.production_prompt_v2 import production_prompt_v2_generator
45
PROMPT_GENERATORS = [
5-
single_shot_prompt_generator,
6-
multi_shot_prompt_generator
7-
]
6+
#single_shot_prompt_generator,
7+
#multi_shot_prompt_generator,
8+
production_prompt_v1_generator,
9+
production_prompt_v2_generator
10+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from evals.eval_types import NotebookState, PromptGenerator
2+
3+
__all__ = ['multi_shot_pandas_focussed_prompt']
4+
5+
class _MultiShotPandasFocussedPrompt(PromptGenerator):
6+
prompt_name = "multi_shot_pandas_focussed_prompt"
7+
8+
def get_prompt(self, user_input: str, notebook_state: NotebookState) -> str:
9+
10+
return f"""You are an expert python programmer writing a script in a Jupyter notebook. You are given a set of variables, existing code, and a task.
11+
12+
Respond with the updated active code cell and a short explanation of the changes you made.
13+
14+
When responding:
15+
- Do not use the word "I"
16+
- Do not recreate variables that already exist
17+
- Keep as much of the original code as possible
18+
19+
<Example 1>
20+
21+
Defined Variables:
22+
{{
23+
'loan_multiplier': 1.5,
24+
'sales_df': pd.DataFrame({{
25+
'transaction_date': ['2024-01-02', '2024-01-02', '2024-01-02', '2024-01-02', '2024-01-03'],
26+
'price_per_unit': [10, 9.99, 13.99, 21.00, 100],
27+
'units_sold': [1, 2, 1, 4, 5],
28+
'total_price': [10, 19.98, 13.99, 84.00, 500]
29+
}})
30+
}}
31+
32+
Code in the active code cell:
33+
```python
34+
import pandas as pd
35+
sales_df = pd.read_csv('./sales.csv')
36+
```
37+
38+
Your task: convert the transaction_date column to datetime and then multiply the total_price column by the sales_multiplier.
39+
40+
Output:
41+
42+
```python
43+
import pandas as pd
44+
sales_df = pd.read_csv('./sales.csv')
45+
sales_df['transaction_date'] = pd.to_datetime(sales_df['transaction_date'])
46+
sales_df['total_price'] = sales_df['total_price'] * sales_multiplier
47+
```
48+
49+
Converted the `transaction_date` column to datetime using the built-in pd.to_datetime function and multiplied the `total_price` column by the `sales_multiplier` variable.
50+
</Example 1>
51+
52+
<Example 2>
53+
Defined Variables:
54+
{{
55+
'df': pd.DataFrame({{
56+
'id': ['id-49830', 'id-39301', 'id-85011', 'id-51892', 'id-99111'],
57+
'name': ['Tamir', 'Aaron', 'Grace', 'Nawaz', 'Julia'],
58+
'age': [29, 31, 26, 21, 30],
59+
'dob': ['1994-06-15', '1992-03-27', '1997-04-11', '2002-07-05', '1993-08-22'],
60+
'city': ['San Francisco', 'New York', 'Los Angeles', 'Chicago', 'Houston'],
61+
'state': ['CA', 'NY', 'CA', 'IL', 'TX'],
62+
'zip': ['94103', '10001', '90038', '60611', '77002'],
63+
'start_date': ['2024-01-01', '2024-01-01', '2024-01-01', '2024-01-01', '2024-01-01'],
64+
'department': ['Engineering', 'Sales', 'Marketing', 'Operations', 'Finance'],
65+
'salary': ['$100,000', '$50,000', '$60,000', '$55,000', '$70,000']
66+
}})
67+
}}
68+
69+
Code in the active code cell:
70+
```python
71+
72+
```
73+
74+
Your task: Calculate the weekly salary for each employee.
75+
76+
Output:
77+
78+
```python
79+
df['salary'] = df['salary'].str[1:].replace(',', '', regex=True).astype('float')
80+
df['weekly_salary'] = df['salary'] / 52
81+
```
82+
83+
Remove the `$` and `,` from the `salary` in order to convert it to a float. Then, divide the salary by 52 to get the weekly salary.
84+
</Example 2>
85+
86+
Defined Variables:
87+
{notebook_state.global_vars}
88+
89+
Code in the active code cell:
90+
91+
```python
92+
{notebook_state.cell_contents[-1] if len(notebook_state.cell_contents) > 0 else ""}
93+
```
94+
95+
Your task: ${user_input}"""
96+
97+
multi_shot_pandas_focussed_prompt = _MultiShotPandasFocussedPrompt()

evals/prompts/multi_shot_prompt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class _MultiShotPromptGenerator(PromptGenerator):
88
def get_prompt(self, user_input: str, notebook_state: NotebookState) -> str:
99
return f"""You are an expert python programmer. You are given a set of variables, existing code, and a task.
1010
11-
Respond with the python code and nothing else.
11+
Respond with the python code that starts with ```python and ends with ```. Do not return anything else.
1212
1313
<Example 1>
1414
You have these variables:

evals/prompts/production_prompt_v1.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from evals.eval_types import NotebookState, PromptGenerator
2+
3+
__all__ = ['production_prompt_v1_generator']
4+
5+
class _ProductionPromptV1(PromptGenerator):
6+
prompt_name = "production_prompt_v1"
7+
8+
def get_prompt(self, user_input: str, notebook_state: NotebookState) -> str:
9+
10+
return f"""You have access to the following variables:
11+
12+
{notebook_state.global_vars}
13+
14+
Complete the task below. Decide what variables to use and what changes you need to make to the active code cell. Only return the full new active code cell and a concise explanation of the changes you made.
15+
16+
<Reminders>
17+
18+
Do not:
19+
- Use the word "I"
20+
- Include multiple approaches in your response
21+
- Recreate variables that already exist
22+
23+
Do:
24+
- Use the variables that you have access to
25+
- Keep as much of the original code as possible
26+
- Ask for more context if you need it.
27+
28+
</Reminders>
29+
30+
<Example>
31+
32+
Code in the active code cell:
33+
34+
```python
35+
import pandas as pd
36+
loans_df = pd.read_csv('./loans.csv')
37+
```
38+
39+
Your task: convert the issue_date column to datetime.
40+
41+
Output:
42+
43+
```python
44+
import pandas as pd
45+
loans_df = pd.read_csv('./loans.csv')
46+
loans_df['issue_date'] = pd.to_datetime(loans_df['issue_date'])
47+
```
48+
49+
Use the pd.to_datetime function to convert the issue_date column to datetime.
50+
51+
</Example>
52+
53+
Code in the active code cell:
54+
55+
```python
56+
{notebook_state.cell_contents[-1] if len(notebook_state.cell_contents) > 0 else ""}
57+
```
58+
59+
Your task: ${user_input}"""
60+
61+
production_prompt_v1_generator = _ProductionPromptV1()

evals/prompts/production_prompt_v2.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from evals.eval_types import NotebookState, PromptGenerator
2+
3+
__all__ = ['production_prompt_v2_generator']
4+
5+
class _ProductionPromptV2(PromptGenerator):
6+
prompt_name = "production_prompt_v2"
7+
8+
def get_prompt(self, user_input: str, notebook_state: NotebookState) -> str:
9+
10+
return f"""You are an expert python programmer writing a script in a Jupyter notebook. You are given a set of variables, existing code, and a task.
11+
12+
Respond with the updated active code cell and a short explanation of the changes you made.
13+
14+
When responding:
15+
- Do not use the word "I"
16+
- Do not recreate variables that already exist
17+
- Keep as much of the original code as possible
18+
19+
<Example>
20+
21+
Defined Variables:
22+
{{
23+
'loan_multiplier': 1.5,
24+
'sales_df': pd.DataFrame({{
25+
'transaction_date': ['2024-01-02', '2024-01-02', '2024-01-02', '2024-01-02', '2024-01-03'],
26+
'price_per_unit': [10, 9.99, 13.99, 21.00, 100],
27+
'units_sold': [1, 2, 1, 4, 5],
28+
'total_price': [10, 19.98, 13.99, 84.00, 500]
29+
}})
30+
}}
31+
32+
Code in the active code cell:
33+
```python
34+
import pandas as pd
35+
sales_df = pd.read_csv('./sales.csv')
36+
```
37+
38+
Your task: convert the transaction_date column to datetime and then multiply the total_price column by the sales_multiplier.
39+
40+
Output:
41+
42+
```python
43+
import pandas as pd
44+
sales_df = pd.read_csv('./sales.csv')
45+
sales_df['transaction_date'] = pd.to_datetime(sales_df['transaction_date'])
46+
sales_df['total_price'] = sales_df['total_price'] * sales_multiplier
47+
```
48+
49+
Converted the `transaction_date` column to datetime using the built-in pd.to_datetime function and multiplied the `total_price` column by the `sales_multiplier` variable.
50+
51+
</Example>
52+
53+
Defined Variables:
54+
{notebook_state.global_vars}
55+
56+
Code in the active code cell:
57+
58+
```python
59+
{notebook_state.cell_contents[-1] if len(notebook_state.cell_contents) > 0 else ""}
60+
```
61+
62+
Your task: ${user_input}"""
63+
64+
production_prompt_v2_generator = _ProductionPromptV2()

evals/prompts/single_shot_prompt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class _SingleShotPromptGenerator(PromptGenerator):
88
def get_prompt(self, user_input: str, notebook_state: NotebookState) -> str:
99
return f"""You are an expert python programmer. You are given a set of variables, existing code, and a task.
1010
11-
Respond with the python code and nothing else.
11+
Respond with the python code that starts with ```python and ends with ```. Do not return anything else.
1212
1313
<Example>
1414
You have these variables:

evals/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,4 @@ def are_globals_equal(globals1: Dict[str, Any], globals2: Dict[str, Any]) -> boo
8686
if var_one != var_two:
8787
return False
8888

89-
return True
89+
return True

0 commit comments

Comments
 (0)