diff --git a/benchmarks/table/gemini.py b/benchmarks/table/gemini.py index 5832a90f..560c681e 100644 --- a/benchmarks/table/gemini.py +++ b/benchmarks/table/gemini.py @@ -22,10 +22,16 @@ 3. Output only the HTML for the table, starting with the tag and ending with the
tag. """.strip() +prompt_with_header = prompt.replace( + "Only use , , and , , or
tags. Only use the colspan and rowspan attributes if necessary. Do not use
tags.", + "Only use , , , , and and , but do not use which fintabnet data doesn't + # Fintabnet doesn't use th tags, need to be replaced for fair comparison + marker_table_soup = BeautifulSoup(table_html, 'html.parser') + tbody = marker_table_soup.find('tbody') + if tbody: + tbody.unwrap() + for th_tag in marker_table_soup.find_all('th'): + th_tag.name = 'td' + for br_tag in marker_table_soup.find_all('br'): + br_tag.replace_with(marker_table_soup.new_string('')) + + marker_table_html = str(marker_table_soup) + marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines + return marker_table_html + + def construct_row_result(self, row, gt_table, marker_table, gemini_table, **kwargs): + return { + "marker_table": marker_table, + "gt_table": gt_table, + "gemini_table": gemini_table + } + + def _load_marker_output(self, i, row): + """Opportunity to load marker output from a previous run""" + return None, None + + def _cache_marker_output(self, i, row, marker_json, page_image): + """Opportunity to save marker output, to allow resuming if interrupted""" + pass + + def inference_tables(self, dataset, use_llm: bool, table_rec_batch_size: int | None, max_rows: int, use_gemini: bool): + + total_unaligned = 0 + results = [] + + iterations = len(dataset) + if max_rows is not None: + iterations = min(max_rows, len(dataset)) + + models = create_model_dict() + config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, + "table_rec_batch_size": table_rec_batch_size, "disable_tqdm": True}) + converter = self.get_converter(models, config_parser) + for i in tqdm(range(iterations), desc='Converting Tables'): + try: + row = dataset[i] + + # save progress while running + marker_json, page_image = self._load_marker_output(i, row) + if marker_json is None: + marker_json, page_image = self.extract_tables_from_doc(converter, row) + self._cache_marker_output(i, row, marker_json, page_image) + + gt_tables = self.extract_gt_tables(row) # Already sorted by reading order, which is what marker returns + + if len(marker_json) == 0 or len(gt_tables) == 0: + print(f'No tables detected, skipping...') + total_unaligned += len(gt_tables) continue - if aligned_idx in used_tables: - # Marker table already aligned with another gt table - unaligned_tables.add(table_idx) - continue + marker_tables = self.extract_tables_from_json(marker_json) + marker_table_boxes = [table.bbox for table in marker_tables] + page_bbox = marker_json[0].bbox - # Gt table doesn't align well with any marker table - gt_table_pct = gt_areas[table_idx] / max_area - if not .85 < gt_table_pct < 1.15: - unaligned_tables.add(table_idx) + if len(marker_tables) != len(gt_tables): + print(f'Number of tables do not match, skipping...') + total_unaligned += len(gt_tables) continue - # Marker table doesn't align with gt table - marker_table_pct = marker_areas[aligned_idx] / max_area - if not .85 < marker_table_pct < 1.15: - unaligned_tables.add(table_idx) - continue - - gemini_html = "" - if use_gemini: + table_images = [ + page_image.crop( + PolygonBox.from_bbox(bbox) + .rescale( + (page_bbox[2], page_bbox[3]), (page_image.width, page_image.height) + ).bbox + ) + for bbox + in marker_table_boxes + ] + + # Normalize the bboxes + for bbox in marker_table_boxes: + bbox[0] = bbox[0] / page_bbox[2] + bbox[1] = bbox[1] / page_bbox[3] + bbox[2] = bbox[2] / page_bbox[2] + bbox[3] = bbox[3] / page_bbox[3] + + gt_boxes = [table['normalized_bbox'] for table in gt_tables] + gt_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in gt_boxes] + marker_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in marker_table_boxes] + table_alignments = matrix_intersection_area(gt_boxes, marker_table_boxes) + + aligned_tables = [] + used_tables = set() + unaligned_tables = set() + for table_idx, alignment in enumerate(table_alignments): try: - gemini_html = gemini_table_rec(table_images[aligned_idx]) - except Exception as e: - print(f'Gemini failed: {e}') - - aligned_tables.append( - (marker_tables[aligned_idx], gt_tables[table_idx], gemini_html) - ) - used_tables.add(aligned_idx) - - total_unaligned += len(unaligned_tables) - - for marker_table, gt_table, gemini_table in aligned_tables: - gt_table_html = gt_table['html'] - - # marker wraps the table in which fintabnet data doesn't - # Fintabnet doesn't use th tags, need to be replaced for fair comparison - marker_table_html = fix_table_html(marker_table.html) - gemini_table_html = fix_table_html(gemini_table) - - results.append({ - "marker_table": marker_table_html, - "gt_table": gt_table_html, - "gemini_table": gemini_table_html - }) - except pdfium.PdfiumError: - print('Broken PDF, Skipping...') - continue - return results, total_unaligned \ No newline at end of file + max_area = np.max(alignment) + aligned_idx = np.argmax(alignment) + except ValueError: + # No alignment found + unaligned_tables.add(table_idx) + continue + + if max_area <= .01: + # No alignment found + unaligned_tables.add(table_idx) + continue + + if aligned_idx in used_tables: + # Marker table already aligned with another gt table + unaligned_tables.add(table_idx) + continue + + # Gt table doesn't align well with any marker table + gt_table_pct = gt_areas[table_idx] / max_area + if not .85 < gt_table_pct < 1.15: + unaligned_tables.add(table_idx) + continue + + # Marker table doesn't align with gt table + marker_table_pct = marker_areas[aligned_idx] / max_area + if not .85 < marker_table_pct < 1.15: + unaligned_tables.add(table_idx) + continue + + gemini_html = "" + if use_gemini: + try: + gemini_html = self.extract_gemini_tables(row, table_images[aligned_idx]) + except Exception as e: + print(f'Gemini failed: {e}') + + aligned_tables.append( + (marker_tables[aligned_idx], gt_tables[table_idx], gemini_html) + ) + used_tables.add(aligned_idx) + + total_unaligned += len(unaligned_tables) + + for marker_table, gt_table, gemini_table in aligned_tables: + gt_table_html = gt_table['html'] + + marker_table_html = self.fix_table_html(marker_table.html, author='marker') + gemini_table_html = self.fix_table_html(gemini_table, author='gemini') + + results.append(self.construct_row_result(row, gt_table_html, marker_table_html, gemini_table_html)) + except pdfium.PdfiumError: + print('Broken PDF, Skipping...') + continue + return results, total_unaligned \ No newline at end of file diff --git a/benchmarks/table/scoring.py b/benchmarks/table/scoring.py index 940bd6e4..9b33ab85 100644 --- a/benchmarks/table/scoring.py +++ b/benchmarks/table/scoring.py @@ -9,6 +9,9 @@ from collections import deque def wrap_table_html(table_html:str)->str: + if not table_html.startswith('
tags. Only use the colspan and rowspan attributes if necessary. Use
.", 1 +) + + class TableSchema(BaseModel): table_html: str -def gemini_table_rec(image: Image.Image): +def gemini_table_rec(image: Image.Image, prompt=prompt): client = genai.Client( api_key=settings.GOOGLE_API_KEY, http_options={"timeout": 60000} diff --git a/benchmarks/table/inference.py b/benchmarks/table/inference.py index 07e5e92c..d504d0d5 100644 --- a/benchmarks/table/inference.py +++ b/benchmarks/table/inference.py @@ -1,3 +1,5 @@ +from functools import partialmethod +import os from typing import List import numpy as np @@ -17,166 +19,203 @@ from marker.schema.polygon import PolygonBox from marker.util import matrix_intersection_area - -def extract_tables(children: List[JSONBlockOutput]): - tables = [] - for child in children: - if child.block_type == 'Table': - tables.append(child) - elif child.children: - tables.extend(extract_tables(child.children)) - return tables - -def fix_table_html(table_html: str) -> str: - marker_table_soup = BeautifulSoup(table_html, 'html.parser') - tbody = marker_table_soup.find('tbody') - if tbody: - tbody.unwrap() - for th_tag in marker_table_soup.find_all('th'): - th_tag.name = 'td' - for br_tag in marker_table_soup.find_all('br'): - br_tag.replace_with(marker_table_soup.new_string('')) - - marker_table_html = str(marker_table_soup) - marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines - return marker_table_html - - -def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, max_rows: int, use_gemini: bool): - models = create_model_dict() - config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size, "disable_tqdm": True}) - total_unaligned = 0 - results = [] - - iterations = len(dataset) - if max_rows is not None: - iterations = min(max_rows, len(dataset)) - - for i in tqdm(range(iterations), desc='Converting Tables'): - try: - row = dataset[i] - pdf_binary = base64.b64decode(row['pdf']) - gt_tables = row['tables'] # Already sorted by reading order, which is what marker returns - - # Only use the basic table processors - converter = TableConverter( - config=config_parser.generate_config_dict(), - artifact_dict=models, - processor_list=[ - "marker.processors.table.TableProcessor", - "marker.processors.llm.llm_table.LLMTableProcessor", - ], - renderer=config_parser.get_renderer() - ) - - with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as temp_pdf_file: +class FinTabNetBenchmark: + def extract_tables_from_json(self, children: List[JSONBlockOutput]) -> List[JSONBlockOutput]: + tables = [] + for child in children: + if child.block_type == 'Table': + tables.append(child) + elif child.children: + tables.extend(self.extract_tables_from_json(child.children)) + return tables + + def get_converter(self, models, config_parser, **kwargs): + # Only use the basic table processors + converter = TableConverter( + config=config_parser.generate_config_dict(), + artifact_dict=models, + processor_list=[ + "marker.processors.table.TableProcessor", + "marker.processors.llm.llm_table.LLMTableProcessor", + ], + renderer=config_parser.get_renderer() + ) + return converter + + def extract_tables_from_doc(self, converter, row): + """Extract table and images from pdf; produce marker_json and page_image""" + + pdf_binary = base64.b64decode(row['pdf']) + + # https://stackoverflow.com/a/23212515 + with tempfile.TemporaryDirectory() as temp_dir: + temp_filepath = os.path.join(temp_dir, 'temp.pdf') + with open(temp_filepath, 'wb') as temp_pdf_file: temp_pdf_file.write(pdf_binary) - temp_pdf_file.seek(0) - marker_json = converter(temp_pdf_file.name).children - - doc = pdfium.PdfDocument(temp_pdf_file.name) - page_image = doc[0].render(scale=96/72).to_pil() - doc.close() - - if len(marker_json) == 0 or len(gt_tables) == 0: - print(f'No tables detected, skipping...') - total_unaligned += len(gt_tables) - continue - - marker_tables = extract_tables(marker_json) - marker_table_boxes = [table.bbox for table in marker_tables] - page_bbox = marker_json[0].bbox - - if len(marker_tables) != len(gt_tables): - print(f'Number of tables do not match, skipping...') - total_unaligned += len(gt_tables) - continue - - table_images = [ - page_image.crop( - PolygonBox.from_bbox(bbox) - .rescale( - (page_bbox[2], page_bbox[3]), (page_image.width, page_image.height) - ).bbox - ) - for bbox - in marker_table_boxes - ] - - # Normalize the bboxes - for bbox in marker_table_boxes: - bbox[0] = bbox[0] / page_bbox[2] - bbox[1] = bbox[1] / page_bbox[3] - bbox[2] = bbox[2] / page_bbox[2] - bbox[3] = bbox[3] / page_bbox[3] - - gt_boxes = [table['normalized_bbox'] for table in gt_tables] - gt_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in gt_boxes] - marker_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in marker_table_boxes] - table_alignments = matrix_intersection_area(gt_boxes, marker_table_boxes) - - aligned_tables = [] - used_tables = set() - unaligned_tables = set() - for table_idx, alignment in enumerate(table_alignments): - try: - max_area = np.max(alignment) - aligned_idx = np.argmax(alignment) - except ValueError: - # No alignment found - unaligned_tables.add(table_idx) - continue - - if max_area <= .01: - # No alignment found - unaligned_tables.add(table_idx) + temp_pdf_file.flush() + + marker_json = converter(temp_pdf_file.name).children + + doc = pdfium.PdfDocument(temp_pdf_file.name) + page_image = doc[0].render(scale=96/72).to_pil() + doc.close() + return marker_json, page_image + + def extract_gt_tables(self, row, **kwargs): + return row['tables'] + + def extract_gemini_tables(self, row, image, **kwargs): + return gemini_table_rec(image) + + def fix_table_html(self, table_html: str, author='marker') -> str: + # marker wraps the table in
'): + table_html = f'
{table_html}
' + return f'{table_html}' class TableTree(Tree): diff --git a/benchmarks/table/synthtabnet.py b/benchmarks/table/synthtabnet.py new file mode 100644 index 00000000..9c15120d --- /dev/null +++ b/benchmarks/table/synthtabnet.py @@ -0,0 +1,262 @@ +""" +Detect tables when the pdf is not available +""" + + +from functools import cache, partialmethod +import io +import os +import tempfile +from typing import Tuple + +from bs4 import BeautifulSoup +from tqdm import tqdm +from benchmarks.table.gemini import gemini_table_rec, prompt_with_header +from benchmarks.table.inference import FinTabNetBenchmark +from marker.builders.document import DocumentBuilder +from marker.processors import BaseProcessor +from marker.processors.llm.llm_complex import LLMComplexRegionProcessor +from marker.processors.llm.llm_form import LLMFormProcessor +from marker.processors.llm.llm_table import LLMTableProcessor +from marker.processors.llm.llm_table_merge import LLMTableMergeProcessor +from marker.processors.table import TableProcessor +from marker.converters.table import TableConverter + +import pdftext.schema +from surya.detection import DetectionPredictor +from surya.recognition import RecognitionPredictor +from surya.table_rec import TableRecPredictor + +from marker.providers.registry import provider_from_filepath +from marker.renderers.json import JSONBlockOutput +from tests.utils import convert_to_pdftext + + +class GroundTruthPagesForcer: + def __init__(self): + # debug data + self.forced_pages = None # [convert_to_page(get_raw_test_data(), [0, 0, 612.0, 792.0], 0)] + def __call__(self, filepath): + return self.forced_pages + +class ChangedTableProcessor(TableProcessor): + + + def __init__( + self, + detection_model: DetectionPredictor, + recognition_model: RecognitionPredictor, + table_rec_model: TableRecPredictor, + config=None, + _gt_forcer: GroundTruthPagesForcer=None, + ): + super().__init__( + detection_model=detection_model, + recognition_model=recognition_model, + table_rec_model=table_rec_model, + config=config + ) + self.detection_model = detection_model + self.recognition_model = recognition_model + self.table_rec_model = table_rec_model + self._gt_forcer = _gt_forcer + + def assign_pdftext_lines(self, extract_blocks: list[dict], filepath): + forced_pages = self._gt_forcer(filepath) + self.assign_forced_lines(extract_blocks, forced_pages) + + def assign_ocr_lines(self, ocr_blocks): + forced_pages = self._gt_forcer(None) + self.assign_forced_lines(ocr_blocks, forced_pages) + + def assign_forced_lines(self, extract_blocks: list[dict], forced_pages): + table_inputs = [] + unique_pages = list(set([t["page_id"] for t in extract_blocks])) + if len(unique_pages) == 0: + return + + for page in unique_pages: + tables = [] + img_size = None + for block in extract_blocks: + if block["page_id"] == page: + tables.append(block["table_bbox"]) + img_size = block["img_size"] + + table_inputs.append({ + "tables": tables, + "img_size": img_size + }) + + # NOTE: added this method + def table_output(filepath, table_inputs, page_range=unique_pages): + # mock forced_pages + pages = [] + for i in page_range: + pages.append(forced_pages[i]) + + + from pdftext.extraction import table_cell_text + out_tables = [] + for page, table_input in zip(forced_pages, table_inputs): + tables = table_cell_text(table_input["tables"], page, table_input["img_size"]) + assert len(tables) == len(table_input["tables"]), "Number of tables and table inputs must match" + out_tables.append(tables) + return out_tables + cell_text = table_output(None, table_inputs, page_range=unique_pages) + + assert len(cell_text) == len(unique_pages), "Number of pages and table inputs must match" + + for pidx, (page_tables, pnum) in enumerate(zip(cell_text, unique_pages)): + table_idx = 0 + for block in extract_blocks: + if block["page_id"] == pnum: + block["table_text_lines"] = page_tables[table_idx] + table_idx += 1 + assert table_idx == len(page_tables), "Number of tables and table inputs must match" + + +class ChangedTableConverter(TableConverter): + default_processors: Tuple[BaseProcessor, ...] = ( + ChangedTableProcessor, # NOTE: changed this line + LLMTableProcessor, + LLMTableMergeProcessor, + LLMFormProcessor, + LLMComplexRegionProcessor, + ) + + @cache + def build_document(self, filepath: str): + provider_cls = provider_from_filepath(filepath) + layout_builder = self.resolve_dependencies(self.layout_builder_class) + line_builder = lambda *args, **kwargs: None + ocr_builder = lambda *args, **kwargs: None + document_builder = DocumentBuilder(self.config) + document_builder.disable_ocr = True + with provider_cls(filepath, self.config) as provider: + document = document_builder(provider, layout_builder, line_builder, ocr_builder) + + for page in document.pages: + page.structure = [p for p in page.structure if p.block_type in self.converter_block_types] + + for processor in self.processor_list: + processor(document) + + return document + +class SynthTabNetBenchmark(FinTabNetBenchmark): + gt_forcer: GroundTruthPagesForcer + + def get_converter(self, models, config_parser, **kwargs): + config_parser.cli_options['force_layout_block'] = 'Table' + config_parser.cli_options['disable_tqdm'] = True + config_parser.cli_options['disable_ocr'] = True + config_parser.cli_options['document_ocr_threshold'] = 0.0 + self.gt_forcer = GroundTruthPagesForcer() + + return ChangedTableConverter( + config={ + **config_parser.generate_config_dict(), + "document_ocr_threshold": 0 + # never perform OCR for evaluation: we know the ground truth + }, + artifact_dict={ + '_gt_forcer': self.gt_forcer, + **models + }, + processor_list=[ + "marker.processors.table.TableProcessor", + "marker.processors.llm.llm_table.LLMTableProcessor", + ], + renderer=config_parser.get_renderer() + ) + + def synthtabnet_page_with_gt_words(self, row, page_image): + bboxes = row['word_bboxes'] + words = row['words'] + good = [i for i, word in enumerate(words) if word] # allow only non-empty words + bboxes = [(*bboxes[i], words[i]) for i in good] + + image_bbox = [0, 0, page_image.width, page_image.height] + return convert_to_pdftext(bboxes, image_bbox, 0) + + def extract_tables_from_doc(self, converter, row): + original_tqdm = tqdm.__init__ + + # disabled_tqdm = original_tqdm + def disabled_tqdm(*args, **kwargs): + if not kwargs.get('disable', False): + kwargs['disable'] = True + return original_tqdm(*args, **kwargs) + + # https://stackoverflow.com/a/23212515 + with tempfile.TemporaryDirectory() as temp_dir: + + bytesio = io.BytesIO() + page_image = row['image'] # PIL.Image.Image + page_image.save(bytesio, format="PNG") + + temp_filepath = os.path.join(temp_dir, 'temp.png') + with open(temp_filepath, 'wb') as temp_png_file: + temp_png_file.write(bytesio.getvalue()) + temp_png_file.flush() + + self.gt_forcer.forced_pages = [ + self.synthtabnet_page_with_gt_words(row, page_image) + ] + + tqdm.__init__ = disabled_tqdm # disable + marker_json = converter(temp_png_file.name).children # word bboxes are ingested by way of "gt_forcer" + tqdm.__init__ = original_tqdm # enable + + return marker_json, page_image + + def extract_gt_tables(self, row, **kwargs): + return [{ + 'normalized_bbox': [0, 0, 1, 1], + 'html': row['html'] + }] + + def extract_gemini_tables(self, row, image, **kwargs): + return gemini_table_rec(image, prompt=prompt_with_header) + + def fix_table_html(self, marker_table: str, author='marker'): + if author == 'gemini': + gemini_table = marker_table.replace("\n", " ") + gemini_table = gemini_table.replace("
", " ") + return gemini_table + + marker_table_soup = BeautifulSoup(marker_table, 'html.parser') + # Synthtabnet uses thead and tbody tags + # Marker uses th tags + thead = marker_table_soup.new_tag('thead') + tbody = marker_table_soup.new_tag('tbody') + in_thead = True + + for tr in marker_table_soup.find_all('tr'): + if in_thead and all(th_tag.name == 'th' for th_tag in tr.find_all()): + thead.append(tr) + else: + in_thead = False + tbody.append(tr) + + # create anew + marker_table_soup.clear() + marker_table_soup.append(thead) + marker_table_soup.append(tbody) + + for th_tag in marker_table_soup.find_all('th'): + th_tag.name = 'td' + marker_table_html = str(marker_table_soup) + marker_table_html = marker_table_html.replace("
", " ") # Fintabnet uses spaces instead of newlines + marker_table_html = marker_table_html.replace("\n", " ") + return marker_table_html + + def construct_row_result(self, row, gt_table, marker_table, gemini_table, **kwargs): + return { + "filename": row['filename'], + "dataset_variant": row.get('dataset_variant', row.get('dataset')), + "marker_table": marker_table, + "gt_table": gt_table, + "gemini_table": gemini_table + } \ No newline at end of file diff --git a/benchmarks/table/table.py b/benchmarks/table/table.py index 4e674c28..f0bb0db9 100644 --- a/benchmarks/table/table.py +++ b/benchmarks/table/table.py @@ -14,7 +14,8 @@ from concurrent.futures import ProcessPoolExecutor from marker.settings import settings -from benchmarks.table.inference import inference_tables +from benchmarks.table.inference import FinTabNetBenchmark +from benchmarks.table.synthtabnet import SynthTabNetBenchmark from scoring import wrap_table_html, similarity_eval_html @@ -28,14 +29,16 @@ def update_teds_score(result, prefix: str = "marker"): @click.command(help="Benchmark Table to HTML Conversion") @click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "table"), help="Output path for results.") -@click.option("--dataset", type=str, default="datalab-to/fintabnet_bench_marker", help="Dataset to use") +@click.option("--against", type=str, default="fintabnet", help="Dataset to use. Options: fintabnet, synthtabnet") +@click.option("--dataset", type=str, default=None, help="Huggingface dataset to use") @click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process") @click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use") @click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.") @click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.") -@click.option("--use_gemini", is_flag=True, help="Evaluate Gemini for table recognition.") +@click.option("--use_gemini", is_flag=True, help="Evaluate Gemini alone for table recognition.") def main( result_path: str, + against: str, dataset: str, max_rows: int, max_workers: int, @@ -43,16 +46,47 @@ def main( table_rec_batch_size: int | None, use_gemini: bool = False ): + return _process( + result_path, + against, + dataset, + max_rows, + max_workers, + use_llm, + table_rec_batch_size, + use_gemini + ) + +def _process( + result_path: str, + against: str, + dataset: str, + max_rows: int, + max_workers: int, + use_llm: bool, + table_rec_batch_size: int | None, + use_gemini: bool = False +): + """Permit the benchmark to be started from python.""" start = time.time() - + if dataset is None: + if against == 'synthtabnet': + dataset = 'datalab-to/synthtabnet_bench_marker' + else: + dataset = 'datalab-to/fintabnet_bench_marker' dataset = datasets.load_dataset(dataset, split='train') dataset = dataset.shuffle(seed=0) - results, total_unaligned = inference_tables(dataset, use_llm, table_rec_batch_size, max_rows, use_gemini) + if against == 'synthtabnet': + benchmark = SynthTabNetBenchmark() + else: + benchmark = FinTabNetBenchmark() + + results, total_unaligned = benchmark.inference_tables(dataset, use_llm, table_rec_batch_size, max_rows, use_gemini) print(f"Total time: {time.time() - start}.") - print(f"Could not align {total_unaligned} tables from fintabnet.") + print(f"Could not align {total_unaligned} tables from {against}.") with ProcessPoolExecutor(max_workers=max_workers) as executor: marker_results = list( @@ -69,7 +103,7 @@ def main( with ProcessPoolExecutor(max_workers=max_workers) as executor: gemini_results = list( tqdm( - executor.map(update_teds_score, results, repeat("gemini")), desc='Computing Gemini scores', + executor.map(update_teds_score, marker_results, repeat("gemini")), desc='Computing Gemini scores', # append gemini results total=len(results) ) ) @@ -81,9 +115,10 @@ def main( print(table) print("Avg score computed by comparing marker predicted HTML with original HTML") + # gemini_results will contain both marker and gemini scores + final_results = gemini_results if use_gemini else marker_results results = { - "marker": marker_results, - "gemini": gemini_results + "marker": final_results, } out_path = Path(result_path) diff --git a/marker/providers/image.py b/marker/providers/image.py index a9f0c03a..1d094b1e 100644 --- a/marker/providers/image.py +++ b/marker/providers/image.py @@ -46,4 +46,7 @@ def get_page_lines(self, idx: int) -> List[Line]: return self.page_lines[idx] def get_page_refs(self, idx: int) -> List[Reference]: - return [] \ No newline at end of file + return [] + + def __exit__(self, exc_type, exc_val, exc_tb): + pass # images need to survive \ No newline at end of file diff --git a/marker/schema/groups/page.py b/marker/schema/groups/page.py index 9f033051..0c5b4228 100644 --- a/marker/schema/groups/page.py +++ b/marker/schema/groups/page.py @@ -235,16 +235,32 @@ def merge_blocks( if block.block_type not in self.excluded_block_types ] - max_intersections = self.compute_line_block_intersections(valid_blocks, provider_outputs) + # Process Figures and FigureGroup in the 2nd pass, so that text is preferentially assigned to text blocks + figures_pass = [ + i for i, block in enumerate(valid_blocks) + if block.block_type in [BlockTypes.Figure, BlockTypes.FigureGroup] + ] + main_pass = [ + i for i in range(len(valid_blocks)) + if i not in figures_pass + ] # Try to assign lines by intersection - assigned_line_idxs = set() + total_max_intersections = {} # provider_line_idx -> (intersection_area: np.float64, block_id) block_lines = defaultdict(list) - for line_idx, provider_output in enumerate(provider_outputs): - if line_idx in max_intersections: - block_id = max_intersections[line_idx][1] - block_lines[block_id].append((line_idx, provider_output)) - assigned_line_idxs.add(line_idx) + for block_subset_indices in [main_pass, figures_pass]: + block_subset = [valid_blocks[i] for i in block_subset_indices] + max_intersections = self.compute_line_block_intersections(block_subset, provider_outputs) + + # remove already assigned in previous passes + max_intersections = {k: v for k, v in max_intersections.items() if k not in total_max_intersections} + for line_idx, provider_output in enumerate(provider_outputs): + if line_idx in max_intersections: + block_id = max_intersections[line_idx][1] + block_lines[block_id].append((line_idx, provider_output)) + total_max_intersections.update(max_intersections) + assigned_line_idxs = set(total_max_intersections.keys()) + # If no intersection, assign by distance for line_idx in set(provider_line_idxs).difference(assigned_line_idxs): diff --git a/tests/builders/test_overriding.py b/tests/builders/test_overriding.py index 8c960448..22a09e8f 100644 --- a/tests/builders/test_overriding.py +++ b/tests/builders/test_overriding.py @@ -32,17 +32,29 @@ def get_lines(pdf: str, config=None): for block_type, block_cls in config["override_map"].items(): register_block_class(block_type, block_cls) - provider: PdfProvider = setup_pdf_provider(pdf, config) + # provider: PdfProvider = setup_pdf_provider(pdf, config) + provider = PdfProvider(pdf, config) return provider.get_page_lines(0) +@pytest.fixture(scope="function") +@pytest.mark.filename("adversarial.pdf") +def adversarial(temp_pdf): + return temp_pdf -def test_overriding_mp(): +@pytest.fixture(scope="function") +@pytest.mark.filename("adversarial_rot.pdf") +def adversarial_rot(temp_pdf): + return temp_pdf + + +def test_overriding_mp(adversarial, adversarial_rot): config = { "page_range": [0], "override_map": {BlockTypes.Line: NewLine} } - pdf_list = ["adversarial.pdf", "adversarial_rot.pdf"] + # use temp files managed by pytest fixtures + pdf_list = [adversarial.name, adversarial_rot.name] with mp.Pool(processes=2) as pool: results = pool.starmap(get_lines, [(pdf, config) for pdf in pdf_list]) diff --git a/tests/conftest.py b/tests/conftest.py index e4c083c7..005e78f1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +import os +import uuid from marker.providers.pdf import PdfProvider import tempfile from typing import Dict, Type @@ -86,10 +88,13 @@ def temp_doc(request, pdf_dataset): idx = pdf_dataset['filename'].index(filename) suffix = filename.split(".")[-1] - temp_pdf = tempfile.NamedTemporaryFile(suffix=f".{suffix}") - temp_pdf.write(pdf_dataset['pdf'][idx]) - temp_pdf.flush() - yield temp_pdf + with tempfile.TemporaryDirectory() as temp_dir: + temp_pdf_path = os.path.join(temp_dir, f"temp.{suffix}") # Randomized filename + + with open(temp_pdf_path, "wb") as temp_pdf: + temp_pdf.write(pdf_dataset['pdf'][idx]) + temp_pdf.flush() + yield temp_pdf @pytest.fixture(scope="function") @@ -148,7 +153,9 @@ def temp_image(): img = Image.new("RGB", (512, 512), color="white") draw = ImageDraw.Draw(img) draw.text((10, 10), "Hello, World!", fill="black", font_size=24) - with tempfile.NamedTemporaryFile(suffix=".png") as f: - img.save(f.name) - f.flush() - yield f + + with tempfile.TemporaryDirectory() as temp_dir: + temp_png_path = os.path.join(temp_dir, f"{uuid.uuid4()}.png") # Randomized filename + img.save(temp_png_path) + with open(temp_png_path, "rb") as f: + yield f diff --git a/tests/schema/groups/test_block_merge.py b/tests/schema/groups/test_block_merge.py new file mode 100644 index 00000000..aca94e60 --- /dev/null +++ b/tests/schema/groups/test_block_merge.py @@ -0,0 +1,73 @@ +from surya.layout.schema import LayoutResult + +from marker.builders.document import DocumentBuilder +from marker.builders.layout import LayoutBuilder +from marker.builders.line import LineBuilder +from marker.providers import ProviderOutput +import marker.schema.blocks +from marker.schema.groups.page import PageGroup +from marker.schema.polygon import PolygonBox +from tests.utils import convert_to_provider_output + + +def test_block_assignment(): + page_bbox = [0, 0, 500, 500] + blocks = [ + (100, 0, 200, 500, marker.schema.blocks.Figure), + (200, 0, 300, 500, marker.schema.blocks.Text), + ] + words = [ + (110, 0, 120, 10, "fig"), + (110, 10, 120, 20, "fig"), + (110, 20, 120, 30, "fig"), + + (220, 0, 230, 10, "intext"), + (220, 10, 230, 20, "intext"), + (220, 20, 230, 30, "intext"), + + (150, 30, 250, 40, "both"), + (150, 40, 250, 50, "both"), + (150, 50, 250, 60, "both"), + + (180, 60, 280, 70, "mosttext"), + (180, 70, 280, 80, "mosttext"), + (180, 80, 280, 90, "mosttext"), + + (120, 90, 220, 100, "mostfig"), + (120, 100, 220, 110, "mostfig"), + (120, 110, 220, 120, "mostfig"), + ] + block_ctr = 0 + + def get_counter(): + nonlocal block_ctr + o = block_ctr + block_ctr += 1 + return o + + page_group = PageGroup( + polygon=PolygonBox.from_bbox(page_bbox), + children=[ + block_cls( + polygon=PolygonBox.from_bbox([xmin, ymin, xmax, ymax]), + page_id=0, + block_id=get_counter(), + ) + for xmin, ymin, xmax, ymax, block_cls in blocks + ], + ) + + provider_outputs = [ + convert_to_provider_output( + [word], page_bbox=[0, 0, 500, 500], get_counter=get_counter + ) + for word in words + ] + + assert not page_group.children[0].structure, "figure's structure should begin with nothing in it" + assert not page_group.children[1].structure, "text's structure should begin with nothing in it" + + page_group.merge_blocks(provider_outputs, text_extraction_method="custom") + + assert len(page_group.children[0].structure) == 3, "figure should have just 3 words" + assert len(page_group.children[1].structure) == 12, "text should have the remaining 12 words" diff --git a/tests/utils.py b/tests/utils.py index e5b577b1..34684264 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,15 @@ +from typing import Callable +import marker.providers from marker.providers.pdf import PdfProvider +from marker.schema.polygon import PolygonBox +from marker.schema.text.span import Span +from marker.schema.text.line import Line import tempfile import datasets +import pdftext.schema + def setup_pdf_provider( filename='adversarial.pdf', @@ -17,3 +24,138 @@ def setup_pdf_provider( provider = PdfProvider(temp_pdf.name, config) return provider + + +def convert_to_pdftext( + word_bboxes: list[tuple[float, float, float, float, str]], + page_bbox: tuple, + page_number: int, +): + """Converts word bboxes (xmin, ymin, xmax, ymax, text) into a pdftext Page object""" + blocks = [] + block_lines = [] + + for x0, y0, x1, y1, text in word_bboxes: + word_bbox = pdftext.schema.Bbox(bbox=[x0, y0, x1, y1]) + + # Create Char entries (assuming each character has uniform bbox) + chars = [] + char_width = (x1 - x0) / len(text) + for i, char in enumerate(text): + char_bbox = pdftext.schema.Bbox( + bbox=[x0 + i * char_width, y0, x0 + (i + 1) * char_width, y1] + ) + chars.append( + pdftext.schema.Char( + bbox=char_bbox, + char=char, + rotation=0, + font={"name": "DefaultFont"}, + char_idx=i, + ) + ) + + span = pdftext.schema.Span( + bbox=word_bbox, + text=text, + font={"name": "DefaultFont"}, + chars=chars, + char_start_idx=0, + char_end_idx=len(text) - 1, + rotation=0, + url="", + ) + + line = pdftext.schema.Line(spans=[span], bbox=word_bbox, rotation=0) + + block_lines.append(line) + + block = pdftext.schema.Block(lines=block_lines, bbox=page_bbox, rotation=0) + blocks.append(block) + + page = pdftext.schema.Page( + page=page_number, + bbox=page_bbox, + width=page_bbox[2] - page_bbox[0], + height=page_bbox[3] - page_bbox[1], + blocks=blocks, + rotation=0, + refs=[], + ) + + return page + + +_block_counter = 0 + + +def convert_to_provider_output( + word_bboxes: list[tuple[float, float, float, float, str]], + page_bbox: tuple, + get_counter: Callable[[], int] = None, + page_number: int = 0, + populate_chars=False, +): + """Converts word bboxes (xmin, ymin, xmax, ymax, text) into a marker.providers.ProviderOutput object""" + + if get_counter is None: + + def get_counter(): + global _block_counter + o = _block_counter + _block_counter += 1 + return o + + all_spans = [] + all_chars = [] + min_x = page_bbox[2] + max_x = page_bbox[0] + min_y = page_bbox[3] + max_y = page_bbox[1] + for x0, y0, x1, y1, text in word_bboxes: + word_bbox = PolygonBox.from_bbox([x0, y0, x1, y1]) + + # Create Char entries (assuming each character has uniform bbox) + if populate_chars: + chars = [] + char_width = (x1 - x0) / len(text) + for i, char in enumerate(text): + char_bbox = PolygonBox.from_bbox( + [x0 + i * char_width, y0, x0 + (i + 1) * char_width, y1] + ) + chars.append( + marker.providers.Char(char=char, polygon=char_bbox, char_idx=i) + ) + + span = Span( + polygon=word_bbox, + text=text, + font="DefaultFont", + font_weight=1.0, + font_size=12.0, + minimum_position=0, + maximum_position=len(text) - 1, + formats=["plain"], + page_id=page_number, + block_id=get_counter(), + ) + all_spans.append(span) + if populate_chars: + all_chars.append(chars) + + min_x = min(min_x, x0) + max_x = max(max_x, x1) + min_y = min(min_y, y0) + max_y = max(max_y, y1) + + # line is union of bboxes + line = Line( + spans=[span], + polygon=PolygonBox.from_bbox([min_x, min_y, max_x, max_y]), + page_id=page_number, + block_id=get_counter(), + ) + + return marker.providers.ProviderOutput( + line=line, spans=all_spans, chars=all_chars if populate_chars else None + )