Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/shinka_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,17 @@
"if not llm_models:\n",
" llm_models = [\"gpt-5-mini\"] # fallback if no keys detected\n",
"\n",
"# pick embedding model based on available keys\n",
"embedding_model_name = \"\"\n",
"if os.getenv(\"GEMINI_API_KEY\"):\n",
" embedding_model_name = \"gemini-embedding-001\"\n",
"elif os.getenv(\"OPENAI_API_KEY\"):\n",
" embedding_model_name = \"text-embedding-3-small\"\n",
"else:\n",
" embedding_model_name = \"text-embedding-3-small\"\n",
"print(f\"✅ Embedding model selected: {embedding_model_name}\")\n",
"\n",
"\n",
"# unique experiment directory\n",
"timestamp = dt.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
"run_tag = f\"{timestamp}_weighted_fast\"\n",
Expand Down Expand Up @@ -271,6 +282,8 @@
" max_novelty_attempts=3,\n",
" # ensemble llm selection among candidates based on past performance\n",
" llm_dynamic_selection=None, # e.g. \"ucb1\"\n",
" # set embedding model\n",
" embedding_model=embedding_model_name,\n",
")\n",
"\n",
"db_config = DatabaseConfig(\n",
Expand All @@ -286,11 +299,13 @@
" enforce_island_separation=True,\n",
" parent_selection_strategy=\"weighted\",\n",
" parent_selection_lambda=10.0,\n",
" \n",
")\n",
"\n",
"job_config = LocalJobConfig(eval_program_path=\"evaluate.py\")\n",
"\n",
"print(\"llm_models:\", llm_models)\n",
"print(\"embedding_model:\", embedding_model_name)\n",
"print(\"results_dir:\", evo_config.results_dir)"
]
},
Expand Down
7 changes: 6 additions & 1 deletion shinka/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,12 @@ def __init__(

# Initialize database and scheduler
db_config.db_path = str(db_path)
self.db = ProgramDatabase(config=db_config)
embedding_model_to_use = (
evo_config.embedding_model or "text-embedding-3-small"
)
self.db = ProgramDatabase(
config=db_config, embedding_model=embedding_model_to_use
)
self.scheduler = JobScheduler(
job_type=evo_config.job_type,
config=job_config, # type: ignore
Expand Down
7 changes: 5 additions & 2 deletions shinka/database/dbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class DatabaseConfig:
# Beam search parent selection parameters
num_beams: int = 5

# Embedding model name
embedding_model: str = "text-embedding-3-small"


def db_retry(max_retries=5, initial_delay=0.1, backoff_factor=2):
"""
Expand Down Expand Up @@ -248,12 +251,12 @@ class ProgramDatabase:
populations, and an archive of elites.
"""

def __init__(self, config: DatabaseConfig, read_only: bool = False):
def __init__(self, config: DatabaseConfig,embedding_model: str = "text-embedding-3-small", read_only: bool = False):
self.config = config
self.conn: Optional[sqlite3.Connection] = None
self.cursor: Optional[sqlite3.Cursor] = None
self.read_only = read_only
self.embedding_client = EmbeddingClient()
self.embedding_client = EmbeddingClient(model_name=embedding_model)

self.last_iteration: int = 0
self.best_program_id: Optional[str] = None
Expand Down
52 changes: 50 additions & 2 deletions shinka/llm/embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import openai
import google.generativeai as genai
import pandas as pd
from typing import Union, List, Optional, Tuple
import numpy as np
Expand All @@ -20,13 +21,23 @@
"azure-text-embedding-3-large",
]

GEMINI_EMBEDDING_MODELS = [
"gemini-embedding-exp-03-07",
"gemini-embedding-001",
]

OPENAI_EMBEDDING_COSTS = {
"text-embedding-3-small": 0.02 / M,
"text-embedding-3-large": 0.13 / M,
}

# Gemini embedding costs (approximate - check current pricing)
GEMINI_EMBEDDING_COSTS = {
"gemini-embedding-exp-03-07": 0.0 / M, # Experimental model, often free
"gemini-embedding-001": 0.0 / M, # Check current pricing
}

def get_client_model(model_name: str) -> tuple[openai.OpenAI, str]:
def get_client_model(model_name: str) -> tuple[Union[openai.OpenAI, str], str]:
if model_name in OPENAI_EMBEDDING_MODELS:
client = openai.OpenAI()
model_to_use = model_name
Expand All @@ -38,6 +49,14 @@ def get_client_model(model_name: str) -> tuple[openai.OpenAI, str]:
api_version=os.getenv("AZURE_API_VERSION"),
azure_endpoint=os.getenv("AZURE_API_ENDPOINT"),
)
elif model_name in GEMINI_EMBEDDING_MODELS:
# Configure Gemini API
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
raise ValueError("GOOGLE_API_KEY environment variable not set for Gemini models")
genai.configure(api_key=api_key)
client = "gemini" # Use string identifier for Gemini
model_to_use = model_name
else:
raise ValueError(f"Invalid embedding model: {model_name}")

Expand All @@ -52,9 +71,10 @@ def __init__(
Initialize the EmbeddingClient.

Args:
model (str): The OpenAI embedding model name to use.
model (str): The OpenAI, Azure, or Gemini embedding model name to use.
"""
self.client, self.model = get_client_model(model_name)
self.model_name = model_name
self.verbose = verbose

def get_embedding(
Expand All @@ -76,6 +96,34 @@ def get_embedding(
single_code = True
else:
single_code = False
# Handle Gemini models
if self.model_name in GEMINI_EMBEDDING_MODELS:
try:
embeddings = []
total_tokens = 0

for text in code:
result = genai.embed_content(
model=f"models/{self.model}",
content=text,
task_type="retrieval_document"
)
embeddings.append(result['embedding'])
total_tokens += len(text.split())

cost = total_tokens * GEMINI_EMBEDDING_COSTS.get(self.model, 0.0)

if single_code:
return embeddings[0] if embeddings else [], cost
else:
return embeddings, cost
except Exception as e:
logger.error(f"Error getting Gemini embedding: {e}")
if single_code:
return [], 0.0
else:
return [[]], 0.0
# Handle OpenAI and Azure models (same interface)
try:
response = self.client.embeddings.create(
model=self.model, input=code, encoding_format="float"
Expand Down