1010
1111from explainaboard import feature
1212from explainaboard .info import BucketPerformance , Performance , SysOutputInfo
13- from explainaboard .metric import MetricStats
13+ import explainaboard .metric
14+ from explainaboard .metric import Metric
1415from explainaboard .processors .processor import Processor
1516from explainaboard .processors .processor_registry import register_processor
1617from explainaboard .tasks import TaskType
17- from explainaboard .utils import bucketing , eval_basic_ner
18+ from explainaboard .utils import bucketing , span_utils
1819from explainaboard .utils .analysis import cap_feature
19- from explainaboard .utils .eval_bucket import f1_seqeval_bucket
2020from explainaboard .utils .py_utils import sort_dict
2121from explainaboard .utils .typing_utils import unwrap
2222
@@ -206,7 +206,13 @@ def default_features(cls) -> feature.Features:
206206
207207 @classmethod
208208 def default_metrics (cls ) -> list [str ]:
209- return ["f1_seqeval" , "recall_seqeval" , "precision_seqeval" ]
209+ return ["F1Score" ]
210+
211+ def _get_true_label (self , data_point : dict ):
212+ return data_point ["true_tags" ]
213+
214+ def _get_predicted_label (self , data_point : dict ):
215+ return data_point ["pred_tags" ]
210216
211217 def _get_statistics_resources (
212218 self , dataset_split : Dataset
@@ -312,13 +318,11 @@ def _get_fre_rank(self, tokens, statistics):
312318 # --- End feature functions
313319
314320 # These return none because NER is not yet in the main metric interface
315- def _get_metrics (self , sys_info : SysOutputInfo ):
316- return None
317-
318- def _gen_metric_stats (
319- self , sys_info : SysOutputInfo , sys_output : list [dict ]
320- ) -> Optional [list [MetricStats ]]:
321- return None
321+ def _get_metrics (self , sys_info : SysOutputInfo ) -> list [Metric ]:
322+ return [
323+ getattr (explainaboard .metric , f'BIO{ name } ' )()
324+ for name in unwrap (sys_info .metric_names )
325+ ]
322326
323327 def _complete_span_features (self , sentence , tags , statistics = None ):
324328
@@ -328,7 +332,7 @@ def _complete_span_features(self, sentence, tags, statistics=None):
328332 efre_dic = statistics ["efre_dic" ] if has_stats else None
329333
330334 span_dics = []
331- chunks = eval_basic_ner . get_chunks (tags )
335+ chunks = span_utils . get_spans_from_bio (tags )
332336 for tag , sid , eid in chunks :
333337 span_text = ' ' .join (sentence [sid :eid ])
334338 # Basic features
@@ -389,35 +393,8 @@ def _complete_features(
389393 dict_sysout ["pred_entity_info" ] = self ._complete_span_features (
390394 tokens , dict_sysout ["pred_tags" ], statistics = external_stats
391395 )
392- return None
393-
394- def get_overall_performance (
395- self ,
396- sys_info : SysOutputInfo ,
397- sys_output : list [dict ],
398- metric_stats : Any = None ,
399- ) -> dict [str , Performance ]:
400- """
401- Get the overall performance according to metrics
402- :param sys_info: Information about the system output
403- :param sys_output: The system output itself
404- :return: a dictionary of metrics to overall performance numbers
405- """
406-
407- true_tags_list = [x ['true_tags' ] for x in sys_output ]
408- pred_tags_list = [x ['pred_tags' ] for x in sys_output ]
409-
410- overall : dict [str , Performance ] = {}
411- for metric_name in unwrap (sys_info .metric_names ):
412- if not metric_name .endswith ('_seqeval' ):
413- raise NotImplementedError (f'Unsupported metric { metric_name } ' )
414- # This gets the appropriate metric from the eval_basic_ner package
415- score_func = getattr (eval_basic_ner , metric_name )
416- overall [metric_name ] = Performance (
417- metric_name = metric_name ,
418- value = score_func (true_tags_list , pred_tags_list ),
419- )
420- return overall
396+ # This is not used elsewhere, so just keep it as-is
397+ return list ()
421398
422399 def _get_span_ids (
423400 self ,
@@ -554,24 +531,24 @@ def get_bucket_cases_ner(
554531 samples_over_bucket_true [bucket_interval ], 'true' , sample_dict
555532 )
556533
557- error_case_list = []
534+ case_list = []
558535 for pos , tags in sample_dict .items ():
559536 true_label = tags .get ('true' , 'O' )
560537 pred_label = tags .get ('pred' , 'O' )
561- if true_label != pred_label :
562- split_pos = pos .split ("|||" )
563- sent_id = int (split_pos [0 ])
564- span = split_pos [- 1 ]
565- system_output_id = sys_output [int (sent_id )]["id" ]
566- error_case = {
567- "span" : span ,
568- "text" : str (system_output_id ),
569- "true_label" : true_label ,
570- "predicted_label" : pred_label ,
571- }
572- error_case_list .append (error_case )
573-
574- return error_case_list
538+
539+ split_pos = pos .split ("|||" )
540+ sent_id = int (split_pos [0 ])
541+ span = split_pos [- 1 ]
542+ system_output_id = sys_output [int (sent_id )]["id" ]
543+ error_case = {
544+ "span" : span ,
545+ "text" : str (system_output_id ),
546+ "true_label" : true_label ,
547+ "predicted_label" : pred_label ,
548+ }
549+ case_list .append (error_case )
550+
551+ return case_list
575552
576553 def get_bucket_performance_ner (
577554 self ,
@@ -593,6 +570,12 @@ def get_bucket_performance_ner(
593570 bucket performance
594571 """
595572
573+ metric_names = unwrap (sys_info .metric_names )
574+ bucket_metrics = [
575+ getattr (explainaboard .metric , name )(ignore_classes = ['O' ])
576+ for name in metric_names
577+ ]
578+
596579 bucket_name_to_performance = {}
597580 for bucket_interval , spans_true in samples_over_bucket_true .items ():
598581
@@ -611,29 +594,29 @@ def get_bucket_performance_ner(
611594 samples_over_bucket_pred ,
612595 )
613596
597+ true_labels = [x ['true_label' ] for x in bucket_samples ]
598+ pred_labels = [x ['predicted_label' ] for x in bucket_samples ]
599+
614600 bucket_performance = BucketPerformance (
615601 bucket_name = bucket_interval ,
616602 n_samples = len (spans_pred ),
617603 bucket_samples = bucket_samples ,
618604 )
619- for metric_name in unwrap (sys_info .metric_names ):
620- """
621- # Note that: for NER task, the bucket-wise evaluation function is a
622- # little different from overall evaluation function
623- # for overall: f1_seqeval
624- # for bucket: f1_seqeval_bucket
625- """
626- f1 , p , r = f1_seqeval_bucket (spans_pred , spans_true )
627- if metric_name == 'f1_seqeval' :
628- my_score = f1
629- elif metric_name == 'precision_seqeval' :
630- my_score = p
631- elif metric_name == 'recall_seqeval' :
632- my_score = r
633- else :
634- raise NotImplementedError (f'Unsupported metric { metric_name } ' )
635- # TODO(gneubig): It'd be better to have significance tests here
636- performance = Performance (metric_name = metric_name , value = my_score )
605+ for metric in bucket_metrics :
606+
607+ metric_val = metric .evaluate (
608+ true_labels , pred_labels , conf_value = sys_info .conf_value
609+ )
610+ conf_low , conf_high = (
611+ metric_val .conf_interval if metric_val .conf_interval else None ,
612+ None ,
613+ )
614+ performance = Performance (
615+ metric_name = metric .name ,
616+ value = metric_val .value ,
617+ confidence_score_low = conf_low ,
618+ confidence_score_high = conf_high ,
619+ )
637620 bucket_performance .performances .append (performance )
638621
639622 bucket_name_to_performance [bucket_interval ] = bucket_performance
@@ -647,7 +630,7 @@ def get_econ_dic(train_word_sequences, tag_sequences_train, tags):
647630 Note: when matching, the text span and tag have been lowercased.
648631 """
649632 econ_dic = dict ()
650- chunks_train = set (eval_basic_ner . get_chunks (tag_sequences_train ))
633+ chunks_train = set (span_utils . get_spans_from_bio (tag_sequences_train ))
651634
652635 # print('tags: ', tags)
653636 count_idx = 0
@@ -722,7 +705,7 @@ def get_econ_dic(train_word_sequences, tag_sequences_train, tags):
722705# Global functions for training set dependent features
723706def get_efre_dic (train_word_sequences , tag_sequences_train ):
724707 efre_dic = dict ()
725- chunks_train = set (eval_basic_ner . get_chunks (tag_sequences_train ))
708+ chunks_train = set (span_utils . get_spans_from_bio (tag_sequences_train ))
726709 count_idx = 0
727710 word_sequences_train_str = ' ' .join (train_word_sequences ).lower ()
728711 for true_chunk in tqdm (chunks_train ):
0 commit comments