4
4
import pandas as pd
5
5
import tempfile
6
6
import base64
7
+ import numpy as np
7
8
from tqdm import tqdm
8
9
import torch .distributed as dist
9
10
from ..image_base import ImageBaseDataset
10
11
from ...smp import *
12
+ # from ..utils import get_intermediate_file_path, load, dump
11
13
12
14
13
15
class OmniDocBench (ImageBaseDataset ):
@@ -28,7 +30,7 @@ class OmniDocBench(ImageBaseDataset):
28
30
29
31
2. Mathematical Formula Processing:
30
32
- Convert all mathematical formulas to LaTeX format.
31
- - Enclose inline formulas with \( \). For example: This is an inline formula \( E = mc^2 \)
33
+ # - Enclose inline formulas with \( \). For example: This is an inline formula \( E = mc^2 \)
32
34
- Enclose block formulas with \\[ \\]. For example: \[ \frac{-b \pm \sqrt{b^2 - 4ac}}{2a} \]
33
35
34
36
3. Table Processing:
@@ -75,9 +77,6 @@ def __init__(self,
75
77
tsv_path ,
76
78
match_method :str = 'quick_match' ,
77
79
filter_types :dict = None ):
78
- self .result_foler = '../../../outputs/OmniDocBench'
79
- if not os .path .exists (self .result_foler ):
80
- os .makedirs (self .result_foler )
81
80
self .eval_file = eval_file
82
81
self .match_method = match_method
83
82
self .references = []
@@ -374,17 +373,18 @@ def process_generated_metric_results(self,samples,save_name:str='end2end_quick_m
374
373
'group' :group_result ,
375
374
'page' :page_result
376
375
}
377
- if not os .path .exists ('./output/OmniDocBench' ):
378
- os .makedirs ('./output/OmniDocBench' )
379
376
if isinstance (cur_samples ,list ):
380
377
saved_samples = cur_samples
381
378
else :
382
379
saved_samples = cur_samples .samples
383
- with open (os .path .join (self .result_foler ,f'{ save_name } _result.josn' ),'w' ,encoding = 'utf-8' ) as f :
384
- json .dump (saved_samples ,f ,indent = 4 ,ensure_ascii = False )
380
+ # NOTE: The original code has a bug here, it will overwrite the result file in each iteration.
381
+ # I will fix it by adding element to the filename.
382
+ # NOTE: Fixed typo .josn -> .json
383
+ result_file = get_intermediate_file_path (self .eval_file , f'_{ save_name } _{ element } _result' , 'json' )
384
+ dump (saved_samples , result_file )
385
385
386
- with open ( os . path . join ( self .result_foler , f' { save_name } _metric_result.json' ), 'w' , encoding = 'utf-8' ) as f :
387
- json . dump (result_all ,f , indent = 4 , ensure_ascii = False )
386
+ metric_result_file = get_intermediate_file_path ( self .eval_file , f'_ { save_name } _metric_result' , 'json' )
387
+ dump (result_all , metric_result_file )
388
388
389
389
dict_list = []
390
390
save_dict = {}
@@ -409,20 +409,20 @@ def process_generated_metric_results(self,samples,save_name:str='end2end_quick_m
409
409
dict_list .append (save_dict )
410
410
df = pd .DataFrame (dict_list ,index = ['end2end' ,]).round (3 )
411
411
412
- with open (os .path .join (self .result_foler ,'End2End_Evaluation.json' ),'w' ,encoding = 'utf-8' ) as f :
413
- json .dump (result_all ,f ,indent = 4 ,ensure_ascii = False )
414
- df .to_csv (os .path .join (self .result_foler ,'overall.csv' ))
415
- over_all_path = os .path .join (self .result_foler ,'End2End_Evaluation.json' )
416
- print (f"The save path of overall.csv is :{ over_all_path } " )
412
+ e2e_eval_file = get_intermediate_file_path (self .eval_file , '_End2End_Evaluation' , 'json' )
413
+ dump (result_all , e2e_eval_file )
414
+
415
+ overall_file = get_intermediate_file_path (self .eval_file , '_overall' )
416
+ dump (df , overall_file )
417
+
418
+ print (f"The save path of End2End_Evaluation is: { e2e_eval_file } " )
419
+ print (f"The save path of overall metrics is: { overall_file } " )
417
420
return df
418
421
419
422
420
423
class table_evalutor ():
421
424
def __init__ (self ,eval_file ,tsv_path ):
422
-
423
- self .result_foler = '../../../outputs/OmniDocBench'
424
- if not os .path .exists (self .result_foler ):
425
- os .makedirs (self .result_foler )
425
+ self .eval_file = eval_file
426
426
gt_key = 'html'
427
427
pred_key = 'pred'
428
428
self .category_filter = 'table'
@@ -434,8 +434,8 @@ def load_data(self,eval_file,gt_file,pred_key,gt_key):
434
434
from .data_preprocess import clean_string , normalized_formula , textblock2unicode , normalized_table
435
435
samples = []
436
436
preds = []
437
- predictions = pd . read_excel (eval_file )['prediction' ].tolist ()
438
- gt_samples = pd . read_csv (gt_file , sep = ' \t ' )['answer' ].tolist ()
437
+ predictions = load (eval_file )['prediction' ].tolist ()
438
+ gt_samples = load (gt_file )['answer' ].tolist ()
439
439
load_success ,load_fail = 0 ,0
440
440
for i ,gt_sample in tqdm (enumerate (gt_samples ),desc = 'Loading data' ):
441
441
try :
@@ -533,8 +533,8 @@ def process_generated_metric_results(self,save_name:str='OmniDocBench_table'):
533
533
'page' :page_result
534
534
}
535
535
536
- with open ( os . path . join ( self .result_foler , f' { save_name } _metric_result.json' ), 'w' , encoding = 'utf-8' ) as f :
537
- json . dump (result_all ,f , indent = 4 , ensure_ascii = False )
536
+ metric_result_file = get_intermediate_file_path ( self .eval_file , f'_ { save_name } _metric_result' , 'json' )
537
+ dump (result_all , metric_result_file )
538
538
539
539
dict_list = []
540
540
dict_list .append (result_all ["group" ]["TEDS" ])
@@ -545,10 +545,7 @@ def process_generated_metric_results(self,save_name:str='OmniDocBench_table'):
545
545
selected_columns = df4 [["language: table_en" , "language: table_simplified_chinese" , "language: table_en_ch_mixed" , "line: full_line" , "line: less_line" , "line: fewer_line" , "line: wireless_line" ,
546
546
"with_span: True" , "with_span: False" , "include_equation: True" , "include_equation: False" , "include_background: True" , "include_background: False" , "table_layout: vertical" , "table_layout: horizontal" ]]
547
547
548
- selected_columns .to_csv (os .path .join (self .result_foler ,'table_attribute.csv' ))
549
- table_attribute_path = os .path .join (self .result_foler ,'table_attribute.csv' )
550
- print (f'The save path of table_attribute.csv is :{ table_attribute_path } ' )
551
- selected_columns
552
-
553
-
548
+ table_attr_file = get_intermediate_file_path (self .eval_file , '_table_attribute' )
549
+ dump (selected_columns , table_attr_file )
550
+ print (f'The save path of table_attribute is :{ table_attr_file } ' )
554
551
return selected_columns
0 commit comments