diff --git a/openlrc/chatbot.py b/openlrc/chatbot.py index acecda0..e550c7f 100644 --- a/openlrc/chatbot.py +++ b/openlrc/chatbot.py @@ -341,7 +341,7 @@ def _get_sleep_time(self, error): @_register_chatbot class GeminiBot(ChatBot): - def __init__(self, model_name='gemini-1.5-flash', temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.8, + def __init__(self, model_name='gemini-2.0-flash-exp', temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.8, proxy=None, base_url_config=None, api_key=None): self.temperature = max(0, min(1, temperature)) @@ -356,7 +356,7 @@ def __init__(self, model_name='gemini-1.5-flash', temperature=1, top_p=1, retry= HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE } if proxy: @@ -423,6 +423,9 @@ async def _create_achat(self, messages: List[Dict], stop_sequences: Optional[Lis except (genai.types.BrokenResponseError, genai.types.IncompleteIterationError, genai.types.StopCandidateException) as e: logger.warning(f'{type(e).__name__}: {e}. Retry num: {i + 1}.') + except genai.types.generation_types.BlockedPromptException as e: + logger.warning(f'Prompt blocked: {e}.\n Retry in 30s.') + time.sleep(30) if not response: raise ChatBotException('Failed to create a chat.') diff --git a/openlrc/models.py b/openlrc/models.py index 23568ef..97ed1f6 100644 --- a/openlrc/models.py +++ b/openlrc/models.py @@ -211,6 +211,17 @@ class Models: context_window=1048576 ) + GEMINI_2_0_FLASH_EXP = ModelInfo( + name="gemini-2.0-flash-exp", + provider=ModelProvider.GOOGLE, + input_price=0, + output_price=0, + max_tokens=8192, + context_window=1048576, + vision_support=True, + knowledge_cutoff="Aug 2024" + ) + # Third Party Models DEEPSEEK = ModelInfo( name="deepseek-chat", diff --git a/openlrc/translate.py b/openlrc/translate.py index 1ce053e..3d0c659 100644 --- a/openlrc/translate.py +++ b/openlrc/translate.py @@ -13,6 +13,7 @@ from openlrc.agents import ChunkedTranslatorAgent, ContextReviewerAgent from openlrc.context import TranslationContext, TranslateInfo +from openlrc.exceptions import ChatBotException from openlrc.logger import logger from openlrc.models import ModelConfig from openlrc.prompter import AtomicTranslatePrompter @@ -107,12 +108,22 @@ def _translate_chunk(self, translator_agent: ChunkedTranslatorAgent, chunk: List Returns: Tuple[List[str], TranslationContext]: Translated texts and updated context. """ + def handle_translation(agent: ChunkedTranslatorAgent) -> Tuple[List[str], TranslationContext]: - trans, updated_context = agent.translate_chunk(chunk_id, chunk, context) + trans, updated_context = None, None + try: + trans, updated_context = agent.translate_chunk(chunk_id, chunk, context) + except ChatBotException as e: + logger.error(f'Failed to translate chunk {chunk_id}.') + if len(trans) != len(chunk) and agent.info.glossary: logger.warning( f'Agent {agent}: Removing glossary for chunk {chunk_id} due to inconsistent translation.') - trans, updated_context = agent.translate_chunk(chunk_id, chunk, context, use_glossary=False) + try: + trans, updated_context = agent.translate_chunk(chunk_id, chunk, context, use_glossary=False) + except ChatBotException as e: + logger.error(f'Failed to translate chunk {chunk_id}.') + return trans, updated_context if self.use_retry_cnt == 0 or not retry_agent: @@ -128,6 +139,9 @@ def handle_translation(agent: ChunkedTranslatorAgent) -> Tuple[List[str], Transl translated, context = handle_translation(retry_agent) self.use_retry_cnt -= 1 + if not translated: + raise ChatBotException(f'Failed to translate chunk {chunk_id}.') + return translated, context def translate(self, texts: Union[str, List[str]], src_lang: str, target_lang: str, diff --git a/tests/test_validators.py b/tests/test_validators.py index 3f0791e..e556214 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -60,7 +60,7 @@ def test_validate_returns_true_when_generated_content_matches_target_language(se self.assertTrue(result) def test_validate_returns_false_when_generated_content_not_matches_target_language(self): - validator = AtomicTranslateValidator(target_lang='cn-zh') + validator = AtomicTranslateValidator(target_lang='zh-cn') user_input = "Hello" generated_content = "你好"