55"""
66
77# Standard
8- import os , shutil
9- import yaml
108from uuid import uuid4
9+ import os
10+ import shutil
1111
1212# Third Party
1313from lm_eval .tasks .unitxt import task
14+ import yaml
1415
1516# First Party
1617from instructlab .eval .mmlu import MMLUBranchEvaluator
2021
2122logger = setup_logger (__name__ )
2223
23- TEMP_DIR_PREFIX = 'unitxt_temp'
24+ TEMP_DIR_PREFIX = "unitxt_temp"
25+
2426
2527class UnitxtEvaluator (MMLUBranchEvaluator ):
2628 """
@@ -29,45 +31,50 @@ class UnitxtEvaluator(MMLUBranchEvaluator):
2931 Attributes:
3032 model_path absolute path to or name of a huggingface model
3133 unitxt_recipe unitxt recipe (see unitxt.ai for more information)
32- A Recipe holds a complete specification of a unitxt pipeline
34+ A Recipe holds a complete specification of a unitxt pipeline
3335 Example: card=cards.wnli,template=templates.classification.multi_class.relation.default,max_train_instances=5,loader_limit=20,num_demos=3,demos_pool_size=10
34-
36+
3537 """
38+
3639 name = "unitxt"
40+
3741 def __init__ (
3842 self ,
39- model_path ,
43+ model_path ,
4044 unitxt_recipe : str ,
4145 ):
4246 task = self .assign_task_name ()
4347 tasks_dir = self .assign_tasks_dir (task )
4448 super ().__init__ (
45- model_path = model_path ,
46- tasks_dir = tasks_dir ,
47- tasks = [task ],
48- few_shots = 0
49+ model_path = model_path , tasks_dir = tasks_dir , tasks = [task ], few_shots = 0
4950 )
5051 self .unitxt_recipe = unitxt_recipe
5152
5253 def assign_tasks_dir (self , task ):
53- return f' { TEMP_DIR_PREFIX } _{ task } '
54+ return f" { TEMP_DIR_PREFIX } _{ task } "
5455
5556 def assign_task_name (self ):
5657 return str (uuid4 ())
5758
58- def prepare_unitxt_files (self )-> tuple :
59+ def prepare_unitxt_files (self ) -> tuple :
5960 task = self .tasks [0 ]
60- yaml_file = os .path .join (self .tasks_dir ,f"{ task } .yaml" )
61+ yaml_file = os .path .join (self .tasks_dir , f"{ task } .yaml" )
6162 create_unitxt_pointer (self .tasks_dir )
62- create_unitxt_yaml (yaml_file = yaml_file , unitxt_recipe = self .unitxt_recipe , task_name = task )
63+ create_unitxt_yaml (
64+ yaml_file = yaml_file , unitxt_recipe = self .unitxt_recipe , task_name = task
65+ )
6366
6467 def remove_unitxt_files (self ):
65- if self .tasks_dir .startswith (TEMP_DIR_PREFIX ): #to avoid unintended deletion if this class is inherited
68+ if self .tasks_dir .startswith (
69+ TEMP_DIR_PREFIX
70+ ): # to avoid unintended deletion if this class is inherited
6671 shutil .rmtree (self .tasks_dir )
6772 else :
68- logger .warning (f"unitxt tasks dir did not start with '{ TEMP_DIR_PREFIX } ' and therefor was not deleted" )
73+ logger .warning (
74+ f"unitxt tasks dir did not start with '{ TEMP_DIR_PREFIX } ' and therefor was not deleted"
75+ )
6976
70- def run (self ,server_url : str | None = None ) -> tuple :
77+ def run (self , server_url : str | None = None ) -> tuple :
7178 """
7279 Runs evaluation
7380
@@ -80,13 +87,16 @@ def run(self,server_url: str | None = None) -> tuple:
8087 os .environ ["TOKENIZERS_PARALLELISM" ] = "true"
8188 results = self ._run_mmlu (server_url = server_url , return_all_results = True )
8289 taskname = self .tasks [0 ]
83- global_scores = results [' results' ][taskname ]
84- global_scores .pop (' alias' )
90+ global_scores = results [" results" ][taskname ]
91+ global_scores .pop (" alias" )
8592 try :
86- instances = results [' samples' ][taskname ]
93+ instances = results [" samples" ][taskname ]
8794 instance_scores = {}
88- metrics = [metric .replace ('metrics.' ,'' ) for metric in instances [0 ]['doc' ]['metrics' ]]
89- for i ,instance in enumerate (instances ):
95+ metrics = [
96+ metric .replace ("metrics." , "" )
97+ for metric in instances [0 ]["doc" ]["metrics" ]
98+ ]
99+ for i , instance in enumerate (instances ):
90100 scores = {}
91101 for metric in metrics :
92102 scores [metric ] = instance [metric ][0 ]
@@ -97,23 +107,20 @@ def run(self,server_url: str | None = None) -> tuple:
97107 logger .error (e .__traceback__ )
98108 instance_scores = None
99109 self .remove_unitxt_files ()
100- return global_scores ,instance_scores
110+ return global_scores , instance_scores
101111
102112
103- def create_unitxt_yaml (yaml_file ,unitxt_recipe , task_name ):
104- data = {
105- 'task' : f'{ task_name } ' ,
106- 'include' : 'unitxt' ,
107- 'recipe' : f'{ unitxt_recipe } '
108- }
109- with open (yaml_file , 'w' ) as file :
113+ def create_unitxt_yaml (yaml_file , unitxt_recipe , task_name ):
114+ data = {"task" : f"{ task_name } " , "include" : "unitxt" , "recipe" : f"{ unitxt_recipe } " }
115+ with open (yaml_file , "w" ) as file :
110116 yaml .dump (data , file , default_flow_style = False )
111117 logger .debug (f"task { task } unitxt recipe written to { yaml_file } " )
112118
119+
113120def create_unitxt_pointer (tasks_dir ):
114121 class_line = "class: !function " + task .__file__ .replace ("task.py" , "task.Unitxt" )
115- output_file = os .path .join (tasks_dir ,' unitxt' )
122+ output_file = os .path .join (tasks_dir , " unitxt" )
116123 os .makedirs (os .path .dirname (output_file ), exist_ok = True )
117- with open (output_file , 'w' ) as f :
124+ with open (output_file , "w" ) as f :
118125 f .write (class_line )
119126 logger .debug (f"Unitxt task pointer written to { output_file } " )
0 commit comments