diff --git a/benchmarks/table/gemini.py b/benchmarks/table/gemini.py
index 5832a90f..560c681e 100644
--- a/benchmarks/table/gemini.py
+++ b/benchmarks/table/gemini.py
@@ -22,10 +22,16 @@
3. Output only the HTML for the table, starting with the
tag.
""".strip()
+prompt_with_header = prompt.replace(
+ "Only use , , and | tags. Only use the colspan and rowspan attributes if necessary. Do not use |
, , or tags.",
+ "Only use , , , , and | tags. Only use the colspan and rowspan attributes if necessary. Use | and , but do not use .", 1
+)
+
+
class TableSchema(BaseModel):
table_html: str
-def gemini_table_rec(image: Image.Image):
+def gemini_table_rec(image: Image.Image, prompt=prompt):
client = genai.Client(
api_key=settings.GOOGLE_API_KEY,
http_options={"timeout": 60000}
diff --git a/benchmarks/table/inference.py b/benchmarks/table/inference.py
index 07e5e92c..d504d0d5 100644
--- a/benchmarks/table/inference.py
+++ b/benchmarks/table/inference.py
@@ -1,3 +1,5 @@
+from functools import partialmethod
+import os
from typing import List
import numpy as np
@@ -17,166 +19,203 @@
from marker.schema.polygon import PolygonBox
from marker.util import matrix_intersection_area
-
-def extract_tables(children: List[JSONBlockOutput]):
- tables = []
- for child in children:
- if child.block_type == 'Table':
- tables.append(child)
- elif child.children:
- tables.extend(extract_tables(child.children))
- return tables
-
-def fix_table_html(table_html: str) -> str:
- marker_table_soup = BeautifulSoup(table_html, 'html.parser')
- tbody = marker_table_soup.find('tbody')
- if tbody:
- tbody.unwrap()
- for th_tag in marker_table_soup.find_all('th'):
- th_tag.name = 'td'
- for br_tag in marker_table_soup.find_all('br'):
- br_tag.replace_with(marker_table_soup.new_string(''))
-
- marker_table_html = str(marker_table_soup)
- marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
- return marker_table_html
-
-
-def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, max_rows: int, use_gemini: bool):
- models = create_model_dict()
- config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size, "disable_tqdm": True})
- total_unaligned = 0
- results = []
-
- iterations = len(dataset)
- if max_rows is not None:
- iterations = min(max_rows, len(dataset))
-
- for i in tqdm(range(iterations), desc='Converting Tables'):
- try:
- row = dataset[i]
- pdf_binary = base64.b64decode(row['pdf'])
- gt_tables = row['tables'] # Already sorted by reading order, which is what marker returns
-
- # Only use the basic table processors
- converter = TableConverter(
- config=config_parser.generate_config_dict(),
- artifact_dict=models,
- processor_list=[
- "marker.processors.table.TableProcessor",
- "marker.processors.llm.llm_table.LLMTableProcessor",
- ],
- renderer=config_parser.get_renderer()
- )
-
- with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as temp_pdf_file:
+class FinTabNetBenchmark:
+ def extract_tables_from_json(self, children: List[JSONBlockOutput]) -> List[JSONBlockOutput]:
+ tables = []
+ for child in children:
+ if child.block_type == 'Table':
+ tables.append(child)
+ elif child.children:
+ tables.extend(self.extract_tables_from_json(child.children))
+ return tables
+
+ def get_converter(self, models, config_parser, **kwargs):
+ # Only use the basic table processors
+ converter = TableConverter(
+ config=config_parser.generate_config_dict(),
+ artifact_dict=models,
+ processor_list=[
+ "marker.processors.table.TableProcessor",
+ "marker.processors.llm.llm_table.LLMTableProcessor",
+ ],
+ renderer=config_parser.get_renderer()
+ )
+ return converter
+
+ def extract_tables_from_doc(self, converter, row):
+ """Extract table and images from pdf; produce marker_json and page_image"""
+
+ pdf_binary = base64.b64decode(row['pdf'])
+
+ # https://stackoverflow.com/a/23212515
+ with tempfile.TemporaryDirectory() as temp_dir:
+ temp_filepath = os.path.join(temp_dir, 'temp.pdf')
+ with open(temp_filepath, 'wb') as temp_pdf_file:
temp_pdf_file.write(pdf_binary)
- temp_pdf_file.seek(0)
- marker_json = converter(temp_pdf_file.name).children
-
- doc = pdfium.PdfDocument(temp_pdf_file.name)
- page_image = doc[0].render(scale=96/72).to_pil()
- doc.close()
-
- if len(marker_json) == 0 or len(gt_tables) == 0:
- print(f'No tables detected, skipping...')
- total_unaligned += len(gt_tables)
- continue
-
- marker_tables = extract_tables(marker_json)
- marker_table_boxes = [table.bbox for table in marker_tables]
- page_bbox = marker_json[0].bbox
-
- if len(marker_tables) != len(gt_tables):
- print(f'Number of tables do not match, skipping...')
- total_unaligned += len(gt_tables)
- continue
-
- table_images = [
- page_image.crop(
- PolygonBox.from_bbox(bbox)
- .rescale(
- (page_bbox[2], page_bbox[3]), (page_image.width, page_image.height)
- ).bbox
- )
- for bbox
- in marker_table_boxes
- ]
-
- # Normalize the bboxes
- for bbox in marker_table_boxes:
- bbox[0] = bbox[0] / page_bbox[2]
- bbox[1] = bbox[1] / page_bbox[3]
- bbox[2] = bbox[2] / page_bbox[2]
- bbox[3] = bbox[3] / page_bbox[3]
-
- gt_boxes = [table['normalized_bbox'] for table in gt_tables]
- gt_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in gt_boxes]
- marker_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in marker_table_boxes]
- table_alignments = matrix_intersection_area(gt_boxes, marker_table_boxes)
-
- aligned_tables = []
- used_tables = set()
- unaligned_tables = set()
- for table_idx, alignment in enumerate(table_alignments):
- try:
- max_area = np.max(alignment)
- aligned_idx = np.argmax(alignment)
- except ValueError:
- # No alignment found
- unaligned_tables.add(table_idx)
- continue
-
- if max_area <= .01:
- # No alignment found
- unaligned_tables.add(table_idx)
+ temp_pdf_file.flush()
+
+ marker_json = converter(temp_pdf_file.name).children
+
+ doc = pdfium.PdfDocument(temp_pdf_file.name)
+ page_image = doc[0].render(scale=96/72).to_pil()
+ doc.close()
+ return marker_json, page_image
+
+ def extract_gt_tables(self, row, **kwargs):
+ return row['tables']
+
+ def extract_gemini_tables(self, row, image, **kwargs):
+ return gemini_table_rec(image)
+
+ def fix_table_html(self, table_html: str, author='marker') -> str:
+ # marker wraps the table in | which fintabnet data doesn't
+ # Fintabnet doesn't use th tags, need to be replaced for fair comparison
+ marker_table_soup = BeautifulSoup(table_html, 'html.parser')
+ tbody = marker_table_soup.find('tbody')
+ if tbody:
+ tbody.unwrap()
+ for th_tag in marker_table_soup.find_all('th'):
+ th_tag.name = 'td'
+ for br_tag in marker_table_soup.find_all('br'):
+ br_tag.replace_with(marker_table_soup.new_string(''))
+
+ marker_table_html = str(marker_table_soup)
+ marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
+ return marker_table_html
+
+ def construct_row_result(self, row, gt_table, marker_table, gemini_table, **kwargs):
+ return {
+ "marker_table": marker_table,
+ "gt_table": gt_table,
+ "gemini_table": gemini_table
+ }
+
+ def _load_marker_output(self, i, row):
+ """Opportunity to load marker output from a previous run"""
+ return None, None
+
+ def _cache_marker_output(self, i, row, marker_json, page_image):
+ """Opportunity to save marker output, to allow resuming if interrupted"""
+ pass
+
+ def inference_tables(self, dataset, use_llm: bool, table_rec_batch_size: int | None, max_rows: int, use_gemini: bool):
+
+ total_unaligned = 0
+ results = []
+
+ iterations = len(dataset)
+ if max_rows is not None:
+ iterations = min(max_rows, len(dataset))
+
+ models = create_model_dict()
+ config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm,
+ "table_rec_batch_size": table_rec_batch_size, "disable_tqdm": True})
+ converter = self.get_converter(models, config_parser)
+ for i in tqdm(range(iterations), desc='Converting Tables'):
+ try:
+ row = dataset[i]
+
+ # save progress while running
+ marker_json, page_image = self._load_marker_output(i, row)
+ if marker_json is None:
+ marker_json, page_image = self.extract_tables_from_doc(converter, row)
+ self._cache_marker_output(i, row, marker_json, page_image)
+
+ gt_tables = self.extract_gt_tables(row) # Already sorted by reading order, which is what marker returns
+
+ if len(marker_json) == 0 or len(gt_tables) == 0:
+ print(f'No tables detected, skipping...')
+ total_unaligned += len(gt_tables)
continue
- if aligned_idx in used_tables:
- # Marker table already aligned with another gt table
- unaligned_tables.add(table_idx)
- continue
+ marker_tables = self.extract_tables_from_json(marker_json)
+ marker_table_boxes = [table.bbox for table in marker_tables]
+ page_bbox = marker_json[0].bbox
- # Gt table doesn't align well with any marker table
- gt_table_pct = gt_areas[table_idx] / max_area
- if not .85 < gt_table_pct < 1.15:
- unaligned_tables.add(table_idx)
+ if len(marker_tables) != len(gt_tables):
+ print(f'Number of tables do not match, skipping...')
+ total_unaligned += len(gt_tables)
continue
- # Marker table doesn't align with gt table
- marker_table_pct = marker_areas[aligned_idx] / max_area
- if not .85 < marker_table_pct < 1.15:
- unaligned_tables.add(table_idx)
- continue
-
- gemini_html = ""
- if use_gemini:
+ table_images = [
+ page_image.crop(
+ PolygonBox.from_bbox(bbox)
+ .rescale(
+ (page_bbox[2], page_bbox[3]), (page_image.width, page_image.height)
+ ).bbox
+ )
+ for bbox
+ in marker_table_boxes
+ ]
+
+ # Normalize the bboxes
+ for bbox in marker_table_boxes:
+ bbox[0] = bbox[0] / page_bbox[2]
+ bbox[1] = bbox[1] / page_bbox[3]
+ bbox[2] = bbox[2] / page_bbox[2]
+ bbox[3] = bbox[3] / page_bbox[3]
+
+ gt_boxes = [table['normalized_bbox'] for table in gt_tables]
+ gt_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in gt_boxes]
+ marker_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in marker_table_boxes]
+ table_alignments = matrix_intersection_area(gt_boxes, marker_table_boxes)
+
+ aligned_tables = []
+ used_tables = set()
+ unaligned_tables = set()
+ for table_idx, alignment in enumerate(table_alignments):
try:
- gemini_html = gemini_table_rec(table_images[aligned_idx])
- except Exception as e:
- print(f'Gemini failed: {e}')
-
- aligned_tables.append(
- (marker_tables[aligned_idx], gt_tables[table_idx], gemini_html)
- )
- used_tables.add(aligned_idx)
-
- total_unaligned += len(unaligned_tables)
-
- for marker_table, gt_table, gemini_table in aligned_tables:
- gt_table_html = gt_table['html']
-
- # marker wraps the table in which fintabnet data doesn't
- # Fintabnet doesn't use th tags, need to be replaced for fair comparison
- marker_table_html = fix_table_html(marker_table.html)
- gemini_table_html = fix_table_html(gemini_table)
-
- results.append({
- "marker_table": marker_table_html,
- "gt_table": gt_table_html,
- "gemini_table": gemini_table_html
- })
- except pdfium.PdfiumError:
- print('Broken PDF, Skipping...')
- continue
- return results, total_unaligned
\ No newline at end of file
+ max_area = np.max(alignment)
+ aligned_idx = np.argmax(alignment)
+ except ValueError:
+ # No alignment found
+ unaligned_tables.add(table_idx)
+ continue
+
+ if max_area <= .01:
+ # No alignment found
+ unaligned_tables.add(table_idx)
+ continue
+
+ if aligned_idx in used_tables:
+ # Marker table already aligned with another gt table
+ unaligned_tables.add(table_idx)
+ continue
+
+ # Gt table doesn't align well with any marker table
+ gt_table_pct = gt_areas[table_idx] / max_area
+ if not .85 < gt_table_pct < 1.15:
+ unaligned_tables.add(table_idx)
+ continue
+
+ # Marker table doesn't align with gt table
+ marker_table_pct = marker_areas[aligned_idx] / max_area
+ if not .85 < marker_table_pct < 1.15:
+ unaligned_tables.add(table_idx)
+ continue
+
+ gemini_html = ""
+ if use_gemini:
+ try:
+ gemini_html = self.extract_gemini_tables(row, table_images[aligned_idx])
+ except Exception as e:
+ print(f'Gemini failed: {e}')
+
+ aligned_tables.append(
+ (marker_tables[aligned_idx], gt_tables[table_idx], gemini_html)
+ )
+ used_tables.add(aligned_idx)
+
+ total_unaligned += len(unaligned_tables)
+
+ for marker_table, gt_table, gemini_table in aligned_tables:
+ gt_table_html = gt_table['html']
+
+ marker_table_html = self.fix_table_html(marker_table.html, author='marker')
+ gemini_table_html = self.fix_table_html(gemini_table, author='gemini')
+
+ results.append(self.construct_row_result(row, gt_table_html, marker_table_html, gemini_table_html))
+ except pdfium.PdfiumError:
+ print('Broken PDF, Skipping...')
+ continue
+ return results, total_unaligned
\ No newline at end of file
diff --git a/benchmarks/table/scoring.py b/benchmarks/table/scoring.py
index 940bd6e4..9b33ab85 100644
--- a/benchmarks/table/scoring.py
+++ b/benchmarks/table/scoring.py
@@ -9,6 +9,9 @@
from collections import deque
def wrap_table_html(table_html:str)->str:
+ if not table_html.startswith(''):
+ table_html = f''
+
return f'{table_html}'
class TableTree(Tree):
diff --git a/benchmarks/table/synthtabnet.py b/benchmarks/table/synthtabnet.py
new file mode 100644
index 00000000..9c15120d
--- /dev/null
+++ b/benchmarks/table/synthtabnet.py
@@ -0,0 +1,262 @@
+"""
+Detect tables when the pdf is not available
+"""
+
+
+from functools import cache, partialmethod
+import io
+import os
+import tempfile
+from typing import Tuple
+
+from bs4 import BeautifulSoup
+from tqdm import tqdm
+from benchmarks.table.gemini import gemini_table_rec, prompt_with_header
+from benchmarks.table.inference import FinTabNetBenchmark
+from marker.builders.document import DocumentBuilder
+from marker.processors import BaseProcessor
+from marker.processors.llm.llm_complex import LLMComplexRegionProcessor
+from marker.processors.llm.llm_form import LLMFormProcessor
+from marker.processors.llm.llm_table import LLMTableProcessor
+from marker.processors.llm.llm_table_merge import LLMTableMergeProcessor
+from marker.processors.table import TableProcessor
+from marker.converters.table import TableConverter
+
+import pdftext.schema
+from surya.detection import DetectionPredictor
+from surya.recognition import RecognitionPredictor
+from surya.table_rec import TableRecPredictor
+
+from marker.providers.registry import provider_from_filepath
+from marker.renderers.json import JSONBlockOutput
+from tests.utils import convert_to_pdftext
+
+
+class GroundTruthPagesForcer:
+ def __init__(self):
+ # debug data
+ self.forced_pages = None # [convert_to_page(get_raw_test_data(), [0, 0, 612.0, 792.0], 0)]
+ def __call__(self, filepath):
+ return self.forced_pages
+
+class ChangedTableProcessor(TableProcessor):
+
+
+ def __init__(
+ self,
+ detection_model: DetectionPredictor,
+ recognition_model: RecognitionPredictor,
+ table_rec_model: TableRecPredictor,
+ config=None,
+ _gt_forcer: GroundTruthPagesForcer=None,
+ ):
+ super().__init__(
+ detection_model=detection_model,
+ recognition_model=recognition_model,
+ table_rec_model=table_rec_model,
+ config=config
+ )
+ self.detection_model = detection_model
+ self.recognition_model = recognition_model
+ self.table_rec_model = table_rec_model
+ self._gt_forcer = _gt_forcer
+
+ def assign_pdftext_lines(self, extract_blocks: list[dict], filepath):
+ forced_pages = self._gt_forcer(filepath)
+ self.assign_forced_lines(extract_blocks, forced_pages)
+
+ def assign_ocr_lines(self, ocr_blocks):
+ forced_pages = self._gt_forcer(None)
+ self.assign_forced_lines(ocr_blocks, forced_pages)
+
+ def assign_forced_lines(self, extract_blocks: list[dict], forced_pages):
+ table_inputs = []
+ unique_pages = list(set([t["page_id"] for t in extract_blocks]))
+ if len(unique_pages) == 0:
+ return
+
+ for page in unique_pages:
+ tables = []
+ img_size = None
+ for block in extract_blocks:
+ if block["page_id"] == page:
+ tables.append(block["table_bbox"])
+ img_size = block["img_size"]
+
+ table_inputs.append({
+ "tables": tables,
+ "img_size": img_size
+ })
+
+ # NOTE: added this method
+ def table_output(filepath, table_inputs, page_range=unique_pages):
+ # mock forced_pages
+ pages = []
+ for i in page_range:
+ pages.append(forced_pages[i])
+
+
+ from pdftext.extraction import table_cell_text
+ out_tables = []
+ for page, table_input in zip(forced_pages, table_inputs):
+ tables = table_cell_text(table_input["tables"], page, table_input["img_size"])
+ assert len(tables) == len(table_input["tables"]), "Number of tables and table inputs must match"
+ out_tables.append(tables)
+ return out_tables
+ cell_text = table_output(None, table_inputs, page_range=unique_pages)
+
+ assert len(cell_text) == len(unique_pages), "Number of pages and table inputs must match"
+
+ for pidx, (page_tables, pnum) in enumerate(zip(cell_text, unique_pages)):
+ table_idx = 0
+ for block in extract_blocks:
+ if block["page_id"] == pnum:
+ block["table_text_lines"] = page_tables[table_idx]
+ table_idx += 1
+ assert table_idx == len(page_tables), "Number of tables and table inputs must match"
+
+
+class ChangedTableConverter(TableConverter):
+ default_processors: Tuple[BaseProcessor, ...] = (
+ ChangedTableProcessor, # NOTE: changed this line
+ LLMTableProcessor,
+ LLMTableMergeProcessor,
+ LLMFormProcessor,
+ LLMComplexRegionProcessor,
+ )
+
+ @cache
+ def build_document(self, filepath: str):
+ provider_cls = provider_from_filepath(filepath)
+ layout_builder = self.resolve_dependencies(self.layout_builder_class)
+ line_builder = lambda *args, **kwargs: None
+ ocr_builder = lambda *args, **kwargs: None
+ document_builder = DocumentBuilder(self.config)
+ document_builder.disable_ocr = True
+ with provider_cls(filepath, self.config) as provider:
+ document = document_builder(provider, layout_builder, line_builder, ocr_builder)
+
+ for page in document.pages:
+ page.structure = [p for p in page.structure if p.block_type in self.converter_block_types]
+
+ for processor in self.processor_list:
+ processor(document)
+
+ return document
+
+class SynthTabNetBenchmark(FinTabNetBenchmark):
+ gt_forcer: GroundTruthPagesForcer
+
+ def get_converter(self, models, config_parser, **kwargs):
+ config_parser.cli_options['force_layout_block'] = 'Table'
+ config_parser.cli_options['disable_tqdm'] = True
+ config_parser.cli_options['disable_ocr'] = True
+ config_parser.cli_options['document_ocr_threshold'] = 0.0
+ self.gt_forcer = GroundTruthPagesForcer()
+
+ return ChangedTableConverter(
+ config={
+ **config_parser.generate_config_dict(),
+ "document_ocr_threshold": 0
+ # never perform OCR for evaluation: we know the ground truth
+ },
+ artifact_dict={
+ '_gt_forcer': self.gt_forcer,
+ **models
+ },
+ processor_list=[
+ "marker.processors.table.TableProcessor",
+ "marker.processors.llm.llm_table.LLMTableProcessor",
+ ],
+ renderer=config_parser.get_renderer()
+ )
+
+ def synthtabnet_page_with_gt_words(self, row, page_image):
+ bboxes = row['word_bboxes']
+ words = row['words']
+ good = [i for i, word in enumerate(words) if word] # allow only non-empty words
+ bboxes = [(*bboxes[i], words[i]) for i in good]
+
+ image_bbox = [0, 0, page_image.width, page_image.height]
+ return convert_to_pdftext(bboxes, image_bbox, 0)
+
+ def extract_tables_from_doc(self, converter, row):
+ original_tqdm = tqdm.__init__
+
+ # disabled_tqdm = original_tqdm
+ def disabled_tqdm(*args, **kwargs):
+ if not kwargs.get('disable', False):
+ kwargs['disable'] = True
+ return original_tqdm(*args, **kwargs)
+
+ # https://stackoverflow.com/a/23212515
+ with tempfile.TemporaryDirectory() as temp_dir:
+
+ bytesio = io.BytesIO()
+ page_image = row['image'] # PIL.Image.Image
+ page_image.save(bytesio, format="PNG")
+
+ temp_filepath = os.path.join(temp_dir, 'temp.png')
+ with open(temp_filepath, 'wb') as temp_png_file:
+ temp_png_file.write(bytesio.getvalue())
+ temp_png_file.flush()
+
+ self.gt_forcer.forced_pages = [
+ self.synthtabnet_page_with_gt_words(row, page_image)
+ ]
+
+ tqdm.__init__ = disabled_tqdm # disable
+ marker_json = converter(temp_png_file.name).children # word bboxes are ingested by way of "gt_forcer"
+ tqdm.__init__ = original_tqdm # enable
+
+ return marker_json, page_image
+
+ def extract_gt_tables(self, row, **kwargs):
+ return [{
+ 'normalized_bbox': [0, 0, 1, 1],
+ 'html': row['html']
+ }]
+
+ def extract_gemini_tables(self, row, image, **kwargs):
+ return gemini_table_rec(image, prompt=prompt_with_header)
+
+ def fix_table_html(self, marker_table: str, author='marker'):
+ if author == 'gemini':
+ gemini_table = marker_table.replace("\n", " ")
+ gemini_table = gemini_table.replace(" ", " ")
+ return gemini_table
+
+ marker_table_soup = BeautifulSoup(marker_table, 'html.parser')
+ # Synthtabnet uses thead and tbody tags
+ # Marker uses th tags
+ thead = marker_table_soup.new_tag('thead')
+ tbody = marker_table_soup.new_tag('tbody')
+ in_thead = True
+
+ for tr in marker_table_soup.find_all('tr'):
+ if in_thead and all(th_tag.name == 'th' for th_tag in tr.find_all()):
+ thead.append(tr)
+ else:
+ in_thead = False
+ tbody.append(tr)
+
+ # create anew
+ marker_table_soup.clear()
+ marker_table_soup.append(thead)
+ marker_table_soup.append(tbody)
+
+ for th_tag in marker_table_soup.find_all('th'):
+ th_tag.name = 'td'
+ marker_table_html = str(marker_table_soup)
+ marker_table_html = marker_table_html.replace(" ", " ") # Fintabnet uses spaces instead of newlines
+ marker_table_html = marker_table_html.replace("\n", " ")
+ return marker_table_html
+
+ def construct_row_result(self, row, gt_table, marker_table, gemini_table, **kwargs):
+ return {
+ "filename": row['filename'],
+ "dataset_variant": row.get('dataset_variant', row.get('dataset')),
+ "marker_table": marker_table,
+ "gt_table": gt_table,
+ "gemini_table": gemini_table
+ }
\ No newline at end of file
diff --git a/benchmarks/table/table.py b/benchmarks/table/table.py
index 4e674c28..f0bb0db9 100644
--- a/benchmarks/table/table.py
+++ b/benchmarks/table/table.py
@@ -14,7 +14,8 @@
from concurrent.futures import ProcessPoolExecutor
from marker.settings import settings
-from benchmarks.table.inference import inference_tables
+from benchmarks.table.inference import FinTabNetBenchmark
+from benchmarks.table.synthtabnet import SynthTabNetBenchmark
from scoring import wrap_table_html, similarity_eval_html
@@ -28,14 +29,16 @@ def update_teds_score(result, prefix: str = "marker"):
@click.command(help="Benchmark Table to HTML Conversion")
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "table"), help="Output path for results.")
-@click.option("--dataset", type=str, default="datalab-to/fintabnet_bench_marker", help="Dataset to use")
+@click.option("--against", type=str, default="fintabnet", help="Dataset to use. Options: fintabnet, synthtabnet")
+@click.option("--dataset", type=str, default=None, help="Huggingface dataset to use")
@click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process")
@click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
@click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
@click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.")
-@click.option("--use_gemini", is_flag=True, help="Evaluate Gemini for table recognition.")
+@click.option("--use_gemini", is_flag=True, help="Evaluate Gemini alone for table recognition.")
def main(
result_path: str,
+ against: str,
dataset: str,
max_rows: int,
max_workers: int,
@@ -43,16 +46,47 @@ def main(
table_rec_batch_size: int | None,
use_gemini: bool = False
):
+ return _process(
+ result_path,
+ against,
+ dataset,
+ max_rows,
+ max_workers,
+ use_llm,
+ table_rec_batch_size,
+ use_gemini
+ )
+
+def _process(
+ result_path: str,
+ against: str,
+ dataset: str,
+ max_rows: int,
+ max_workers: int,
+ use_llm: bool,
+ table_rec_batch_size: int | None,
+ use_gemini: bool = False
+):
+ """Permit the benchmark to be started from python."""
start = time.time()
-
+ if dataset is None:
+ if against == 'synthtabnet':
+ dataset = 'datalab-to/synthtabnet_bench_marker'
+ else:
+ dataset = 'datalab-to/fintabnet_bench_marker'
dataset = datasets.load_dataset(dataset, split='train')
dataset = dataset.shuffle(seed=0)
- results, total_unaligned = inference_tables(dataset, use_llm, table_rec_batch_size, max_rows, use_gemini)
+ if against == 'synthtabnet':
+ benchmark = SynthTabNetBenchmark()
+ else:
+ benchmark = FinTabNetBenchmark()
+
+ results, total_unaligned = benchmark.inference_tables(dataset, use_llm, table_rec_batch_size, max_rows, use_gemini)
print(f"Total time: {time.time() - start}.")
- print(f"Could not align {total_unaligned} tables from fintabnet.")
+ print(f"Could not align {total_unaligned} tables from {against}.")
with ProcessPoolExecutor(max_workers=max_workers) as executor:
marker_results = list(
@@ -69,7 +103,7 @@ def main(
with ProcessPoolExecutor(max_workers=max_workers) as executor:
gemini_results = list(
tqdm(
- executor.map(update_teds_score, results, repeat("gemini")), desc='Computing Gemini scores',
+ executor.map(update_teds_score, marker_results, repeat("gemini")), desc='Computing Gemini scores', # append gemini results
total=len(results)
)
)
@@ -81,9 +115,10 @@ def main(
print(table)
print("Avg score computed by comparing marker predicted HTML with original HTML")
+ # gemini_results will contain both marker and gemini scores
+ final_results = gemini_results if use_gemini else marker_results
results = {
- "marker": marker_results,
- "gemini": gemini_results
+ "marker": final_results,
}
out_path = Path(result_path)
diff --git a/marker/providers/image.py b/marker/providers/image.py
index a9f0c03a..1d094b1e 100644
--- a/marker/providers/image.py
+++ b/marker/providers/image.py
@@ -46,4 +46,7 @@ def get_page_lines(self, idx: int) -> List[Line]:
return self.page_lines[idx]
def get_page_refs(self, idx: int) -> List[Reference]:
- return []
\ No newline at end of file
+ return []
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass # images need to survive
\ No newline at end of file
diff --git a/marker/schema/groups/page.py b/marker/schema/groups/page.py
index 9f033051..0c5b4228 100644
--- a/marker/schema/groups/page.py
+++ b/marker/schema/groups/page.py
@@ -235,16 +235,32 @@ def merge_blocks(
if block.block_type not in self.excluded_block_types
]
- max_intersections = self.compute_line_block_intersections(valid_blocks, provider_outputs)
+ # Process Figures and FigureGroup in the 2nd pass, so that text is preferentially assigned to text blocks
+ figures_pass = [
+ i for i, block in enumerate(valid_blocks)
+ if block.block_type in [BlockTypes.Figure, BlockTypes.FigureGroup]
+ ]
+ main_pass = [
+ i for i in range(len(valid_blocks))
+ if i not in figures_pass
+ ]
# Try to assign lines by intersection
- assigned_line_idxs = set()
+ total_max_intersections = {} # provider_line_idx -> (intersection_area: np.float64, block_id)
block_lines = defaultdict(list)
- for line_idx, provider_output in enumerate(provider_outputs):
- if line_idx in max_intersections:
- block_id = max_intersections[line_idx][1]
- block_lines[block_id].append((line_idx, provider_output))
- assigned_line_idxs.add(line_idx)
+ for block_subset_indices in [main_pass, figures_pass]:
+ block_subset = [valid_blocks[i] for i in block_subset_indices]
+ max_intersections = self.compute_line_block_intersections(block_subset, provider_outputs)
+
+ # remove already assigned in previous passes
+ max_intersections = {k: v for k, v in max_intersections.items() if k not in total_max_intersections}
+ for line_idx, provider_output in enumerate(provider_outputs):
+ if line_idx in max_intersections:
+ block_id = max_intersections[line_idx][1]
+ block_lines[block_id].append((line_idx, provider_output))
+ total_max_intersections.update(max_intersections)
+ assigned_line_idxs = set(total_max_intersections.keys())
+
# If no intersection, assign by distance
for line_idx in set(provider_line_idxs).difference(assigned_line_idxs):
diff --git a/tests/builders/test_overriding.py b/tests/builders/test_overriding.py
index 8c960448..22a09e8f 100644
--- a/tests/builders/test_overriding.py
+++ b/tests/builders/test_overriding.py
@@ -32,17 +32,29 @@ def get_lines(pdf: str, config=None):
for block_type, block_cls in config["override_map"].items():
register_block_class(block_type, block_cls)
- provider: PdfProvider = setup_pdf_provider(pdf, config)
+ # provider: PdfProvider = setup_pdf_provider(pdf, config)
+ provider = PdfProvider(pdf, config)
return provider.get_page_lines(0)
+@pytest.fixture(scope="function")
+@pytest.mark.filename("adversarial.pdf")
+def adversarial(temp_pdf):
+ return temp_pdf
-def test_overriding_mp():
+@pytest.fixture(scope="function")
+@pytest.mark.filename("adversarial_rot.pdf")
+def adversarial_rot(temp_pdf):
+ return temp_pdf
+
+
+def test_overriding_mp(adversarial, adversarial_rot):
config = {
"page_range": [0],
"override_map": {BlockTypes.Line: NewLine}
}
- pdf_list = ["adversarial.pdf", "adversarial_rot.pdf"]
+ # use temp files managed by pytest fixtures
+ pdf_list = [adversarial.name, adversarial_rot.name]
with mp.Pool(processes=2) as pool:
results = pool.starmap(get_lines, [(pdf, config) for pdf in pdf_list])
diff --git a/tests/conftest.py b/tests/conftest.py
index e4c083c7..005e78f1 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,3 +1,5 @@
+import os
+import uuid
from marker.providers.pdf import PdfProvider
import tempfile
from typing import Dict, Type
@@ -86,10 +88,13 @@ def temp_doc(request, pdf_dataset):
idx = pdf_dataset['filename'].index(filename)
suffix = filename.split(".")[-1]
- temp_pdf = tempfile.NamedTemporaryFile(suffix=f".{suffix}")
- temp_pdf.write(pdf_dataset['pdf'][idx])
- temp_pdf.flush()
- yield temp_pdf
+ with tempfile.TemporaryDirectory() as temp_dir:
+ temp_pdf_path = os.path.join(temp_dir, f"temp.{suffix}") # Randomized filename
+
+ with open(temp_pdf_path, "wb") as temp_pdf:
+ temp_pdf.write(pdf_dataset['pdf'][idx])
+ temp_pdf.flush()
+ yield temp_pdf
@pytest.fixture(scope="function")
@@ -148,7 +153,9 @@ def temp_image():
img = Image.new("RGB", (512, 512), color="white")
draw = ImageDraw.Draw(img)
draw.text((10, 10), "Hello, World!", fill="black", font_size=24)
- with tempfile.NamedTemporaryFile(suffix=".png") as f:
- img.save(f.name)
- f.flush()
- yield f
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+ temp_png_path = os.path.join(temp_dir, f"{uuid.uuid4()}.png") # Randomized filename
+ img.save(temp_png_path)
+ with open(temp_png_path, "rb") as f:
+ yield f
diff --git a/tests/schema/groups/test_block_merge.py b/tests/schema/groups/test_block_merge.py
new file mode 100644
index 00000000..aca94e60
--- /dev/null
+++ b/tests/schema/groups/test_block_merge.py
@@ -0,0 +1,73 @@
+from surya.layout.schema import LayoutResult
+
+from marker.builders.document import DocumentBuilder
+from marker.builders.layout import LayoutBuilder
+from marker.builders.line import LineBuilder
+from marker.providers import ProviderOutput
+import marker.schema.blocks
+from marker.schema.groups.page import PageGroup
+from marker.schema.polygon import PolygonBox
+from tests.utils import convert_to_provider_output
+
+
+def test_block_assignment():
+ page_bbox = [0, 0, 500, 500]
+ blocks = [
+ (100, 0, 200, 500, marker.schema.blocks.Figure),
+ (200, 0, 300, 500, marker.schema.blocks.Text),
+ ]
+ words = [
+ (110, 0, 120, 10, "fig"),
+ (110, 10, 120, 20, "fig"),
+ (110, 20, 120, 30, "fig"),
+
+ (220, 0, 230, 10, "intext"),
+ (220, 10, 230, 20, "intext"),
+ (220, 20, 230, 30, "intext"),
+
+ (150, 30, 250, 40, "both"),
+ (150, 40, 250, 50, "both"),
+ (150, 50, 250, 60, "both"),
+
+ (180, 60, 280, 70, "mosttext"),
+ (180, 70, 280, 80, "mosttext"),
+ (180, 80, 280, 90, "mosttext"),
+
+ (120, 90, 220, 100, "mostfig"),
+ (120, 100, 220, 110, "mostfig"),
+ (120, 110, 220, 120, "mostfig"),
+ ]
+ block_ctr = 0
+
+ def get_counter():
+ nonlocal block_ctr
+ o = block_ctr
+ block_ctr += 1
+ return o
+
+ page_group = PageGroup(
+ polygon=PolygonBox.from_bbox(page_bbox),
+ children=[
+ block_cls(
+ polygon=PolygonBox.from_bbox([xmin, ymin, xmax, ymax]),
+ page_id=0,
+ block_id=get_counter(),
+ )
+ for xmin, ymin, xmax, ymax, block_cls in blocks
+ ],
+ )
+
+ provider_outputs = [
+ convert_to_provider_output(
+ [word], page_bbox=[0, 0, 500, 500], get_counter=get_counter
+ )
+ for word in words
+ ]
+
+ assert not page_group.children[0].structure, "figure's structure should begin with nothing in it"
+ assert not page_group.children[1].structure, "text's structure should begin with nothing in it"
+
+ page_group.merge_blocks(provider_outputs, text_extraction_method="custom")
+
+ assert len(page_group.children[0].structure) == 3, "figure should have just 3 words"
+ assert len(page_group.children[1].structure) == 12, "text should have the remaining 12 words"
diff --git a/tests/utils.py b/tests/utils.py
index e5b577b1..34684264 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1,8 +1,15 @@
+from typing import Callable
+import marker.providers
from marker.providers.pdf import PdfProvider
+from marker.schema.polygon import PolygonBox
+from marker.schema.text.span import Span
+from marker.schema.text.line import Line
import tempfile
import datasets
+import pdftext.schema
+
def setup_pdf_provider(
filename='adversarial.pdf',
@@ -17,3 +24,138 @@ def setup_pdf_provider(
provider = PdfProvider(temp_pdf.name, config)
return provider
+
+
+def convert_to_pdftext(
+ word_bboxes: list[tuple[float, float, float, float, str]],
+ page_bbox: tuple,
+ page_number: int,
+):
+ """Converts word bboxes (xmin, ymin, xmax, ymax, text) into a pdftext Page object"""
+ blocks = []
+ block_lines = []
+
+ for x0, y0, x1, y1, text in word_bboxes:
+ word_bbox = pdftext.schema.Bbox(bbox=[x0, y0, x1, y1])
+
+ # Create Char entries (assuming each character has uniform bbox)
+ chars = []
+ char_width = (x1 - x0) / len(text)
+ for i, char in enumerate(text):
+ char_bbox = pdftext.schema.Bbox(
+ bbox=[x0 + i * char_width, y0, x0 + (i + 1) * char_width, y1]
+ )
+ chars.append(
+ pdftext.schema.Char(
+ bbox=char_bbox,
+ char=char,
+ rotation=0,
+ font={"name": "DefaultFont"},
+ char_idx=i,
+ )
+ )
+
+ span = pdftext.schema.Span(
+ bbox=word_bbox,
+ text=text,
+ font={"name": "DefaultFont"},
+ chars=chars,
+ char_start_idx=0,
+ char_end_idx=len(text) - 1,
+ rotation=0,
+ url="",
+ )
+
+ line = pdftext.schema.Line(spans=[span], bbox=word_bbox, rotation=0)
+
+ block_lines.append(line)
+
+ block = pdftext.schema.Block(lines=block_lines, bbox=page_bbox, rotation=0)
+ blocks.append(block)
+
+ page = pdftext.schema.Page(
+ page=page_number,
+ bbox=page_bbox,
+ width=page_bbox[2] - page_bbox[0],
+ height=page_bbox[3] - page_bbox[1],
+ blocks=blocks,
+ rotation=0,
+ refs=[],
+ )
+
+ return page
+
+
+_block_counter = 0
+
+
+def convert_to_provider_output(
+ word_bboxes: list[tuple[float, float, float, float, str]],
+ page_bbox: tuple,
+ get_counter: Callable[[], int] = None,
+ page_number: int = 0,
+ populate_chars=False,
+):
+ """Converts word bboxes (xmin, ymin, xmax, ymax, text) into a marker.providers.ProviderOutput object"""
+
+ if get_counter is None:
+
+ def get_counter():
+ global _block_counter
+ o = _block_counter
+ _block_counter += 1
+ return o
+
+ all_spans = []
+ all_chars = []
+ min_x = page_bbox[2]
+ max_x = page_bbox[0]
+ min_y = page_bbox[3]
+ max_y = page_bbox[1]
+ for x0, y0, x1, y1, text in word_bboxes:
+ word_bbox = PolygonBox.from_bbox([x0, y0, x1, y1])
+
+ # Create Char entries (assuming each character has uniform bbox)
+ if populate_chars:
+ chars = []
+ char_width = (x1 - x0) / len(text)
+ for i, char in enumerate(text):
+ char_bbox = PolygonBox.from_bbox(
+ [x0 + i * char_width, y0, x0 + (i + 1) * char_width, y1]
+ )
+ chars.append(
+ marker.providers.Char(char=char, polygon=char_bbox, char_idx=i)
+ )
+
+ span = Span(
+ polygon=word_bbox,
+ text=text,
+ font="DefaultFont",
+ font_weight=1.0,
+ font_size=12.0,
+ minimum_position=0,
+ maximum_position=len(text) - 1,
+ formats=["plain"],
+ page_id=page_number,
+ block_id=get_counter(),
+ )
+ all_spans.append(span)
+ if populate_chars:
+ all_chars.append(chars)
+
+ min_x = min(min_x, x0)
+ max_x = max(max_x, x1)
+ min_y = min(min_y, y0)
+ max_y = max(max_y, y1)
+
+ # line is union of bboxes
+ line = Line(
+ spans=[span],
+ polygon=PolygonBox.from_bbox([min_x, min_y, max_x, max_y]),
+ page_id=page_number,
+ block_id=get_counter(),
+ )
+
+ return marker.providers.ProviderOutput(
+ line=line, spans=all_spans, chars=all_chars if populate_chars else None
+ )
|