diff --git a/glmocr/config.py b/glmocr/config.py index 789e206..00211a9 100644 --- a/glmocr/config.py +++ b/glmocr/config.py @@ -198,6 +198,7 @@ class LayoutConfig(_BaseConfig): layout_merge_bboxes_mode: Union[str, Dict[int, str]] = "large" label_task_mapping: Optional[Dict[str, Any]] = None use_polygon: bool = False + id2label: Optional[Dict[Union[int, str], str]] = None @field_validator("device") @classmethod diff --git a/glmocr/layout/layout_detector.py b/glmocr/layout/layout_detector.py index c4f1680..f5e96ed 100644 --- a/glmocr/layout/layout_detector.py +++ b/glmocr/layout/layout_detector.py @@ -77,8 +77,12 @@ def start(self): self._device = "cpu" self._model = self._model.to(self._device) if self.id2label is None: - self.id2label = self._model.config.id2label - + self.id2label = getattr(self._model.config, "id2label", None) + if self.id2label is None: + raise RuntimeError( + "Missing id2label in both layout config and model config; " + "please set pipeline.layout.id2label." + ) # Patch upstream _extract_polygon_points_by_masks to guard against # empty mask crops that crash cv2.resize with !ssize.empty(). def _safe_extract(boxes, masks, scale_ratio): @@ -126,6 +130,11 @@ def _safe_extract(boxes, masks, scale_ratio): return polygon_points self._image_processor._extract_polygon_points_by_masks = _safe_extract + if self.label_task_mapping is None: + logger.warning( + "layout.label_task_mapping is missing; defaulting all labels to text" + ) + self.label_task_mapping = {"text": list(self.id2label.values())} logger.debug(f"PP-DocLayoutV3 loaded on device: {self._device}") def stop(self): diff --git a/glmocr/tests/test_unit.py b/glmocr/tests/test_unit.py index 62ba50d..6a01f93 100644 --- a/glmocr/tests/test_unit.py +++ b/glmocr/tests/test_unit.py @@ -208,6 +208,66 @@ def test_detector_device_selection_auto_cuda(self): assert det._device == "cuda:1" + def test_detector_defaults_label_task_mapping_from_model_id2label(self): + """Missing label_task_mapping falls back to a text bucket from id2label.""" + self._require_layout_runtime() + from glmocr.config import LayoutConfig + from glmocr.layout.layout_detector import PPDocLayoutDetector + + cfg = LayoutConfig(model_dir="dummy", device="cpu") + det = PPDocLayoutDetector(cfg) + + mock_model = MagicMock() + mock_model.to = MagicMock(return_value=mock_model) + mock_model.eval = MagicMock() + mock_model.config = MagicMock() + mock_model.config.id2label = {0: "text", 1: "title"} + mock_proc = MagicMock() + + with ( + patch( + "glmocr.layout.layout_detector.PPDocLayoutV3ForObjectDetection.from_pretrained", + return_value=mock_model, + ), + patch( + "glmocr.layout.layout_detector.PPDocLayoutV3ImageProcessorFast.from_pretrained", + return_value=mock_proc, + ), + ): + det.start() + + assert det.id2label == {0: "text", 1: "title"} + assert det.label_task_mapping == {"text": ["text", "title"]} + + def test_detector_raises_when_id2label_missing_everywhere(self): + """Missing id2label in both config and model config raises a clear error.""" + self._require_layout_runtime() + from glmocr.config import LayoutConfig + from glmocr.layout.layout_detector import PPDocLayoutDetector + + cfg = LayoutConfig(model_dir="dummy", device="cpu") + det = PPDocLayoutDetector(cfg) + + mock_model = MagicMock() + mock_model.to = MagicMock(return_value=mock_model) + mock_model.eval = MagicMock() + mock_model.config = MagicMock() + mock_model.config.id2label = None + mock_proc = MagicMock() + + with ( + patch( + "glmocr.layout.layout_detector.PPDocLayoutV3ForObjectDetection.from_pretrained", + return_value=mock_model, + ), + patch( + "glmocr.layout.layout_detector.PPDocLayoutV3ImageProcessorFast.from_pretrained", + return_value=mock_proc, + ), + pytest.raises(RuntimeError, match="Missing id2label"), + ): + det.start() + class TestPageLoader: """Tests for PageLoader.""" diff --git a/pyproject.toml b/pyproject.toml index 9c37852..63e5dc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ layout = [ "transformers>=5.3.0", "sentencepiece>=0.2.0", "accelerate>=1.13.0", + "opencv-python>=4.8.0", ] server = [ @@ -55,6 +56,7 @@ selfhosted = [ "transformers>=5.3.0", "sentencepiece>=0.2.0", "accelerate>=1.13.0", + "opencv-python>=4.8.0", "pypdfium2>=5.6.0", ] @@ -64,6 +66,7 @@ all = [ "transformers>=5.3.0", "sentencepiece>=0.2.0", "accelerate>=1.13.0", + "opencv-python>=4.8.0", "pypdfium2>=5.6.0", "flask>=3.1.0", ]