diff --git a/README.md b/README.md index c272fdbd..c0fac49d 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ See [below](#benchmarks) for detailed speed and accuracy benchmarks, and instruc ## Hybrid Mode -For the highest accuracy, pass the `--use_llm` flag to use an LLM alongside marker. This will do things like merge tables across pages, handle inline math, format tables properly, and extract values from forms. It can use any gemini or ollama model. By default, it uses `gemini-2.0-flash`. See [below](#llm-services) for details. +For the highest accuracy, pass the `--use_llm` flag to use an LLM alongside marker. This will do things like merge tables across pages, handle inline math, format tables properly, and extract values from forms. It can use any gemini or ollama model. By default, it uses `gemini-2.5-flash`. See [below](#llm-services) for details. Here is a table benchmark comparing marker, gemini flash alone, and marker with use_llm: @@ -534,7 +534,7 @@ python benchmarks/table/table.py --max_rows 100 Options: - `--use_llm` uses an llm with marker to improve accuracy. -- `--use_gemini` also benchmarks gemini 2.0 flash. +- `--use_gemini` also benchmarks gemini 2.5 flash. # How it works diff --git a/benchmarks/overall/elo.py b/benchmarks/overall/elo.py index d260ee81..be90794f 100644 --- a/benchmarks/overall/elo.py +++ b/benchmarks/overall/elo.py @@ -143,7 +143,7 @@ def llm_response_wrapper( ) try: responses = client.models.generate_content( - model="gemini-2.0-flash-001", + model="gemini-2.5-flash", contents=prompt, config={ "temperature": 0, diff --git a/benchmarks/overall/scorers/llm.py b/benchmarks/overall/scorers/llm.py index 00ff4031..7d2cdd57 100644 --- a/benchmarks/overall/scorers/llm.py +++ b/benchmarks/overall/scorers/llm.py @@ -142,7 +142,7 @@ def llm_response_wrapper(self, prompt, response_schema, depth=0): ) try: responses = client.models.generate_content( - model="gemini-2.0-flash-001", + model="gemini-2.5-flash", contents=prompt, config={ "temperature": 0, diff --git a/benchmarks/table/gemini.py b/benchmarks/table/gemini.py index 5832a90f..768662ae 100644 --- a/benchmarks/table/gemini.py +++ b/benchmarks/table/gemini.py @@ -35,7 +35,7 @@ def gemini_table_rec(image: Image.Image): image.save(image_bytes, format="PNG") responses = client.models.generate_content( - model="gemini-2.0-flash", + model="gemini-2.5-flash", contents=[types.Part.from_bytes(data=image_bytes.getvalue(), mime_type="image/png"), prompt], # According to gemini docs, it performs better if the image is the first element config={ "temperature": 0, diff --git a/marker/config/parser.py b/marker/config/parser.py index ffc3240d..541151c2 100644 --- a/marker/config/parser.py +++ b/marker/config/parser.py @@ -13,6 +13,7 @@ from marker.renderers.markdown import MarkdownRenderer from marker.settings import settings from marker.util import classes_to_strings, parse_range_str, strings_to_classes +from marker.services.gemini import GeminiModel logger = get_logger() @@ -82,6 +83,12 @@ def common_options(fn): default=None, help="LLM service to use - should be full import path, like marker.services.gemini.GoogleGeminiService", )(fn) + fn = click.option( + "--gemini_model_name", + type=str, + default=None, + help="Name of the Gemini model to use (e.g., gemini-2.5-flash)", + )(fn) return fn def generate_config_dict(self) -> Dict[str, any]: @@ -114,6 +121,13 @@ def generate_config_dict(self) -> Dict[str, any]: if settings.GOOGLE_API_KEY: config["gemini_api_key"] = settings.GOOGLE_API_KEY + if self.cli_options.get("gemini_model_name"): + try: + config["gemini_model_name"] = GeminiModel(self.cli_options["gemini_model_name"]) + except ValueError: + logger.warning(f"Invalid gemini_model_name: {self.cli_options["gemini_model_name"]}. Using default.") + config["gemini_model_name"] = GeminiModel.DEFAULT + return config def get_llm_service(self): diff --git a/marker/services/gemini.py b/marker/services/gemini.py index 213aebd0..a71d9f03 100644 --- a/marker/services/gemini.py +++ b/marker/services/gemini.py @@ -1,7 +1,10 @@ import json import time import traceback +from collections import deque +from enum import Enum from io import BytesIO +from threading import Lock from typing import List, Annotated import PIL @@ -17,10 +20,28 @@ logger = get_logger() +class GeminiModel(str, Enum): + GEMINI_2_5_PRO = "gemini-2.5-pro" + GEMINI_2_5_FLASH = "gemini-2.5-flash" + DEFAULT = "gemini-2.5-flash" + + +# Rate limiting settings +MODEL_LIMITS = { + GeminiModel.GEMINI_2_5_PRO: {"rpm": 5}, + GeminiModel.GEMINI_2_5_FLASH: {"rpm": 10}, + GeminiModel.DEFAULT: {"rpm": 10}, # Corresponds to gemini-2.5-flash +} + +# Global request tracker and lock +REQUEST_TIMESTAMPS = {model: deque() for model in GeminiModel} +RATE_LIMIT_LOCK = Lock() + + class BaseGeminiService(BaseService): gemini_model_name: Annotated[ - str, "The name of the Google model to use for the service." - ] = "gemini-2.0-flash" + GeminiModel, "The name of the Google model to use for the service." + ] = GeminiModel.DEFAULT def img_to_bytes(self, img: PIL.Image.Image): image_bytes = BytesIO() @@ -52,6 +73,28 @@ def __call__( if timeout is None: timeout = self.timeout + # Proactive rate limiting + with RATE_LIMIT_LOCK: + model_name = self.gemini_model_name + rpm_limit = MODEL_LIMITS.get(model_name, {"rpm": 10})["rpm"] + request_history = REQUEST_TIMESTAMPS[model_name] + + current_time = time.time() + # Remove timestamps older than 60 seconds + while request_history and current_time - request_history[0] > 60: + request_history.popleft() + + if len(request_history) >= rpm_limit: + wait_time = 60 - (current_time - request_history[0]) + if wait_time > 0: + logger.warning( + f"RPM limit for {model_name} reached. Waiting for {wait_time:.2f} seconds." + ) + time.sleep(wait_time) + + # Record the new request timestamp + request_history.append(time.time()) + client = self.get_google_client(timeout=timeout) image_parts = self.format_image_for_llm(image) diff --git a/marker/services/vertex.py b/marker/services/vertex.py index f67e2bac..a0cc4567 100644 --- a/marker/services/vertex.py +++ b/marker/services/vertex.py @@ -16,7 +16,7 @@ class GoogleVertexService(BaseGeminiService): gemini_model_name: Annotated[ str, "The name of the Google model to use for the service." - ] = "gemini-2.0-flash-001" + ] = "gemini-2.5-flash" vertex_dedicated: Annotated[ bool, "Whether to use a dedicated Vertex AI instance."