Skip to content

Commit 3af612c

Browse files
authored
Make dataset generation work with preferences & seed file (#49)
* pass along description * Start adding preferences * Add human preferences to ds generation * Remove rich prompt typo * Fix up seed file reading * Refactor * Add method for generating full dataset * Ask how many dataset examples * Fix typo & comment out client method
1 parent e84081c commit 3af612c

File tree

2 files changed

+143
-77
lines changed

2 files changed

+143
-77
lines changed

quotientai/cli/generate/dataset.py

Lines changed: 96 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
11
import json
2+
import os
23
import time
34

5+
from pathlib import Path
6+
from typing import List, Optional
7+
48
from quotientai._enums import GenerateDatasetType
59
from quotientai.client import QuotientClient
610
from rich import print
711
from rich.console import Console
812
from rich.panel import Panel
913
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
1014
from rich.prompt import Confirm, IntPrompt, Prompt
15+
from rich.table import Table
1116

1217
console = Console()
18+
client = QuotientClient()
1319

14-
from rich.table import Table
1520

16-
17-
def view_examples(graded_examples):
21+
def show_graded_examples(graded_examples):
1822
table = Table(show_header=True, header_style="bold magenta")
1923
table.add_column("ID")
2024
table.add_column("Context")
@@ -36,55 +40,59 @@ def view_examples(graded_examples):
3640
console.print(table)
3741

3842

39-
def get_seed_data(seed: str):
43+
def get_seed_data(seed: str) -> List[Optional[str]]:
4044
if seed is None:
4145
seed_path = Confirm.ask(
42-
"Do you have a seed file (.jsonl) with examples to assist the creation of the dataset?",
46+
"Do you have a seed file ([green].jsonl[/green]) with examples to assist the creation of the dataset?",
4347
)
4448
if not seed_path:
45-
seed_data = "Here is some fake data REPLACE ME"
4649
console.print(
4750
"No problem! We'll generate some examples for you, and you can grade them\n"
4851
)
49-
return
52+
return []
5053
else:
51-
valid_format = False
52-
while not valid_format:
53-
filepath = Prompt.ask("Please provide the path to the seed file.")
54-
55-
if filepath.endswith(".jsonl") or filepath.endswith(".jsonlines"):
56-
valid_format = True
54+
valid_file = False
55+
while not valid_file:
56+
filepath = Prompt.ask("Please provide the path to the seed file")
57+
58+
is_valid_format = filepath.endswith(".jsonl") or filepath.endswith(".jsonlines")
59+
is_valid_path = os.path.exists(filepath)
60+
if is_valid_format and is_valid_path:
61+
valid_file = True
5762
else:
5863
console.print(
59-
"The seed file should be in the .jsonl format. Please provide a valid file."
64+
f"[yellow]Please provide a valid .jsonl file. Got {filepath}"
6065
)
6166
else:
6267
filepath = seed
6368

6469
try:
6570
with open(filepath, "r") as file:
66-
seed_data = file.readlines()
67-
seed_data = [json.loads(seed) for seed in seed_data]
71+
raw_data = file.readlines()
72+
raw_data = [json.loads(line) for line in raw_data]
6873
except FileNotFoundError:
6974
console.print("The file could not be found. Please provide a valid file.")
7075

71-
valid_field = False
72-
# check that we can get the field name by looking at the first example
73-
while not valid_field:
74-
field = Prompt.ask(
75-
"Please indicate the field in the JSONL file that contain an example to use as a seed."
76+
valid_field = False
77+
# check that we can get the field name by looking at the first example
78+
while not valid_field:
79+
available_fields = list(raw_data[0].keys())
80+
field = Prompt.ask(
81+
"Please indicate the field in the file that contains examples to use as a seed. "
82+
f"Available fields: [magenta]{available_fields}"
83+
)
84+
if field not in raw_data[0]:
85+
console.print(
86+
f"The field '{field}' is not present in the seed file."
7687
)
77-
if field not in seed_data[0]:
78-
console.print(
79-
f"The field '{field}' is not present in the seed file. Please provide a valid field."
80-
)
81-
else:
82-
valid_field = True
88+
else:
89+
valid_field = True
8390

84-
console.print("Here is an example from the seed file:")
85-
seed_one = seed_data[0][field]
86-
console.print(seed_one)
91+
console.print("\nHere is an example from the seed file:")
92+
seed_one = raw_data[0][field]
93+
console.print(f"[green] {seed_one}\n")
8794

95+
seed_data = [line[field] for line in raw_data]
8896
return seed_data
8997

9098

@@ -103,7 +111,6 @@ def grade_examples(
103111
task = progress.add_task("generation", total=0)
104112

105113
# Step 4
106-
client = QuotientClient()
107114
examples = client.generate_examples(
108115
generation_type=generation_type,
109116
description=description,
@@ -173,6 +180,36 @@ def grade_examples(
173180
console.print()
174181
return data
175182

183+
def select_next_action():
184+
next_action_choices = {
185+
1: {
186+
"type": "Generate more examples",
187+
"description": "Continue grading more examples.",
188+
},
189+
2: {
190+
"type": "View graded examples",
191+
"description": "View the graded examples.",
192+
},
193+
3: {
194+
"type": "Stop grading and generate the dataset",
195+
"description": "Stop grading and generate the dataset.",
196+
},
197+
}
198+
199+
console.print("What would you like to do next?")
200+
for index, choice in next_action_choices.items():
201+
console.print(
202+
f"[magenta]{index}[/magenta]. {choice['type']}: [white]{choice['description']}[/white]",
203+
style="yellow",
204+
)
205+
206+
console.print()
207+
next_action = IntPrompt.ask(
208+
"Choose an option",
209+
choices=[str(index) for index in next_action_choices.keys()],
210+
)
211+
return next_action
212+
176213

177214
def generation_workflow(seed: str = None):
178215
"""
@@ -236,11 +273,13 @@ def generation_workflow(seed: str = None):
236273
)
237274

238275
description = Prompt.ask(
239-
"[bold]Please describe in detail what the context is like[/bold]"
276+
"[bold]Please describe in detail the context of your problem[/bold]"
240277
)
241278

242279
# if the seed is not provided, ask the user if they have a seed file
243-
seed_data = get_seed_data(seed=seed)
280+
seed_data: Optional[List[str]] = get_seed_data(seed=seed)
281+
if seed_data:
282+
seed_data = seed_data[0]
244283

245284
graded_examples = []
246285
preferences = []
@@ -276,33 +315,7 @@ def generation_workflow(seed: str = None):
276315
f"For better results, we recommend grading [red]5 to 10[/red] examples.\n"
277316
)
278317

279-
next_action_choices = {
280-
1: {
281-
"type": "Generate more examples",
282-
"description": "Continue grading more examples.",
283-
},
284-
2: {
285-
"type": "View graded examples",
286-
"description": "View the graded examples.",
287-
},
288-
3: {
289-
"type": "Stop grading and generate the dataset",
290-
"description": "Stop grading and generate the dataset.",
291-
},
292-
}
293-
294-
console.print("What would you like to do next?")
295-
for index, choice in next_action_choices.items():
296-
console.print(
297-
f"[magenta]{index}[/magenta]. {choice['type']}: [white]{choice['description']}[/white]",
298-
style="yellow",
299-
)
300-
301-
console.print()
302-
next_action = IntPrompt.ask(
303-
"Choose an option",
304-
choices=[str(index) for index in next_action_choices.keys()],
305-
)
318+
next_action = select_next_action()
306319

307320
if next_action == 1:
308321
# Generate more examples
@@ -312,26 +325,32 @@ def generation_workflow(seed: str = None):
312325
)
313326
continue
314327
elif next_action == 2:
315-
view_examples(graded_examples)
316-
if Confirm.ask("Would you like to continue grading more examples?"):
317-
num_examples = IntPrompt.ask(
318-
"How many more examples would you like to generate?",
319-
default=3,
320-
min_value=3,
321-
max_value=10,
322-
)
323-
continue
324-
else:
325-
# Stop grading and generate the dataset
326-
console.print()
327-
console.print(
328-
"[bold]🧪 We will now generate a dataset using the graded examples as a seed.[/bold]"
329-
)
330-
return
328+
show_graded_examples(graded_examples)
331329
else:
332330
# Stop grading and generate the dataset
333331
console.print()
332+
console.print("Sweet!")
333+
num_dataset_examples = IntPrompt.ask(
334+
"How many examples would you like to generate for your dataset? [magenta](Max: 1000)[/magenta]",
335+
)
336+
console.print(
337+
f"[bold]🧪 We will now generate a dataset with {num_dataset_examples} examples, using the graded examples as a seed...[/bold]\n"
338+
)
339+
time.sleep(5)
340+
# client.generate_dataset(
341+
# generation_type=generation_type,
342+
# description=description,
343+
# num_examples=num_examples,
344+
# seed_data=seed_data,
345+
# preferences=preferences,
346+
# )
347+
console.print(
348+
"[green][bold]🚀 Dataset request submitted! "
349+
"You will soon receive an email with your downloadable dataset![/bold][/green]"
350+
)
334351
console.print(
335-
"[bold]🧪 We will now generate a dataset using the graded examples as a seed.[/bold]"
352+
"[yellow]Note: If you see the email in your spam folder please "
353+
"let us know at [red][email protected][/red][/yellow]"
336354
)
355+
time.sleep(0.5)
337356
return

quotientai/client.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,3 +1166,50 @@ def generate_examples(
11661166

11671167
except Exception as e:
11681168
raise QuotientAIException(f"Failed to generate examples: {str(e)}") from e
1169+
1170+
@require_api_key
1171+
def generate_dataset(
1172+
self,
1173+
generation_type: GenerateDatasetType,
1174+
description: str,
1175+
num_examples: int = 3,
1176+
seed_data: str = None,
1177+
preferences: List[dict] = None,
1178+
) -> List[str]:
1179+
try:
1180+
url = f"{self.eval_scheduler_url}/generate/dataset"
1181+
1182+
headers = {
1183+
"Authorization": f"Bearer {self.api_key}",
1184+
}
1185+
params = {
1186+
"generation_type": generation_type.value,
1187+
}
1188+
1189+
data = {
1190+
"inputs": seed_data,
1191+
"description": description,
1192+
"num_examples": num_examples,
1193+
"preferences": preferences,
1194+
}
1195+
response = requests.post(
1196+
url,
1197+
headers=headers,
1198+
params=params,
1199+
json=data,
1200+
)
1201+
result = response.json()
1202+
if response.status_code != 200:
1203+
if "detail" in result:
1204+
raise FastAPIError(response.status_code, result["detail"])
1205+
else:
1206+
response.raise_for_status()
1207+
1208+
return result
1209+
except FastAPIError as fast_err:
1210+
raise QuotientAIException(
1211+
f"Failed to generate dataset: {fast_err.status_code} {fast_err.detail}"
1212+
) from fast_err
1213+
1214+
except Exception as e:
1215+
raise QuotientAIException(f"Failed to generate dataset: {str(e)}") from e

0 commit comments

Comments
 (0)