diff --git a/internal/server/api.go b/internal/server/api.go index 0396d1a58a7..fdb3480e351 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -15,20 +15,11 @@ package server import ( - "encoding/json" - "errors" - "fmt" "net/http" - "strings" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/render" - "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/util" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/metric" ) // apiRouter creates a router that represents the routes under /api @@ -39,6 +30,12 @@ func apiRouter(s *Server) (chi.Router, error) { r.Use(middleware.StripSlashes) r.Use(render.SetContentType(render.ContentTypeJSON)) + r.Get("/authservice", func(w http.ResponseWriter, r *http.Request) { authServiceListHandler(s, w, r) }) + r.Get("/authservice/{authServiceName}", func(w http.ResponseWriter, r *http.Request) { authServiceGetHandler(s, w, r) }) + + r.Get("/source", func(w http.ResponseWriter, r *http.Request) { sourceListHandler(s, w, r) }) + r.Get("/source/{sourceName}", func(w http.ResponseWriter, r *http.Request) { sourceGetHandler(s, w, r) }) + r.Get("/toolset", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) }) r.Get("/toolset/{toolsetName}", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) }) @@ -49,292 +46,3 @@ func apiRouter(s *Server) (chi.Router, error) { return r, nil } - -// toolsetHandler handles the request for information about a Toolset. -func toolsetHandler(s *Server, w http.ResponseWriter, r *http.Request) { - ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/toolset/get") - r = r.WithContext(ctx) - - toolsetName := chi.URLParam(r, "toolsetName") - s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName)) - span.SetAttributes(attribute.String("toolset_name", toolsetName)) - var err error - defer func() { - if err != nil { - span.SetStatus(codes.Error, err.Error()) - } - span.End() - - status := "success" - if err != nil { - status = "error" - } - s.instrumentation.ToolsetGet.Add( - r.Context(), - 1, - metric.WithAttributes(attribute.String("toolbox.name", toolsetName)), - metric.WithAttributes(attribute.String("toolbox.operation.status", status)), - ) - }() - - toolset, ok := s.ResourceMgr.GetToolset(toolsetName) - if !ok { - err = fmt.Errorf("toolset %q does not exist", toolsetName) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) - return - } - render.JSON(w, r, toolset.Manifest) -} - -// toolGetHandler handles requests for a single Tool. -func toolGetHandler(s *Server, w http.ResponseWriter, r *http.Request) { - ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/tool/get") - r = r.WithContext(ctx) - - toolName := chi.URLParam(r, "toolName") - s.logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) - span.SetAttributes(attribute.String("tool_name", toolName)) - var err error - defer func() { - if err != nil { - span.SetStatus(codes.Error, err.Error()) - } - span.End() - - status := "success" - if err != nil { - status = "error" - } - s.instrumentation.ToolGet.Add( - r.Context(), - 1, - metric.WithAttributes(attribute.String("toolbox.name", toolName)), - metric.WithAttributes(attribute.String("toolbox.operation.status", status)), - ) - }() - tool, ok := s.ResourceMgr.GetTool(toolName) - if !ok { - err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) - return - } - // TODO: this can be optimized later with some caching - m := tools.ToolsetManifest{ - ServerVersion: s.version, - ToolsManifest: map[string]tools.Manifest{ - toolName: tool.Manifest(), - }, - } - - render.JSON(w, r, m) -} - -// toolInvokeHandler handles the API request to invoke a specific Tool. -func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { - ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/tool/invoke") - r = r.WithContext(ctx) - ctx = util.WithLogger(r.Context(), s.logger) - - toolName := chi.URLParam(r, "toolName") - s.logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) - span.SetAttributes(attribute.String("tool_name", toolName)) - var err error - defer func() { - if err != nil { - span.SetStatus(codes.Error, err.Error()) - } - span.End() - - status := "success" - if err != nil { - status = "error" - } - s.instrumentation.ToolInvoke.Add( - r.Context(), - 1, - metric.WithAttributes(attribute.String("toolbox.name", toolName)), - metric.WithAttributes(attribute.String("toolbox.operation.status", status)), - ) - }() - - tool, ok := s.ResourceMgr.GetTool(toolName) - if !ok { - err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) - return - } - - // Extract OAuth access token from the "Authorization" header (currently for - // BigQuery end-user credentials usage only) - accessToken := tools.AccessToken(r.Header.Get("Authorization")) - - // Check if this specific tool requires the standard authorization header - clientAuth, err := tool.RequiresClientAuthorization(s.ResourceMgr) - if err != nil { - errMsg := fmt.Errorf("error during invocation: %w", err) - s.logger.DebugContext(ctx, errMsg.Error()) - _ = render.Render(w, r, newErrResponse(errMsg, http.StatusNotFound)) - return - } - if clientAuth { - if accessToken == "" { - err = fmt.Errorf("tool requires client authorization but access token is missing from the request header") - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) - return - } - } - - // Tool authentication - // claimsFromAuth maps the name of the authservice to the claims retrieved from it. - claimsFromAuth := make(map[string]map[string]any) - for _, aS := range s.ResourceMgr.GetAuthServiceMap() { - claims, err := aS.GetClaimsFromHeader(ctx, r.Header) - if err != nil { - s.logger.DebugContext(ctx, err.Error()) - continue - } - if claims == nil { - // authService not present in header - continue - } - claimsFromAuth[aS.GetName()] = claims - } - - // Tool authorization check - verifiedAuthServices := make([]string, len(claimsFromAuth)) - i := 0 - for k := range claimsFromAuth { - verifiedAuthServices[i] = k - i++ - } - - // Check if any of the specified auth services is verified - isAuthorized := tool.Authorized(verifiedAuthServices) - if !isAuthorized { - err = fmt.Errorf("tool invocation not authorized. Please make sure your specify correct auth headers") - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) - return - } - s.logger.DebugContext(ctx, "tool invocation authorized") - - var data map[string]any - if err = util.DecodeJSON(r.Body, &data); err != nil { - render.Status(r, http.StatusBadRequest) - err = fmt.Errorf("request body was invalid JSON: %w", err) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) - return - } - - params, err := tool.ParseParams(data, claimsFromAuth) - if err != nil { - // If auth error, return 401 - if errors.Is(err, util.ErrUnauthorized) { - s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err)) - _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) - return - } - err = fmt.Errorf("provided parameters were invalid: %w", err) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) - return - } - s.logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) - - params, err = tool.EmbedParams(ctx, params, s.ResourceMgr.GetEmbeddingModelMap()) - if err != nil { - err = fmt.Errorf("error embedding parameters: %w", err) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) - return - } - - res, err := tool.Invoke(ctx, s.ResourceMgr, params, accessToken) - - // Determine what error to return to the users. - if err != nil { - errStr := err.Error() - var statusCode int - - // Upstream API auth error propagation - switch { - case strings.Contains(errStr, "Error 401"): - statusCode = http.StatusUnauthorized - case strings.Contains(errStr, "Error 403"): - statusCode = http.StatusForbidden - } - - if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { - if clientAuth { - // Propagate the original 401/403 error. - s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err)) - _ = render.Render(w, r, newErrResponse(err, statusCode)) - return - } - // ADC lacking permission or credentials configuration error. - internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err) - s.logger.ErrorContext(ctx, internalErr.Error()) - _ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError)) - return - } - err = fmt.Errorf("error while invoking tool: %w", err) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) - return - } - - resMarshal, err := json.Marshal(res) - if err != nil { - err = fmt.Errorf("unable to marshal result: %w", err) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError)) - return - } - - _ = render.Render(w, r, &resultResponse{Result: string(resMarshal)}) -} - -var _ render.Renderer = &resultResponse{} // Renderer interface for managing response payloads. - -// resultResponse is the response sent back when the tool was invocated successfully. -type resultResponse struct { - Result string `json:"result"` // result of tool invocation -} - -// Render renders a single payload and respond to the client request. -func (rr resultResponse) Render(w http.ResponseWriter, r *http.Request) error { - render.Status(r, http.StatusOK) - return nil -} - -var _ render.Renderer = &errResponse{} // Renderer interface for managing response payloads. - -// newErrResponse is a helper function initializing an ErrResponse -func newErrResponse(err error, code int) *errResponse { - return &errResponse{ - Err: err, - HTTPStatusCode: code, - - StatusText: http.StatusText(code), - ErrorText: err.Error(), - } -} - -// errResponse is the response sent back when an error has been encountered. -type errResponse struct { - Err error `json:"-"` // low-level runtime error - HTTPStatusCode int `json:"-"` // http response status code - - StatusText string `json:"status"` // user-level status message - ErrorText string `json:"error,omitempty"` // application-level error message, for debugging -} - -func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error { - render.Status(r, e.HTTPStatusCode) - return nil -} diff --git a/internal/server/api_authservices.go b/internal/server/api_authservices.go new file mode 100644 index 00000000000..be4bdf5a436 --- /dev/null +++ b/internal/server/api_authservices.go @@ -0,0 +1,147 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "net/http" + "slices" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" + "github.com/googleapis/genai-toolbox/internal/tools" +) + +type AuthServiceInfo struct { + Name string `json:"name"` + Kind string `json:"kind"` + HeaderName string `json:"headerName"` + Tools []string `json:"tools"` +} + +type AuthServiceListResponse struct { + AuthServices map[string]AuthServiceInfo `json:"authServices"` +} + +// authServiceListHandler handles requests for listing all auth services. +func authServiceListHandler(s *Server, w http.ResponseWriter, r *http.Request) { + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/authservice/list") + r = r.WithContext(ctx) + defer span.End() + + authServicesMap := s.ResourceMgr.GetAuthServiceMap() + usageByAuthService := toolsForAuthServices(s.ResourceMgr.GetToolsMap()) + resp := AuthServiceListResponse{ + AuthServices: make(map[string]AuthServiceInfo, len(authServicesMap)), + } + for name, authService := range authServicesMap { + resp.AuthServices[name] = AuthServiceInfo{ + Name: authService.GetName(), + Kind: authService.AuthServiceKind(), + HeaderName: authService.GetName(), + Tools: usageByAuthService[name], + } + } + render.JSON(w, r, resp) +} + +// authServiceGetHandler handles requests for a single auth service. +func authServiceGetHandler(s *Server, w http.ResponseWriter, r *http.Request) { + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/authservice/get") + r = r.WithContext(ctx) + defer span.End() + + authServiceName := chi.URLParam(r, "authServiceName") + authService, ok := s.ResourceMgr.GetAuthService(authServiceName) + if !ok { + err := fmt.Errorf("auth service %q does not exist", authServiceName) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) + return + } + toolsMap := s.ResourceMgr.GetToolsMap() + resp := AuthServiceListResponse{ + AuthServices: map[string]AuthServiceInfo{ + authServiceName: { + Name: authService.GetName(), + Kind: authService.AuthServiceKind(), + HeaderName: authService.GetName(), + Tools: toolsForAuthService(toolsMap, authServiceName), + }, + }, + } + render.JSON(w, r, resp) +} + +func toolsForAuthServices(toolsMap map[string]tools.Tool) map[string][]string { + usage := make(map[string]map[string]bool) + + for toolName, tool := range toolsMap { + manifest := tool.Manifest() + for _, authName := range manifest.AuthRequired { + addAuthServiceUsage(usage, authName, toolName) + } + for _, param := range manifest.Parameters { + for _, authName := range param.AuthServices { + addAuthServiceUsage(usage, authName, toolName) + } + } + } + + out := make(map[string][]string, len(usage)) + for authName, toolsSet := range usage { + toolsList := make([]string, 0, len(toolsSet)) + for toolName := range toolsSet { + toolsList = append(toolsList, toolName) + } + slices.Sort(toolsList) + out[authName] = toolsList + } + return out +} + +func addAuthServiceUsage(usage map[string]map[string]bool, authName, toolName string) { + if authName == "" { + return + } + if usage[authName] == nil { + usage[authName] = make(map[string]bool) + } + usage[authName][toolName] = true +} + +func toolsForAuthService(toolsMap map[string]tools.Tool, authServiceName string) []string { + toolsSet := make(map[string]bool, len(toolsMap)) + for toolName, tool := range toolsMap { + manifest := tool.Manifest() + if slices.Contains(manifest.AuthRequired, authServiceName) { + toolsSet[toolName] = true + continue + } + for _, param := range manifest.Parameters { + if slices.Contains(param.AuthServices, authServiceName) { + toolsSet[toolName] = true + break + } + } + } + + toolsList := make([]string, 0, len(toolsSet)) + for toolName := range toolsSet { + toolsList = append(toolsList, toolName) + } + slices.Sort(toolsList) + return toolsList +} \ No newline at end of file diff --git a/internal/server/api_common.go b/internal/server/api_common.go new file mode 100644 index 00000000000..dcb44dde951 --- /dev/null +++ b/internal/server/api_common.go @@ -0,0 +1,48 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "net/http" + + "github.com/go-chi/render" +) + +var _ render.Renderer = &errResponse{} // Renderer interface for managing response payloads. + +// newErrResponse is a helper function initializing an ErrResponse +func newErrResponse(err error, code int) *errResponse { + return &errResponse{ + Err: err, + HTTPStatusCode: code, + + StatusText: http.StatusText(code), + ErrorText: err.Error(), + } +} + +// errResponse is the response sent back when an error has been encountered. +type errResponse struct { + Err error `json:"-"` // low-level runtime error + HTTPStatusCode int `json:"-"` // http response status code + + StatusText string `json:"status"` // user-level status message + ErrorText string `json:"error,omitempty"` // application-level error message, for debugging +} + +func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error { + render.Status(r, e.HTTPStatusCode) + return nil +} diff --git a/internal/server/api_sources.go b/internal/server/api_sources.go new file mode 100644 index 00000000000..c9214333c74 --- /dev/null +++ b/internal/server/api_sources.go @@ -0,0 +1,132 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" + yaml "github.com/goccy/go-yaml" +) + +type SourceInfo struct { + Name string `json:"name"` + Kind string `json:"kind"` + Config map[string]any `json:"config,omitempty"` +} + +type SourceListResponse struct { + Sources map[string]SourceInfo `json:"sources"` +} + +// sourceListHandler handles requests for listing all sources. +func sourceListHandler(s *Server, w http.ResponseWriter, r *http.Request) { + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/source/list") + r = r.WithContext(ctx) + defer span.End() + + sourcesMap := s.ResourceMgr.GetSourcesMap() + resp := SourceListResponse{ + Sources: make(map[string]SourceInfo, len(sourcesMap)), + } + for name, source := range sourcesMap { + resp.Sources[name] = SourceInfo{ + Name: name, + Kind: source.SourceKind(), + } + } + render.JSON(w, r, resp) +} + +// sourceGetHandler handles requests for a single source. +func sourceGetHandler(s *Server, w http.ResponseWriter, r *http.Request) { + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/source/get") + r = r.WithContext(ctx) + defer span.End() + + sourceName := chi.URLParam(r, "sourceName") + source, ok := s.ResourceMgr.GetSource(sourceName) + if !ok { + err := fmt.Errorf("source %q does not exist", sourceName) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) + return + } + configMap, err := sourceConfigToMap(source.ToConfig()) + if err != nil { + errMsg := fmt.Errorf("unable to serialize source %q config: %w", sourceName, err) + s.logger.DebugContext(ctx, errMsg.Error()) + _ = render.Render(w, r, newErrResponse(errMsg, http.StatusInternalServerError)) + return + } + redactSensitiveValues(configMap) + resp := SourceListResponse{ + Sources: map[string]SourceInfo{ + sourceName: { + Name: sourceName, + Kind: source.SourceKind(), + Config: configMap, + }, + }, + } + render.JSON(w, r, resp) +} + +func sourceConfigToMap(cfg any) (map[string]any, error) { + raw, err := yaml.Marshal(cfg) + if err != nil { + return nil, err + } + var configMap map[string]any + if err := yaml.Unmarshal(raw, &configMap); err != nil { + return nil, err + } + return configMap, nil +} + +func redactSensitiveValues(v any) { + switch typed := v.(type) { + case map[string]any: + for k, val := range typed { + if isSensitiveKey(k) { + typed[k] = "[REDACTED]" + continue + } + redactSensitiveValues(val) + } + case []any: + for i := range typed { + redactSensitiveValues(typed[i]) + } + } +} + +func isSensitiveKey(key string) bool { + lower := strings.ToLower(key) + // Avoid using "key" as a substring match as it is too broad and can match things like "primary_key". + if lower == "key" || lower == "api_key" { + return true + } + sensitive := []string{"password", "secret", "token", "credential"} + for _, keyword := range sensitive { + if strings.Contains(lower, keyword) { + return true + } + } + return false +} diff --git a/internal/server/api_test.go b/internal/server/api_test.go index 49c5206477a..a7d34f0d209 100644 --- a/internal/server/api_test.go +++ b/internal/server/api_test.go @@ -23,6 +23,8 @@ import ( "strings" "testing" + "github.com/googleapis/genai-toolbox/internal/auth" + "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" ) @@ -295,3 +297,201 @@ func TestToolInvokeEndpoint(t *testing.T) { }) } } + +func TestSourceListEndpoint(t *testing.T) { + sourceA := &MockSource{Name: "source-a", Kind: "postgres"} + sourceB := &MockSource{Name: "source-b", Kind: "mysql"} + sourcesMap := map[string]sources.Source{ + "source-a": sourceA, + "source-b": sourceB, + } + + r, shutdown := setUpServerWithResources(t, "api", sourcesMap, nil, nil, nil, nil, nil) + defer shutdown() + ts := runServer(r, false) + defer ts.Close() + + resp, body, err := runRequest(ts, http.MethodGet, "/source", nil, nil) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + + if contentType := resp.Header.Get("Content-type"); contentType != "application/json" { + t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + + var m SourceListResponse + if err := json.Unmarshal(body, &m); err != nil { + t.Fatalf("unable to parse SourceListResponse: %s", err) + } + + if _, ok := m.Sources["source-a"]; !ok { + t.Fatalf("source-a not found in response") + } + if _, ok := m.Sources["source-b"]; !ok { + t.Fatalf("source-b not found in response") + } +} + +func TestSourceGetEndpoint(t *testing.T) { + sourceA := &MockSource{ + Name: "source-a", + Kind: "postgres", + Host: "127.0.0.1", + Password: "secret", + } + sourcesMap := map[string]sources.Source{ + "source-a": sourceA, + } + + r, shutdown := setUpServerWithResources(t, "api", sourcesMap, nil, nil, nil, nil, nil) + defer shutdown() + ts := runServer(r, false) + defer ts.Close() + + resp, body, err := runRequest(ts, http.MethodGet, "/source/source-a", nil, nil) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + + if contentType := resp.Header.Get("Content-type"); contentType != "application/json" { + t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + + var m SourceListResponse + if err := json.Unmarshal(body, &m); err != nil { + t.Fatalf("unable to parse SourceListResponse: %s", err) + } + + sourceInfo, ok := m.Sources["source-a"] + if !ok { + t.Fatalf("source-a not found in response") + } + if sourceInfo.Config == nil { + t.Fatalf("expected config for source-a, got nil") + } + if host, ok := sourceInfo.Config["host"]; !ok || host != "127.0.0.1" { + t.Fatalf("expected host to be %q, got %v", "127.0.0.1", host) + } + if password, ok := sourceInfo.Config["password"]; !ok || password != "[REDACTED]" { + t.Fatalf("expected password to be redacted, got %v", password) + } + + resp, _, err = runRequest(ts, http.MethodGet, "/source/unknown-source", nil, nil) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("unexpected status code for missing source: want %d, got %d", http.StatusNotFound, resp.StatusCode) + } +} + +func TestAuthServiceListEndpoint(t *testing.T) { + authA := &MockAuthService{Name: "auth-a", Kind: "google"} + authB := &MockAuthService{Name: "auth-b", Kind: "google"} + authMap := map[string]auth.AuthService{ + "auth-a": authA, + "auth-b": authB, + } + toolsMap := map[string]tools.Tool{ + "tool-auth-a": MockTool{Name: "tool-auth-a", AuthRequired: []string{"auth-a"}}, + } + + r, shutdown := setUpServerWithResources(t, "api", nil, authMap, toolsMap, nil, nil, nil) + defer shutdown() + ts := runServer(r, false) + defer ts.Close() + + resp, body, err := runRequest(ts, http.MethodGet, "/authservice", nil, nil) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + + if contentType := resp.Header.Get("Content-type"); contentType != "application/json" { + t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + + var m AuthServiceListResponse + if err := json.Unmarshal(body, &m); err != nil { + t.Fatalf("unable to parse AuthServiceListResponse: %s", err) + } + + authAInfo, ok := m.AuthServices["auth-a"] + if !ok { + t.Fatalf("auth-a not found in response") + } + if authAInfo.HeaderName != "auth-a" { + t.Fatalf("unexpected headerName: want %q, got %q", "auth-a", authAInfo.HeaderName) + } + if len(authAInfo.Tools) != 1 || authAInfo.Tools[0] != "tool-auth-a" { + t.Fatalf("unexpected tools list for auth-a: %v", authAInfo.Tools) + } + + if _, ok := m.AuthServices["auth-b"]; !ok { + t.Fatalf("auth-b not found in response") + } +} + +func TestAuthServiceGetEndpoint(t *testing.T) { + authA := &MockAuthService{Name: "auth-a", Kind: "google"} + authMap := map[string]auth.AuthService{ + "auth-a": authA, + } + toolsMap := map[string]tools.Tool{ + "tool-auth-a": MockTool{Name: "tool-auth-a", AuthRequired: []string{"auth-a"}}, + } + + r, shutdown := setUpServerWithResources(t, "api", nil, authMap, toolsMap, nil, nil, nil) + defer shutdown() + ts := runServer(r, false) + defer ts.Close() + + resp, body, err := runRequest(ts, http.MethodGet, "/authservice/auth-a", nil, nil) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + + if contentType := resp.Header.Get("Content-type"); contentType != "application/json" { + t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + + var m AuthServiceListResponse + if err := json.Unmarshal(body, &m); err != nil { + t.Fatalf("unable to parse AuthServiceListResponse: %s", err) + } + + authAInfo, ok := m.AuthServices["auth-a"] + if !ok { + t.Fatalf("auth-a not found in response") + } + if authAInfo.HeaderName != "auth-a" { + t.Fatalf("unexpected headerName: want %q, got %q", "auth-a", authAInfo.HeaderName) + } + if len(authAInfo.Tools) != 1 || authAInfo.Tools[0] != "tool-auth-a" { + t.Fatalf("unexpected tools list for auth-a: %v", authAInfo.Tools) + } + + resp, _, err = runRequest(ts, http.MethodGet, "/authservice/unknown-auth", nil, nil) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("unexpected status code for missing auth service: want %d, got %d", http.StatusNotFound, resp.StatusCode) + } +} diff --git a/internal/server/api_tools.go b/internal/server/api_tools.go new file mode 100644 index 00000000000..c48e4dded71 --- /dev/null +++ b/internal/server/api_tools.go @@ -0,0 +1,256 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" +) + +// toolGetHandler handles requests for a single Tool. +func toolGetHandler(s *Server, w http.ResponseWriter, r *http.Request) { + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/tool/get") + r = r.WithContext(ctx) + + toolName := chi.URLParam(r, "toolName") + s.logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) + span.SetAttributes(attribute.String("tool_name", toolName)) + var err error + defer func() { + if err != nil { + span.SetStatus(codes.Error, err.Error()) + } + span.End() + + status := "success" + if err != nil { + status = "error" + } + s.instrumentation.ToolGet.Add( + r.Context(), + 1, + metric.WithAttributes(attribute.String("toolbox.name", toolName)), + metric.WithAttributes(attribute.String("toolbox.operation.status", status)), + ) + }() + tool, ok := s.ResourceMgr.GetTool(toolName) + if !ok { + err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) + return + } + // TODO: this can be optimized later with some caching + m := tools.ToolsetManifest{ + ServerVersion: s.version, + ToolsManifest: map[string]tools.Manifest{ + toolName: tool.Manifest(), + }, + } + + render.JSON(w, r, m) +} + +// toolInvokeHandler handles the API request to invoke a specific Tool. +func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/tool/invoke") + r = r.WithContext(ctx) + ctx = util.WithLogger(r.Context(), s.logger) + + toolName := chi.URLParam(r, "toolName") + s.logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) + span.SetAttributes(attribute.String("tool_name", toolName)) + var err error + defer func() { + if err != nil { + span.SetStatus(codes.Error, err.Error()) + } + span.End() + + status := "success" + if err != nil { + status = "error" + } + s.instrumentation.ToolInvoke.Add( + r.Context(), + 1, + metric.WithAttributes(attribute.String("toolbox.name", toolName)), + metric.WithAttributes(attribute.String("toolbox.operation.status", status)), + ) + }() + + tool, ok := s.ResourceMgr.GetTool(toolName) + if !ok { + err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) + return + } + + // Extract OAuth access token from the "Authorization" header (currently for + // BigQuery end-user credentials usage only) + accessToken := tools.AccessToken(r.Header.Get("Authorization")) + + // Check if this specific tool requires the standard authorization header + clientAuth, err := tool.RequiresClientAuthorization(s.ResourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + s.logger.DebugContext(ctx, errMsg.Error()) + _ = render.Render(w, r, newErrResponse(errMsg, http.StatusNotFound)) + return + } + if clientAuth { + if accessToken == "" { + err = fmt.Errorf("tool requires client authorization but access token is missing from the request header") + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) + return + } + } + + // Tool authentication + // claimsFromAuth maps the name of the authservice to the claims retrieved from it. + claimsFromAuth := make(map[string]map[string]any) + for _, aS := range s.ResourceMgr.GetAuthServiceMap() { + claims, err := aS.GetClaimsFromHeader(ctx, r.Header) + if err != nil { + s.logger.DebugContext(ctx, err.Error()) + continue + } + if claims == nil { + // authService not present in header + continue + } + claimsFromAuth[aS.GetName()] = claims + } + + // Tool authorization check + verifiedAuthServices := make([]string, len(claimsFromAuth)) + i := 0 + for k := range claimsFromAuth { + verifiedAuthServices[i] = k + i++ + } + + // Check if any of the specified auth services is verified + isAuthorized := tool.Authorized(verifiedAuthServices) + if !isAuthorized { + err = fmt.Errorf("tool invocation not authorized. Please make sure your specify correct auth headers") + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) + return + } + s.logger.DebugContext(ctx, "tool invocation authorized") + + var data map[string]any + if err = util.DecodeJSON(r.Body, &data); err != nil { + render.Status(r, http.StatusBadRequest) + err = fmt.Errorf("request body was invalid JSON: %w", err) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) + return + } + + params, err := tool.ParseParams(data, claimsFromAuth) + if err != nil { + // If auth error, return 401 + if errors.Is(err, util.ErrUnauthorized) { + s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err)) + _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) + return + } + err = fmt.Errorf("provided parameters were invalid: %w", err) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) + return + } + s.logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) + + params, err = tool.EmbedParams(ctx, params, s.ResourceMgr.GetEmbeddingModelMap()) + if err != nil { + err = fmt.Errorf("error embedding parameters: %w", err) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) + return + } + + res, err := tool.Invoke(ctx, s.ResourceMgr, params, accessToken) + + // Determine what error to return to the users. + if err != nil { + errStr := err.Error() + var statusCode int + + // Upstream API auth error propagation + switch { + case strings.Contains(errStr, "Error 401"): + statusCode = http.StatusUnauthorized + case strings.Contains(errStr, "Error 403"): + statusCode = http.StatusForbidden + } + + if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { + if clientAuth { + // Propagate the original 401/403 error. + s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err)) + _ = render.Render(w, r, newErrResponse(err, statusCode)) + return + } + // ADC lacking permission or credentials configuration error. + internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err) + s.logger.ErrorContext(ctx, internalErr.Error()) + _ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError)) + return + } + err = fmt.Errorf("error while invoking tool: %w", err) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) + return + } + + resMarshal, err := json.Marshal(res) + if err != nil { + err = fmt.Errorf("unable to marshal result: %w", err) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError)) + return + } + + _ = render.Render(w, r, &resultResponse{Result: string(resMarshal)}) +} + +var _ render.Renderer = &resultResponse{} // Renderer interface for managing response payloads. + +// resultResponse is the response sent back when the tool was invocated successfully. +type resultResponse struct { + Result string `json:"result"` // result of tool invocation +} + +// Render renders a single payload and respond to the client request. +func (rr resultResponse) Render(w http.ResponseWriter, r *http.Request) error { + render.Status(r, http.StatusOK) + return nil +} diff --git a/internal/server/api_toolsets.go b/internal/server/api_toolsets.go new file mode 100644 index 00000000000..16154821ec1 --- /dev/null +++ b/internal/server/api_toolsets.go @@ -0,0 +1,63 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" +) + +// toolsetHandler handles the request for information about a Toolset. +func toolsetHandler(s *Server, w http.ResponseWriter, r *http.Request) { + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/toolset/get") + r = r.WithContext(ctx) + + toolsetName := chi.URLParam(r, "toolsetName") + s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName)) + span.SetAttributes(attribute.String("toolset_name", toolsetName)) + var err error + defer func() { + if err != nil { + span.SetStatus(codes.Error, err.Error()) + } + span.End() + + status := "success" + if err != nil { + status = "error" + } + s.instrumentation.ToolsetGet.Add( + r.Context(), + 1, + metric.WithAttributes(attribute.String("toolbox.name", toolsetName)), + metric.WithAttributes(attribute.String("toolbox.operation.status", status)), + ) + }() + + toolset, ok := s.ResourceMgr.GetToolset(toolsetName) + if !ok { + err = fmt.Errorf("toolset %q does not exist", toolsetName) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) + return + } + render.JSON(w, r, toolset.Manifest) +} diff --git a/internal/server/common_test.go b/internal/server/common_test.go index 39aca55be39..48309ee8491 100644 --- a/internal/server/common_test.go +++ b/internal/server/common_test.go @@ -24,21 +24,26 @@ import ( "testing" "github.com/go-chi/chi/v5" + "github.com/googleapis/genai-toolbox/internal/auth" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/log" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server/resources" + "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/telemetry" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "go.opentelemetry.io/otel/trace" ) // fakeVersionString is used as a temporary version string in tests const fakeVersionString = "0.0.0" var ( - _ tools.Tool = MockTool{} - _ prompts.Prompt = MockPrompt{} + _ tools.Tool = MockTool{} + _ prompts.Prompt = MockPrompt{} + _ sources.Source = &MockSource{} + _ auth.AuthService = &MockAuthService{} ) // MockTool is used to mock tools in tests @@ -46,11 +51,85 @@ type MockTool struct { Name string Description string Params []parameters.Parameter + AuthRequired []string manifest tools.Manifest unauthorized bool requiresClientAuthrorization bool } +// MockSourceConfig is used to mock sources in tests. +type MockSourceConfig struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Host string `yaml:"host"` + Password string `yaml:"password"` +} + +func (c MockSourceConfig) SourceConfigKind() string { + return c.Kind +} + +func (c MockSourceConfig) Initialize(context.Context, trace.Tracer) (sources.Source, error) { + return &MockSource{Name: c.Name, Kind: c.Kind, Host: c.Host, Password: c.Password}, nil +} + +// MockSource is used to mock sources in tests. +type MockSource struct { + Name string + Kind string + Host string + Password string +} + +func (s *MockSource) SourceKind() string { + return s.Kind +} + +func (s *MockSource) ToConfig() sources.SourceConfig { + return MockSourceConfig{ + Name: s.Name, + Kind: s.Kind, + Host: s.Host, + Password: s.Password, + } +} + +// MockAuthServiceConfig is used to mock auth services in tests. +type MockAuthServiceConfig struct { + Name string + Kind string +} + +func (c MockAuthServiceConfig) AuthServiceConfigKind() string { + return c.Kind +} + +func (c MockAuthServiceConfig) Initialize() (auth.AuthService, error) { + return &MockAuthService{Name: c.Name, Kind: c.Kind}, nil +} + +// MockAuthService is used to mock auth services in tests. +type MockAuthService struct { + Name string + Kind string +} + +func (s *MockAuthService) AuthServiceKind() string { + return s.Kind +} + +func (s *MockAuthService) GetName() string { + return s.Name +} + +func (s *MockAuthService) GetClaimsFromHeader(context.Context, http.Header) (map[string]any, error) { + return nil, nil +} + +func (s *MockAuthService) ToConfig() auth.AuthServiceConfig { + return MockAuthServiceConfig{Name: s.Name, Kind: s.Kind} +} + func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, error) { mock := []any{t.Name} return mock, nil @@ -74,7 +153,7 @@ func (t MockTool) Manifest() tools.Manifest { for _, p := range t.Params { pMs = append(pMs, p.Manifest()) } - return tools.Manifest{Description: t.Description, Parameters: pMs} + return tools.Manifest{Description: t.Description, Parameters: pMs, AuthRequired: t.AuthRequired} } func (t MockTool) Authorized(verifiedAuthServices []string) bool { @@ -262,6 +341,20 @@ func setUpResources(t *testing.T, mockTools []MockTool, mockPrompts []MockPrompt // setUpServer create a new server with tools, toolsets, prompts, and promptsets. func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, toolsets map[string]tools.Toolset, prompts map[string]prompts.Prompt, promptsets map[string]prompts.Promptset) (chi.Router, func()) { + return setUpServerWithResources(t, router, nil, nil, tools, toolsets, prompts, promptsets) +} + +// setUpServerWithResources create a new server with sources, auth services, tools, toolsets, prompts, and promptsets. +func setUpServerWithResources( + t *testing.T, + router string, + sourcesMap map[string]sources.Source, + authServices map[string]auth.AuthService, + tools map[string]tools.Tool, + toolsets map[string]tools.Toolset, + prompts map[string]prompts.Prompt, + promptsets map[string]prompts.Promptset, +) (chi.Router, func()) { ctx, cancel := context.WithCancel(context.Background()) testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info") @@ -281,7 +374,7 @@ func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, tools sseManager := newSseManager(ctx) - resourceManager := resources.NewResourceManager(nil, nil, nil, tools, toolsets, prompts, promptsets) + resourceManager := resources.NewResourceManager(sourcesMap, authServices, nil, tools, toolsets, prompts, promptsets) server := Server{ version: fakeVersionString, diff --git a/internal/server/resources/resources.go b/internal/server/resources/resources.go index b41e160a39b..a7746217aeb 100644 --- a/internal/server/resources/resources.go +++ b/internal/server/resources/resources.go @@ -139,6 +139,16 @@ func (r *ResourceManager) GetEmbeddingModelMap() map[string]embeddingmodels.Embe return copiedMap } +func (r *ResourceManager) GetSourcesMap() map[string]sources.Source { + r.mu.RLock() + defer r.mu.RUnlock() + copiedMap := make(map[string]sources.Source, len(r.sources)) + for k, v := range r.sources { + copiedMap[k] = v + } + return copiedMap +} + func (r *ResourceManager) GetToolsMap() map[string]tools.Tool { r.mu.RLock() defer r.mu.RUnlock() diff --git a/internal/server/static/authservices.html b/internal/server/static/authservices.html new file mode 100644 index 00000000000..f12d48e90fa --- /dev/null +++ b/internal/server/static/authservices.html @@ -0,0 +1,33 @@ + + +
+ + +