Skip to content

Commit

Permalink
Add support to gemini-2.0-flash-exp model.
Browse files Browse the repository at this point in the history
  • Loading branch information
zh-plus committed Dec 23, 2024
1 parent 616d7a6 commit dd62400
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 5 deletions.
7 changes: 5 additions & 2 deletions openlrc/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _get_sleep_time(self, error):

@_register_chatbot
class GeminiBot(ChatBot):
def __init__(self, model_name='gemini-1.5-flash', temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.8,
def __init__(self, model_name='gemini-2.0-flash-exp', temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.8,
proxy=None, base_url_config=None, api_key=None):
self.temperature = max(0, min(1, temperature))

Expand All @@ -356,7 +356,7 @@ def __init__(self, model_name='gemini-1.5-flash', temperature=1, top_p=1, retry=
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE
}

if proxy:
Expand Down Expand Up @@ -423,6 +423,9 @@ async def _create_achat(self, messages: List[Dict], stop_sequences: Optional[Lis
except (genai.types.BrokenResponseError, genai.types.IncompleteIterationError,
genai.types.StopCandidateException) as e:
logger.warning(f'{type(e).__name__}: {e}. Retry num: {i + 1}.')
except genai.types.generation_types.BlockedPromptException as e:
logger.warning(f'Prompt blocked: {e}.\n Retry in 30s.')
time.sleep(30)

if not response:
raise ChatBotException('Failed to create a chat.')
Expand Down
11 changes: 11 additions & 0 deletions openlrc/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,17 @@ class Models:
context_window=1048576
)

GEMINI_2_0_FLASH_EXP = ModelInfo(
name="gemini-2.0-flash-exp",
provider=ModelProvider.GOOGLE,
input_price=0,
output_price=0,
max_tokens=8192,
context_window=1048576,
vision_support=True,
knowledge_cutoff="Aug 2024"
)

# Third Party Models
DEEPSEEK = ModelInfo(
name="deepseek-chat",
Expand Down
18 changes: 16 additions & 2 deletions openlrc/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from openlrc.agents import ChunkedTranslatorAgent, ContextReviewerAgent
from openlrc.context import TranslationContext, TranslateInfo
from openlrc.exceptions import ChatBotException
from openlrc.logger import logger
from openlrc.models import ModelConfig
from openlrc.prompter import AtomicTranslatePrompter
Expand Down Expand Up @@ -107,12 +108,22 @@ def _translate_chunk(self, translator_agent: ChunkedTranslatorAgent, chunk: List
Returns:
Tuple[List[str], TranslationContext]: Translated texts and updated context.
"""

def handle_translation(agent: ChunkedTranslatorAgent) -> Tuple[List[str], TranslationContext]:
trans, updated_context = agent.translate_chunk(chunk_id, chunk, context)
trans, updated_context = None, None
try:
trans, updated_context = agent.translate_chunk(chunk_id, chunk, context)
except ChatBotException as e:
logger.error(f'Failed to translate chunk {chunk_id}.')

if len(trans) != len(chunk) and agent.info.glossary:
logger.warning(
f'Agent {agent}: Removing glossary for chunk {chunk_id} due to inconsistent translation.')
trans, updated_context = agent.translate_chunk(chunk_id, chunk, context, use_glossary=False)
try:
trans, updated_context = agent.translate_chunk(chunk_id, chunk, context, use_glossary=False)
except ChatBotException as e:
logger.error(f'Failed to translate chunk {chunk_id}.')

return trans, updated_context

if self.use_retry_cnt == 0 or not retry_agent:
Expand All @@ -128,6 +139,9 @@ def handle_translation(agent: ChunkedTranslatorAgent) -> Tuple[List[str], Transl
translated, context = handle_translation(retry_agent)
self.use_retry_cnt -= 1

if not translated:
raise ChatBotException(f'Failed to translate chunk {chunk_id}.')

return translated, context

def translate(self, texts: Union[str, List[str]], src_lang: str, target_lang: str,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_validate_returns_true_when_generated_content_matches_target_language(se
self.assertTrue(result)

def test_validate_returns_false_when_generated_content_not_matches_target_language(self):
validator = AtomicTranslateValidator(target_lang='cn-zh')
validator = AtomicTranslateValidator(target_lang='zh-cn')
user_input = "Hello"
generated_content = "你好"

Expand Down

0 comments on commit dd62400

Please sign in to comment.