Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/internal/service/openai_codex_transform.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if v, ok := reqBody["model"].(string); ok {
model = v
}
normalizedModel := normalizeCodexModel(model)
normalizedModel := strings.TrimSpace(model)
if normalizedModel != "" {
if model != normalizedModel {
reqBody["model"] = normalizedModel
Expand Down
29 changes: 29 additions & 0 deletions backend/internal/service/openai_codex_transform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
"gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex-spark": "gpt-5.3-codex",
"gpt 5.3 codex spark": "gpt-5.3-codex",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
"gpt 5.3 codex": "gpt-5.3-codex",
Expand All @@ -256,6 +257,34 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
}
}

func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
"input": []any{},
}

result := applyCodexOAuthTransform(reqBody, false, false)

require.Equal(t, "gpt-5.3-codex-spark", reqBody["model"])
require.Equal(t, "gpt-5.3-codex-spark", result.NormalizedModel)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.False(t, store)
}

func TestApplyCodexOAuthTransform_TrimmedModelWithoutPolicyRewrite(t *testing.T) {
reqBody := map[string]any{
"model": " gpt-5.3-codex-spark ",
"input": []any{},
}

result := applyCodexOAuthTransform(reqBody, false, false)

require.Equal(t, "gpt-5.3-codex-spark", reqBody["model"])
require.Equal(t, "gpt-5.3-codex-spark", result.NormalizedModel)
require.True(t, result.Modified)
}

func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) {
// Codex CLI 场景:已有 instructions 时不修改

Expand Down
8 changes: 4 additions & 4 deletions backend/internal/service/openai_compat_prompt_cache_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
const compatPromptCacheKeyPrefix = "compat_cc_"

func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
switch normalizeCodexModel(strings.TrimSpace(model)) {
case "gpt-5.4", "gpt-5.3-codex":
switch resolveOpenAIUpstreamModel(strings.TrimSpace(model)) {
case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark":
return true
default:
return false
Expand All @@ -23,9 +23,9 @@ func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedMod
return ""
}

normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel))
normalizedModel := resolveOpenAIUpstreamModel(strings.TrimSpace(mappedModel))
if normalizedModel == "" {
normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model))
normalizedModel = resolveOpenAIUpstreamModel(strings.TrimSpace(req.Model))
}
if normalizedModel == "" {
normalizedModel = strings.TrimSpace(req.Model)
Expand Down
15 changes: 15 additions & 0 deletions backend/internal/service/openai_compat_prompt_cache_key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) {
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4"))
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3"))
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex"))
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex-spark"))
require.False(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-4o"))
}

Expand Down Expand Up @@ -62,3 +63,17 @@ func TestDeriveCompatPromptCacheKey_DiffersAcrossSessions(t *testing.T) {
k2 := deriveCompatPromptCacheKey(req2, "gpt-5.4")
require.NotEqual(t, k1, k2, "different first user messages should yield different keys")
}

func TestDeriveCompatPromptCacheKey_UsesResolvedSparkFamily(t *testing.T) {
req := &apicompat.ChatCompletionsRequest{
Model: "gpt-5.3-codex-spark",
Messages: []apicompat.ChatMessage{
{Role: "user", Content: mustRawJSON(t, `"Question A"`)},
},
}

k1 := deriveCompatPromptCacheKey(req, "gpt-5.3-codex-spark")
k2 := deriveCompatPromptCacheKey(req, " openai/gpt-5.3-codex-spark ")
require.NotEmpty(t, k1)
require.Equal(t, k1, k2, "resolved spark family should derive a stable compat cache key")
}
33 changes: 20 additions & 13 deletions backend/internal/service/openai_gateway_chat_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(

// 2. Resolve model mapping early so compat prompt_cache_key injection can
// derive a stable seed from the final upstream model family.
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
upstreamModel := resolveOpenAIUpstreamModel(billingModel)

promptCacheKey = strings.TrimSpace(promptCacheKey)
compatPromptCacheInjected := false
if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(mappedModel) {
promptCacheKey = deriveCompatPromptCacheKey(&chatReq, mappedModel)
if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) {
promptCacheKey = deriveCompatPromptCacheKey(&chatReq, upstreamModel)
compatPromptCacheInjected = promptCacheKey != ""
}

Expand All @@ -60,12 +61,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
if err != nil {
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
responsesReq.Model = mappedModel
responsesReq.Model = upstreamModel

logFields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.String("billing_model", billingModel),
zap.String("upstream_model", upstreamModel),
zap.Bool("stream", clientStream),
}
if compatPromptCacheInjected {
Expand All @@ -88,6 +90,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
}
codexResult := applyCodexOAuthTransform(reqBody, false, false)
if codexResult.NormalizedModel != "" {
upstreamModel = codexResult.NormalizedModel
}
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" {
Expand Down Expand Up @@ -180,9 +185,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
var result *OpenAIForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime)
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, includeUsage, startTime)
} else {
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
}

// Propagate ServiceTier and ReasoningEffort to result for billing
Expand Down Expand Up @@ -224,7 +229,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
billingModel string,
upstreamModel string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
Expand Down Expand Up @@ -295,8 +301,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
UpstreamModel: mappedModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
Expand All @@ -308,7 +314,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
billingModel string,
upstreamModel string,
includeUsage bool,
startTime time.Time,
) (*OpenAIForwardResult, error) {
Expand Down Expand Up @@ -343,8 +350,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
UpstreamModel: mappedModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
Expand Down
29 changes: 18 additions & 11 deletions backend/internal/service/openai_gateway_messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
}

// 3. Model mapping
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
responsesReq.Model = mappedModel
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
upstreamModel := resolveOpenAIUpstreamModel(billingModel)
responsesReq.Model = upstreamModel

logger.L().Debug("openai messages: model mapping applied",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.String("billing_model", billingModel),
zap.String("upstream_model", upstreamModel),
zap.Bool("stream", isStream),
)

Expand All @@ -81,6 +83,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
}
codexResult := applyCodexOAuthTransform(reqBody, false, false)
if codexResult.NormalizedModel != "" {
upstreamModel = codexResult.NormalizedModel
}
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" {
Expand Down Expand Up @@ -181,10 +186,10 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
var result *OpenAIForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime)
result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
} else {
// Client wants JSON: buffer the streaming response and assemble a JSON reply.
result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
}

// Propagate ServiceTier and ReasoningEffort to result for billing
Expand Down Expand Up @@ -229,7 +234,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
billingModel string,
upstreamModel string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
Expand Down Expand Up @@ -302,8 +308,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
UpstreamModel: mappedModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
Expand All @@ -318,7 +324,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
billingModel string,
upstreamModel string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
Expand Down Expand Up @@ -351,8 +358,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
UpstreamModel: mappedModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
Expand Down
40 changes: 20 additions & 20 deletions backend/internal/service/openai_gateway_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1814,29 +1814,29 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}

// 对所有请求执行模型映射(包含 Codex CLI)。
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
reqBody["model"] = mappedModel
billingModel := account.GetMappedModel(reqModel)
if billingModel != reqModel {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, billingModel, account.Name, isCodexCLI)
reqBody["model"] = billingModel
bodyModified = true
markPatchSet("model", mappedModel)
markPatchSet("model", billingModel)
}
upstreamModel := billingModel

// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
if model, ok := reqBody["model"].(string); ok {
normalizedModel := normalizeCodexModel(model)
if normalizedModel != "" && normalizedModel != model {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
model, normalizedModel, account.Name, account.Type, isCodexCLI)
reqBody["model"] = normalizedModel
mappedModel = normalizedModel
upstreamModel = resolveOpenAIUpstreamModel(model)
if upstreamModel != "" && upstreamModel != model {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
model, upstreamModel, account.Name, account.Type, isCodexCLI)
reqBody["model"] = upstreamModel
bodyModified = true
markPatchSet("model", normalizedModel)
markPatchSet("model", upstreamModel)
}

// 移除 gpt-5.2-codex 以下的版本 verbosity 参数
// 确保高版本模型向低版本模型映射不报错
if !SupportsVerbosity(normalizedModel) {
if !SupportsVerbosity(upstreamModel) {
if text, ok := reqBody["text"].(map[string]any); ok {
delete(text, "verbosity")
}
Expand All @@ -1860,7 +1860,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
disablePatch()
}
if codexResult.NormalizedModel != "" {
mappedModel = codexResult.NormalizedModel
upstreamModel = codexResult.NormalizedModel
}
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
Expand Down Expand Up @@ -1977,7 +1977,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
"forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v",
account.ID,
account.Type,
mappedModel,
upstreamModel,
reqStream,
hasPreviousResponseID,
)
Expand Down Expand Up @@ -2066,7 +2066,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
isCodexCLI,
reqStream,
originalModel,
mappedModel,
upstreamModel,
startTime,
attempt,
wsLastFailureReason,
Expand Down Expand Up @@ -2167,7 +2167,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
firstTokenMs,
wsAttempts,
)
wsResult.UpstreamModel = mappedModel
wsResult.UpstreamModel = upstreamModel
return wsResult, nil
}
s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
Expand Down Expand Up @@ -2272,14 +2272,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel)
if err != nil {
return nil, err
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel)
if err != nil {
return nil, err
}
Expand All @@ -2303,7 +2303,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
UpstreamModel: mappedModel,
UpstreamModel: upstreamModel,
ServiceTier: serviceTier,
ReasoningEffort: reasoningEffort,
Stream: reqStream,
Expand Down
Loading
Loading