Skip to content

Commit 06fd00f

Browse files
committed
feat: add groq-hosted models
1 parent 0003552 commit 06fd00f

File tree

6 files changed

+157
-5
lines changed

6 files changed

+157
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ dependencies = [
4545
"adjustText",
4646
"markdown",
4747
"aiofiles",
48+
"groq",
4849
]
4950

5051
[tool.setuptools]

shinka/llm/client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
OPENAI_MODELS,
1212
DEEPSEEK_MODELS,
1313
GEMINI_MODELS,
14+
GROQ_MODELS,
1415
)
1516

1617
env_path = Path(__file__).parent.parent.parent / ".env"
@@ -78,6 +79,14 @@ def get_client_llm(model_name: str, structured_output: bool = False) -> Tuple[An
7879
client,
7980
mode=instructor.Mode.GEMINI_JSON,
8081
)
82+
elif model_name in GROQ_MODELS.keys():
83+
import groq
84+
# Strip groq/ prefix for API call
85+
actual_model_name = model_name.replace("groq/", "", 1)
86+
client = groq.Groq(api_key=os.environ["GROQ_API_KEY"])
87+
if structured_output:
88+
raise NotImplementedError("Structured output not supported for Groq.")
89+
return client, actual_model_name
8190
else:
8291
raise ValueError(f"Model {model_name} not supported.")
8392

shinka/llm/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
from .openai import query_openai
33
from .deepseek import query_deepseek
44
from .gemini import query_gemini
5+
from .groq import query_groq
56
from .result import QueryResult
67

78
__all__ = [
89
"query_anthropic",
910
"query_openai",
1011
"query_deepseek",
1112
"query_gemini",
13+
"query_groq",
1214
"QueryResult",
1315
]

shinka/llm/models/groq.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import backoff
2+
import groq
3+
from .pricing import GROQ_MODELS
4+
from .result import QueryResult
5+
import logging
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def backoff_handler(details):
11+
exc = details.get("exception")
12+
if exc:
13+
logger.warning(
14+
f"Groq - Retry {details['tries']} due to error: {exc}. Waiting {details['wait']:0.1f}s..."
15+
)
16+
17+
18+
@backoff.on_exception(
19+
backoff.expo,
20+
(
21+
groq.APIConnectionError,
22+
groq.APIStatusError,
23+
groq.RateLimitError,
24+
groq.APITimeoutError,
25+
),
26+
max_tries=5,
27+
max_value=20,
28+
on_backoff=backoff_handler,
29+
)
30+
def query_groq(
31+
client,
32+
model,
33+
msg,
34+
system_msg,
35+
msg_history,
36+
output_model,
37+
model_posteriors=None,
38+
**kwargs,
39+
) -> QueryResult:
40+
"""Query Groq model."""
41+
if output_model is not None:
42+
raise NotImplementedError("Structured output not supported for Groq.")
43+
44+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
45+
response = client.chat.completions.create(
46+
model=model,
47+
messages=[
48+
{"role": "system", "content": system_msg},
49+
*new_msg_history,
50+
],
51+
**kwargs,
52+
n=1,
53+
stop=None,
54+
)
55+
56+
content = response.choices[0].message.content
57+
new_msg_history.append({"role": "assistant", "content": content})
58+
59+
# Add groq/ prefix back for pricing lookup (client.py strips one groq/ prefix)
60+
pricing_key = f"groq/{model}"
61+
input_cost = GROQ_MODELS[pricing_key]["input_price"] * response.usage.prompt_tokens
62+
output_cost = GROQ_MODELS[pricing_key]["output_price"] * response.usage.completion_tokens
63+
64+
return QueryResult(
65+
content=content,
66+
msg=msg,
67+
system_msg=system_msg,
68+
new_msg_history=new_msg_history,
69+
model_name=pricing_key, # Use the full groq/ prefixed name
70+
kwargs=kwargs,
71+
input_tokens=response.usage.prompt_tokens,
72+
output_tokens=response.usage.completion_tokens,
73+
cost=input_cost + output_cost,
74+
input_cost=input_cost,
75+
output_cost=output_cost,
76+
thought="",
77+
model_posteriors=model_posteriors,
78+
)

shinka/llm/models/pricing.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,53 @@
143143
},
144144
}
145145

146+
GROQ_MODELS = {
147+
"groq/openai/gpt-oss-120b": {
148+
"input_price": 0.15 / M,
149+
"output_price": 0.75 / M,
150+
},
151+
"groq/openai/gpt-oss-20b": {
152+
"input_price": 0.10 / M,
153+
"output_price": 0.50 / M,
154+
},
155+
"groq/groq/compound": {
156+
"input_price": 0.15 / M, # Uses gpt-oss-120b pricing
157+
"output_price": 0.75 / M,
158+
},
159+
"groq/groq/compound-mini": {
160+
"input_price": 0.15 / M, # Uses gpt-oss-120b pricing
161+
"output_price": 0.75 / M,
162+
},
163+
"groq/moonshotai/kimi-k2-instruct-0905": {
164+
"input_price": 1.00 / M,
165+
"output_price": 3.00 / M,
166+
},
167+
"groq/llama-3.3-70b-versatile": {
168+
"input_price": 0.59 / M,
169+
"output_price": 0.79 / M,
170+
},
171+
"groq/qwen/qwen3-32b": {
172+
"input_price": 0.29 / M,
173+
"output_price": 0.59 / M,
174+
},
175+
"groq/llama-3.1-8b-instant": {
176+
"input_price": 0.05 / M,
177+
"output_price": 0.08 / M,
178+
},
179+
"groq/deepseek-r1-distill-llama-70b": {
180+
"input_price": 0.75 / M,
181+
"output_price": 0.99 / M,
182+
},
183+
"groq/meta-llama/llama-4-scout-17b-16e-instruct": {
184+
"input_price": 0.11 / M,
185+
"output_price": 0.34 / M,
186+
},
187+
"groq/meta-llama/llama-4-maverick-17b-128e-instruct": {
188+
"input_price": 0.20 / M,
189+
"output_price": 0.60 / M,
190+
},
191+
}
192+
146193
BEDROCK_MODELS = {
147194
"bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0": CLAUDE_MODELS[
148195
"claude-3-5-sonnet-20241022"
@@ -200,3 +247,10 @@
200247
"bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
201248
"bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
202249
]
250+
251+
REASONING_GROQ_MODELS = [
252+
"groq/openai/gpt-oss-120b",
253+
"groq/openai/gpt-oss-20b",
254+
"groq/groq/compound",
255+
"groq/groq/compound-mini",
256+
]

shinka/llm/query.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,22 @@
77
OPENAI_MODELS,
88
DEEPSEEK_MODELS,
99
GEMINI_MODELS,
10+
GROQ_MODELS,
1011
BEDROCK_MODELS,
1112
REASONING_OAI_MODELS,
1213
REASONING_CLAUDE_MODELS,
1314
REASONING_DEEPSEEK_MODELS,
1415
REASONING_GEMINI_MODELS,
1516
REASONING_AZURE_MODELS,
1617
REASONING_BEDROCK_MODELS,
18+
REASONING_GROQ_MODELS,
1719
)
1820
from .models import (
1921
query_anthropic,
2022
query_openai,
2123
query_deepseek,
2224
query_gemini,
25+
query_groq,
2326
QueryResult,
2427
)
2528
import logging
@@ -119,6 +122,7 @@ def sample_model_kwargs(
119122
+ REASONING_GEMINI_MODELS
120123
+ REASONING_AZURE_MODELS
121124
+ REASONING_BEDROCK_MODELS
125+
+ REASONING_GROQ_MODELS
122126
):
123127
kwargs_dict["temperature"] = 1.0
124128
else:
@@ -180,6 +184,7 @@ def sample_model_kwargs(
180184
or kwargs_dict["model_name"] in REASONING_BEDROCK_MODELS
181185
or kwargs_dict["model_name"] in DEEPSEEK_MODELS
182186
or kwargs_dict["model_name"] in REASONING_DEEPSEEK_MODELS
187+
or kwargs_dict["model_name"] in GROQ_MODELS
183188
):
184189
kwargs_dict["max_tokens"] = random.choice(max_tokens)
185190
else:
@@ -198,19 +203,22 @@ def query(
198203
**kwargs,
199204
) -> QueryResult:
200205
"""Query the LLM."""
206+
original_model_name = model_name
201207
client, model_name = get_client_llm(
202208
model_name, structured_output=output_model is not None
203209
)
204-
if model_name in CLAUDE_MODELS.keys() or "anthropic" in model_name:
210+
if original_model_name in CLAUDE_MODELS.keys() or "anthropic" in original_model_name:
205211
query_fn = query_anthropic
206-
elif model_name in OPENAI_MODELS.keys():
212+
elif original_model_name in OPENAI_MODELS.keys():
207213
query_fn = query_openai
208-
elif model_name in DEEPSEEK_MODELS.keys():
214+
elif original_model_name in DEEPSEEK_MODELS.keys():
209215
query_fn = query_deepseek
210-
elif model_name in GEMINI_MODELS.keys():
216+
elif original_model_name in GEMINI_MODELS.keys():
211217
query_fn = query_gemini
218+
elif original_model_name in GROQ_MODELS.keys():
219+
query_fn = query_groq
212220
else:
213-
raise ValueError(f"Model {model_name} not supported.")
221+
raise ValueError(f"Model {original_model_name} not supported.")
214222
result = query_fn(
215223
client,
216224
model_name,

0 commit comments

Comments
 (0)