Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"adjustText",
"markdown",
"aiofiles",
"groq",
]

[tool.setuptools]
Expand Down
9 changes: 9 additions & 0 deletions shinka/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
OPENAI_MODELS,
DEEPSEEK_MODELS,
GEMINI_MODELS,
GROQ_MODELS,
)

env_path = Path(__file__).parent.parent.parent / ".env"
Expand Down Expand Up @@ -78,6 +79,14 @@ def get_client_llm(model_name: str, structured_output: bool = False) -> Tuple[An
client,
mode=instructor.Mode.GEMINI_JSON,
)
elif model_name in GROQ_MODELS.keys():
import groq
# Strip groq/ prefix for API call
actual_model_name = model_name.replace("groq/", "", 1)
client = groq.Groq(api_key=os.environ["GROQ_API_KEY"])
if structured_output:
raise NotImplementedError("Structured output not supported for Groq.")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Groq does support structured output.

return client, actual_model_name
else:
raise ValueError(f"Model {model_name} not supported.")

Expand Down
2 changes: 2 additions & 0 deletions shinka/llm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from .openai import query_openai
from .deepseek import query_deepseek
from .gemini import query_gemini
from .groq import query_groq
from .result import QueryResult

__all__ = [
"query_anthropic",
"query_openai",
"query_deepseek",
"query_gemini",
"query_groq",
"QueryResult",
]
78 changes: 78 additions & 0 deletions shinka/llm/models/groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import backoff
import groq
from .pricing import GROQ_MODELS
from .result import QueryResult
import logging

logger = logging.getLogger(__name__)


def backoff_handler(details):
exc = details.get("exception")
if exc:
logger.warning(
f"Groq - Retry {details['tries']} due to error: {exc}. Waiting {details['wait']:0.1f}s..."
)


@backoff.on_exception(
backoff.expo,
(
groq.APIConnectionError,
groq.APIStatusError,
groq.RateLimitError,
groq.APITimeoutError,
),
max_tries=5,
max_value=20,
on_backoff=backoff_handler,
)
def query_groq(
client,
model,
msg,
system_msg,
msg_history,
output_model,
model_posteriors=None,
**kwargs,
) -> QueryResult:
"""Query Groq model."""
if output_model is not None:
raise NotImplementedError("Structured output not supported for Groq.")

new_msg_history = msg_history + [{"role": "user", "content": msg}]
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_msg},
*new_msg_history,
],
**kwargs,
n=1,
stop=None,
)

content = response.choices[0].message.content
new_msg_history.append({"role": "assistant", "content": content})

# Add groq/ prefix back for pricing lookup (client.py strips one groq/ prefix)
pricing_key = f"groq/{model}"
input_cost = GROQ_MODELS[pricing_key]["input_price"] * response.usage.prompt_tokens
output_cost = GROQ_MODELS[pricing_key]["output_price"] * response.usage.completion_tokens

return QueryResult(
content=content,
msg=msg,
system_msg=system_msg,
new_msg_history=new_msg_history,
model_name=pricing_key, # Use the full groq/ prefixed name
kwargs=kwargs,
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
cost=input_cost + output_cost,
input_cost=input_cost,
output_cost=output_cost,
thought="",
model_posteriors=model_posteriors,
)
54 changes: 54 additions & 0 deletions shinka/llm/models/pricing.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,53 @@
},
}

GROQ_MODELS = {
"groq/openai/gpt-oss-120b": {
"input_price": 0.15 / M,
"output_price": 0.75 / M,
},
"groq/openai/gpt-oss-20b": {
"input_price": 0.10 / M,
"output_price": 0.50 / M,
},
"groq/groq/compound": {
"input_price": 0.15 / M, # Uses gpt-oss-120b pricing
"output_price": 0.75 / M,
},
"groq/groq/compound-mini": {
"input_price": 0.15 / M, # Uses gpt-oss-120b pricing
"output_price": 0.75 / M,
},
"groq/moonshotai/kimi-k2-instruct-0905": {
"input_price": 1.00 / M,
"output_price": 3.00 / M,
},
"groq/llama-3.3-70b-versatile": {
"input_price": 0.59 / M,
"output_price": 0.79 / M,
},
"groq/qwen/qwen3-32b": {
"input_price": 0.29 / M,
"output_price": 0.59 / M,
},
"groq/llama-3.1-8b-instant": {
"input_price": 0.05 / M,
"output_price": 0.08 / M,
},
"groq/deepseek-r1-distill-llama-70b": {
"input_price": 0.75 / M,
"output_price": 0.99 / M,
},
"groq/meta-llama/llama-4-scout-17b-16e-instruct": {
"input_price": 0.11 / M,
"output_price": 0.34 / M,
},
"groq/meta-llama/llama-4-maverick-17b-128e-instruct": {
"input_price": 0.20 / M,
"output_price": 0.60 / M,
},
}

BEDROCK_MODELS = {
"bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0": CLAUDE_MODELS[
"claude-3-5-sonnet-20241022"
Expand Down Expand Up @@ -200,3 +247,10 @@
"bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
"bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
]

REASONING_GROQ_MODELS = [
"groq/openai/gpt-oss-120b",
"groq/openai/gpt-oss-20b",
"groq/groq/compound",
"groq/groq/compound-mini",
]
18 changes: 13 additions & 5 deletions shinka/llm/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@
OPENAI_MODELS,
DEEPSEEK_MODELS,
GEMINI_MODELS,
GROQ_MODELS,
BEDROCK_MODELS,
REASONING_OAI_MODELS,
REASONING_CLAUDE_MODELS,
REASONING_DEEPSEEK_MODELS,
REASONING_GEMINI_MODELS,
REASONING_AZURE_MODELS,
REASONING_BEDROCK_MODELS,
REASONING_GROQ_MODELS,
)
from .models import (
query_anthropic,
query_openai,
query_deepseek,
query_gemini,
query_groq,
QueryResult,
)
import logging
Expand Down Expand Up @@ -119,6 +122,7 @@ def sample_model_kwargs(
+ REASONING_GEMINI_MODELS
+ REASONING_AZURE_MODELS
+ REASONING_BEDROCK_MODELS
+ REASONING_GROQ_MODELS
):
kwargs_dict["temperature"] = 1.0
else:
Expand Down Expand Up @@ -180,6 +184,7 @@ def sample_model_kwargs(
or kwargs_dict["model_name"] in REASONING_BEDROCK_MODELS
or kwargs_dict["model_name"] in DEEPSEEK_MODELS
or kwargs_dict["model_name"] in REASONING_DEEPSEEK_MODELS
or kwargs_dict["model_name"] in GROQ_MODELS
):
kwargs_dict["max_tokens"] = random.choice(max_tokens)
else:
Expand All @@ -198,19 +203,22 @@ def query(
**kwargs,
) -> QueryResult:
"""Query the LLM."""
original_model_name = model_name
client, model_name = get_client_llm(
model_name, structured_output=output_model is not None
)
if model_name in CLAUDE_MODELS.keys() or "anthropic" in model_name:
if original_model_name in CLAUDE_MODELS.keys() or "anthropic" in original_model_name:
query_fn = query_anthropic
elif model_name in OPENAI_MODELS.keys():
elif original_model_name in OPENAI_MODELS.keys():
query_fn = query_openai
elif model_name in DEEPSEEK_MODELS.keys():
elif original_model_name in DEEPSEEK_MODELS.keys():
query_fn = query_deepseek
elif model_name in GEMINI_MODELS.keys():
elif original_model_name in GEMINI_MODELS.keys():
query_fn = query_gemini
elif original_model_name in GROQ_MODELS.keys():
query_fn = query_groq
else:
raise ValueError(f"Model {model_name} not supported.")
raise ValueError(f"Model {original_model_name} not supported.")
result = query_fn(
client,
model_name,
Expand Down