Skip to content

Commit 10d3fab

Browse files
authored
add model to self-hosted mode (#162)
Signed-off-by: JaredforReal <w13431838023@gmail.com>
1 parent 8d41946 commit 10d3fab

2 files changed

Lines changed: 27 additions & 1 deletion

File tree

glmocr/config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,6 @@ def from_env(
430430
_KW_MAP = {
431431
"api_key": "pipeline.maas.api_key",
432432
"api_url": "pipeline.maas.api_url",
433-
"model": "pipeline.maas.model",
434433
"mode": "pipeline.maas.enabled",
435434
"timeout": "pipeline.maas.request_timeout",
436435
"enable_layout": "pipeline.enable_layout",
@@ -442,6 +441,15 @@ def from_env(
442441
"cuda_visible_devices": "pipeline.layout.cuda_visible_devices",
443442
"layout_device": "pipeline.layout.device",
444443
}
444+
445+
# `model` is shared by both MaaS and self-hosted modes.
446+
# Keep MaaS behavior while also forwarding it to OCR API so that
447+
# `GlmOcr(mode="selfhosted", model="...")` works as expected.
448+
if "model" in overrides and overrides["model"] is not None:
449+
model_value = str(overrides["model"])
450+
_set_nested(data, "pipeline.maas.model", model_value)
451+
_set_nested(data, "pipeline.ocr_api.model", model_value)
452+
445453
for kw, dotted in _KW_MAP.items():
446454
if kw in overrides and overrides[kw] is not None:
447455
raw = overrides[kw]

glmocr/tests/test_unit.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,6 +1353,24 @@ def test_explicit_selfhosted_mode(self, monkeypatch):
13531353
assert parser._use_maas is False
13541354
parser.close()
13551355

1356+
def test_selfhosted_model_kwarg_is_forwarded_to_ocr_api(self, monkeypatch):
1357+
"""model=... should configure self-hosted OCR request model."""
1358+
from glmocr.config import _ENV_MAP, ENV_PREFIX
1359+
1360+
for suffix in _ENV_MAP:
1361+
monkeypatch.delenv(f"{ENV_PREFIX}{suffix}", raising=False)
1362+
monkeypatch.setattr("glmocr.config._find_dotenv", lambda: None)
1363+
1364+
with patch("glmocr.pipeline.Pipeline") as mock_pipeline:
1365+
mock_pipeline.return_value.start = MagicMock()
1366+
mock_pipeline.return_value.enable_layout = False
1367+
from glmocr.api import GlmOcr
1368+
1369+
parser = GlmOcr(mode="selfhosted", model="glm-ocr")
1370+
assert parser._use_maas is False
1371+
assert parser.config_model.pipeline.ocr_api.model == "glm-ocr"
1372+
parser.close()
1373+
13561374

13571375
class TestOCRClientOllamaConfig:
13581376
"""Tests for OCRClient initialization with Ollama api_mode."""

0 commit comments

Comments
 (0)