diff --git a/examples/shinka_tutorial.ipynb b/examples/shinka_tutorial.ipynb index 66a71a0..c6d8189 100644 --- a/examples/shinka_tutorial.ipynb +++ b/examples/shinka_tutorial.ipynb @@ -237,6 +237,17 @@ "if not llm_models:\n", " llm_models = [\"gpt-5-mini\"] # fallback if no keys detected\n", "\n", + "# pick embedding model based on available keys\n", + "embedding_model_name = \"\"\n", + "if os.getenv(\"GEMINI_API_KEY\"):\n", + " embedding_model_name = \"gemini-embedding-001\"\n", + "elif os.getenv(\"OPENAI_API_KEY\"):\n", + " embedding_model_name = \"text-embedding-3-small\"\n", + "else:\n", + " embedding_model_name = \"text-embedding-3-small\"\n", + "print(f\"✅ Embedding model selected: {embedding_model_name}\")\n", + "\n", + "\n", "# unique experiment directory\n", "timestamp = dt.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", "run_tag = f\"{timestamp}_weighted_fast\"\n", @@ -271,6 +282,8 @@ " max_novelty_attempts=3,\n", " # ensemble llm selection among candidates based on past performance\n", " llm_dynamic_selection=None, # e.g. \"ucb1\"\n", + " # set embedding model\n", + " embedding_model=embedding_model_name,\n", ")\n", "\n", "db_config = DatabaseConfig(\n", @@ -286,11 +299,13 @@ " enforce_island_separation=True,\n", " parent_selection_strategy=\"weighted\",\n", " parent_selection_lambda=10.0,\n", + " \n", ")\n", "\n", "job_config = LocalJobConfig(eval_program_path=\"evaluate.py\")\n", "\n", "print(\"llm_models:\", llm_models)\n", + "print(\"embedding_model:\", embedding_model_name)\n", "print(\"results_dir:\", evo_config.results_dir)" ] }, diff --git a/shinka/core/runner.py b/shinka/core/runner.py index 3c81874..c8c7c43 100644 --- a/shinka/core/runner.py +++ b/shinka/core/runner.py @@ -158,7 +158,12 @@ def __init__( # Initialize database and scheduler db_config.db_path = str(db_path) - self.db = ProgramDatabase(config=db_config) + embedding_model_to_use = ( + evo_config.embedding_model or "text-embedding-3-small" + ) + self.db = ProgramDatabase( + config=db_config, embedding_model=embedding_model_to_use + ) self.scheduler = JobScheduler( job_type=evo_config.job_type, config=job_config, # type: ignore diff --git a/shinka/database/dbase.py b/shinka/database/dbase.py index 69fdf54..c6a2b89 100644 --- a/shinka/database/dbase.py +++ b/shinka/database/dbase.py @@ -82,6 +82,9 @@ class DatabaseConfig: # Beam search parent selection parameters num_beams: int = 5 + # Embedding model name + embedding_model: str = "text-embedding-3-small" + def db_retry(max_retries=5, initial_delay=0.1, backoff_factor=2): """ @@ -248,12 +251,12 @@ class ProgramDatabase: populations, and an archive of elites. """ - def __init__(self, config: DatabaseConfig, read_only: bool = False): + def __init__(self, config: DatabaseConfig,embedding_model: str = "text-embedding-3-small", read_only: bool = False): self.config = config self.conn: Optional[sqlite3.Connection] = None self.cursor: Optional[sqlite3.Cursor] = None self.read_only = read_only - self.embedding_client = EmbeddingClient() + self.embedding_client = EmbeddingClient(model_name=embedding_model) self.last_iteration: int = 0 self.best_program_id: Optional[str] = None diff --git a/shinka/llm/embedding.py b/shinka/llm/embedding.py index a5c6b07..1f2ad49 100644 --- a/shinka/llm/embedding.py +++ b/shinka/llm/embedding.py @@ -1,5 +1,6 @@ import os import openai +import google.generativeai as genai import pandas as pd from typing import Union, List, Optional, Tuple import numpy as np @@ -20,13 +21,23 @@ "azure-text-embedding-3-large", ] +GEMINI_EMBEDDING_MODELS = [ + "gemini-embedding-exp-03-07", + "gemini-embedding-001", +] + OPENAI_EMBEDDING_COSTS = { "text-embedding-3-small": 0.02 / M, "text-embedding-3-large": 0.13 / M, } +# Gemini embedding costs (approximate - check current pricing) +GEMINI_EMBEDDING_COSTS = { + "gemini-embedding-exp-03-07": 0.0 / M, # Experimental model, often free + "gemini-embedding-001": 0.0 / M, # Check current pricing +} -def get_client_model(model_name: str) -> tuple[openai.OpenAI, str]: +def get_client_model(model_name: str) -> tuple[Union[openai.OpenAI, str], str]: if model_name in OPENAI_EMBEDDING_MODELS: client = openai.OpenAI() model_to_use = model_name @@ -38,6 +49,14 @@ def get_client_model(model_name: str) -> tuple[openai.OpenAI, str]: api_version=os.getenv("AZURE_API_VERSION"), azure_endpoint=os.getenv("AZURE_API_ENDPOINT"), ) + elif model_name in GEMINI_EMBEDDING_MODELS: + # Configure Gemini API + api_key = os.getenv("GOOGLE_API_KEY") + if not api_key: + raise ValueError("GOOGLE_API_KEY environment variable not set for Gemini models") + genai.configure(api_key=api_key) + client = "gemini" # Use string identifier for Gemini + model_to_use = model_name else: raise ValueError(f"Invalid embedding model: {model_name}") @@ -52,9 +71,10 @@ def __init__( Initialize the EmbeddingClient. Args: - model (str): The OpenAI embedding model name to use. + model (str): The OpenAI, Azure, or Gemini embedding model name to use. """ self.client, self.model = get_client_model(model_name) + self.model_name = model_name self.verbose = verbose def get_embedding( @@ -76,6 +96,34 @@ def get_embedding( single_code = True else: single_code = False + # Handle Gemini models + if self.model_name in GEMINI_EMBEDDING_MODELS: + try: + embeddings = [] + total_tokens = 0 + + for text in code: + result = genai.embed_content( + model=f"models/{self.model}", + content=text, + task_type="retrieval_document" + ) + embeddings.append(result['embedding']) + total_tokens += len(text.split()) + + cost = total_tokens * GEMINI_EMBEDDING_COSTS.get(self.model, 0.0) + + if single_code: + return embeddings[0] if embeddings else [], cost + else: + return embeddings, cost + except Exception as e: + logger.error(f"Error getting Gemini embedding: {e}") + if single_code: + return [], 0.0 + else: + return [[]], 0.0 + # Handle OpenAI and Azure models (same interface) try: response = self.client.embeddings.create( model=self.model, input=code, encoding_format="float"