Skip to content

Commit 1eb2037

Browse files
committed
refactor: improve text line sorting algorithm with row grouping and center alignment
- Refactored `sort_text_lines` to use a more robust method for sorting text lines in reading order (top-to-bottom, left-to-right). - Introduced helper functions to: - Compute vertical center of lines (`_center_y`) - Group lines into rows based on vertical median tolerance - Sort text within rows horizontally and flatten result - Added `_calculate_row_tolerance` with dynamic tolerance factor based on median line height - Defined `DEFAULT_ROW_TOLERANCE_FACTOR` constant for configurable grouping sensitivity - Improved overall code style, spacing, and docstrings for clarity and maintainability
1 parent 8b0969a commit 1eb2037

File tree

1 file changed

+100
-34
lines changed

1 file changed

+100
-34
lines changed

marker/util.py

Lines changed: 100 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,28 @@
1010

1111
from marker.schema.polygon import PolygonBox
1212
from marker.settings import settings
13+
from statistics import median
1314

15+
DEFAULT_ROW_TOLERANCE_FACTOR = 0.5
1416
OPENING_TAG_REGEX = re.compile(r"<((?:math|i|b))(?:\s+[^>]*)?>")
1517
CLOSING_TAG_REGEX = re.compile(r"</((?:math|i|b))>")
1618
TAG_MAPPING = {
17-
'i': 'italic',
18-
'b': 'bold',
19-
'math': 'math',
20-
'mark': 'highlight',
21-
'sub': 'subscript',
22-
'sup': 'superscript',
23-
'small': 'small',
24-
'u': 'underline',
25-
'code': 'code'
19+
"i": "italic",
20+
"b": "bold",
21+
"math": "math",
22+
"mark": "highlight",
23+
"sub": "subscript",
24+
"sup": "superscript",
25+
"small": "small",
26+
"u": "underline",
27+
"code": "code",
2628
}
2729

30+
2831
def strings_to_classes(items: List[str]) -> List[type]:
2932
classes = []
3033
for item in items:
31-
module_name, class_name = item.rsplit('.', 1)
34+
module_name, class_name = item.rsplit(".", 1)
3235
module = import_module(module_name)
3336
classes.append(getattr(module, class_name))
3437
return classes
@@ -52,7 +55,9 @@ def verify_config_keys(obj):
5255
if value is None:
5356
none_vals += f"{attr_name}, "
5457

55-
assert len(none_vals) == 0, f"In order to use {obj.__class__.__name__}, you must set the configuration values `{none_vals}`."
58+
assert len(none_vals) == 0, (
59+
f"In order to use {obj.__class__.__name__}, you must set the configuration values `{none_vals}`."
60+
)
5661

5762

5863
def assign_config(cls, config: BaseModel | dict | None):
@@ -122,7 +127,7 @@ def matrix_distance(boxes1: List[List[float]], boxes2: List[List[float]]) -> np.
122127
boxes1 = np.array(boxes1) # Shape: (N, 4)
123128
boxes2 = np.array(boxes2) # Shape: (M, 4)
124129

125-
boxes1_centers = (boxes1[:, :2] + boxes1[:, 2:]) / 2 # Shape: (M, 2)
130+
boxes1_centers = (boxes1[:, :2] + boxes1[:, 2:]) / 2 # Shape: (M, 2)
126131
boxes2_centers = (boxes2[:, :2] + boxes2[:, 2:]) / 2 # Shape: (M, 2)
127132

128133
boxes1_centers = boxes1_centers[:, np.newaxis, :] # Shape: (N, 1, 2)
@@ -132,67 +137,128 @@ def matrix_distance(boxes1: List[List[float]], boxes2: List[List[float]]) -> np.
132137
return distances
133138

134139

135-
def sort_text_lines(lines: List[PolygonBox], tolerance=1.25):
136-
# Sorts in reading order. Not 100% accurate, this should only
137-
# be used as a starting point for more advanced sorting.
138-
vertical_groups = {}
139-
for line in lines:
140-
group_key = round(line.bbox[1] / tolerance) * tolerance
141-
if group_key not in vertical_groups:
142-
vertical_groups[group_key] = []
143-
vertical_groups[group_key].append(line)
140+
def sort_text_lines(lines: List[PolygonBox]) -> List[PolygonBox]:
141+
"""
142+
Sort text lines in reading order (top-to-bottom, left-to-right).
143+
144+
Groups lines into rows based on vertical proximity, then sorts each row horizontally.
145+
146+
Args:
147+
lines: List of PolygonBox objects representing text lines
148+
149+
Returns:
150+
List of PolygonBox objects sorted in reading order
151+
"""
152+
if not lines:
153+
return []
154+
155+
# Calculate row grouping tolerance based on median line height
156+
row_tolerance = _calculate_row_tolerance(lines)
157+
158+
# Group lines into rows based on vertical position
159+
rows = _group_lines_into_rows(lines, row_tolerance)
160+
161+
# Sort each row horizontally and flatten
162+
return _sort_and_flatten_rows(rows)
163+
164+
165+
def _calculate_row_tolerance(lines: List[PolygonBox]) -> float:
166+
"""Calculate vertical tolerance for grouping lines into rows."""
167+
line_heights = [line.height for line in lines]
168+
return median(line_heights) * DEFAULT_ROW_TOLERANCE_FACTOR
169+
170+
171+
def _center_y(line: PolygonBox) -> float:
172+
"""Calculate the vertical center of a line."""
173+
return (line.bbox[1] + line.bbox[3]) / 2 if line.bbox else 0.0
174+
175+
176+
def _group_lines_into_rows(lines: List[PolygonBox], tolerance: float) -> List[List[PolygonBox]]:
177+
"""Group lines into rows based on vertical proximity."""
178+
# Sort lines by vertical position
179+
sorted_lines = sorted(lines, key=lambda line: _center_y(line))
180+
181+
if not sorted_lines:
182+
return []
144183

145-
# Sort each group horizontally and flatten the groups into a single list
146-
sorted_lines = []
147-
for _, group in sorted(vertical_groups.items()):
148-
sorted_group = sorted(group, key=lambda x: x.bbox[0])
149-
sorted_lines.extend(sorted_group)
184+
rows: List[List[PolygonBox]] = []
185+
current_row: List[PolygonBox] = [sorted_lines[0]]
186+
current_row_y: float = _center_y(sorted_lines[0])
187+
188+
for line in sorted_lines[1:]:
189+
if abs(_center_y(line) - current_row_y) <= tolerance:
190+
# Line belongs to current row
191+
current_row.append(line)
192+
else:
193+
# Start new row
194+
rows.append(current_row)
195+
current_row = [line]
196+
current_row_y = _center_y(line)
197+
198+
# The last row
199+
if current_row:
200+
rows.append(current_row)
201+
202+
return rows
203+
204+
205+
def _sort_and_flatten_rows(rows: List[List[PolygonBox]]) -> List[PolygonBox]:
206+
"""Sort each row horizontally and flatten into a single list."""
207+
sorted_lines: List[PolygonBox] = []
208+
209+
for row in rows:
210+
# Sort row by horizontal position (left to right)
211+
sorted_row: List[PolygonBox] = sorted(row, key=lambda line: line.bbox[0])
212+
sorted_lines.extend(sorted_row)
150213

151214
return sorted_lines
152215

216+
153217
def download_font():
154218
if not os.path.exists(settings.FONT_PATH):
155219
os.makedirs(os.path.dirname(settings.FONT_PATH), exist_ok=True)
156220
font_dl_path = f"{settings.ARTIFACT_URL}/{settings.FONT_NAME}"
157-
with requests.get(font_dl_path, stream=True) as r, open(settings.FONT_PATH, 'wb') as f:
221+
with requests.get(font_dl_path, stream=True) as r, open(settings.FONT_PATH, "wb") as f:
158222
r.raise_for_status()
159223
for chunk in r.iter_content(chunk_size=8192):
160224
f.write(chunk)
161225

226+
162227
def get_opening_tag_type(tag):
163228
"""
164229
Determines if a tag is an opening tag and extracts the tag type.
165-
230+
166231
Args:
167232
tag (str): The tag string to analyze.
168233
169234
Returns:
170235
tuple: (is_opening_tag (bool), tag_type (str or None))
171236
"""
172237
match = OPENING_TAG_REGEX.match(tag)
173-
238+
174239
if match:
175240
tag_type = match.group(1)
176241
if tag_type in TAG_MAPPING:
177242
return True, TAG_MAPPING[tag_type]
178-
243+
179244
return False, None
180245

246+
181247
def get_closing_tag_type(tag):
182248
"""
183249
Determines if a tag is an opening tag and extracts the tag type.
184-
250+
185251
Args:
186252
tag (str): The tag string to analyze.
187253
188254
Returns:
189255
tuple: (is_opening_tag (bool), tag_type (str or None))
190256
"""
191257
match = CLOSING_TAG_REGEX.match(tag)
192-
258+
193259
if match:
194260
tag_type = match.group(1)
195261
if tag_type in TAG_MAPPING:
196262
return True, TAG_MAPPING[tag_type]
197-
198-
return False, None
263+
264+
return False, None

0 commit comments

Comments
 (0)