From 000d8afb0b511d7db7f7d816b078d1f7c7af2d74 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Sat, 28 Feb 2026 14:10:25 +0800 Subject: [PATCH 01/38] Remove support for OCR without layout analysis --- README.md | 3 - README_zh.md | 3 - agent.md | 2 - examples/ollama-deploy/README.md | 3 - glmocr/api.py | 12 +-- glmocr/cli.py | 10 +-- glmocr/config.py | 13 +-- glmocr/config.yaml | 11 +-- glmocr/pipeline/pipeline.py | 137 +++---------------------------- glmocr/tests/test_unit.py | 41 --------- 10 files changed, 20 insertions(+), 215 deletions(-) diff --git a/README.md b/README.md index a840518..aa925da 100644 --- a/README.md +++ b/README.md @@ -247,9 +247,6 @@ pipeline: # Result formatting result_formatter: output_format: both # json, markdown, or both - - # Layout detection (optional) - enable_layout: false ``` See [config.yaml](glmocr/config.yaml) for all options. diff --git a/README_zh.md b/README_zh.md index 8fe64f0..9b2c5dd 100644 --- a/README_zh.md +++ b/README_zh.md @@ -248,9 +248,6 @@ pipeline: # Result formatting result_formatter: output_format: both # json, markdown, or both - - # Layout detection (optional) - enable_layout: false ``` 更多选项请参考 [config.yaml](glmocr/config.yaml)。 diff --git a/agent.md b/agent.md index e375a98..5b6a95a 100644 --- a/agent.md +++ b/agent.md @@ -70,7 +70,6 @@ or in a `.env` file anywhere in the working-directory ancestry. | `GLMOCR_OCR_API_HOST` | `pipeline.ocr_api.api_host` | `localhost` | | `GLMOCR_OCR_API_PORT` | `pipeline.ocr_api.api_port` | `5002` | | `GLMOCR_OCR_MODEL` | `pipeline.ocr_api.model` | `glm-ocr-model` | -| `GLMOCR_ENABLE_LAYOUT` | `pipeline.enable_layout` | `true` / `false` | | `GLMOCR_LOG_LEVEL` | `logging.level` | `DEBUG`, `INFO`, `WARNING`, `ERROR` | ### `.env` File Auto-Loading @@ -102,7 +101,6 @@ with **higher priority**. | `model` | `str` | Model name. | | `mode` | `str` | `"maas"` or `"selfhosted"`. | | `timeout` | `int` | Request timeout in seconds. | -| `enable_layout` | `bool` | Enable layout detection. | | `log_level` | `str` | Logging level. | --- diff --git a/examples/ollama-deploy/README.md b/examples/ollama-deploy/README.md index dafde9b..1f9bae1 100644 --- a/examples/ollama-deploy/README.md +++ b/examples/ollama-deploy/README.md @@ -61,8 +61,6 @@ pipeline: api_path: /api/generate # Use Ollama native endpoint model: glm-ocr:latest # Required: specify model name api_mode: ollama_generate # Required: use Ollama native format - - enable_layout: false # Recommended for initial testing ``` ### Configuration Options Explained @@ -70,7 +68,6 @@ pipeline: - **api_path**: `/api/generate` - Ollama's native endpoint (more stable for vision) - **model**: `glm-ocr:latest` - Model name (required by Ollama) - **api_mode**: `ollama_generate` - Enables Ollama-specific request/response format -- **enable_layout**: `false` - Disable layout detection if dependencies not installed ## Usage diff --git a/glmocr/api.py b/glmocr/api.py index 9ec8257..48fad1e 100644 --- a/glmocr/api.py +++ b/glmocr/api.py @@ -69,7 +69,6 @@ def __init__( model: Optional[str] = None, mode: Optional[str] = None, timeout: Optional[int] = None, - enable_layout: Optional[bool] = None, log_level: Optional[str] = None, # Extra knobs for self-hosted mode & GPU binding ocr_api_host: Optional[str] = None, @@ -90,7 +89,6 @@ def __init__( If *api_key* is provided without an explicit *mode*, mode defaults to ``"maas"``. timeout: Request timeout in seconds. - enable_layout: Whether to run layout detection. log_level: Logging level (DEBUG, INFO, WARNING, ERROR). """ # If user provides api_key but no explicit mode, default to MaaS. @@ -105,7 +103,6 @@ def __init__( model=model, mode=mode, timeout=timeout, - enable_layout=enable_layout, log_level=log_level, ocr_api_host=ocr_api_host, ocr_api_port=ocr_api_port, @@ -128,14 +125,12 @@ def __init__( self._maas_client = MaaSClient(self.config_model.pipeline.maas) self._maas_client.start() - self.enable_layout = True # MaaS always includes layout logger.info("GLM-OCR initialized in MaaS mode (cloud API passthrough)") else: # Self-hosted mode: use full Pipeline from glmocr.pipeline import Pipeline self._pipeline = Pipeline(config=self.config_model.pipeline) - self.enable_layout = self._pipeline.enable_layout self._pipeline.start() logger.info("GLM-OCR initialized in self-hosted mode") @@ -439,7 +434,7 @@ def _parse_selfhosted( request_data = {"messages": messages} layout_vis_dir = None - if self._pipeline.enable_layout and save_layout_visualization: + if save_layout_visualization: layout_vis_dir = tempfile.mkdtemp(prefix="layout_vis_") results = list( @@ -475,7 +470,7 @@ def _stream_parse_selfhosted( request_data = {"messages": messages} layout_vis_dir = None - if self._pipeline.enable_layout and save_layout_visualization: + if save_layout_visualization: layout_vis_dir = tempfile.mkdtemp(prefix="layout_vis_") for result in self._pipeline.process( @@ -596,7 +591,6 @@ def parse( model: Optional[str] = None, mode: Optional[str] = None, timeout: Optional[int] = None, - enable_layout: Optional[bool] = None, log_level: Optional[str] = None, **kwargs: Any, ) -> Union[PipelineResult, List[PipelineResult], Generator[PipelineResult, None, None]]: @@ -637,7 +631,6 @@ def parse( model: Model name. mode: ``"maas"`` or ``"selfhosted"``. timeout: Request timeout in seconds. - enable_layout: Whether to run layout detection. log_level: Logging level. Returns: @@ -661,7 +654,6 @@ def parse( model=model, mode=mode, timeout=timeout, - enable_layout=enable_layout, log_level=log_level, ) as parser: return parser.parse( diff --git a/glmocr/cli.py b/glmocr/cli.py index 6828036..5168be7 100644 --- a/glmocr/cli.py +++ b/glmocr/cli.py @@ -66,9 +66,6 @@ def main(): # Parse all images in a directory glmocr parse ./images/ - # Disable layout detection (OCR-only): set pipeline.enable_layout=false - glmocr parse image.png --config my_config.yaml - # Specify output directory glmocr parse image.png --output ./output/ @@ -99,7 +96,7 @@ def main(): parse_parser.add_argument( "--no-layout-vis", action="store_true", - help="Do not save layout visualization results (only effective when enable_layout=true)", + help="Do not save layout visualization results", ) parse_parser.add_argument( "--config", @@ -145,10 +142,7 @@ def main(): save_layout_vis = not args.no_layout_vis with GlmOcr(config_path=args.config) as glm_parser: - logger.info( - "Using Pipeline (enable_layout=%s)...", - "true" if glm_parser.enable_layout else "false", - ) + logger.info("Using Pipeline...") # Process each file (parse() with str returns a single PipelineResult) total_files = len(image_paths) diff --git a/glmocr/config.py b/glmocr/config.py index 386021a..dcd9d18 100644 --- a/glmocr/config.py +++ b/glmocr/config.py @@ -44,8 +44,6 @@ def _find_dotenv(start: Optional[Path] = None) -> Optional[Path]: "OCR_API_HOST": "pipeline.ocr_api.api_host", "OCR_API_PORT": "pipeline.ocr_api.api_port", "OCR_MODEL": "pipeline.ocr_api.model", - # Layout - "ENABLE_LAYOUT": "pipeline.enable_layout", # Allow overriding which GPU(s) the layout model uses "LAYOUT_CUDA_VISIBLE_DEVICES": "pipeline.layout.cuda_visible_devices", # Logging @@ -192,8 +190,6 @@ class LayoutConfig(_BaseConfig): class PipelineConfig(_BaseConfig): - enable_layout: bool = False - # MaaS mode configuration (Zhipu cloud API passthrough) maas: MaaSApiConfig = Field(default_factory=MaaSApiConfig) @@ -224,11 +220,8 @@ def _set_nested(data: Dict[str, Any], dotted_path: str, value: Any) -> None: def _coerce_env_value(dotted_path: str, raw: str) -> Any: """Coerce a raw environment-variable string to the expected Python type.""" # Boolean fields - if dotted_path in ("pipeline.maas.enabled", "pipeline.enable_layout"): - # Special handling for MODE: "maas" → True, anything else → False - if dotted_path == "pipeline.maas.enabled": - return raw.strip().lower() in ("maas", "true", "1", "yes") - return raw.strip().lower() in ("true", "1", "yes") + if dotted_path == "pipeline.maas.enabled": + return raw.strip().lower() in ("maas", "true", "1", "yes") # Integer fields if dotted_path.endswith((".api_port", ".request_timeout", ".connect_timeout")): return int(raw) @@ -317,7 +310,6 @@ def from_env( * ``model`` – model name * ``mode`` – ``"maas"`` or ``"selfhosted"`` * ``timeout`` – request timeout in seconds - * ``enable_layout`` – whether to run layout detection * ``log_level`` – logging level (DEBUG / INFO / …) Any other keyword is silently ignored so that callers can safely @@ -360,7 +352,6 @@ def from_env( "model": "pipeline.maas.model", "mode": "pipeline.maas.enabled", "timeout": "pipeline.maas.request_timeout", - "enable_layout": "pipeline.enable_layout", "log_level": "logging.level", # Self-hosted OCR API "ocr_api_host": "pipeline.ocr_api.api_host", diff --git a/glmocr/config.yaml b/glmocr/config.yaml index 7315810..d1921e3 100644 --- a/glmocr/config.yaml +++ b/glmocr/config.yaml @@ -147,10 +147,6 @@ pipeline: # Result formatter: post-processing and output formatting result_formatter: - # Filter nested regions (remove overlapping smaller regions) - filter_nested: true - min_overlap_ratio: 0.8 - # Output format: json, markdown, or both output_format: both @@ -178,12 +174,7 @@ pipeline: - seal - formula_number - # Enable layout detection mode - # - true: detect document regions (tables, figures, text blocks) then OCR each - # - false: direct OCR on the whole image - enable_layout: true - - # Layout detection settings (used when enable_layout=true) + # Layout detection settings layout: # PP-DocLayoutV3 model directory # Can be a local folder or a Hugging Face model id diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index eeeb7b5..9850d14 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -21,7 +21,6 @@ from glmocr.parser_result import PipelineResult from glmocr.postprocess import ResultFormatter from glmocr.utils.image_utils import crop_image_region -from glmocr.utils.image_utils import load_image_to_base64 from glmocr.utils.logging import get_logger, get_profiler if TYPE_CHECKING: @@ -90,7 +89,6 @@ def __init__( result_formatter: Optional[ResultFormatter] = None, ): self.config = config - self.enable_layout = config.enable_layout # Unified page loader self.page_loader = PageLoader(config.page_loader) @@ -104,20 +102,20 @@ def __init__( else: self.result_formatter = ResultFormatter(config.result_formatter) - # Layout detector (initialized only when enabled) - if self.enable_layout: - if layout_detector is not None: - self.layout_detector = layout_detector - else: - from glmocr.layout import PPDocLayoutDetector + # Layout detector + if layout_detector is not None: + self.layout_detector = layout_detector + else: + from glmocr.layout import PPDocLayoutDetector - if PPDocLayoutDetector is None: - from glmocr.layout import _raise_layout_import_error + if PPDocLayoutDetector is None: + from glmocr.layout import _raise_layout_import_error - _raise_layout_import_error() + _raise_layout_import_error() - self.layout_detector = PPDocLayoutDetector(config.layout) - self.max_workers = config.max_workers + self.layout_detector = PPDocLayoutDetector(config.layout) + + self.max_workers = config.max_workers self._page_maxsize = getattr(config, "page_maxsize", 100) self._region_maxsize = getattr(config, "region_maxsize", 800) @@ -172,113 +170,6 @@ def process( PipelineResult per input URL (one image or one PDF). """ - if not self.enable_layout: - image_urls = self._extract_image_urls(request_data) - if not image_urls: - request_data = self.page_loader.build_request(request_data) - response, status_code = self.ocr_client.process(request_data) - if status_code != 200: - raise Exception( - f"OCR request failed: {response}, status_code: {status_code}" - ) - content = ( - response.get("choices", [{}])[0] - .get("message", {}) - .get("content", "") - ) - json_result, markdown_result = self.result_formatter.format_ocr_result( - content - ) - yield PipelineResult( - json_result=json_result, - markdown_result=markdown_result, - original_images=[], - layout_vis_dir=layout_vis_output_dir, - ) - return - pages, unit_indices = self.page_loader.load_pages_with_unit_indices( - image_urls - ) - from copy import deepcopy - - base_request_data = deepcopy(request_data) - cleaned_messages = [] - for msg in base_request_data.get("messages", []): - if msg.get("role") != "user": - cleaned_messages.append(msg) - continue - contents = msg.get("content", []) - if isinstance(contents, list): - contents = [c for c in contents if c.get("type") != "image_url"] - cleaned_messages.append({**msg, "content": contents}) - base_request_data["messages"] = cleaned_messages - - num_units = len(image_urls) - original_inputs = [ - (url[7:] if url.startswith("file://") else url) for url in image_urls - ] - unit_contents: Dict[int, List[str]] = {u: [] for u in range(num_units)} - for page_idx, page in enumerate(pages): - u = ( - (unit_indices or [0])[page_idx] - if page_idx < len(unit_indices or []) - else 0 - ) - img_b64 = load_image_to_base64( - page, - t_patch_size=self.page_loader.t_patch_size, - max_pixels=self.page_loader.max_pixels, - image_format=self.page_loader.image_format, - patch_expand_factor=self.page_loader.patch_expand_factor, - min_pixels=self.page_loader.min_pixels, - ) - data_url = f"data:image/{self.page_loader.image_format.lower()};base64,{img_b64}" - per_request = deepcopy(base_request_data) - user_msg = None - for m in per_request.get("messages", []): - if m.get("role") == "user" and isinstance(m.get("content"), list): - user_msg = m - break - if user_msg is None: - per_request.setdefault("messages", []).append( - {"role": "user", "content": []} - ) - user_msg = per_request["messages"][-1] - user_msg["content"].append( - {"type": "image_url", "image_url": {"url": data_url}} - ) - per_request = self.page_loader.build_request(per_request) - response, status_code = self.ocr_client.process(per_request) - if status_code != 200: - raise Exception( - f"OCR request failed: {response}, status_code: {status_code}" - ) - content = ( - response.get("choices", [{}])[0] - .get("message", {}) - .get("content", "") - ) - unit_contents.setdefault(u, []).append(content) - for u in range(num_units): - contents_u = unit_contents.get(u, []) - if len(contents_u) == 1: - ( - json_result, - markdown_result, - ) = self.result_formatter.format_ocr_result(contents_u[0]) - else: - ( - json_result, - markdown_result, - ) = self.result_formatter.format_multi_page_results(contents_u) - yield PipelineResult( - json_result=json_result, - markdown_result=markdown_result, - original_images=[original_inputs[u]], - layout_vis_dir=layout_vis_output_dir, - ) - return - image_urls = self._extract_image_urls(request_data) if not image_urls: request_data = self.page_loader.build_request(request_data) @@ -715,8 +606,7 @@ def _recognize_regions(self, regions: List[tuple]) -> List[tuple]: def start(self): """Start the pipeline.""" logger.info("Starting Pipeline...") - if self.enable_layout: - self.layout_detector.start() + self.layout_detector.start() self.ocr_client.start() logger.info("Pipeline started!") @@ -724,8 +614,7 @@ def stop(self): """Stop the pipeline.""" logger.info("Stopping Pipeline...") self.ocr_client.stop() - if self.enable_layout: - self.layout_detector.stop() + self.layout_detector.stop() logger.info("Pipeline stopped!") def __enter__(self): diff --git a/glmocr/tests/test_unit.py b/glmocr/tests/test_unit.py index 00a1899..c256676 100644 --- a/glmocr/tests/test_unit.py +++ b/glmocr/tests/test_unit.py @@ -216,31 +216,6 @@ def test_parse_result_repr(self): assert "images=2" in repr(result) -class TestPipeline: - """Tests for Pipeline (without starting).""" - - def test_pipeline_init_enable_layout_default(self): - """Default enable_layout behavior (mocked).""" - from glmocr.pipeline import Pipeline - - # Use a mock to avoid heavy dependencies - with patch.object(Pipeline, "__init__", lambda self, config: None): - p = Pipeline.__new__(Pipeline) - p.config = {} - p.enable_layout = p.config.get("enable_layout", True) - assert p.enable_layout is True - - def test_pipeline_init_enable_layout_false(self): - """enable_layout can be disabled (mocked).""" - with patch("glmocr.pipeline.Pipeline.__init__", return_value=None): - from glmocr.pipeline import Pipeline - - p = Pipeline.__new__(Pipeline) - p.config = {"enable_layout": False} - p.enable_layout = p.config.get("enable_layout", True) - assert p.enable_layout is False - - class TestUtils: """Tests for utility functions.""" @@ -538,18 +513,6 @@ def test_mode_case_insensitive(self): assert _coerce_env_value("pipeline.maas.enabled", "MaaS") is True assert _coerce_env_value("pipeline.maas.enabled", "TRUE") is True - def test_enable_layout_true(self): - from glmocr.config import _coerce_env_value - - assert _coerce_env_value("pipeline.enable_layout", "1") is True - assert _coerce_env_value("pipeline.enable_layout", "yes") is True - - def test_enable_layout_false(self): - from glmocr.config import _coerce_env_value - - assert _coerce_env_value("pipeline.enable_layout", "0") is False - assert _coerce_env_value("pipeline.enable_layout", "no") is False - def test_integer_coercion(self): from glmocr.config import _coerce_env_value @@ -955,7 +918,6 @@ def _make_glmocr(self): obj._pipeline = None obj._maas_client = MagicMock() obj.config_model = MagicMock() - obj.enable_layout = True # Mock _parse_maas to return a list of one result obj._parse_maas = MagicMock(return_value=[mock_result]) @@ -1001,7 +963,6 @@ def _make_glmocr_maas(self): obj._pipeline = None obj._maas_client = MagicMock() obj.config_model = MagicMock() - obj.enable_layout = True obj._maas_response_to_pipeline_result = MagicMock( return_value=PipelineResult( json_result=[[{"content": "ok"}]], @@ -1021,7 +982,6 @@ def _make_glmocr_selfhosted(self): obj._maas_client = None obj._pipeline = MagicMock() obj.config_model = MagicMock() - obj.enable_layout = True r1 = PipelineResult( json_result=[], markdown_result="a", original_images=["a.png"] ) @@ -1135,7 +1095,6 @@ def test_explicit_selfhosted_mode(self, monkeypatch): with patch("glmocr.pipeline.Pipeline") as mock_pipeline: mock_pipeline.return_value.start = MagicMock() - mock_pipeline.return_value.enable_layout = False from glmocr.api import GlmOcr parser = GlmOcr(mode="selfhosted") From 796519a9b476160271cf9ad9ad7c4d6a92006554 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Sat, 28 Feb 2026 20:12:02 +0800 Subject: [PATCH 02/38] reconstruct pipeline --- glmocr/pipeline/_common.py | 42 ++ glmocr/pipeline/_state.py | 96 +++++ glmocr/pipeline/_unit_tracker.py | 106 +++++ glmocr/pipeline/_workers.py | 266 ++++++++++++ glmocr/pipeline/pipeline.py | 688 ++++++++----------------------- 5 files changed, 692 insertions(+), 506 deletions(-) create mode 100644 glmocr/pipeline/_common.py create mode 100644 glmocr/pipeline/_state.py create mode 100644 glmocr/pipeline/_unit_tracker.py create mode 100644 glmocr/pipeline/_workers.py diff --git a/glmocr/pipeline/_common.py b/glmocr/pipeline/_common.py new file mode 100644 index 0000000..7b09ba4 --- /dev/null +++ b/glmocr/pipeline/_common.py @@ -0,0 +1,42 @@ +"""Shared helpers for the pipeline package.""" + +from __future__ import annotations + +from typing import Any, Dict, List + +from glmocr.utils.logging import get_logger + +logger = get_logger(__name__) + + +def extract_image_urls(request_data: Dict[str, Any]) -> List[str]: + """Extract image URLs from an OpenAI-style request payload.""" + image_urls: List[str] = [] + for msg in request_data.get("messages", []): + if msg.get("role") == "user": + contents = msg.get("content", []) + if isinstance(contents, list): + for content in contents: + if content.get("type") == "image_url": + image_urls.append(content["image_url"]["url"]) + return image_urls + + +def make_original_inputs(image_urls: List[str]) -> List[str]: + """Strip ``file://`` prefix so that original paths are returned.""" + return [(url[7:] if url.startswith("file://") else url) for url in image_urls] + + +def extract_ocr_content(response: Dict[str, Any]) -> str: + """Pull the content string out of an OpenAI-style OCR response.""" + return ( + response.get("choices", [{}])[0].get("message", {}).get("content", "") + ) + + +# ── Queue message "identifier" field values ────────────────────────── +# Every queue message is a dict with at least an "identifier" key. +IDENTIFIER_IMAGE = "image" +IDENTIFIER_REGION = "region" +IDENTIFIER_DONE = "done" +IDENTIFIER_ERROR = "error" diff --git a/glmocr/pipeline/_state.py b/glmocr/pipeline/_state.py new file mode 100644 index 0000000..9f5514f --- /dev/null +++ b/glmocr/pipeline/_state.py @@ -0,0 +1,96 @@ +"""Shared mutable state for the three-stage async pipeline. + +This object is created once per ``Pipeline.process()`` call and passed to +all three worker threads. It holds the inter-thread queues, accumulated +results, and the (optional) UnitTracker reference. +""" + +from __future__ import annotations + +import queue +import threading +from typing import Any, Dict, List, Optional + +from glmocr.pipeline._unit_tracker import UnitTracker + + +class PipelineState: + """Thread-safe container shared by loader / layout / recognition workers. + + Queues (dict messages flow through these): + page_queue — Stage 1 → Stage 2 + region_queue — Stage 2 → Stage 3 + + Accumulated results (list, not a queue — main thread needs random access): + recognition_results — Stage 3 appends, main thread snapshots + """ + + def __init__( + self, + page_maxsize: int = 100, + region_maxsize: int = 800, + ): + # ── Inter-thread queues ────────────────────────────────────── + self.page_queue: queue.Queue[Dict[str, Any]] = queue.Queue(maxsize=page_maxsize) + self.region_queue: queue.Queue[Dict[str, Any]] = queue.Queue(maxsize=region_maxsize) + + # ── Per-page data (stage 1 & 2 write, main thread reads) ───── + self.images_dict: Dict[int, Any] = {} + self.layout_results_dict: Dict[int, List] = {} + + # ── Counters (stage 1 writes, main thread reads after join) ── + self.num_images_loaded: List[int] = [0] + self.unit_indices_holder: List[Optional[List[int]]] = [None] + + # ── Recognition results (stage 3 appends, main thread reads) ─ + self._recognition_results: List[Dict[str, Any]] = [] + self._results_lock = threading.Lock() + + # ── UnitTracker (set by main thread after stage 1 & 2 join) ── + self._tracker: Optional[UnitTracker] = None + + # ── Exception collection ───────────────────────────────────── + self._exceptions: List[Dict[str, Any]] = [] + self._exception_lock = threading.Lock() + + # ------------------------------------------------------------------ + # Recognition results + # ------------------------------------------------------------------ + + def add_recognition_result(self, page_idx: int, region: Dict) -> None: + """Append a completed region result and notify the tracker.""" + result = {"page_idx": page_idx, "region": region} + with self._results_lock: + self._recognition_results.append(result) + tracker = self._tracker + if tracker is not None: + tracker.on_region_done(page_idx) + + def snapshot_recognition_results(self) -> List[Dict[str, Any]]: + """Return a shallow copy of all results accumulated so far.""" + with self._results_lock: + return list(self._recognition_results) + + # ------------------------------------------------------------------ + # UnitTracker lifecycle + # ------------------------------------------------------------------ + + def set_tracker(self, tracker: UnitTracker) -> None: + self._tracker = tracker + + # ------------------------------------------------------------------ + # Exception handling + # ------------------------------------------------------------------ + + def record_exception(self, source: str, exc: Exception) -> None: + with self._exception_lock: + self._exceptions.append({"source": source, "exception": exc}) + + def raise_if_exceptions(self) -> None: + with self._exception_lock: + if self._exceptions: + raise RuntimeError( + "; ".join( + f"{e['source']}: {e['exception']}" for e in self._exceptions + ) + ) diff --git a/glmocr/pipeline/_unit_tracker.py b/glmocr/pipeline/_unit_tracker.py new file mode 100644 index 0000000..c0e2f48 --- /dev/null +++ b/glmocr/pipeline/_unit_tracker.py @@ -0,0 +1,106 @@ +"""UnitTracker — tracks per-unit (per input URL) region completion. + +One "unit" corresponds to a single input URL (one image file or one PDF). +A PDF unit may span multiple pages, each page having multiple regions. +This tracker counts completed regions and notifies the main thread when +all regions for a unit are done. + +Thread safety: + - ``on_region_done()`` is called from Thread 3 (recognition worker). + - ``wait_next_ready_unit()`` / ``iter_ready_units()`` are called from the main thread. +""" + +from __future__ import annotations + +import queue +import threading +from typing import Dict, List + + +class UnitTracker: + """Tracks region-level completion for each input unit.""" + + def __init__( + self, + num_units: int, + unit_image_indices: List[List[int]], + unit_region_count: List[int], + ): + self._num_units = num_units + self._unit_image_indices = unit_image_indices + self._unit_region_count = unit_region_count + + self._unit_for_image: Dict[int, int] = { + img_idx: u + for u in range(num_units) + for img_idx in unit_image_indices[u] + } + self._done_count: List[int] = [0] * num_units + self._ready_queue: queue.Queue[int] = queue.Queue() + self._notified: set = set() + + self._lock = threading.Lock() + self._notify_lock = threading.Lock() + + # ------------------------------------------------------------------ + # Backfill: handle regions that completed *before* tracker existed + # ------------------------------------------------------------------ + + def backfill(self, done_page_indices: List[int]) -> None: + """Account for regions finished before the tracker was initialised. + + Args: + done_page_indices: ``page_idx`` values of already-completed regions. + + Called once from the main thread right after construction. + """ + for page_idx in done_page_indices: + u = self._unit_for_image.get(page_idx) + if u is not None: + self._done_count[u] += 1 + + for u in range(self._num_units): + if self._done_count[u] >= self._unit_region_count[u]: + self._ready_queue.put(u) + self._notified.add(u) + + # ------------------------------------------------------------------ + # Runtime: called from Thread 3 after each region completes + # ------------------------------------------------------------------ + + def on_region_done(self, page_idx: int) -> None: + """Increment the counter for the unit owning *page_idx*. + + O(1). If the unit reaches its target count, enqueue it. + """ + u = self._unit_for_image.get(page_idx) + if u is None: + return + with self._lock: + self._done_count[u] += 1 + ready = self._done_count[u] >= self._unit_region_count[u] + if ready: + with self._notify_lock: + if u not in self._notified: + self._ready_queue.put(u) + self._notified.add(u) + + # ------------------------------------------------------------------ + # Consumption: called from the main thread + # ------------------------------------------------------------------ + + def wait_next_ready_unit(self) -> int: + """Block until the next unit is ready and return its index.""" + return self._ready_queue.get() + + @property + def num_units(self) -> int: + return self._num_units + + @property + def unit_image_indices(self) -> List[List[int]]: + return self._unit_image_indices + + @property + def unit_region_count(self) -> List[int]: + return self._unit_region_count diff --git a/glmocr/pipeline/_workers.py b/glmocr/pipeline/_workers.py new file mode 100644 index 0000000..84f3d5c --- /dev/null +++ b/glmocr/pipeline/_workers.py @@ -0,0 +1,266 @@ +"""Thread workers for the three-stage async pipeline. + +Stage 1 (data_loading_worker): Load pages from URLs → page_queue +Stage 2 (layout_worker): Layout detection → region_queue +Stage 3 (recognition_worker): Parallel OCR → recognition_results + +Queue message formats: + + page_queue:: + {"identifier": "image", "page_idx": int, "image": PIL.Image} + {"identifier": "done"} + {"identifier": "error"} + + region_queue:: + {"identifier": "region", "page_idx": int, "cropped_image": PIL.Image, + "region": dict, "task_type": str} + {"identifier": "done"} + {"identifier": "error"} +""" + +from __future__ import annotations + +import queue +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from concurrent.futures import ThreadPoolExecutor, as_completed + +from glmocr.pipeline._common import ( + IDENTIFIER_DONE, + IDENTIFIER_ERROR, + IDENTIFIER_IMAGE, + IDENTIFIER_REGION, +) +from glmocr.pipeline._state import PipelineState +from glmocr.utils.image_utils import crop_image_region +from glmocr.utils.logging import get_logger + +if TYPE_CHECKING: + from glmocr.dataloader import PageLoader + from glmocr.layout.base import BaseLayoutDetector + +logger = get_logger(__name__) + + +# ====================================================================== +# Stage 1: Data Loading +# ====================================================================== + +def data_loading_worker( + state: PipelineState, + page_loader: "PageLoader", + image_urls: List[str], +) -> None: + """Load pages from *image_urls* and push them onto ``state.page_queue``.""" + page_idx = 0 + unit_indices_list: List[int] = [] + try: + for page, unit_idx in page_loader.iter_pages_with_unit_indices(image_urls): + state.images_dict[page_idx] = page + state.page_queue.put({ + "identifier": IDENTIFIER_IMAGE, + "page_idx": page_idx, + "image": page, + }) + unit_indices_list.append(unit_idx) + page_idx += 1 + state.num_images_loaded[0] = page_idx + state.unit_indices_holder[0] = list(unit_indices_list) + state.page_queue.put({"identifier": IDENTIFIER_DONE}) + except Exception as e: + logger.exception("Data loading worker error: %s", e) + state.num_images_loaded[0] = page_idx + state.unit_indices_holder[0] = list(unit_indices_list) + state.record_exception("DataLoadingWorker", e) + state.page_queue.put({"identifier": IDENTIFIER_ERROR}) + + +# ====================================================================== +# Stage 2: Layout Detection +# ====================================================================== + +def layout_worker( + state: PipelineState, + layout_detector: "BaseLayoutDetector", + save_visualization: bool, + vis_output_dir: Optional[str], +) -> None: + """Consume pages, run layout detection in batches, push regions.""" + try: + batch_images: List[Any] = [] + batch_page_indices: List[int] = [] + global_start_idx = 0 + + while True: + try: + msg = state.page_queue.get(timeout=0.01) + except queue.Empty: + continue + + identifier = msg["identifier"] + + if identifier == IDENTIFIER_IMAGE: + batch_images.append(msg["image"]) + batch_page_indices.append(msg["page_idx"]) + if len(batch_images) >= layout_detector.batch_size: + _flush_layout_batch( + state, layout_detector, batch_images, batch_page_indices, + save_visualization, vis_output_dir, global_start_idx, + ) + global_start_idx += len(batch_page_indices) + batch_images, batch_page_indices = [], [] + + elif identifier == IDENTIFIER_DONE: + if batch_images: + _flush_layout_batch( + state, layout_detector, batch_images, batch_page_indices, + save_visualization, vis_output_dir, global_start_idx, + ) + state.region_queue.put({"identifier": IDENTIFIER_DONE}) + break + + elif identifier == IDENTIFIER_ERROR: + state.region_queue.put({"identifier": IDENTIFIER_ERROR}) + break + + except Exception as e: + logger.exception("Layout worker error: %s", e) + state.record_exception("LayoutWorker", e) + state.region_queue.put({"identifier": IDENTIFIER_ERROR}) + + +def _flush_layout_batch( + state: PipelineState, + layout_detector: "BaseLayoutDetector", + batch_images: List[Any], + batch_page_indices: List[int], + save_visualization: bool, + vis_output_dir: Optional[str], + global_start_idx: int, +) -> None: + """Run layout detection on one batch and enqueue the resulting regions.""" + layout_results = layout_detector.process( + batch_images, + save_visualization=save_visualization and vis_output_dir is not None, + visualization_output_dir=vis_output_dir, + global_start_idx=global_start_idx, + ) + for page_idx, image, layout_result in zip( + batch_page_indices, batch_images, layout_results + ): + state.layout_results_dict[page_idx] = layout_result + for region in layout_result: + cropped = crop_image_region(image, region["bbox_2d"], region["polygon"]) + state.region_queue.put({ + "identifier": IDENTIFIER_REGION, + "page_idx": page_idx, + "cropped_image": cropped, + "region": region, + }) + + +# ====================================================================== +# Stage 3: VLM Recognition +# ====================================================================== + +def recognition_worker( + state: PipelineState, + page_loader: "PageLoader", + ocr_client: Any, + max_workers: int, +) -> None: + """Consume regions, run parallel OCR, store results.""" + try: + executor = ThreadPoolExecutor(max_workers=min(max_workers, 128)) + futures: Dict[Any, Dict[str, Any]] = {} + pending_skip: List[Dict[str, Any]] = [] + processing_complete = False + + while True: + _collect_done_futures(futures, state) + + try: + msg = state.region_queue.get(timeout=0.01) + except queue.Empty: + if processing_complete and not futures: + _flush_pending_skips(pending_skip, state) + break + if futures: + _wait_for_any(futures) + continue + + identifier = msg["identifier"] + + if identifier == IDENTIFIER_REGION: + if msg["region"]["task_type"] == "skip": + pending_skip.append(msg) + else: + req = page_loader.build_request_from_image( + msg["cropped_image"], msg["region"]["task_type"], + ) + future = executor.submit(ocr_client.process, req) + futures[future] = msg + + elif identifier == IDENTIFIER_DONE: + processing_complete = True + + elif identifier == IDENTIFIER_ERROR: + break + + for future in as_completed(futures.keys()): + _handle_future_result(future, futures, state) + executor.shutdown(wait=True) + + except Exception as e: + logger.exception("Recognition worker error: %s", e) + state.record_exception("RecognitionWorker", e) + + +# ------------------------------------------------------------------ +# Recognition helpers +# ------------------------------------------------------------------ + +def _collect_done_futures( + futures: Dict[Any, Dict[str, Any]], + state: PipelineState, +) -> None: + for f in list(futures): + if f.done(): + _handle_future_result(f, futures, state) + + +def _handle_future_result( + future: Any, + futures: Dict[Any, Dict[str, Any]], + state: PipelineState, +) -> None: + msg = futures.pop(future) + region = msg["region"] + page_idx = msg["page_idx"] + try: + response, status_code = future.result() + if status_code == 200: + region["content"] = response["choices"][0]["message"]["content"].strip() + else: + region["content"] = "" + except Exception as e: + logger.warning("Recognition failed: %s", e) + region["content"] = "" + state.add_recognition_result(page_idx, region) + + +def _flush_pending_skips( + pending: List[Dict[str, Any]], + state: PipelineState, +) -> None: + for msg in pending: + msg["region"]["content"] = None + state.add_recognition_result(msg["page_idx"], msg["region"]) + + +def _wait_for_any(futures: Dict) -> None: + done_list = [f for f in futures if f.done()] + if not done_list: + try: + next(as_completed(futures.keys(), timeout=0.05)) + except Exception: + pass diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index 9850d14..5a8a555 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -1,85 +1,63 @@ """GLM-OCR Pipeline -Unified document parsing pipeline. Async by default: process() yields one result -per input unit (one image or one PDF). No separate sync API. +Three-stage async document parsing pipeline. ``process()`` yields one +``PipelineResult`` per input unit (one image or one PDF). -Extension options: -1. Replace components: custom LayoutDetector / ResultFormatter -2. Inherit: subclass Pipeline and override process() +Stages (all always enabled): + 1. PageLoader — load images / PDF pages + 2. LayoutDetector — detect regions per page + 3. OCRClient — recognise each region via VLM + +Extension points: + * Pass a custom ``layout_detector`` or ``result_formatter`` to the constructor. + * Subclass ``Pipeline`` and override ``process()``. """ from __future__ import annotations -import queue import threading -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Any, Optional, Tuple, List, Generator -from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional from glmocr.dataloader import PageLoader from glmocr.ocr_client import OCRClient from glmocr.parser_result import PipelineResult from glmocr.postprocess import ResultFormatter -from glmocr.utils.image_utils import crop_image_region -from glmocr.utils.logging import get_logger, get_profiler +from glmocr.utils.logging import get_logger + +from glmocr.pipeline._common import extract_image_urls, extract_ocr_content, make_original_inputs +from glmocr.pipeline._state import PipelineState +from glmocr.pipeline._workers import data_loading_worker, layout_worker, recognition_worker +from glmocr.pipeline._unit_tracker import UnitTracker if TYPE_CHECKING: from glmocr.config import PipelineConfig from glmocr.layout.base import BaseLayoutDetector logger = get_logger(__name__) -profiler = get_profiler(__name__) - - -@dataclass -class _AsyncPipelineState: - """Shared state for the 3-thread layout path (loader -> layout -> recognition).""" - - page_queue: queue.Queue - region_queue: queue.Queue - ready_units_queue: queue.Queue - recognition_results: List[Tuple[int, Dict]] - results_lock: threading.Lock - images_dict: Dict[int, Any] - layout_results_dict: Dict[int, List] - num_images_loaded: List[int] - unit_indices_holder: List[Optional[List[int]]] - unit_info_holder: List[Optional[Tuple]] - units_put: set - units_put_lock: threading.Lock - count_lock: threading.Lock - exceptions: List[Tuple[str, Exception]] - exception_lock: threading.Lock class Pipeline: """GLM-OCR pipeline. - Unified processing flow: - 1. PageLoader: load images/PDF into pages - 2. (Optional) LayoutDetector: detect regions - 3. OCRClient: call OCR service - 4. ResultFormatter: format outputs + Processing flow: + 1. PageLoader: load images / PDF into pages + 2. LayoutDetector: detect regions + 3. OCRClient: call OCR service + 4. ResultFormatter: format outputs Args: config: PipelineConfig instance. layout_detector: Custom layout detector (optional). result_formatter: Custom result formatter (optional). - Example: + Example:: + from glmocr.config import load_config cfg = load_config() pipeline = Pipeline(cfg.pipeline) for result in pipeline.process(request_data): result.save(output_dir="./results") - - # Custom components - pipeline = Pipeline( - cfg.pipeline, - layout_detector=MyLayoutDetector(cfg.pipeline.layout), - result_formatter=MyFormatter(cfg.pipeline.result_formatter), - ) """ def __init__( @@ -89,20 +67,13 @@ def __init__( result_formatter: Optional[ResultFormatter] = None, ): self.config = config - - # Unified page loader self.page_loader = PageLoader(config.page_loader) - - # OCR client self.ocr_client = OCRClient(config.ocr_api) + self.result_formatter = ( + result_formatter if result_formatter is not None + else ResultFormatter(config.result_formatter) + ) - # Result formatter - if result_formatter is not None: - self.result_formatter = result_formatter - else: - self.result_formatter = ResultFormatter(config.result_formatter) - - # Layout detector if layout_detector is not None: self.layout_detector = layout_detector else: @@ -110,7 +81,6 @@ def __init__( if PPDocLayoutDetector is None: from glmocr.layout import _raise_layout_import_error - _raise_layout_import_error() self.layout_detector = PPDocLayoutDetector(config.layout) @@ -119,30 +89,9 @@ def __init__( self._page_maxsize = getattr(config, "page_maxsize", 100) self._region_maxsize = getattr(config, "region_maxsize", 800) - def _create_async_pipeline_state( - self, - page_maxsize: Optional[int], - region_maxsize: Optional[int], - ) -> _AsyncPipelineState: - q1 = page_maxsize if page_maxsize is not None else self._page_maxsize - q2 = region_maxsize if region_maxsize is not None else self._region_maxsize - return _AsyncPipelineState( - page_queue=queue.Queue(maxsize=q1), - region_queue=queue.Queue(maxsize=q2), - ready_units_queue=queue.Queue(), - recognition_results=[], - results_lock=threading.Lock(), - images_dict={}, - layout_results_dict={}, - num_images_loaded=[0], - unit_indices_holder=[None], - unit_info_holder=[None], - units_put=set(), - units_put_lock=threading.Lock(), - count_lock=threading.Lock(), - exceptions=[], - exception_lock=threading.Lock(), - ) + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ def process( self, @@ -152,459 +101,80 @@ def process( page_maxsize: Optional[int] = None, region_maxsize: Optional[int] = None, ) -> Generator[PipelineResult, None, None]: - """Process request with async three-stage flow; yield one result per input unit. + """Process a request; yield one ``PipelineResult`` per input unit. - Uses three threads: load pages -> layout detection -> recognition. - Yields PipelineResult as each unit (one image or one PDF) completes. + Uses three threads (load → layout → recognition) with bounded queues + for back-pressure. Args: - request_data: Request payload containing messages. - save_layout_visualization: Whether to save layout visualization. - layout_vis_output_dir: Visualization output directory. - page_maxsize: Max size for page_queue (page-level items). - region_maxsize: Max size for region_queue (region-level items). Should be - larger than page_maxsize since one page yields many regions. - Defaults to page_maxsize * 8 if not set. + request_data: OpenAI-style request payload containing messages. + save_layout_visualization: Save layout visualisation images. + layout_vis_output_dir: Directory for visualisation output. + page_maxsize: Bound for the page queue. + region_maxsize: Bound for the region queue. Yields: - PipelineResult per input URL (one image or one PDF). + One ``PipelineResult`` per input URL (image or PDF). """ + image_urls = extract_image_urls(request_data) - image_urls = self._extract_image_urls(request_data) if not image_urls: - request_data = self.page_loader.build_request(request_data) - response, status_code = self.ocr_client.process(request_data) - if status_code != 200: - raise Exception( - f"OCR request failed: {response}, status_code: {status_code}" - ) - content = ( - response.get("choices", [{}])[0].get("message", {}).get("content", "") - ) - json_result, markdown_result = self.result_formatter.format_ocr_result( - content - ) - yield PipelineResult( - json_result=json_result, - markdown_result=markdown_result, - original_images=[], - layout_vis_dir=layout_vis_output_dir, - ) + yield self._process_passthrough(request_data, layout_vis_output_dir) return - state = self._create_async_pipeline_state(page_maxsize, region_maxsize) - - def data_loading_thread() -> None: - try: - img_idx = 0 - unit_indices_list: List[int] = [] - for page, unit_idx in self.page_loader.iter_pages_with_unit_indices( - image_urls - ): - state.images_dict[img_idx] = page - state.page_queue.put(("image", img_idx, page)) - unit_indices_list.append(unit_idx) - img_idx += 1 - state.num_images_loaded[0] = img_idx - state.unit_indices_holder[0] = list(unit_indices_list) - state.page_queue.put(("done", None, None)) - except Exception as e: - logger.exception("Data loading thread error: %s", e) - state.num_images_loaded[0] = img_idx - state.unit_indices_holder[0] = list(unit_indices_list) - with state.exception_lock: - state.exceptions.append(("DataLoadingThread", e)) - state.page_queue.put(("error", None, None)) - - def layout_detection_thread() -> None: - try: - batch_images: List[Any] = [] - batch_indices: List[int] = [] - loading_complete = False - global_start_idx = 0 - while True: - try: - item_type, img_idx, data = state.page_queue.get(timeout=1) - except queue.Empty: - if loading_complete and batch_images: - self._stream_process_layout_batch( - batch_images, - batch_indices, - state.region_queue, - state.images_dict, - state.layout_results_dict, - save_layout_visualization, - layout_vis_output_dir, - global_start_idx, - ) - global_start_idx += len(batch_indices) - batch_images = [] - batch_indices = [] - continue - if item_type == "image": - batch_images.append(data) - batch_indices.append(img_idx) - if len(batch_images) >= self.layout_detector.batch_size: - self._stream_process_layout_batch( - batch_images, - batch_indices, - state.region_queue, - state.images_dict, - state.layout_results_dict, - save_layout_visualization, - layout_vis_output_dir, - global_start_idx, - ) - global_start_idx += len(batch_indices) - batch_images = [] - batch_indices = [] - elif item_type == "done": - loading_complete = True - if batch_images: - self._stream_process_layout_batch( - batch_images, - batch_indices, - state.region_queue, - state.images_dict, - state.layout_results_dict, - save_layout_visualization, - layout_vis_output_dir, - global_start_idx, - ) - state.region_queue.put(("done", None, None)) - break - elif item_type == "error": - state.region_queue.put(("error", None, None)) - break - except Exception as e: - logger.exception("Layout detection thread error: %s", e) - with state.exception_lock: - state.exceptions.append(("LayoutDetectionThread", e)) - state.region_queue.put(("error", None, None)) - - def maybe_notify_ready_units(img_idx: Optional[int] = None) -> None: - """Notify when a unit is ready. O(1) when img_idx is given.""" - info = state.unit_info_holder[0] - if info is None: - return - if img_idx is not None and len(info) >= 5: - ( - _, - unit_region_count, - unit_for_image, - unit_region_done_count, - c_lock, - ) = info - u = unit_for_image.get(img_idx) - if u is None: - return - with c_lock: - unit_region_done_count[u] += 1 - if unit_region_done_count[u] >= unit_region_count[u]: - with state.units_put_lock: - if u not in state.units_put: - state.ready_units_queue.put(u) - state.units_put.add(u) - return - unit_image_indices, unit_region_count = info[:2] - num_units = len(unit_region_count) - with state.results_lock: - rec = list(state.recognition_results) - with state.units_put_lock: - for u in range(num_units): - if u in state.units_put: - continue - if unit_region_count[u] == 0: - state.ready_units_queue.put(u) - state.units_put.add(u) - continue - count = sum(1 for (i, _) in rec if i in unit_image_indices[u]) - if count >= unit_region_count[u]: - state.ready_units_queue.put(u) - state.units_put.add(u) - - def vlm_recognition_thread() -> None: - try: - executor = ThreadPoolExecutor(max_workers=min(self.max_workers, 128)) - futures: Dict[Any, Tuple[Dict, str, int]] = {} - pending_skip: List[Tuple[Dict, str, int]] = [] - processing_complete = False - while True: - for f in list(futures.keys()): - if f.done(): - info, task, page_idx = futures.pop(f) - try: - response, status_code = f.result() - if status_code == 200: - info["content"] = response["choices"][0]["message"][ - "content" - ].strip() - else: - info["content"] = "" - except Exception as e: - logger.warning("Recognition failed: %s", e) - info["content"] = "" - with state.results_lock: - state.recognition_results.append((page_idx, info)) - maybe_notify_ready_units(page_idx) - try: - item_type, img_idx, data = state.region_queue.get(timeout=0.01) - except queue.Empty: - if processing_complete and len(futures) == 0: - for region, task_type, page_idx in pending_skip: - region["content"] = None - with state.results_lock: - state.recognition_results.append((page_idx, region)) - maybe_notify_ready_units(page_idx) - break - if futures: - done_list = [f for f in futures.keys() if f.done()] - if not done_list: - try: - next(as_completed(futures.keys(), timeout=0.05)) - except Exception: - pass - continue - if item_type == "region": - cropped_image, region, task_type, page_idx = data - if task_type == "skip": - pending_skip.append((region, task_type, page_idx)) - else: - req = self.page_loader.build_request_from_image( - cropped_image, task_type - ) - future = executor.submit(self.ocr_client.process, req) - futures[future] = (region, task_type, page_idx) - elif item_type == "done": - processing_complete = True - elif item_type == "error": - break - if futures: - for future in as_completed(futures.keys()): - info, task, page_idx = futures[future] - try: - response, status_code = future.result() - if status_code == 200: - info["content"] = response["choices"][0]["message"][ - "content" - ].strip() - else: - info["content"] = "" - except Exception as e: - logger.warning("Recognition failed: %s", e) - info["content"] = "" - with state.results_lock: - state.recognition_results.append((page_idx, info)) - maybe_notify_ready_units(page_idx) - executor.shutdown(wait=True) - except Exception as e: - logger.exception("VLM recognition thread error: %s", e) - with state.exception_lock: - state.exceptions.append(("VLMRecognitionThread", e)) - - t1 = threading.Thread(target=data_loading_thread, daemon=True) - t2 = threading.Thread(target=layout_detection_thread, daemon=True) - t3 = threading.Thread(target=vlm_recognition_thread, daemon=True) + state = PipelineState( + page_maxsize=page_maxsize or self._page_maxsize, + region_maxsize=region_maxsize or self._region_maxsize, + ) + + t1 = threading.Thread( + target=data_loading_worker, + args=(state, self.page_loader, image_urls), + daemon=True, + ) + t2 = threading.Thread( + target=layout_worker, + args=(state, self.layout_detector, save_layout_visualization, layout_vis_output_dir), + daemon=True, + ) + t3 = threading.Thread( + target=recognition_worker, + args=(state, self.page_loader, self.ocr_client, self.max_workers), + daemon=True, + ) + t1.start() t2.start() t3.start() + + # Wait for loading & layout to finish so we know the total counts. t1.join() t2.join() num_images = state.num_images_loaded[0] - unit_indices = state.unit_indices_holder[0] num_units = len(image_urls) - original_inputs = [ - (url[7:] if url.startswith("file://") else url) for url in image_urls - ] + original_inputs = make_original_inputs(image_urls) if num_images == 0: - empty_json, empty_md = self.result_formatter.process([]) - for u in range(num_units): - yield PipelineResult( - json_result=empty_json, - markdown_result=empty_md, - original_images=[original_inputs[u]], - layout_vis_dir=layout_vis_output_dir, - ) + yield from self._emit_empty(num_units, original_inputs, layout_vis_output_dir) t3.join() - with state.exception_lock: - if state.exceptions: - raise RuntimeError( - "; ".join(f"{n}: {e}" for n, e in state.exceptions) - ) + state.raise_if_exceptions() return - unit_image_indices: List[List[int]] = [[] for _ in range(num_units)] - for img_idx in range(num_images): - if unit_indices is not None and img_idx < len(unit_indices): - u = unit_indices[img_idx] - if u < num_units: - unit_image_indices[u].append(img_idx) - unit_region_count = [ - sum( - len(state.layout_results_dict.get(i, [])) for i in unit_image_indices[u] - ) - for u in range(num_units) - ] - unit_for_image: Dict[int, int] = { - i: u for u in range(num_units) for i in unit_image_indices[u] - } - unit_region_done_count: List[int] = [0] * num_units - with state.results_lock: - rec_init = list(state.recognition_results) - for i, _ in rec_init: - u = unit_for_image.get(i) - if u is not None: - unit_region_done_count[u] += 1 - state.unit_info_holder[0] = ( - unit_image_indices, - unit_region_count, - unit_for_image, - unit_region_done_count, - state.count_lock, - ) - for u in range(num_units): - if unit_region_done_count[u] >= unit_region_count[u]: - state.ready_units_queue.put(u) - state.units_put.add(u) + tracker = self._build_tracker(state, num_units, num_images) + state.set_tracker(tracker) - emitted: set = set() - while len(emitted) < num_units: - u = state.ready_units_queue.get() - if u in emitted: - continue - with state.results_lock: - rec = list(state.recognition_results) - count = sum(1 for (i, _) in rec if i in unit_image_indices[u]) - if count < unit_region_count[u]: - state.ready_units_queue.put(u) - continue - img_to_idx = {i: k for k, i in enumerate(unit_image_indices[u])} - grouped_u: List[List[Dict]] = [[] for _ in unit_image_indices[u]] - for i, r in rec: - if i in img_to_idx: - grouped_u[img_to_idx[i]].append(r) - json_u, md_u = self.result_formatter.process(grouped_u) - yield PipelineResult( - json_result=json_u, - markdown_result=md_u, - original_images=[original_inputs[u]], - layout_vis_dir=layout_vis_output_dir, - layout_image_indices=unit_image_indices[u], - ) - emitted.add(u) + yield from self._emit_results(state, tracker, original_inputs, layout_vis_output_dir) t3.join() - with state.exception_lock: - if state.exceptions: - raise RuntimeError("; ".join(f"{n}: {e}" for n, e in state.exceptions)) + state.raise_if_exceptions() - def _stream_process_layout_batch( - self, - batch_images: List[Any], - batch_indices: List[int], - region_queue: queue.Queue, - images_dict: Dict[int, Any], - layout_results_dict: Dict[int, List], - save_visualization: bool, - vis_output_dir: Optional[str], - global_start_idx: int, - ) -> None: - """Run layout detection on a batch and push regions to queue2.""" - layout_results = self.layout_detector.process( - batch_images, - save_visualization=save_visualization and vis_output_dir is not None, - visualization_output_dir=vis_output_dir, - global_start_idx=global_start_idx, - ) - for img_idx, image, layout_result in zip( - batch_indices, batch_images, layout_results - ): - layout_results_dict[img_idx] = layout_result - for region in layout_result: - cropped = crop_image_region(image, region["bbox_2d"], region["polygon"]) - region_queue.put( - ( - "region", - img_idx, - (cropped, region, region["task_type"], img_idx), - ) - ) - - def _extract_image_urls(self, request_data: Dict[str, Any]) -> List[str]: - """Extract image URLs from request_data.""" - image_urls = [] - for msg in request_data.get("messages", []): - if msg.get("role") == "user": - contents = msg.get("content", []) - if isinstance(contents, list): - for content in contents: - if content.get("type") == "image_url": - image_urls.append(content["image_url"]["url"]) - return image_urls - - def _prepare_regions(self, pages, layout_results) -> List[tuple]: - """Prepare regions that need recognition.""" - regions = [] - with profiler.measure("crop_regions"): - for page_idx, (page, layouts) in enumerate(zip(pages, layout_results)): - for region in layouts: - cropped = crop_image_region(page, region["bbox_2d"]) - regions.append((cropped, region, region["task_type"], page_idx)) - return regions - - def _recognize_regions(self, regions: List[tuple]) -> List[tuple]: - """Recognize all regions in parallel.""" - results = [] - - # Split skipped regions and regions to process - to_process = [] - for img, info, task, page_idx in regions: - if task == "skip": - info["content"] = None - results.append((page_idx, info)) - else: - to_process.append((img, info, task, page_idx)) - - if not to_process: - return results - - # Build all requests first - request_data_list = [] - with profiler.measure("build_region_requests"): - for img, info, task, page_idx in to_process: - request_data = self.page_loader.build_request_from_image(img, task) - request_data_list.append((request_data, info, task, page_idx)) - - # Run in parallel - with ThreadPoolExecutor( - max_workers=min(self.max_workers, len(to_process)) - ) as executor: - futures = {} - for request_data, info, task, page_idx in request_data_list: - future = executor.submit(self.ocr_client.process, request_data) - futures[future] = (info, task, page_idx) - - for future in as_completed(futures): - info, task, page_idx = futures[future] - try: - response, status_code = future.result() - if status_code == 200: - info["content"] = response["choices"][0]["message"][ - "content" - ].strip() - else: - info["content"] = "" - except Exception as e: - logger.warning("Recognition failed: %s", e) - info["content"] = "" - results.append((page_idx, info)) - - return results + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ def start(self): - """Start the pipeline.""" + """Start the pipeline (layout detector + OCR client).""" logger.info("Starting Pipeline...") self.layout_detector.start() self.ocr_client.start() @@ -623,3 +193,109 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.stop() + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _process_passthrough( + self, + request_data: Dict[str, Any], + layout_vis_output_dir: Optional[str], + ) -> PipelineResult: + """No image URLs — forward the request directly to the OCR API.""" + request_data = self.page_loader.build_request(request_data) + response, status_code = self.ocr_client.process(request_data) + if status_code != 200: + raise Exception( + f"OCR request failed: {response}, status_code: {status_code}" + ) + content = extract_ocr_content(response) + json_result, markdown_result = self.result_formatter.format_ocr_result(content) + return PipelineResult( + json_result=json_result, + markdown_result=markdown_result, + original_images=[], + layout_vis_dir=layout_vis_output_dir, + ) + + @staticmethod + def _build_tracker( + state: PipelineState, + num_units: int, + num_images: int, + ) -> UnitTracker: + """Build and backfill a UnitTracker from the current state.""" + unit_indices = state.unit_indices_holder[0] + unit_image_indices: List[List[int]] = [[] for _ in range(num_units)] + for page_idx in range(num_images): + if unit_indices is not None and page_idx < len(unit_indices): + u = unit_indices[page_idx] + if u < num_units: + unit_image_indices[u].append(page_idx) + + unit_region_count = [ + sum(len(state.layout_results_dict.get(i, [])) for i in unit_image_indices[u]) + for u in range(num_units) + ] + + tracker = UnitTracker(num_units, unit_image_indices, unit_region_count) + already_done = state.snapshot_recognition_results() + tracker.backfill([r["page_idx"] for r in already_done]) + return tracker + + def _emit_empty( + self, + num_units: int, + original_inputs: List[str], + layout_vis_output_dir: Optional[str], + ) -> Generator[PipelineResult, None, None]: + """Yield empty results when no images were loaded.""" + empty_json, empty_md = self.result_formatter.process([]) + for u in range(num_units): + yield PipelineResult( + json_result=empty_json, + markdown_result=empty_md, + original_images=[original_inputs[u]], + layout_vis_dir=layout_vis_output_dir, + ) + + def _emit_results( + self, + state: PipelineState, + tracker: UnitTracker, + original_inputs: List[str], + layout_vis_output_dir: Optional[str], + ) -> Generator[PipelineResult, None, None]: + """Wait for units to complete and yield their formatted results.""" + emitted: set = set() + while len(emitted) < tracker.num_units: + u = tracker.wait_next_ready_unit() + if u in emitted: + continue + + results = state.snapshot_recognition_results() + + page_indices = tracker.unit_image_indices[u] + page_set = set(page_indices) + count = sum(1 for r in results if r["page_idx"] in page_set) + if count < tracker.unit_region_count[u]: + tracker._ready_queue.put(u) + continue + + page_to_pos = {idx: k for k, idx in enumerate(page_indices)} + grouped: List[List[Dict]] = [[] for _ in page_indices] + for r in results: + pos = page_to_pos.get(r["page_idx"]) + if pos is not None: + grouped[pos].append(r["region"]) + + json_u, md_u = self.result_formatter.process(grouped) + yield PipelineResult( + json_result=json_u, + markdown_result=md_u, + original_images=[original_inputs[u]], + layout_vis_dir=layout_vis_output_dir, + layout_image_indices=page_indices, + ) + emitted.add(u) From 0489c36a4d913c12e468a6433177880c5d4fec9a Mon Sep 17 00:00:00 2001 From: xueyadong Date: Mon, 2 Mar 2026 19:55:31 +0800 Subject: [PATCH 03/38] - Add async pipeline support for files in directory via CLI - Implement dynamic registration of Tracker in pipeline for immediate results without waiting for t1 and t2 to finish --- glmocr/api.py | 6 ++ glmocr/cli.py | 127 +++++++++++++++++----------- glmocr/pipeline/_common.py | 1 + glmocr/pipeline/_state.py | 36 +++++++- glmocr/pipeline/_unit_tracker.py | 138 ++++++++++++++++++++----------- glmocr/pipeline/_workers.py | 104 +++++++++++++++++++++-- glmocr/pipeline/pipeline.py | 106 ++++++++++-------------- pyproject.toml | 1 + 8 files changed, 349 insertions(+), 170 deletions(-) diff --git a/glmocr/api.py b/glmocr/api.py index 48fad1e..4f8d8d1 100644 --- a/glmocr/api.py +++ b/glmocr/api.py @@ -524,6 +524,12 @@ def parse_maas( **kwargs, ) + def get_queue_stats(self) -> Optional[Dict[str, int]]: + """Return current pipeline queue sizes, or ``None`` if unavailable.""" + if self._pipeline is not None: + return self._pipeline.get_queue_stats() + return None + def close(self): """Close the parser and release resources.""" if self._pipeline: diff --git a/glmocr/cli.py b/glmocr/cli.py index 5168be7..899a716 100644 --- a/glmocr/cli.py +++ b/glmocr/cli.py @@ -6,10 +6,13 @@ import sys import json import argparse +import threading import traceback from pathlib import Path from typing import List +from tqdm import tqdm + from glmocr.api import GlmOcr from glmocr.utils.logging import get_logger, configure_logging @@ -31,14 +34,12 @@ def load_image_paths(input_path: str) -> List[str]: image_paths = [] if path.is_file(): - # Single file suffix = path.suffix.lower() if suffix in [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".pdf"]: image_paths.append(str(path.absolute())) else: raise ValueError(f"Not Supported Type: {path.suffix}") elif path.is_dir(): - # Directory: find all image and PDF files for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.pdf"]: image_paths.extend([str(p.absolute()) for p in path.glob(ext)]) image_paths.extend([str(p.absolute()) for p in path.glob(ext.upper())]) @@ -53,6 +54,17 @@ def load_image_paths(input_path: str) -> List[str]: return image_paths +def _queue_stats_updater(glm_parser: GlmOcr, pbar: tqdm, stop: threading.Event): + while not stop.wait(0.3): + stats = glm_parser.get_queue_stats() + if stats: + pbar.set_postfix_str( + f"Q1:{stats['page_queue_size']}/{stats['page_queue_maxsize']} " + f"Q2:{stats['region_queue_size']}/{stats['region_queue_maxsize']}", + refresh=True, + ) + + def main(): """CLI entrypoint.""" parser = argparse.ArgumentParser( @@ -76,7 +88,6 @@ def main(): subparsers = parser.add_subparsers(dest="command", help="Command") - # parse command parse_parser = subparsers.add_parser("parse", help="Parse document images") parse_parser.add_argument( "input", type=str, help="Input image file or directory path" @@ -129,66 +140,82 @@ def main(): parser.print_help() sys.exit(1) - # Configure logging configure_logging(level=args.log_level) try: - # Load inputs logger.info("Loading images: %s", args.input) image_paths = load_image_paths(args.input) logger.info("Found %d file(s)", len(image_paths)) - # Use GlmOcr API save_layout_vis = not args.no_layout_vis with GlmOcr(config_path=args.config) as glm_parser: - logger.info("Using Pipeline...") - - # Process each file (parse() with str returns a single PipelineResult) total_files = len(image_paths) - for idx, image_path in enumerate(image_paths, 1): - file_name = Path(image_path).name - logger.info("") - logger.info("=== Parsing: %s (%d/%d) ===", file_name, idx, total_files) - - try: - result = glm_parser.parse( - image_path, save_layout_visualization=save_layout_vis + + pbar = tqdm( + total=total_files, + desc="Parsing", + unit="file", + file=sys.stderr, + dynamic_ncols=True, + ) + + stop_event = threading.Event() + stats_thread = threading.Thread( + target=_queue_stats_updater, + args=(glm_parser, pbar, stop_event), + daemon=True, + ) + stats_thread.start() + + try: + for result in glm_parser.parse( + image_paths, + stream=True, + save_layout_visualization=save_layout_vis, + ): + file_name = ( + Path(result.original_images[0]).name + if result.original_images + else f"unit_{pbar.n + 1}" ) - # Output - if args.stdout: - stem = ( - Path(result.original_images[0]).stem - if result.original_images - else file_name - ) - print(f"\n=== {stem} - JSON Result ===") - print( - json.dumps( - result.json_result, - ensure_ascii=False, - indent=2, + pbar.update(1) + + try: + if args.stdout: + stem = ( + Path(result.original_images[0]).stem + if result.original_images + else file_name + ) + print(f"\n=== {stem} - JSON Result ===") + print( + json.dumps( + result.json_result, + ensure_ascii=False, + indent=2, + ) + if isinstance(result.json_result, (dict, list)) + else result.json_result ) - if isinstance(result.json_result, (dict, list)) - else result.json_result - ) - if result.markdown_result and not args.json_only: - print(f"\n=== {stem} - Markdown Result ===") - print(result.markdown_result) - - # Save to files by default (unless --no-save) - if not args.no_save: - result.save( - output_dir=args.output, - save_layout_visualization=save_layout_vis, - ) - - except Exception as e: - logger.error("Failed: %s: %s", file_name, e) - logger.debug(traceback.format_exc()) - continue - - logger.info("") + if result.markdown_result and not args.json_only: + print(f"\n=== {stem} - Markdown Result ===") + print(result.markdown_result) + + if not args.no_save: + result.save( + output_dir=args.output, + save_layout_visualization=save_layout_vis, + ) + + except Exception as e: + tqdm.write(f"Failed: {file_name}: {e}", file=sys.stderr) + continue + finally: + stop_event.set() + stats_thread.join(timeout=2) + pbar.close() + logger.info("All done!") except KeyboardInterrupt: diff --git a/glmocr/pipeline/_common.py b/glmocr/pipeline/_common.py index 7b09ba4..09e00d1 100644 --- a/glmocr/pipeline/_common.py +++ b/glmocr/pipeline/_common.py @@ -37,6 +37,7 @@ def extract_ocr_content(response: Dict[str, Any]) -> str: # ── Queue message "identifier" field values ────────────────────────── # Every queue message is a dict with at least an "identifier" key. IDENTIFIER_IMAGE = "image" +IDENTIFIER_UNIT_DONE = "unit_done" # t1 → t2: all pages for one input unit are queued IDENTIFIER_REGION = "region" IDENTIFIER_DONE = "done" IDENTIFIER_ERROR = "error" diff --git a/glmocr/pipeline/_state.py b/glmocr/pipeline/_state.py index 9f5514f..27da34f 100644 --- a/glmocr/pipeline/_state.py +++ b/glmocr/pipeline/_state.py @@ -2,7 +2,7 @@ This object is created once per ``Pipeline.process()`` call and passed to all three worker threads. It holds the inter-thread queues, accumulated -results, and the (optional) UnitTracker reference. +results, and the UnitTracker reference. """ from __future__ import annotations @@ -46,13 +46,26 @@ def __init__( self._recognition_results: List[Dict[str, Any]] = [] self._results_lock = threading.Lock() - # ── UnitTracker (set by main thread after stage 1 & 2 join) ── + # ── UnitTracker (set before threads start) ─────────────────── self._tracker: Optional[UnitTracker] = None # ── Exception collection ───────────────────────────────────── self._exceptions: List[Dict[str, Any]] = [] self._exception_lock = threading.Lock() + # ------------------------------------------------------------------ + # Page registration (delegated to tracker) + # ------------------------------------------------------------------ + + def register_page(self, page_idx: int, unit_idx: int) -> None: + """Register a ``page_idx → unit_idx`` mapping in the tracker. + + Called by the data-loading worker (t1) for every loaded page. + """ + tracker = self._tracker + if tracker is not None: + tracker.register_page(page_idx, unit_idx) + # ------------------------------------------------------------------ # Recognition results # ------------------------------------------------------------------ @@ -76,8 +89,24 @@ def snapshot_recognition_results(self) -> List[Dict[str, Any]]: # ------------------------------------------------------------------ def set_tracker(self, tracker: UnitTracker) -> None: + """Attach *tracker* to the shared state. + + Must be called **before** any worker thread is started so that + ``register_page``, ``finalize_unit``, and ``on_region_done`` are + never no-ops. + """ self._tracker = tracker + def finalize_unit(self, unit_idx: int, region_count: int) -> None: + """Delegate to the tracker's ``finalize_unit`` if a tracker is attached. + + Called by the layout worker (t2) after it has processed all pages of + *unit_idx*. + """ + tracker = self._tracker + if tracker is not None: + tracker.finalize_unit(unit_idx, region_count) + # ------------------------------------------------------------------ # Exception handling # ------------------------------------------------------------------ @@ -85,6 +114,9 @@ def set_tracker(self, tracker: UnitTracker) -> None: def record_exception(self, source: str, exc: Exception) -> None: with self._exception_lock: self._exceptions.append({"source": source, "exception": exc}) + tracker = self._tracker + if tracker is not None: + tracker.signal_shutdown() def raise_if_exceptions(self) -> None: with self._exception_lock: diff --git a/glmocr/pipeline/_unit_tracker.py b/glmocr/pipeline/_unit_tracker.py index c0e2f48..266564e 100644 --- a/glmocr/pipeline/_unit_tracker.py +++ b/glmocr/pipeline/_unit_tracker.py @@ -2,95 +2,133 @@ One "unit" corresponds to a single input URL (one image file or one PDF). A PDF unit may span multiple pages, each page having multiple regions. -This tracker counts completed regions and notifies the main thread when -all regions for a unit are done. + +Fully dynamic registration protocol +------------------------------------ +The tracker is created *before* any worker thread starts, knowing only +``num_units``. All other metadata is registered incrementally: + +1. **``register_page(page_idx, unit_idx)``** — called by Stage 1 / t1 for + every page that is successfully loaded. Builds the ``page → unit`` + mapping on the fly. + +2. **``finalize_unit(u, region_count)``** — called by Stage 2 / t2 as soon as + all pages of unit *u* have been layout-detected. At that point the total + region count for the unit is known and the tracker can check whether Stage 3 + has already finished all recognitions, immediately notifying the main thread + if so. + +3. **``on_region_done(page_idx)``** — called by Stage 3 / t3 for every + completed recognition. If the unit has already been finalised (its + ``region_count`` is known) and the counter reaches the target, the unit is + enqueued for the main thread. + +4. **``signal_shutdown()``** — called when an error is recorded; puts a + ``None`` sentinel on the ready queue so the main thread unblocks. Thread safety: + - ``register_page()`` is called from Thread 1 (data-loading worker). + - ``finalize_unit()`` is called from Thread 2 (layout worker). - ``on_region_done()`` is called from Thread 3 (recognition worker). - - ``wait_next_ready_unit()`` / ``iter_ready_units()`` are called from the main thread. + - ``wait_next_ready_unit()`` is called from the main thread. + All mutations are serialised by a single lock. """ from __future__ import annotations import queue import threading -from typing import Dict, List +from typing import Dict, List, Optional class UnitTracker: - """Tracks region-level completion for each input unit.""" - - def __init__( - self, - num_units: int, - unit_image_indices: List[List[int]], - unit_region_count: List[int], - ): + """Tracks region-level completion for each input unit. + + Args: + num_units: Total number of input URLs being processed. + """ + + def __init__(self, num_units: int): self._num_units = num_units - self._unit_image_indices = unit_image_indices - self._unit_region_count = unit_region_count - - self._unit_for_image: Dict[int, int] = { - img_idx: u - for u in range(num_units) - for img_idx in unit_image_indices[u] - } + self._unit_image_indices: List[List[int]] = [[] for _ in range(num_units)] + self._unit_region_count: List[Optional[int]] = [None] * num_units + self._unit_for_image: Dict[int, int] = {} self._done_count: List[int] = [0] * num_units - self._ready_queue: queue.Queue[int] = queue.Queue() + self._ready_queue: queue.Queue[Optional[int]] = queue.Queue() self._notified: set = set() - self._lock = threading.Lock() - self._notify_lock = threading.Lock() # ------------------------------------------------------------------ - # Backfill: handle regions that completed *before* tracker existed + # Phase 1: called from t1 for each loaded page # ------------------------------------------------------------------ - def backfill(self, done_page_indices: List[int]) -> None: - """Account for regions finished before the tracker was initialised. - - Args: - done_page_indices: ``page_idx`` values of already-completed regions. + def register_page(self, page_idx: int, unit_idx: int) -> None: + """Register a ``page_idx → unit_idx`` mapping. - Called once from the main thread right after construction. + Called by the data-loading worker (t1) for every successfully loaded + page, *before* the page is placed on the page queue. This guarantees + that by the time Stage 3 calls ``on_region_done(page_idx)``, the + mapping is already present. """ - for page_idx in done_page_indices: - u = self._unit_for_image.get(page_idx) - if u is not None: - self._done_count[u] += 1 + with self._lock: + self._unit_image_indices[unit_idx].append(page_idx) + self._unit_for_image[page_idx] = unit_idx + + # ------------------------------------------------------------------ + # Phase 2: called from t2 when all pages of a unit are layout-done + # ------------------------------------------------------------------ + + def finalize_unit(self, u: int, region_count: int) -> None: + """Record the total region count for unit *u* and check completion. - for u in range(self._num_units): - if self._done_count[u] >= self._unit_region_count[u]: + Called by the layout worker (t2) immediately after it has processed + all pages belonging to unit *u*. If Stage 3 has already finished all + recognitions for that unit, the unit is enqueued for the main thread. + """ + with self._lock: + self._unit_region_count[u] = region_count + if self._done_count[u] >= region_count and u not in self._notified: self._ready_queue.put(u) self._notified.add(u) # ------------------------------------------------------------------ - # Runtime: called from Thread 3 after each region completes + # Runtime: called from t3 after each region completes # ------------------------------------------------------------------ def on_region_done(self, page_idx: int) -> None: """Increment the counter for the unit owning *page_idx*. - O(1). If the unit reaches its target count, enqueue it. + O(1). If the unit's region_count is known and the counter reaches + the target, the unit is enqueued for the main thread. """ - u = self._unit_for_image.get(page_idx) - if u is None: - return with self._lock: + u = self._unit_for_image.get(page_idx) + if u is None: + return self._done_count[u] += 1 - ready = self._done_count[u] >= self._unit_region_count[u] - if ready: - with self._notify_lock: - if u not in self._notified: - self._ready_queue.put(u) - self._notified.add(u) + rc = self._unit_region_count[u] + if rc is not None and self._done_count[u] >= rc and u not in self._notified: + self._ready_queue.put(u) + self._notified.add(u) + + # ------------------------------------------------------------------ + # Shutdown: wake up blocked main thread on error + # ------------------------------------------------------------------ + + def signal_shutdown(self) -> None: + """Put a ``None`` sentinel on the ready queue to unblock the main thread.""" + self._ready_queue.put(None) # ------------------------------------------------------------------ # Consumption: called from the main thread # ------------------------------------------------------------------ - def wait_next_ready_unit(self) -> int: - """Block until the next unit is ready and return its index.""" + def wait_next_ready_unit(self) -> Optional[int]: + """Block until the next unit is ready and return its index. + + Returns ``None`` when ``signal_shutdown()`` has been called (error + path). + """ return self._ready_queue.get() @property @@ -102,5 +140,5 @@ def unit_image_indices(self) -> List[List[int]]: return self._unit_image_indices @property - def unit_region_count(self) -> List[int]: + def unit_region_count(self) -> List[Optional[int]]: return self._unit_region_count diff --git a/glmocr/pipeline/_workers.py b/glmocr/pipeline/_workers.py index 84f3d5c..293eabc 100644 --- a/glmocr/pipeline/_workers.py +++ b/glmocr/pipeline/_workers.py @@ -7,7 +7,10 @@ Queue message formats: page_queue:: - {"identifier": "image", "page_idx": int, "image": PIL.Image} + {"identifier": "image", "page_idx": int, "unit_idx": int, + "image": PIL.Image} + {"identifier": "unit_done", "unit_idx": int} ← all pages for this + unit have been queued {"identifier": "done"} {"identifier": "error"} @@ -29,6 +32,7 @@ IDENTIFIER_ERROR, IDENTIFIER_IMAGE, IDENTIFIER_REGION, + IDENTIFIER_UNIT_DONE, ) from glmocr.pipeline._state import PipelineState from glmocr.utils.image_utils import crop_image_region @@ -50,21 +54,63 @@ def data_loading_worker( page_loader: "PageLoader", image_urls: List[str], ) -> None: - """Load pages from *image_urls* and push them onto ``state.page_queue``.""" + """Load pages from *image_urls* and push them onto ``state.page_queue``. + + For each page that is loaded, ``state.register_page()`` is called + **before** the page message is enqueued, so that the tracker's + ``page → unit`` mapping is always available by the time Stage 3 calls + ``on_region_done``. + + After all pages for each input URL (unit) have been enqueued, a + ``IDENTIFIER_UNIT_DONE`` sentinel is sent so that the layout worker can + call ``state.finalize_unit()`` without waiting for *all* units to finish. + + Units that produce zero pages (e.g. broken URLs) still receive a + ``UNIT_DONE`` sentinel so the tracker can finalise them with + ``region_count=0``. + """ + num_units = len(image_urls) page_idx = 0 unit_indices_list: List[int] = [] + prev_unit_idx: Optional[int] = None + sent_unit_done: set = set() try: for page, unit_idx in page_loader.iter_pages_with_unit_indices(image_urls): + if prev_unit_idx is not None and unit_idx != prev_unit_idx: + state.page_queue.put({ + "identifier": IDENTIFIER_UNIT_DONE, + "unit_idx": prev_unit_idx, + }) + sent_unit_done.add(prev_unit_idx) + + state.register_page(page_idx, unit_idx) state.images_dict[page_idx] = page state.page_queue.put({ "identifier": IDENTIFIER_IMAGE, "page_idx": page_idx, + "unit_idx": unit_idx, "image": page, }) unit_indices_list.append(unit_idx) page_idx += 1 state.num_images_loaded[0] = page_idx state.unit_indices_holder[0] = list(unit_indices_list) + prev_unit_idx = unit_idx + + if prev_unit_idx is not None: + state.page_queue.put({ + "identifier": IDENTIFIER_UNIT_DONE, + "unit_idx": prev_unit_idx, + }) + sent_unit_done.add(prev_unit_idx) + + for u in range(num_units): + if u not in sent_unit_done: + state.page_queue.put({ + "identifier": IDENTIFIER_UNIT_DONE, + "unit_idx": u, + }) + state.page_queue.put({"identifier": IDENTIFIER_DONE}) except Exception as e: logger.exception("Data loading worker error: %s", e) @@ -84,12 +130,23 @@ def layout_worker( save_visualization: bool, vis_output_dir: Optional[str], ) -> None: - """Consume pages, run layout detection in batches, push regions.""" + """Consume pages, run layout detection in batches, push regions. + + When a ``IDENTIFIER_UNIT_DONE`` sentinel arrives from Stage 1, the + current batch is flushed immediately (it contains the last pages for + that unit) and ``state.finalize_unit()`` is called with the total + region count for that unit. This lets the main thread start emitting + results for the completed unit without waiting for later units. + """ try: batch_images: List[Any] = [] batch_page_indices: List[int] = [] + batch_unit_indices: List[int] = [] global_start_idx = 0 + # page_indices seen so far per unit, used to compute region counts. + unit_page_indices: Dict[int, List[int]] = {} + while True: try: msg = state.page_queue.get(timeout=0.01) @@ -99,15 +156,46 @@ def layout_worker( identifier = msg["identifier"] if identifier == IDENTIFIER_IMAGE: + unit_idx = msg["unit_idx"] batch_images.append(msg["image"]) batch_page_indices.append(msg["page_idx"]) + batch_unit_indices.append(unit_idx) + if unit_idx not in unit_page_indices: + unit_page_indices[unit_idx] = [] + unit_page_indices[unit_idx].append(msg["page_idx"]) + if len(batch_images) >= layout_detector.batch_size: _flush_layout_batch( state, layout_detector, batch_images, batch_page_indices, save_visualization, vis_output_dir, global_start_idx, ) global_start_idx += len(batch_page_indices) - batch_images, batch_page_indices = [], [] + batch_images, batch_page_indices, batch_unit_indices = [], [], [] + + elif identifier == IDENTIFIER_UNIT_DONE: + unit_idx = msg["unit_idx"] + # Flush any remaining pages in the batch (they all belong to + # this unit since t1 sends UNIT_DONE after the last page). + if batch_images: + _flush_layout_batch( + state, layout_detector, batch_images, batch_page_indices, + save_visualization, vis_output_dir, global_start_idx, + ) + global_start_idx += len(batch_page_indices) + batch_images, batch_page_indices, batch_unit_indices = [], [], [] + + # All pages for this unit have been layout-detected; compute + # total region count and tell the tracker. + pages_for_unit = unit_page_indices.get(unit_idx, []) + region_count = sum( + len(state.layout_results_dict.get(pi, [])) + for pi in pages_for_unit + ) + state.finalize_unit(unit_idx, region_count) + logger.debug( + "Unit %d finalised: %d pages, %d regions", + unit_idx, len(pages_for_unit), region_count, + ) elif identifier == IDENTIFIER_DONE: if batch_images: @@ -170,7 +258,8 @@ def recognition_worker( ) -> None: """Consume regions, run parallel OCR, store results.""" try: - executor = ThreadPoolExecutor(max_workers=min(max_workers, 128)) + concurrency = min(max_workers, 128) + executor = ThreadPoolExecutor(max_workers=concurrency) futures: Dict[Any, Dict[str, Any]] = {} pending_skip: List[Dict[str, Any]] = [] processing_complete = False @@ -178,6 +267,11 @@ def recognition_worker( while True: _collect_done_futures(futures, state) + if len(futures) >= concurrency: + _wait_for_any(futures) + _collect_done_futures(futures, state) + continue + try: msg = state.region_queue.get(timeout=0.01) except queue.Empty: diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index 5a8a555..8fb0f8c 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -88,6 +88,7 @@ def __init__( self.max_workers = config.max_workers self._page_maxsize = getattr(config, "page_maxsize", 100) self._region_maxsize = getattr(config, "region_maxsize", 800) + self._current_state: Optional[PipelineState] = None # ------------------------------------------------------------------ # Public API @@ -122,10 +123,17 @@ def process( yield self._process_passthrough(request_data, layout_vis_output_dir) return + num_units = len(image_urls) + original_inputs = make_original_inputs(image_urls) + state = PipelineState( page_maxsize=page_maxsize or self._page_maxsize, region_maxsize=region_maxsize or self._region_maxsize, ) + self._current_state = state + + tracker = UnitTracker(num_units) + state.set_tracker(tracker) t1 = threading.Thread( target=data_loading_worker, @@ -147,27 +155,26 @@ def process( t2.start() t3.start() - # Wait for loading & layout to finish so we know the total counts. - t1.join() - t2.join() - - num_images = state.num_images_loaded[0] - num_units = len(image_urls) - original_inputs = make_original_inputs(image_urls) - - if num_images == 0: - yield from self._emit_empty(num_units, original_inputs, layout_vis_output_dir) + try: + yield from self._emit_results(state, tracker, original_inputs, layout_vis_output_dir) + t1.join() + t2.join() t3.join() state.raise_if_exceptions() - return - - tracker = self._build_tracker(state, num_units, num_images) - state.set_tracker(tracker) - - yield from self._emit_results(state, tracker, original_inputs, layout_vis_output_dir) - - t3.join() - state.raise_if_exceptions() + finally: + self._current_state = None + + def get_queue_stats(self) -> Optional[Dict[str, int]]: + """Return current queue sizes, or ``None`` if no processing is active.""" + state = self._current_state + if state is None: + return None + return { + "page_queue_size": state.page_queue.qsize(), + "page_queue_maxsize": state.page_queue.maxsize, + "region_queue_size": state.region_queue.qsize(), + "region_queue_maxsize": state.region_queue.maxsize, + } # ------------------------------------------------------------------ # Lifecycle @@ -219,47 +226,6 @@ def _process_passthrough( layout_vis_dir=layout_vis_output_dir, ) - @staticmethod - def _build_tracker( - state: PipelineState, - num_units: int, - num_images: int, - ) -> UnitTracker: - """Build and backfill a UnitTracker from the current state.""" - unit_indices = state.unit_indices_holder[0] - unit_image_indices: List[List[int]] = [[] for _ in range(num_units)] - for page_idx in range(num_images): - if unit_indices is not None and page_idx < len(unit_indices): - u = unit_indices[page_idx] - if u < num_units: - unit_image_indices[u].append(page_idx) - - unit_region_count = [ - sum(len(state.layout_results_dict.get(i, [])) for i in unit_image_indices[u]) - for u in range(num_units) - ] - - tracker = UnitTracker(num_units, unit_image_indices, unit_region_count) - already_done = state.snapshot_recognition_results() - tracker.backfill([r["page_idx"] for r in already_done]) - return tracker - - def _emit_empty( - self, - num_units: int, - original_inputs: List[str], - layout_vis_output_dir: Optional[str], - ) -> Generator[PipelineResult, None, None]: - """Yield empty results when no images were loaded.""" - empty_json, empty_md = self.result_formatter.process([]) - for u in range(num_units): - yield PipelineResult( - json_result=empty_json, - markdown_result=empty_md, - original_images=[original_inputs[u]], - layout_vis_dir=layout_vis_output_dir, - ) - def _emit_results( self, state: PipelineState, @@ -267,19 +233,33 @@ def _emit_results( original_inputs: List[str], layout_vis_output_dir: Optional[str], ) -> Generator[PipelineResult, None, None]: - """Wait for units to complete and yield their formatted results.""" + """Wait for units to complete and yield their formatted results. + + A unit enters the ready queue when: + - ``finalize_unit`` has been called (region count is known), AND + - all its regions have been recognised (``on_region_done`` counter + reached the target). + + ``None`` from the ready queue signals a pipeline error (shutdown). + """ emitted: set = set() while len(emitted) < tracker.num_units: u = tracker.wait_next_ready_unit() + if u is None: + break if u in emitted: continue - results = state.snapshot_recognition_results() + region_count = tracker.unit_region_count[u] + if region_count is None: + tracker._ready_queue.put(u) + continue + results = state.snapshot_recognition_results() page_indices = tracker.unit_image_indices[u] page_set = set(page_indices) count = sum(1 for r in results if r["page_idx"] in page_set) - if count < tracker.unit_region_count[u]: + if count < region_count: tracker._ready_queue.put(u) continue diff --git a/pyproject.toml b/pyproject.toml index 97f0fef..efa5bc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "PyYAML>=6.0.0", "portalocker>=2.8.2", "python-dotenv>=0.21.0", + "tqdm>=4.62.0", # Layout detection "torch>=2.0.0", From 6dc8d8cced0cc3ad5db51ec9b662fcefd31cf661 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Tue, 3 Mar 2026 15:02:32 +0800 Subject: [PATCH 04/38] Fix a blocking bug; remove redundant PIL format conversions; optimize memory release performance --- glmocr/dataloader/page_loader.py | 42 ++++++++++++++------------------ glmocr/layout/layout_detector.py | 14 ++++------- glmocr/pipeline/_state.py | 9 ++++--- glmocr/pipeline/_workers.py | 22 ++++++----------- glmocr/pipeline/pipeline.py | 18 ++++++-------- 5 files changed, 43 insertions(+), 62 deletions(-) diff --git a/glmocr/dataloader/page_loader.py b/glmocr/dataloader/page_loader.py index 61362af..d4bf72f 100644 --- a/glmocr/dataloader/page_loader.py +++ b/glmocr/dataloader/page_loader.py @@ -336,34 +336,28 @@ def build_request_from_image( if not str(prompt_text).strip(): prompt_text = self.default_prompt - # Convert to RGB - if image.mode != "RGB": - image = image.convert("RGB") - - # Encode image - buffered = BytesIO() - image.save(buffered, format=self.image_format) - img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") - - original_msg = { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": f"data:image/{self.image_format.lower()};base64,{img_base64}" - }, - }, - ], - } + encoded_image = load_image_to_base64( + image, + t_patch_size=self.t_patch_size, + max_pixels=self.max_pixels, + image_format=self.image_format, + patch_expand_factor=self.patch_expand_factor, + min_pixels=self.min_pixels, + ) + content: list = [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/{self.image_format.lower()};base64,{encoded_image}" + }, + }, + ] if prompt_text: - original_msg["content"].append({"type": "text", "text": prompt_text}) - - processed_msg = self._process_msg_standard(original_msg) + content.append({"type": "text", "text": prompt_text}) return { - "messages": [processed_msg], + "messages": [{"role": "user", "content": content}], "max_tokens": self.max_tokens, "temperature": self.temperature, "top_p": self.top_p, diff --git a/glmocr/layout/layout_detector.py b/glmocr/layout/layout_detector.py index 0a551f2..a616a38 100644 --- a/glmocr/layout/layout_detector.py +++ b/glmocr/layout/layout_detector.py @@ -232,13 +232,10 @@ def process( raise RuntimeError("Layout detector not started. Call start() first.") num_images = len(images) - image_batch = [] - for image in images: - image_width, image_height = image.size - image_array = np.array(image.convert("RGB")) - image_batch.append((image_array, image_width, image_height)) - - pil_images = [Image.fromarray(img[0]) for img in image_batch] + pil_images = [ + img.convert("RGB") if img.mode != "RGB" else img + for img in images + ] all_paddle_format_results = [] for chunk_start in range(0, num_images, self.batch_size): @@ -326,8 +323,7 @@ def process( all_results = [] for img_idx, paddle_results in enumerate(all_paddle_format_results): - image_width = image_batch[img_idx][1] - image_height = image_batch[img_idx][2] + image_width, image_height = pil_images[img_idx].size results = [] valid_index = 0 for item in paddle_results: diff --git a/glmocr/pipeline/_state.py b/glmocr/pipeline/_state.py index 27da34f..4264d1d 100644 --- a/glmocr/pipeline/_state.py +++ b/glmocr/pipeline/_state.py @@ -44,6 +44,7 @@ def __init__( # ── Recognition results (stage 3 appends, main thread reads) ─ self._recognition_results: List[Dict[str, Any]] = [] + self._results_by_page: Dict[int, List[Dict]] = {} self._results_lock = threading.Lock() # ── UnitTracker (set before threads start) ─────────────────── @@ -75,14 +76,16 @@ def add_recognition_result(self, page_idx: int, region: Dict) -> None: result = {"page_idx": page_idx, "region": region} with self._results_lock: self._recognition_results.append(result) + self._results_by_page.setdefault(page_idx, []).append(region) tracker = self._tracker if tracker is not None: tracker.on_region_done(page_idx) - def snapshot_recognition_results(self) -> List[Dict[str, Any]]: - """Return a shallow copy of all results accumulated so far.""" + def get_grouped_results(self, page_indices: List[int]) -> List[List[Dict]]: + """Return recognition results grouped by page for the given indices. + """ with self._results_lock: - return list(self._recognition_results) + return [list(self._results_by_page.get(pi, [])) for pi in page_indices] # ------------------------------------------------------------------ # UnitTracker lifecycle diff --git a/glmocr/pipeline/_workers.py b/glmocr/pipeline/_workers.py index 293eabc..c0d0ba9 100644 --- a/glmocr/pipeline/_workers.py +++ b/glmocr/pipeline/_workers.py @@ -170,22 +170,22 @@ def layout_worker( save_visualization, vis_output_dir, global_start_idx, ) global_start_idx += len(batch_page_indices) + for pi in batch_page_indices: + state.images_dict.pop(pi, None) batch_images, batch_page_indices, batch_unit_indices = [], [], [] elif identifier == IDENTIFIER_UNIT_DONE: unit_idx = msg["unit_idx"] - # Flush any remaining pages in the batch (they all belong to - # this unit since t1 sends UNIT_DONE after the last page). if batch_images: _flush_layout_batch( state, layout_detector, batch_images, batch_page_indices, save_visualization, vis_output_dir, global_start_idx, ) global_start_idx += len(batch_page_indices) + for pi in batch_page_indices: + state.images_dict.pop(pi, None) batch_images, batch_page_indices, batch_unit_indices = [], [], [] - # All pages for this unit have been layout-detected; compute - # total region count and tell the tracker. pages_for_unit = unit_page_indices.get(unit_idx, []) region_count = sum( len(state.layout_results_dict.get(pi, [])) @@ -261,7 +261,6 @@ def recognition_worker( concurrency = min(max_workers, 128) executor = ThreadPoolExecutor(max_workers=concurrency) futures: Dict[Any, Dict[str, Any]] = {} - pending_skip: List[Dict[str, Any]] = [] processing_complete = False while True: @@ -276,7 +275,6 @@ def recognition_worker( msg = state.region_queue.get(timeout=0.01) except queue.Empty: if processing_complete and not futures: - _flush_pending_skips(pending_skip, state) break if futures: _wait_for_any(futures) @@ -286,11 +284,13 @@ def recognition_worker( if identifier == IDENTIFIER_REGION: if msg["region"]["task_type"] == "skip": - pending_skip.append(msg) + msg["region"]["content"] = None + state.add_recognition_result(msg["page_idx"], msg["region"]) else: req = page_loader.build_request_from_image( msg["cropped_image"], msg["region"]["task_type"], ) + del msg["cropped_image"] future = executor.submit(ocr_client.process, req) futures[future] = msg @@ -342,14 +342,6 @@ def _handle_future_result( state.add_recognition_result(page_idx, region) -def _flush_pending_skips( - pending: List[Dict[str, Any]], - state: PipelineState, -) -> None: - for msg in pending: - msg["region"]["content"] = None - state.add_recognition_result(msg["page_idx"], msg["region"]) - def _wait_for_any(futures: Dict) -> None: done_list = [f for f in futures if f.done()] diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index 8fb0f8c..9020e09 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -15,6 +15,7 @@ from __future__ import annotations +import time import threading from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional @@ -253,23 +254,18 @@ def _emit_results( region_count = tracker.unit_region_count[u] if region_count is None: tracker._ready_queue.put(u) + time.sleep(0.05) continue - results = state.snapshot_recognition_results() page_indices = tracker.unit_image_indices[u] - page_set = set(page_indices) - count = sum(1 for r in results if r["page_idx"] in page_set) - if count < region_count: + grouped = state.get_grouped_results(page_indices) + + total = sum(len(g) for g in grouped) + if total < region_count: tracker._ready_queue.put(u) + time.sleep(0.05) continue - page_to_pos = {idx: k for k, idx in enumerate(page_indices)} - grouped: List[List[Dict]] = [[] for _ in page_indices] - for r in results: - pos = page_to_pos.get(r["page_idx"]) - if pos is not None: - grouped[pos].append(r["region"]) - json_u, md_u = self.result_formatter.process(grouped) yield PipelineResult( json_result=json_u, From 59e6ba646440671f5d384af56e5caedf397a497c Mon Sep 17 00:00:00 2001 From: xueyadong Date: Tue, 3 Mar 2026 17:16:46 +0800 Subject: [PATCH 05/38] Added shutdown event handling to safely stop processing and drain queues --- glmocr/pipeline/_common.py | 3 +- glmocr/pipeline/_state.py | 39 +++++++++++++ glmocr/pipeline/_workers.py | 111 ++++++++++++++++++++---------------- glmocr/pipeline/pipeline.py | 10 ++-- 4 files changed, 108 insertions(+), 55 deletions(-) diff --git a/glmocr/pipeline/_common.py b/glmocr/pipeline/_common.py index 09e00d1..fe2f532 100644 --- a/glmocr/pipeline/_common.py +++ b/glmocr/pipeline/_common.py @@ -35,9 +35,8 @@ def extract_ocr_content(response: Dict[str, Any]) -> str: # ── Queue message "identifier" field values ────────────────────────── -# Every queue message is a dict with at least an "identifier" key. +# Every queue message is a dict with an "identifier" key. IDENTIFIER_IMAGE = "image" IDENTIFIER_UNIT_DONE = "unit_done" # t1 → t2: all pages for one input unit are queued IDENTIFIER_REGION = "region" IDENTIFIER_DONE = "done" -IDENTIFIER_ERROR = "error" diff --git a/glmocr/pipeline/_state.py b/glmocr/pipeline/_state.py index 4264d1d..2f2de12 100644 --- a/glmocr/pipeline/_state.py +++ b/glmocr/pipeline/_state.py @@ -54,6 +54,44 @@ def __init__( self._exceptions: List[Dict[str, Any]] = [] self._exception_lock = threading.Lock() + # ── Shutdown coordination ───────────────────────────────────── + self._shutdown_event = threading.Event() + + # ------------------------------------------------------------------ + # Shutdown helpers + # ------------------------------------------------------------------ + + @property + def is_shutdown(self) -> bool: + return self._shutdown_event.is_set() + + def request_shutdown(self) -> None: + """Signal all workers to stop processing.""" + self._shutdown_event.set() + tracker = self._tracker + if tracker is not None: + tracker.signal_shutdown() + + def safe_put(self, q: queue.Queue, msg: Dict[str, Any], + timeout: float = 0.5) -> bool: + """Put *msg* on *q*, returning ``False`` if shutdown was requested.""" + while not self._shutdown_event.is_set(): + try: + q.put(msg, timeout=timeout) + return True + except queue.Full: + continue + return False + + @staticmethod + def drain_queue(q: queue.Queue) -> None: + """Drain all items from *q* to unblock any blocked producers.""" + while True: + try: + q.get_nowait() + except queue.Empty: + break + # ------------------------------------------------------------------ # Page registration (delegated to tracker) # ------------------------------------------------------------------ @@ -117,6 +155,7 @@ def finalize_unit(self, unit_idx: int, region_count: int) -> None: def record_exception(self, source: str, exc: Exception) -> None: with self._exception_lock: self._exceptions.append({"source": source, "exception": exc}) + self._shutdown_event.set() tracker = self._tracker if tracker is not None: tracker.signal_shutdown() diff --git a/glmocr/pipeline/_workers.py b/glmocr/pipeline/_workers.py index c0d0ba9..01836bf 100644 --- a/glmocr/pipeline/_workers.py +++ b/glmocr/pipeline/_workers.py @@ -4,32 +4,30 @@ Stage 2 (layout_worker): Layout detection → region_queue Stage 3 (recognition_worker): Parallel OCR → recognition_results -Queue message formats: - - page_queue:: - {"identifier": "image", "page_idx": int, "unit_idx": int, - "image": PIL.Image} - {"identifier": "unit_done", "unit_idx": int} ← all pages for this - unit have been queued - {"identifier": "done"} - {"identifier": "error"} - - region_queue:: - {"identifier": "region", "page_idx": int, "cropped_image": PIL.Image, - "region": dict, "task_type": str} - {"identifier": "done"} - {"identifier": "error"} +Queue message formats +--------------------- +page_queue:: + + {"identifier": "image", "page_idx": int, "unit_idx": int, + "image": PIL.Image} + {"identifier": "unit_done", "unit_idx": int} + {"identifier": "done"} + +region_queue:: + + {"identifier": "region", "page_idx": int, "cropped_image": PIL.Image, + "region": dict} + {"identifier": "done"} """ from __future__ import annotations import queue -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional from concurrent.futures import ThreadPoolExecutor, as_completed from glmocr.pipeline._common import ( IDENTIFIER_DONE, - IDENTIFIER_ERROR, IDENTIFIER_IMAGE, IDENTIFIER_REGION, IDENTIFIER_UNIT_DONE, @@ -76,48 +74,53 @@ def data_loading_worker( sent_unit_done: set = set() try: for page, unit_idx in page_loader.iter_pages_with_unit_indices(image_urls): + if state.is_shutdown: + break + if prev_unit_idx is not None and unit_idx != prev_unit_idx: - state.page_queue.put({ + if not state.safe_put(state.page_queue, { "identifier": IDENTIFIER_UNIT_DONE, "unit_idx": prev_unit_idx, - }) + }): + break sent_unit_done.add(prev_unit_idx) state.register_page(page_idx, unit_idx) state.images_dict[page_idx] = page - state.page_queue.put({ + if not state.safe_put(state.page_queue, { "identifier": IDENTIFIER_IMAGE, "page_idx": page_idx, "unit_idx": unit_idx, "image": page, - }) + }): + break unit_indices_list.append(unit_idx) page_idx += 1 state.num_images_loaded[0] = page_idx state.unit_indices_holder[0] = list(unit_indices_list) prev_unit_idx = unit_idx - if prev_unit_idx is not None: - state.page_queue.put({ - "identifier": IDENTIFIER_UNIT_DONE, - "unit_idx": prev_unit_idx, - }) - sent_unit_done.add(prev_unit_idx) - - for u in range(num_units): - if u not in sent_unit_done: - state.page_queue.put({ + if not state.is_shutdown: + if prev_unit_idx is not None: + state.safe_put(state.page_queue, { "identifier": IDENTIFIER_UNIT_DONE, - "unit_idx": u, + "unit_idx": prev_unit_idx, }) + sent_unit_done.add(prev_unit_idx) - state.page_queue.put({"identifier": IDENTIFIER_DONE}) + for u in range(num_units): + if u not in sent_unit_done: + state.safe_put(state.page_queue, { + "identifier": IDENTIFIER_UNIT_DONE, + "unit_idx": u, + }) + + state.safe_put(state.page_queue, {"identifier": IDENTIFIER_DONE}) except Exception as e: logger.exception("Data loading worker error: %s", e) state.num_images_loaded[0] = page_idx state.unit_indices_holder[0] = list(unit_indices_list) state.record_exception("DataLoadingWorker", e) - state.page_queue.put({"identifier": IDENTIFIER_ERROR}) # ====================================================================== @@ -144,10 +147,12 @@ def layout_worker( batch_unit_indices: List[int] = [] global_start_idx = 0 - # page_indices seen so far per unit, used to compute region counts. unit_page_indices: Dict[int, List[int]] = {} while True: + if state.is_shutdown: + break + try: msg = state.page_queue.get(timeout=0.01) except queue.Empty: @@ -203,17 +208,14 @@ def layout_worker( state, layout_detector, batch_images, batch_page_indices, save_visualization, vis_output_dir, global_start_idx, ) - state.region_queue.put({"identifier": IDENTIFIER_DONE}) - break - - elif identifier == IDENTIFIER_ERROR: - state.region_queue.put({"identifier": IDENTIFIER_ERROR}) + state.safe_put(state.region_queue, {"identifier": IDENTIFIER_DONE}) break except Exception as e: logger.exception("Layout worker error: %s", e) state.record_exception("LayoutWorker", e) - state.region_queue.put({"identifier": IDENTIFIER_ERROR}) + finally: + state.drain_queue(state.page_queue) def _flush_layout_batch( @@ -238,12 +240,13 @@ def _flush_layout_batch( state.layout_results_dict[page_idx] = layout_result for region in layout_result: cropped = crop_image_region(image, region["bbox_2d"], region["polygon"]) - state.region_queue.put({ + if not state.safe_put(state.region_queue, { "identifier": IDENTIFIER_REGION, "page_idx": page_idx, "cropped_image": cropped, "region": region, - }) + }): + return # ====================================================================== @@ -257,6 +260,7 @@ def recognition_worker( max_workers: int, ) -> None: """Consume regions, run parallel OCR, store results.""" + executor = None try: concurrency = min(max_workers, 128) executor = ThreadPoolExecutor(max_workers=concurrency) @@ -264,6 +268,9 @@ def recognition_worker( processing_complete = False while True: + if state.is_shutdown: + break + _collect_done_futures(futures, state) if len(futures) >= concurrency: @@ -297,16 +304,22 @@ def recognition_worker( elif identifier == IDENTIFIER_DONE: processing_complete = True - elif identifier == IDENTIFIER_ERROR: - break - - for future in as_completed(futures.keys()): - _handle_future_result(future, futures, state) - executor.shutdown(wait=True) + if not state.is_shutdown: + for future in as_completed(futures.keys()): + _handle_future_result(future, futures, state) + executor.shutdown(wait=True) + else: + for f in list(futures): + f.cancel() + executor.shutdown(wait=False) except Exception as e: logger.exception("Recognition worker error: %s", e) state.record_exception("RecognitionWorker", e) + if executor is not None: + executor.shutdown(wait=False) + finally: + state.drain_queue(state.region_queue) # ------------------------------------------------------------------ diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index 9020e09..81e25f4 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -158,13 +158,15 @@ def process( try: yield from self._emit_results(state, tracker, original_inputs, layout_vis_output_dir) - t1.join() - t2.join() - t3.join() - state.raise_if_exceptions() finally: + state.request_shutdown() + t1.join(timeout=10) + t2.join(timeout=10) + t3.join(timeout=10) self._current_state = None + state.raise_if_exceptions() + def get_queue_stats(self) -> Optional[Dict[str, int]]: """Return current queue sizes, or ``None`` if no processing is active.""" state = self._current_state From 56c40a88b2b2a1fe9df624a1078a2ce95adca3cd Mon Sep 17 00:00:00 2001 From: xueyadong Date: Tue, 3 Mar 2026 18:37:08 +0800 Subject: [PATCH 06/38] support image / PDF bytes input --- glmocr/api.py | 265 ++++++++++++++++++++++++++++---------------------- 1 file changed, 148 insertions(+), 117 deletions(-) diff --git a/glmocr/api.py b/glmocr/api.py index 4f8d8d1..7aa92af 100644 --- a/glmocr/api.py +++ b/glmocr/api.py @@ -8,17 +8,23 @@ 2. Self-hosted Mode (maas.enabled=false): Uses local vLLM/SGLang service. Requires GPU; SDK handles layout detection, parallel OCR, etc. +Supported input types: file paths (``str``), ``pathlib.Path``, raw ``bytes`` +(image or PDF content), and URLs (file://, http://, data:). + Agent-friendly usage:: - # Only needs GLMOCR_API_KEY in environment (or pass api_key directly) from glmocr import GlmOcr parser = GlmOcr(api_key="sk-xxx", mode="maas") - results = parser.parse("document.png") - print(results[0].to_dict()) + result = parser.parse("document.png") + result = parser.parse(open("doc.pdf", "rb").read()) # bytes + print(result.to_dict()) """ +import os import re +import shutil +import tempfile from typing import Any, Dict, Generator, List, Literal, Optional, Union, overload from pathlib import Path @@ -50,8 +56,10 @@ class GlmOcr: # --- Classic: YAML-based --- parser = glmocr.GlmOcr(config_path="config.yaml") - # --- Parse --- - results = parser.parse("image.png") + # --- Parse (paths, bytes, or mixed) --- + result = parser.parse("image.png") + result = parser.parse(open("doc.pdf", "rb").read()) + results = parser.parse(["img.png", pdf_bytes]) for r in results: print(r.markdown_result) print(r.to_dict()) # structured, JSON-serialisable @@ -118,6 +126,7 @@ def __init__( self._use_maas = self.config_model.pipeline.maas.enabled self._pipeline = None self._maas_client = None + self._session_temp_dir: Optional[str] = None if self._use_maas: # MaaS mode: use MaaSClient for direct API passthrough @@ -134,10 +143,78 @@ def __init__( self._pipeline.start() logger.info("GLM-OCR initialized in self-hosted mode") + # ------------------------------------------------------------------ + # Input normalisation helpers + # ------------------------------------------------------------------ + + def _get_temp_dir(self) -> str: + if self._session_temp_dir is None: + self._session_temp_dir = tempfile.mkdtemp(prefix="glmocr_") + return self._session_temp_dir + + @staticmethod + def _detect_suffix(data: bytes) -> str: + """Detect file extension from magic bytes.""" + if data[:5] == b"%PDF-": + return ".pdf" + if data[:8] == b"\x89PNG\r\n\x1a\n": + return ".png" + if data[:2] == b"\xff\xd8": + return ".jpg" + if data[:4] == b"GIF8": + return ".gif" + if len(data) > 12 and data[:4] == b"RIFF" and data[8:12] == b"WEBP": + return ".webp" + if data[:2] == b"BM": + return ".bmp" + return ".png" + + def _bytes_to_temp_file(self, data: bytes) -> str: + """Write *data* to a temp file and return the path. + + The file lives in ``_session_temp_dir`` and is cleaned up by + ``close()``. + """ + suffix = self._detect_suffix(data) + fd, path = tempfile.mkstemp(suffix=suffix, dir=self._get_temp_dir()) + try: + os.write(fd, data) + finally: + os.close(fd) + return path + + def _to_url(self, image: Union[str, bytes, Path]) -> str: + """Convert any supported input to a ``file://`` or ``data:`` URL.""" + if isinstance(image, bytes): + return f"file://{self._bytes_to_temp_file(image)}" + if isinstance(image, Path): + return f"file://{image.absolute()}" + if isinstance(image, str): + if image.startswith(("http://", "https://", "data:", "file://")): + return image + return f"file://{Path(image).absolute()}" + raise TypeError(f"Unsupported image type: {type(image)}") + + @staticmethod + def _maas_source(image: Union[str, bytes, Path]): + """Return ``(source, display_name)`` suitable for the MaaS client.""" + if isinstance(image, bytes): + return image, "" + if isinstance(image, Path): + return str(image), str(image) + if isinstance(image, str) and image.startswith("file://"): + p = image[7:] + return p, p + return image, str(image) + + # ------------------------------------------------------------------ + # parse() and overloads + # ------------------------------------------------------------------ + @overload def parse( self, - images: str, + images: Union[str, bytes, Path], *, stream: Literal[False] = ..., save_layout_visualization: bool = ..., @@ -148,7 +225,7 @@ def parse( @overload def parse( self, - images: List[str], + images: List[Union[str, bytes, Path]], *, stream: Literal[False] = ..., save_layout_visualization: bool = ..., @@ -159,7 +236,7 @@ def parse( @overload def parse( self, - images: Union[str, List[str]], + images: Union[str, bytes, Path, List[Union[str, bytes, Path]]], *, stream: Literal[True], save_layout_visualization: bool = ..., @@ -169,7 +246,7 @@ def parse( def parse( self, - images: Union[str, List[str]], + images: Union[str, bytes, Path, List[Union[str, bytes, Path]]], *, stream: bool = False, save_layout_visualization: bool = True, @@ -177,39 +254,36 @@ def parse( ) -> Union[ PipelineResult, List[PipelineResult], Generator[PipelineResult, None, None] ]: - """Predict / parse images or documents. + """Parse images or documents. - Supports local paths and URLs (file://, http://, https://, data:). - Supports image files (jpg, png, bmp, gif, webp) and PDF files. + Accepts file paths, URLs, ``pathlib.Path`` objects, or raw ``bytes`` + (image or PDF content). Format is auto-detected from magic bytes. Args: - images: Image path/URL — a single ``str`` or a ``list`` of strings. - stream: If ``True``, yields one :class:`PipelineResult` at a time (avoids - holding all results in memory). If ``False``, returns a single result - or a list, depending on *images*. + images: A single input or a list. Each element can be: + + * ``str`` – local path or URL (file://, http://, data:) + * ``bytes`` – raw image / PDF bytes + * ``Path`` – ``pathlib.Path`` to a file + + stream: If ``True``, yields one :class:`PipelineResult` at a time. save_layout_visualization: Whether to save layout visualization artifacts. **kwargs: Additional parameters for MaaS mode (return_crop_images, need_layout_visualization, start_page_id, end_page_id, etc.) Returns: - - When ``stream=False`` (default): a single ``PipelineResult`` if *images* - is a ``str``, or a ``List[PipelineResult]`` if *images* is a list. - - When ``stream=True``: a generator that yields one ``PipelineResult`` - per input. + - ``stream=False``, single input → ``PipelineResult`` + - ``stream=False``, list input → ``List[PipelineResult]`` + - ``stream=True`` → ``Generator[PipelineResult, ...]`` - Example: - # Single file — returns one PipelineResult - result = parser.parse("image.png") - result.save(output_dir="./output") + Examples:: - # Multiple files — returns a list - results = parser.parse(["img1.png", "doc.pdf"]) - - # Stream to avoid large in-memory results - for r in parser.parse(["a.pdf", "b.pdf"], stream=True): - r.save(output_dir="./output") + result = parser.parse("image.png") + result = parser.parse(Path("image.png")) + result = parser.parse(open("image.png", "rb").read()) + results = parser.parse(["img1.png", pdf_bytes]) """ - _single = isinstance(images, str) + _single = isinstance(images, (str, bytes, Path)) if _single: images = [images] @@ -225,7 +299,7 @@ def parse( def _parse_stream( self, - images: List[str], + images: List[Union[str, bytes, Path]], save_layout_visualization: bool = True, **kwargs: Any, ) -> Generator[PipelineResult, None, None]: @@ -234,19 +308,17 @@ def _parse_stream( if save_layout_visualization: kwargs.setdefault("need_layout_visualization", True) for image in images: - img = image - if img.startswith("file://"): - img = img[7:] + source, display = self._maas_source(image) try: - response = self._maas_client.parse(img, **kwargs) - result = self._maas_response_to_pipeline_result(response, img) + response = self._maas_client.parse(source, **kwargs) + result = self._maas_response_to_pipeline_result(response, display) yield result except Exception as e: - logger.error("MaaS API error for %s: %s", img, e) + logger.error("MaaS API error for %s: %s", display, e) result = PipelineResult( json_result=[], markdown_result="", - original_images=[img], + original_images=[display], ) result._error = str(e) yield result @@ -259,33 +331,28 @@ def _parse_stream( def _parse_maas( self, - images: List[str], + images: List[Union[str, bytes, Path]], save_layout_visualization: bool = True, - **kwargs, + **kwargs: Any, ) -> List[PipelineResult]: """Parse using MaaS API (passthrough mode).""" results = [] - # Map save_layout_visualization to MaaS API parameter if save_layout_visualization: kwargs.setdefault("need_layout_visualization", True) for image in images: - # Resolve file:// URLs to actual paths - if image.startswith("file://"): - image = image[7:] - + source, display = self._maas_source(image) try: - response = self._maas_client.parse(image, **kwargs) - result = self._maas_response_to_pipeline_result(response, image) + response = self._maas_client.parse(source, **kwargs) + result = self._maas_response_to_pipeline_result(response, display) results.append(result) except Exception as e: - logger.error("MaaS API error for %s: %s", image, e) - # Return an error result + logger.error("MaaS API error for %s: %s", display, e) result = PipelineResult( json_result=[], markdown_result="", - original_images=[image], + original_images=[display], ) result._error = str(e) results.append(result) @@ -414,24 +481,25 @@ def _maas_response_to_pipeline_result( return result + def _build_selfhosted_request( + self, images: List[Union[str, bytes, Path]], + ) -> Dict[str, Any]: + """Build OpenAI-style request from mixed inputs.""" + messages: List[Dict[str, Any]] = [{"role": "user", "content": []}] + for image in images: + url = self._to_url(image) + messages[0]["content"].append( + {"type": "image_url", "image_url": {"url": url}} + ) + return {"messages": messages} + def _parse_selfhosted( self, - images: List[str], + images: List[Union[str, bytes, Path]], save_layout_visualization: bool = True, ) -> List[PipelineResult]: """Parse using self-hosted vLLM/SGLang pipeline.""" - import tempfile - - messages = [{"role": "user", "content": []}] - for image in images: - if image.startswith(("http://", "https://", "data:", "file://")): - url = image - else: - url = f"file://{Path(image).absolute()}" - messages[0]["content"].append( - {"type": "image_url", "image_url": {"url": url}} - ) - request_data = {"messages": messages} + request_data = self._build_selfhosted_request(images) layout_vis_dir = None if save_layout_visualization: @@ -448,26 +516,11 @@ def _parse_selfhosted( def _stream_parse_selfhosted( self, - images: List[str], + images: List[Union[str, bytes, Path]], save_layout_visualization: bool = True, ) -> Generator[PipelineResult, None, None]: - """Streaming variant of self-hosted parse(). - - Wraps ``Pipeline.process(...)`` and yields results as soon as they - become available from the async pipeline. - """ - import tempfile - - messages = [{"role": "user", "content": []}] - for image in images: - if image.startswith(("http://", "https://", "data:", "file://")): - url = image - else: - url = f"file://{Path(image).absolute()}" - messages[0]["content"].append( - {"type": "image_url", "image_url": {"url": url}} - ) - request_data = {"messages": messages} + """Streaming variant of self-hosted parse().""" + request_data = self._build_selfhosted_request(images) layout_vis_dir = None if save_layout_visualization: @@ -538,6 +591,9 @@ def close(self): if self._maas_client: self._maas_client.stop() self._maas_client = None + if self._session_temp_dir: + shutil.rmtree(self._session_temp_dir, ignore_errors=True) + self._session_temp_dir = None def __enter__(self): """Context manager entry.""" @@ -558,7 +614,7 @@ def __del__(self): # Convenience function @overload def parse( - images: str, + images: Union[str, bytes, Path], config_path: Optional[str] = ..., save_layout_visualization: bool = ..., ) -> PipelineResult: @@ -567,7 +623,7 @@ def parse( @overload def parse( - images: List[str], + images: List[Union[str, bytes, Path]], config_path: Optional[str] = ..., save_layout_visualization: bool = ..., ) -> List[PipelineResult]: @@ -576,7 +632,7 @@ def parse( @overload def parse( - images: Union[str, List[str]], + images: Union[str, bytes, Path, List[Union[str, bytes, Path]]], config_path: Optional[str] = ..., save_layout_visualization: bool = ..., *, @@ -587,7 +643,7 @@ def parse( def parse( - images: Union[str, List[str]], + images: Union[str, bytes, Path, List[Union[str, bytes, Path]]], config_path: Optional[str] = None, save_layout_visualization: bool = True, *, @@ -603,55 +659,30 @@ def parse( """Convenience function: parse images or documents in one call. Creates a :class:`GlmOcr` instance, runs parsing, and cleans up. - All keyword arguments are forwarded to the ``GlmOcr`` constructor. Examples:: import glmocr - # Minimal – only needs GLMOCR_API_KEY env var - results = glmocr.parse("image.png") - - # Explicit API key - results = glmocr.parse("image.png", api_key="sk-xxx") + result = glmocr.parse("image.png") + result = glmocr.parse(open("doc.pdf", "rb").read()) + results = glmocr.parse(["img.png", pdf_bytes]) - # Self-hosted mode - results = glmocr.parse("image.png", mode="selfhosted") - - # Stream to avoid large in-memory results for r in glmocr.parse(["a.pdf", "b.pdf"], stream=True): r.save(output_dir="./output") - The return type mirrors the input type and stream: - - ``str``, stream=False → ``PipelineResult`` - - ``List[str]``, stream=False → ``List[PipelineResult]`` - - ``stream=True`` → ``Generator[PipelineResult, None, None]`` - Args: - images: Image path or URL (single ``str`` or ``List[str]``). + images: Single input or list. Each element can be ``str`` (path/URL), + ``bytes`` (raw image/PDF), or ``pathlib.Path``. config_path: Config file path. save_layout_visualization: Whether to save layout visualization. - stream: If ``True``, returns a generator that yields one result at a time. + stream: If ``True``, returns a generator. api_key: API key. api_url: MaaS API endpoint URL. model: Model name. mode: ``"maas"`` or ``"selfhosted"``. timeout: Request timeout in seconds. log_level: Logging level. - - Returns: - A single ``PipelineResult``, a list, or a generator, depending on input and stream. - - Example: - result = parse("image.png") - result.save(output_dir="./output") - - results = parse(["img1.png", "doc.pdf"]) - for r in results: - r.save(output_dir="./output") - - for r in parse(["a.pdf", "b.pdf"], stream=True): - r.save(output_dir="./output") """ with GlmOcr( config_path=config_path, From f55d36a3750455cc26fcc5f116c6f9a27ab2fb29 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Wed, 4 Mar 2026 14:13:27 +0800 Subject: [PATCH 07/38] add image path to json result for cropped images --- glmocr/parser_result/base.py | 94 ++++++++++++++++++++++++++++++------ 1 file changed, 78 insertions(+), 16 deletions(-) diff --git a/glmocr/parser_result/base.py b/glmocr/parser_result/base.py index f6cfa3e..4eff3f5 100644 --- a/glmocr/parser_result/base.py +++ b/glmocr/parser_result/base.py @@ -5,14 +5,15 @@ from __future__ import annotations +import copy import json import traceback from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from glmocr.utils.logging import get_logger -from glmocr.utils.markdown_utils import crop_and_replace_images +from glmocr.utils.markdown_utils import crop_and_replace_images, extract_image_refs logger = get_logger(__name__) @@ -58,6 +59,60 @@ def save( """Save result to disk. Subclasses implement layout vis etc.""" pass + @staticmethod + def _build_image_path_map( + markdown_text: str, image_prefix: str = "cropped" + ) -> Dict[Tuple[int, ...], str]: + """Build a mapping from (page_idx, *bbox) to the relative image path. + + The mapping is derived purely from the markdown image references so + it stays in sync with what ``crop_and_replace_images`` will produce, + without performing any file I/O here. + """ + mapping: Dict[Tuple[int, ...], str] = {} + refs = extract_image_refs(markdown_text) + for idx, (page_idx, bbox, _) in enumerate(refs): + key = (page_idx, *bbox) + rel = f"imgs/{image_prefix}_page{page_idx}_idx{idx}.jpg" + mapping[key] = rel + return mapping + + @staticmethod + def _annotate_json_image_paths( + json_data: Any, + image_path_map: Dict[Tuple[int, ...], str], + ) -> Any: + """Return a deep-copied json_data with ``image_path`` added to image regions. + + ``json_data`` is expected to be a list-of-pages (list of lists of region + dicts). For every region whose ``label`` is ``"image"``, the relative + path is looked up by ``(page_idx, *bbox_2d)`` and written into the copy. + The original ``json_data`` is never mutated. + """ + if not image_path_map or not isinstance(json_data, list): + return json_data + + result = [] + for page_idx, page in enumerate(json_data): + if not isinstance(page, list): + result.append(page) + continue + page_copy = [] + for region in page: + if not isinstance(region, dict) or region.get("label") != "image": + page_copy.append(region) + continue + bbox = region.get("bbox_2d") + region_copy = copy.copy(region) + if bbox: + key = (page_idx, *bbox) + rel = image_path_map.get(key) + if rel: + region_copy["image_path"] = rel + page_copy.append(region_copy) + result.append(page_copy) + return result + def _save_json_and_markdown(self, output_dir: Union[str, Path]) -> None: """Save JSON and Markdown to output_dir (by first image name or 'result').""" output_dir = Path(output_dir).absolute() @@ -70,28 +125,35 @@ def _save_json_and_markdown(self, output_dir: Union[str, Path]) -> None: output_path.mkdir(parents=True, exist_ok=True) base_name = output_path.name - # JSON + # Build image_path_map from markdown refs so JSON can reference the + # same filenames that crop_and_replace_images will produce below. + image_path_map: Dict[Tuple[int, ...], str] = {} + if self.markdown_result and self.original_images: + image_path_map = self._build_image_path_map( + self.markdown_result, image_prefix="cropped" + ) + + # JSON — annotate image regions with their relative image_path json_file = output_path / f"{base_name}.json" try: - if isinstance(self.json_result, (dict, list)): - with open(json_file, "w", encoding="utf-8") as f: - json.dump(self.json_result, f, ensure_ascii=False, indent=2) - elif isinstance(self.json_result, str): + json_data = self.json_result + if isinstance(json_data, str): try: - data = json.loads(self.json_result) - with open(json_file, "w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False, indent=2) + json_data = json.loads(json_data) except json.JSONDecodeError: - with open(json_file, "w", encoding="utf-8") as f: - f.write(self.json_result) - else: - with open(json_file, "w", encoding="utf-8") as f: - json.dump(self.json_result, f, ensure_ascii=False, indent=2) + pass + if isinstance(json_data, list): + json_data = self._annotate_json_image_paths(json_data, image_path_map) + with open(json_file, "w", encoding="utf-8") as f: + if isinstance(json_data, (dict, list)): + json.dump(json_data, f, ensure_ascii=False, indent=2) + else: + f.write(str(json_data)) except Exception as e: logger.warning("Failed to save JSON: %s", e) traceback.print_exc() - # Markdown (with image crop/replace if original_images) + # Markdown — crop image regions and replace bbox tags with file paths if self.markdown_result and self.markdown_result.strip(): md_text = self.markdown_result if self.original_images: From 73ef92de37f9e1876b818c946de4a6c7ef257770 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Wed, 4 Mar 2026 21:14:47 +0800 Subject: [PATCH 08/38] Fix a layout visualization file naming bug --- glmocr/parser_result/pipeline_result.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/glmocr/parser_result/pipeline_result.py b/glmocr/parser_result/pipeline_result.py index 058abd8..27b7b26 100644 --- a/glmocr/parser_result/pipeline_result.py +++ b/glmocr/parser_result/pipeline_result.py @@ -2,7 +2,6 @@ from __future__ import annotations -import re import shutil from pathlib import Path from typing import List, Optional, Union @@ -87,20 +86,12 @@ def save( layout_files.extend(sorted(temp_layout_path.glob("layout_page*.png"))) stem = Path(self.original_images[0]).stem if self.original_images else "result" - for layout_file in layout_files: - m = re.match( - r"layout_page(\d+)\.(jpg|png)$", - layout_file.name, - re.IGNORECASE, - ) - if m: - idx_str, ext = m.group(1), m.group(2).lower() - else: - idx_str, ext = "0", layout_file.suffix.lstrip(".") or "jpg" + for local_idx, layout_file in enumerate(layout_files): + ext = layout_file.suffix.lstrip(".").lower() or "jpg" new_name = ( f"{stem}.{ext}" if len(layout_files) == 1 - else f"{stem}_page{idx_str}.{ext}" + else f"{stem}_page{local_idx}.{ext}" ) target_file = target_dir / new_name shutil.move(str(layout_file), str(target_file)) From eb2ebe373b67647b9dedc1f61b75a90c4706fb9d Mon Sep 17 00:00:00 2001 From: xueyadong Date: Thu, 5 Mar 2026 04:20:40 +0800 Subject: [PATCH 09/38] add fallback handling to page loader --- glmocr/dataloader/page_loader.py | 7 +++++-- glmocr/utils/image_utils.py | 12 +++++++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/glmocr/dataloader/page_loader.py b/glmocr/dataloader/page_loader.py index d4bf72f..82ae1ae 100644 --- a/glmocr/dataloader/page_loader.py +++ b/glmocr/dataloader/page_loader.py @@ -157,8 +157,11 @@ def iter_pages_with_unit_indices(self, sources: Union[str, List[str]]): if isinstance(sources, str): sources = [sources] for unit_idx, source in enumerate(sources): - for page in self._iter_source(source): - yield page, unit_idx + try: + for page in self._iter_source(source): + yield page, unit_idx + except Exception as e: + logger.warning("Skipping source '%s' (unit %d): %s", source, unit_idx, e) def _iter_source(self, source: str): """Yield pages from a single source one at a time.""" diff --git a/glmocr/utils/image_utils.py b/glmocr/utils/image_utils.py index e228cbb..82f64fd 100644 --- a/glmocr/utils/image_utils.py +++ b/glmocr/utils/image_utils.py @@ -9,6 +9,10 @@ import numpy as np from PIL import Image +from glmocr.utils.logging import get_logger + +logger = get_logger(__name__) + def smart_resize( t: int, @@ -379,12 +383,18 @@ def pdf_to_images_pil_iter( if end_page_id >= page_count: end_page_id = page_count - 1 for i in range(start_page_id, end_page_id + 1): - page = pdf[i] + try: + page = pdf[i] + except Exception as e: + logger.warning("Skipping page %d of '%s': %s", i, pdf_path, e) + continue try: image, _ = _page_to_image( page, dpi=dpi, max_width_or_height=max_width_or_height ) yield image + except Exception as e: + logger.warning("Skipping page %d of '%s' (render failed): %s", i, pdf_path, e) finally: page.close() finally: From c7104f35f88c7cd2e2bf9e18b8d8413b13798f76 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Thu, 5 Mar 2026 15:28:34 +0800 Subject: [PATCH 10/38] change PDF renderer to PyMuPDF; harden error handling --- glmocr/dataloader/page_loader.py | 13 +---- glmocr/pipeline/_workers.py | 38 ++++++++++---- glmocr/tests/test_integration.py | 9 ---- glmocr/tests/test_unit.py | 25 +-------- glmocr/utils/image_utils.py | 87 ++++++++++++-------------------- glmocr/utils/markdown_utils.py | 6 --- pyproject.toml | 6 +-- 7 files changed, 64 insertions(+), 120 deletions(-) diff --git a/glmocr/dataloader/page_loader.py b/glmocr/dataloader/page_loader.py index 82ae1ae..ec2b36d 100644 --- a/glmocr/dataloader/page_loader.py +++ b/glmocr/dataloader/page_loader.py @@ -26,7 +26,6 @@ load_image_to_base64, pdf_to_images_pil, pdf_to_images_pil_iter, - PYPDFIUM2_AVAILABLE, ) from glmocr.utils.logging import get_logger, get_profiler @@ -86,7 +85,7 @@ def __init__(self, config: "PageLoaderConfig"): # Default OCR instruction (used when user provides images without text) self.default_prompt = config.default_prompt - # PDF-to-image parameters (pypdfium2 only) + # PDF-to-image parameters self.pdf_dpi = config.pdf_dpi self.pdf_max_pages = config.pdf_max_pages self.pdf_verbose = config.pdf_verbose @@ -189,10 +188,6 @@ def _compute_end_page(self) -> Optional[int]: def _iter_pdf(self, file_path: str): """Yield PDF pages one at a time (streaming).""" - if not PYPDFIUM2_AVAILABLE: - raise RuntimeError( - "PDF support requires pypdfium2. Install: pip install pypdfium2" - ) end_page = self._compute_end_page() for image in pdf_to_images_pil_iter( file_path, @@ -243,11 +238,7 @@ def _load_image(self, source: str) -> Image.Image: raise RuntimeError(f"Error loading image '{source}': {e}") def _load_pdf(self, file_path: str) -> List[Image.Image]: - """Load all pages from a PDF file using pypdfium2 (required).""" - if not PYPDFIUM2_AVAILABLE: - raise RuntimeError( - "PDF support requires pypdfium2. Install: pip install pypdfium2" - ) + """Load all pages from a PDF file.""" t0 = time.perf_counter() end_page = self._compute_end_page() pages = pdf_to_images_pil( diff --git a/glmocr/pipeline/_workers.py b/glmocr/pipeline/_workers.py index 01836bf..cf36356 100644 --- a/glmocr/pipeline/_workers.py +++ b/glmocr/pipeline/_workers.py @@ -228,18 +228,37 @@ def _flush_layout_batch( global_start_idx: int, ) -> None: """Run layout detection on one batch and enqueue the resulting regions.""" - layout_results = layout_detector.process( - batch_images, - save_visualization=save_visualization and vis_output_dir is not None, - visualization_output_dir=vis_output_dir, - global_start_idx=global_start_idx, - ) + try: + layout_results = layout_detector.process( + batch_images, + save_visualization=save_visualization and vis_output_dir is not None, + visualization_output_dir=vis_output_dir, + global_start_idx=global_start_idx, + ) + except Exception as e: + logger.warning( + "Layout detection failed for pages %s, skipping batch: %s", + batch_page_indices, e, + ) + for page_idx in batch_page_indices: + state.layout_results_dict[page_idx] = [] + return + for page_idx, image, layout_result in zip( batch_page_indices, batch_images, layout_results ): state.layout_results_dict[page_idx] = layout_result for region in layout_result: - cropped = crop_image_region(image, region["bbox_2d"], region["polygon"]) + try: + cropped = crop_image_region(image, region["bbox_2d"], region["polygon"]) + except Exception as e: + logger.warning( + "Failed to crop region on page %d (bbox=%s), skipping: %s", + page_idx, region.get("bbox_2d"), e, + ) + region["content"] = "" + state.add_recognition_result(page_idx, region) + continue if not state.safe_put(state.region_queue, { "identifier": IDENTIFIER_REGION, "page_idx": page_idx, @@ -346,11 +365,12 @@ def _handle_future_result( try: response, status_code = future.result() if status_code == 200: - region["content"] = response["choices"][0]["message"]["content"].strip() + content = response["choices"][0]["message"]["content"] + region["content"] = content.strip() if content else "" else: region["content"] = "" except Exception as e: - logger.warning("Recognition failed: %s", e) + logger.warning("Recognition failed for page %d: %s", page_idx, e) region["content"] = "" state.add_recognition_result(page_idx, region) diff --git a/glmocr/tests/test_integration.py b/glmocr/tests/test_integration.py index aca160c..1bac198 100644 --- a/glmocr/tests/test_integration.py +++ b/glmocr/tests/test_integration.py @@ -119,15 +119,6 @@ def test_parse_pdf_file_uri(self, server_url, timeout_seconds, sample_pdf_path): if sample_pdf_path is None: pytest.skip("No sample PDF available") - # Dependency check: pypdfium2 - try: - from glmocr.utils.image_utils import PYPDFIUM2_AVAILABLE - except Exception: - PYPDFIUM2_AVAILABLE = False - - if not PYPDFIUM2_AVAILABLE: - pytest.skip("pypdfium2 is not installed") - pdf_uri = f"file://{sample_pdf_path.resolve()}" payload = {"images": [pdf_uri]} resp = requests.post( diff --git a/glmocr/tests/test_unit.py b/glmocr/tests/test_unit.py index c256676..8d1d1c9 100644 --- a/glmocr/tests/test_unit.py +++ b/glmocr/tests/test_unit.py @@ -50,25 +50,10 @@ def test_pageloader_with_config(self): assert loader.max_tokens == 8192 assert loader.image_format == "PNG" - def test_pageloader_load_pdf_requires_pypdfium2(self): - """Gives a clear error when pypdfium2 is unavailable.""" - from glmocr.dataloader import PageLoader - from glmocr.config import PageLoaderConfig - - loader = PageLoader(PageLoaderConfig()) - with patch("glmocr.dataloader.page_loader.PYPDFIUM2_AVAILABLE", False): - with pytest.raises(RuntimeError) as exc: - loader._load_pdf("dummy.pdf") - assert "pypdfium2" in str(exc.value).lower() - def test_pageloader_load_pdf_pages(self): - """Expands a PDF into page images (requires pypdfium2).""" + """Expands a PDF into page images.""" from glmocr.config import PageLoaderConfig from glmocr.dataloader import PageLoader - from glmocr.utils.image_utils import PYPDFIUM2_AVAILABLE - - if not PYPDFIUM2_AVAILABLE: - pytest.skip("pypdfium2 is not installed") repo_root = Path(__file__).resolve().parents[2] source_dir = repo_root / "examples" / "source" @@ -87,12 +72,8 @@ def test_pageloader_load_pdf_pages(self): def test_pageloader_load_pdf_via_file_uri(self): """Parses PDF file:// URIs correctly.""" from glmocr.dataloader import PageLoader - from glmocr.utils.image_utils import PYPDFIUM2_AVAILABLE from glmocr.config import PageLoaderConfig - if not PYPDFIUM2_AVAILABLE: - pytest.skip("pypdfium2 is not installed") - repo_root = Path(__file__).resolve().parents[2] source_dir = repo_root / "examples" / "source" sample_pdf = next( @@ -111,10 +92,6 @@ def test_iter_pages_with_unit_indices_pdf_and_multi_source(self): """Streaming: pages yielded incrementally; unit indices correct for multi-source.""" from glmocr.config import PageLoaderConfig from glmocr.dataloader import PageLoader - from glmocr.utils.image_utils import PYPDFIUM2_AVAILABLE - - if not PYPDFIUM2_AVAILABLE: - pytest.skip("pypdfium2 is not installed") repo_root = Path(__file__).resolve().parents[2] source_dir = repo_root / "examples" / "source" diff --git a/glmocr/utils/image_utils.py b/glmocr/utils/image_utils.py index 82f64fd..2773a6d 100644 --- a/glmocr/utils/image_utils.py +++ b/glmocr/utils/image_utils.py @@ -261,39 +261,31 @@ def image_tensor_to_base64(image_tensor, image_format): # ----------------------------------------------------------------------------- -# PDF rendering via pypdfium2 +# PDF rendering via PyMuPDF (fitz) # ----------------------------------------------------------------------------- -try: - import pypdfium2 as _pdfium # noqa: F401 +import fitz - PYPDFIUM2_AVAILABLE = True -except ImportError: - PYPDFIUM2_AVAILABLE = False - -def _page_to_image(page, dpi: int = 200, max_width_or_height: int = 3500): - """Render a PDF page to PIL Image (pypdfium2). +def _render_page_to_pil(page, dpi: int = 200, max_width_or_height: int = 3500): + """Render a PDF page to PIL Image via PyMuPDF. Args: - page: pypdfium2 PdfPage. + page: fitz.Page object. dpi: Render DPI. - max_width_or_height: Max width or height. + max_width_or_height: Cap on the longer side in pixels. Returns: (PIL.Image, scale_factor) """ scale = dpi / 72.0 - width, height = page.get_size() - long_side_length = max(width, height) - if (long_side_length * scale) > max_width_or_height: - scale = max_width_or_height / long_side_length - bitmap = page.render(scale=scale) - image = bitmap.to_pil() - try: - bitmap.close() - except Exception: - pass + rect = page.rect + long_side_pt = max(rect.width, rect.height) + if long_side_pt * scale > max_width_or_height: + scale = max_width_or_height / long_side_pt + mat = fitz.Matrix(scale, scale) + pix = page.get_pixmap(matrix=mat, alpha=False) + image = Image.frombytes("RGB", (pix.width, pix.height), pix.samples) return image, scale @@ -304,7 +296,7 @@ def pdf_to_images_pil( start_page_id: int = 0, end_page_id: int = None, ) -> list: - """Convert PDF to list of PIL Images using pypdfium2 (single-process). + """Convert PDF to list of PIL Images. Args: pdf_path: PDF file path. @@ -316,34 +308,25 @@ def pdf_to_images_pil( Returns: List of PIL.Image. """ - if not PYPDFIUM2_AVAILABLE: - raise ImportError( - "PDF support requires pypdfium2. Install with: pip install pypdfium2" - ) - import pypdfium2 as pdfium - - pdf = None + doc = None try: - pdf = pdfium.PdfDocument(pdf_path) - page_count = len(pdf) + doc = fitz.open(pdf_path) + page_count = doc.page_count if end_page_id is None or end_page_id < 0: end_page_id = page_count - 1 if end_page_id >= page_count: end_page_id = page_count - 1 images = [] for i in range(start_page_id, end_page_id + 1): - page = pdf[i] - try: - image, _ = _page_to_image( - page, dpi=dpi, max_width_or_height=max_width_or_height - ) - images.append(image) - finally: - page.close() + page = doc.load_page(i) + image, _ = _render_page_to_pil( + page, dpi=dpi, max_width_or_height=max_width_or_height + ) + images.append(image) return images finally: - if pdf is not None: - pdf.close() + if doc is not None: + doc.close() def pdf_to_images_pil_iter( @@ -368,35 +351,27 @@ def pdf_to_images_pil_iter( Yields: PIL.Image per page. """ - if not PYPDFIUM2_AVAILABLE: - raise ImportError( - "PDF support requires pypdfium2. Install with: pip install pypdfium2" - ) - import pypdfium2 as pdfium - - pdf = None + doc = None try: - pdf = pdfium.PdfDocument(pdf_path) - page_count = len(pdf) + doc = fitz.open(pdf_path) + page_count = doc.page_count if end_page_id is None or end_page_id < 0: end_page_id = page_count - 1 if end_page_id >= page_count: end_page_id = page_count - 1 for i in range(start_page_id, end_page_id + 1): try: - page = pdf[i] + page = doc.load_page(i) except Exception as e: logger.warning("Skipping page %d of '%s': %s", i, pdf_path, e) continue try: - image, _ = _page_to_image( + image, _ = _render_page_to_pil( page, dpi=dpi, max_width_or_height=max_width_or_height ) yield image except Exception as e: logger.warning("Skipping page %d of '%s' (render failed): %s", i, pdf_path, e) - finally: - page.close() finally: - if pdf is not None: - pdf.close() + if doc is not None: + doc.close() diff --git a/glmocr/utils/markdown_utils.py b/glmocr/utils/markdown_utils.py index d7f3a4d..061ba90 100644 --- a/glmocr/utils/markdown_utils.py +++ b/glmocr/utils/markdown_utils.py @@ -9,7 +9,6 @@ from glmocr.utils.image_utils import ( crop_image_region, pdf_to_images_pil, - PYPDFIUM2_AVAILABLE, ) from glmocr.utils.logging import get_logger @@ -81,11 +80,6 @@ def crop_and_replace_images( suffix = path.suffix.lower() if suffix == ".pdf": - # PDF: convert to images (pypdfium2 only) - if not PYPDFIUM2_AVAILABLE: - raise RuntimeError( - "PDF support requires pypdfium2. Install: pip install pypdfium2" - ) try: pdf_images = pdf_to_images_pil( img_path, dpi=200, max_width_or_height=3500 diff --git a/pyproject.toml b/pyproject.toml index efa5bc1..ea49e4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "opencv-python>=4.8.0", # PDF support - "pypdfium2>=5.3.0", + "pymupdf>=1.24.0", # Flask server "flask>=2.3.0", @@ -60,10 +60,6 @@ layout = [ "opencv-python>=4.8.0", ] -pdf = [ - "pypdfium2>=5.3.0", -] - server = [ "flask>=2.3.0", ] From f6b67c496bfc443b8850445215dbf2a007353bec Mon Sep 17 00:00:00 2001 From: xueyadong Date: Thu, 5 Mar 2026 18:45:39 +0800 Subject: [PATCH 11/38] simplify image region save flow to reduce IO --- glmocr/api.py | 14 +- glmocr/parser_result/base.py | 101 +++--------- glmocr/parser_result/pipeline_result.py | 4 + glmocr/pipeline/_state.py | 33 ++++ glmocr/pipeline/_workers.py | 5 + glmocr/pipeline/pipeline.py | 6 +- glmocr/postprocess/result_formatter.py | 40 ++++- glmocr/utils/markdown_utils.py | 196 ++++++++++-------------- 8 files changed, 192 insertions(+), 207 deletions(-) diff --git a/glmocr/api.py b/glmocr/api.py index 7aa92af..4d175a1 100644 --- a/glmocr/api.py +++ b/glmocr/api.py @@ -31,6 +31,7 @@ from glmocr.config import load_config from glmocr.parser_result import PipelineResult from glmocr.utils.logging import get_logger, ensure_logging_configured +from glmocr.utils.markdown_utils import resolve_image_regions logger = get_logger(__name__) @@ -364,8 +365,8 @@ def _parse_maas( # ------------------------------------------------------------------ # The MaaS API returns bbox_2d in **absolute pixel coordinates** of # its own internal rendering (e.g. 2040×2640 for a letter-sized PDF - # page). The rest of the SDK (self-hosted pipeline, crop_image_region, - # crop_and_replace_images) uses **normalised 0-1000 coordinates**. + # page). The rest of the SDK (self-hosted pipeline, + # resolve_image_regions) uses **normalised 0-1000 coordinates**. # # To keep everything consistent we convert here, right after receiving # the MaaS response, so that json_result and markdown_result always @@ -398,8 +399,8 @@ def _normalise_markdown_bboxes( pages_info: List[Dict[str, int]], ) -> str: """Replace absolute-pixel bbox values in Markdown image refs with - normalised 0-1000 values so that ``crop_and_replace_images`` crops - from the correct region. + normalised 0-1000 values so that the result formatter resolves + the correct region. """ if not pages_info or not markdown: return markdown @@ -466,11 +467,16 @@ def _maas_response_to_pipeline_result( pages_info, ) + json_result, markdown_result, image_files = resolve_image_regions( + json_result, markdown_result, source, + ) + # Create PipelineResult result = PipelineResult( json_result=json_result, markdown_result=markdown_result, original_images=[source], + image_files=image_files or None, ) # Store additional MaaS response data diff --git a/glmocr/parser_result/base.py b/glmocr/parser_result/base.py index 4eff3f5..6d5a406 100644 --- a/glmocr/parser_result/base.py +++ b/glmocr/parser_result/base.py @@ -5,15 +5,13 @@ from __future__ import annotations -import copy import json import traceback from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union from glmocr.utils.logging import get_logger -from glmocr.utils.markdown_utils import crop_and_replace_images, extract_image_refs logger = get_logger(__name__) @@ -29,6 +27,7 @@ def __init__( json_result: Union[str, dict, list], markdown_result: Optional[str] = None, original_images: Optional[List[str]] = None, + image_files: Optional[Dict[str, Any]] = None, ): """Initialize. @@ -36,6 +35,8 @@ def __init__( json_result: JSON result (string, dict, or list). markdown_result: Markdown result (optional). original_images: Original image paths. + image_files: Mapping of ``filename`` → PIL Image for image-type + regions, to be saved under ``imgs/`` during :meth:`save`. """ if isinstance(json_result, str): try: @@ -49,6 +50,7 @@ def __init__( self.original_images = [ str(Path(p).absolute()) for p in (original_images or []) ] + self.image_files = image_files @abstractmethod def save( @@ -59,60 +61,6 @@ def save( """Save result to disk. Subclasses implement layout vis etc.""" pass - @staticmethod - def _build_image_path_map( - markdown_text: str, image_prefix: str = "cropped" - ) -> Dict[Tuple[int, ...], str]: - """Build a mapping from (page_idx, *bbox) to the relative image path. - - The mapping is derived purely from the markdown image references so - it stays in sync with what ``crop_and_replace_images`` will produce, - without performing any file I/O here. - """ - mapping: Dict[Tuple[int, ...], str] = {} - refs = extract_image_refs(markdown_text) - for idx, (page_idx, bbox, _) in enumerate(refs): - key = (page_idx, *bbox) - rel = f"imgs/{image_prefix}_page{page_idx}_idx{idx}.jpg" - mapping[key] = rel - return mapping - - @staticmethod - def _annotate_json_image_paths( - json_data: Any, - image_path_map: Dict[Tuple[int, ...], str], - ) -> Any: - """Return a deep-copied json_data with ``image_path`` added to image regions. - - ``json_data`` is expected to be a list-of-pages (list of lists of region - dicts). For every region whose ``label`` is ``"image"``, the relative - path is looked up by ``(page_idx, *bbox_2d)`` and written into the copy. - The original ``json_data`` is never mutated. - """ - if not image_path_map or not isinstance(json_data, list): - return json_data - - result = [] - for page_idx, page in enumerate(json_data): - if not isinstance(page, list): - result.append(page) - continue - page_copy = [] - for region in page: - if not isinstance(region, dict) or region.get("label") != "image": - page_copy.append(region) - continue - bbox = region.get("bbox_2d") - region_copy = copy.copy(region) - if bbox: - key = (page_idx, *bbox) - rel = image_path_map.get(key) - if rel: - region_copy["image_path"] = rel - page_copy.append(region_copy) - result.append(page_copy) - return result - def _save_json_and_markdown(self, output_dir: Union[str, Path]) -> None: """Save JSON and Markdown to output_dir (by first image name or 'result').""" output_dir = Path(output_dir).absolute() @@ -125,15 +73,7 @@ def _save_json_and_markdown(self, output_dir: Union[str, Path]) -> None: output_path.mkdir(parents=True, exist_ok=True) base_name = output_path.name - # Build image_path_map from markdown refs so JSON can reference the - # same filenames that crop_and_replace_images will produce below. - image_path_map: Dict[Tuple[int, ...], str] = {} - if self.markdown_result and self.original_images: - image_path_map = self._build_image_path_map( - self.markdown_result, image_prefix="cropped" - ) - - # JSON — annotate image regions with their relative image_path + # JSON json_file = output_path / f"{base_name}.json" try: json_data = self.json_result @@ -142,8 +82,6 @@ def _save_json_and_markdown(self, output_dir: Union[str, Path]) -> None: json_data = json.loads(json_data) except json.JSONDecodeError: pass - if isinstance(json_data, list): - json_data = self._annotate_json_image_paths(json_data, image_path_map) with open(json_file, "w", encoding="utf-8") as f: if isinstance(json_data, (dict, list)): json.dump(json_data, f, ensure_ascii=False, indent=2) @@ -153,23 +91,22 @@ def _save_json_and_markdown(self, output_dir: Union[str, Path]) -> None: logger.warning("Failed to save JSON: %s", e) traceback.print_exc() - # Markdown — crop image regions and replace bbox tags with file paths + # Markdown if self.markdown_result and self.markdown_result.strip(): - md_text = self.markdown_result - if self.original_images: - try: - imgs_dir = output_path / "imgs" - md_text, _ = crop_and_replace_images( - md_text, - self.original_images, - imgs_dir, - image_prefix="cropped", - ) - except Exception as e: - logger.warning("Failed to process image regions: %s", e) md_file = output_path / f"{base_name}.md" with open(md_file, "w", encoding="utf-8") as f: - f.write(md_text) + f.write(self.markdown_result) + + # Image files produced by the result formatter + if self.image_files: + imgs_dir = output_path / "imgs" + imgs_dir.mkdir(parents=True, exist_ok=True) + for filename, img in self.image_files.items(): + try: + img.save(imgs_dir / filename, quality=95) + except Exception as e: + logger.warning("Failed to save image %s: %s", filename, e) + self.image_files = None def to_dict(self) -> dict: """Return a JSON-serialisable dict of the result. diff --git a/glmocr/parser_result/pipeline_result.py b/glmocr/parser_result/pipeline_result.py index 27b7b26..1f02488 100644 --- a/glmocr/parser_result/pipeline_result.py +++ b/glmocr/parser_result/pipeline_result.py @@ -26,6 +26,7 @@ def __init__( original_images: List[str], layout_vis_dir: Optional[str] = None, layout_image_indices: Optional[List[int]] = None, + image_files: Optional[dict] = None, ): """Initialize. @@ -36,11 +37,14 @@ def __init__( layout_vis_dir: Temp dir with layout_page{N}.jpg (optional). layout_image_indices: Indices of layout pages belonging to this unit; None means all files in layout_vis_dir belong to this unit. + image_files: Mapping of ``filename`` → PIL Image for image-type + regions; saved directly to ``imgs/`` during :meth:`save`. """ super().__init__( json_result=json_result, markdown_result=markdown_result, original_images=original_images, + image_files=image_files, ) self.layout_vis_dir = layout_vis_dir self.layout_image_indices = layout_image_indices diff --git a/glmocr/pipeline/_state.py b/glmocr/pipeline/_state.py index 2f2de12..7906f5b 100644 --- a/glmocr/pipeline/_state.py +++ b/glmocr/pipeline/_state.py @@ -47,6 +47,10 @@ def __init__( self._results_by_page: Dict[int, List[Dict]] = {} self._results_lock = threading.Lock() + # ── Pre-cropped images for image-type regions ───────────────── + self._image_region_store: Dict[int, Dict[tuple, Any]] = {} + self._image_store_lock = threading.Lock() + # ── UnitTracker (set before threads start) ─────────────────── self._tracker: Optional[UnitTracker] = None @@ -125,6 +129,35 @@ def get_grouped_results(self, page_indices: List[int]) -> List[List[Dict]]: with self._results_lock: return [list(self._results_by_page.get(pi, [])) for pi in page_indices] + # ------------------------------------------------------------------ + # Pre-cropped image store (for image-type regions) + # ------------------------------------------------------------------ + + def store_cropped_image(self, page_idx: int, bbox: list, image: Any) -> None: + """Store a pre-cropped image for an image-type (skip) region. + + Called by the recognition worker for regions with ``task_type == "skip"``. + """ + key = tuple(bbox) + with self._image_store_lock: + self._image_region_store.setdefault(page_idx, {})[key] = image + + def collect_cropped_images_for_unit( + self, page_indices: List[int] + ) -> Dict[tuple, Any]: + """Collect pre-cropped images for one unit, re-keyed by local page index. + + Returns a dict mapping ``(local_page_idx, *bbox)`` → PIL Image. + Consumed entries are removed from the store to free memory. + """ + result: Dict[tuple, Any] = {} + with self._image_store_lock: + for local_idx, global_idx in enumerate(page_indices): + page_store = self._image_region_store.pop(global_idx, {}) + for bbox_key, img in page_store.items(): + result[(local_idx, *bbox_key)] = img + return result + # ------------------------------------------------------------------ # UnitTracker lifecycle # ------------------------------------------------------------------ diff --git a/glmocr/pipeline/_workers.py b/glmocr/pipeline/_workers.py index cf36356..79694e8 100644 --- a/glmocr/pipeline/_workers.py +++ b/glmocr/pipeline/_workers.py @@ -311,6 +311,11 @@ def recognition_worker( if identifier == IDENTIFIER_REGION: if msg["region"]["task_type"] == "skip": msg["region"]["content"] = None + bbox = msg["region"].get("bbox_2d") + if bbox and "cropped_image" in msg: + state.store_cropped_image( + msg["page_idx"], bbox, msg["cropped_image"] + ) state.add_recognition_result(msg["page_idx"], msg["region"]) else: req = page_loader.build_request_from_image( diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index 81e25f4..5fb7ec2 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -268,12 +268,16 @@ def _emit_results( time.sleep(0.05) continue - json_u, md_u = self.result_formatter.process(grouped) + cropped_images = state.collect_cropped_images_for_unit(page_indices) + json_u, md_u, image_files = self.result_formatter.process( + grouped, cropped_images=cropped_images or None, + ) yield PipelineResult( json_result=json_u, markdown_result=md_u, original_images=[original_inputs[u]], layout_vis_dir=layout_vis_output_dir, layout_image_indices=page_indices, + image_files=image_files or None, ) emitted.add(u) diff --git a/glmocr/postprocess/result_formatter.py b/glmocr/postprocess/result_formatter.py index 393958c..da228e8 100644 --- a/glmocr/postprocess/result_formatter.py +++ b/glmocr/postprocess/result_formatter.py @@ -43,7 +43,7 @@ class ResultFormatter(BasePostProcessor): formatter = ResultFormatter(ResultFormatterConfig()) # Layout mode: process grouped results - json_str, md_str = formatter.process(grouped_results) + json_str, md_str, image_files = formatter.process(grouped_results) # OCR-only mode: format a single output json_str, md_str = formatter.format_ocr_result(content) @@ -132,14 +132,25 @@ def format_multi_page_results(self, contents: List[str]) -> Tuple[str, str]: # Layout mode # ========================================================================= - def process(self, grouped_results: List[List[Dict]]) -> Tuple[str, str]: + def process( + self, + grouped_results: List[List[Dict]], + cropped_images: Dict[tuple, Any] | None = None, + image_prefix: str = "cropped", + ) -> Tuple[str, str, Dict[str, Any]]: """Process grouped results in layout mode. Args: grouped_results: Region recognition results grouped by page. + cropped_images: Pre-cropped PIL images keyed by + ``(local_page_idx, *bbox)``; when provided, image regions + are resolved to final file paths directly in the markdown + and JSON output. + image_prefix: Filename prefix for saved images. Returns: - (json_str, markdown_str) + (json_str, markdown_str, image_files) where *image_files* maps + ``filename`` → PIL Image for the caller to persist. """ json_final_results = [] @@ -190,7 +201,9 @@ def process(self, grouped_results: List[List[Dict]]) -> Tuple[str, str]: json_final_results.append(json_page_results) - # Generate markdown results + # Generate markdown results and resolve image regions + image_files: Dict[str, Any] = {} + image_counter = 0 with profiler.measure("generate_markdown"): markdown_final_results = [] for page_idx, json_page_results in enumerate(json_final_results): @@ -198,9 +211,22 @@ def process(self, grouped_results: List[List[Dict]]) -> Tuple[str, str]: for result in json_page_results: content = result["content"] if result["label"] == "image": - markdown_page_results.append( - f"![](page={page_idx},bbox={result.get('bbox_2d', [])})" + bbox = result.get("bbox_2d", []) + key = (page_idx, *bbox) if bbox else None + img = ( + cropped_images.get(key) + if cropped_images and key + else None ) + if img is not None: + filename = f"{image_prefix}_page{page_idx}_idx{image_counter}.jpg" + rel_path = f"imgs/{filename}" + image_files[filename] = img + result["image_path"] = rel_path + markdown_page_results.append( + f"![Image {page_idx}-{image_counter}]({rel_path})" + ) + image_counter += 1 elif content: markdown_page_results.append(content) markdown_final_results.append("\n\n".join(markdown_page_results)) @@ -209,7 +235,7 @@ def process(self, grouped_results: List[List[Dict]]) -> Tuple[str, str]: json_str = json.dumps(json_final_results, ensure_ascii=False) markdown_str = "\n\n".join(markdown_final_results) - return json_str, markdown_str + return json_str, markdown_str, image_files # ========================================================================= # Content handling diff --git a/glmocr/utils/markdown_utils.py b/glmocr/utils/markdown_utils.py index 061ba90..9c36a1d 100644 --- a/glmocr/utils/markdown_utils.py +++ b/glmocr/utils/markdown_utils.py @@ -1,136 +1,106 @@ -"""Markdown processing utilities.""" +"""Markdown processing utilities for image region resolution.""" + +from __future__ import annotations -import re -import ast from pathlib import Path -from typing import List, Tuple +from typing import Any, Dict, List, Tuple from PIL import Image -from glmocr.utils.image_utils import ( - crop_image_region, - pdf_to_images_pil, -) +from glmocr.utils.image_utils import crop_image_region, pdf_to_images_pil from glmocr.utils.logging import get_logger logger = get_logger(__name__) -def extract_image_refs(markdown_text: str) -> List[Tuple[int, List[int], str]]: - """Extract image references from Markdown. - - Args: - markdown_text: Markdown text. - - Returns: - List of (page_idx, bbox, original_tag). - """ - # Pattern: ![](page=0,bbox=[57, 199, 884, 444]) - pattern = r"!\[\]\(page=(\d+),bbox=(\[[\d,\s]+\])\)" - matches = re.finditer(pattern, markdown_text) - - image_refs = [] - for match in matches: - page_idx = int(match.group(1)) - bbox_str = match.group(2) - # Parse bbox string safely - try: - bbox = ast.literal_eval(bbox_str) - if not isinstance(bbox, list) or len(bbox) != 4: - raise ValueError(f"Invalid bbox format: {bbox_str}") - except (ValueError, SyntaxError) as e: - logger.warning("Cannot parse bbox %s: %s", bbox_str, e) - continue - original_tag = match.group(0) - image_refs.append((page_idx, bbox, original_tag)) - - return image_refs - +def resolve_image_regions( + json_result: list, + markdown_result: str, + source: str, + image_prefix: str = "cropped", +) -> Tuple[list, str, Dict[str, Any]]: + """Crop image regions from the original file, resolve markdown and JSON paths. -def crop_and_replace_images( - markdown_text: str, - original_images: List[str], - output_dir: Path, - image_prefix: str = "image", -) -> Tuple[str, List[str]]: - """Crop referenced image regions and replace Markdown tags. + For results where image regions only have bbox references (e.g. MaaS), + this function loads the original file, crops each image region, and + produces the ``image_files`` dict that ``PipelineResult.save()`` persists + to disk. Args: - markdown_text: Source Markdown. - original_images: Original image paths. - output_dir: Output directory. + json_result: List-of-pages recognition results (list of lists of + region dicts). + markdown_result: Markdown text potentially containing + ``![](page=N,bbox=[...])`` placeholders. + source: Path to the original image or PDF file. image_prefix: Filename prefix for cropped images. Returns: - (updated_markdown, saved_image_paths) + (updated_json_result, updated_markdown_result, image_files) """ - # Ensure output directory exists - output_dir.mkdir(parents=True, exist_ok=True) - - # Extract image references - image_refs = extract_image_refs(markdown_text) - - if not image_refs: - # No image references - return markdown_text, [] - - # Load originals (supports PDFs) - loaded_images = [] - for img_path in original_images: - path = Path(img_path) - suffix = path.suffix.lower() - - if suffix == ".pdf": - try: - pdf_images = pdf_to_images_pil( - img_path, dpi=200, max_width_or_height=3500 - ) - loaded_images.extend(pdf_images) - except Exception as e: - raise RuntimeError(f"Failed to convert PDF to images: {e}") from e - else: - # Normal image file - img = Image.open(img_path) + has_images = any( + r.get("label") == "image" + for page in json_result if isinstance(page, list) + for r in page if isinstance(r, dict) + ) + if not has_images: + return json_result, markdown_result, {} + + path = Path(source) + loaded_images: list = [] + try: + if path.suffix.lower() == ".pdf" and path.is_file(): + loaded_images = pdf_to_images_pil( + str(path), dpi=200, max_width_or_height=3500, + ) + elif path.is_file(): + img = Image.open(str(path)) if img.mode != "RGB": img = img.convert("RGB") loaded_images.append(img) + except Exception as e: + logger.warning("Cannot load source %s for image cropping: %s", source, e) + return json_result, markdown_result, {} - # Process each reference - result_markdown = markdown_text - saved_image_paths = [] - - for idx, (page_idx, bbox, original_tag) in enumerate(image_refs): - # Validate page index - if page_idx < 0 or page_idx >= len(loaded_images): - logger.warning( - "page_idx %d out of range (total %d images), skipping", - page_idx, - len(loaded_images), - ) - continue - - # Crop from original - original_image = loaded_images[page_idx] - try: - cropped_image = crop_image_region(original_image, bbox) + if not loaded_images: + return json_result, markdown_result, {} - # Output filename format: image_page0_idx0.jpg - image_filename = f"{image_prefix}_page{page_idx}_idx{idx}.jpg" - image_path = output_dir / image_filename + image_files: Dict[str, Any] = {} + image_counter = 0 + updated_json: List[list] = [] - # Save cropped image - cropped_image.save(image_path, quality=95) - saved_image_paths.append(str(image_path)) - - # Replace Markdown image tag with a relative path (imgs/filename) - relative_path = f"imgs/{image_filename}" - new_tag = f"![Image {page_idx}-{idx}]({relative_path})" - result_markdown = result_markdown.replace(original_tag, new_tag, 1) - - except Exception as e: - logger.warning( - "Failed to crop image (page=%d, bbox=%s): %s", page_idx, bbox, e - ) - # Keep original tag on failure + for page_idx, page in enumerate(json_result): + if not isinstance(page, list): + updated_json.append(page) continue - - return result_markdown, saved_image_paths + page_copy = [] + for region in page: + if ( + not isinstance(region, dict) + or region.get("label") != "image" + or page_idx >= len(loaded_images) + ): + page_copy.append(region) + continue + + bbox = region.get("bbox_2d") + region_copy = dict(region) + if bbox: + try: + cropped = crop_image_region(loaded_images[page_idx], bbox) + filename = f"{image_prefix}_page{page_idx}_idx{image_counter}.jpg" + rel_path = f"imgs/{filename}" + image_files[filename] = cropped + region_copy["image_path"] = rel_path + + old_tag = f"![](page={page_idx},bbox={bbox})" + new_tag = f"![Image {page_idx}-{image_counter}]({rel_path})" + markdown_result = markdown_result.replace(old_tag, new_tag, 1) + image_counter += 1 + except Exception as e: + logger.warning( + "Failed to crop image (page=%d, bbox=%s): %s", + page_idx, bbox, e, + ) + page_copy.append(region_copy) + updated_json.append(page_copy) + + return updated_json, markdown_result, image_files From eebf13a80b05d6cf6f3d61141be6a1cd25f940f5 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Fri, 6 Mar 2026 11:48:52 +0800 Subject: [PATCH 12/38] save raw output json file from recognition model --- glmocr/parser_result/base.py | 13 +++++++++++++ glmocr/parser_result/pipeline_result.py | 3 +++ glmocr/pipeline/pipeline.py | 26 +++++++++++++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/glmocr/parser_result/base.py b/glmocr/parser_result/base.py index 6d5a406..8be772e 100644 --- a/glmocr/parser_result/base.py +++ b/glmocr/parser_result/base.py @@ -28,6 +28,7 @@ def __init__( markdown_result: Optional[str] = None, original_images: Optional[List[str]] = None, image_files: Optional[Dict[str, Any]] = None, + raw_json_result: Optional[list] = None, ): """Initialize. @@ -37,6 +38,8 @@ def __init__( original_images: Original image paths. image_files: Mapping of ``filename`` → PIL Image for image-type regions, to be saved under ``imgs/`` during :meth:`save`. + raw_json_result: Raw model output before post-processing; + saved as ``{name}_model.json`` alongside the final result. """ if isinstance(json_result, str): try: @@ -51,6 +54,7 @@ def __init__( str(Path(p).absolute()) for p in (original_images or []) ] self.image_files = image_files + self.raw_json_result = raw_json_result @abstractmethod def save( @@ -91,6 +95,15 @@ def _save_json_and_markdown(self, output_dir: Union[str, Path]) -> None: logger.warning("Failed to save JSON: %s", e) traceback.print_exc() + # Raw model output (before post-processing) + if self.raw_json_result is not None: + raw_file = output_path / f"{base_name}_model.json" + try: + with open(raw_file, "w", encoding="utf-8") as f: + json.dump(self.raw_json_result, f, ensure_ascii=False, indent=2) + except Exception as e: + logger.warning("Failed to save raw JSON: %s", e) + # Markdown if self.markdown_result and self.markdown_result.strip(): md_file = output_path / f"{base_name}.md" diff --git a/glmocr/parser_result/pipeline_result.py b/glmocr/parser_result/pipeline_result.py index 1f02488..93f00c8 100644 --- a/glmocr/parser_result/pipeline_result.py +++ b/glmocr/parser_result/pipeline_result.py @@ -27,6 +27,7 @@ def __init__( layout_vis_dir: Optional[str] = None, layout_image_indices: Optional[List[int]] = None, image_files: Optional[dict] = None, + raw_json_result: Optional[list] = None, ): """Initialize. @@ -39,12 +40,14 @@ def __init__( None means all files in layout_vis_dir belong to this unit. image_files: Mapping of ``filename`` → PIL Image for image-type regions; saved directly to ``imgs/`` during :meth:`save`. + raw_json_result: Raw model output before post-processing (optional). """ super().__init__( json_result=json_result, markdown_result=markdown_result, original_images=original_images, image_files=image_files, + raw_json_result=raw_json_result, ) self.layout_vis_dir = layout_vis_dir self.layout_image_indices = layout_image_indices diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index 5fb7ec2..e46113d 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -208,6 +208,30 @@ def __exit__(self, exc_type, exc_val, exc_tb): # Private helpers # ------------------------------------------------------------------ + @staticmethod + def _build_raw_json(grouped_results: List[List[Dict]]) -> list: + """Build a raw JSON snapshot from grouped recognition results. + + Same structure as the final JSON (list of pages, each a list of region + dicts) but with the original model output before any post-processing. + """ + raw = [] + for page_results in grouped_results: + sorted_results = sorted( + page_results, key=lambda x: x.get("index", 0) + ) + raw.append([ + { + "index": i, + "label": r.get("label", "text"), + "content": r.get("content", ""), + "bbox_2d": r.get("bbox_2d"), + "polygon": r.get("polygon"), + } + for i, r in enumerate(sorted_results) + ]) + return raw + def _process_passthrough( self, request_data: Dict[str, Any], @@ -269,6 +293,7 @@ def _emit_results( continue cropped_images = state.collect_cropped_images_for_unit(page_indices) + raw_json = self._build_raw_json(grouped) json_u, md_u, image_files = self.result_formatter.process( grouped, cropped_images=cropped_images or None, ) @@ -279,5 +304,6 @@ def _emit_results( layout_vis_dir=layout_vis_output_dir, layout_image_indices=page_indices, image_files=image_files or None, + raw_json_result=raw_json, ) emitted.add(u) From 5dd277858142ca8810f0385e879a63108172f164 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Fri, 6 Mar 2026 13:21:02 +0800 Subject: [PATCH 13/38] add an argument to control whether to use polygon property in layout detector --- glmocr/config.py | 1 + glmocr/config.yaml | 10 +++++++++- glmocr/layout/layout_detector.py | 2 ++ glmocr/pipeline/_workers.py | 9 ++++++++- glmocr/pipeline/pipeline.py | 2 +- 5 files changed, 21 insertions(+), 3 deletions(-) diff --git a/glmocr/config.py b/glmocr/config.py index dcd9d18..74c36a1 100644 --- a/glmocr/config.py +++ b/glmocr/config.py @@ -187,6 +187,7 @@ class LayoutConfig(_BaseConfig): layout_unclip_ratio: Optional[Any] = None layout_merge_bboxes_mode: Union[str, Dict[int, str]] = "large" label_task_mapping: Optional[Dict[str, Any]] = None + use_polygon: bool = False class PipelineConfig(_BaseConfig): diff --git a/glmocr/config.yaml b/glmocr/config.yaml index d1921e3..68fdc1f 100644 --- a/glmocr/config.yaml +++ b/glmocr/config.yaml @@ -140,7 +140,7 @@ pipeline: table: "Table Recognition:" formula: "Formula Recognition:" - # PDF processing (pypdfium2 only) + # PDF processing pdf_dpi: 200 pdf_max_pages: null # null = no limit pdf_verbose: false @@ -196,6 +196,14 @@ pipeline: cuda_visible_devices: "0" # img_size: null # resize input (optional) + # Use polygon masks for region cropping and visualization. + # When true, regions are cropped using the polygon outline from layout + # detection (more precise, masks out content outside the polygon), + # recommended for documents with rotating or staggered layouts. + # When false, regions are cropped using the bounding box only (faster, simpler), + # recommended for regular documents without rotating. + use_polygon: false + # Post-processing layout_nms: true layout_unclip_ratio: diff --git a/glmocr/layout/layout_detector.py b/glmocr/layout/layout_detector.py index a616a38..2584fa1 100644 --- a/glmocr/layout/layout_detector.py +++ b/glmocr/layout/layout_detector.py @@ -216,6 +216,7 @@ def process( save_visualization: bool = False, visualization_output_dir: Optional[str] = None, global_start_idx: int = 0, + use_polygon: bool = False, ) -> List[List[Dict]]: """Batch-detect layout regions in-process. @@ -318,6 +319,7 @@ def process( show_label=True, show_score=True, show_index=True, + use_polygon=use_polygon, ) saved_vis_paths.append(str(save_path)) diff --git a/glmocr/pipeline/_workers.py b/glmocr/pipeline/_workers.py index 79694e8..9d14221 100644 --- a/glmocr/pipeline/_workers.py +++ b/glmocr/pipeline/_workers.py @@ -132,6 +132,7 @@ def layout_worker( layout_detector: "BaseLayoutDetector", save_visualization: bool, vis_output_dir: Optional[str], + use_polygon: bool = False, ) -> None: """Consume pages, run layout detection in batches, push regions. @@ -173,6 +174,7 @@ def layout_worker( _flush_layout_batch( state, layout_detector, batch_images, batch_page_indices, save_visualization, vis_output_dir, global_start_idx, + use_polygon=use_polygon, ) global_start_idx += len(batch_page_indices) for pi in batch_page_indices: @@ -185,6 +187,7 @@ def layout_worker( _flush_layout_batch( state, layout_detector, batch_images, batch_page_indices, save_visualization, vis_output_dir, global_start_idx, + use_polygon=use_polygon, ) global_start_idx += len(batch_page_indices) for pi in batch_page_indices: @@ -207,6 +210,7 @@ def layout_worker( _flush_layout_batch( state, layout_detector, batch_images, batch_page_indices, save_visualization, vis_output_dir, global_start_idx, + use_polygon=use_polygon, ) state.safe_put(state.region_queue, {"identifier": IDENTIFIER_DONE}) break @@ -226,6 +230,7 @@ def _flush_layout_batch( save_visualization: bool, vis_output_dir: Optional[str], global_start_idx: int, + use_polygon: bool = False, ) -> None: """Run layout detection on one batch and enqueue the resulting regions.""" try: @@ -234,6 +239,7 @@ def _flush_layout_batch( save_visualization=save_visualization and vis_output_dir is not None, visualization_output_dir=vis_output_dir, global_start_idx=global_start_idx, + use_polygon=use_polygon, ) except Exception as e: logger.warning( @@ -250,7 +256,8 @@ def _flush_layout_batch( state.layout_results_dict[page_idx] = layout_result for region in layout_result: try: - cropped = crop_image_region(image, region["bbox_2d"], region["polygon"]) + polygon = region.get("polygon") if use_polygon else None + cropped = crop_image_region(image, region["bbox_2d"], polygon) except Exception as e: logger.warning( "Failed to crop region on page %d (bbox=%s), skipping: %s", diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index e46113d..b5a1e6e 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -143,7 +143,7 @@ def process( ) t2 = threading.Thread( target=layout_worker, - args=(state, self.layout_detector, save_layout_visualization, layout_vis_output_dir), + args=(state, self.layout_detector, save_layout_visualization, layout_vis_output_dir, self.config.layout.use_polygon), daemon=True, ) t3 = threading.Thread( From ffa8f534b21ba931464251ecdc199aa994de200e Mon Sep 17 00:00:00 2001 From: xueyadong Date: Tue, 10 Mar 2026 11:56:43 +0800 Subject: [PATCH 14/38] support load image/PDF files from a directory recursively & update default config --- glmocr/cli.py | 56 ++++++++++++++++++++++++++++++-------------- glmocr/config.yaml | 10 ++++---- glmocr/ocr_client.py | 2 +- 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/glmocr/cli.py b/glmocr/cli.py index 899a716..d50c31c 100644 --- a/glmocr/cli.py +++ b/glmocr/cli.py @@ -9,7 +9,7 @@ import threading import traceback from pathlib import Path -from typing import List +from typing import List, Optional, Tuple from tqdm import tqdm @@ -19,39 +19,54 @@ logger = get_logger(__name__) -def load_image_paths(input_path: str) -> List[str]: - """Load image paths from a file or directory. +_SUPPORTED_SUFFIXES = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".pdf"} - PDF files are included as inputs (they will be expanded into page images later). + +def load_image_paths(input_path: str) -> Tuple[List[str], Optional[str]]: + """Load image paths from a file or directory (recursively). + + When *input_path* is a directory the search is recursive — all supported + image/PDF files in nested subdirectories are collected. Args: input_path: Input path (file or directory). Returns: - List[str]: Image/PDF file paths. + A tuple ``(image_paths, input_root)``. + *input_root* is the absolute directory path when the input is a + directory (``None`` when it is a single file). It is used by the + caller to compute relative paths so that the output preserves the + original directory hierarchy. """ path = Path(input_path) - image_paths = [] if path.is_file(): - suffix = path.suffix.lower() - if suffix in [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".pdf"]: - image_paths.append(str(path.absolute())) - else: + if path.suffix.lower() not in _SUPPORTED_SUFFIXES: raise ValueError(f"Not Supported Type: {path.suffix}") - elif path.is_dir(): + return [str(path.absolute())], None + + if path.is_dir(): + seen: set = set() + image_paths: List[str] = [] for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.pdf"]: - image_paths.extend([str(p.absolute()) for p in path.glob(ext)]) - image_paths.extend([str(p.absolute()) for p in path.glob(ext.upper())]) + for p in path.rglob(ext): + abs_p = str(p.absolute()) + if abs_p not in seen: + seen.add(abs_p) + image_paths.append(abs_p) + for p in path.rglob(ext.upper()): + abs_p = str(p.absolute()) + if abs_p not in seen: + seen.add(abs_p) + image_paths.append(abs_p) image_paths.sort() if not image_paths: raise ValueError( f"Cannot find image or PDF files in directory: {input_path}" ) - else: - raise ValueError(f"Path does not exist: {input_path}") + return image_paths, str(path.absolute()) - return image_paths + raise ValueError(f"Path does not exist: {input_path}") def _queue_stats_updater(glm_parser: GlmOcr, pbar: tqdm, stop: threading.Event): @@ -144,7 +159,7 @@ def main(): try: logger.info("Loading images: %s", args.input) - image_paths = load_image_paths(args.input) + image_paths, input_root = load_image_paths(args.input) logger.info("Found %d file(s)", len(image_paths)) save_layout_vis = not args.no_layout_vis @@ -203,8 +218,13 @@ def main(): print(result.markdown_result) if not args.no_save: + save_dir = args.output + if input_root and result.original_images: + rel = Path(result.original_images[0]).parent.relative_to(input_root) + if str(rel) != ".": + save_dir = str(Path(args.output) / rel) result.save( - output_dir=args.output, + output_dir=save_dir, save_layout_visualization=save_layout_vis, ) diff --git a/glmocr/config.yaml b/glmocr/config.yaml index 68fdc1f..33c68f2 100644 --- a/glmocr/config.yaml +++ b/glmocr/config.yaml @@ -101,15 +101,15 @@ pipeline: retry_status_codes: [429, 500, 502, 503, 504] # HTTP connection pool size (default 128). Set >= max_workers to avoid - # "Connection pool is full" when layout mode runs concurrent requests. + # "Connection pool is full" when runs concurrent requests. connection_pool_size: 128 - # Maximum parallel workers for region recognition (layout mode) + # Maximum parallel workers for region recognition # Lower values to reduce 503 errors on busy OCR servers - max_workers: 32 + max_workers: 64 # Queue sizes page_maxsize: 100 - region_maxsize: 800 + region_maxsize: 2000 # Page loader: handles image/PDF loading and API request building page_loader: @@ -134,7 +134,7 @@ pipeline: Preserve the original layout (headings/paragraphs/tables/formulas). Do not fabricate content that does not exist in the image. - # Task-specific prompts (used in layout mode) + # Task-specific prompts task_prompt_mapping: text: "Text Recognition:" table: "Table Recognition:" diff --git a/glmocr/ocr_client.py b/glmocr/ocr_client.py index 4fcc747..57c8b9c 100644 --- a/glmocr/ocr_client.py +++ b/glmocr/ocr_client.py @@ -321,7 +321,7 @@ def process(self, request_data: Dict) -> Tuple[Dict, int]: "error": f"Invalid OpenAI API response format: {str(e)}" }, 500 - return {"choices": [{"message": {"content": output.strip()}}]}, 200 + return {"choices": [{"message": {"content": (output or "").strip()}}]}, 200 status = int(response.status_code) body_preview = (response.text or "")[:500] From 1d7a43a4a1821b6a4c8476ea8f7432c0c5227dae Mon Sep 17 00:00:00 2001 From: xueyadong Date: Tue, 10 Mar 2026 11:56:33 +0000 Subject: [PATCH 15/38] Implement safe extraction of polygon points in layout detector to handle empty mask crops and prevent crashes in cv2.resize --- glmocr/layout/layout_detector.py | 72 ++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 22 deletions(-) diff --git a/glmocr/layout/layout_detector.py b/glmocr/layout/layout_detector.py index 2584fa1..0762db4 100644 --- a/glmocr/layout/layout_detector.py +++ b/glmocr/layout/layout_detector.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import TYPE_CHECKING, List, Dict, Optional +import cv2 import torch import numpy as np from PIL import Image @@ -77,6 +78,55 @@ def start(self): self._model = self._model.to(self._device) if self.id2label is None: self.id2label = self._model.config.id2label + + # Patch upstream _extract_polygon_points_by_masks to guard against + # empty mask crops that crash cv2.resize with !ssize.empty(). + def _safe_extract(boxes, masks, scale_ratio): + scale_w, scale_h = scale_ratio[0] / 4, scale_ratio[1] / 4 + mask_h, mask_w = masks.shape[1:] + polygon_points = [] + + for i in range(len(boxes)): + x_min, y_min, x_max, y_max = boxes[i].astype(np.int32) + box_w, box_h = x_max - x_min, y_max - y_min + rect = np.array( + [[x_min, y_min], [x_max, y_min], + [x_max, y_max], [x_min, y_max]], + dtype=np.float32, + ) + + if box_w <= 0 or box_h <= 0: + polygon_points.append(rect) + continue + + x_start = int(round((x_min * scale_w).item())) + x_end = int(round((x_max * scale_w).item())) + x_start, x_end = np.clip([x_start, x_end], 0, mask_w) + y_start = int(round((y_min * scale_h).item())) + y_end = int(round((y_max * scale_h).item())) + y_start, y_end = np.clip([y_start, y_end], 0, mask_h) + + cropped_mask = masks[i, y_start:y_end, x_start:x_end] + if cropped_mask.size == 0: + polygon_points.append(rect) + continue + + resized = cv2.resize( + cropped_mask.astype(np.uint8), + (box_w, box_h), + interpolation=cv2.INTER_NEAREST, + ) + polygon = self._image_processor._mask2polygon(resized) + if polygon is not None and len(polygon) < 4: + polygon_points.append(rect) + continue + if polygon is not None and len(polygon) > 0: + polygon = polygon + np.array([x_min, y_min]) + polygon_points.append(polygon) + + return polygon_points + + self._image_processor._extract_polygon_points_by_masks = _safe_extract logger.debug(f"PP-DocLayoutV3 loaded on device: {self._device}") def stop(self): @@ -252,28 +302,6 @@ def process( target_sizes = torch.tensor( [img.size[::-1] for img in chunk_pil], device=self._device ) - try: - if hasattr(outputs, "pred_boxes") and outputs.pred_boxes is not None: - pred_boxes = outputs.pred_boxes - if hasattr(outputs, "out_masks") and outputs.out_masks is not None: - mask_h, mask_w = outputs.out_masks.shape[-2:] - else: - mask_h, mask_w = 200, 200 - min_norm_w = 1.0 / mask_w - min_norm_h = 1.0 / mask_h - box_wh = pred_boxes[..., 2:4] - valid_mask = (box_wh[..., 0] > min_norm_w) & ( - box_wh[..., 1] > min_norm_h - ) - if hasattr(outputs, "logits") and outputs.logits is not None: - invalid_mask = ~valid_mask - if invalid_mask.any(): - outputs.logits.masked_fill_( - invalid_mask.unsqueeze(-1), -100.0 - ) - except Exception as e: - logger.warning("Pre-filter failed (%s), continuing...", e) - if self.threshold_by_class: # Use the lowest threshold (per-class or global fallback) # so post-processing doesn't discard valid detections early. From 9ce61f78d29f3ba2f178e6603e97bea4361c35ee Mon Sep 17 00:00:00 2001 From: xueyadong Date: Tue, 10 Mar 2026 12:57:32 +0000 Subject: [PATCH 16/38] Removed temporary directory usage for layout visualizations and updated related methods to return visualization images directly --- glmocr/api.py | 12 ----- glmocr/layout/layout_detector.py | 33 ++++++------- glmocr/parser_result/pipeline_result.py | 61 +++++++------------------ glmocr/pipeline/_state.py | 3 ++ glmocr/pipeline/_workers.py | 15 +++--- glmocr/pipeline/pipeline.py | 24 +++++----- glmocr/server.py | 1 - 7 files changed, 52 insertions(+), 97 deletions(-) diff --git a/glmocr/api.py b/glmocr/api.py index 4d175a1..6455818 100644 --- a/glmocr/api.py +++ b/glmocr/api.py @@ -506,16 +506,10 @@ def _parse_selfhosted( ) -> List[PipelineResult]: """Parse using self-hosted vLLM/SGLang pipeline.""" request_data = self._build_selfhosted_request(images) - - layout_vis_dir = None - if save_layout_visualization: - layout_vis_dir = tempfile.mkdtemp(prefix="layout_vis_") - results = list( self._pipeline.process( request_data, save_layout_visualization=save_layout_visualization, - layout_vis_output_dir=layout_vis_dir, ) ) return results @@ -527,15 +521,9 @@ def _stream_parse_selfhosted( ) -> Generator[PipelineResult, None, None]: """Streaming variant of self-hosted parse().""" request_data = self._build_selfhosted_request(images) - - layout_vis_dir = None - if save_layout_visualization: - layout_vis_dir = tempfile.mkdtemp(prefix="layout_vis_") - for result in self._pipeline.process( request_data, save_layout_visualization=save_layout_visualization, - layout_vis_output_dir=layout_vis_dir, ): yield result diff --git a/glmocr/layout/layout_detector.py b/glmocr/layout/layout_detector.py index 0762db4..d0607b3 100644 --- a/glmocr/layout/layout_detector.py +++ b/glmocr/layout/layout_detector.py @@ -18,7 +18,7 @@ from glmocr.layout.base import BaseLayoutDetector from glmocr.utils.layout_postprocess_utils import apply_layout_postprocess from glmocr.utils.logging import get_logger -from glmocr.utils.visualization_utils import save_layout_visualization +from glmocr.utils.visualization_utils import draw_layout_boxes if TYPE_CHECKING: from glmocr.config import LayoutConfig @@ -264,20 +264,22 @@ def process( self, images: List[Image.Image], save_visualization: bool = False, - visualization_output_dir: Optional[str] = None, global_start_idx: int = 0, use_polygon: bool = False, - ) -> List[List[Dict]]: + ) -> tuple: """Batch-detect layout regions in-process. Args: images: List of PIL Images. - save_visualization: Whether to also save visualization. - visualization_output_dir: Where to save visualization outputs. - global_start_idx: Start index for visualization filenames (layout_page{N}). + save_visualization: Whether to generate visualization images. + global_start_idx: Start index for visualization page numbering. + use_polygon: Use polygon masks for visualization and cropping. Returns: - List[List[Dict]]: Detection results per image. + Tuple of (results, vis_images) where *results* is + ``List[List[Dict]]`` and *vis_images* is + ``Dict[int, PIL.Image.Image]`` mapping global page index to + the rendered layout visualization (empty dict when disabled). """ if self._model is None: raise RuntimeError("Layout detector not started. Call start() first.") @@ -332,24 +334,15 @@ def process( del inputs, outputs, raw_results torch.cuda.empty_cache() - saved_vis_paths = [] - if save_visualization and visualization_output_dir: - vis_output_path = Path(visualization_output_dir) - vis_output_path.mkdir(parents=True, exist_ok=True) + vis_images: Dict[int, Image.Image] = {} + if save_visualization: for img_idx, img_results in enumerate(all_paddle_format_results): vis_img = np.array(pil_images[img_idx]) - save_filename = f"layout_page{global_start_idx + img_idx}.jpg" - save_path = vis_output_path / save_filename - save_layout_visualization( + vis_images[global_start_idx + img_idx] = draw_layout_boxes( image=vis_img, boxes=img_results, - save_path=str(save_path), - show_label=True, - show_score=True, - show_index=True, use_polygon=use_polygon, ) - saved_vis_paths.append(str(save_path)) all_results = [] for img_idx, paddle_results in enumerate(all_paddle_format_results): @@ -396,4 +389,4 @@ def process( valid_index += 1 all_results.append(results) - return all_results + return all_results, vis_images diff --git a/glmocr/parser_result/pipeline_result.py b/glmocr/parser_result/pipeline_result.py index 93f00c8..c088937 100644 --- a/glmocr/parser_result/pipeline_result.py +++ b/glmocr/parser_result/pipeline_result.py @@ -2,9 +2,8 @@ from __future__ import annotations -import shutil from pathlib import Path -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union from glmocr.utils.logging import get_logger @@ -24,10 +23,9 @@ def __init__( json_result: Union[str, dict, list], markdown_result: Optional[str], original_images: List[str], - layout_vis_dir: Optional[str] = None, - layout_image_indices: Optional[List[int]] = None, image_files: Optional[dict] = None, raw_json_result: Optional[list] = None, + layout_vis_images: Optional[Dict[int, Any]] = None, ): """Initialize. @@ -35,12 +33,11 @@ def __init__( json_result: JSON result (string, dict, or list). markdown_result: Markdown result. original_images: Original image paths for this unit. - layout_vis_dir: Temp dir with layout_page{N}.jpg (optional). - layout_image_indices: Indices of layout pages belonging to this unit; - None means all files in layout_vis_dir belong to this unit. image_files: Mapping of ``filename`` → PIL Image for image-type regions; saved directly to ``imgs/`` during :meth:`save`. raw_json_result: Raw model output before post-processing (optional). + layout_vis_images: Mapping of ``page_idx`` → PIL Image for layout + visualization; saved to ``layout_vis/`` during :meth:`save`. """ super().__init__( json_result=json_result, @@ -49,9 +46,7 @@ def __init__( image_files=image_files, raw_json_result=raw_json_result, ) - self.layout_vis_dir = layout_vis_dir - self.layout_image_indices = layout_image_indices - self._layout_vis_saved = False + self.layout_vis_images = layout_vis_images def save( self, @@ -61,15 +56,7 @@ def save( """Save JSON, Markdown, and optionally layout visualization.""" self._save_json_and_markdown(output_dir) - if ( - not save_layout_visualization - or not self.layout_vis_dir - or self._layout_vis_saved - ): - return - - temp_layout_path = Path(self.layout_vis_dir) - if not temp_layout_path.exists(): + if not save_layout_visualization or not self.layout_vis_images: return if self.original_images: @@ -80,34 +67,18 @@ def save( target_dir.mkdir(parents=True, exist_ok=True) - if self.layout_image_indices is not None: - layout_files = [] - for idx in self.layout_image_indices: - for ext in (".jpg", ".png"): - p = temp_layout_path / f"layout_page{idx}{ext}" - if p.exists(): - layout_files.append(p) - break - else: - layout_files = sorted(temp_layout_path.glob("layout_page*.jpg")) - layout_files.extend(sorted(temp_layout_path.glob("layout_page*.png"))) - + vis_items = sorted(self.layout_vis_images.items()) stem = Path(self.original_images[0]).stem if self.original_images else "result" - for local_idx, layout_file in enumerate(layout_files): - ext = layout_file.suffix.lstrip(".").lower() or "jpg" - new_name = ( - f"{stem}.{ext}" - if len(layout_files) == 1 - else f"{stem}_page{local_idx}.{ext}" + for local_idx, (_page_idx, vis_img) in enumerate(vis_items): + name = ( + f"{stem}.jpg" + if len(vis_items) == 1 + else f"{stem}_page{local_idx}.jpg" ) - target_file = target_dir / new_name - shutil.move(str(layout_file), str(target_file)) - - if self.layout_image_indices is None: try: - temp_layout_path.rmdir() - except Exception: - pass + vis_img.save(target_dir / name, quality=95) + except Exception as e: + logger.warning("Failed to save layout vis %s: %s", name, e) - self._layout_vis_saved = True + self.layout_vis_images = None logger.debug("Layout visualization saved to %s", target_dir) diff --git a/glmocr/pipeline/_state.py b/glmocr/pipeline/_state.py index 7906f5b..551d7e1 100644 --- a/glmocr/pipeline/_state.py +++ b/glmocr/pipeline/_state.py @@ -51,6 +51,9 @@ def __init__( self._image_region_store: Dict[int, Dict[tuple, Any]] = {} self._image_store_lock = threading.Lock() + # ── Layout visualization images (page_idx → PIL Image) ──────── + self.layout_vis_images: Dict[int, Any] = {} + # ── UnitTracker (set before threads start) ─────────────────── self._tracker: Optional[UnitTracker] = None diff --git a/glmocr/pipeline/_workers.py b/glmocr/pipeline/_workers.py index 9d14221..93a5803 100644 --- a/glmocr/pipeline/_workers.py +++ b/glmocr/pipeline/_workers.py @@ -131,7 +131,6 @@ def layout_worker( state: PipelineState, layout_detector: "BaseLayoutDetector", save_visualization: bool, - vis_output_dir: Optional[str], use_polygon: bool = False, ) -> None: """Consume pages, run layout detection in batches, push regions. @@ -173,7 +172,7 @@ def layout_worker( if len(batch_images) >= layout_detector.batch_size: _flush_layout_batch( state, layout_detector, batch_images, batch_page_indices, - save_visualization, vis_output_dir, global_start_idx, + save_visualization, global_start_idx, use_polygon=use_polygon, ) global_start_idx += len(batch_page_indices) @@ -186,7 +185,7 @@ def layout_worker( if batch_images: _flush_layout_batch( state, layout_detector, batch_images, batch_page_indices, - save_visualization, vis_output_dir, global_start_idx, + save_visualization, global_start_idx, use_polygon=use_polygon, ) global_start_idx += len(batch_page_indices) @@ -209,7 +208,7 @@ def layout_worker( if batch_images: _flush_layout_batch( state, layout_detector, batch_images, batch_page_indices, - save_visualization, vis_output_dir, global_start_idx, + save_visualization, global_start_idx, use_polygon=use_polygon, ) state.safe_put(state.region_queue, {"identifier": IDENTIFIER_DONE}) @@ -228,19 +227,19 @@ def _flush_layout_batch( batch_images: List[Any], batch_page_indices: List[int], save_visualization: bool, - vis_output_dir: Optional[str], global_start_idx: int, use_polygon: bool = False, ) -> None: """Run layout detection on one batch and enqueue the resulting regions.""" try: - layout_results = layout_detector.process( + layout_results, vis_images = layout_detector.process( batch_images, - save_visualization=save_visualization and vis_output_dir is not None, - visualization_output_dir=vis_output_dir, + save_visualization=save_visualization, global_start_idx=global_start_idx, use_polygon=use_polygon, ) + if vis_images: + state.layout_vis_images.update(vis_images) except Exception as e: logger.warning( "Layout detection failed for pages %s, skipping batch: %s", diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index b5a1e6e..4f14c2d 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -99,7 +99,6 @@ def process( self, request_data: Dict[str, Any], save_layout_visualization: bool = False, - layout_vis_output_dir: Optional[str] = None, page_maxsize: Optional[int] = None, region_maxsize: Optional[int] = None, ) -> Generator[PipelineResult, None, None]: @@ -110,8 +109,7 @@ def process( Args: request_data: OpenAI-style request payload containing messages. - save_layout_visualization: Save layout visualisation images. - layout_vis_output_dir: Directory for visualisation output. + save_layout_visualization: Generate layout visualisation images. page_maxsize: Bound for the page queue. region_maxsize: Bound for the region queue. @@ -121,7 +119,7 @@ def process( image_urls = extract_image_urls(request_data) if not image_urls: - yield self._process_passthrough(request_data, layout_vis_output_dir) + yield self._process_passthrough(request_data) return num_units = len(image_urls) @@ -143,7 +141,7 @@ def process( ) t2 = threading.Thread( target=layout_worker, - args=(state, self.layout_detector, save_layout_visualization, layout_vis_output_dir, self.config.layout.use_polygon), + args=(state, self.layout_detector, save_layout_visualization, self.config.layout.use_polygon), daemon=True, ) t3 = threading.Thread( @@ -157,7 +155,7 @@ def process( t3.start() try: - yield from self._emit_results(state, tracker, original_inputs, layout_vis_output_dir) + yield from self._emit_results(state, tracker, original_inputs) finally: state.request_shutdown() t1.join(timeout=10) @@ -235,7 +233,6 @@ def _build_raw_json(grouped_results: List[List[Dict]]) -> list: def _process_passthrough( self, request_data: Dict[str, Any], - layout_vis_output_dir: Optional[str], ) -> PipelineResult: """No image URLs — forward the request directly to the OCR API.""" request_data = self.page_loader.build_request(request_data) @@ -250,7 +247,6 @@ def _process_passthrough( json_result=json_result, markdown_result=markdown_result, original_images=[], - layout_vis_dir=layout_vis_output_dir, ) def _emit_results( @@ -258,7 +254,6 @@ def _emit_results( state: PipelineState, tracker: UnitTracker, original_inputs: List[str], - layout_vis_output_dir: Optional[str], ) -> Generator[PipelineResult, None, None]: """Wait for units to complete and yield their formatted results. @@ -297,13 +292,20 @@ def _emit_results( json_u, md_u, image_files = self.result_formatter.process( grouped, cropped_images=cropped_images or None, ) + + # Collect layout visualization images for this unit + vis_images = {} + for pi in page_indices: + img = state.layout_vis_images.pop(pi, None) + if img is not None: + vis_images[pi] = img + yield PipelineResult( json_result=json_u, markdown_result=md_u, original_images=[original_inputs[u]], - layout_vis_dir=layout_vis_output_dir, - layout_image_indices=page_indices, image_files=image_files or None, raw_json_result=raw_json, + layout_vis_images=vis_images or None, ) emitted.add(u) diff --git a/glmocr/server.py b/glmocr/server.py index 2a0f79e..8a65509 100644 --- a/glmocr/server.py +++ b/glmocr/server.py @@ -91,7 +91,6 @@ def parse(): pipeline.process( request_data, save_layout_visualization=False, - layout_vis_output_dir=None, ) ) if not results: From 961abcd0e5a5fc3d3aa1dbcab7ba0bb31076eda4 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Wed, 11 Mar 2026 07:20:34 +0000 Subject: [PATCH 17/38] Enhance configuration flexibility by adding CLI `--set` option for overriding config values --- README.md | 12 ++++++++++++ README_zh.md | 12 ++++++++++++ glmocr/api.py | 2 ++ glmocr/cli.py | 32 +++++++++++++++++++++++++++++++- glmocr/config.py | 24 ++++++++++++++++++------ 5 files changed, 75 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index aa925da..bef4c16 100644 --- a/README.md +++ b/README.md @@ -168,6 +168,10 @@ glmocr parse examples/source/code.png --config my_config.yaml # Enable debug logging with profiling glmocr parse examples/source/code.png --log-level DEBUG + +# Override any config value via --set (dotted path, repeatable) +glmocr parse examples/source/code.png --set pipeline.ocr_api.api_port 8080 +glmocr parse examples/source/ --set pipeline.layout.use_polygon true --set logging.level DEBUG ``` #### Python API @@ -213,6 +217,14 @@ Semantics: ### Configuration +Configuration priority (highest to lowest): + +1. CLI `--set` overrides +2. Python API keyword arguments +3. `GLMOCR_*` environment variables / `.env` file +4. YAML config file +5. Built-in defaults + Full configuration in `glmocr/config.yaml`: ```yaml diff --git a/README_zh.md b/README_zh.md index 9b2c5dd..8ba00b3 100644 --- a/README_zh.md +++ b/README_zh.md @@ -169,6 +169,10 @@ glmocr parse examples/source/code.png --config my_config.yaml # 开启 debug 日志(包含 profiling) glmocr parse examples/source/code.png --log-level DEBUG + +# 通过 --set 覆盖任意配置项(使用 dotted path,可多次使用) +glmocr parse examples/source/code.png --set pipeline.ocr_api.api_port 8080 +glmocr parse examples/source/ --set pipeline.layout.use_polygon true --set logging.level DEBUG ``` #### Python API @@ -214,6 +218,14 @@ curl -X POST http://localhost:5002/glmocr/parse \ ### 配置 +配置加载优先级(从高到低): + +1. CLI `--set` 参数 +2. Python API 关键字参数 +3. `GLMOCR_*` 环境变量 / `.env` 文件 +4. YAML 配置文件 +5. 内置默认值 + 完整配置见 `glmocr/config.yaml`: ```yaml diff --git a/glmocr/api.py b/glmocr/api.py index 6455818..14701c9 100644 --- a/glmocr/api.py +++ b/glmocr/api.py @@ -83,6 +83,7 @@ def __init__( ocr_api_host: Optional[str] = None, ocr_api_port: Optional[int] = None, cuda_visible_devices: Optional[str] = None, + **kwargs: Any, ): """Initialize GlmOcr. @@ -116,6 +117,7 @@ def __init__( ocr_api_host=ocr_api_host, ocr_api_port=ocr_api_port, cuda_visible_devices=cuda_visible_devices, + **kwargs, ) # Apply logging config for API/SDK usage. ensure_logging_configured( diff --git a/glmocr/cli.py b/glmocr/cli.py index d50c31c..a155c32 100644 --- a/glmocr/cli.py +++ b/glmocr/cli.py @@ -80,6 +80,18 @@ def _queue_stats_updater(glm_parser: GlmOcr, pbar: tqdm, stop: threading.Event): ) +def _auto_coerce(raw: str): + """Coerce a CLI string to a Python scalar. + """ + if raw.lower() in ("true", "yes"): + return True + if raw.lower() in ("false", "no"): + return False + if raw.lower() in ("null", "none", "~"): + return None + return raw + + def main(): """CLI entrypoint.""" parser = argparse.ArgumentParser( @@ -98,6 +110,10 @@ def main(): # Specify config file glmocr parse image.png --config config.yaml + + # Override config values via --set + glmocr parse image.png --set pipeline.ocr_api.api_port 8080 + glmocr parse image.png --set pipeline.layout.use_polygon true --set pipeline.maas.enabled false """, ) @@ -148,6 +164,15 @@ def main(): choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Log level (default: INFO)", ) + parse_parser.add_argument( + "--set", + nargs=2, + action="append", + metavar=("KEY", "VALUE"), + dest="config_overrides", + help="Override a config value using dotted path, e.g. " + "--set pipeline.ocr_api.api_port 8080", + ) args = parser.parse_args() @@ -164,7 +189,12 @@ def main(): save_layout_vis = not args.no_layout_vis - with GlmOcr(config_path=args.config) as glm_parser: + # Build dotted-path overrides from --set KEY VALUE pairs + dotted_overrides: dict = {} + for key, value in (args.config_overrides or []): + dotted_overrides[key] = _auto_coerce(value) + + with GlmOcr(config_path=args.config, _dotted=dotted_overrides) as glm_parser: total_files = len(image_paths) pbar = tqdm( diff --git a/glmocr/config.py b/glmocr/config.py index 74c36a1..e450645 100644 --- a/glmocr/config.py +++ b/glmocr/config.py @@ -296,7 +296,13 @@ def from_env( config_path: Optional[Union[str, Path]] = None, **overrides: Any, ) -> "GlmOcrConfig": - """Build config with priority: *overrides* > env-vars > YAML > defaults. + """Build config with layered priority (highest → lowest): + + 1. CLI ``--set`` overrides (``_dotted`` dict) + 2. Keyword overrides (``api_key``, ``mode``, …) + 3. ``GLMOCR_*`` environment variables / ``.env`` file + 4. YAML config file + 5. Built-in defaults This is the **agent-friendly** entry-point. An agent (or any programmatic caller) can configure the SDK entirely through keyword @@ -329,7 +335,8 @@ def from_env( # With a custom YAML base cfg = GlmOcrConfig.from_env(config_path="my.yaml", api_key="sk") """ - # 1. YAML baseline + # --- Priority (applied in order, later wins): --- + # 1. YAML baseline (lowest) yaml_path = Path(config_path or cls.default_path()) if yaml_path.exists(): data: Dict[str, Any] = ( @@ -341,12 +348,12 @@ def from_env( raise FileNotFoundError(f"Config file not found: {yaml_path}") data = {} - # 2. Environment variable overrides + # 2. Environment variable overrides (.env + GLMOCR_*) env_data = _collect_env_overrides() if env_data: _deep_merge(data, env_data) - # 3. Keyword overrides (flat convenience names → nested paths) + # 3. Keyword overrides (Python API convenience names) _KW_MAP = { "api_key": "pipeline.maas.api_key", "api_url": "pipeline.maas.api_url", @@ -365,6 +372,10 @@ def from_env( raw = overrides[kw] _set_nested(data, dotted, _coerce_env_value(dotted, str(raw))) + # 4. CLI --set overrides (highest priority) + for dotted, value in overrides.get("_dotted", {}).items(): + _set_nested(data, dotted, value) + return cls.model_validate(data) def to_dict(self) -> Dict[str, Any]: @@ -375,11 +386,12 @@ def load_config( path: Optional[Union[str, Path]] = None, **overrides: Any, ) -> GlmOcrConfig: - """Load config with priority: *overrides* > env-vars > YAML > defaults. + """Load config with priority: CLI --set > keyword > env-vars > YAML > defaults. This is a drop-in replacement for the old ``load_config(path)``. When called without arguments it behaves exactly as before (YAML only). When keyword overrides or ``GLMOCR_*`` env-vars are present they take - precedence. + precedence. CLI ``--set`` overrides (passed via ``_dotted``) have the + highest priority. """ return GlmOcrConfig.from_env(config_path=path, **overrides) From 79404cd3409b91cf4b4657760b89f722e08bf116 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Wed, 11 Mar 2026 07:20:53 +0000 Subject: [PATCH 18/38] Update default output directory from './results' to './output' in save methods across parser result classes and pipeline --- glmocr/parser_result/base.py | 2 +- glmocr/parser_result/pipeline_result.py | 2 +- glmocr/pipeline/pipeline.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/glmocr/parser_result/base.py b/glmocr/parser_result/base.py index 8be772e..d853fd1 100644 --- a/glmocr/parser_result/base.py +++ b/glmocr/parser_result/base.py @@ -59,7 +59,7 @@ def __init__( @abstractmethod def save( self, - output_dir: Union[str, Path] = "./results", + output_dir: Union[str, Path] = "./output", save_layout_visualization: bool = True, ) -> None: """Save result to disk. Subclasses implement layout vis etc.""" diff --git a/glmocr/parser_result/pipeline_result.py b/glmocr/parser_result/pipeline_result.py index c088937..605c1d4 100644 --- a/glmocr/parser_result/pipeline_result.py +++ b/glmocr/parser_result/pipeline_result.py @@ -50,7 +50,7 @@ def __init__( def save( self, - output_dir: Union[str, Path] = "./results", + output_dir: Union[str, Path] = "./output", save_layout_visualization: bool = True, ) -> None: """Save JSON, Markdown, and optionally layout visualization.""" diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index 4f14c2d..286db59 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -58,7 +58,7 @@ class Pipeline: cfg = load_config() pipeline = Pipeline(cfg.pipeline) for result in pipeline.process(request_data): - result.save(output_dir="./results") + result.save(output_dir="./output") """ def __init__( From da003dd74db7b7ad0c49972c9f43ddc66312dbdf Mon Sep 17 00:00:00 2001 From: xueyadong Date: Wed, 11 Mar 2026 08:23:15 +0000 Subject: [PATCH 19/38] Update methods to load images and PDFs from various input types, remove tempfile dependence --- glmocr/api.py | 63 ++++++---------------------- glmocr/dataloader/page_loader.py | 72 +++++++++++++++++++++++++------- glmocr/pipeline/_common.py | 32 +++++++++----- glmocr/pipeline/_workers.py | 11 +++-- glmocr/pipeline/pipeline.py | 12 +++--- glmocr/utils/image_utils.py | 24 +++++++---- 6 files changed, 119 insertions(+), 95 deletions(-) diff --git a/glmocr/api.py b/glmocr/api.py index 14701c9..098baa4 100644 --- a/glmocr/api.py +++ b/glmocr/api.py @@ -21,10 +21,7 @@ print(result.to_dict()) """ -import os import re -import shutil -import tempfile from typing import Any, Dict, Generator, List, Literal, Optional, Union, overload from pathlib import Path @@ -129,7 +126,6 @@ def __init__( self._use_maas = self.config_model.pipeline.maas.enabled self._pipeline = None self._maas_client = None - self._session_temp_dir: Optional[str] = None if self._use_maas: # MaaS mode: use MaaSClient for direct API passthrough @@ -150,46 +146,9 @@ def __init__( # Input normalisation helpers # ------------------------------------------------------------------ - def _get_temp_dir(self) -> str: - if self._session_temp_dir is None: - self._session_temp_dir = tempfile.mkdtemp(prefix="glmocr_") - return self._session_temp_dir - @staticmethod - def _detect_suffix(data: bytes) -> str: - """Detect file extension from magic bytes.""" - if data[:5] == b"%PDF-": - return ".pdf" - if data[:8] == b"\x89PNG\r\n\x1a\n": - return ".png" - if data[:2] == b"\xff\xd8": - return ".jpg" - if data[:4] == b"GIF8": - return ".gif" - if len(data) > 12 and data[:4] == b"RIFF" and data[8:12] == b"WEBP": - return ".webp" - if data[:2] == b"BM": - return ".bmp" - return ".png" - - def _bytes_to_temp_file(self, data: bytes) -> str: - """Write *data* to a temp file and return the path. - - The file lives in ``_session_temp_dir`` and is cleaned up by - ``close()``. - """ - suffix = self._detect_suffix(data) - fd, path = tempfile.mkstemp(suffix=suffix, dir=self._get_temp_dir()) - try: - os.write(fd, data) - finally: - os.close(fd) - return path - - def _to_url(self, image: Union[str, bytes, Path]) -> str: - """Convert any supported input to a ``file://`` or ``data:`` URL.""" - if isinstance(image, bytes): - return f"file://{self._bytes_to_temp_file(image)}" + def _to_url(image: Union[str, Path]) -> str: + """Convert a path/URL to a ``file://`` URL.""" if isinstance(image, Path): return f"file://{image.absolute()}" if isinstance(image, str): @@ -492,13 +451,18 @@ def _maas_response_to_pipeline_result( def _build_selfhosted_request( self, images: List[Union[str, bytes, Path]], ) -> Dict[str, Any]: - """Build OpenAI-style request from mixed inputs.""" + """Build request from mixed inputs (paths, URLs, or raw bytes).""" messages: List[Dict[str, Any]] = [{"role": "user", "content": []}] for image in images: - url = self._to_url(image) - messages[0]["content"].append( - {"type": "image_url", "image_url": {"url": url}} - ) + if isinstance(image, bytes): + messages[0]["content"].append( + {"type": "image_bytes", "data": image} + ) + else: + url = self._to_url(image) + messages[0]["content"].append( + {"type": "image_url", "image_url": {"url": url}} + ) return {"messages": messages} def _parse_selfhosted( @@ -587,9 +551,6 @@ def close(self): if self._maas_client: self._maas_client.stop() self._maas_client = None - if self._session_temp_dir: - shutil.rmtree(self._session_temp_dir, ignore_errors=True) - self._session_temp_dir = None def __enter__(self): """Context manager entry.""" diff --git a/glmocr/dataloader/page_loader.py b/glmocr/dataloader/page_loader.py index ec2b36d..80682d7 100644 --- a/glmocr/dataloader/page_loader.py +++ b/glmocr/dataloader/page_loader.py @@ -94,18 +94,19 @@ def __init__(self, config: "PageLoaderConfig"): # Page loading # ========================================================================= - def load_pages(self, sources: Union[str, List[str]]) -> List[Image.Image]: + def load_pages(self, sources: Union[str, bytes, List[Union[str, bytes]]]) -> List[Image.Image]: """Load sources into a list of PIL Images. - Supports image files and PDFs (PDFs are expanded into multiple pages). + Supports image files, PDFs, and raw bytes (PDFs are expanded into + multiple pages). Args: - sources: Single path/URL or a list. + sources: Single path/URL/bytes or a list. Returns: List[PIL.Image.Image] """ - if isinstance(sources, str): + if isinstance(sources, (str, bytes)): sources = [sources] all_pages = [] @@ -116,21 +117,21 @@ def load_pages(self, sources: Union[str, List[str]]) -> List[Image.Image]: return all_pages def load_pages_with_unit_indices( - self, sources: Union[str, List[str]] + self, sources: Union[str, bytes, List[Union[str, bytes]]] ) -> Tuple[List[Image.Image], List[int]]: """Load sources into pages and return unit index per page. - Each input URL is one "unit". For a PDF, all its pages share the same + Each input is one "unit". For a PDF, all its pages share the same unit index. Used by streaming mode to yield one result per input unit. Args: - sources: Single path/URL or a list. + sources: Single path/URL/bytes or a list. Returns: (all_pages, unit_indices) where unit_indices[i] is the unit index - of page i (i.e. which input URL it came from). + of page i (i.e. which input it came from). """ - if isinstance(sources, str): + if isinstance(sources, (str, bytes)): sources = [sources] all_pages: List[Image.Image] = [] @@ -141,29 +142,36 @@ def load_pages_with_unit_indices( unit_indices.extend([unit_idx] * len(pages)) return all_pages, unit_indices - def iter_pages_with_unit_indices(self, sources: Union[str, List[str]]): + def iter_pages_with_unit_indices(self, sources: Union[str, bytes, List[Union[str, bytes]]]): """Stream pages one at a time with unit index per page. Yields (page, unit_idx) so the pipeline can enqueue each page as soon as it is rendered (e.g. PDF: render one page → yield → next page). Args: - sources: Single path/URL or a list. + sources: Single path/URL/bytes or a list. Yields: (PIL.Image, unit_idx) for each page. """ - if isinstance(sources, str): + if isinstance(sources, (str, bytes)): sources = [sources] for unit_idx, source in enumerate(sources): try: for page in self._iter_source(source): yield page, unit_idx except Exception as e: - logger.warning("Skipping source '%s' (unit %d): %s", source, unit_idx, e) + logger.warning("Skipping source (unit %d): %s", unit_idx, e) - def _iter_source(self, source: str): + def _iter_source(self, source: Union[str, bytes]): """Yield pages from a single source one at a time.""" + if isinstance(source, bytes): + if source[:5] == b"%PDF-": + yield from self._iter_pdf_bytes(source) + else: + yield Image.open(BytesIO(source)) + return + if source.startswith("file://"): file_path = source[7:] else: @@ -198,11 +206,45 @@ def _iter_pdf(self, file_path: str): ): yield image - def _load_source(self, source: str) -> List[Image.Image]: + def _iter_pdf_bytes(self, data: bytes): + """Yield PDF pages from raw bytes one at a time.""" + end_page = self._compute_end_page() + for image in pdf_to_images_pil_iter( + data, + dpi=self.pdf_dpi, + max_width_or_height=3500, + start_page_id=0, + end_page_id=end_page, + ): + yield image + + def _load_pdf_bytes(self, data: bytes) -> List[Image.Image]: + """Load all pages from PDF bytes.""" + t0 = time.perf_counter() + end_page = self._compute_end_page() + pages = pdf_to_images_pil( + data, + dpi=self.pdf_dpi, + max_width_or_height=3500, + start_page_id=0, + end_page_id=end_page, + ) + profiler.log( + "pdf_to_images_pil()", + (time.perf_counter() - t0) * 1000, + ) + return pages + + def _load_source(self, source: Union[str, bytes]) -> List[Image.Image]: """Load a single source and return a list of pages. PDFs return all pages; images return a single-page list. """ + if isinstance(source, bytes): + if source[:5] == b"%PDF-": + return self._load_pdf_bytes(source) + return [Image.open(BytesIO(source))] + if source.startswith("file://"): file_path = source[7:] else: diff --git a/glmocr/pipeline/_common.py b/glmocr/pipeline/_common.py index fe2f532..7155607 100644 --- a/glmocr/pipeline/_common.py +++ b/glmocr/pipeline/_common.py @@ -2,29 +2,39 @@ from __future__ import annotations -from typing import Any, Dict, List +from typing import Any, Dict, List, Union from glmocr.utils.logging import get_logger logger = get_logger(__name__) -def extract_image_urls(request_data: Dict[str, Any]) -> List[str]: - """Extract image URLs from an OpenAI-style request payload.""" - image_urls: List[str] = [] +def extract_image_sources(request_data: Dict[str, Any]) -> List[Union[str, bytes]]: + """Extract image sources (URLs or raw bytes) from a request payload.""" + sources: List[Union[str, bytes]] = [] for msg in request_data.get("messages", []): if msg.get("role") == "user": contents = msg.get("content", []) if isinstance(contents, list): for content in contents: if content.get("type") == "image_url": - image_urls.append(content["image_url"]["url"]) - return image_urls - - -def make_original_inputs(image_urls: List[str]) -> List[str]: - """Strip ``file://`` prefix so that original paths are returned.""" - return [(url[7:] if url.startswith("file://") else url) for url in image_urls] + sources.append(content["image_url"]["url"]) + elif content.get("type") == "image_bytes": + sources.append(content["data"]) + return sources + + +def make_original_inputs(sources: List[Union[str, bytes]]) -> List[str]: + """Return display-friendly names for each input source.""" + results: List[str] = [] + for i, src in enumerate(sources): + if isinstance(src, bytes): + results.append(f"document_{i}") + elif src.startswith("file://"): + results.append(src[7:]) + else: + results.append(src) + return results def extract_ocr_content(response: Dict[str, Any]) -> str: diff --git a/glmocr/pipeline/_workers.py b/glmocr/pipeline/_workers.py index 93a5803..2339a77 100644 --- a/glmocr/pipeline/_workers.py +++ b/glmocr/pipeline/_workers.py @@ -50,9 +50,12 @@ def data_loading_worker( state: PipelineState, page_loader: "PageLoader", - image_urls: List[str], + image_sources: List[Any], ) -> None: - """Load pages from *image_urls* and push them onto ``state.page_queue``. + """Load pages from *image_sources* and push them onto ``state.page_queue``. + + *image_sources* may contain file paths (str), ``file://`` URLs, or raw + ``bytes`` (image / PDF content). For each page that is loaded, ``state.register_page()`` is called **before** the page message is enqueued, so that the tracker's @@ -67,13 +70,13 @@ def data_loading_worker( ``UNIT_DONE`` sentinel so the tracker can finalise them with ``region_count=0``. """ - num_units = len(image_urls) + num_units = len(image_sources) page_idx = 0 unit_indices_list: List[int] = [] prev_unit_idx: Optional[int] = None sent_unit_done: set = set() try: - for page, unit_idx in page_loader.iter_pages_with_unit_indices(image_urls): + for page, unit_idx in page_loader.iter_pages_with_unit_indices(image_sources): if state.is_shutdown: break diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index 286db59..f7c0f28 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -25,7 +25,7 @@ from glmocr.postprocess import ResultFormatter from glmocr.utils.logging import get_logger -from glmocr.pipeline._common import extract_image_urls, extract_ocr_content, make_original_inputs +from glmocr.pipeline._common import extract_image_sources, extract_ocr_content, make_original_inputs from glmocr.pipeline._state import PipelineState from glmocr.pipeline._workers import data_loading_worker, layout_worker, recognition_worker from glmocr.pipeline._unit_tracker import UnitTracker @@ -116,14 +116,14 @@ def process( Yields: One ``PipelineResult`` per input URL (image or PDF). """ - image_urls = extract_image_urls(request_data) + image_sources = extract_image_sources(request_data) - if not image_urls: + if not image_sources: yield self._process_passthrough(request_data) return - num_units = len(image_urls) - original_inputs = make_original_inputs(image_urls) + num_units = len(image_sources) + original_inputs = make_original_inputs(image_sources) state = PipelineState( page_maxsize=page_maxsize or self._page_maxsize, @@ -136,7 +136,7 @@ def process( t1 = threading.Thread( target=data_loading_worker, - args=(state, self.page_loader, image_urls), + args=(state, self.page_loader, image_sources), daemon=True, ) t2 = threading.Thread( diff --git a/glmocr/utils/image_utils.py b/glmocr/utils/image_utils.py index 2773a6d..65b404a 100644 --- a/glmocr/utils/image_utils.py +++ b/glmocr/utils/image_utils.py @@ -289,8 +289,15 @@ def _render_page_to_pil(page, dpi: int = 200, max_width_or_height: int = 3500): return image, scale +def _open_pdf(source): + """Open a PDF from a file path (str) or raw bytes.""" + if isinstance(source, bytes): + return fitz.open(stream=source, filetype="pdf") + return fitz.open(source) + + def pdf_to_images_pil( - pdf_path: str, + source, dpi: int = 200, max_width_or_height: int = 3500, start_page_id: int = 0, @@ -299,7 +306,7 @@ def pdf_to_images_pil( """Convert PDF to list of PIL Images. Args: - pdf_path: PDF file path. + source: PDF file path (str) or raw PDF bytes. dpi: Render DPI. max_width_or_height: Max width or height. start_page_id: Start page index (0-based). @@ -310,7 +317,7 @@ def pdf_to_images_pil( """ doc = None try: - doc = fitz.open(pdf_path) + doc = _open_pdf(source) page_count = doc.page_count if end_page_id is None or end_page_id < 0: end_page_id = page_count - 1 @@ -330,7 +337,7 @@ def pdf_to_images_pil( def pdf_to_images_pil_iter( - pdf_path: str, + source, dpi: int = 200, max_width_or_height: int = 3500, start_page_id: int = 0, @@ -342,7 +349,7 @@ def pdf_to_images_pil_iter( downstream can start processing before the whole PDF is loaded. Args: - pdf_path: PDF file path. + source: PDF file path (str) or raw PDF bytes. dpi: Render DPI. max_width_or_height: Max width or height. start_page_id: Start page index (0-based). @@ -352,8 +359,9 @@ def pdf_to_images_pil_iter( PIL.Image per page. """ doc = None + label = source if isinstance(source, str) else "" try: - doc = fitz.open(pdf_path) + doc = _open_pdf(source) page_count = doc.page_count if end_page_id is None or end_page_id < 0: end_page_id = page_count - 1 @@ -363,7 +371,7 @@ def pdf_to_images_pil_iter( try: page = doc.load_page(i) except Exception as e: - logger.warning("Skipping page %d of '%s': %s", i, pdf_path, e) + logger.warning("Skipping page %d of '%s': %s", i, label, e) continue try: image, _ = _render_page_to_pil( @@ -371,7 +379,7 @@ def pdf_to_images_pil_iter( ) yield image except Exception as e: - logger.warning("Skipping page %d of '%s' (render failed): %s", i, pdf_path, e) + logger.warning("Skipping page %d of '%s' (render failed): %s", i, label, e) finally: if doc is not None: doc.close() From 83e4c719472b198ebef0e19eb4150d50085a05aa Mon Sep 17 00:00:00 2001 From: xueyadong Date: Thu, 12 Mar 2026 09:09:39 +0000 Subject: [PATCH 20/38] Improve recognition result postprocess --- glmocr/postprocess/result_formatter.py | 41 ++++++++++++++++++-------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/glmocr/postprocess/result_formatter.py b/glmocr/postprocess/result_formatter.py index da228e8..ed9011d 100644 --- a/glmocr/postprocess/result_formatter.py +++ b/glmocr/postprocess/result_formatter.py @@ -267,7 +267,18 @@ def _format_content(self, content: Any, label: str, native_label: str) -> str: if content is None: return content - content = self._clean_content(str(content)) + if label == "table": + if content.startswith(""): + content = content.strip() + else: + content = self._clean_content(str(content)) + elif label == "formula": + if content.startswith("$$") and content.endswith("$$"): + content = content.strip() + else: + content = self._clean_content(str(content)) + else: + content = self._clean_content(str(content)) # Title formatting if native_label == "doc_title": @@ -283,20 +294,26 @@ def _format_content(self, content: Any, label: str, native_label: str) -> str: # Formula formatting if label == "formula": - if content.startswith("$$") and content.endswith("$$"): - content = content[2:-2].strip() - content = "$$\n" + content + "\n$$" - elif content.startswith("\\[") and content.endswith("\\]"): - content = content[2:-2].strip() - content = "$$\n" + content + "\n$$" - elif content.startswith("\\(") and content.endswith("\\)"): - content = content[2:-2].strip() - content = "$$\n" + content + "\n$$" - else: - content = "$$\n" + content + "\n$$" + if ( + content.startswith("$$") + or content.startswith("\\[") + or content.startswith("\\(") + ): + content = content[2:].strip() + if ( + content.endswith("$$") + or content.endswith("\\]") + or content.endswith("\\)") + ): + content = content[:-2].strip() + content = "$$\n" + content + "\n$$" # Text formatting if label == "text": + # Code blocks + if content.startswith("```") and (not content.endswith("```")): + content = content + "\n```" + # Bullet points if ( content.startswith("·") From 28f108b3750866e03cbd17291f504f492ef8945a Mon Sep 17 00:00:00 2001 From: xueyadong Date: Fri, 13 Mar 2026 08:16:05 +0000 Subject: [PATCH 21/38] Add multi-GPU deployment support for GLM-OCR --- examples/multi-gpu-deploy/README.md | 120 +++++ examples/multi-gpu-deploy/README_zh.md | 120 +++++ examples/multi-gpu-deploy/coordinator.py | 621 +++++++++++++++++++++++ examples/multi-gpu-deploy/engine.py | 170 +++++++ examples/multi-gpu-deploy/gpu_utils.py | 145 ++++++ examples/multi-gpu-deploy/launch.py | 141 +++++ examples/multi-gpu-deploy/worker.py | 104 ++++ 7 files changed, 1421 insertions(+) create mode 100644 examples/multi-gpu-deploy/README.md create mode 100644 examples/multi-gpu-deploy/README_zh.md create mode 100644 examples/multi-gpu-deploy/coordinator.py create mode 100644 examples/multi-gpu-deploy/engine.py create mode 100644 examples/multi-gpu-deploy/gpu_utils.py create mode 100644 examples/multi-gpu-deploy/launch.py create mode 100644 examples/multi-gpu-deploy/worker.py diff --git a/examples/multi-gpu-deploy/README.md b/examples/multi-gpu-deploy/README.md new file mode 100644 index 0000000..df39fa0 --- /dev/null +++ b/examples/multi-gpu-deploy/README.md @@ -0,0 +1,120 @@ +# Multi-GPU Deployment for GLM-OCR + +Automatically launch sglang/vLLM inference services across multiple GPUs, distribute image files evenly, and run the GLM-OCR pipeline in parallel for maximum throughput. + +Each GPU hosts both an inference server (sglang or vLLM) and a layout detection model, forming a self-contained processing unit with zero cross-GPU communication. + +## Features + +- **Auto GPU detection** — discovers all available GPUs and filters by free VRAM +- **Dynamic port allocation** — automatically skips occupied ports +- **Fault tolerance** — failed GPUs are skipped, files are redistributed to healthy GPUs +- **Global progress bar** — real-time `tqdm` progress across all GPUs +- **Graceful shutdown** — `Ctrl+C` cleanly terminates all subprocesses; double `Ctrl+C` force-kills +- **Centralized logging** — all engine/worker logs saved under `logs//` +- **Speculative decoding** — MTP enabled by default for both sglang and vLLM + +## Quick Start + +```bash +# Use all available GPUs with sglang (default) +python examples/multi-gpu-deploy/launch.py -i ./images -o ./output -m /path/to/GLM-OCR + +# Specify GPUs and use vLLM +python examples/multi-gpu-deploy/launch.py -i ./images -o ./output --engine vllm --gpus 0,1,2,3 + +# Custom model path and VRAM threshold +python examples/multi-gpu-deploy/launch.py -i ./images -o ./output -m /path/to/GLM-OCR --min-free-mb 20000 +``` + +## Parameters + +| Parameter | Default | Description | +|---|---|---| +| `-i`, `--input` | *required* | Input image file or directory (recursive) | +| `-o`, `--output` | `./output` | Output directory for results | +| `-m`, `--model` | `zai-org/GLM-OCR` | Model name or local path | +| `--engine` | `sglang` | Inference engine: `sglang` or `vllm` | +| `--gpus` | `auto` | GPU IDs (comma-separated) or `auto` for all available | +| `--base-port` | `8080` | Base port for engine services | +| `--min-free-mb` | `16000` | Minimum free GPU memory in MB to use a GPU | +| `--timeout` | `600` | Engine startup timeout in seconds | +| `--engine-args` | *none* | Extra arguments passed to the engine | +| `-c`, `--config` | *none* | Path to a custom glmocr config YAML | +| `--log-level` | `WARNING` | Log level for worker processes | + + +## Examples + +### Basic usage + +```bash +python examples/multi-gpu-deploy/launch.py -i /data/documents -o /data/results +``` + +### Use vLLM with specific GPUs + +```bash +python examples/multi-gpu-deploy/launch.py \ + -i /data/documents \ + -o /data/results \ + --engine vllm \ + --gpus 0,2,4,6 +``` + +### Custom engine arguments + +```bash +# sglang with custom memory fraction +python examples/multi-gpu-deploy/launch.py \ + -i /data/documents \ + -o /data/results \ + --engine-args "--mem-fraction-static 0.85" +``` + +### Custom config YAML + +```bash +python examples/multi-gpu-deploy/launch.py \ + -i /data/documents \ + -o /data/results \ + --config my_config.yaml +``` + +## Logs + +All logs are saved under `logs//`: + +| File | Content | +|---|---| +| `main.log` | Coordinator stdout/stderr | +| `engine_gpu_port

.log` | Engine service output for each GPU | +| `worker_gpu.log` | Worker process output for each GPU | +| `failed_files.json` | Aggregated list of failed files (if any) | + +## Troubleshooting + +**Q: Some ports are occupied, will it still work?** + +Yes. The launcher automatically scans for available ports starting from `--base-port` and skips any that are in use. + +**Q: A GPU runs out of memory mid-processing. What happens?** + +The worker on that GPU will fail, but other GPUs continue processing. Failed files are logged in `failed_files.json` for later re-processing. + +**Q: How do I re-run only the failed files?** + +Copy the failed files to a directory and run the launcher again pointing to that directory. + +## File Structure + +``` +examples/multi-gpu-deploy/ +├── launch.py # Entry point and CLI argument parser +├── coordinator.py # Orchestration: GPU detection, engine/worker lifecycle +├── engine.py # Engine service management and progress tracking +├── worker.py # Worker process: GLM-OCR pipeline execution +├── gpu_utils.py # GPU detection, port checking, file sharding +├── README.md # This file (English) +└── README_zh.md # Chinese documentation +``` diff --git a/examples/multi-gpu-deploy/README_zh.md b/examples/multi-gpu-deploy/README_zh.md new file mode 100644 index 0000000..5aeb2bb --- /dev/null +++ b/examples/multi-gpu-deploy/README_zh.md @@ -0,0 +1,120 @@ +# GLM-OCR 多卡并行部署 + +自动在多张 GPU 上启动 sglang/vLLM 推理服务,均匀分配图像文件,并行运行 GLM-OCR 流水线以获得最大吞吐量。 + +每张 GPU 同时承载推理服务(sglang 或 vLLM)和版面检测模型,形成独立的处理单元,GPU 之间零通信开销。 + +## 特性 + +- **自动检测 GPU** — 自动发现所有可用 GPU,按空闲显存过滤 +- **动态端口分配** — 自动跳过已被占用的端口 +- **容错机制** — 失败的 GPU 自动跳过,文件重新分配到健康的 GPU 上 +- **全局进度条** — 实时 `tqdm` 进度展示,汇总所有 GPU 的处理进度 +- **优雅退出** — `Ctrl+C` 清理所有子进程;双击 `Ctrl+C` 强制终止 +- **集中日志** — 所有引擎/Worker 日志保存在 `logs/<时间戳>/` 目录下 +- **投机解码** — sglang 和 vLLM 均默认启用 MTP(多 Token 预测) + +## 快速开始 + +```bash +# 使用所有可用 GPU,默认 sglang 引擎 +python examples/multi-gpu-deploy/launch.py -i ./images -o ./output -m /path/to/GLM-OCR + +# 指定 GPU 并使用 vLLM +python examples/multi-gpu-deploy/launch.py -i ./images -o ./output --engine vllm --gpus 0,1,2,3 + +# 自定义模型路径和显存阈值 +python examples/multi-gpu-deploy/launch.py -i ./images -o ./output -m /path/to/GLM-OCR --min-free-mb 20000 +``` + +## 参数说明 + +| 参数 | 默认值 | 说明 | +|---|---|---| +| `-i`, `--input` | *必填* | 输入图像文件或目录(支持递归扫描) | +| `-o`, `--output` | `./output` | 输出结果目录 | +| `-m`, `--model` | `zai-org/GLM-OCR` | 模型名称或本地路径 | +| `--engine` | `sglang` | 推理引擎:`sglang` 或 `vllm` | +| `--gpus` | `auto` | GPU 编号(逗号分隔)或 `auto` 自动检测 | +| `--base-port` | `8080` | 推理服务起始端口 | +| `--min-free-mb` | `16000` | 使用 GPU 所需的最小空闲显存(MB) | +| `--timeout` | `600` | 推理服务启动超时时间(秒) | +| `--engine-args` | *无* | 传递给推理引擎的额外参数 | +| `-c`, `--config` | *无* | 自定义 glmocr 配置 YAML 路径 | +| `--log-level` | `WARNING` | Worker 进程的日志级别 | + + +## 使用示例 + +### 基本用法 + +```bash +python examples/multi-gpu-deploy/launch.py -i /data/documents -o /data/results +``` + +### 使用 vLLM 并指定 GPU + +```bash +python examples/multi-gpu-deploy/launch.py \ + -i /data/documents \ + -o /data/results \ + --engine vllm \ + --gpus 0,2,4,6 +``` + +### 自定义引擎参数 + +```bash +# sglang 设置显存占用比例 +python examples/multi-gpu-deploy/launch.py \ + -i /data/documents \ + -o /data/results \ + --engine-args "--mem-fraction-static 0.85" +``` + +### 使用自定义配置文件 + +```bash +python examples/multi-gpu-deploy/launch.py \ + -i /data/documents \ + -o /data/results \ + --config my_config.yaml +``` + +## 日志 + +所有日志保存在 `logs/<时间戳>/` 目录下: + +| 文件 | 内容 | +|---|---| +| `main.log` | 协调器主进程的 stdout/stderr | +| `engine_gpu_port

.log` | 各 GPU 的推理引擎输出 | +| `worker_gpu.log` | 各 GPU 的 Worker 进程输出 | +| `failed_files.json` | 汇总的失败文件列表(如有) | + +## 常见问题 + +**Q:某些端口被占用了,还能正常工作吗?** + +可以。启动器会从 `--base-port` 开始自动扫描可用端口,跳过所有已被占用的端口。 + +**Q:某张 GPU 在处理过程中显存不足怎么办?** + +该 GPU 上的 Worker 会失败,但其他 GPU 继续处理。失败的文件会记录在 `failed_files.json` 中,方便后续重新处理。 + +**Q:如何只重跑失败的文件?** + +将失败的文件复制到一个目录中,然后重新运行启动器指向该目录即可。 + +## 文件结构 + +``` +examples/multi-gpu-deploy/ +├── launch.py # 入口文件与命令行参数解析 +├── coordinator.py # 编排器:GPU 检测、引擎/Worker 生命周期管理 +├── engine.py # 推理引擎管理与进度追踪 +├── worker.py # Worker 进程:GLM-OCR 流水线执行 +├── gpu_utils.py # GPU 检测、端口检查、文件分片 +├── README.md # 英文文档 +└── README_zh.md # 本文件(中文文档) +``` diff --git a/examples/multi-gpu-deploy/coordinator.py b/examples/multi-gpu-deploy/coordinator.py new file mode 100644 index 0000000..1124e45 --- /dev/null +++ b/examples/multi-gpu-deploy/coordinator.py @@ -0,0 +1,621 @@ +"""Coordinator — orchestrates multi-GPU engine startup, file sharding, +worker launching, progress monitoring, and graceful shutdown.""" + +import io +import os +import sys +import json +import time +import signal +import tempfile +import subprocess +from pathlib import Path +import concurrent.futures +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +from engine import read_progress, start_engine, wait_for_service +from gpu_utils import ( + _print_err, + get_gpu_info, + shard_files, + collect_files, + find_available_ports, + filter_available_gpus, +) + +_PACKAGE_DIR = os.path.dirname(os.path.abspath(__file__)) + + +class _TeeWriter: + """Write to both a terminal stream and a log file simultaneously.""" + + def __init__(self, terminal: io.TextIOBase, log_file: io.TextIOBase): + self._terminal = terminal + self._log_file = log_file + + def write(self, data: str) -> int: + self._terminal.write(data) + self._log_file.write(data) + self._log_file.flush() + return len(data) + + def flush(self) -> None: + self._terminal.flush() + self._log_file.flush() + + def fileno(self) -> int: + return self._terminal.fileno() + + def isatty(self) -> bool: + return self._terminal.isatty() + + +class Coordinator: + """Orchestrates multi-GPU OCR processing.""" + + def __init__(self, args): + self.args = args + self.engine_procs: Dict[int, subprocess.Popen] = {} + self.worker_procs: Dict[int, subprocess.Popen] = {} + self.progress_files: Dict[int, str] = {} + self.file_handles: List[Any] = [] + self.tmp_dir = tempfile.mkdtemp(prefix="glmocr_mgpu_") + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.log_dir = Path("logs") / timestamp + self.log_dir.mkdir(parents=True, exist_ok=True) + + self._shutdown = False + self._start_time = time.time() + self._orig_stdout: Any = None + self._orig_stderr: Any = None + self._main_log_fh: Any = None + + # ------------------------------------------------------------------ + # Public entry + # ------------------------------------------------------------------ + + def run(self) -> None: + self._install_tee() + + old_sigint = signal.getsignal(signal.SIGINT) + old_sigterm = signal.getsignal(signal.SIGTERM) + signal.signal(signal.SIGINT, self._on_signal) + signal.signal(signal.SIGTERM, self._on_signal) + + try: + self._run_impl() + finally: + signal.signal(signal.SIGINT, old_sigint) + signal.signal(signal.SIGTERM, old_sigterm) + self._cleanup() + self._uninstall_tee() + + def _install_tee(self) -> None: + """Redirect stdout and stderr so all output also goes to main.log. + + The original stderr is saved as ``_orig_stderr`` and passed to tqdm + directly, so progress-bar control characters never reach the log. + """ + log_path = self.log_dir / "main.log" + self._main_log_fh = open(log_path, "w") + self._orig_stdout = sys.stdout + self._orig_stderr = sys.stderr + sys.stdout = _TeeWriter(self._orig_stdout, self._main_log_fh) + sys.stderr = _TeeWriter(self._orig_stderr, self._main_log_fh) + + def _uninstall_tee(self) -> None: + """Restore original stdout/stderr.""" + if self._orig_stdout is not None: + sys.stdout = self._orig_stdout + if self._orig_stderr is not None: + sys.stderr = self._orig_stderr + if self._main_log_fh is not None: + try: + self._main_log_fh.close() + except Exception: + pass + + # ------------------------------------------------------------------ + # Signal handling + # ------------------------------------------------------------------ + + def _on_signal(self, signum, frame): + if self._shutdown: + _print_err("\n[FORCE] Second signal received, force killing...") + self._force_kill_all() + os._exit(1) + _print_err( + f"\n[INFO] Signal {signum} received, shutting down gracefully..." + ) + self._shutdown = True + + # ------------------------------------------------------------------ + # Main pipeline + # ------------------------------------------------------------------ + + def _run_impl(self) -> None: + self._step1_detect_gpus() + if self._shutdown: + return + + available = self._available_gpus + n_gpus = len(available) + + total_files, files, shards, input_root = self._step2_collect_files( + n_gpus + ) + if self._shutdown: + return + + gpu_port_map = self._step3_start_engines(available) + if self._shutdown: + return + + ready_pairs, shards = self._step4_wait_services( + available, gpu_port_map, shards, total_files + ) + if self._shutdown or not ready_pairs: + return + + self._step5_start_workers(ready_pairs, shards, input_root) + if self._shutdown: + return + + self._monitor_progress(total_files) + self._print_summary(total_files) + + # ------------------------------------------------------------------ + # Step 1 — Detect GPUs + # ------------------------------------------------------------------ + + def _step1_detect_gpus(self) -> None: + print("=" * 60) + print(" GLM-OCR Multi-GPU Launcher") + print("=" * 60) + + gpus = get_gpu_info() + if not gpus: + _print_err("[ERROR] No GPUs found.") + sys.exit(1) + + print(f"\n[1/5] Detected {len(gpus)} GPU(s):") + for g in gpus: + pct = g["used_mb"] / max(g["total_mb"], 1) + bar_len = int(pct * 20) + bar = "\u2588" * bar_len + "\u2591" * (20 - bar_len) + print( + f" GPU {g['id']:>1}: {g['name']:<26} " + f"[{bar}] {g['used_mb']:>5}/{g['total_mb']}MB " + f"(free: {g['free_mb']}MB)" + ) + + gpu_ids = None + if self.args.gpus != "auto": + gpu_ids = [int(x.strip()) for x in self.args.gpus.split(",")] + + available = filter_available_gpus(gpus, self.args.min_free_mb, gpu_ids) + if not available: + _print_err( + f"\n[ERROR] No GPUs have >= {self.args.min_free_mb}MB " + "free memory." + ) + sys.exit(1) + + print( + f"\n Using {len(available)} GPU(s): " + f"{[g['id'] for g in available]}" + ) + self._available_gpus = available + + # ------------------------------------------------------------------ + # Step 2 — Collect and shard files + # ------------------------------------------------------------------ + + def _step2_collect_files( + self, n_gpus: int + ) -> Tuple[int, List[str], List[List[str]], Optional[str]]: + print(f"\n[2/5] Scanning: {self.args.input}") + files = collect_files(self.args.input) + total_files = len(files) + print(f" Found {total_files} file(s)") + + shards = shard_files(files, n_gpus) + for gpu, shard in zip(self._available_gpus, shards): + print(f" GPU {gpu['id']}: {len(shard)} files") + + input_path = Path(self.args.input) + input_root = ( + str(input_path.absolute()) if input_path.is_dir() else None + ) + + self.log_dir.mkdir(parents=True, exist_ok=True) + return total_files, files, shards, input_root + + # ------------------------------------------------------------------ + # Step 3 — Start engine services + # ------------------------------------------------------------------ + + def _step3_start_engines(self, available: List[Dict]) -> Dict[int, int]: + print(f"\n[3/5] Starting {self.args.engine} services...") + + gpu_port_map: Dict[int, int] = {} + + ports = find_available_ports(self.args.base_port, len(available)) + if len(ports) < len(available): + _print_err( + f" [WARN] Only found {len(ports)} available ports " + f"(need {len(available)}). Some GPUs will be skipped." + ) + + for i, gpu in enumerate(available): + if self._shutdown: + break + if i >= len(ports): + _print_err(f" [SKIP] GPU {gpu['id']}: no available port") + continue + port = ports[i] + gpu_id = gpu["id"] + + proc, log_path, log_fh = start_engine( + engine=self.args.engine, + model=self.args.model, + gpu_id=gpu_id, + port=port, + extra_args=self.args.engine_args or "", + log_dir=str(self.log_dir), + ) + self.engine_procs[gpu_id] = proc + self.file_handles.append(log_fh) + gpu_port_map[gpu_id] = port + print( + f" GPU {gpu_id} -> port {port} " + f"(pid {proc.pid}, log: {log_path.name})" + ) + + return gpu_port_map + + # ------------------------------------------------------------------ + # Step 4 — Wait for services to be ready + # ------------------------------------------------------------------ + + def _step4_wait_services( + self, + available: List[Dict], + gpu_port_map: Dict[int, int], + shards: List[List[str]], + total_files: int, + ) -> Tuple[List[Tuple[int, int]], List[List[str]]]: + print( + f"\n[4/5] Waiting for services to be ready " + f"(timeout: {self.args.timeout}s)..." + ) + + ready_pairs: List[Tuple[int, int]] = [] + ready_shard_indices: List[int] = [] + + future_map: Dict[ + concurrent.futures.Future, Tuple[int, int, int] + ] = {} + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(available) + ) as executor: + for i, gpu in enumerate(available): + gpu_id = gpu["id"] + if gpu_id not in gpu_port_map: + continue + port = gpu_port_map[gpu_id] + proc = self.engine_procs[gpu_id] + future = executor.submit( + wait_for_service, port, proc, self.args.timeout + ) + future_map[future] = (i, gpu_id, port) + + for future in concurrent.futures.as_completed(future_map): + if self._shutdown: + break + i, gpu_id, port = future_map[future] + success, elapsed = future.result() + if success: + print( + f" GPU {gpu_id} (port {port}): " + f"Ready ({elapsed}s)" + ) + ready_pairs.append((gpu_id, port)) + ready_shard_indices.append(i) + else: + proc = self.engine_procs[gpu_id] + if proc.poll() is not None: + print( + f" GPU {gpu_id} (port {port}): CRASHED " + f"(exit={proc.returncode}, {elapsed}s)" + ) + else: + print( + f" GPU {gpu_id} (port {port}): " + f"TIMEOUT ({elapsed}s)" + ) + self._kill_proc(proc) + + if not ready_pairs: + _print_err( + "\n[ERROR] No engine services started successfully!\n" + f" Check logs: {self.log_dir}" + ) + return [], [] + + n_ready = len(ready_pairs) + n_total = len(available) + if n_ready < n_total: + all_files: List[str] = [] + for shard in shards: + all_files.extend(shard) + shards = shard_files(all_files, n_ready) + print( + f"\n {n_ready}/{n_total} GPUs ready. " + f"Redistributed {total_files} files across " + f"{n_ready} GPU(s)." + ) + else: + shards = [shards[i] for i in ready_shard_indices] + + return ready_pairs, shards + + # ------------------------------------------------------------------ + # Step 5 — Start workers + # ------------------------------------------------------------------ + + def _step5_start_workers( + self, + ready_pairs: List[Tuple[int, int]], + shards: List[List[str]], + input_root: Optional[str], + ) -> None: + print("\n[5/5] Starting workers...") + + entry_point = os.path.join(_PACKAGE_DIR, "launch.py") + + for (gpu_id, port), shard in zip(ready_pairs, shards): + if self._shutdown: + break + + filelist_path = os.path.join( + self.tmp_dir, f"shard_gpu{gpu_id}.json" + ) + with open(filelist_path, "w") as f: + json.dump(shard, f) + + progress_path = os.path.join( + self.tmp_dir, f"progress_gpu{gpu_id}.json" + ) + self.progress_files[gpu_id] = progress_path + + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + + worker_cmd = [ + sys.executable, + entry_point, + "--worker", + "--gpu-id", + str(gpu_id), + "--port", + str(port), + "--filelist", + filelist_path, + "--output", + self.args.output, + "--progress-file", + progress_path, + "--log-level", + self.args.log_level or "WARNING", + ] + if input_root: + worker_cmd.extend(["--input-root", input_root]) + if self.args.config: + worker_cmd.extend(["--config", self.args.config]) + + worker_log = self.log_dir / f"worker_gpu{gpu_id}.log" + wfh = open(worker_log, "w") + self.file_handles.append(wfh) + + proc = subprocess.Popen( + worker_cmd, + env=env, + stdout=wfh, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + self.worker_procs[gpu_id] = proc + print( + f" GPU {gpu_id}: {len(shard)} files " + f"-> worker pid {proc.pid}" + ) + + # ------------------------------------------------------------------ + # Progress monitoring + # ------------------------------------------------------------------ + + def _monitor_progress(self, total_files: int) -> None: + print(f"\n{'=' * 60}") + + try: + from tqdm import tqdm + + pbar: Any = tqdm( + total=total_files, + desc="Total", + unit="file", + file=self._orig_stderr or sys.stderr, + dynamic_ncols=True, + ) + except ImportError: + pbar = None + + last_total = 0 + + while not self._shutdown: + all_done = True + total_completed = 0 + total_failed = 0 + gpu_display: Dict[int, str] = {} + + for gpu_id, proc in self.worker_procs.items(): + prog = read_progress( + self.progress_files.get(gpu_id, "") + ) + if prog: + total_completed += prog["completed"] + total_failed += prog["failed"] + gpu_display[gpu_id] = ( + f"{prog['completed']}/{prog['total']}" + ) + status = prog["status"] + else: + gpu_display[gpu_id] = "init" + status = "init" + + alive = proc.poll() is None + done_status = status in ("done", "done_with_errors") + errored = status.startswith("error") + + if alive and not done_status and not errored: + all_done = False + elif not alive and not done_status: + gpu_display[gpu_id] += f"(exit:{proc.returncode})" + + delta = total_completed - last_total + if pbar and delta > 0: + pbar.update(delta) + last_total = total_completed + + if pbar: + parts = [ + f"G{gid}:{s}" + for gid, s in sorted(gpu_display.items()) + ] + pbar.set_postfix_str(" ".join(parts), refresh=True) + + if all_done: + total_completed, total_failed = self._aggregate_progress() + delta = total_completed - last_total + if pbar and delta > 0: + pbar.update(delta) + break + + time.sleep(1) + + if pbar: + pbar.close() + + # ------------------------------------------------------------------ + # Summary + # ------------------------------------------------------------------ + + def _print_summary(self, total_files: int) -> None: + total_completed, total_failed = self._aggregate_progress() + elapsed = int(time.time() - self._start_time) + mins, secs = divmod(elapsed, 60) + + print(f"\n{'=' * 60}") + print(" Summary") + print(f"{'=' * 60}") + + for gpu_id in sorted(self.progress_files.keys()): + prog = read_progress(self.progress_files[gpu_id]) + if prog: + print( + f" GPU {gpu_id}: " + f"{prog['completed']}/{prog['total']} done, " + f"{prog['failed']} failed [{prog['status']}]" + ) + + print( + f"\n Total: {total_completed}/{total_files} completed, " + f"{total_failed} failed" + ) + print(f" Time: {mins}m {secs}s") + print(f" Output: {self.args.output}") + print(f" Logs: {self.log_dir}") + + if total_failed > 0: + self._report_failures() + + def _report_failures(self) -> None: + all_failed: List[Dict] = [] + for gpu_id in self.progress_files: + fp = self.progress_files[gpu_id].replace( + ".json", "_failed.json" + ) + if os.path.exists(fp): + try: + with open(fp) as f: + all_failed.extend(json.load(f)) + except Exception: + pass + if all_failed: + summary = self.log_dir / "failed_files.json" + with open(summary, "w") as f: + json.dump(all_failed, f, ensure_ascii=False, indent=2) + print(f"\n Failed files: {summary}") + launch = os.path.join(_PACKAGE_DIR, "launch.py") + print( + " Re-run with just the failed files:\n" + f" python {launch} " + f"-i -o {self.args.output}" + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _aggregate_progress(self) -> Tuple[int, int]: + total_completed = 0 + total_failed = 0 + for gpu_id in self.progress_files: + prog = read_progress(self.progress_files[gpu_id]) + if prog: + total_completed += prog["completed"] + total_failed += prog["failed"] + return total_completed, total_failed + + def _kill_proc(self, proc: subprocess.Popen) -> None: + if proc.poll() is not None: + return + try: + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + try: + proc.wait(timeout=15) + except subprocess.TimeoutExpired: + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + proc.wait(timeout=5) + except (ProcessLookupError, OSError, PermissionError): + pass + + def _cleanup(self) -> None: + _print_err("\n[INFO] Cleaning up subprocesses...") + + for gpu_id, proc in self.worker_procs.items(): + self._kill_proc(proc) + + for gpu_id, proc in self.engine_procs.items(): + self._kill_proc(proc) + + for fh in self.file_handles: + try: + fh.close() + except Exception: + pass + + _print_err("[INFO] All processes stopped.") + + def _force_kill_all(self) -> None: + for proc in list(self.worker_procs.values()) + list( + self.engine_procs.values() + ): + try: + if proc.poll() is None: + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + except Exception: + pass diff --git a/examples/multi-gpu-deploy/engine.py b/examples/multi-gpu-deploy/engine.py new file mode 100644 index 0000000..c2a589b --- /dev/null +++ b/examples/multi-gpu-deploy/engine.py @@ -0,0 +1,170 @@ +"""Engine service management and progress tracking.""" + +import os +import sys +import json +import time +import shlex +import subprocess +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +# ========================================================================= +# Engine Service Management +# ========================================================================= + +def build_engine_cmd( + engine: str, + model: str, + port: int, + extra_args: str = "", +) -> List[str]: + """Build command to start sglang or vLLM service. + + Default speculative-decoding flags are included for each engine so that + GLM-OCR runs with MTP (multi-token prediction) out of the box. Pass + ``extra_args`` to override or extend these defaults. + """ + if engine == "sglang": + cmd = [ + sys.executable, + "-m", + "sglang.launch_server", + "--model", + model, + "--port", + str(port), + "--log-level", + "warning", + "--speculative-algorithm", + "NEXTN", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--served-model-name", + "glm-ocr", + ] + elif engine == "vllm": + cmd = [ + "vllm", + "serve", + model, + "--port", + str(port), + "--allowed-local-media-path", + "/", + "--speculative-config", + '{"method": "mtp", "num_speculative_tokens": 1}', + "--served-model-name", + "glm-ocr", + ] + else: + raise ValueError(f"Unknown engine: {engine}") + + if extra_args: + cmd.extend(shlex.split(extra_args)) + return cmd + + +def start_engine( + engine: str, + model: str, + gpu_id: int, + port: int, + extra_args: str = "", + log_dir: str = "/tmp", + engine_log_level: str = "warning", +) -> Tuple[subprocess.Popen, Path, Any]: + """Start an engine service on a specific GPU. + + Returns (process, log_path, log_file_handle). + """ + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + if engine == "vllm": + env["VLLM_LOGGING_LEVEL"] = engine_log_level.upper() + + cmd = build_engine_cmd(engine, model, port, extra_args) + log_path = Path(log_dir) / f"engine_gpu{gpu_id}_port{port}.log" + log_fh = open(log_path, "w") + + proc = subprocess.Popen( + cmd, + env=env, + stdout=log_fh, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + return proc, log_path, log_fh + + +def wait_for_service( + port: int, + proc: subprocess.Popen, + timeout: int = 600, + interval: int = 5, +) -> Tuple[bool, int]: + """Wait for a service to become ready by polling /v1/models. + + Returns (success, elapsed_seconds). + """ + import urllib.request + import urllib.error + + url = f"http://127.0.0.1:{port}/v1/models" + start = time.time() + + while time.time() - start < timeout: + if proc.poll() is not None: + return False, int(time.time() - start) + try: + req = urllib.request.Request(url, method="GET") + with urllib.request.urlopen(req, timeout=5) as resp: + if resp.status == 200: + return True, int(time.time() - start) + except Exception: + pass + time.sleep(interval) + + return False, int(time.time() - start) + + +# ========================================================================= +# Progress Tracking +# ========================================================================= + +def write_progress( + path: str, + completed: int, + total: int, + failed: int = 0, + status: str = "running", +) -> None: + """Atomically write progress to a JSON file.""" + data = { + "completed": completed, + "total": total, + "failed": failed, + "status": status, + "timestamp": time.time(), + } + tmp = path + ".tmp" + try: + with open(tmp, "w") as f: + json.dump(data, f) + os.replace(tmp, path) + except OSError: + pass + + +def read_progress(path: str) -> Optional[Dict]: + """Read progress from a JSON file.""" + try: + with open(path, "r") as f: + return json.load(f) + except (FileNotFoundError, json.JSONDecodeError, OSError): + return None diff --git a/examples/multi-gpu-deploy/gpu_utils.py b/examples/multi-gpu-deploy/gpu_utils.py new file mode 100644 index 0000000..d492f40 --- /dev/null +++ b/examples/multi-gpu-deploy/gpu_utils.py @@ -0,0 +1,145 @@ +"""GPU detection, port checking, file collection, and sharding utilities.""" + +import os +import sys +import socket +import subprocess +from pathlib import Path +from typing import Any, Dict, List, Optional + +SUPPORTED_SUFFIXES = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".pdf"} + +DEFAULT_BASE_PORT = 8080 +DEFAULT_MIN_FREE_MB = 16000 + + +def _print_err(*args, **kwargs): + kwargs.setdefault("file", sys.stderr) + print(*args, **kwargs) + + +# ========================================================================= +# GPU Detection +# ========================================================================= + +def get_gpu_info() -> List[Dict[str, Any]]: + """Query GPU information via nvidia-smi.""" + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=index,name,memory.total,memory.free,memory.used", + "--format=csv,noheader,nounits", + ], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode != 0: + return [] + gpus = [] + for line in result.stdout.strip().split("\n"): + parts = [p.strip() for p in line.split(",")] + if len(parts) >= 5: + gpus.append( + { + "id": int(parts[0]), + "name": parts[1], + "total_mb": int(parts[2]), + "free_mb": int(parts[3]), + "used_mb": int(parts[4]), + } + ) + return gpus + except FileNotFoundError: + _print_err("[ERROR] nvidia-smi not found. Is the NVIDIA driver installed?") + return [] + except Exception as e: + _print_err(f"[ERROR] Failed to query GPU info: {e}") + return [] + + +def filter_available_gpus( + gpus: List[Dict], + min_free_mb: int, + gpu_ids: Optional[List[int]] = None, +) -> List[Dict]: + """Filter GPUs that have enough free VRAM.""" + available = [] + for gpu in gpus: + if gpu_ids is not None and gpu["id"] not in gpu_ids: + continue + if gpu["free_mb"] >= min_free_mb: + available.append(gpu) + else: + _print_err( + f" [SKIP] GPU {gpu['id']} ({gpu['name']}): " + f"{gpu['free_mb']}MB free < {min_free_mb}MB required" + ) + return available + + +# ========================================================================= +# Port Checking +# ========================================================================= + +def is_port_available(port: int) -> bool: + """Check if a TCP port is available for binding.""" + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("127.0.0.1", port)) + return True + except OSError: + return False + + +def find_available_ports(base_port: int, count: int) -> List[int]: + """Find *count* available ports starting from *base_port*. + + Skips any port that is already in use. + """ + ports: List[int] = [] + port = base_port + max_port = base_port + count * 10 + while len(ports) < count and port < max_port: + if is_port_available(port): + ports.append(port) + else: + _print_err(f" [SKIP] Port {port} is occupied, trying next...") + port += 1 + return ports + + +# ========================================================================= +# File Collection and Sharding +# ========================================================================= + +def collect_files(input_path: str) -> List[str]: + """Collect all supported image/PDF files from input path (recursive).""" + path = Path(input_path) + if path.is_file(): + if path.suffix.lower() in SUPPORTED_SUFFIXES: + return [str(path.absolute())] + raise ValueError(f"Unsupported file type: {path.suffix}") + if path.is_dir(): + seen: set = set() + files: List[str] = [] + for p in sorted(path.rglob("*")): + if p.is_file() and p.suffix.lower() in SUPPORTED_SUFFIXES: + abs_p = str(p.absolute()) + if abs_p not in seen: + seen.add(abs_p) + files.append(abs_p) + if not files: + raise ValueError(f"No image/PDF files found in: {input_path}") + return files + raise ValueError(f"Path does not exist: {input_path}") + + +def shard_files(files: List[str], n_shards: int) -> List[List[str]]: + """Distribute files across shards using round-robin.""" + shards: List[List[str]] = [[] for _ in range(n_shards)] + for i, f in enumerate(files): + shards[i % n_shards].append(f) + return shards diff --git a/examples/multi-gpu-deploy/launch.py b/examples/multi-gpu-deploy/launch.py new file mode 100644 index 0000000..6edca85 --- /dev/null +++ b/examples/multi-gpu-deploy/launch.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +""" +Multi-GPU Launcher for GLM-OCR + +Automatically launches sglang/vLLM services across multiple GPUs, distributes +files evenly, and runs the GLM-OCR pipeline in parallel for maximum throughput. + +Each GPU hosts both a sglang/vLLM inference server and a layout detection model, +forming a self-contained processing unit with zero cross-GPU communication. + +Usage: + python examples/multi-gpu-deploy/launch.py -i ./images -o ./output + python examples/multi-gpu-deploy/launch.py -i ./docs -o ./results --engine vllm --gpus 0,1,2,3 + python examples/multi-gpu-deploy/launch.py -i ./pdfs -o ./out --engine-args "--mem-fraction-static 0.85" +""" + +import sys +import argparse +from pathlib import Path + +from gpu_utils import DEFAULT_BASE_PORT, DEFAULT_MIN_FREE_MB, _print_err + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Multi-GPU launcher for GLM-OCR", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python examples/multi-gpu-deploy/launch.py -i ./images -o ./output + python examples/multi-gpu-deploy/launch.py -i ./docs -o ./results --engine vllm --gpus 0,1,2,3 + python examples/multi-gpu-deploy/launch.py -i ./pdfs -o ./out --engine-args "--mem-fraction-static 0.85" + python examples/multi-gpu-deploy/launch.py -i ./imgs -o ./out --min-free-mb 20000 --timeout 900 + """, + ) + + parser.add_argument( + "--worker", action="store_true", help=argparse.SUPPRESS + ) + parser.add_argument("--gpu-id", type=int, help=argparse.SUPPRESS) + parser.add_argument("--port", type=int, help=argparse.SUPPRESS) + parser.add_argument("--filelist", type=str, help=argparse.SUPPRESS) + parser.add_argument("--progress-file", type=str, help=argparse.SUPPRESS) + parser.add_argument( + "--input-root", type=str, default=None, help=argparse.SUPPRESS + ) + + parser.add_argument( + "--input", "-i", type=str, help="Input image file or directory" + ) + parser.add_argument( + "--output", + "-o", + type=str, + default="./output", + help="Output directory (default: ./output)", + ) + parser.add_argument( + "--model", + "-m", + type=str, + default="zai-org/GLM-OCR", + help="Model name or path (default: zai-org/GLM-OCR)", + ) + parser.add_argument( + "--engine", + type=str, + default="sglang", + choices=["sglang", "vllm"], + help="Inference engine (default: sglang)", + ) + parser.add_argument( + "--gpus", + type=str, + default="auto", + help="GPU IDs, comma-separated, or 'auto' (default: auto)", + ) + parser.add_argument( + "--base-port", + type=int, + default=DEFAULT_BASE_PORT, + help=f"Base port for engine services (default: {DEFAULT_BASE_PORT})", + ) + parser.add_argument( + "--min-free-mb", + type=int, + default=DEFAULT_MIN_FREE_MB, + help=f"Min free GPU memory in MB (default: {DEFAULT_MIN_FREE_MB})", + ) + parser.add_argument( + "--timeout", + type=int, + default=600, + help="Engine startup timeout in seconds (default: 600)", + ) + parser.add_argument( + "--engine-args", + type=str, + default=None, + help='Extra args for engine ' + '(e.g. "--mem-fraction-static 0.85")', + ) + parser.add_argument( + "--config", + "-c", + type=str, + default=None, + help="Path to glmocr config YAML", + ) + parser.add_argument( + "--log-level", + type=str, + default="WARNING", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Log level for workers (default: WARNING)", + ) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + if args.worker: + from worker import run_worker + + run_worker(args) + else: + if not args.input: + _print_err("Error: --input/-i is required") + sys.exit(1) + Path(args.output).mkdir(parents=True, exist_ok=True) + + from coordinator import Coordinator + + coordinator = Coordinator(args) + coordinator.run() + + +if __name__ == "__main__": + main() diff --git a/examples/multi-gpu-deploy/worker.py b/examples/multi-gpu-deploy/worker.py new file mode 100644 index 0000000..5936e79 --- /dev/null +++ b/examples/multi-gpu-deploy/worker.py @@ -0,0 +1,104 @@ +"""Worker process — runs inside a subprocess with CUDA_VISIBLE_DEVICES +already set to a single GPU.""" + +import sys +import json +from pathlib import Path +from typing import Any, Dict, List + +from gpu_utils import _print_err +from engine import write_progress + + +def run_worker(args) -> None: + """Process a shard of files using the GLM-OCR pipeline. + + ``cuda_visible_devices="0"`` always refers to the intended physical GPU + because the parent process restricts visibility via CUDA_VISIBLE_DEVICES. + """ + with open(args.filelist, "r") as f: + files = json.load(f) + + if not files: + write_progress(args.progress_file, 0, 0, status="done") + return + + total = len(files) + completed = 0 + failed = 0 + failed_files: List[Dict[str, str]] = [] + + write_progress(args.progress_file, 0, total, 0, "loading_model") + + try: + from glmocr.api import GlmOcr + from glmocr.utils.logging import configure_logging + + configure_logging(level=args.log_level or "WARNING") + + glm_kwargs: Dict[str, Any] = { + "ocr_api_port": args.port, + "cuda_visible_devices": "0", + } + if args.config: + glm_kwargs["config_path"] = args.config + + with GlmOcr(**glm_kwargs) as parser: + write_progress(args.progress_file, 0, total, 0, "running") + + for result in parser.parse(files, stream=True): + completed += 1 + + try: + save_dir = args.output + if args.input_root and result.original_images: + try: + rel = Path( + result.original_images[0] + ).parent.relative_to(args.input_root) + if str(rel) != ".": + save_dir = str(Path(args.output) / rel) + except ValueError: + pass + + result.save(output_dir=save_dir) + except Exception as e: + failed += 1 + src = ( + result.original_images[0] + if result.original_images + else "unknown" + ) + failed_files.append({"file": src, "error": str(e)}) + + write_progress( + args.progress_file, completed, total, failed, "running" + ) + + except Exception as e: + import traceback + + _print_err(f"[GPU {args.gpu_id}] Worker error: {e}") + traceback.print_exc(file=sys.stderr) + write_progress( + args.progress_file, completed, total, failed, f"error: {e}" + ) + _save_failed_list(args.progress_file, failed_files) + return + + status = "done" if failed == 0 else "done_with_errors" + write_progress(args.progress_file, completed, total, failed, status) + _save_failed_list(args.progress_file, failed_files) + + +def _save_failed_list( + progress_file: str, failed_files: List[Dict[str, str]] +) -> None: + if not failed_files: + return + path = progress_file.replace(".json", "_failed.json") + try: + with open(path, "w") as f: + json.dump(failed_files, f, ensure_ascii=False, indent=2) + except OSError: + pass From 12390895c999f884c9c30949d903209f5736402f Mon Sep 17 00:00:00 2001 From: xueyadong Date: Wed, 18 Mar 2026 04:53:29 +0000 Subject: [PATCH 22/38] Refactor multi-GPU deployment to eliminate tempfile usage and direct output to log directory --- examples/multi-gpu-deploy/coordinator.py | 10 +++------- examples/multi-gpu-deploy/engine.py | 5 ++++- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/multi-gpu-deploy/coordinator.py b/examples/multi-gpu-deploy/coordinator.py index 1124e45..cf4988a 100644 --- a/examples/multi-gpu-deploy/coordinator.py +++ b/examples/multi-gpu-deploy/coordinator.py @@ -7,7 +7,6 @@ import json import time import signal -import tempfile import subprocess from pathlib import Path import concurrent.futures @@ -60,7 +59,6 @@ def __init__(self, args): self.worker_procs: Dict[int, subprocess.Popen] = {} self.progress_files: Dict[int, str] = {} self.file_handles: List[Any] = [] - self.tmp_dir = tempfile.mkdtemp(prefix="glmocr_mgpu_") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.log_dir = Path("logs") / timestamp @@ -380,14 +378,12 @@ def _step5_start_workers( if self._shutdown: break - filelist_path = os.path.join( - self.tmp_dir, f"shard_gpu{gpu_id}.json" - ) + filelist_path = str(self.log_dir / f"shard_gpu{gpu_id}.json") with open(filelist_path, "w") as f: json.dump(shard, f) - progress_path = os.path.join( - self.tmp_dir, f"progress_gpu{gpu_id}.json" + progress_path = str( + self.log_dir / f"progress_gpu{gpu_id}.json" ) self.progress_files[gpu_id] = progress_path diff --git a/examples/multi-gpu-deploy/engine.py b/examples/multi-gpu-deploy/engine.py index c2a589b..fa6da4b 100644 --- a/examples/multi-gpu-deploy/engine.py +++ b/examples/multi-gpu-deploy/engine.py @@ -116,6 +116,9 @@ def wait_for_service( import urllib.error url = f"http://127.0.0.1:{port}/v1/models" + # Bypass any HTTP proxy for localhost connections + no_proxy_handler = urllib.request.ProxyHandler({}) + opener = urllib.request.build_opener(no_proxy_handler) start = time.time() while time.time() - start < timeout: @@ -123,7 +126,7 @@ def wait_for_service( return False, int(time.time() - start) try: req = urllib.request.Request(url, method="GET") - with urllib.request.urlopen(req, timeout=5) as resp: + with opener.open(req, timeout=5) as resp: if resp.status == 200: return True, int(time.time() - start) except Exception: From f19f404421207661b012d9703076ebcd17211b23 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Wed, 18 Mar 2026 13:12:42 +0000 Subject: [PATCH 23/38] Update error handling in recognition process to log failures and set content to None --- glmocr/pipeline/_workers.py | 7 +++++-- glmocr/postprocess/result_formatter.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/glmocr/pipeline/_workers.py b/glmocr/pipeline/_workers.py index 2339a77..6e0dbca 100644 --- a/glmocr/pipeline/_workers.py +++ b/glmocr/pipeline/_workers.py @@ -382,10 +382,13 @@ def _handle_future_result( content = response["choices"][0]["message"]["content"] region["content"] = content.strip() if content else "" else: - region["content"] = "" + logger.warning( + "Recognition failed for page %d: HTTP %s", page_idx, status_code + ) + region["content"] = None except Exception as e: logger.warning("Recognition failed for page %d: %s", page_idx, e) - region["content"] = "" + region["content"] = None state.add_recognition_result(page_idx, region) diff --git a/glmocr/postprocess/result_formatter.py b/glmocr/postprocess/result_formatter.py index ed9011d..96f5b2c 100644 --- a/glmocr/postprocess/result_formatter.py +++ b/glmocr/postprocess/result_formatter.py @@ -177,9 +177,9 @@ def process( result["native_label"], ) - # Skip empty content (after formatting) + # Skip empty or failed content (after formatting) content = result.get("content") - if isinstance(content, str) and content.strip() == "": + if content is None or (isinstance(content, str) and content.strip() == ""): continue # Update index From 88fc499fc3e08b6eef85b52d5d30988ec65d540f Mon Sep 17 00:00:00 2001 From: xueyadong Date: Wed, 18 Mar 2026 13:25:44 +0000 Subject: [PATCH 24/38] Add health monitoring to OCR pipeline with a watchdog thread and socket connectivity check --- glmocr/ocr_client.py | 9 +++++++++ glmocr/pipeline/pipeline.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/glmocr/ocr_client.py b/glmocr/ocr_client.py index 57c8b9c..ed3e32d 100644 --- a/glmocr/ocr_client.py +++ b/glmocr/ocr_client.py @@ -119,6 +119,15 @@ def start(self): if self._session is None: self._session = self._make_session() + def is_alive(self, timeout: float = 5.0) -> bool: + """Quick socket-level check whether the API port is still reachable.""" + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(timeout) + return sock.connect_ex((self.api_host, self.api_port)) == 0 + except Exception: + return False + def stop(self): """No-op: this client does not manage server lifecycle.""" logger.debug("API recognizer does not manage server lifecycle.") diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index f7c0f28..35163b9 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -154,6 +154,13 @@ def process( t2.start() t3.start() + t_watchdog = threading.Thread( + target=self._health_watchdog, + args=(state,), + daemon=True, + ) + t_watchdog.start() + try: yield from self._emit_results(state, tracker, original_inputs) finally: @@ -161,6 +168,7 @@ def process( t1.join(timeout=10) t2.join(timeout=10) t3.join(timeout=10) + t_watchdog.join(timeout=5) self._current_state = None state.raise_if_exceptions() @@ -202,6 +210,35 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.stop() + # ------------------------------------------------------------------ + # Health watchdog + # ------------------------------------------------------------------ + + def _health_watchdog( + self, + state: PipelineState, + check_interval: float = 5.0, + ) -> None: + """Daemon thread that monitors OCR service liveness. + + Periodically probes the API port via socket. On the first + failure the pipeline is shut down immediately so that workers + stop instead of accumulating failed requests. + """ + while not state.is_shutdown: + state._shutdown_event.wait(check_interval) + if state.is_shutdown: + break + + if not self.ocr_client.is_alive(): + error = RuntimeError( + f"OCR service at {self.ocr_client.api_host}:{self.ocr_client.api_port} " + f"is no longer available" + ) + logger.error("%s", error) + state.record_exception("HealthWatchdog", error) + break + # ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ From 0d51f41a16d60544090852892877f606cfd60ffa Mon Sep 17 00:00:00 2001 From: xueyadong Date: Wed, 18 Mar 2026 13:27:28 +0000 Subject: [PATCH 25/38] Add engine health checks in multi-GPU coordinator to monitor and handle crashed processes --- examples/multi-gpu-deploy/coordinator.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/examples/multi-gpu-deploy/coordinator.py b/examples/multi-gpu-deploy/coordinator.py index cf4988a..a04f620 100644 --- a/examples/multi-gpu-deploy/coordinator.py +++ b/examples/multi-gpu-deploy/coordinator.py @@ -450,8 +450,11 @@ def _monitor_progress(self, total_files: int) -> None: pbar = None last_total = 0 + dead_engines: set = set() while not self._shutdown: + self._check_engines(dead_engines) + all_done = True total_completed = 0 total_failed = 0 @@ -505,6 +508,22 @@ def _monitor_progress(self, total_files: int) -> None: if pbar: pbar.close() + def _check_engines(self, dead_engines: set) -> None: + """Check engine processes and kill workers whose engine has died.""" + for gpu_id, proc in self.engine_procs.items(): + if gpu_id in dead_engines: + continue + if proc.poll() is not None: + dead_engines.add(gpu_id) + print( + f"\n [ERROR] Engine on GPU {gpu_id} crashed " + f"(exit code: {proc.returncode}). " + f"Killing worker for GPU {gpu_id}..." + ) + worker = self.worker_procs.get(gpu_id) + if worker and worker.poll() is None: + self._kill_proc(worker) + # ------------------------------------------------------------------ # Summary # ------------------------------------------------------------------ From 978564a8854ec791da15e8e239663638e853e755 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Thu, 19 Mar 2026 12:30:19 +0000 Subject: [PATCH 26/38] Refactor content validation in ResultFormatter to skip non-image labels before checking for empty content --- glmocr/postprocess/result_formatter.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/glmocr/postprocess/result_formatter.py b/glmocr/postprocess/result_formatter.py index 96f5b2c..d265ced 100644 --- a/glmocr/postprocess/result_formatter.py +++ b/glmocr/postprocess/result_formatter.py @@ -178,9 +178,10 @@ def process( ) # Skip empty or failed content (after formatting) - content = result.get("content") - if content is None or (isinstance(content, str) and content.strip() == ""): - continue + if result["label"] != "image": + content = result.get("content") + if content is None or (isinstance(content, str) and content.strip() == ""): + continue # Update index result["index"] = valid_idx From 507c2d1548645fcc8639ce47970c8e9603a363ed Mon Sep 17 00:00:00 2001 From: xueyadong Date: Fri, 20 Mar 2026 08:14:36 +0000 Subject: [PATCH 27/38] Add inline formula normalization in result formatting process --- glmocr/postprocess/result_formatter.py | 3 ++ glmocr/utils/result_postprocess_utils.py | 40 ++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/glmocr/postprocess/result_formatter.py b/glmocr/postprocess/result_formatter.py index d265ced..f0de515 100644 --- a/glmocr/postprocess/result_formatter.py +++ b/glmocr/postprocess/result_formatter.py @@ -23,6 +23,7 @@ from glmocr.utils.result_postprocess_utils import ( clean_repeated_content, clean_formula_number, + normalize_inline_formula, ) if TYPE_CHECKING: @@ -261,6 +262,8 @@ def _clean_content(self, content: str) -> str: if len(content) >= 2048: content = clean_repeated_content(content) + content = normalize_inline_formula(content) + return content.strip() def _format_content(self, content: Any, label: str, native_label: str) -> str: diff --git a/glmocr/utils/result_postprocess_utils.py b/glmocr/utils/result_postprocess_utils.py index 34255a3..c990aa0 100644 --- a/glmocr/utils/result_postprocess_utils.py +++ b/glmocr/utils/result_postprocess_utils.py @@ -113,3 +113,43 @@ def clean_formula_number(number_content: str) -> str: elif number_clean.startswith("(") and number_clean.endswith(")"): number_clean = number_clean[1:-1] return number_clean + +def normalize_inline_formula(content: str) -> str: + """Normalize inline formula spacing. + + ``$ x $`` → ``$x$``, and ensure a space between surrounding text + and the ``$...$`` delimiter. + """ + INLINE_FORMULA_RE = re.compile( + r"(? Date: Fri, 20 Mar 2026 08:21:15 +0000 Subject: [PATCH 28/38] Refactor result yielding in Pipeline to maintain original input order --- glmocr/pipeline/pipeline.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index 35163b9..8b80ec7 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -292,21 +292,30 @@ def _emit_results( tracker: UnitTracker, original_inputs: List[str], ) -> Generator[PipelineResult, None, None]: - """Wait for units to complete and yield their formatted results. + """Wait for units to complete and yield their formatted results + **in the original input order**. - A unit enters the ready queue when: - - ``finalize_unit`` has been called (region count is known), AND - - all its regions have been recognised (``on_region_done`` counter - reached the target). + Units may complete in arbitrary order; finished results are buffered + and yielded sequentially (unit 0 first, then 1, 2, …). ``None`` from the ready queue signals a pipeline error (shutdown). """ - emitted: set = set() - while len(emitted) < tracker.num_units: + pending: Dict[int, PipelineResult] = {} + built: set = set() + next_to_emit = 0 + num_units = tracker.num_units + + while next_to_emit < num_units: + while next_to_emit in pending: + yield pending.pop(next_to_emit) + next_to_emit += 1 + if next_to_emit >= num_units: + break + u = tracker.wait_next_ready_unit() if u is None: break - if u in emitted: + if u in built: continue region_count = tracker.unit_region_count[u] @@ -330,14 +339,13 @@ def _emit_results( grouped, cropped_images=cropped_images or None, ) - # Collect layout visualization images for this unit vis_images = {} for pi in page_indices: img = state.layout_vis_images.pop(pi, None) if img is not None: vis_images[pi] = img - yield PipelineResult( + pending[u] = PipelineResult( json_result=json_u, markdown_result=md_u, original_images=[original_inputs[u]], @@ -345,4 +353,4 @@ def _emit_results( raw_json_result=raw_json, layout_vis_images=vis_images or None, ) - emitted.add(u) + built.add(u) From b3874ad26bd6429f24f283df56f927c53a24454e Mon Sep 17 00:00:00 2001 From: xueyadong Date: Fri, 20 Mar 2026 09:28:44 +0000 Subject: [PATCH 29/38] Add --no-save option to multi-GPU deployment for optional result file writing --- examples/multi-gpu-deploy/coordinator.py | 2 + examples/multi-gpu-deploy/launch.py | 5 +++ examples/multi-gpu-deploy/worker.py | 48 +++++++++++++----------- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/examples/multi-gpu-deploy/coordinator.py b/examples/multi-gpu-deploy/coordinator.py index a04f620..d6c88fd 100644 --- a/examples/multi-gpu-deploy/coordinator.py +++ b/examples/multi-gpu-deploy/coordinator.py @@ -411,6 +411,8 @@ def _step5_start_workers( worker_cmd.extend(["--input-root", input_root]) if self.args.config: worker_cmd.extend(["--config", self.args.config]) + if getattr(self.args, "no_save", False): + worker_cmd.append("--no-save") worker_log = self.log_dir / f"worker_gpu{gpu_id}.log" wfh = open(worker_log, "w") diff --git a/examples/multi-gpu-deploy/launch.py b/examples/multi-gpu-deploy/launch.py index 6edca85..6cf2952 100644 --- a/examples/multi-gpu-deploy/launch.py +++ b/examples/multi-gpu-deploy/launch.py @@ -114,6 +114,11 @@ def parse_args() -> argparse.Namespace: choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Log level for workers (default: WARNING)", ) + parser.add_argument( + "--no-save", + action="store_true", + help="Do not write any result files (useful for benchmarking / stress tests)", + ) return parser.parse_args() diff --git a/examples/multi-gpu-deploy/worker.py b/examples/multi-gpu-deploy/worker.py index 5936e79..7f769d5 100644 --- a/examples/multi-gpu-deploy/worker.py +++ b/examples/multi-gpu-deploy/worker.py @@ -46,30 +46,33 @@ def run_worker(args) -> None: with GlmOcr(**glm_kwargs) as parser: write_progress(args.progress_file, 0, total, 0, "running") + no_save = getattr(args, "no_save", False) + for result in parser.parse(files, stream=True): completed += 1 - try: - save_dir = args.output - if args.input_root and result.original_images: - try: - rel = Path( - result.original_images[0] - ).parent.relative_to(args.input_root) - if str(rel) != ".": - save_dir = str(Path(args.output) / rel) - except ValueError: - pass - - result.save(output_dir=save_dir) - except Exception as e: - failed += 1 - src = ( - result.original_images[0] - if result.original_images - else "unknown" - ) - failed_files.append({"file": src, "error": str(e)}) + if not no_save: + try: + save_dir = args.output + if args.input_root and result.original_images: + try: + rel = Path( + result.original_images[0] + ).parent.relative_to(args.input_root) + if str(rel) != ".": + save_dir = str(Path(args.output) / rel) + except ValueError: + pass + + result.save(output_dir=save_dir) + except Exception as e: + failed += 1 + src = ( + result.original_images[0] + if result.original_images + else "unknown" + ) + failed_files.append({"file": src, "error": str(e)}) write_progress( args.progress_file, completed, total, failed, "running" @@ -88,7 +91,8 @@ def run_worker(args) -> None: status = "done" if failed == 0 else "done_with_errors" write_progress(args.progress_file, completed, total, failed, status) - _save_failed_list(args.progress_file, failed_files) + if not getattr(args, "no_save", False): + _save_failed_list(args.progress_file, failed_files) def _save_failed_list( From cee8fa0c01583d963146352f6f406d1d3b04047f Mon Sep 17 00:00:00 2001 From: xueyadong Date: Fri, 20 Mar 2026 09:53:38 +0000 Subject: [PATCH 30/38] Update multi-GPU deployment to pass log directory and engine log level for improved logging --- examples/multi-gpu-deploy/coordinator.py | 2 +- examples/multi-gpu-deploy/engine.py | 3 ++- examples/multi-gpu-deploy/launch.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/multi-gpu-deploy/coordinator.py b/examples/multi-gpu-deploy/coordinator.py index d6c88fd..63c50b7 100644 --- a/examples/multi-gpu-deploy/coordinator.py +++ b/examples/multi-gpu-deploy/coordinator.py @@ -261,8 +261,8 @@ def _step3_start_engines(self, available: List[Dict]) -> Dict[int, int]: model=self.args.model, gpu_id=gpu_id, port=port, - extra_args=self.args.engine_args or "", log_dir=str(self.log_dir), + extra_args=self.args.engine_args or "", ) self.engine_procs[gpu_id] = proc self.file_handles.append(log_fh) diff --git a/examples/multi-gpu-deploy/engine.py b/examples/multi-gpu-deploy/engine.py index fa6da4b..160a929 100644 --- a/examples/multi-gpu-deploy/engine.py +++ b/examples/multi-gpu-deploy/engine.py @@ -75,8 +75,8 @@ def start_engine( model: str, gpu_id: int, port: int, + log_dir: str, extra_args: str = "", - log_dir: str = "/tmp", engine_log_level: str = "warning", ) -> Tuple[subprocess.Popen, Path, Any]: """Start an engine service on a specific GPU. @@ -87,6 +87,7 @@ def start_engine( env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) if engine == "vllm": env["VLLM_LOGGING_LEVEL"] = engine_log_level.upper() + env["UVICORN_LOG_LEVEL"] = engine_log_level.lower() cmd = build_engine_cmd(engine, model, port, extra_args) log_path = Path(log_dir) / f"engine_gpu{gpu_id}_port{port}.log" diff --git a/examples/multi-gpu-deploy/launch.py b/examples/multi-gpu-deploy/launch.py index 6cf2952..81c2b78 100644 --- a/examples/multi-gpu-deploy/launch.py +++ b/examples/multi-gpu-deploy/launch.py @@ -117,7 +117,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--no-save", action="store_true", - help="Do not write any result files (useful for benchmarking / stress tests)", + help="Do not write any result files", ) return parser.parse_args() From 33b78f77f7df4ef804258ee055dca95e6c0c924b Mon Sep 17 00:00:00 2001 From: xueyadong Date: Fri, 20 Mar 2026 10:01:04 +0000 Subject: [PATCH 31/38] Add engine log level parameter to build_engine_cmd for enhanced logging control --- examples/multi-gpu-deploy/engine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/multi-gpu-deploy/engine.py b/examples/multi-gpu-deploy/engine.py index 160a929..9344a6e 100644 --- a/examples/multi-gpu-deploy/engine.py +++ b/examples/multi-gpu-deploy/engine.py @@ -19,6 +19,7 @@ def build_engine_cmd( model: str, port: int, extra_args: str = "", + engine_log_level: str = "warning", ) -> List[str]: """Build command to start sglang or vLLM service. @@ -61,6 +62,8 @@ def build_engine_cmd( '{"method": "mtp", "num_speculative_tokens": 1}', "--served-model-name", "glm-ocr", + "--uvicorn-log-level", + engine_log_level.lower(), ] else: raise ValueError(f"Unknown engine: {engine}") @@ -87,9 +90,8 @@ def start_engine( env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) if engine == "vllm": env["VLLM_LOGGING_LEVEL"] = engine_log_level.upper() - env["UVICORN_LOG_LEVEL"] = engine_log_level.lower() - cmd = build_engine_cmd(engine, model, port, extra_args) + cmd = build_engine_cmd(engine, model, port, extra_args, engine_log_level) log_path = Path(log_dir) / f"engine_gpu{gpu_id}_port{port}.log" log_fh = open(log_path, "w") From 9b83925784973af380e5543cb5e1a549eb24146b Mon Sep 17 00:00:00 2001 From: xueyadong Date: Tue, 24 Mar 2026 07:29:07 +0000 Subject: [PATCH 32/38] Enhance memory management in PipelineState by adding release_unit_data method to free per-page data after processing --- examples/multi-gpu-deploy/worker.py | 2 +- glmocr/pipeline/_state.py | 15 ++++++++++++--- glmocr/pipeline/pipeline.py | 2 ++ 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/examples/multi-gpu-deploy/worker.py b/examples/multi-gpu-deploy/worker.py index 7f769d5..19c40b7 100644 --- a/examples/multi-gpu-deploy/worker.py +++ b/examples/multi-gpu-deploy/worker.py @@ -48,7 +48,7 @@ def run_worker(args) -> None: no_save = getattr(args, "no_save", False) - for result in parser.parse(files, stream=True): + for result in parser.parse(files, stream=True, save_layout_visualization=not no_save): completed += 1 if not no_save: diff --git a/glmocr/pipeline/_state.py b/glmocr/pipeline/_state.py index 551d7e1..fbab39a 100644 --- a/glmocr/pipeline/_state.py +++ b/glmocr/pipeline/_state.py @@ -43,7 +43,6 @@ def __init__( self.unit_indices_holder: List[Optional[List[int]]] = [None] # ── Recognition results (stage 3 appends, main thread reads) ─ - self._recognition_results: List[Dict[str, Any]] = [] self._results_by_page: Dict[int, List[Dict]] = {} self._results_lock = threading.Lock() @@ -118,9 +117,7 @@ def register_page(self, page_idx: int, unit_idx: int) -> None: def add_recognition_result(self, page_idx: int, region: Dict) -> None: """Append a completed region result and notify the tracker.""" - result = {"page_idx": page_idx, "region": region} with self._results_lock: - self._recognition_results.append(result) self._results_by_page.setdefault(page_idx, []).append(region) tracker = self._tracker if tracker is not None: @@ -132,6 +129,18 @@ def get_grouped_results(self, page_indices: List[int]) -> List[List[Dict]]: with self._results_lock: return [list(self._results_by_page.get(pi, [])) for pi in page_indices] + def release_unit_data(self, page_indices: List[int]) -> None: + """Release per-page data for a unit after it has been emitted. + + Frees recognition results and layout results so that memory is not + held for the lifetime of the entire process() call. + """ + with self._results_lock: + for pi in page_indices: + self._results_by_page.pop(pi, None) + for pi in page_indices: + self.layout_results_dict.pop(pi, None) + # ------------------------------------------------------------------ # Pre-cropped image store (for image-type regions) # ------------------------------------------------------------------ diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index 8b80ec7..8887897 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -345,6 +345,8 @@ def _emit_results( if img is not None: vis_images[pi] = img + state.release_unit_data(page_indices) + pending[u] = PipelineResult( json_result=json_u, markdown_result=md_u, From 2ec08098877a0cb8edc9dce376a854668bc6d4cd Mon Sep 17 00:00:00 2001 From: xueyadong Date: Thu, 26 Mar 2026 08:27:27 +0000 Subject: [PATCH 33/38] Add post-processing configuration options to ResultFormatter for merging formulas, text blocks, and formatting bullet points --- glmocr/config.py | 3 +++ glmocr/config.yaml | 6 +++++- glmocr/postprocess/result_formatter.py | 12 +++++++++--- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/glmocr/config.py b/glmocr/config.py index e450645..95bf855 100644 --- a/glmocr/config.py +++ b/glmocr/config.py @@ -172,6 +172,9 @@ class ResultFormatterConfig(_BaseConfig): filter_nested: bool = True min_overlap_ratio: float = 0.8 output_format: str = "both" # json | markdown | both + enable_merge_formula_numbers: bool = True + enable_merge_text_blocks: bool = True + enable_format_bullet_points: bool = True label_visualization_mapping: Dict[str, Any] = Field(default_factory=dict) diff --git a/glmocr/config.yaml b/glmocr/config.yaml index 33c68f2..a34e5a3 100644 --- a/glmocr/config.yaml +++ b/glmocr/config.yaml @@ -106,7 +106,7 @@ pipeline: # Maximum parallel workers for region recognition # Lower values to reduce 503 errors on busy OCR servers - max_workers: 64 + max_workers: 32 # Queue sizes page_maxsize: 100 region_maxsize: 2000 @@ -149,6 +149,10 @@ pipeline: result_formatter: # Output format: json, markdown, or both output_format: both + # Post-process switches + enable_merge_formula_numbers: true + enable_merge_text_blocks: true + enable_format_bullet_points: true # Label to visualization category mapping (for layout visualization) label_visualization_mapping: diff --git a/glmocr/postprocess/result_formatter.py b/glmocr/postprocess/result_formatter.py index f0de515..6d0a759 100644 --- a/glmocr/postprocess/result_formatter.py +++ b/glmocr/postprocess/result_formatter.py @@ -63,6 +63,9 @@ def __init__(self, config: "ResultFormatterConfig"): # Output format self.output_format = config.output_format + self.enable_merge_formula_numbers = config.enable_merge_formula_numbers + self.enable_merge_text_blocks = config.enable_merge_text_blocks + self.enable_format_bullet_points = config.enable_format_bullet_points # ========================================================================= # OCR-only mode @@ -193,13 +196,16 @@ def process( json_page_results.append(result) # Merge formula with formula_number - json_page_results = self._merge_formula_numbers(json_page_results) + if self.enable_merge_formula_numbers: + json_page_results = self._merge_formula_numbers(json_page_results) # Merge hyphenated text blocks - json_page_results = self._merge_text_blocks(json_page_results) + if self.enable_merge_text_blocks: + json_page_results = self._merge_text_blocks(json_page_results) # Format bullet points - json_page_results = self._format_bullet_points(json_page_results) + if self.enable_format_bullet_points: + json_page_results = self._format_bullet_points(json_page_results) json_final_results.append(json_page_results) From 4252a3c458245e47765bca1ab2ad751f4343bca4 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Thu, 26 Mar 2026 09:57:49 +0000 Subject: [PATCH 34/38] Update default configuration parameters --- README.md | 8 ++++---- README_zh.md | 8 ++++---- glmocr/config.py | 18 ++++++------------ glmocr/config.yaml | 18 ++++++------------ glmocr/dataloader/page_loader.py | 23 ----------------------- glmocr/ocr_client.py | 4 ---- 6 files changed, 20 insertions(+), 59 deletions(-) diff --git a/README.md b/README.md index bef4c16..b248b5d 100644 --- a/README.md +++ b/README.md @@ -245,13 +245,13 @@ pipeline: api_host: localhost api_port: 8080 api_key: null # or set API_KEY env var - connect_timeout: 300 - request_timeout: 300 + connect_timeout: 30 + request_timeout: 120 # Page loader settings page_loader: - max_tokens: 16384 - temperature: 0.01 + max_tokens: 8192 + temperature: 0.0 image_format: JPEG min_pixels: 12544 max_pixels: 71372800 diff --git a/README_zh.md b/README_zh.md index 8ba00b3..488e350 100644 --- a/README_zh.md +++ b/README_zh.md @@ -246,13 +246,13 @@ pipeline: api_host: localhost api_port: 8080 api_key: null # or set API_KEY env var - connect_timeout: 300 - request_timeout: 300 + connect_timeout: 30 + request_timeout: 120 # Page loader settings page_loader: - max_tokens: 16384 - temperature: 0.01 + max_tokens: 8192 + temperature: 0.0 image_format: JPEG min_pixels: 12544 max_pixels: 71372800 diff --git a/glmocr/config.py b/glmocr/config.py index 95bf855..f6ebc68 100644 --- a/glmocr/config.py +++ b/glmocr/config.py @@ -74,10 +74,9 @@ class OCRApiConfig(_BaseConfig): api_scheme: Optional[str] = None api_path: str = "/v1/chat/completions" api_url: Optional[str] = None - model: Optional[str] = None # Optional model name (required by Ollama/MLX) api_key: Optional[str] = None - # Model name included in API requests. + # Model name included in API requests (required by Ollama/MLX). model: Optional[str] = None headers: Dict[str, str] = Field(default_factory=dict) verify_ssl: bool = False @@ -86,8 +85,8 @@ class OCRApiConfig(_BaseConfig): # Use "ollama_generate" for Ollama's native /api/generate endpoint api_mode: str = "openai" - connect_timeout: int = 300 - request_timeout: int = 300 + connect_timeout: int = 30 + request_timeout: int = 120 # Retry behavior (for transient upstream failures like 429/5xx) retry_max_attempts: int = 2 # total attempts = 1 + retry_max_attempts @@ -143,8 +142,8 @@ class MaaSApiConfig(_BaseConfig): class PageLoaderConfig(_BaseConfig): - max_tokens: int = 16384 - temperature: float = 0.01 + max_tokens: int = 8192 + temperature: float = 0.0 top_p: float = 0.00001 top_k: int = 1 repetition_penalty: float = 1.1 @@ -156,11 +155,6 @@ class PageLoaderConfig(_BaseConfig): min_pixels: int = 112 * 112 max_pixels: int = 14 * 14 * 4 * 1280 - default_prompt: str = ( - "Recognize the text in the image and output in Markdown format. " - "Preserve the original layout (headings/paragraphs/tables/formulas). " - "Do not fabricate content that does not exist in the image." - ) task_prompt_mapping: Optional[Dict[str, str]] = None pdf_dpi: int = 200 @@ -180,7 +174,7 @@ class ResultFormatterConfig(_BaseConfig): class LayoutConfig(_BaseConfig): model_dir: Optional[str] = None - threshold: float = 0.4 + threshold: float = 0.3 threshold_by_class: Optional[Dict[Union[int, str], float]] = None batch_size: int = 8 workers: int = 1 diff --git a/glmocr/config.yaml b/glmocr/config.yaml index a34e5a3..ec06e31 100644 --- a/glmocr/config.yaml +++ b/glmocr/config.yaml @@ -114,10 +114,10 @@ pipeline: # Page loader: handles image/PDF loading and API request building page_loader: # Generation parameters - max_tokens: 4096 - temperature: 0.8 - top_p: 0.9 - top_k: 50 + max_tokens: 8192 + temperature: 0.0 + top_p: 0.00001 + top_k: 1 repetition_penalty: 1.1 # Image processing @@ -128,12 +128,6 @@ pipeline: min_pixels: 12544 # 112 * 112 max_pixels: 71372800 # 14 * 14 * 4 * 1280 - # Default prompt for OCR (used when no custom prompt provided) - default_prompt: > - Recognize the text in the image and output in Markdown format. - Preserve the original layout (headings/paragraphs/tables/formulas). - Do not fabricate content that does not exist in the image. - # Task-specific prompts task_prompt_mapping: text: "Text Recognition:" @@ -183,7 +177,7 @@ pipeline: # PP-DocLayoutV3 model directory # Can be a local folder or a Hugging Face model id # (Use *_safetensors for Transformers; PaddlePaddle/PP-DocLayoutV3 is a PaddleOCR export) - model_dir: PaddlePaddle/PP-DocLayoutV3_safetensors + model_dir: /workspace/ocr_document_data_cloud/ckpt/opensource/PP-DocLayoutV3_safetensors # Detection threshold threshold: 0.3 @@ -197,7 +191,7 @@ pipeline: # batch_size: max images per model forward pass (reduce to 1 if OOM) batch_size: 1 workers: 1 - cuda_visible_devices: "0" + cuda_visible_devices: "5" # img_size: null # resize input (optional) # Use polygon masks for region cropping and visualization. diff --git a/glmocr/dataloader/page_loader.py b/glmocr/dataloader/page_loader.py index 80682d7..afc4b2e 100644 --- a/glmocr/dataloader/page_loader.py +++ b/glmocr/dataloader/page_loader.py @@ -82,9 +82,6 @@ def __init__(self, config: "PageLoaderConfig"): # Task prompt mapping self.task_prompt_mapping = config.task_prompt_mapping - # Default OCR instruction (used when user provides images without text) - self.default_prompt = config.default_prompt - # PDF-to-image parameters self.pdf_dpi = config.pdf_dpi self.pdf_max_pages = config.pdf_max_pages @@ -329,24 +326,6 @@ def build_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: if msg["role"] in ("system", "assistant", "tool"): processed_messages.append(msg) elif msg["role"] in ("user", "observation"): - # If user provides images but no text, inject the default OCR instruction - if isinstance(msg.get("content"), list): - has_image = any( - c.get("type") == "image_url" for c in msg["content"] - ) - has_text = any( - c.get("type") == "text" and str(c.get("text", "")).strip() - for c in msg["content"] - ) - if has_image and not has_text: - msg = { - **msg, - "content": [ - *msg["content"], - {"type": "text", "text": self.default_prompt}, - ], - } - processed_messages.append(self._process_msg_standard(msg)) else: raise ValueError(f"{msg['role']} is not a valid role for a message.") @@ -369,8 +348,6 @@ def build_request_from_image( prompt_text = "" if self.task_prompt_mapping: prompt_text = self.task_prompt_mapping.get(task_type, "") - if not str(prompt_text).strip(): - prompt_text = self.default_prompt encoded_image = load_image_to_base64( image, diff --git a/glmocr/ocr_client.py b/glmocr/ocr_client.py index ed3e32d..62742d4 100644 --- a/glmocr/ocr_client.py +++ b/glmocr/ocr_client.py @@ -276,10 +276,6 @@ def process(self, request_data: Dict) -> Tuple[Dict, int]: if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" - # Inject model if configured - if self.model and "model" not in request_data: - request_data["model"] = self.model - total_attempts = int(self.retry_max_attempts) + 1 last_error: Optional[str] = None From cf323152e26482609e2f5f4ed541039cdf0a4cf4 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Thu, 26 Mar 2026 12:12:32 +0000 Subject: [PATCH 35/38] Add preserve_order parameter to GlmOcr and Pipeline classes for consistent output order --- glmocr/api.py | 24 ++++++++++++++++++++++-- glmocr/cli.py | 1 + glmocr/pipeline/pipeline.py | 34 ++++++++++++++++++++++------------ 3 files changed, 45 insertions(+), 14 deletions(-) diff --git a/glmocr/api.py b/glmocr/api.py index 098baa4..292bb31 100644 --- a/glmocr/api.py +++ b/glmocr/api.py @@ -212,6 +212,7 @@ def parse( *, stream: bool = False, save_layout_visualization: bool = True, + preserve_order: bool = True, **kwargs: Any, ) -> Union[ PipelineResult, List[PipelineResult], Generator[PipelineResult, None, None] @@ -230,6 +231,7 @@ def parse( stream: If ``True``, yields one :class:`PipelineResult` at a time. save_layout_visualization: Whether to save layout visualization artifacts. + preserve_order: Whether to keep output order consistent with input order. **kwargs: Additional parameters for MaaS mode (return_crop_images, need_layout_visualization, start_page_id, end_page_id, etc.) @@ -250,12 +252,21 @@ def parse( images = [images] if stream: - return self._parse_stream(images, save_layout_visualization, **kwargs) + return self._parse_stream( + images, + save_layout_visualization, + preserve_order=preserve_order, + **kwargs, + ) if self._use_maas: result_list = self._parse_maas(images, save_layout_visualization, **kwargs) else: - result_list = self._parse_selfhosted(images, save_layout_visualization) + result_list = self._parse_selfhosted( + images, + save_layout_visualization, + preserve_order=preserve_order, + ) return result_list[0] if _single else result_list @@ -263,6 +274,7 @@ def _parse_stream( self, images: List[Union[str, bytes, Path]], save_layout_visualization: bool = True, + preserve_order: bool = True, **kwargs: Any, ) -> Generator[PipelineResult, None, None]: """Internal: yield one PipelineResult per input. Used by parse(stream=True).""" @@ -288,6 +300,7 @@ def _parse_stream( for result in self._stream_parse_selfhosted( images, save_layout_visualization=save_layout_visualization, + preserve_order=preserve_order, ): yield result @@ -469,6 +482,7 @@ def _parse_selfhosted( self, images: List[Union[str, bytes, Path]], save_layout_visualization: bool = True, + preserve_order: bool = True, ) -> List[PipelineResult]: """Parse using self-hosted vLLM/SGLang pipeline.""" request_data = self._build_selfhosted_request(images) @@ -476,6 +490,7 @@ def _parse_selfhosted( self._pipeline.process( request_data, save_layout_visualization=save_layout_visualization, + preserve_order=preserve_order, ) ) return results @@ -484,12 +499,14 @@ def _stream_parse_selfhosted( self, images: List[Union[str, bytes, Path]], save_layout_visualization: bool = True, + preserve_order: bool = True, ) -> Generator[PipelineResult, None, None]: """Streaming variant of self-hosted parse().""" request_data = self._build_selfhosted_request(images) for result in self._pipeline.process( request_data, save_layout_visualization=save_layout_visualization, + preserve_order=preserve_order, ): yield result @@ -605,6 +622,7 @@ def parse( save_layout_visualization: bool = True, *, stream: bool = False, + preserve_order: bool = True, api_key: Optional[str] = None, api_url: Optional[str] = None, model: Optional[str] = None, @@ -634,6 +652,7 @@ def parse( config_path: Config file path. save_layout_visualization: Whether to save layout visualization. stream: If ``True``, returns a generator. + preserve_order: Whether to keep output order consistent with input order. api_key: API key. api_url: MaaS API endpoint URL. model: Model name. @@ -654,5 +673,6 @@ def parse( images, stream=stream, save_layout_visualization=save_layout_visualization, + preserve_order=preserve_order, **kwargs, ) diff --git a/glmocr/cli.py b/glmocr/cli.py index a155c32..5cbcb88 100644 --- a/glmocr/cli.py +++ b/glmocr/cli.py @@ -218,6 +218,7 @@ def main(): image_paths, stream=True, save_layout_visualization=save_layout_vis, + preserve_order=False, ): file_name = ( Path(result.original_images[0]).name diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index 8887897..4674739 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -101,6 +101,7 @@ def process( save_layout_visualization: bool = False, page_maxsize: Optional[int] = None, region_maxsize: Optional[int] = None, + preserve_order: bool = True, ) -> Generator[PipelineResult, None, None]: """Process a request; yield one ``PipelineResult`` per input unit. @@ -112,6 +113,7 @@ def process( save_layout_visualization: Generate layout visualisation images. page_maxsize: Bound for the page queue. region_maxsize: Bound for the region queue. + preserve_order: Whether to emit results in input order. Yields: One ``PipelineResult`` per input URL (image or PDF). @@ -162,7 +164,9 @@ def process( t_watchdog.start() try: - yield from self._emit_results(state, tracker, original_inputs) + yield from self._emit_results( + state, tracker, original_inputs, preserve_order=preserve_order + ) finally: state.request_shutdown() t1.join(timeout=10) @@ -291,12 +295,13 @@ def _emit_results( state: PipelineState, tracker: UnitTracker, original_inputs: List[str], + preserve_order: bool = True, ) -> Generator[PipelineResult, None, None]: - """Wait for units to complete and yield their formatted results - **in the original input order**. + """Wait for units to complete and yield formatted results. - Units may complete in arbitrary order; finished results are buffered - and yielded sequentially (unit 0 first, then 1, 2, …). + When ``preserve_order`` is True, units may complete in arbitrary order + but are buffered and yielded sequentially (unit 0, 1, 2, ...). + When ``preserve_order`` is False, each ready unit is yielded immediately. ``None`` from the ready queue signals a pipeline error (shutdown). """ @@ -305,12 +310,13 @@ def _emit_results( next_to_emit = 0 num_units = tracker.num_units - while next_to_emit < num_units: - while next_to_emit in pending: - yield pending.pop(next_to_emit) - next_to_emit += 1 - if next_to_emit >= num_units: - break + while (next_to_emit < num_units) if preserve_order else (len(built) < num_units): + if preserve_order: + while next_to_emit in pending: + yield pending.pop(next_to_emit) + next_to_emit += 1 + if next_to_emit >= num_units: + break u = tracker.wait_next_ready_unit() if u is None: @@ -347,7 +353,7 @@ def _emit_results( state.release_unit_data(page_indices) - pending[u] = PipelineResult( + result = PipelineResult( json_result=json_u, markdown_result=md_u, original_images=[original_inputs[u]], @@ -356,3 +362,7 @@ def _emit_results( layout_vis_images=vis_images or None, ) built.add(u) + if preserve_order: + pending[u] = result + else: + yield result From 52687c237032aab4ff99438811fbc05eb587c4a8 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Thu, 26 Mar 2026 12:51:07 +0000 Subject: [PATCH 36/38] Refactor code for improved readability and consistency across multiple files --- glmocr/api.py | 11 +-- glmocr/cli.py | 11 +-- glmocr/config.yaml | 2 +- glmocr/dataloader/page_loader.py | 8 +- glmocr/layout/layout_detector.py | 11 +-- glmocr/ocr_client.py | 4 +- glmocr/parser_result/pipeline_result.py | 4 +- glmocr/pipeline/_common.py | 6 +- glmocr/pipeline/_state.py | 12 +-- glmocr/pipeline/_workers.py | 109 +++++++++++++++-------- glmocr/pipeline/pipeline.py | 54 +++++++---- glmocr/postprocess/result_formatter.py | 20 +++-- glmocr/utils/image_utils.py | 9 +- glmocr/utils/markdown_utils.py | 14 ++- glmocr/utils/result_postprocess_utils.py | 7 +- 15 files changed, 173 insertions(+), 109 deletions(-) diff --git a/glmocr/api.py b/glmocr/api.py index 292bb31..0c316b8 100644 --- a/glmocr/api.py +++ b/glmocr/api.py @@ -442,7 +442,9 @@ def _maas_response_to_pipeline_result( ) json_result, markdown_result, image_files = resolve_image_regions( - json_result, markdown_result, source, + json_result, + markdown_result, + source, ) # Create PipelineResult @@ -462,15 +464,14 @@ def _maas_response_to_pipeline_result( return result def _build_selfhosted_request( - self, images: List[Union[str, bytes, Path]], + self, + images: List[Union[str, bytes, Path]], ) -> Dict[str, Any]: """Build request from mixed inputs (paths, URLs, or raw bytes).""" messages: List[Dict[str, Any]] = [{"role": "user", "content": []}] for image in images: if isinstance(image, bytes): - messages[0]["content"].append( - {"type": "image_bytes", "data": image} - ) + messages[0]["content"].append({"type": "image_bytes", "data": image}) else: url = self._to_url(image) messages[0]["content"].append( diff --git a/glmocr/cli.py b/glmocr/cli.py index 5cbcb88..2adce01 100644 --- a/glmocr/cli.py +++ b/glmocr/cli.py @@ -81,8 +81,7 @@ def _queue_stats_updater(glm_parser: GlmOcr, pbar: tqdm, stop: threading.Event): def _auto_coerce(raw: str): - """Coerce a CLI string to a Python scalar. - """ + """Coerce a CLI string to a Python scalar.""" if raw.lower() in ("true", "yes"): return True if raw.lower() in ("false", "no"): @@ -171,7 +170,7 @@ def main(): metavar=("KEY", "VALUE"), dest="config_overrides", help="Override a config value using dotted path, e.g. " - "--set pipeline.ocr_api.api_port 8080", + "--set pipeline.ocr_api.api_port 8080", ) args = parser.parse_args() @@ -191,7 +190,7 @@ def main(): # Build dotted-path overrides from --set KEY VALUE pairs dotted_overrides: dict = {} - for key, value in (args.config_overrides or []): + for key, value in args.config_overrides or []: dotted_overrides[key] = _auto_coerce(value) with GlmOcr(config_path=args.config, _dotted=dotted_overrides) as glm_parser: @@ -251,7 +250,9 @@ def main(): if not args.no_save: save_dir = args.output if input_root and result.original_images: - rel = Path(result.original_images[0]).parent.relative_to(input_root) + rel = Path( + result.original_images[0] + ).parent.relative_to(input_root) if str(rel) != ".": save_dir = str(Path(args.output) / rel) result.save( diff --git a/glmocr/config.yaml b/glmocr/config.yaml index ec06e31..9081d57 100644 --- a/glmocr/config.yaml +++ b/glmocr/config.yaml @@ -196,7 +196,7 @@ pipeline: # Use polygon masks for region cropping and visualization. # When true, regions are cropped using the polygon outline from layout - # detection (more precise, masks out content outside the polygon), + # detection (more precise, masks out content outside the polygon), # recommended for documents with rotating or staggered layouts. # When false, regions are cropped using the bounding box only (faster, simpler), # recommended for regular documents without rotating. diff --git a/glmocr/dataloader/page_loader.py b/glmocr/dataloader/page_loader.py index afc4b2e..f4745a5 100644 --- a/glmocr/dataloader/page_loader.py +++ b/glmocr/dataloader/page_loader.py @@ -91,7 +91,9 @@ def __init__(self, config: "PageLoaderConfig"): # Page loading # ========================================================================= - def load_pages(self, sources: Union[str, bytes, List[Union[str, bytes]]]) -> List[Image.Image]: + def load_pages( + self, sources: Union[str, bytes, List[Union[str, bytes]]] + ) -> List[Image.Image]: """Load sources into a list of PIL Images. Supports image files, PDFs, and raw bytes (PDFs are expanded into @@ -139,7 +141,9 @@ def load_pages_with_unit_indices( unit_indices.extend([unit_idx] * len(pages)) return all_pages, unit_indices - def iter_pages_with_unit_indices(self, sources: Union[str, bytes, List[Union[str, bytes]]]): + def iter_pages_with_unit_indices( + self, sources: Union[str, bytes, List[Union[str, bytes]]] + ): """Stream pages one at a time with unit index per page. Yields (page, unit_idx) so the pipeline can enqueue each page as soon diff --git a/glmocr/layout/layout_detector.py b/glmocr/layout/layout_detector.py index d0607b3..691cc03 100644 --- a/glmocr/layout/layout_detector.py +++ b/glmocr/layout/layout_detector.py @@ -3,8 +3,7 @@ from __future__ import annotations -from pathlib import Path -from typing import TYPE_CHECKING, List, Dict, Optional +from typing import TYPE_CHECKING, List, Dict import cv2 import torch @@ -78,7 +77,7 @@ def start(self): self._model = self._model.to(self._device) if self.id2label is None: self.id2label = self._model.config.id2label - + # Patch upstream _extract_polygon_points_by_masks to guard against # empty mask crops that crash cv2.resize with !ssize.empty(). def _safe_extract(boxes, masks, scale_ratio): @@ -90,8 +89,7 @@ def _safe_extract(boxes, masks, scale_ratio): x_min, y_min, x_max, y_max = boxes[i].astype(np.int32) box_w, box_h = x_max - x_min, y_max - y_min rect = np.array( - [[x_min, y_min], [x_max, y_min], - [x_max, y_max], [x_min, y_max]], + [[x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max]], dtype=np.float32, ) @@ -286,8 +284,7 @@ def process( num_images = len(images) pil_images = [ - img.convert("RGB") if img.mode != "RGB" else img - for img in images + img.convert("RGB") if img.mode != "RGB" else img for img in images ] all_paddle_format_results = [] diff --git a/glmocr/ocr_client.py b/glmocr/ocr_client.py index 62742d4..b376773 100644 --- a/glmocr/ocr_client.py +++ b/glmocr/ocr_client.py @@ -326,7 +326,9 @@ def process(self, request_data: Dict) -> Tuple[Dict, int]: "error": f"Invalid OpenAI API response format: {str(e)}" }, 500 - return {"choices": [{"message": {"content": (output or "").strip()}}]}, 200 + return { + "choices": [{"message": {"content": (output or "").strip()}}] + }, 200 status = int(response.status_code) body_preview = (response.text or "")[:500] diff --git a/glmocr/parser_result/pipeline_result.py b/glmocr/parser_result/pipeline_result.py index 605c1d4..8f821f2 100644 --- a/glmocr/parser_result/pipeline_result.py +++ b/glmocr/parser_result/pipeline_result.py @@ -71,9 +71,7 @@ def save( stem = Path(self.original_images[0]).stem if self.original_images else "result" for local_idx, (_page_idx, vis_img) in enumerate(vis_items): name = ( - f"{stem}.jpg" - if len(vis_items) == 1 - else f"{stem}_page{local_idx}.jpg" + f"{stem}.jpg" if len(vis_items) == 1 else f"{stem}_page{local_idx}.jpg" ) try: vis_img.save(target_dir / name, quality=95) diff --git a/glmocr/pipeline/_common.py b/glmocr/pipeline/_common.py index 7155607..8eaa166 100644 --- a/glmocr/pipeline/_common.py +++ b/glmocr/pipeline/_common.py @@ -39,14 +39,12 @@ def make_original_inputs(sources: List[Union[str, bytes]]) -> List[str]: def extract_ocr_content(response: Dict[str, Any]) -> str: """Pull the content string out of an OpenAI-style OCR response.""" - return ( - response.get("choices", [{}])[0].get("message", {}).get("content", "") - ) + return response.get("choices", [{}])[0].get("message", {}).get("content", "") # ── Queue message "identifier" field values ────────────────────────── # Every queue message is a dict with an "identifier" key. IDENTIFIER_IMAGE = "image" -IDENTIFIER_UNIT_DONE = "unit_done" # t1 → t2: all pages for one input unit are queued +IDENTIFIER_UNIT_DONE = "unit_done" # t1 → t2: all pages for one input unit are queued IDENTIFIER_REGION = "region" IDENTIFIER_DONE = "done" diff --git a/glmocr/pipeline/_state.py b/glmocr/pipeline/_state.py index fbab39a..83bed3e 100644 --- a/glmocr/pipeline/_state.py +++ b/glmocr/pipeline/_state.py @@ -32,7 +32,9 @@ def __init__( ): # ── Inter-thread queues ────────────────────────────────────── self.page_queue: queue.Queue[Dict[str, Any]] = queue.Queue(maxsize=page_maxsize) - self.region_queue: queue.Queue[Dict[str, Any]] = queue.Queue(maxsize=region_maxsize) + self.region_queue: queue.Queue[Dict[str, Any]] = queue.Queue( + maxsize=region_maxsize + ) # ── Per-page data (stage 1 & 2 write, main thread reads) ───── self.images_dict: Dict[int, Any] = {} @@ -78,8 +80,9 @@ def request_shutdown(self) -> None: if tracker is not None: tracker.signal_shutdown() - def safe_put(self, q: queue.Queue, msg: Dict[str, Any], - timeout: float = 0.5) -> bool: + def safe_put( + self, q: queue.Queue, msg: Dict[str, Any], timeout: float = 0.5 + ) -> bool: """Put *msg* on *q*, returning ``False`` if shutdown was requested.""" while not self._shutdown_event.is_set(): try: @@ -124,8 +127,7 @@ def add_recognition_result(self, page_idx: int, region: Dict) -> None: tracker.on_region_done(page_idx) def get_grouped_results(self, page_indices: List[int]) -> List[List[Dict]]: - """Return recognition results grouped by page for the given indices. - """ + """Return recognition results grouped by page for the given indices.""" with self._results_lock: return [list(self._results_by_page.get(pi, [])) for pi in page_indices] diff --git a/glmocr/pipeline/_workers.py b/glmocr/pipeline/_workers.py index 6e0dbca..3eaf846 100644 --- a/glmocr/pipeline/_workers.py +++ b/glmocr/pipeline/_workers.py @@ -47,6 +47,7 @@ # Stage 1: Data Loading # ====================================================================== + def data_loading_worker( state: PipelineState, page_loader: "PageLoader", @@ -81,21 +82,27 @@ def data_loading_worker( break if prev_unit_idx is not None and unit_idx != prev_unit_idx: - if not state.safe_put(state.page_queue, { - "identifier": IDENTIFIER_UNIT_DONE, - "unit_idx": prev_unit_idx, - }): + if not state.safe_put( + state.page_queue, + { + "identifier": IDENTIFIER_UNIT_DONE, + "unit_idx": prev_unit_idx, + }, + ): break sent_unit_done.add(prev_unit_idx) state.register_page(page_idx, unit_idx) state.images_dict[page_idx] = page - if not state.safe_put(state.page_queue, { - "identifier": IDENTIFIER_IMAGE, - "page_idx": page_idx, - "unit_idx": unit_idx, - "image": page, - }): + if not state.safe_put( + state.page_queue, + { + "identifier": IDENTIFIER_IMAGE, + "page_idx": page_idx, + "unit_idx": unit_idx, + "image": page, + }, + ): break unit_indices_list.append(unit_idx) page_idx += 1 @@ -105,18 +112,24 @@ def data_loading_worker( if not state.is_shutdown: if prev_unit_idx is not None: - state.safe_put(state.page_queue, { - "identifier": IDENTIFIER_UNIT_DONE, - "unit_idx": prev_unit_idx, - }) + state.safe_put( + state.page_queue, + { + "identifier": IDENTIFIER_UNIT_DONE, + "unit_idx": prev_unit_idx, + }, + ) sent_unit_done.add(prev_unit_idx) for u in range(num_units): if u not in sent_unit_done: - state.safe_put(state.page_queue, { - "identifier": IDENTIFIER_UNIT_DONE, - "unit_idx": u, - }) + state.safe_put( + state.page_queue, + { + "identifier": IDENTIFIER_UNIT_DONE, + "unit_idx": u, + }, + ) state.safe_put(state.page_queue, {"identifier": IDENTIFIER_DONE}) except Exception as e: @@ -130,6 +143,7 @@ def data_loading_worker( # Stage 2: Layout Detection # ====================================================================== + def layout_worker( state: PipelineState, layout_detector: "BaseLayoutDetector", @@ -174,8 +188,12 @@ def layout_worker( if len(batch_images) >= layout_detector.batch_size: _flush_layout_batch( - state, layout_detector, batch_images, batch_page_indices, - save_visualization, global_start_idx, + state, + layout_detector, + batch_images, + batch_page_indices, + save_visualization, + global_start_idx, use_polygon=use_polygon, ) global_start_idx += len(batch_page_indices) @@ -187,8 +205,12 @@ def layout_worker( unit_idx = msg["unit_idx"] if batch_images: _flush_layout_batch( - state, layout_detector, batch_images, batch_page_indices, - save_visualization, global_start_idx, + state, + layout_detector, + batch_images, + batch_page_indices, + save_visualization, + global_start_idx, use_polygon=use_polygon, ) global_start_idx += len(batch_page_indices) @@ -198,20 +220,25 @@ def layout_worker( pages_for_unit = unit_page_indices.get(unit_idx, []) region_count = sum( - len(state.layout_results_dict.get(pi, [])) - for pi in pages_for_unit + len(state.layout_results_dict.get(pi, [])) for pi in pages_for_unit ) state.finalize_unit(unit_idx, region_count) logger.debug( "Unit %d finalised: %d pages, %d regions", - unit_idx, len(pages_for_unit), region_count, + unit_idx, + len(pages_for_unit), + region_count, ) elif identifier == IDENTIFIER_DONE: if batch_images: _flush_layout_batch( - state, layout_detector, batch_images, batch_page_indices, - save_visualization, global_start_idx, + state, + layout_detector, + batch_images, + batch_page_indices, + save_visualization, + global_start_idx, use_polygon=use_polygon, ) state.safe_put(state.region_queue, {"identifier": IDENTIFIER_DONE}) @@ -246,7 +273,8 @@ def _flush_layout_batch( except Exception as e: logger.warning( "Layout detection failed for pages %s, skipping batch: %s", - batch_page_indices, e, + batch_page_indices, + e, ) for page_idx in batch_page_indices: state.layout_results_dict[page_idx] = [] @@ -263,17 +291,22 @@ def _flush_layout_batch( except Exception as e: logger.warning( "Failed to crop region on page %d (bbox=%s), skipping: %s", - page_idx, region.get("bbox_2d"), e, + page_idx, + region.get("bbox_2d"), + e, ) region["content"] = "" state.add_recognition_result(page_idx, region) continue - if not state.safe_put(state.region_queue, { - "identifier": IDENTIFIER_REGION, - "page_idx": page_idx, - "cropped_image": cropped, - "region": region, - }): + if not state.safe_put( + state.region_queue, + { + "identifier": IDENTIFIER_REGION, + "page_idx": page_idx, + "cropped_image": cropped, + "region": region, + }, + ): return @@ -281,6 +314,7 @@ def _flush_layout_batch( # Stage 3: VLM Recognition # ====================================================================== + def recognition_worker( state: PipelineState, page_loader: "PageLoader", @@ -328,7 +362,8 @@ def recognition_worker( state.add_recognition_result(msg["page_idx"], msg["region"]) else: req = page_loader.build_request_from_image( - msg["cropped_image"], msg["region"]["task_type"], + msg["cropped_image"], + msg["region"]["task_type"], ) del msg["cropped_image"] future = executor.submit(ocr_client.process, req) @@ -359,6 +394,7 @@ def recognition_worker( # Recognition helpers # ------------------------------------------------------------------ + def _collect_done_futures( futures: Dict[Any, Dict[str, Any]], state: PipelineState, @@ -392,7 +428,6 @@ def _handle_future_result( state.add_recognition_result(page_idx, region) - def _wait_for_any(futures: Dict) -> None: done_list = [f for f in futures if f.done()] if not done_list: diff --git a/glmocr/pipeline/pipeline.py b/glmocr/pipeline/pipeline.py index 4674739..af70c21 100644 --- a/glmocr/pipeline/pipeline.py +++ b/glmocr/pipeline/pipeline.py @@ -25,9 +25,17 @@ from glmocr.postprocess import ResultFormatter from glmocr.utils.logging import get_logger -from glmocr.pipeline._common import extract_image_sources, extract_ocr_content, make_original_inputs +from glmocr.pipeline._common import ( + extract_image_sources, + extract_ocr_content, + make_original_inputs, +) from glmocr.pipeline._state import PipelineState -from glmocr.pipeline._workers import data_loading_worker, layout_worker, recognition_worker +from glmocr.pipeline._workers import ( + data_loading_worker, + layout_worker, + recognition_worker, +) from glmocr.pipeline._unit_tracker import UnitTracker if TYPE_CHECKING: @@ -71,7 +79,8 @@ def __init__( self.page_loader = PageLoader(config.page_loader) self.ocr_client = OCRClient(config.ocr_api) self.result_formatter = ( - result_formatter if result_formatter is not None + result_formatter + if result_formatter is not None else ResultFormatter(config.result_formatter) ) @@ -82,6 +91,7 @@ def __init__( if PPDocLayoutDetector is None: from glmocr.layout import _raise_layout_import_error + _raise_layout_import_error() self.layout_detector = PPDocLayoutDetector(config.layout) @@ -143,7 +153,12 @@ def process( ) t2 = threading.Thread( target=layout_worker, - args=(state, self.layout_detector, save_layout_visualization, self.config.layout.use_polygon), + args=( + state, + self.layout_detector, + save_layout_visualization, + self.config.layout.use_polygon, + ), daemon=True, ) t3 = threading.Thread( @@ -256,19 +271,19 @@ def _build_raw_json(grouped_results: List[List[Dict]]) -> list: """ raw = [] for page_results in grouped_results: - sorted_results = sorted( - page_results, key=lambda x: x.get("index", 0) + sorted_results = sorted(page_results, key=lambda x: x.get("index", 0)) + raw.append( + [ + { + "index": i, + "label": r.get("label", "text"), + "content": r.get("content", ""), + "bbox_2d": r.get("bbox_2d"), + "polygon": r.get("polygon"), + } + for i, r in enumerate(sorted_results) + ] ) - raw.append([ - { - "index": i, - "label": r.get("label", "text"), - "content": r.get("content", ""), - "bbox_2d": r.get("bbox_2d"), - "polygon": r.get("polygon"), - } - for i, r in enumerate(sorted_results) - ]) return raw def _process_passthrough( @@ -310,7 +325,9 @@ def _emit_results( next_to_emit = 0 num_units = tracker.num_units - while (next_to_emit < num_units) if preserve_order else (len(built) < num_units): + while ( + (next_to_emit < num_units) if preserve_order else (len(built) < num_units) + ): if preserve_order: while next_to_emit in pending: yield pending.pop(next_to_emit) @@ -342,7 +359,8 @@ def _emit_results( cropped_images = state.collect_cropped_images_for_unit(page_indices) raw_json = self._build_raw_json(grouped) json_u, md_u, image_files = self.result_formatter.process( - grouped, cropped_images=cropped_images or None, + grouped, + cropped_images=cropped_images or None, ) vis_images = {} diff --git a/glmocr/postprocess/result_formatter.py b/glmocr/postprocess/result_formatter.py index 6d0a759..1786320 100644 --- a/glmocr/postprocess/result_formatter.py +++ b/glmocr/postprocess/result_formatter.py @@ -184,7 +184,9 @@ def process( # Skip empty or failed content (after formatting) if result["label"] != "image": content = result.get("content") - if content is None or (isinstance(content, str) and content.strip() == ""): + if content is None or ( + isinstance(content, str) and content.strip() == "" + ): continue # Update index @@ -222,12 +224,12 @@ def process( bbox = result.get("bbox_2d", []) key = (page_idx, *bbox) if bbox else None img = ( - cropped_images.get(key) - if cropped_images and key - else None + cropped_images.get(key) if cropped_images and key else None ) if img is not None: - filename = f"{image_prefix}_page{page_idx}_idx{image_counter}.jpg" + filename = ( + f"{image_prefix}_page{page_idx}_idx{image_counter}.jpg" + ) rel_path = f"imgs/{filename}" image_files[filename] = img result["image_path"] = rel_path @@ -305,14 +307,14 @@ def _format_content(self, content: Any, label: str, native_label: str) -> str: # Formula formatting if label == "formula": if ( - content.startswith("$$") - or content.startswith("\\[") + content.startswith("$$") + or content.startswith("\\[") or content.startswith("\\(") ): content = content[2:].strip() if ( - content.endswith("$$") - or content.endswith("\\]") + content.endswith("$$") + or content.endswith("\\]") or content.endswith("\\)") ): content = content[:-2].strip() diff --git a/glmocr/utils/image_utils.py b/glmocr/utils/image_utils.py index 65b404a..502dddd 100644 --- a/glmocr/utils/image_utils.py +++ b/glmocr/utils/image_utils.py @@ -2,8 +2,9 @@ import io import cv2 -import base64 +import fitz import math +import base64 from io import BytesIO import numpy as np @@ -264,8 +265,6 @@ def image_tensor_to_base64(image_tensor, image_format): # PDF rendering via PyMuPDF (fitz) # ----------------------------------------------------------------------------- -import fitz - def _render_page_to_pil(page, dpi: int = 200, max_width_or_height: int = 3500): """Render a PDF page to PIL Image via PyMuPDF. @@ -379,7 +378,9 @@ def pdf_to_images_pil_iter( ) yield image except Exception as e: - logger.warning("Skipping page %d of '%s' (render failed): %s", i, label, e) + logger.warning( + "Skipping page %d of '%s' (render failed): %s", i, label, e + ) finally: if doc is not None: doc.close() diff --git a/glmocr/utils/markdown_utils.py b/glmocr/utils/markdown_utils.py index 9c36a1d..065bf35 100644 --- a/glmocr/utils/markdown_utils.py +++ b/glmocr/utils/markdown_utils.py @@ -38,8 +38,10 @@ def resolve_image_regions( """ has_images = any( r.get("label") == "image" - for page in json_result if isinstance(page, list) - for r in page if isinstance(r, dict) + for page in json_result + if isinstance(page, list) + for r in page + if isinstance(r, dict) ) if not has_images: return json_result, markdown_result, {} @@ -49,7 +51,9 @@ def resolve_image_regions( try: if path.suffix.lower() == ".pdf" and path.is_file(): loaded_images = pdf_to_images_pil( - str(path), dpi=200, max_width_or_height=3500, + str(path), + dpi=200, + max_width_or_height=3500, ) elif path.is_file(): img = Image.open(str(path)) @@ -98,7 +102,9 @@ def resolve_image_regions( except Exception as e: logger.warning( "Failed to crop image (page=%d, bbox=%s): %s", - page_idx, bbox, e, + page_idx, + bbox, + e, ) page_copy.append(region_copy) updated_json.append(page_copy) diff --git a/glmocr/utils/result_postprocess_utils.py b/glmocr/utils/result_postprocess_utils.py index c990aa0..2f57677 100644 --- a/glmocr/utils/result_postprocess_utils.py +++ b/glmocr/utils/result_postprocess_utils.py @@ -114,15 +114,14 @@ def clean_formula_number(number_content: str) -> str: number_clean = number_clean[1:-1] return number_clean + def normalize_inline_formula(content: str) -> str: """Normalize inline formula spacing. ``$ x $`` → ``$x$``, and ensure a space between surrounding text and the ``$...$`` delimiter. """ - INLINE_FORMULA_RE = re.compile( - r"(? str: return content parts.append(content[last_end:]) - return "".join(parts) \ No newline at end of file + return "".join(parts) From b9cc4cb31cf526823bf24fda2de495e0edd3dbd3 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Fri, 27 Mar 2026 09:09:30 +0000 Subject: [PATCH 37/38] Passed the pre-commit code check --- glmocr/cli.py | 4 ++-- glmocr/parser_result/pipeline_result.py | 6 +----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/glmocr/cli.py b/glmocr/cli.py index bf13369..da2846f 100644 --- a/glmocr/cli.py +++ b/glmocr/cli.py @@ -227,7 +227,7 @@ def main(): metavar=("KEY", "VALUE"), dest="config_overrides", help="Override a config value using dotted path, e.g. " - "--set pipeline.ocr_api.api_port 8080", + "--set pipeline.ocr_api.api_port 8080", ) args = parser.parse_args() @@ -246,7 +246,7 @@ def main(): save_layout_vis = not args.no_layout_vis dotted_overrides: dict = {} - for key, value in (args.config_overrides or []): + for key, value in args.config_overrides or []: dotted_overrides[key] = _auto_coerce(value) with GlmOcr( diff --git a/glmocr/parser_result/pipeline_result.py b/glmocr/parser_result/pipeline_result.py index 3a549ab..800084c 100644 --- a/glmocr/parser_result/pipeline_result.py +++ b/glmocr/parser_result/pipeline_result.py @@ -74,11 +74,7 @@ def save( ) single = len(self.layout_vis_images) == 1 for page_idx, vis_img in self.layout_vis_images.items(): - name = ( - f"{stem_name}.jpg" - if single - else f"{stem_name}_page{page_idx}.jpg" - ) + name = f"{stem_name}.jpg" if single else f"{stem_name}_page{page_idx}.jpg" try: vis_img.save(target_dir / name, quality=95) except Exception as e: From 77a743021a118bbad2c5074010614e1e531d7df5 Mon Sep 17 00:00:00 2001 From: xueyadong Date: Fri, 27 Mar 2026 09:45:47 +0000 Subject: [PATCH 38/38] Add preserve_order argument to stream parsing test --- glmocr/tests/test_unit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/glmocr/tests/test_unit.py b/glmocr/tests/test_unit.py index c544d26..62ba50d 100644 --- a/glmocr/tests/test_unit.py +++ b/glmocr/tests/test_unit.py @@ -1248,6 +1248,7 @@ def test_parse_stream_selfhosted_delegates(self): parser._stream_parse_selfhosted.assert_called_once_with( ["a.png", "b.png"], save_layout_visualization=False, + preserve_order=True, )