Skip to content

Commit da28f0f

Browse files
authored
Fix ups for seed data + translation (#52)
* Fix ups for seed data + translation * Fix scheduler url
1 parent 79c3f94 commit da28f0f

File tree

3 files changed

+58
-41
lines changed

3 files changed

+58
-41
lines changed

quotientai/_enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ class GenerateDatasetType(Enum):
88

99
grounded_qa: str = "grounded-question-answering"
1010
summarization: str = "summarization"
11+
translation: str = "translation"

quotientai/cli/generate/dataset.py

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
import random
34
import time
45
from pathlib import Path
56
from typing import List, Optional
@@ -91,14 +92,22 @@ def get_seed_data(seed: str) -> List[Optional[str]]:
9192
seed_one = raw_data[0][field]
9293
console.print(f"[green] {seed_one}\n")
9394

94-
seed_data = [line[field] for line in raw_data]
95+
try:
96+
seed_data = [line[field] for line in raw_data]
97+
except KeyError:
98+
raise Exception(
99+
f"could not find field '{field}' in one or more lines. "
100+
"please ensure all lines have the field available"
101+
)
102+
95103
return seed_data
96104

97105

98106
def grade_examples(
99107
generation_type: GenerateDatasetType,
100108
description: str,
101-
seed_data: str,
109+
language: str = None,
110+
seed_data: List[str] = [],
102111
preferences: list[dict] = None,
103112
num_examples: int = 3,
104113
):
@@ -109,11 +118,18 @@ def grade_examples(
109118
) as progress:
110119
task = progress.add_task("generation", total=0)
111120

121+
# randomly sample `num_examples` number of rows for
122+
# initial generation
123+
seed_data = random.sample(
124+
seed_data,
125+
min(len(seed_data), num_examples),
126+
)
112127
# Step 4
113128
examples = client.generate_examples(
114129
generation_type=generation_type,
115130
description=description,
116131
num_examples=num_examples,
132+
language=language,
117133
seed_data=seed_data,
118134
preferences=preferences,
119135
)
@@ -122,12 +138,12 @@ def grade_examples(
122138
time.sleep(0.10)
123139

124140
# Step 5
125-
step_message = "🚀 Generated 3 examples. Please grade them!"
141+
step_message = f"🚀 Generated {len(examples)} examples. Please grade them!"
126142
console.print("-" * len(step_message))
127143
console.print(f"[bold]{step_message}[/bold]")
128144
console.print("")
129145

130-
context = examples["metadata"]["input_text"]
146+
context = examples["metadata"]["inputs"]
131147

132148
# if the generation type is grounded_qa, we will use the pull input_text and format the
133149
# examples as context, question, and answer.
@@ -137,20 +153,29 @@ def grade_examples(
137153
data = [
138154
{
139155
"id": example["id"],
140-
"context": context,
156+
"context": example["context"],
141157
"question": example["question"],
142158
"answer": example["answer"],
143159
}
144160
for example in examples["pairs"]
145161
]
146-
else:
162+
elif GenerateDatasetType(generation_type) == GenerateDatasetType.summarization:
147163
data = [
148164
{
149165
"id": example["id"],
150166
"context": context,
151167
"summary": example["summary"],
152168
}
153-
for example in examples["data"]
169+
for example in examples["pairs"]
170+
]
171+
elif GenerateDatasetType(generation_type) == GenerateDatasetType.translation:
172+
data = [
173+
{
174+
"id": example["id"],
175+
"text": example["text"],
176+
"translation": example["translation"],
177+
}
178+
for example in examples["pairs"]
154179
]
155180

156181
for idx, datum in enumerate(data):
@@ -168,7 +193,7 @@ def grade_examples(
168193

169194
# add the grade and the explanation to the example
170195
datum["grade"] = 1 if is_good else 0
171-
datum["explanation"] = explanation
196+
datum["feedback"] = explanation
172197

173198
if idx < len(examples) - 1:
174199
console.print("👍 Got it! Here's the next one\n")
@@ -254,6 +279,10 @@ def generation_workflow(seed: str = None):
254279
"type": GenerateDatasetType.summarization,
255280
"description": "A dataset that can be used for evaluating model summarization abilties.",
256281
},
282+
3: {
283+
"type": GenerateDatasetType.translation,
284+
"description": "A dataset that can be used for evaluating model translation abilities.",
285+
},
257286
}
258287

259288
for index, choice in generation_choices.items():
@@ -276,40 +305,29 @@ def generation_workflow(seed: str = None):
276305
"[bold]Please describe in detail the context of your problem[/bold]"
277306
)
278307

279-
# if the seed is not provided, ask the user if they have a seed file
280-
seed_data: Optional[List[str]] = get_seed_data(seed=seed)
281-
if seed_data:
282-
seed_data = seed_data[0]
308+
if generation_type == GenerateDatasetType.translation:
309+
language = Prompt.ask(
310+
"[bold]Please tell us what language you want to translate to[/bold]"
311+
)
283312
else:
284-
seed_data = None
313+
language = None
314+
315+
# if the seed is not provided, ask the user if they have a seed file
316+
seed_data: List[Optional[str]] = get_seed_data(seed=seed)
285317

286318
graded_examples = []
287-
preferences = []
288319
num_examples = 3
289320
while True:
290321
graded = grade_examples(
291322
generation_type=generation_type,
292323
description=description,
293-
seed_data=seed_data,
294324
num_examples=num_examples,
295-
preferences=preferences,
325+
language=language,
326+
seed_data=seed_data,
327+
preferences=graded_examples,
296328
)
297329
graded_examples.extend(graded)
298330

299-
# add the graded examples to the preferences
300-
prefs = [
301-
{
302-
"id": example["id"],
303-
"context": example["context"],
304-
"question": example["question"],
305-
"answer": example["answer"],
306-
"grade": example["grade"],
307-
"feedback": example["explanation"],
308-
}
309-
for example in graded_examples
310-
]
311-
preferences.extend(prefs)
312-
313331
console.print(
314332
f"You have graded [yellow]{len(graded_examples)}[yellow] examples."
315333
)

quotientai/client.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ def __init__(self):
3535
# Base URL for the Supabase project
3636
self.supabase_url = "https://hhqppcqltklzfpggdocb.supabase.co"
3737

38-
self.eval_scheduler_url = (
39-
"http://eval-scheduler-alb-887401167.us-east-2.elb.amazonaws.com"
40-
)
38+
self.eval_scheduler_url = "http://eval-scheduler-alb-887401167.us-east-2.elb.amazonaws.com"
4139

4240
self.supaclient = SyncPostgrestClient(
4341
self.supabase_url + "/rest/v1", headers={"apiKey": self.public_api_key}
@@ -1128,29 +1126,28 @@ def generate_examples(
11281126
generation_type: GenerateDatasetType,
11291127
description: str,
11301128
num_examples: int = 3,
1131-
seed_data: str = None,
1129+
language: str = None,
1130+
seed_data: List[str] = [],
11321131
preferences: List[dict] = None,
11331132
) -> List[str]:
11341133
try:
1135-
url = f"{self.eval_scheduler_url}/generate/examples"
1134+
url = f"{self.eval_scheduler_url}/generate/examples/{generation_type.value}"
11361135

11371136
headers = {
11381137
"Authorization": f"Bearer {self.api_key}",
11391138
}
1140-
params = {
1141-
"generation_type": generation_type.value,
1142-
}
1143-
11441139
data = {
11451140
"inputs": seed_data,
11461141
"description": description,
11471142
"num_examples": num_examples,
11481143
"preferences": preferences,
11491144
}
1145+
if generation_type == GenerateDatasetType.translation:
1146+
data["language"] = language
1147+
11501148
response = requests.post(
11511149
url,
11521150
headers=headers,
1153-
params=params,
11541151
json=data,
11551152
)
11561153
result = response.json()
@@ -1176,6 +1173,7 @@ def generate_dataset(
11761173
description: str,
11771174
num_examples: int = 3,
11781175
seed_data: str = None,
1176+
language: str = None,
11791177
preferences: List[dict] = None,
11801178
) -> List[str]:
11811179
try:
@@ -1194,10 +1192,10 @@ def generate_dataset(
11941192
"num_examples": num_examples,
11951193
"preferences": preferences,
11961194
}
1195+
11971196
response = requests.post(
11981197
url,
11991198
headers=headers,
1200-
params=params,
12011199
json=data,
12021200
)
12031201
result = response.json()

0 commit comments

Comments
 (0)