11import json
22import os
3+ import random
34import time
45from pathlib import Path
56from typing import List , Optional
@@ -91,14 +92,22 @@ def get_seed_data(seed: str) -> List[Optional[str]]:
9192 seed_one = raw_data [0 ][field ]
9293 console .print (f"[green] { seed_one } \n " )
9394
94- seed_data = [line [field ] for line in raw_data ]
95+ try :
96+ seed_data = [line [field ] for line in raw_data ]
97+ except KeyError :
98+ raise Exception (
99+ f"could not find field '{ field } ' in one or more lines. "
100+ "please ensure all lines have the field available"
101+ )
102+
95103 return seed_data
96104
97105
98106def grade_examples (
99107 generation_type : GenerateDatasetType ,
100108 description : str ,
101- seed_data : str ,
109+ language : str = None ,
110+ seed_data : List [str ] = [],
102111 preferences : list [dict ] = None ,
103112 num_examples : int = 3 ,
104113):
@@ -109,11 +118,18 @@ def grade_examples(
109118 ) as progress :
110119 task = progress .add_task ("generation" , total = 0 )
111120
121+ # randomly sample `num_examples` number of rows for
122+ # initial generation
123+ seed_data = random .sample (
124+ seed_data ,
125+ min (len (seed_data ), num_examples ),
126+ )
112127 # Step 4
113128 examples = client .generate_examples (
114129 generation_type = generation_type ,
115130 description = description ,
116131 num_examples = num_examples ,
132+ language = language ,
117133 seed_data = seed_data ,
118134 preferences = preferences ,
119135 )
@@ -122,12 +138,12 @@ def grade_examples(
122138 time .sleep (0.10 )
123139
124140 # Step 5
125- step_message = "🚀 Generated 3 examples. Please grade them!"
141+ step_message = f "🚀 Generated { len ( examples ) } examples. Please grade them!"
126142 console .print ("-" * len (step_message ))
127143 console .print (f"[bold]{ step_message } [/bold]" )
128144 console .print ("" )
129145
130- context = examples ["metadata" ]["input_text " ]
146+ context = examples ["metadata" ]["inputs " ]
131147
132148 # if the generation type is grounded_qa, we will use the pull input_text and format the
133149 # examples as context, question, and answer.
@@ -137,20 +153,29 @@ def grade_examples(
137153 data = [
138154 {
139155 "id" : example ["id" ],
140- "context" : context ,
156+ "context" : example [ " context" ] ,
141157 "question" : example ["question" ],
142158 "answer" : example ["answer" ],
143159 }
144160 for example in examples ["pairs" ]
145161 ]
146- else :
162+ elif GenerateDatasetType ( generation_type ) == GenerateDatasetType . summarization :
147163 data = [
148164 {
149165 "id" : example ["id" ],
150166 "context" : context ,
151167 "summary" : example ["summary" ],
152168 }
153- for example in examples ["data" ]
169+ for example in examples ["pairs" ]
170+ ]
171+ elif GenerateDatasetType (generation_type ) == GenerateDatasetType .translation :
172+ data = [
173+ {
174+ "id" : example ["id" ],
175+ "text" : example ["text" ],
176+ "translation" : example ["translation" ],
177+ }
178+ for example in examples ["pairs" ]
154179 ]
155180
156181 for idx , datum in enumerate (data ):
@@ -168,7 +193,7 @@ def grade_examples(
168193
169194 # add the grade and the explanation to the example
170195 datum ["grade" ] = 1 if is_good else 0
171- datum ["explanation " ] = explanation
196+ datum ["feedback " ] = explanation
172197
173198 if idx < len (examples ) - 1 :
174199 console .print ("👍 Got it! Here's the next one\n " )
@@ -254,6 +279,10 @@ def generation_workflow(seed: str = None):
254279 "type" : GenerateDatasetType .summarization ,
255280 "description" : "A dataset that can be used for evaluating model summarization abilties." ,
256281 },
282+ 3 : {
283+ "type" : GenerateDatasetType .translation ,
284+ "description" : "A dataset that can be used for evaluating model translation abilities." ,
285+ },
257286 }
258287
259288 for index , choice in generation_choices .items ():
@@ -276,40 +305,29 @@ def generation_workflow(seed: str = None):
276305 "[bold]Please describe in detail the context of your problem[/bold]"
277306 )
278307
279- # if the seed is not provided, ask the user if they have a seed file
280- seed_data : Optional [ List [ str ]] = get_seed_data ( seed = seed )
281- if seed_data :
282- seed_data = seed_data [ 0 ]
308+ if generation_type == GenerateDatasetType . translation :
309+ language = Prompt . ask (
310+ "[bold]Please tell us what language you want to translate to[/bold]"
311+ )
283312 else :
284- seed_data = None
313+ language = None
314+
315+ # if the seed is not provided, ask the user if they have a seed file
316+ seed_data : List [Optional [str ]] = get_seed_data (seed = seed )
285317
286318 graded_examples = []
287- preferences = []
288319 num_examples = 3
289320 while True :
290321 graded = grade_examples (
291322 generation_type = generation_type ,
292323 description = description ,
293- seed_data = seed_data ,
294324 num_examples = num_examples ,
295- preferences = preferences ,
325+ language = language ,
326+ seed_data = seed_data ,
327+ preferences = graded_examples ,
296328 )
297329 graded_examples .extend (graded )
298330
299- # add the graded examples to the preferences
300- prefs = [
301- {
302- "id" : example ["id" ],
303- "context" : example ["context" ],
304- "question" : example ["question" ],
305- "answer" : example ["answer" ],
306- "grade" : example ["grade" ],
307- "feedback" : example ["explanation" ],
308- }
309- for example in graded_examples
310- ]
311- preferences .extend (prefs )
312-
313331 console .print (
314332 f"You have graded [yellow]{ len (graded_examples )} [yellow] examples."
315333 )
0 commit comments