diff --git a/glmocr/__init__.py b/glmocr/__init__.py index 53117ce..614ea7f 100644 --- a/glmocr/__init__.py +++ b/glmocr/__init__.py @@ -21,6 +21,7 @@ "MissingApiKeyError", "GlmOcr", "parse", + "extract", ] @@ -34,6 +35,7 @@ "MissingApiKeyError": ("maas_client", "MissingApiKeyError"), "GlmOcr": ("api", "GlmOcr"), "parse": ("api", "parse"), + "extract": ("api", "extract"), } @@ -56,7 +58,7 @@ def __dir__(): if TYPE_CHECKING: # pragma: no cover from . import dataloader, layout, postprocess, utils - from .api import GlmOcr, parse + from .api import GlmOcr, parse, extract from .config import GlmOcrConfig, load_config from .maas_client import MaaSClient, MissingApiKeyError from .parser_result import PipelineResult diff --git a/glmocr/api.py b/glmocr/api.py index f347c48..d06199f 100644 --- a/glmocr/api.py +++ b/glmocr/api.py @@ -18,8 +18,11 @@ print(results[0].to_dict()) """ +import json import os import re +import shutil +import tempfile from typing import Any, Dict, Generator, List, Literal, Optional, Union, overload from pathlib import Path @@ -32,6 +35,94 @@ # Backward compatibility: ParseResult is PipelineResult ParseResult = PipelineResult +# Default extraction prompt used by GLM-OCR information extraction mode. +_DEFAULT_EXTRACTION_PROMPT = "请按下列JSON格式输出图中信息:" + + +def _json_schema_to_template(schema: Dict[str, Any]) -> Any: + """Convert a JSON Schema dict to an empty-value template for GLM-OCR. + + Handles ``$defs``/``definitions``, ``$ref``, ``allOf``/``anyOf``/``oneOf``, + nested objects, and arrays. All leaf values become ``""``. + """ + defs = schema.get("$defs", schema.get("definitions", {})) + + def _resolve(s: Any) -> Any: + if not isinstance(s, dict): + return "" + if "$ref" in s: + ref_name = s["$ref"].rsplit("/", 1)[-1] + if ref_name in defs: + return _convert(defs[ref_name]) + return "" + return _convert(s) + + def _convert(s: dict) -> Any: + for key in ("allOf", "anyOf", "oneOf"): + if key in s and s[key]: + return _resolve(s[key][0]) + schema_type = s.get("type", "string") + if schema_type == "object": + props = s.get("properties", {}) + return {k: _resolve(v) for k, v in props.items()} + if schema_type == "array": + return [_resolve(s.get("items", {}))] + return "" + + return _resolve(schema) + + +def _resolve_schema_template(schema: Any) -> Dict[str, Any]: + """Normalise *schema* into the empty-value JSON template GLM-OCR expects. + + Accepted inputs: + + * **dict without ``"type"``/``"properties"``** – treated as a ready-made + template (the format shown in the GLM-OCR docs). + * **JSON Schema dict** (has ``"type": "object"`` + ``"properties"``) – + converted automatically. This is what Zod produces via + ``zodToJsonSchema()``. + * **Pydantic model class** – calls ``model_json_schema()`` then converts. + """ + # Pydantic v2 model class + if isinstance(schema, type) and hasattr(schema, "model_json_schema"): + return _json_schema_to_template(schema.model_json_schema()) + + if not isinstance(schema, dict): + raise TypeError( + f"schema must be a dict or Pydantic model class, got {type(schema)}" + ) + + # JSON Schema dict + if schema.get("type") == "object" and "properties" in schema: + return _json_schema_to_template(schema) + + # Already a raw template dict + return schema + + +def _parse_json_from_text(text: str) -> Any: + """Best-effort extraction of a JSON object from *text*. + + Tries direct ``json.loads`` first, then falls back to extracting from + Markdown fenced code blocks. + """ + text = text.strip() + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError): + pass + + # Try Markdown ```json ... ``` blocks + m = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", text, re.DOTALL) + if m: + try: + return json.loads(m.group(1).strip()) + except json.JSONDecodeError: + pass + + raise ValueError(f"Failed to parse extraction response as JSON: {text[:500]}") + class GlmOcr: """Main GLM-OCR entrypoint. @@ -154,10 +245,13 @@ def __init__( self._pipeline.start() logger.info("GLM-OCR initialized in self-hosted mode") + # Type alias for accepted input sources + InputSource = Union[str, bytes, Path] + @overload def parse( self, - images: str, + images: "GlmOcr.InputSource", *, stream: Literal[False] = ..., save_layout_visualization: bool = ..., @@ -167,7 +261,7 @@ def parse( @overload def parse( self, - images: List[str], + images: List["GlmOcr.InputSource"], *, stream: Literal[False] = ..., save_layout_visualization: bool = ..., @@ -177,7 +271,7 @@ def parse( @overload def parse( self, - images: Union[str, List[str]], + images: Union["GlmOcr.InputSource", List["GlmOcr.InputSource"]], *, stream: Literal[True], save_layout_visualization: bool = ..., @@ -186,7 +280,7 @@ def parse( def parse( self, - images: Union[str, List[str]], + images: Union["GlmOcr.InputSource", List["GlmOcr.InputSource"]], *, stream: bool = False, save_layout_visualization: bool = True, @@ -196,11 +290,19 @@ def parse( ]: """Predict / parse images or documents. - Supports local paths and URLs (file://, http://, https://, data:). + Supports local paths, ``Path`` objects, URLs (file://, http://, https://, + data:// — including presigned URLs), and raw ``bytes``. Supports image files (jpg, png, bmp, gif, webp) and PDF files. Args: - images: Image path/URL — a single ``str`` or a ``list`` of strings. + images: A single input or list of inputs. Each input can be: + + - ``str``: local file path, or URL (http/https/file/data). + Presigned URLs (e.g. S3) are supported. + - ``bytes``: raw file content (image or PDF). + Useful for multipart/form-data uploads. + - ``Path``: a ``pathlib.Path`` to a local file. + stream: If ``True``, yields one :class:`PipelineResult` at a time (avoids holding all results in memory). If ``False``, returns a single result or a list, depending on *images*. @@ -210,23 +312,31 @@ def parse( Returns: - When ``stream=False`` (default): a single ``PipelineResult`` if *images* - is a ``str``, or a ``List[PipelineResult]`` if *images* is a list. + is a single input, or a ``List[PipelineResult]`` if *images* is a list. - When ``stream=True``: a generator that yields one ``PipelineResult`` per input. Example: # Single file — returns one PipelineResult result = parser.parse("image.png") - result.save(output_dir="./output") - # Multiple files — returns a list - results = parser.parse(["img1.png", "doc.pdf"]) + # Path object + result = parser.parse(Path("document.pdf")) + + # Presigned URL + result = parser.parse("https://bucket.s3.amazonaws.com/doc.pdf?X-Amz-...") + + # Raw bytes (e.g. from a multipart/form-data upload) + result = parser.parse(uploaded_file.read()) + + # Mixed list + results = parser.parse([b"...pdf bytes...", "https://presigned/img.png"]) # Stream to avoid large in-memory results for r in parser.parse(["a.pdf", "b.pdf"], stream=True): r.save(output_dir="./output") """ - _single = isinstance(images, str) + _single = isinstance(images, (str, bytes, Path)) if _single: images = [images] @@ -240,6 +350,51 @@ def parse( return result_list[0] if _single else result_list + @staticmethod + def _guess_suffix(data: bytes) -> str: + """Guess file suffix from magic bytes.""" + if data[:5] == b"%PDF-": + return ".pdf" + if data[:8] == b"\x89PNG\r\n\x1a\n": + return ".png" + if data[:3] == b"\xff\xd8\xff": + return ".jpg" + if data[:4] == b"RIFF" and data[8:12] == b"WEBP": + return ".webp" + if data[:3] == b"GIF": + return ".gif" + if data[:2] == b"BM": + return ".bmp" + return ".bin" + + def _resolve_inputs( + self, images: List[Union[str, bytes, Path]] + ) -> tuple: + """Convert bytes/Path inputs to file path strings. + + Returns: + (resolved_paths, temp_dir) — *temp_dir* is ``None`` when no temp + files were created; otherwise the caller must clean it up. + """ + resolved: List[str] = [] + temp_dir: Optional[str] = None + + for idx, img in enumerate(images): + if isinstance(img, bytes): + if temp_dir is None: + temp_dir = tempfile.mkdtemp(prefix="glmocr_upload_") + suffix = self._guess_suffix(img) + path = os.path.join(temp_dir, f"input_{idx}{suffix}") + with open(path, "wb") as f: + f.write(img) + resolved.append(path) + elif isinstance(img, Path): + resolved.append(str(img.absolute())) + else: + resolved.append(str(img)) + + return resolved, temp_dir + def _parse_stream( self, images: List[str], @@ -433,39 +588,42 @@ def _maas_response_to_pipeline_result( def _parse_selfhosted( self, - images: List[str], + images: List[Union[str, bytes, Path]], save_layout_visualization: bool = True, ) -> List[PipelineResult]: """Parse using self-hosted vLLM/SGLang pipeline.""" - import tempfile - - messages = [{"role": "user", "content": []}] - for image in images: - if image.startswith(("http://", "https://", "data:", "file://")): - url = image - else: - url = f"file://{Path(image).absolute()}" - messages[0]["content"].append( - {"type": "image_url", "image_url": {"url": url}} - ) - request_data = {"messages": messages} + resolved, temp_dir = self._resolve_inputs(images) + try: + messages = [{"role": "user", "content": []}] + for image in resolved: + if image.startswith(("http://", "https://", "data:", "file://")): + url = image + else: + url = f"file://{Path(image).absolute()}" + messages[0]["content"].append( + {"type": "image_url", "image_url": {"url": url}} + ) + request_data = {"messages": messages} - layout_vis_dir = None - if self._pipeline.enable_layout and save_layout_visualization: - layout_vis_dir = tempfile.mkdtemp(prefix="layout_vis_") + layout_vis_dir = None + if self._pipeline.enable_layout and save_layout_visualization: + layout_vis_dir = tempfile.mkdtemp(prefix="layout_vis_") - results = list( - self._pipeline.process( - request_data, - save_layout_visualization=save_layout_visualization, - layout_vis_output_dir=layout_vis_dir, + results = list( + self._pipeline.process( + request_data, + save_layout_visualization=save_layout_visualization, + layout_vis_output_dir=layout_vis_dir, + ) ) - ) - return results + return results + finally: + if temp_dir: + shutil.rmtree(temp_dir, ignore_errors=True) def _stream_parse_selfhosted( self, - images: List[str], + images: List[Union[str, bytes, Path]], save_layout_visualization: bool = True, ) -> Generator[PipelineResult, None, None]: """Streaming variant of self-hosted parse(). @@ -473,29 +631,143 @@ def _stream_parse_selfhosted( Wraps ``Pipeline.process(...)`` and yields results as soon as they become available from the async pipeline. """ - import tempfile + resolved, temp_dir = self._resolve_inputs(images) + try: + messages = [{"role": "user", "content": []}] + for image in resolved: + if image.startswith(("http://", "https://", "data:", "file://")): + url = image + else: + url = f"file://{Path(image).absolute()}" + messages[0]["content"].append( + {"type": "image_url", "image_url": {"url": url}} + ) + request_data = {"messages": messages} - messages = [{"role": "user", "content": []}] - for image in images: - if image.startswith(("http://", "https://", "data:", "file://")): - url = image - else: - url = f"file://{Path(image).absolute()}" - messages[0]["content"].append( - {"type": "image_url", "image_url": {"url": url}} + layout_vis_dir = None + if self._pipeline.enable_layout and save_layout_visualization: + layout_vis_dir = tempfile.mkdtemp(prefix="layout_vis_") + + for result in self._pipeline.process( + request_data, + save_layout_visualization=save_layout_visualization, + layout_vis_output_dir=layout_vis_dir, + ): + yield result + finally: + if temp_dir: + shutil.rmtree(temp_dir, ignore_errors=True) + + def extract( + self, + images: Union["GlmOcr.InputSource", List["GlmOcr.InputSource"]], + *, + schema: Union[Dict[str, Any], type], + prompt: Optional[str] = None, + **kwargs: Any, + ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + """Extract structured data from documents according to *schema*. + + Uses GLM-OCR's information extraction mode: the model receives the + document image together with a JSON template and returns a populated + version of that template. + + Args: + images: One or more document images (paths, URLs, bytes, or Path + objects). + schema: Describes the desired output structure. Accepts: + + - A **dict with empty values** (GLM-OCR native template):: + + {"invoice_no": "", "total": "", "items": [{"desc": "", "qty": ""}]} + + - A **JSON Schema dict** (what Zod's ``zodToJsonSchema()`` + produces):: + + {"type": "object", "properties": {"invoice_no": {"type": "string"}, ...}} + + - A **Pydantic model class**:: + + class Invoice(BaseModel): + invoice_no: str + total: str + + prompt: Custom prompt prefix. Defaults to the standard Chinese + extraction prompt used by GLM-OCR. + **kwargs: Extra parameters forwarded to the MaaS / self-hosted API. + + Returns: + A single ``dict`` when *images* is a single input, or a + ``list[dict]`` when *images* is a list. + + Raises: + ValueError: If the model response cannot be parsed as JSON. + RuntimeError: If used in self-hosted mode (not yet supported). + + Example:: + + # --- Raw template (GLM-OCR native) --- + data = parser.extract("id_card.png", schema={ + "id_number": "", + "name": "", + "date_of_birth": "", + }) + + # --- JSON Schema (from Zod via zodToJsonSchema) --- + data = parser.extract("invoice.pdf", schema={ + "type": "object", + "properties": { + "invoice_no": {"type": "string"}, + "total": {"type": "number"}, + "items": { + "type": "array", + "items": { + "type": "object", + "properties": { + "description": {"type": "string"}, + "amount": {"type": "number"}, + }, + }, + }, + }, + }) + + # --- Pydantic model --- + from pydantic import BaseModel + + class IdCard(BaseModel): + id_number: str + name: str + date_of_birth: str + + data = parser.extract("id_card.png", schema=IdCard) + """ + template = _resolve_schema_template(schema) + prefix = prompt or _DEFAULT_EXTRACTION_PROMPT + full_prompt = f"{prefix}\n{json.dumps(template, ensure_ascii=False, indent=4)}" + + _single = isinstance(images, (str, bytes, Path)) + if _single: + images = [images] + + if not self._use_maas: + raise RuntimeError( + "extract() currently requires MaaS mode. " + "Initialize with mode='maas' or set maas.enabled=true." ) - request_data = {"messages": messages} - layout_vis_dir = None - if self._pipeline.enable_layout and save_layout_visualization: - layout_vis_dir = tempfile.mkdtemp(prefix="layout_vis_") + results: List[Dict[str, Any]] = [] + for image in images: + img = image + if isinstance(img, str) and img.startswith("file://"): + img = img[7:] - for result in self._pipeline.process( - request_data, - save_layout_visualization=save_layout_visualization, - layout_vis_output_dir=layout_vis_dir, - ): - yield result + response = self._maas_client.parse(img, prompt=full_prompt, **kwargs) + md_results = response.get("md_results", "") + extracted = _parse_json_from_text(md_results) + results.append(extracted) + + return results[0] if _single else results def parse_maas( self, @@ -681,3 +953,61 @@ def parse( save_layout_visualization=save_layout_visualization, **kwargs, ) + + +def extract( + images: Union[str, List[str]], + *, + schema: Union[Dict[str, Any], type], + prompt: Optional[str] = None, + config_path: Optional[str] = None, + api_key: Optional[str] = None, + api_url: Optional[str] = None, + model: Optional[str] = None, + mode: Optional[str] = None, + timeout: Optional[int] = None, + log_level: Optional[str] = None, + env_file: Optional[str] = None, + **kwargs: Any, +) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + """Convenience function: extract structured data in one call. + + Creates a :class:`GlmOcr` instance, runs extraction, and cleans up. + + Examples:: + + import glmocr + + data = glmocr.extract( + "id_card.png", + schema={"id_number": "", "name": "", "date_of_birth": ""}, + api_key="sk-xxx", + ) + + Args: + images: Image path or URL (single ``str`` or ``list[str]``). + schema: Extraction schema (template dict, JSON Schema, or Pydantic model). + prompt: Custom extraction prompt prefix. + config_path: Config file path. + api_key: API key. + api_url: MaaS API endpoint URL. + model: Model name. + mode: ``"maas"`` or ``"selfhosted"``. + timeout: Request timeout in seconds. + log_level: Logging level. + env_file: Path to ``.env`` file. + + Returns: + A single ``dict`` or a ``list[dict]``, depending on input. + """ + with GlmOcr( + config_path=config_path, + api_key=api_key, + api_url=api_url, + model=model, + mode=mode, + timeout=timeout, + log_level=log_level, + env_file=env_file, + ) as parser: + return parser.extract(images, schema=schema, prompt=prompt, **kwargs) diff --git a/glmocr/dataloader/page_loader.py b/glmocr/dataloader/page_loader.py index 61362af..903482f 100644 --- a/glmocr/dataloader/page_loader.py +++ b/glmocr/dataloader/page_loader.py @@ -22,6 +22,8 @@ from PIL import Image +import requests as _requests + from glmocr.utils.image_utils import ( load_image_to_base64, pdf_to_images_pil, @@ -160,10 +162,23 @@ def iter_pages_with_unit_indices(self, sources: Union[str, List[str]]): for page in self._iter_source(source): yield page, unit_idx + def _download_url(self, url: str) -> bytes: + """Download content from an HTTP(S) URL (including presigned URLs).""" + resp = _requests.get(url, timeout=120) + resp.raise_for_status() + return resp.content + + def _is_pdf_bytes(self, data: bytes, content_type: str = "") -> bool: + """Check if bytes represent a PDF.""" + return "application/pdf" in content_type or data[:5] == b"%PDF-" + def _iter_source(self, source: str): """Yield pages from a single source one at a time.""" if source.startswith("file://"): file_path = source[7:] + elif source.startswith(("http://", "https://")): + yield from self._iter_url_source(source) + return else: file_path = source @@ -200,6 +215,52 @@ def _iter_pdf(self, file_path: str): ): yield image + def _load_url_source(self, url: str) -> List[Image.Image]: + """Download from HTTP(S) URL and load as image or PDF pages.""" + import tempfile as _tempfile + + data = self._download_url(url) + content_type = "" + try: + # Re-fetch would be wasteful; sniff from bytes instead. + pass + except Exception: + pass + + if self._is_pdf_bytes(data, content_type): + with _tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: + f.write(data) + temp_path = f.name + try: + return self._load_pdf(temp_path) + finally: + try: + os.unlink(temp_path) + except OSError: + pass + + return [Image.open(BytesIO(data))] + + def _iter_url_source(self, url: str): + """Download from HTTP(S) URL and yield pages (streaming for PDFs).""" + import tempfile as _tempfile + + data = self._download_url(url) + + if self._is_pdf_bytes(data): + with _tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: + f.write(data) + temp_path = f.name + try: + yield from self._iter_pdf(temp_path) + finally: + try: + os.unlink(temp_path) + except OSError: + pass + else: + yield Image.open(BytesIO(data)) + def _load_source(self, source: str) -> List[Image.Image]: """Load a single source and return a list of pages. @@ -207,6 +268,8 @@ def _load_source(self, source: str) -> List[Image.Image]: """ if source.startswith("file://"): file_path = source[7:] + elif source.startswith(("http://", "https://")): + return self._load_url_source(source) else: file_path = source @@ -229,6 +292,11 @@ def _load_image(self, source: str) -> Image.Image: elif source.startswith("file://"): return Image.open(source[7:]) + # Remote URL (including presigned URLs) + elif source.startswith(("http://", "https://")): + data = self._download_url(source) + return Image.open(BytesIO(data)) + # Local file elif os.path.isfile(source): return Image.open(source) diff --git a/glmocr/server.py b/glmocr/server.py index f3995d0..60d0fdb 100644 --- a/glmocr/server.py +++ b/glmocr/server.py @@ -1,10 +1,12 @@ """GLM-OCR SDK Flask service.""" import os +import shutil import sys +import tempfile import traceback import multiprocessing -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List try: from flask import Flask, request, jsonify @@ -16,6 +18,7 @@ jsonify = None # type: ignore _FLASK_IMPORT_ERROR = e +from glmocr.api import _resolve_schema_template, _parse_json_from_text, _DEFAULT_EXTRACTION_PROMPT from glmocr.pipeline import Pipeline from glmocr.config import load_config from glmocr.utils.logging import get_logger, configure_logging @@ -53,31 +56,89 @@ def create_app(config: "GlmOcrConfig") -> Flask: app.config["pipeline"] = pipeline app.config["doc_config"] = config + def _build_messages(image_urls: List[str]) -> dict: + """Build pipeline request_data from a list of image URL strings.""" + messages = [{"role": "user", "content": []}] + for image_url in image_urls: + messages[0]["content"].append( + {"type": "image_url", "image_url": {"url": image_url}} + ) + return {"messages": messages} + + def _format_results(results): + """Format pipeline results into a JSON response tuple.""" + if not results: + return jsonify({"json_result": None, "markdown_result": ""}), 200 + if len(results) == 1: + r = results[0] + return ( + jsonify( + { + "json_result": r.json_result, + "markdown_result": r.markdown_result or "", + } + ), + 200, + ) + json_result = [r.json_result for r in results] + markdown_result = "\n\n---\n\n".join( + r.markdown_result or "" for r in results + ) + return ( + jsonify( + { + "json_result": json_result, + "markdown_result": markdown_result, + } + ), + 200, + ) + @app.route("/glmocr/parse", methods=["POST"]) def parse(): """Document parsing endpoint. - Request: + Accepts two content types: + + **application/json**:: + { - "images": ["url1", "url2", ...], # image URLs (http/https/file/data) + "images": ["url1", "url2", ...], # URLs or presigned URLs } - Response: + **multipart/form-data**:: + + files: one or more file uploads (field name ``files``) + urls: one or more URL strings (field name ``urls``) + + Response:: + { "json_result": {...}, "markdown_result": "..." } """ - # Validate Content-Type - if request.headers.get("Content-Type") != "application/json": + content_type = (request.content_type or "").split(";")[0].strip().lower() + + if content_type == "multipart/form-data": + return _handle_multipart(pipeline) + elif content_type == "application/json": + return _handle_json(pipeline) + else: return ( jsonify( - {"error": "Invalid Content-Type. Expected 'application/json'."} + { + "error": ( + "Unsupported Content-Type. " + "Expected 'application/json' or 'multipart/form-data'." + ) + } ), 400, ) - # Parse JSON payload + def _handle_json(pipeline): + """Handle application/json requests.""" try: data = request.json except Exception: @@ -90,17 +151,58 @@ def parse(): if not images: return jsonify({"error": "No images provided"}), 400 - # Build pipeline request - messages = [{"role": "user", "content": []}] - for image_url in images: - messages[0]["content"].append( - {"type": "image_url", "image_url": {"url": image_url}} + request_data = _build_messages(images) + + try: + results = list( + pipeline.process( + request_data, + save_layout_visualization=False, + layout_vis_output_dir=None, + ) ) + return _format_results(results) + except Exception as e: + logger.error("Parse error: %s", e) + logger.debug(traceback.format_exc()) + return jsonify({"error": f"Parse error: {str(e)}"}), 500 - request_data = {"messages": messages} + def _handle_multipart(pipeline): + """Handle multipart/form-data requests (file uploads + URLs).""" + from pathlib import Path as _Path + uploaded_files = request.files.getlist("files") + url_values = request.form.getlist("urls") + + if not uploaded_files and not url_values: + return jsonify({"error": "No files or urls provided"}), 400 + + temp_dir = None try: - # Pipeline.process() yields one result per input unit; merge for single response + image_paths: List[str] = [] + + # Save uploaded files to a temp directory + if uploaded_files: + temp_dir = tempfile.mkdtemp(prefix="glmocr_upload_") + for idx, f in enumerate(uploaded_files): + filename = f.filename or f"upload_{idx}" + # Sanitise: keep only the basename to prevent path traversal + safe_name = _Path(filename).name or f"upload_{idx}" + save_path = os.path.join(temp_dir, f"{idx}_{safe_name}") + f.save(save_path) + image_paths.append(save_path) + + # Append any URL strings (presigned URLs, etc.) + for url in url_values: + url = url.strip() + if url: + image_paths.append(url) + + if not image_paths: + return jsonify({"error": "No valid files or urls provided"}), 400 + + request_data = _build_messages(image_paths) + results = list( pipeline.process( request_data, @@ -108,41 +210,486 @@ def parse(): layout_vis_output_dir=None, ) ) - if not results: - return ( - jsonify({"json_result": None, "markdown_result": ""}), - 200, + return _format_results(results) + + except Exception as e: + logger.error("Parse error: %s", e) + logger.debug(traceback.format_exc()) + return jsonify({"error": f"Parse error: {str(e)}"}), 500 + finally: + if temp_dir: + shutil.rmtree(temp_dir, ignore_errors=True) + + @app.route("/glmocr/extract", methods=["POST"]) + def extract(): + """Structured information extraction endpoint. + + Accepts a document image and a JSON schema, then returns structured + data matching the schema. The schema can be: + + - An **empty-value template** (GLM-OCR native format) + - A **JSON Schema** (e.g. from Zod's ``zodToJsonSchema()``) + + **application/json**:: + + { + "images": ["url1"], + "schema": {"invoice_no": "", "total": ""}, + "prompt": "..." // optional + } + + **multipart/form-data**:: + + files: file uploads (field name ``files``) + urls: URL strings (field name ``urls``) + schema: JSON string (field name ``schema``, required) + prompt: string (field name ``prompt``, optional) + + Response:: + + { + "data": { ... } // extracted structured data + } + """ + content_type = (request.content_type or "").split(";")[0].strip().lower() + + if content_type == "multipart/form-data": + return _handle_extract_multipart(pipeline) + elif content_type == "application/json": + return _handle_extract_json(pipeline) + else: + return ( + jsonify( + { + "error": ( + "Unsupported Content-Type. " + "Expected 'application/json' or 'multipart/form-data'." + ) + } + ), + 400, + ) + + _SCHEMALESS_EXTRACTION_PROMPT = ( + "请将以下文档内容转换为结构化JSON格式输出。" + "根据文档内容自动识别字段并组织为合理的JSON结构:" + ) + + def _build_extraction_prompt(schema_raw, prompt_override=None): + """Resolve schema and build the full extraction prompt string.""" + import json as _json + + template = _resolve_schema_template(schema_raw) + prefix = prompt_override or _DEFAULT_EXTRACTION_PROMPT + return f"{prefix}\n{_json.dumps(template, ensure_ascii=False, indent=4)}" + + def _extraction_response(results, extraction_prompt): + """Run extraction on pipeline results and return JSON response.""" + import json as _json + + extracted = [] + for r in results: + text = r.markdown_result or "" + try: + data = _parse_json_from_text(text) + except ValueError: + data = None + extracted.append(data) + + if len(extracted) == 1: + return jsonify({"data": extracted[0]}), 200 + return jsonify({"data": extracted}), 200 + + def _handle_extract_json(pipeline): + """Handle JSON extraction requests.""" + try: + data = request.json + except Exception: + return jsonify({"error": "Invalid JSON payload"}), 400 + + images = data.get("images", []) + if isinstance(images, str): + images = [images] + schema_raw = data.get("schema") + prompt_override = data.get("prompt") + + if not images: + return jsonify({"error": "No images provided"}), 400 + + if not schema_raw: + # No schema: parse first, then convert markdown to JSON + extraction_prompt = prompt_override or _SCHEMALESS_EXTRACTION_PROMPT + return _handle_schemaless_extract(pipeline, images, extraction_prompt) + + try: + extraction_prompt = _build_extraction_prompt(schema_raw, prompt_override) + except (TypeError, ValueError) as e: + return jsonify({"error": f"Invalid schema: {e}"}), 400 + + # Check if this server is backed by a MaaS-enabled GlmOcr + maas_config = app.config["doc_config"].pipeline.maas + if maas_config.enabled: + return _handle_extract_maas(images, extraction_prompt) + + # Self-hosted: inject extraction prompt into pipeline request + return _handle_extract_selfhosted(pipeline, images, extraction_prompt) + + def _handle_extract_maas(images, extraction_prompt): + """Run extraction via MaaS API.""" + import json as _json + + from glmocr.maas_client import MaaSClient + + maas_config = app.config["doc_config"].pipeline.maas + client = MaaSClient(maas_config) + client.start() + try: + extracted = [] + for image in images: + response = client.parse(image, prompt=extraction_prompt) + logger.debug( + "MaaS extract response keys: %s", list(response.keys()) ) - if len(results) == 1: - r = results[0] - return ( - jsonify( - { - "json_result": r.json_result, - "markdown_result": r.markdown_result or "", - } - ), - 200, + # Try md_results first, then fall back to content in choices + md = response.get("md_results", "") + if not md: + # Some MaaS responses use the chat-completion format + md = ( + response.get("choices", [{}])[0] + .get("message", {}) + .get("content", "") + ) + if not md: + logger.warning( + "MaaS extract: no parseable text in response. " + "Response keys: %s, raw (truncated): %s", + list(response.keys()), + str(response)[:1000], + ) + try: + data = _parse_json_from_text(md) + except ValueError as exc: + logger.warning("Extract JSON parse failed: %s", exc) + data = None + extracted.append(data) + + if len(extracted) == 1: + return jsonify({"data": extracted[0]}), 200 + return jsonify({"data": extracted}), 200 + except Exception as e: + logger.error("Extract error: %s", e) + return jsonify({"error": f"Extract error: {str(e)}"}), 500 + finally: + client.stop() + + def _is_pdf_source(image_url): + """Check if a source is a PDF (by extension or magic bytes).""" + url_lower = image_url.lower() + if url_lower.endswith(".pdf"): + return True + # Check file:// paths + if url_lower.startswith("file://") and url_lower[7:].endswith(".pdf"): + return True + # Check if the file exists and starts with PDF magic bytes + path = image_url + if path.startswith("file://"): + path = path[7:] + try: + if os.path.isfile(path): + with open(path, "rb") as f: + return f.read(5) == b"%PDF-" + except Exception: + pass + return False + + def _handle_extract_selfhosted(pipeline, images, extraction_prompt): + """Run extraction via self-hosted OCR. + + For **single-page images**: send the image + extraction prompt + directly to the VLM (bypasses layout detection). + + For **PDFs**: first run the normal parse pipeline (with layout + detection) to get accurate markdown, then send the combined + markdown text + extraction prompt to the VLM in a single call. + This lets the model see the full document context. + """ + try: + extracted = [] + for image_url in images: + if _is_pdf_source(image_url): + data = _extract_from_pdf(pipeline, image_url, extraction_prompt) + else: + data = _extract_from_image(pipeline, image_url, extraction_prompt) + extracted.append(data) + + if len(extracted) == 1: + return jsonify({"data": extracted[0]}), 200 + return jsonify({"data": extracted}), 200 + except Exception as e: + logger.error("Extract error: %s", e) + return jsonify({"error": f"Extract error: {str(e)}"}), 500 + + def _extract_from_image(pipeline, image_url, extraction_prompt): + """Extract structured data from a single image.""" + pages = pipeline.page_loader.load_pages([image_url]) + if not pages: + return None + + page = pages[0] + req = pipeline.page_loader.build_request_from_image( + page, task_type="text" + ) + # Replace the default task prompt with our extraction prompt + for msg in req.get("messages", []): + if msg.get("role") == "user" and isinstance( + msg.get("content"), list + ): + for item in msg["content"]: + if item.get("type") == "text": + item["text"] = extraction_prompt + + response, status_code = pipeline.ocr_client.process(req) + if status_code != 200: + logger.error("OCR request failed (%s): %s", status_code, response) + return None + + content = ( + response.get("choices", [{}])[0] + .get("message", {}) + .get("content", "") + ) or response.get("response", "") + + logger.debug( + "Self-hosted extract raw content (truncated): %s", + content[:500], + ) + try: + return _parse_json_from_text(content) + except ValueError as exc: + logger.warning("Extract JSON parse failed: %s", exc) + return None + + def _extract_from_pdf(pipeline, image_url, extraction_prompt): + """Extract structured data from a PDF. + + Two-phase approach: + 1. Parse the full PDF through the normal pipeline (layout + + region OCR) to get high-quality markdown. + 2. Send the combined markdown + extraction prompt to the VLM + in a single call so the model sees the complete document. + """ + # Phase 1: standard parse to get markdown + request_data = _build_messages([image_url]) + try: + results = list( + pipeline.process( + request_data, + save_layout_visualization=False, + layout_vis_output_dir=None, + ) + ) + except Exception as e: + logger.error("PDF parse failed: %s", e) + return None + + if not results: + return None + + # Combine markdown from all results (usually one per PDF) + full_markdown = "\n\n---\n\n".join( + r.markdown_result or "" for r in results + ) + + if not full_markdown.strip(): + logger.warning("PDF parse produced empty markdown") + return None + + logger.debug( + "PDF parse markdown (truncated): %s", full_markdown[:500] + ) + + # Phase 2: send markdown + extraction prompt to VLM + combined_prompt = ( + f"以下是文档内容:\n\n{full_markdown}\n\n{extraction_prompt}" + ) + req = { + "messages": [ + { + "role": "user", + "content": combined_prompt, + } + ], + "temperature": 0.1, + "top_p": pipeline.page_loader.top_p, + "top_k": pipeline.page_loader.top_k, + "repetition_penalty": pipeline.page_loader.repetition_penalty, + } + + response, status_code = pipeline.ocr_client.process(req) + if status_code != 200: + logger.error( + "Extraction VLM call failed (%s): %s", status_code, response + ) + return None + + content = ( + response.get("choices", [{}])[0] + .get("message", {}) + .get("content", "") + ) or response.get("response", "") + + logger.debug( + "PDF extract raw content (truncated): %s", content[:500] + ) + try: + return _parse_json_from_text(content) + except ValueError as exc: + logger.warning("Extract JSON parse failed: %s", exc) + return None + + def _handle_schemaless_extract(pipeline, images, extraction_prompt): + """Extract without schema: parse to markdown, then convert to JSON. + + 1. Run the normal parse pipeline to get markdown. + 2. Send the markdown + extraction prompt to the VLM to produce JSON. + """ + try: + request_data = _build_messages(images) + results = list( + pipeline.process( + request_data, + save_layout_visualization=False, + layout_vis_output_dir=None, ) - # Multiple units: merge json as list, markdown with separator - json_result = [r.json_result for r in results] - markdown_result = "\n\n---\n\n".join( + ) + + if not results: + return jsonify({"data": None}), 200 + + full_markdown = "\n\n---\n\n".join( r.markdown_result or "" for r in results ) - return ( - jsonify( + + if not full_markdown.strip(): + logger.warning("Schemaless extract: parse produced empty markdown") + return jsonify({"data": None}), 200 + + logger.debug( + "Schemaless extract markdown (truncated): %s", + full_markdown[:500], + ) + + # Send markdown + prompt to VLM to convert to JSON + combined_prompt = ( + f"以下是文档内容:\n\n{full_markdown}\n\n{extraction_prompt}" + ) + req = { + "messages": [ { - "json_result": json_result, - "markdown_result": markdown_result, + "role": "user", + "content": combined_prompt, } - ), - 200, + ], + "temperature": 0.1, + "top_p": pipeline.page_loader.top_p, + "top_k": pipeline.page_loader.top_k, + "repetition_penalty": pipeline.page_loader.repetition_penalty, + } + + response, status_code = pipeline.ocr_client.process(req) + if status_code != 200: + logger.error( + "Schemaless extract VLM call failed (%s): %s", + status_code, + response, + ) + return jsonify({"error": "VLM extraction failed"}), 500 + + content = ( + response.get("choices", [{}])[0] + .get("message", {}) + .get("content", "") + ) or response.get("response", "") + + logger.debug( + "Schemaless extract raw content (truncated): %s", + content[:500], ) + try: + data = _parse_json_from_text(content) + except ValueError as exc: + logger.warning("Schemaless extract JSON parse failed: %s", exc) + data = None + + return jsonify({"data": data}), 200 + except Exception as e: - logger.error("Parse error: %s", e) + logger.error("Schemaless extract error: %s", e) logger.debug(traceback.format_exc()) - return jsonify({"error": f"Parse error: {str(e)}"}), 500 + return jsonify({"error": f"Extract error: {str(e)}"}), 500 + + def _handle_extract_multipart(pipeline): + """Handle multipart/form-data extraction requests.""" + import json as _json + from pathlib import Path as _Path + + uploaded_files = request.files.getlist("files") + url_values = request.form.getlist("urls") + schema_str = request.form.get("schema") + prompt_override = request.form.get("prompt") + + if not uploaded_files and not url_values: + return jsonify({"error": "No files or urls provided"}), 400 + + if schema_str: + try: + schema_raw = _json.loads(schema_str) + except _json.JSONDecodeError: + return jsonify({"error": "schema must be valid JSON"}), 400 + + try: + extraction_prompt = _build_extraction_prompt(schema_raw, prompt_override) + except (TypeError, ValueError) as e: + return jsonify({"error": f"Invalid schema: {e}"}), 400 + else: + extraction_prompt = prompt_override or _SCHEMALESS_EXTRACTION_PROMPT + + temp_dir = None + try: + image_paths: List[str] = [] + + if uploaded_files: + temp_dir = tempfile.mkdtemp(prefix="glmocr_extract_") + for idx, f in enumerate(uploaded_files): + filename = f.filename or f"upload_{idx}" + safe_name = _Path(filename).name or f"upload_{idx}" + save_path = os.path.join(temp_dir, f"{idx}_{safe_name}") + f.save(save_path) + image_paths.append(save_path) + + for url in url_values: + url = url.strip() + if url: + image_paths.append(url) + + if not image_paths: + return jsonify({"error": "No valid files or urls provided"}), 400 + + if not schema_str: + return _handle_schemaless_extract(pipeline, image_paths, extraction_prompt) + + maas_config = app.config["doc_config"].pipeline.maas + if maas_config.enabled: + return _handle_extract_maas(image_paths, extraction_prompt) + return _handle_extract_selfhosted(pipeline, image_paths, extraction_prompt) + + except Exception as e: + logger.error("Extract error: %s", e) + return jsonify({"error": f"Extract error: {str(e)}"}), 500 + finally: + if temp_dir: + shutil.rmtree(temp_dir, ignore_errors=True) @app.route("/health", methods=["GET"]) def health(): @@ -193,7 +740,7 @@ def main(): logger.info( "GlmOcr Server starting on %s:%d...", server_config.host, server_config.port ) - logger.info("API endpoint: /glmocr/parse") + logger.info("API endpoints: /glmocr/parse, /glmocr/extract") logger.info("=" * 60) logger.info("") diff --git a/glmocr/utils/image_utils.py b/glmocr/utils/image_utils.py index 7c77a81..34e0c56 100644 --- a/glmocr/utils/image_utils.py +++ b/glmocr/utils/image_utils.py @@ -123,7 +123,14 @@ def _try_decode_base64_to_image_bytes(s: str) -> bytes | None: if image_source.startswith("file://"): image_source = image_source[7:] - if os.path.isfile(image_source): + if image_source.startswith(("http://", "https://")): + # Remote URL (including presigned URLs) + import requests as _requests + + resp = _requests.get(image_source, timeout=120) + resp.raise_for_status() + image = Image.open(io.BytesIO(resp.content)) + elif os.path.isfile(image_source): # Local file path (PDFs are handled via PageLoader) with open(image_source, "rb") as f: image_data = f.read() diff --git a/pyproject.toml b/pyproject.toml index 0ac865d..df10bd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ layout = [ server = [ "flask>=3.1.0", + "gunicorn>=23.0.0", ] selfhosted = [