diff --git a/docs/source/translation.rst b/docs/source/translation.rst index 7dd7ef15a..efd646a28 100644 --- a/docs/source/translation.rst +++ b/docs/source/translation.rst @@ -13,6 +13,41 @@ Limitations - If probes or detectors fail to load, you need may need to choose a smaller local translation model or utilize a remote service. - Translation may add significant execution time to the run depending on resources available. +Translation Caching +------------------ + +Garak implements a translation caching system to improve performance and reduce API costs when using translation services. The caching mechanism automatically stores and retrieves translation results to avoid redundant API calls. + +**How it works:** + +- Each translation pair (source language → target language) gets its own cache file +- Cache files are stored in JSON format under the cache directory: ``{cache_dir}/translation/translation_cache_{source_lang}_{target_lang}_{model_type}_{model_name}.json`` +- Translation results are keyed by MD5 hash of the input text for efficient storage and retrieval +- Cache files persist between runs, allowing translations to be reused across multiple garak sessions + +**Benefits:** + +- **Performance**: Significantly reduces translation time for repeated text +- **Cost savings**: Reduces API calls to paid services like DeepL, Google Cloud Translation, and NVIDIA Riva +- **Reliability**: Provides fallback for offline scenarios when cached translations are available +- **Consistency**: Ensures identical translations for the same input text across different runs + +**Cache management:** + +- Cache files are automatically created when translations are performed +- Corrupted cache files are handled gracefully with fallback to empty cache +- Cache files can be manually deleted to force fresh translations +- Cache directory location follows garak's standard cache configuration + +**Supported for all translation services:** + +- Local translation models (Hugging Face) +- DeepL API +- NVIDIA Riva API +- Google Cloud Translation API + +The caching system is transparent to users and requires no additional configuration. It automatically activates when translation services are used. + Supported translation services ------------------------------ diff --git a/garak/langproviders/base.py b/garak/langproviders/base.py index b35124d0f..f6587fa0d 100644 --- a/garak/langproviders/base.py +++ b/garak/langproviders/base.py @@ -10,8 +10,13 @@ import unicodedata import string import logging +import json +import hashlib +import os +from pathlib import Path from garak.resources.api import nltk from langdetect import detect, DetectorFactory, LangDetectException +from garak import _config _intialized_words = False @@ -128,7 +133,7 @@ def is_meaning_string(text: str) -> bool: # To be `Configurable` the root object must meet the standard type search criteria # { langproviders: # "local": { # model_type -# "language": "-" +# "language": "," # "name": "model/name" # model_name # "hf_args": {} # or any other translator specific values for the model_type # } @@ -136,6 +141,86 @@ def is_meaning_string(text: str) -> bool: from garak.configurable import Configurable +class TranslationCache: + def __init__(self, provider: "LangProvider"): + if not hasattr(provider, "model_type"): + return None # providers without a model_type do not have a cache + + self.source_lang = provider.source_lang + self.target_lang = provider.target_lang + self.model_type = provider.model_type + self.model_name = "default" + if hasattr(provider, "model_name"): + self.model_name = provider.model_name + + cache_dir = _config.transient.cache_dir / "translation" + cache_dir.mkdir(mode=0o740, parents=True, exist_ok=True) + cache_filename = f"translation_cache_{self.source_lang}_{self.target_lang}_{self.model_type}_{self.model_name.replace('/', '_')}.json" + self.cache_file = cache_dir / cache_filename + logging.info(f"Cache file: {self.cache_file}") + self._cache = self._load_cache() + + def _load_cache(self) -> dict: + if self.cache_file.exists(): + try: + with open(self.cache_file, "r", encoding="utf-8") as f: + return json.load(f) + except (json.JSONDecodeError, IOError) as e: + logging.warning(f"Failed to load translation cache: {e}") + return {} + return {} + + def _save_cache(self): + try: + with open(self.cache_file, "w", encoding="utf-8") as f: + json.dump(self._cache, f, ensure_ascii=False, indent=2) + except IOError as e: + logging.warning(f"Failed to save translation cache: {e}") + + def get_cache_key(self, text: str) -> str: + return hashlib.md5(text.encode("utf-8"), usedforsecurity=False).hexdigest() + + def get(self, text: str) -> str | None: + cache_key = self.get_cache_key(text) + cache_entry = self._cache.get(cache_key) + if cache_entry and isinstance(cache_entry, dict): + return cache_entry.get("translation") + elif isinstance(cache_entry, str): + # Backward compatibility with old format + return cache_entry + return None + + def set(self, text: str, translation: str): + cache_key = self.get_cache_key(text) + self._cache[cache_key] = { + "original": text, + "translation": translation, + "source_lang": self.source_lang, + "target_lang": self.target_lang, + "model_type": self.model_type, + "model_name": self.model_name, + } + self._save_cache() + + def get_cache_entry(self, text: str) -> dict | None: + """Get full cache entry including original text and metadata.""" + cache_key = self.get_cache_key(text) + cache_entry = self._cache.get(cache_key) + if cache_entry and isinstance(cache_entry, dict): + return cache_entry + elif isinstance(cache_entry, str): + # Backward compatibility with old format + return { + "original": text, + "translation": cache_entry, + "source_lang": self.source_lang, + "target_lang": self.target_lang, + "model_type": self.model_type, + "model_name": self.model_name, + } + return None + + class LangProvider(Configurable): """Base class for objects that provision language""" @@ -147,6 +232,9 @@ def __init__(self, config_root: dict = {}) -> None: self._validate_env_var() + # Use TranslationCache for caching + self.cache = TranslationCache(self) + self._load_langprovider() def _load_langprovider(self): @@ -155,6 +243,16 @@ def _load_langprovider(self): def _translate(self, text: str) -> str: raise NotImplementedError + def _translate_with_cache(self, text: str) -> str: + """Translate text with caching support.""" + cached_translation = self.cache.get(text) + if cached_translation is not None: + logging.debug(f"Using cached translation for text: {text[:50]}...") + return cached_translation + translation = self._translate_impl(text) + self.cache.set(text, translation) + return translation + def _get_response(self, input_text: str): translated_lines = [] @@ -189,7 +287,7 @@ def _short_sentence_translate(self, line: str) -> str: if needs_translation: cleaned_line = self._clean_line(line) if cleaned_line: - translated_line = self._translate(cleaned_line) + translated_line = self._translate_with_cache(cleaned_line) translated_lines.append(translated_line) return translated_lines @@ -202,7 +300,7 @@ def _long_sentence_translate(self, line: str) -> str: if self._should_skip_line(cleaned_sentence): translated_lines.append(cleaned_sentence) continue - translated_line = self._translate(cleaned_sentence) + translated_line = self._translate_with_cache(cleaned_sentence) translated_lines.append(translated_line) return translated_lines diff --git a/garak/langproviders/local.py b/garak/langproviders/local.py index 1c7f4b783..360eb3e11 100644 --- a/garak/langproviders/local.py +++ b/garak/langproviders/local.py @@ -19,6 +19,7 @@ def _load_langprovider(self): pass def _translate(self, text: str) -> str: + # Use _translate_with_cache to enable caching return text def get_text( @@ -110,6 +111,11 @@ def _load_langprovider(self): self.tokenizer = MarianTokenizer.from_pretrained(model_name) def _translate(self, text: str) -> str: + # Use _translate_with_cache to enable caching + return self._translate_with_cache(text) + + def _translate_impl(self, text: str) -> str: + """Actual translation implementation without caching.""" if "m2m100" in self.model_name: self.tokenizer.src_lang = self.source_lang diff --git a/garak/langproviders/remote.py b/garak/langproviders/remote.py index dc541f2a8..336d02033 100644 --- a/garak/langproviders/remote.py +++ b/garak/langproviders/remote.py @@ -91,6 +91,11 @@ def _load_langprovider(self): # TODO: consider adding a backoff here and determining if a connection needs to be re-established def _translate(self, text: str) -> str: + # Use _translate_with_cache to enable caching + return self._translate_with_cache(text) + + def _translate_impl(self, text: str) -> str: + """Actual translation implementation without caching.""" try: if self.client is None: self._load_langprovider() @@ -152,6 +157,11 @@ def _load_langprovider(self): self._tested = True def _translate(self, text: str) -> str: + # Use _translate_with_cache to enable caching + return self._translate_with_cache(text) + + def _translate_impl(self, text: str) -> str: + """Actual translation implementation without caching.""" try: return self.client.translate_text( text, source_lang=self._source_lang, target_lang=self._target_lang @@ -230,6 +240,11 @@ def _load_langprovider(self): self._tested = True def _translate(self, text: str) -> str: + # Use _translate_with_cache to enable caching + return self._translate_with_cache(text) + + def _translate_impl(self, text: str) -> str: + """Actual translation implementation without caching.""" retry = 5 while retry > 0: try: diff --git a/tests/langservice/test_translation_cache.py b/tests/langservice/test_translation_cache.py new file mode 100644 index 000000000..81dcd12b0 --- /dev/null +++ b/tests/langservice/test_translation_cache.py @@ -0,0 +1,271 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import tempfile +import json +import os +from pathlib import Path +from unittest.mock import patch, MagicMock + +from garak.langproviders.base import LangProvider, TranslationCache + + +class TestTranslationCache: + """Test translation caching functionality.""" + + @pytest.fixture + def temp_cache_dir(self): + """Create a temporary cache directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + @pytest.fixture + def mock_config(self): + """Mock configuration for testing.""" + return { + "langproviders": {"passthru": {"language": "en,ja", "model_type": "test"}} + } + + def test_cache_with_different_model_types(self, temp_cache_dir): + """Test cache works with different model types.""" + with patch("garak._config.transient.cache_dir", temp_cache_dir): + config1 = { + "langproviders": { + "local": { + "language": "en,ja", + "model_type": "local", + "name": "test_model", + } + } + } + config2 = { + "langproviders": { + "remote": { + "language": "en,ja", + "model_type": "remote", + "name": "test_model", + } + } + } + + # Create mock LangProvider instances + provider1 = MagicMock() + provider1.source_lang = "en" + provider1.target_lang = "ja" + provider1.model_type = "local" + provider1.model_name = "test_model" + + provider2 = MagicMock() + provider2.source_lang = "en" + provider2.target_lang = "ja" + provider2.model_type = "remote" + provider2.model_name = "test_model" + + cache1 = TranslationCache(provider1) + cache2 = TranslationCache(provider2) + + # Different model types should create different cache files + assert str(cache1.cache_file) != str(cache2.cache_file) + + # Test caching works for both + cache1.set("hello", "こんにちは") + cache2.set("hello", "こんにちは") + + assert cache1.get("hello") == "こんにちは" + assert cache2.get("hello") == "こんにちは" + + def test_cache_stores_original_text(self, temp_cache_dir): + """Test that cache stores original text along with translation.""" + with patch("garak._config.transient.cache_dir", temp_cache_dir): + # Create mock LangProvider instance + provider = MagicMock() + provider.source_lang = "en" + provider.target_lang = "ja" + provider.model_type = "local" + provider.model_name = "test_model" + + cache = TranslationCache(provider) + original_text = "Hello world" + translated_text = "こんにちは世界" + + cache.set(original_text, translated_text) + + # Get full cache entry + cache_entry = cache.get_cache_entry(original_text) + assert cache_entry is not None + assert cache_entry["original"] == original_text + assert cache_entry["translation"] == translated_text + assert cache_entry["source_lang"] == "en" + assert cache_entry["target_lang"] == "ja" + assert cache_entry["model_type"] == "local" + assert cache_entry["model_name"] == "test_model" + + def test_remote_translator_cache_initialization(self, temp_cache_dir): + """Test that remote translators work without __init__ methods.""" + with patch("garak._config.transient.cache_dir", temp_cache_dir): + from garak.langproviders.remote import ( + RivaTranslator, + DeeplTranslator, + GoogleTranslator, + ) + + # Test RivaTranslator + config_riva = { + "langproviders": { + "riva": { + "language": "en,ja", + "model_type": "remote.RivaTranslator", + "api_key": "test_key", + } + } + } + + # Mock API key validation and create test subclass + with ( + patch.object(RivaTranslator, "_validate_env_var"), + patch.object(RivaTranslator, "_load_langprovider"), + ): + + class TestRivaTranslator(RivaTranslator): + def __init__(self, config_root={}): + self.language = "en,ja" + self.model_type = "remote.RivaTranslator" + super().__init__(config_root) + + translator_riva = TestRivaTranslator(config_root=config_riva) + + # Check that cache is initialized + assert translator_riva.cache is not None + assert "en_ja" in str(translator_riva.cache.cache_file) + assert "remote.RivaTranslator" in str(translator_riva.cache.cache_file) + + # Test DeeplTranslator + config_deepl = { + "langproviders": { + "deepl": { + "language": "en,ja", + "model_type": "remote.DeeplTranslator", + "api_key": "test_key", + } + } + } + + with ( + patch.object(DeeplTranslator, "_validate_env_var"), + patch.object(DeeplTranslator, "_load_langprovider"), + ): + + class TestDeeplTranslator(DeeplTranslator): + def __init__(self, config_root={}): + self.language = "en,ja" + self.model_type = "remote.DeeplTranslator" + super().__init__(config_root) + + translator_deepl = TestDeeplTranslator(config_root=config_deepl) + + assert translator_deepl.cache is not None + assert "en_ja" in str(translator_deepl.cache.cache_file) + assert "remote.DeeplTranslator" in str( + translator_deepl.cache.cache_file + ) + + # Test GoogleTranslator + config_google = { + "langproviders": { + "google": { + "language": "en,ja", + "model_type": "remote.GoogleTranslator", + "api_key": "test_key", + } + } + } + + with ( + patch.object(GoogleTranslator, "_validate_env_var"), + patch.object(GoogleTranslator, "_load_langprovider"), + ): + + class TestGoogleTranslator(GoogleTranslator): + def __init__(self, config_root={}): + self.language = "en,ja" + self.model_type = "remote.GoogleTranslator" + super().__init__(config_root) + + translator_google = TestGoogleTranslator(config_root=config_google) + + assert translator_google.cache is not None + assert "en_ja" in str(translator_google.cache.cache_file) + assert "remote.GoogleTranslator" in str( + translator_google.cache.cache_file + ) + + def test_remote_translator_cache_functionality(self, temp_cache_dir): + """Test that remote translators can use cache functionality.""" + with patch("garak._config.transient.cache_dir", temp_cache_dir): + from garak.langproviders.remote import RivaTranslator + + config = { + "langproviders": { + "riva": { + "language": "en,ja", + "model_type": "remote.RivaTranslator", + "api_key": "test_key", + } + } + } + + with ( + patch.object(RivaTranslator, "_validate_env_var"), + patch.object(RivaTranslator, "_load_langprovider"), + ): + + class TestRivaTranslator(RivaTranslator): + def __init__(self, config_root={}): + self.language = "en,ja" + self.model_type = "remote.RivaTranslator" + super().__init__(config_root) + + translator = TestRivaTranslator(config_root=config) + + # Test cache functionality + test_text = "Hello world" + test_translation = "こんにちは世界" + + # Set cache manually + translator.cache.set(test_text, test_translation) + + # Verify cache entry + cache_entry = translator.cache.get_cache_entry(test_text) + assert cache_entry is not None + assert cache_entry["original"] == test_text + assert cache_entry["translation"] == test_translation + assert cache_entry["source_lang"] == "en" + assert cache_entry["target_lang"] == "ja" + assert cache_entry["model_type"] == "remote.RivaTranslator" + + def test_cache_with_custom_model_name(self, temp_cache_dir): + """Test cache works with custom model name.""" + with patch("garak._config.transient.cache_dir", temp_cache_dir): + # Create mock LangProvider instance with custom model_name + provider = MagicMock() + provider.source_lang = "en" + provider.target_lang = "ja" + provider.model_type = "local" + provider.model_name = "custom_model" + + cache = TranslationCache(provider) + + # Verify custom model_name is used + assert cache.model_name == "custom_model" + + # Test cache functionality + test_text = "Hello world" + test_translation = "こんにちは世界" + + cache.set(test_text, test_translation) + + # Verify cache entry includes custom model_name + cache_entry = cache.get_cache_entry(test_text) + assert cache_entry is not None + assert cache_entry["model_name"] == "custom_model" diff --git a/tests/langservice/test_translation_cache_integration.py b/tests/langservice/test_translation_cache_integration.py new file mode 100644 index 000000000..68a801161 --- /dev/null +++ b/tests/langservice/test_translation_cache_integration.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import tempfile +import os +from pathlib import Path +from unittest.mock import patch, MagicMock + +from garak.langproviders.base import LangProvider, TranslationCache + + +class TestTranslationCacheIntegration: + """Integration test for translation caching functionality.""" + + @pytest.fixture + def temp_cache_dir(self): + """Create a temporary cache directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + def test_remote_translator_integration(self, temp_cache_dir): + """Test that remote translators work correctly in integration scenarios.""" + with patch("garak._config.transient.cache_dir", temp_cache_dir): + from garak.langproviders.remote import RivaTranslator + + config = { + "langproviders": { + "riva": { + "language": "en,ja", + "model_type": "remote.RivaTranslator", + "api_key": "test_key", + } + } + } + + # Mock API key validation and create test subclass + with ( + patch.object(RivaTranslator, "_validate_env_var"), + patch.object(RivaTranslator, "_load_langprovider"), + ): + + class TestRivaTranslator(RivaTranslator): + def __init__(self, config_root={}): + self.language = "en,ja" + self.model_type = "remote.RivaTranslator" + super().__init__(config_root) + + translator = TestRivaTranslator(config_root=config) + + # Test that translator can be instantiated and has cache + assert translator.cache is not None + assert translator.source_lang == "en" + assert translator.target_lang == "ja" + + # Test that cache file path is correctly generated + cache_file_path = translator.cache.cache_file + assert "en_ja" in str(cache_file_path) + assert "remote.RivaTranslator" in str(cache_file_path) + assert "default" in str(cache_file_path) # Default model_name + + # Test that translator can handle translation requests (mock) + with patch.object( + translator, "_translate_impl", return_value="こんにちは世界" + ): + result = translator._translate_with_cache("Hello world") + assert result == "こんにちは世界" + + # Second call should use cache + result2 = translator._translate_with_cache("Hello world") + assert result2 == "こんにちは世界" + + def test_local_translator_integration(self, temp_cache_dir): + """Test that local translators work correctly in integration scenarios (mocked, no Passthru).""" + with patch("garak._config.transient.cache_dir", temp_cache_dir): + # モックLangProviderサブクラス + class MockLocalProvider(LangProvider): + def __init__(self): + self.language = "en,ja" + self.model_type = "local" + self.model_name = "test_model" + self.source_lang, self.target_lang = self.language.split(",") + self._validate_env_var = lambda: None + self._load_langprovider = lambda: None + self.cache = TranslationCache(self) + + def _translate(self, text): + return "" + + def _translate_impl(self, text): + return "" + + translator = MockLocalProvider() + + # Test that translator can be instantiated and has cache + assert translator.cache is not None + assert translator.source_lang == "en" + assert translator.target_lang == "ja" + + # Test that cache file path is correctly generated + cache_file_path = translator.cache.cache_file + assert "en_ja" in str(cache_file_path) + assert "local" in str(cache_file_path) + assert "test_model" in str(cache_file_path) + + # Test that translator can handle translation requests (mock) + with patch.object( + translator, "_translate_impl", return_value="こんにちは世界" + ): + result = translator._translate_with_cache("Hello world") + assert result == "こんにちは世界" + + # Second call should use cache + result2 = translator._translate_with_cache("Hello world") + assert result2 == "こんにちは世界" + + def test_cache_persistence_across_sessions(self, temp_cache_dir): + """Test that cache persists across different translator sessions (mocked, no Passthru).""" + with patch("garak._config.transient.cache_dir", temp_cache_dir): + + class MockLocalProvider(LangProvider): + def __init__(self): + self.language = "en,ja" + self.model_type = "local" + self.model_name = "test_model" + self.source_lang, self.target_lang = self.language.split(",") + self._validate_env_var = lambda: None + self._load_langprovider = lambda: None + self.cache = TranslationCache(self) + + def _translate(self, text): + return "" + + # Create first translator instance + translator1 = MockLocalProvider() + # Set cache entry + test_text = "Hello world" + test_translation = "こんにちは世界" + translator1.cache.set(test_text, test_translation) + # Verify cache entry was saved + cache_entry = translator1.cache.get_cache_entry(test_text) + assert cache_entry is not None + assert cache_entry["translation"] == test_translation + # Create second translator instance (simulating new session) + translator2 = MockLocalProvider() + # Verify cache entry is still available + cached_translation = translator2.cache.get(test_text) + assert cached_translation == test_translation