Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions marker/schema/groups/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,32 @@ def merge_blocks(
if block.block_type not in self.excluded_block_types
]

max_intersections = self.compute_line_block_intersections(valid_blocks, provider_outputs)
# Process Figures and FigureGroup in the 2nd pass, so that text is preferentially assigned to text blocks
figures_pass = [
i for i, block in enumerate(valid_blocks)
if block.block_type in [BlockTypes.Figure, BlockTypes.FigureGroup]
]
main_pass = [
i for i in range(len(valid_blocks))
if i not in figures_pass
]

# Try to assign lines by intersection
assigned_line_idxs = set()
total_max_intersections = {} # provider_line_idx -> (intersection_area: np.float64, block_id)
block_lines = defaultdict(list)
for line_idx, provider_output in enumerate(provider_outputs):
if line_idx in max_intersections:
block_id = max_intersections[line_idx][1]
block_lines[block_id].append((line_idx, provider_output))
assigned_line_idxs.add(line_idx)
for block_subset_indices in [main_pass, figures_pass]:
block_subset = [valid_blocks[i] for i in block_subset_indices]
max_intersections = self.compute_line_block_intersections(block_subset, provider_outputs)

# remove already assigned in previous passes
max_intersections = {k: v for k, v in max_intersections.items() if k not in total_max_intersections}
for line_idx, provider_output in enumerate(provider_outputs):
if line_idx in max_intersections:
block_id = max_intersections[line_idx][1]
block_lines[block_id].append((line_idx, provider_output))
total_max_intersections.update(max_intersections)
assigned_line_idxs = set(total_max_intersections.keys())


# If no intersection, assign by distance
for line_idx in set(provider_line_idxs).difference(assigned_line_idxs):
Expand Down
18 changes: 15 additions & 3 deletions tests/builders/test_overriding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,29 @@ def get_lines(pdf: str, config=None):
for block_type, block_cls in config["override_map"].items():
register_block_class(block_type, block_cls)

provider: PdfProvider = setup_pdf_provider(pdf, config)
# provider: PdfProvider = setup_pdf_provider(pdf, config)
provider = PdfProvider(pdf, config)
return provider.get_page_lines(0)

@pytest.fixture(scope="function")
@pytest.mark.filename("adversarial.pdf")
def adversarial(temp_pdf):
return temp_pdf

def test_overriding_mp():
@pytest.fixture(scope="function")
@pytest.mark.filename("adversarial_rot.pdf")
def adversarial_rot(temp_pdf):
return temp_pdf


def test_overriding_mp(adversarial, adversarial_rot):
config = {
"page_range": [0],
"override_map": {BlockTypes.Line: NewLine}
}

pdf_list = ["adversarial.pdf", "adversarial_rot.pdf"]
# use temp files managed by pytest fixtures
pdf_list = [adversarial.name, adversarial_rot.name]

with mp.Pool(processes=2) as pool:
results = pool.starmap(get_lines, [(pdf, config) for pdf in pdf_list])
Expand Down
23 changes: 15 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import uuid
from marker.providers.pdf import PdfProvider
import tempfile
from typing import Dict, Type
Expand Down Expand Up @@ -86,10 +88,13 @@ def temp_doc(request, pdf_dataset):
idx = pdf_dataset['filename'].index(filename)
suffix = filename.split(".")[-1]

temp_pdf = tempfile.NamedTemporaryFile(suffix=f".{suffix}")
temp_pdf.write(pdf_dataset['pdf'][idx])
temp_pdf.flush()
yield temp_pdf
with tempfile.TemporaryDirectory() as temp_dir:
temp_pdf_path = os.path.join(temp_dir, f"temp.{suffix}") # Randomized filename

with open(temp_pdf_path, "wb") as temp_pdf:
temp_pdf.write(pdf_dataset['pdf'][idx])
temp_pdf.flush()
yield temp_pdf


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -148,7 +153,9 @@ def temp_image():
img = Image.new("RGB", (512, 512), color="white")
draw = ImageDraw.Draw(img)
draw.text((10, 10), "Hello, World!", fill="black")
with tempfile.NamedTemporaryFile(suffix=".png") as f:
img.save(f.name)
f.flush()
yield f

with tempfile.TemporaryDirectory() as temp_dir:
temp_png_path = os.path.join(temp_dir, f"{uuid.uuid4()}.png") # Randomized filename
img.save(temp_png_path)
with open(temp_png_path, "rb") as f:
yield f
66 changes: 66 additions & 0 deletions tests/schema/groups/test_block_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from surya.layout.schema import LayoutResult

from marker.builders.document import DocumentBuilder
from marker.builders.layout import LayoutBuilder
from marker.builders.line import LineBuilder
from marker.providers import ProviderOutput
import marker.schema.blocks
from marker.schema.groups.page import PageGroup
from marker.schema.polygon import PolygonBox
from tests.utils import convert_to_provider_output


def test_block_assignment():
page_bbox = [0, 0, 500, 500]
blocks = [
(100, 0, 200, 500, marker.schema.blocks.Figure),
(200, 0, 300, 500, marker.schema.blocks.Text)
]
words = [
(110, 0, 120, 10, "fig"),
(110, 10, 120, 20, "fig"),
(110, 20, 120, 30, "fig"),

(220, 0, 230, 10, "intext"),
(220, 10, 230, 20, "intext"),
(220, 20, 230, 30, "intext"),

(150, 30, 250, 40, "both"),
(150, 40, 250, 50, "both"),
(150, 50, 250, 60, "both"),

(180, 60, 280, 70, "mosttext"),
(180, 70, 280, 80, "mosttext"),
(180, 80, 280, 90, "mosttext"),

(120, 90, 220, 100, "mostfig"),
(120, 100, 220, 110, "mostfig"),
(120, 110, 220, 120, "mostfig"),
]
block_ctr = 0
def get_counter():
nonlocal block_ctr
o = block_ctr
block_ctr += 1
return o
page_group = PageGroup(polygon=PolygonBox.from_bbox(page_bbox),
children=[
block_cls(polygon=PolygonBox.from_bbox([xmin, ymin, xmax, ymax]),
page_id=0,
block_id=get_counter())
for xmin, ymin, xmax, ymax, block_cls in blocks
])

provider_outputs = [
convert_to_provider_output([word], page_bbox=[0, 0, 500, 500], get_counter=get_counter) for word in words
]


assert not page_group.children[0].structure, "figure's structure should begin with nothing in it"
assert not page_group.children[1].structure, "text's structure should begin with nothing in it"

page_group.merge_blocks(provider_outputs, text_extraction_method='custom')

assert len(page_group.children[0].structure) == 3, "figure should have just 3 words"
assert len(page_group.children[1].structure) == 12, "text should have the remaining 12 words"

139 changes: 139 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import marker.providers
from marker.providers.pdf import PdfProvider
from marker.schema.polygon import PolygonBox
from marker.schema.text.span import Span as MarkerSpan
from marker.schema.text.line import Line as MarkerLine
import tempfile

import datasets

import pdftext.schema

def setup_pdf_provider(
filename='adversarial.pdf',
Expand All @@ -17,3 +22,137 @@ def setup_pdf_provider(

provider = PdfProvider(temp_pdf.name, config)
return provider





def convert_to_pdftext(word_bboxes: list[tuple[float, float, float, float, str]], page_bbox: tuple, page_number: int):
"""Converts word bboxes (xmin, ymin, xmax, ymax) and text into a pdftext Page object"""
blocks = []
block_lines = []

for x0, y0, x1, y1, text in word_bboxes:
word_bbox = pdftext.schema.Bbox(bbox=[x0, y0, x1, y1])

# Create Char entries (assuming each character has uniform bbox)
chars = []
char_width = (x1 - x0) / len(text)
for i, char in enumerate(text):
char_bbox = pdftext.schema.Bbox(bbox=[x0 + i * char_width, y0, x0 + (i + 1) * char_width, y1])
chars.append(pdftext.schema.Char(
bbox=char_bbox,
char=char,
rotation=0,
font={"name": "DefaultFont"},
char_idx=i
))

span = pdftext.schema.Span(
bbox=word_bbox,
text=text,
font={"name": "DefaultFont"},
chars=chars,
char_start_idx=0,
char_end_idx=len(text) - 1,
rotation=0,
url=""
)

line = pdftext.schema.Line(
spans=[span],
bbox=word_bbox,
rotation=0
)

block_lines.append(line)

block = pdftext.schema.Block(
lines=block_lines,
bbox=page_bbox,
rotation=0
)
blocks.append(block)

page = pdftext.schema.Page(
page=page_number,
bbox=page_bbox,
width=page_bbox[2] - page_bbox[0],
height=page_bbox[3] - page_bbox[1],
blocks=blocks,
rotation=0,
refs=[]
)

return page

_block_counter = 0
def convert_to_provider_output(word_bboxes: list[tuple[float, float, float, float, str]], page_bbox: tuple, page_number: int=0, populate_chars=False,
get_counter=None):
"""Converts word bboxes (xmin, ymin, xmax, ymax, text) and text into a marker.providers.ProviderOutput object"""

if get_counter is None:
def get_counter():
global _block_counter
o = _block_counter
_block_counter += 1
return o

all_spans = []
all_chars = []
min_x = page_bbox[2]
max_x = page_bbox[0]
min_y = page_bbox[3]
max_y = page_bbox[1]
for x0, y0, x1, y1, text in word_bboxes:
word_bbox = PolygonBox.from_bbox([x0, y0, x1, y1])

# Create Char entries (assuming each character has uniform bbox)
if populate_chars:
chars = []
char_width = (x1 - x0) / len(text)
for i, char in enumerate(text):
char_bbox = PolygonBox.from_bbox([x0 + i * char_width, y0, x0 + (i + 1) * char_width, y1])
chars.append(marker.providers.Char(
char=char,
polygon=char_bbox,
char_idx=i
))

span = MarkerSpan(
polygon=word_bbox,
text=text,
font="DefaultFont",
font_weight=1.0,
font_size=12.0,
minimum_position=0,
maximum_position=len(text) - 1,
formats=['plain'],
page_id=page_number,
block_id=get_counter(),
# rotation=0,
)
all_spans.append(span)
if populate_chars:
all_chars.append(chars)

min_x = min(min_x, x0)
max_x = max(max_x, x1)
min_y = min(min_y, y0)
max_y = max(max_y, y1)

# line is union of bboxes
line = MarkerLine(
spans=[span],
polygon=PolygonBox.from_bbox([min_x, min_y, max_x, max_y]),
page_id=page_number,
block_id=get_counter(),
# rotation=0
)


return marker.providers.ProviderOutput(
line=line,
spans=all_spans,
chars=all_chars if populate_chars else None
)