From 6e806d7763d36775ca352163b440e50b9e92a898 Mon Sep 17 00:00:00 2001 From: Mengsheng Wu Date: Fri, 26 Sep 2025 23:08:41 +0800 Subject: [PATCH] feat(completion): enhance completion and chat endpoints with keep-alive functionality --- runner/server/handler/completion.go | 38 ++++++++++++++++++++++++----- runner/server/handler/embedder.go | 1 + runner/server/handler/image.go | 1 + runner/server/service/keepalive.go | 31 ++++++++++++++--------- 4 files changed, 54 insertions(+), 17 deletions(-) diff --git a/runner/server/handler/completion.go b/runner/server/handler/completion.go index 1ccffe66..39d9528d 100644 --- a/runner/server/handler/completion.go +++ b/runner/server/handler/completion.go @@ -13,12 +13,33 @@ import ( "github.com/openai/openai-go" "github.com/openai/openai-go/shared/constant" + "github.com/NexaAI/nexa-sdk/runner/internal/config" "github.com/NexaAI/nexa-sdk/runner/internal/store" "github.com/NexaAI/nexa-sdk/runner/internal/types" nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk" "github.com/NexaAI/nexa-sdk/runner/server/service" ) +type BaseParams struct { + // stream: if false the response will be returned as a single response object, rather than a stream of objects + Stream bool `json:"stream" default:"false"` + // keep_alive: controls how long the model will stay loaded into memory following the request (default: 5m) + KeepAlive *int64 `json:"keep_alive" default:"300"` +} + +// getKeepAliveValue extracts the keepAlive value from BaseParams, using default if not set +func getKeepAliveValue(param BaseParams) int64 { + if param.KeepAlive != nil { + return *param.KeepAlive + } + return config.Get().KeepAlive +} + +type CompletionRequest struct { + BaseParams + openai.CompletionNewParams +} + // @Router /completions [post] // @Summary completion // @Description Legacy completion endpoint for text generation. It is recommended to use the Chat Completions endpoint for new applications. @@ -27,16 +48,18 @@ import ( // @Produce json // @Success 200 {object} openai.Completion func Completions(c *gin.Context) { - param := openai.CompletionNewParams{} + param := CompletionRequest{} if err := c.ShouldBindJSON(¶m); err != nil { c.JSON(http.StatusBadRequest, map[string]any{"error": err.Error()}) return } + slog.Debug("param", "param", param) p, err := service.KeepAliveGet[nexa_sdk.LLM]( string(param.Model), types.ModelParam{NCtx: 4096}, c.GetHeader("Nexa-KeepCache") != "true", + getKeepAliveValue(param.BaseParams), ) if err != nil { c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()}) @@ -58,19 +81,18 @@ func Completions(c *gin.Context) { } } -type ChatCompletionNewParams openai.ChatCompletionNewParams - // ChatCompletionRequest defines the request body for the chat completions API. // example: { "model": "nexaml/nexaml-models", "messages": [ { "role": "user", "content": "why is the sky blue?" } ] } type ChatCompletionRequest struct { - Stream bool `json:"stream" default:"false"` EnableThink bool `json:"enable_think" default:"true"` - - ChatCompletionNewParams + BaseParams + openai.ChatCompletionNewParams } var toolCallRegex = regexp.MustCompile(`([\s\S]+)<\/tool_call>` + "|" + "```json([\\s\\S]+)```") + + // @Router /chat/completions [post] // @Summary Creates a model response for the given chat conversation. // @Description This endpoint generates a model response for a given conversation, which can include text and images. It supports both single-turn and multi-turn conversations and can be used for various tasks like question answering, code generation, and function calling. @@ -85,6 +107,8 @@ func ChatCompletions(c *gin.Context) { return } + slog.Debug("param", "param", param) + s := store.Get() manifest, err := s.GetManifest(param.Model) if err != nil { @@ -109,6 +133,7 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) { string(param.Model), types.ModelParam{NCtx: 4096}, c.GetHeader("Nexa-KeepCache") != "true", + getKeepAliveValue(param.BaseParams), ) if errors.Is(err, os.ErrNotExist) { c.JSON(http.StatusNotFound, map[string]any{"error": "model not found"}) @@ -276,6 +301,7 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) { string(param.Model), types.ModelParam{NCtx: 4096}, c.GetHeader("Nexa-KeepCache") != "true", + getKeepAliveValue(param.BaseParams), ) if errors.Is(err, os.ErrNotExist) { c.JSON(http.StatusNotFound, map[string]any{"error": "model not found"}) diff --git a/runner/server/handler/embedder.go b/runner/server/handler/embedder.go index 89f43f46..dc8157e4 100644 --- a/runner/server/handler/embedder.go +++ b/runner/server/handler/embedder.go @@ -27,6 +27,7 @@ func Embeddings(c *gin.Context) { string(param.Model), types.ModelParam{}, false, + 300, // default 5 minutes for embedder ) if err != nil { c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()}) diff --git a/runner/server/handler/image.go b/runner/server/handler/image.go index c7564c38..c868ea03 100644 --- a/runner/server/handler/image.go +++ b/runner/server/handler/image.go @@ -57,6 +57,7 @@ func ImageGenerations(c *gin.Context) { param.Model, types.ModelParam{}, c.GetHeader("Nexa-KeepCache") != "true", + 300, // default 5 minutes for image generation ) if err != nil { c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()}) diff --git a/runner/server/service/keepalive.go b/runner/server/service/keepalive.go index b7915c2d..eebeb803 100644 --- a/runner/server/service/keepalive.go +++ b/runner/server/service/keepalive.go @@ -14,8 +14,9 @@ import ( // KeepAliveGet retrieves a model from the keepalive cache or creates it if not found // This avoids the overhead of repeatedly loading/unloading models from disk -func KeepAliveGet[T any](name string, param types.ModelParam, reset bool) (*T, error) { - t, err := keepAliveGet[T](name, param, reset) +// keepAlive specifies the timeout in seconds for this specific model instance +func KeepAliveGet[T any](name string, param types.ModelParam, reset bool, keepAlive int64) (*T, error) { + t, err := keepAliveGet[T](name, param, reset, keepAlive) if err != nil { return nil, err } @@ -34,9 +35,10 @@ type keepAliveService struct { // modelKeepInfo holds metadata for a cached model instance type modelKeepInfo struct { - model keepable - param types.ModelParam - lastTime time.Time + model keepable + param types.ModelParam + lastTime time.Time + keepAliveTimeout int64 } // keepable interface defines objects that can be managed by the keepalive service @@ -70,7 +72,12 @@ func (keepAlive *keepAliveService) start() { case <-t.C: keepAlive.Lock() for name, model := range keepAlive.models { - if time.Since(model.lastTime).Milliseconds()/1000 > config.Get().KeepAlive { + // Use the model-specific keepAlive timeout, fallback to global config if not set + timeout := model.keepAliveTimeout + if timeout <= 0 { + timeout = config.Get().KeepAlive + } + if time.Since(model.lastTime).Milliseconds()/1000 > timeout { model.model.Destroy() delete(keepAlive.models, name) } @@ -83,7 +90,7 @@ func (keepAlive *keepAliveService) start() { // keepAliveGet retrieves a cached model or creates a new one if not found // Ensures only one model is kept in memory at a time by clearing others -func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any, error) { +func keepAliveGet[T any](name string, param types.ModelParam, reset bool, keepAliveTimeout int64) (any, error) { keepAlive.Lock() defer keepAlive.Unlock() @@ -102,6 +109,7 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any, model.model.Reset() } model.lastTime = time.Now() + model.keepAliveTimeout = keepAliveTimeout return model.model, nil } @@ -127,7 +135,7 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any, break } } - + var t keepable var e error switch reflect.TypeFor[T]() { @@ -188,9 +196,10 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any, return nil, e } model = &modelKeepInfo{ - model: t, - param: param, - lastTime: time.Now(), + model: t, + param: param, + lastTime: time.Now(), + keepAliveTimeout: keepAliveTimeout, } keepAlive.models[name] = model