Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions runner/server/handler/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,33 @@ import (
"github.com/openai/openai-go"
"github.com/openai/openai-go/shared/constant"

"github.com/NexaAI/nexa-sdk/runner/internal/config"
"github.com/NexaAI/nexa-sdk/runner/internal/store"
"github.com/NexaAI/nexa-sdk/runner/internal/types"
nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk"
"github.com/NexaAI/nexa-sdk/runner/server/service"
)

type BaseParams struct {
// stream: if false the response will be returned as a single response object, rather than a stream of objects
Stream bool `json:"stream" default:"false"`
// keep_alive: controls how long the model will stay loaded into memory following the request (default: 5m)
KeepAlive *int64 `json:"keep_alive" default:"300"`
}

// getKeepAliveValue extracts the keepAlive value from BaseParams, using default if not set
func getKeepAliveValue(param BaseParams) int64 {
if param.KeepAlive != nil {
return *param.KeepAlive
}
return config.Get().KeepAlive
}

type CompletionRequest struct {
BaseParams
openai.CompletionNewParams
}

// @Router /completions [post]
// @Summary completion
// @Description Legacy completion endpoint for text generation. It is recommended to use the Chat Completions endpoint for new applications.
Expand All @@ -27,16 +48,18 @@ import (
// @Produce json
// @Success 200 {object} openai.Completion
func Completions(c *gin.Context) {
param := openai.CompletionNewParams{}
param := CompletionRequest{}
if err := c.ShouldBindJSON(&param); err != nil {
c.JSON(http.StatusBadRequest, map[string]any{"error": err.Error()})
return
}
slog.Debug("param", "param", param)

p, err := service.KeepAliveGet[nexa_sdk.LLM](
string(param.Model),
types.ModelParam{NCtx: 4096},
c.GetHeader("Nexa-KeepCache") != "true",
getKeepAliveValue(param.BaseParams),
)
if err != nil {
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()})
Expand All @@ -58,19 +81,18 @@ func Completions(c *gin.Context) {
}
}

type ChatCompletionNewParams openai.ChatCompletionNewParams

// ChatCompletionRequest defines the request body for the chat completions API.
// example: { "model": "nexaml/nexaml-models", "messages": [ { "role": "user", "content": "why is the sky blue?" } ] }
type ChatCompletionRequest struct {
Stream bool `json:"stream" default:"false"`
EnableThink bool `json:"enable_think" default:"true"`

ChatCompletionNewParams
BaseParams
openai.ChatCompletionNewParams
}

var toolCallRegex = regexp.MustCompile(`<tool_call>([\s\S]+)<\/tool_call>` + "|" + "```json([\\s\\S]+)```")



// @Router /chat/completions [post]
// @Summary Creates a model response for the given chat conversation.
// @Description This endpoint generates a model response for a given conversation, which can include text and images. It supports both single-turn and multi-turn conversations and can be used for various tasks like question answering, code generation, and function calling.
Expand All @@ -85,6 +107,8 @@ func ChatCompletions(c *gin.Context) {
return
}

slog.Debug("param", "param", param)

s := store.Get()
manifest, err := s.GetManifest(param.Model)
if err != nil {
Expand All @@ -109,6 +133,7 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {
string(param.Model),
types.ModelParam{NCtx: 4096},
c.GetHeader("Nexa-KeepCache") != "true",
getKeepAliveValue(param.BaseParams),
)
if errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusNotFound, map[string]any{"error": "model not found"})
Expand Down Expand Up @@ -276,6 +301,7 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {
string(param.Model),
types.ModelParam{NCtx: 4096},
c.GetHeader("Nexa-KeepCache") != "true",
getKeepAliveValue(param.BaseParams),
)
if errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusNotFound, map[string]any{"error": "model not found"})
Expand Down
1 change: 1 addition & 0 deletions runner/server/handler/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func Embeddings(c *gin.Context) {
string(param.Model),
types.ModelParam{},
false,
300, // default 5 minutes for embedder
)
if err != nil {
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()})
Expand Down
1 change: 1 addition & 0 deletions runner/server/handler/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func ImageGenerations(c *gin.Context) {
param.Model,
types.ModelParam{},
c.GetHeader("Nexa-KeepCache") != "true",
300, // default 5 minutes for image generation
)
if err != nil {
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()})
Expand Down
31 changes: 20 additions & 11 deletions runner/server/service/keepalive.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ import (

// KeepAliveGet retrieves a model from the keepalive cache or creates it if not found
// This avoids the overhead of repeatedly loading/unloading models from disk
func KeepAliveGet[T any](name string, param types.ModelParam, reset bool) (*T, error) {
t, err := keepAliveGet[T](name, param, reset)
// keepAlive specifies the timeout in seconds for this specific model instance
func KeepAliveGet[T any](name string, param types.ModelParam, reset bool, keepAlive int64) (*T, error) {
t, err := keepAliveGet[T](name, param, reset, keepAlive)
if err != nil {
return nil, err
}
Expand All @@ -34,9 +35,10 @@ type keepAliveService struct {

// modelKeepInfo holds metadata for a cached model instance
type modelKeepInfo struct {
model keepable
param types.ModelParam
lastTime time.Time
model keepable
param types.ModelParam
lastTime time.Time
keepAliveTimeout int64
}

// keepable interface defines objects that can be managed by the keepalive service
Expand Down Expand Up @@ -70,7 +72,12 @@ func (keepAlive *keepAliveService) start() {
case <-t.C:
keepAlive.Lock()
for name, model := range keepAlive.models {
if time.Since(model.lastTime).Milliseconds()/1000 > config.Get().KeepAlive {
// Use the model-specific keepAlive timeout, fallback to global config if not set
timeout := model.keepAliveTimeout
if timeout <= 0 {
timeout = config.Get().KeepAlive
}
if time.Since(model.lastTime).Milliseconds()/1000 > timeout {
model.model.Destroy()
delete(keepAlive.models, name)
}
Expand All @@ -83,7 +90,7 @@ func (keepAlive *keepAliveService) start() {

// keepAliveGet retrieves a cached model or creates a new one if not found
// Ensures only one model is kept in memory at a time by clearing others
func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any, error) {
func keepAliveGet[T any](name string, param types.ModelParam, reset bool, keepAliveTimeout int64) (any, error) {
keepAlive.Lock()
defer keepAlive.Unlock()

Expand All @@ -102,6 +109,7 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any,
model.model.Reset()
}
model.lastTime = time.Now()
model.keepAliveTimeout = keepAliveTimeout
return model.model, nil
}

Expand All @@ -127,7 +135,7 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any,
break
}
}

var t keepable
var e error
switch reflect.TypeFor[T]() {
Expand Down Expand Up @@ -188,9 +196,10 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any,
return nil, e
}
model = &modelKeepInfo{
model: t,
param: param,
lastTime: time.Now(),
model: t,
param: param,
lastTime: time.Now(),
keepAliveTimeout: keepAliveTimeout,
}
keepAlive.models[name] = model

Expand Down