Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ go get github.com/go-playground/validator/v10

#### 前端依赖
```bash
npm install @yokowu/modelkit-ui
npm install @ctzhian/modelkit
# 或
yarn add @yokowu/modelkit-ui
yarn add @ctzhian/modelkit
```

### 2. 实现接口
Expand Down
27 changes: 25 additions & 2 deletions domain/model.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,41 @@
package domain

import "github.com/chaitin/ModelKit/v2/consts"
import (
"github.com/chaitin/ModelKit/v2/consts"
"github.com/cloudwego/eino-ext/libs/acl/openai"
)

type ModelMetadata struct {
// 基础参数
ModelName string `json:"id"` // 模型的名字
Object string `json:"object"` // 总是model
Created int `json:"created"` // 创建时间
Provider consts.ModelProvider `json:"provider"` // 提供商
ModelType consts.ModelType `json:"model_type"` // 模型类型

// api 调用相关参数
BaseURL string `json:"base_url"`
APIKey string `json:"api_key"`
APIHeader string `json:"api_header"`
APIVersion string `json:"api_version"` // for azure openai
// 高级参数
// 限制生成的最大token数量,可选,默认为模型最大值, Ollama不支持
MaxTokens *int `json:"max_tokens"`
// 采样温度参数,建议与TopP二选一,范围0-2,值越大输出越随机,可选,默认1.0
Temperature *float32 `json:"temperature"`
// 控制采样多样性,建议与Temperature二选一,范围0-1,值越小输出越聚焦,可选,默认1.0
TopP *float32 `json:"top_p"`
// API停止生成的序列标记,可选,例如:[]string{"\n", "User:"}
Stop []string `json:"stop"`
// 基于存在惩罚重复,范围-2到2,正值增加新主题可能性,可选,默认0, Gemini不支持
PresencePenalty *float32 `json:"presence_penalty"`
// 指定模型响应的格式,可选,用于结构化输出, DS,Gemini,Ollama不支持
ResponseFormat *openai.ChatCompletionResponseFormat `json:"response_format"`
// 启用确定性采样以获得一致输出,可选,用于可重现结果, DS,Gemini不支持
Seed *int `json:"seed"`
// 基于频率惩罚重复,范围-2到2,正值降低重复可能性,可选,默认0, Gemini不支持
FrequencyPenalty *float32 `json:"frequency_penalty"`
// 修改特定token在补全中出现的可能性,可选,token ID到偏置值(-100到100)的映射, DS,Gemini,Ollama不支持
LogitBias map[string]int `json:"logit_bias"`
}

var Models []ModelMetadata
Expand Down
98 changes: 91 additions & 7 deletions usecase/modelkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,46 @@ func (m *ModelKit) CheckModel(ctx context.Context, req *domain.CheckModelReq) (*
func (m *ModelKit) GetChatModel(ctx context.Context, model *domain.ModelMetadata) (model.BaseChatModel, error) {
// config chat model
modelProvider := model.Provider

// 使用高级参数中的温度值,如果没有设置则使用默认值0.0
var temperature float32 = 0.0
if model.Temperature != nil {
temperature = *model.Temperature
}

config := &openai.ChatModelConfig{
APIKey: model.APIKey,
BaseURL: model.BaseURL,
Model: string(model.ModelName),
Temperature: &temperature,
}

// 添加高级参数支持
if model.MaxTokens != nil {
config.MaxTokens = model.MaxTokens
}
if model.TopP != nil {
config.TopP = model.TopP
}
if len(model.Stop) > 0 {
config.Stop = model.Stop
}
if model.PresencePenalty != nil {
config.PresencePenalty = model.PresencePenalty
}
if model.FrequencyPenalty != nil {
config.FrequencyPenalty = model.FrequencyPenalty
}
if model.ResponseFormat != nil {
config.ResponseFormat = model.ResponseFormat
}
if model.Seed != nil {
config.Seed = model.Seed
}
if model.LogitBias != nil {
config.LogitBias = model.LogitBias
}

if modelProvider == consts.ModelProviderAzureOpenAI {
config.ByAzure = true
config.APIVersion = model.APIVersion
Expand All @@ -341,12 +374,32 @@ func (m *ModelKit) GetChatModel(ctx context.Context, model *domain.ModelMetadata

switch modelProvider {
case consts.ModelProviderDeepSeek:
chatModel, err := deepseek.NewChatModel(ctx, &deepseek.ChatModelConfig{
deepseekConfig := &deepseek.ChatModelConfig{
BaseURL: model.BaseURL,
APIKey: model.APIKey,
Model: model.ModelName,
Temperature: temperature,
})
}

// 添加 DeepSeek 支持的高级参数
if model.MaxTokens != nil {
deepseekConfig.MaxTokens = *model.MaxTokens
}
if model.TopP != nil {
deepseekConfig.TopP = *model.TopP
}
if len(model.Stop) > 0 {
deepseekConfig.Stop = model.Stop
}
if model.PresencePenalty != nil {
deepseekConfig.PresencePenalty = *model.PresencePenalty
}
if model.FrequencyPenalty != nil {
deepseekConfig.FrequencyPenalty = *model.FrequencyPenalty
}
// ResponseFormat, Seed, LogitBias 在 DeepSeek 配置中不支持,跳过

chatModel, err := deepseek.NewChatModel(ctx, deepseekConfig)
if err != nil {
return nil, err
}
Expand All @@ -359,14 +412,26 @@ func (m *ModelKit) GetChatModel(ctx context.Context, model *domain.ModelMetadata
return nil, err
}

chatModel, err := gemini.NewChatModel(ctx, &gemini.Config{
geminiConfig := &gemini.Config{
Client: client,
Model: model.ModelName,
ThinkingConfig: &genai.ThinkingConfig{
IncludeThoughts: true,
ThinkingBudget: nil,
},
})
}

// 添加 Gemini 支持的高级参数
if model.MaxTokens != nil {
geminiConfig.MaxTokens = model.MaxTokens
}
if model.Temperature != nil {
geminiConfig.Temperature = model.Temperature
}
if model.TopP != nil {
geminiConfig.TopP = model.TopP
}
chatModel, err := gemini.NewChatModel(ctx, geminiConfig)
if err != nil {
return nil, err
}
Expand All @@ -385,13 +450,32 @@ func (m *ModelKit) GetChatModel(ctx context.Context, model *domain.ModelMetadata
return nil, err
}

ollamaOptions := &api.Options{
Temperature: temperature,
}

// 添加 Ollama 支持的高级参数
if model.TopP != nil {
ollamaOptions.TopP = *model.TopP
}
if len(model.Stop) > 0 {
ollamaOptions.Stop = model.Stop
}
if model.PresencePenalty != nil {
ollamaOptions.PresencePenalty = *model.PresencePenalty
}
if model.FrequencyPenalty != nil {
ollamaOptions.FrequencyPenalty = *model.FrequencyPenalty
}
if model.Seed != nil {
ollamaOptions.Seed = *model.Seed
}

chatModel, err := ollama.NewChatModel(ctx, &ollama.ChatModelConfig{
BaseURL: baseUrl,
Timeout: config.Timeout,
Model: config.Model,
Options: &api.Options{
Temperature: temperature,
},
Options: ollamaOptions,
})
if err != nil {
return nil, err
Expand Down
Loading