diff --git a/internal/server/api.go b/internal/server/api.go index 0396d1a58a7..fdb3480e351 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -15,20 +15,11 @@ package server import ( - "encoding/json" - "errors" - "fmt" "net/http" - "strings" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/render" - "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/util" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/metric" ) // apiRouter creates a router that represents the routes under /api @@ -39,6 +30,12 @@ func apiRouter(s *Server) (chi.Router, error) { r.Use(middleware.StripSlashes) r.Use(render.SetContentType(render.ContentTypeJSON)) + r.Get("/authservice", func(w http.ResponseWriter, r *http.Request) { authServiceListHandler(s, w, r) }) + r.Get("/authservice/{authServiceName}", func(w http.ResponseWriter, r *http.Request) { authServiceGetHandler(s, w, r) }) + + r.Get("/source", func(w http.ResponseWriter, r *http.Request) { sourceListHandler(s, w, r) }) + r.Get("/source/{sourceName}", func(w http.ResponseWriter, r *http.Request) { sourceGetHandler(s, w, r) }) + r.Get("/toolset", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) }) r.Get("/toolset/{toolsetName}", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) }) @@ -49,292 +46,3 @@ func apiRouter(s *Server) (chi.Router, error) { return r, nil } - -// toolsetHandler handles the request for information about a Toolset. -func toolsetHandler(s *Server, w http.ResponseWriter, r *http.Request) { - ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/toolset/get") - r = r.WithContext(ctx) - - toolsetName := chi.URLParam(r, "toolsetName") - s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName)) - span.SetAttributes(attribute.String("toolset_name", toolsetName)) - var err error - defer func() { - if err != nil { - span.SetStatus(codes.Error, err.Error()) - } - span.End() - - status := "success" - if err != nil { - status = "error" - } - s.instrumentation.ToolsetGet.Add( - r.Context(), - 1, - metric.WithAttributes(attribute.String("toolbox.name", toolsetName)), - metric.WithAttributes(attribute.String("toolbox.operation.status", status)), - ) - }() - - toolset, ok := s.ResourceMgr.GetToolset(toolsetName) - if !ok { - err = fmt.Errorf("toolset %q does not exist", toolsetName) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) - return - } - render.JSON(w, r, toolset.Manifest) -} - -// toolGetHandler handles requests for a single Tool. -func toolGetHandler(s *Server, w http.ResponseWriter, r *http.Request) { - ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/tool/get") - r = r.WithContext(ctx) - - toolName := chi.URLParam(r, "toolName") - s.logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) - span.SetAttributes(attribute.String("tool_name", toolName)) - var err error - defer func() { - if err != nil { - span.SetStatus(codes.Error, err.Error()) - } - span.End() - - status := "success" - if err != nil { - status = "error" - } - s.instrumentation.ToolGet.Add( - r.Context(), - 1, - metric.WithAttributes(attribute.String("toolbox.name", toolName)), - metric.WithAttributes(attribute.String("toolbox.operation.status", status)), - ) - }() - tool, ok := s.ResourceMgr.GetTool(toolName) - if !ok { - err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) - return - } - // TODO: this can be optimized later with some caching - m := tools.ToolsetManifest{ - ServerVersion: s.version, - ToolsManifest: map[string]tools.Manifest{ - toolName: tool.Manifest(), - }, - } - - render.JSON(w, r, m) -} - -// toolInvokeHandler handles the API request to invoke a specific Tool. -func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { - ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/tool/invoke") - r = r.WithContext(ctx) - ctx = util.WithLogger(r.Context(), s.logger) - - toolName := chi.URLParam(r, "toolName") - s.logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) - span.SetAttributes(attribute.String("tool_name", toolName)) - var err error - defer func() { - if err != nil { - span.SetStatus(codes.Error, err.Error()) - } - span.End() - - status := "success" - if err != nil { - status = "error" - } - s.instrumentation.ToolInvoke.Add( - r.Context(), - 1, - metric.WithAttributes(attribute.String("toolbox.name", toolName)), - metric.WithAttributes(attribute.String("toolbox.operation.status", status)), - ) - }() - - tool, ok := s.ResourceMgr.GetTool(toolName) - if !ok { - err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) - return - } - - // Extract OAuth access token from the "Authorization" header (currently for - // BigQuery end-user credentials usage only) - accessToken := tools.AccessToken(r.Header.Get("Authorization")) - - // Check if this specific tool requires the standard authorization header - clientAuth, err := tool.RequiresClientAuthorization(s.ResourceMgr) - if err != nil { - errMsg := fmt.Errorf("error during invocation: %w", err) - s.logger.DebugContext(ctx, errMsg.Error()) - _ = render.Render(w, r, newErrResponse(errMsg, http.StatusNotFound)) - return - } - if clientAuth { - if accessToken == "" { - err = fmt.Errorf("tool requires client authorization but access token is missing from the request header") - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) - return - } - } - - // Tool authentication - // claimsFromAuth maps the name of the authservice to the claims retrieved from it. - claimsFromAuth := make(map[string]map[string]any) - for _, aS := range s.ResourceMgr.GetAuthServiceMap() { - claims, err := aS.GetClaimsFromHeader(ctx, r.Header) - if err != nil { - s.logger.DebugContext(ctx, err.Error()) - continue - } - if claims == nil { - // authService not present in header - continue - } - claimsFromAuth[aS.GetName()] = claims - } - - // Tool authorization check - verifiedAuthServices := make([]string, len(claimsFromAuth)) - i := 0 - for k := range claimsFromAuth { - verifiedAuthServices[i] = k - i++ - } - - // Check if any of the specified auth services is verified - isAuthorized := tool.Authorized(verifiedAuthServices) - if !isAuthorized { - err = fmt.Errorf("tool invocation not authorized. Please make sure your specify correct auth headers") - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) - return - } - s.logger.DebugContext(ctx, "tool invocation authorized") - - var data map[string]any - if err = util.DecodeJSON(r.Body, &data); err != nil { - render.Status(r, http.StatusBadRequest) - err = fmt.Errorf("request body was invalid JSON: %w", err) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) - return - } - - params, err := tool.ParseParams(data, claimsFromAuth) - if err != nil { - // If auth error, return 401 - if errors.Is(err, util.ErrUnauthorized) { - s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err)) - _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) - return - } - err = fmt.Errorf("provided parameters were invalid: %w", err) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) - return - } - s.logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) - - params, err = tool.EmbedParams(ctx, params, s.ResourceMgr.GetEmbeddingModelMap()) - if err != nil { - err = fmt.Errorf("error embedding parameters: %w", err) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) - return - } - - res, err := tool.Invoke(ctx, s.ResourceMgr, params, accessToken) - - // Determine what error to return to the users. - if err != nil { - errStr := err.Error() - var statusCode int - - // Upstream API auth error propagation - switch { - case strings.Contains(errStr, "Error 401"): - statusCode = http.StatusUnauthorized - case strings.Contains(errStr, "Error 403"): - statusCode = http.StatusForbidden - } - - if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { - if clientAuth { - // Propagate the original 401/403 error. - s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err)) - _ = render.Render(w, r, newErrResponse(err, statusCode)) - return - } - // ADC lacking permission or credentials configuration error. - internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err) - s.logger.ErrorContext(ctx, internalErr.Error()) - _ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError)) - return - } - err = fmt.Errorf("error while invoking tool: %w", err) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) - return - } - - resMarshal, err := json.Marshal(res) - if err != nil { - err = fmt.Errorf("unable to marshal result: %w", err) - s.logger.DebugContext(ctx, err.Error()) - _ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError)) - return - } - - _ = render.Render(w, r, &resultResponse{Result: string(resMarshal)}) -} - -var _ render.Renderer = &resultResponse{} // Renderer interface for managing response payloads. - -// resultResponse is the response sent back when the tool was invocated successfully. -type resultResponse struct { - Result string `json:"result"` // result of tool invocation -} - -// Render renders a single payload and respond to the client request. -func (rr resultResponse) Render(w http.ResponseWriter, r *http.Request) error { - render.Status(r, http.StatusOK) - return nil -} - -var _ render.Renderer = &errResponse{} // Renderer interface for managing response payloads. - -// newErrResponse is a helper function initializing an ErrResponse -func newErrResponse(err error, code int) *errResponse { - return &errResponse{ - Err: err, - HTTPStatusCode: code, - - StatusText: http.StatusText(code), - ErrorText: err.Error(), - } -} - -// errResponse is the response sent back when an error has been encountered. -type errResponse struct { - Err error `json:"-"` // low-level runtime error - HTTPStatusCode int `json:"-"` // http response status code - - StatusText string `json:"status"` // user-level status message - ErrorText string `json:"error,omitempty"` // application-level error message, for debugging -} - -func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error { - render.Status(r, e.HTTPStatusCode) - return nil -} diff --git a/internal/server/api_authservices.go b/internal/server/api_authservices.go new file mode 100644 index 00000000000..be4bdf5a436 --- /dev/null +++ b/internal/server/api_authservices.go @@ -0,0 +1,147 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "net/http" + "slices" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" + "github.com/googleapis/genai-toolbox/internal/tools" +) + +type AuthServiceInfo struct { + Name string `json:"name"` + Kind string `json:"kind"` + HeaderName string `json:"headerName"` + Tools []string `json:"tools"` +} + +type AuthServiceListResponse struct { + AuthServices map[string]AuthServiceInfo `json:"authServices"` +} + +// authServiceListHandler handles requests for listing all auth services. +func authServiceListHandler(s *Server, w http.ResponseWriter, r *http.Request) { + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/authservice/list") + r = r.WithContext(ctx) + defer span.End() + + authServicesMap := s.ResourceMgr.GetAuthServiceMap() + usageByAuthService := toolsForAuthServices(s.ResourceMgr.GetToolsMap()) + resp := AuthServiceListResponse{ + AuthServices: make(map[string]AuthServiceInfo, len(authServicesMap)), + } + for name, authService := range authServicesMap { + resp.AuthServices[name] = AuthServiceInfo{ + Name: authService.GetName(), + Kind: authService.AuthServiceKind(), + HeaderName: authService.GetName(), + Tools: usageByAuthService[name], + } + } + render.JSON(w, r, resp) +} + +// authServiceGetHandler handles requests for a single auth service. +func authServiceGetHandler(s *Server, w http.ResponseWriter, r *http.Request) { + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/authservice/get") + r = r.WithContext(ctx) + defer span.End() + + authServiceName := chi.URLParam(r, "authServiceName") + authService, ok := s.ResourceMgr.GetAuthService(authServiceName) + if !ok { + err := fmt.Errorf("auth service %q does not exist", authServiceName) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) + return + } + toolsMap := s.ResourceMgr.GetToolsMap() + resp := AuthServiceListResponse{ + AuthServices: map[string]AuthServiceInfo{ + authServiceName: { + Name: authService.GetName(), + Kind: authService.AuthServiceKind(), + HeaderName: authService.GetName(), + Tools: toolsForAuthService(toolsMap, authServiceName), + }, + }, + } + render.JSON(w, r, resp) +} + +func toolsForAuthServices(toolsMap map[string]tools.Tool) map[string][]string { + usage := make(map[string]map[string]bool) + + for toolName, tool := range toolsMap { + manifest := tool.Manifest() + for _, authName := range manifest.AuthRequired { + addAuthServiceUsage(usage, authName, toolName) + } + for _, param := range manifest.Parameters { + for _, authName := range param.AuthServices { + addAuthServiceUsage(usage, authName, toolName) + } + } + } + + out := make(map[string][]string, len(usage)) + for authName, toolsSet := range usage { + toolsList := make([]string, 0, len(toolsSet)) + for toolName := range toolsSet { + toolsList = append(toolsList, toolName) + } + slices.Sort(toolsList) + out[authName] = toolsList + } + return out +} + +func addAuthServiceUsage(usage map[string]map[string]bool, authName, toolName string) { + if authName == "" { + return + } + if usage[authName] == nil { + usage[authName] = make(map[string]bool) + } + usage[authName][toolName] = true +} + +func toolsForAuthService(toolsMap map[string]tools.Tool, authServiceName string) []string { + toolsSet := make(map[string]bool, len(toolsMap)) + for toolName, tool := range toolsMap { + manifest := tool.Manifest() + if slices.Contains(manifest.AuthRequired, authServiceName) { + toolsSet[toolName] = true + continue + } + for _, param := range manifest.Parameters { + if slices.Contains(param.AuthServices, authServiceName) { + toolsSet[toolName] = true + break + } + } + } + + toolsList := make([]string, 0, len(toolsSet)) + for toolName := range toolsSet { + toolsList = append(toolsList, toolName) + } + slices.Sort(toolsList) + return toolsList +} \ No newline at end of file diff --git a/internal/server/api_common.go b/internal/server/api_common.go new file mode 100644 index 00000000000..dcb44dde951 --- /dev/null +++ b/internal/server/api_common.go @@ -0,0 +1,48 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "net/http" + + "github.com/go-chi/render" +) + +var _ render.Renderer = &errResponse{} // Renderer interface for managing response payloads. + +// newErrResponse is a helper function initializing an ErrResponse +func newErrResponse(err error, code int) *errResponse { + return &errResponse{ + Err: err, + HTTPStatusCode: code, + + StatusText: http.StatusText(code), + ErrorText: err.Error(), + } +} + +// errResponse is the response sent back when an error has been encountered. +type errResponse struct { + Err error `json:"-"` // low-level runtime error + HTTPStatusCode int `json:"-"` // http response status code + + StatusText string `json:"status"` // user-level status message + ErrorText string `json:"error,omitempty"` // application-level error message, for debugging +} + +func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error { + render.Status(r, e.HTTPStatusCode) + return nil +} diff --git a/internal/server/api_sources.go b/internal/server/api_sources.go new file mode 100644 index 00000000000..c9214333c74 --- /dev/null +++ b/internal/server/api_sources.go @@ -0,0 +1,132 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" + yaml "github.com/goccy/go-yaml" +) + +type SourceInfo struct { + Name string `json:"name"` + Kind string `json:"kind"` + Config map[string]any `json:"config,omitempty"` +} + +type SourceListResponse struct { + Sources map[string]SourceInfo `json:"sources"` +} + +// sourceListHandler handles requests for listing all sources. +func sourceListHandler(s *Server, w http.ResponseWriter, r *http.Request) { + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/source/list") + r = r.WithContext(ctx) + defer span.End() + + sourcesMap := s.ResourceMgr.GetSourcesMap() + resp := SourceListResponse{ + Sources: make(map[string]SourceInfo, len(sourcesMap)), + } + for name, source := range sourcesMap { + resp.Sources[name] = SourceInfo{ + Name: name, + Kind: source.SourceKind(), + } + } + render.JSON(w, r, resp) +} + +// sourceGetHandler handles requests for a single source. +func sourceGetHandler(s *Server, w http.ResponseWriter, r *http.Request) { + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/source/get") + r = r.WithContext(ctx) + defer span.End() + + sourceName := chi.URLParam(r, "sourceName") + source, ok := s.ResourceMgr.GetSource(sourceName) + if !ok { + err := fmt.Errorf("source %q does not exist", sourceName) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) + return + } + configMap, err := sourceConfigToMap(source.ToConfig()) + if err != nil { + errMsg := fmt.Errorf("unable to serialize source %q config: %w", sourceName, err) + s.logger.DebugContext(ctx, errMsg.Error()) + _ = render.Render(w, r, newErrResponse(errMsg, http.StatusInternalServerError)) + return + } + redactSensitiveValues(configMap) + resp := SourceListResponse{ + Sources: map[string]SourceInfo{ + sourceName: { + Name: sourceName, + Kind: source.SourceKind(), + Config: configMap, + }, + }, + } + render.JSON(w, r, resp) +} + +func sourceConfigToMap(cfg any) (map[string]any, error) { + raw, err := yaml.Marshal(cfg) + if err != nil { + return nil, err + } + var configMap map[string]any + if err := yaml.Unmarshal(raw, &configMap); err != nil { + return nil, err + } + return configMap, nil +} + +func redactSensitiveValues(v any) { + switch typed := v.(type) { + case map[string]any: + for k, val := range typed { + if isSensitiveKey(k) { + typed[k] = "[REDACTED]" + continue + } + redactSensitiveValues(val) + } + case []any: + for i := range typed { + redactSensitiveValues(typed[i]) + } + } +} + +func isSensitiveKey(key string) bool { + lower := strings.ToLower(key) + // Avoid using "key" as a substring match as it is too broad and can match things like "primary_key". + if lower == "key" || lower == "api_key" { + return true + } + sensitive := []string{"password", "secret", "token", "credential"} + for _, keyword := range sensitive { + if strings.Contains(lower, keyword) { + return true + } + } + return false +} diff --git a/internal/server/api_test.go b/internal/server/api_test.go index 49c5206477a..a7d34f0d209 100644 --- a/internal/server/api_test.go +++ b/internal/server/api_test.go @@ -23,6 +23,8 @@ import ( "strings" "testing" + "github.com/googleapis/genai-toolbox/internal/auth" + "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" ) @@ -295,3 +297,201 @@ func TestToolInvokeEndpoint(t *testing.T) { }) } } + +func TestSourceListEndpoint(t *testing.T) { + sourceA := &MockSource{Name: "source-a", Kind: "postgres"} + sourceB := &MockSource{Name: "source-b", Kind: "mysql"} + sourcesMap := map[string]sources.Source{ + "source-a": sourceA, + "source-b": sourceB, + } + + r, shutdown := setUpServerWithResources(t, "api", sourcesMap, nil, nil, nil, nil, nil) + defer shutdown() + ts := runServer(r, false) + defer ts.Close() + + resp, body, err := runRequest(ts, http.MethodGet, "/source", nil, nil) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + + if contentType := resp.Header.Get("Content-type"); contentType != "application/json" { + t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + + var m SourceListResponse + if err := json.Unmarshal(body, &m); err != nil { + t.Fatalf("unable to parse SourceListResponse: %s", err) + } + + if _, ok := m.Sources["source-a"]; !ok { + t.Fatalf("source-a not found in response") + } + if _, ok := m.Sources["source-b"]; !ok { + t.Fatalf("source-b not found in response") + } +} + +func TestSourceGetEndpoint(t *testing.T) { + sourceA := &MockSource{ + Name: "source-a", + Kind: "postgres", + Host: "127.0.0.1", + Password: "secret", + } + sourcesMap := map[string]sources.Source{ + "source-a": sourceA, + } + + r, shutdown := setUpServerWithResources(t, "api", sourcesMap, nil, nil, nil, nil, nil) + defer shutdown() + ts := runServer(r, false) + defer ts.Close() + + resp, body, err := runRequest(ts, http.MethodGet, "/source/source-a", nil, nil) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + + if contentType := resp.Header.Get("Content-type"); contentType != "application/json" { + t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + + var m SourceListResponse + if err := json.Unmarshal(body, &m); err != nil { + t.Fatalf("unable to parse SourceListResponse: %s", err) + } + + sourceInfo, ok := m.Sources["source-a"] + if !ok { + t.Fatalf("source-a not found in response") + } + if sourceInfo.Config == nil { + t.Fatalf("expected config for source-a, got nil") + } + if host, ok := sourceInfo.Config["host"]; !ok || host != "127.0.0.1" { + t.Fatalf("expected host to be %q, got %v", "127.0.0.1", host) + } + if password, ok := sourceInfo.Config["password"]; !ok || password != "[REDACTED]" { + t.Fatalf("expected password to be redacted, got %v", password) + } + + resp, _, err = runRequest(ts, http.MethodGet, "/source/unknown-source", nil, nil) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("unexpected status code for missing source: want %d, got %d", http.StatusNotFound, resp.StatusCode) + } +} + +func TestAuthServiceListEndpoint(t *testing.T) { + authA := &MockAuthService{Name: "auth-a", Kind: "google"} + authB := &MockAuthService{Name: "auth-b", Kind: "google"} + authMap := map[string]auth.AuthService{ + "auth-a": authA, + "auth-b": authB, + } + toolsMap := map[string]tools.Tool{ + "tool-auth-a": MockTool{Name: "tool-auth-a", AuthRequired: []string{"auth-a"}}, + } + + r, shutdown := setUpServerWithResources(t, "api", nil, authMap, toolsMap, nil, nil, nil) + defer shutdown() + ts := runServer(r, false) + defer ts.Close() + + resp, body, err := runRequest(ts, http.MethodGet, "/authservice", nil, nil) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + + if contentType := resp.Header.Get("Content-type"); contentType != "application/json" { + t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + + var m AuthServiceListResponse + if err := json.Unmarshal(body, &m); err != nil { + t.Fatalf("unable to parse AuthServiceListResponse: %s", err) + } + + authAInfo, ok := m.AuthServices["auth-a"] + if !ok { + t.Fatalf("auth-a not found in response") + } + if authAInfo.HeaderName != "auth-a" { + t.Fatalf("unexpected headerName: want %q, got %q", "auth-a", authAInfo.HeaderName) + } + if len(authAInfo.Tools) != 1 || authAInfo.Tools[0] != "tool-auth-a" { + t.Fatalf("unexpected tools list for auth-a: %v", authAInfo.Tools) + } + + if _, ok := m.AuthServices["auth-b"]; !ok { + t.Fatalf("auth-b not found in response") + } +} + +func TestAuthServiceGetEndpoint(t *testing.T) { + authA := &MockAuthService{Name: "auth-a", Kind: "google"} + authMap := map[string]auth.AuthService{ + "auth-a": authA, + } + toolsMap := map[string]tools.Tool{ + "tool-auth-a": MockTool{Name: "tool-auth-a", AuthRequired: []string{"auth-a"}}, + } + + r, shutdown := setUpServerWithResources(t, "api", nil, authMap, toolsMap, nil, nil, nil) + defer shutdown() + ts := runServer(r, false) + defer ts.Close() + + resp, body, err := runRequest(ts, http.MethodGet, "/authservice/auth-a", nil, nil) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + + if contentType := resp.Header.Get("Content-type"); contentType != "application/json" { + t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) + } + + var m AuthServiceListResponse + if err := json.Unmarshal(body, &m); err != nil { + t.Fatalf("unable to parse AuthServiceListResponse: %s", err) + } + + authAInfo, ok := m.AuthServices["auth-a"] + if !ok { + t.Fatalf("auth-a not found in response") + } + if authAInfo.HeaderName != "auth-a" { + t.Fatalf("unexpected headerName: want %q, got %q", "auth-a", authAInfo.HeaderName) + } + if len(authAInfo.Tools) != 1 || authAInfo.Tools[0] != "tool-auth-a" { + t.Fatalf("unexpected tools list for auth-a: %v", authAInfo.Tools) + } + + resp, _, err = runRequest(ts, http.MethodGet, "/authservice/unknown-auth", nil, nil) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("unexpected status code for missing auth service: want %d, got %d", http.StatusNotFound, resp.StatusCode) + } +} diff --git a/internal/server/api_tools.go b/internal/server/api_tools.go new file mode 100644 index 00000000000..c48e4dded71 --- /dev/null +++ b/internal/server/api_tools.go @@ -0,0 +1,256 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" +) + +// toolGetHandler handles requests for a single Tool. +func toolGetHandler(s *Server, w http.ResponseWriter, r *http.Request) { + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/tool/get") + r = r.WithContext(ctx) + + toolName := chi.URLParam(r, "toolName") + s.logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) + span.SetAttributes(attribute.String("tool_name", toolName)) + var err error + defer func() { + if err != nil { + span.SetStatus(codes.Error, err.Error()) + } + span.End() + + status := "success" + if err != nil { + status = "error" + } + s.instrumentation.ToolGet.Add( + r.Context(), + 1, + metric.WithAttributes(attribute.String("toolbox.name", toolName)), + metric.WithAttributes(attribute.String("toolbox.operation.status", status)), + ) + }() + tool, ok := s.ResourceMgr.GetTool(toolName) + if !ok { + err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) + return + } + // TODO: this can be optimized later with some caching + m := tools.ToolsetManifest{ + ServerVersion: s.version, + ToolsManifest: map[string]tools.Manifest{ + toolName: tool.Manifest(), + }, + } + + render.JSON(w, r, m) +} + +// toolInvokeHandler handles the API request to invoke a specific Tool. +func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/tool/invoke") + r = r.WithContext(ctx) + ctx = util.WithLogger(r.Context(), s.logger) + + toolName := chi.URLParam(r, "toolName") + s.logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) + span.SetAttributes(attribute.String("tool_name", toolName)) + var err error + defer func() { + if err != nil { + span.SetStatus(codes.Error, err.Error()) + } + span.End() + + status := "success" + if err != nil { + status = "error" + } + s.instrumentation.ToolInvoke.Add( + r.Context(), + 1, + metric.WithAttributes(attribute.String("toolbox.name", toolName)), + metric.WithAttributes(attribute.String("toolbox.operation.status", status)), + ) + }() + + tool, ok := s.ResourceMgr.GetTool(toolName) + if !ok { + err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) + return + } + + // Extract OAuth access token from the "Authorization" header (currently for + // BigQuery end-user credentials usage only) + accessToken := tools.AccessToken(r.Header.Get("Authorization")) + + // Check if this specific tool requires the standard authorization header + clientAuth, err := tool.RequiresClientAuthorization(s.ResourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + s.logger.DebugContext(ctx, errMsg.Error()) + _ = render.Render(w, r, newErrResponse(errMsg, http.StatusNotFound)) + return + } + if clientAuth { + if accessToken == "" { + err = fmt.Errorf("tool requires client authorization but access token is missing from the request header") + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) + return + } + } + + // Tool authentication + // claimsFromAuth maps the name of the authservice to the claims retrieved from it. + claimsFromAuth := make(map[string]map[string]any) + for _, aS := range s.ResourceMgr.GetAuthServiceMap() { + claims, err := aS.GetClaimsFromHeader(ctx, r.Header) + if err != nil { + s.logger.DebugContext(ctx, err.Error()) + continue + } + if claims == nil { + // authService not present in header + continue + } + claimsFromAuth[aS.GetName()] = claims + } + + // Tool authorization check + verifiedAuthServices := make([]string, len(claimsFromAuth)) + i := 0 + for k := range claimsFromAuth { + verifiedAuthServices[i] = k + i++ + } + + // Check if any of the specified auth services is verified + isAuthorized := tool.Authorized(verifiedAuthServices) + if !isAuthorized { + err = fmt.Errorf("tool invocation not authorized. Please make sure your specify correct auth headers") + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) + return + } + s.logger.DebugContext(ctx, "tool invocation authorized") + + var data map[string]any + if err = util.DecodeJSON(r.Body, &data); err != nil { + render.Status(r, http.StatusBadRequest) + err = fmt.Errorf("request body was invalid JSON: %w", err) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) + return + } + + params, err := tool.ParseParams(data, claimsFromAuth) + if err != nil { + // If auth error, return 401 + if errors.Is(err, util.ErrUnauthorized) { + s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err)) + _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) + return + } + err = fmt.Errorf("provided parameters were invalid: %w", err) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) + return + } + s.logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) + + params, err = tool.EmbedParams(ctx, params, s.ResourceMgr.GetEmbeddingModelMap()) + if err != nil { + err = fmt.Errorf("error embedding parameters: %w", err) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) + return + } + + res, err := tool.Invoke(ctx, s.ResourceMgr, params, accessToken) + + // Determine what error to return to the users. + if err != nil { + errStr := err.Error() + var statusCode int + + // Upstream API auth error propagation + switch { + case strings.Contains(errStr, "Error 401"): + statusCode = http.StatusUnauthorized + case strings.Contains(errStr, "Error 403"): + statusCode = http.StatusForbidden + } + + if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { + if clientAuth { + // Propagate the original 401/403 error. + s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err)) + _ = render.Render(w, r, newErrResponse(err, statusCode)) + return + } + // ADC lacking permission or credentials configuration error. + internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err) + s.logger.ErrorContext(ctx, internalErr.Error()) + _ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError)) + return + } + err = fmt.Errorf("error while invoking tool: %w", err) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) + return + } + + resMarshal, err := json.Marshal(res) + if err != nil { + err = fmt.Errorf("unable to marshal result: %w", err) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError)) + return + } + + _ = render.Render(w, r, &resultResponse{Result: string(resMarshal)}) +} + +var _ render.Renderer = &resultResponse{} // Renderer interface for managing response payloads. + +// resultResponse is the response sent back when the tool was invocated successfully. +type resultResponse struct { + Result string `json:"result"` // result of tool invocation +} + +// Render renders a single payload and respond to the client request. +func (rr resultResponse) Render(w http.ResponseWriter, r *http.Request) error { + render.Status(r, http.StatusOK) + return nil +} diff --git a/internal/server/api_toolsets.go b/internal/server/api_toolsets.go new file mode 100644 index 00000000000..16154821ec1 --- /dev/null +++ b/internal/server/api_toolsets.go @@ -0,0 +1,63 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" +) + +// toolsetHandler handles the request for information about a Toolset. +func toolsetHandler(s *Server, w http.ResponseWriter, r *http.Request) { + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/toolset/get") + r = r.WithContext(ctx) + + toolsetName := chi.URLParam(r, "toolsetName") + s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName)) + span.SetAttributes(attribute.String("toolset_name", toolsetName)) + var err error + defer func() { + if err != nil { + span.SetStatus(codes.Error, err.Error()) + } + span.End() + + status := "success" + if err != nil { + status = "error" + } + s.instrumentation.ToolsetGet.Add( + r.Context(), + 1, + metric.WithAttributes(attribute.String("toolbox.name", toolsetName)), + metric.WithAttributes(attribute.String("toolbox.operation.status", status)), + ) + }() + + toolset, ok := s.ResourceMgr.GetToolset(toolsetName) + if !ok { + err = fmt.Errorf("toolset %q does not exist", toolsetName) + s.logger.DebugContext(ctx, err.Error()) + _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) + return + } + render.JSON(w, r, toolset.Manifest) +} diff --git a/internal/server/common_test.go b/internal/server/common_test.go index 39aca55be39..48309ee8491 100644 --- a/internal/server/common_test.go +++ b/internal/server/common_test.go @@ -24,21 +24,26 @@ import ( "testing" "github.com/go-chi/chi/v5" + "github.com/googleapis/genai-toolbox/internal/auth" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/log" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server/resources" + "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/telemetry" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "go.opentelemetry.io/otel/trace" ) // fakeVersionString is used as a temporary version string in tests const fakeVersionString = "0.0.0" var ( - _ tools.Tool = MockTool{} - _ prompts.Prompt = MockPrompt{} + _ tools.Tool = MockTool{} + _ prompts.Prompt = MockPrompt{} + _ sources.Source = &MockSource{} + _ auth.AuthService = &MockAuthService{} ) // MockTool is used to mock tools in tests @@ -46,11 +51,85 @@ type MockTool struct { Name string Description string Params []parameters.Parameter + AuthRequired []string manifest tools.Manifest unauthorized bool requiresClientAuthrorization bool } +// MockSourceConfig is used to mock sources in tests. +type MockSourceConfig struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Host string `yaml:"host"` + Password string `yaml:"password"` +} + +func (c MockSourceConfig) SourceConfigKind() string { + return c.Kind +} + +func (c MockSourceConfig) Initialize(context.Context, trace.Tracer) (sources.Source, error) { + return &MockSource{Name: c.Name, Kind: c.Kind, Host: c.Host, Password: c.Password}, nil +} + +// MockSource is used to mock sources in tests. +type MockSource struct { + Name string + Kind string + Host string + Password string +} + +func (s *MockSource) SourceKind() string { + return s.Kind +} + +func (s *MockSource) ToConfig() sources.SourceConfig { + return MockSourceConfig{ + Name: s.Name, + Kind: s.Kind, + Host: s.Host, + Password: s.Password, + } +} + +// MockAuthServiceConfig is used to mock auth services in tests. +type MockAuthServiceConfig struct { + Name string + Kind string +} + +func (c MockAuthServiceConfig) AuthServiceConfigKind() string { + return c.Kind +} + +func (c MockAuthServiceConfig) Initialize() (auth.AuthService, error) { + return &MockAuthService{Name: c.Name, Kind: c.Kind}, nil +} + +// MockAuthService is used to mock auth services in tests. +type MockAuthService struct { + Name string + Kind string +} + +func (s *MockAuthService) AuthServiceKind() string { + return s.Kind +} + +func (s *MockAuthService) GetName() string { + return s.Name +} + +func (s *MockAuthService) GetClaimsFromHeader(context.Context, http.Header) (map[string]any, error) { + return nil, nil +} + +func (s *MockAuthService) ToConfig() auth.AuthServiceConfig { + return MockAuthServiceConfig{Name: s.Name, Kind: s.Kind} +} + func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, error) { mock := []any{t.Name} return mock, nil @@ -74,7 +153,7 @@ func (t MockTool) Manifest() tools.Manifest { for _, p := range t.Params { pMs = append(pMs, p.Manifest()) } - return tools.Manifest{Description: t.Description, Parameters: pMs} + return tools.Manifest{Description: t.Description, Parameters: pMs, AuthRequired: t.AuthRequired} } func (t MockTool) Authorized(verifiedAuthServices []string) bool { @@ -262,6 +341,20 @@ func setUpResources(t *testing.T, mockTools []MockTool, mockPrompts []MockPrompt // setUpServer create a new server with tools, toolsets, prompts, and promptsets. func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, toolsets map[string]tools.Toolset, prompts map[string]prompts.Prompt, promptsets map[string]prompts.Promptset) (chi.Router, func()) { + return setUpServerWithResources(t, router, nil, nil, tools, toolsets, prompts, promptsets) +} + +// setUpServerWithResources create a new server with sources, auth services, tools, toolsets, prompts, and promptsets. +func setUpServerWithResources( + t *testing.T, + router string, + sourcesMap map[string]sources.Source, + authServices map[string]auth.AuthService, + tools map[string]tools.Tool, + toolsets map[string]tools.Toolset, + prompts map[string]prompts.Prompt, + promptsets map[string]prompts.Promptset, +) (chi.Router, func()) { ctx, cancel := context.WithCancel(context.Background()) testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info") @@ -281,7 +374,7 @@ func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, tools sseManager := newSseManager(ctx) - resourceManager := resources.NewResourceManager(nil, nil, nil, tools, toolsets, prompts, promptsets) + resourceManager := resources.NewResourceManager(sourcesMap, authServices, nil, tools, toolsets, prompts, promptsets) server := Server{ version: fakeVersionString, diff --git a/internal/server/resources/resources.go b/internal/server/resources/resources.go index b41e160a39b..a7746217aeb 100644 --- a/internal/server/resources/resources.go +++ b/internal/server/resources/resources.go @@ -139,6 +139,16 @@ func (r *ResourceManager) GetEmbeddingModelMap() map[string]embeddingmodels.Embe return copiedMap } +func (r *ResourceManager) GetSourcesMap() map[string]sources.Source { + r.mu.RLock() + defer r.mu.RUnlock() + copiedMap := make(map[string]sources.Source, len(r.sources)) + for k, v := range r.sources { + copiedMap[k] = v + } + return copiedMap +} + func (r *ResourceManager) GetToolsMap() map[string]tools.Tool { r.mu.RLock() defer r.mu.RUnlock() diff --git a/internal/server/static/authservices.html b/internal/server/static/authservices.html new file mode 100644 index 00000000000..f12d48e90fa --- /dev/null +++ b/internal/server/static/authservices.html @@ -0,0 +1,33 @@ + + + + + + Auth Services View + + + + + + + +
+ + + + + + + diff --git a/internal/server/static/js/apiFetch.js b/internal/server/static/js/apiFetch.js new file mode 100644 index 00000000000..9280f7e7153 --- /dev/null +++ b/internal/server/static/js/apiFetch.js @@ -0,0 +1,35 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * Fetches JSON from the provided URL and returns the object found at the given key. + * @param {string} url + * @param {string} key + * @returns {Promise} + */ +export async function fetchJsonObjectByKey(url, key) { + const response = await fetch(url); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + const apiResponse = await response.json(); + if ( + !apiResponse || + typeof apiResponse[key] !== "object" || + apiResponse[key] === null + ) { + throw new Error(`Invalid response format from API for key "${key}".`); + } + return apiResponse[key]; +} diff --git a/internal/server/static/js/authservices.js b/internal/server/static/js/authservices.js new file mode 100644 index 00000000000..93a463a3191 --- /dev/null +++ b/internal/server/static/js/authservices.js @@ -0,0 +1,132 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { fetchAuthServices, fetchAuthService } from "./loadAuthServices.js"; +import { renderAuthServiceDetails } from "./authservicesDisplay.js"; + +/** + * These functions run after the browser finishes loading and parsing HTML structure. + * This ensures that elements can be safely accessed. + */ +document.addEventListener("DOMContentLoaded", () => { + const authServiceDisplayArea = document.getElementById( + "authservice-display-area" + ); + const secondaryPanelContent = document.getElementById( + "secondary-panel-content" + ); + + if (!secondaryPanelContent || !authServiceDisplayArea) { + console.error("Required DOM elements not found."); + return; + } + + loadAuthServices(secondaryPanelContent, authServiceDisplayArea); +}); + +/** + * Fetches the auth services and renders the list. + * @param {!HTMLElement} secondaryPanelContent The element for the auth service list. + * @param {!HTMLElement} authServiceDisplayArea The element for showing auth service details. + * @returns {!Promise} + */ +async function loadAuthServices(secondaryPanelContent, authServiceDisplayArea) { + secondaryPanelContent.innerHTML = "

Fetching auth services...

"; + try { + const services = await fetchAuthServices(); + renderAuthServiceList( + services, + secondaryPanelContent, + authServiceDisplayArea + ); + } catch (error) { + console.error("Failed to load auth services:", error); + secondaryPanelContent.innerHTML = `

Failed to load auth services:

${error}

`; + } +} + +/** + * Renders the list of auth services as buttons. + * @param {!Array<{name: string, kind: string}>} services The auth services to render. + * @param {!HTMLElement} secondaryPanelContent The element for the auth service list. + * @param {!HTMLElement} authServiceDisplayArea The element for showing auth service details. + */ +function renderAuthServiceList( + services, + secondaryPanelContent, + authServiceDisplayArea +) { + secondaryPanelContent.innerHTML = ""; + + if (!Array.isArray(services) || services.length === 0) { + secondaryPanelContent.textContent = "No auth services found."; + return; + } + + const ul = document.createElement("ul"); + services.forEach((service) => { + const li = document.createElement("li"); + const button = document.createElement("button"); + button.textContent = service.name; + button.dataset.authservicename = service.name; + button.classList.add("tool-button"); + button.addEventListener("click", (event) => + handleAuthServiceClick( + event, + secondaryPanelContent, + authServiceDisplayArea + ) + ); + li.appendChild(button); + ul.appendChild(li); + }); + secondaryPanelContent.appendChild(ul); +} + +/** + * Handles the click event on an auth service button. + * @param {!Event} event The click event object. + * @param {!HTMLElement} secondaryPanelContent The element containing the auth service list. + * @param {!HTMLElement} authServiceDisplayArea The element for showing auth service details. + */ +async function handleAuthServiceClick( + event, + secondaryPanelContent, + authServiceDisplayArea +) { + const authServiceName = event.target.dataset.authservicename; + if (!authServiceName) { + return; + } + + const currentActive = secondaryPanelContent.querySelector( + ".tool-button.active" + ); + if (currentActive) { + currentActive.classList.remove("active"); + } + event.target.classList.add("active"); + + authServiceDisplayArea.innerHTML = "

Loading auth service details...

"; + try { + const service = await fetchAuthService(authServiceName); + renderAuthServiceDetails(service, authServiceDisplayArea); + } catch (error) { + console.error( + `Failed to load details for auth service "${authServiceName}":`, + error + ); + authServiceDisplayArea.innerHTML = `

Failed to load details for ${authServiceName}. ${error.message}

`; + } +} diff --git a/internal/server/static/js/authservicesDisplay.js b/internal/server/static/js/authservicesDisplay.js new file mode 100644 index 00000000000..2d89121ec38 --- /dev/null +++ b/internal/server/static/js/authservicesDisplay.js @@ -0,0 +1,60 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * Renders auth service details into the main content area. + * @param {{name: string, kind: string}} service The auth service to render. + * @param {!HTMLElement} container The container to render into. + */ +export function renderAuthServiceDetails(service, container) { + container.innerHTML = ""; + + const wrapper = document.createElement("div"); + wrapper.className = "tool-box"; + + const title = document.createElement("h3"); + title.textContent = service.name || "Unnamed auth service"; + + const kind = document.createElement("p"); + kind.innerHTML = `Kind: ${service.kind || "unknown"}`; + + const headerName = document.createElement("p"); + headerName.innerHTML = `Header: ${ + service.headerName || (service.name ? `${service.name}_token` : "unknown") + }`; + + const toolsTitle = document.createElement("h5"); + toolsTitle.textContent = "Used by tools"; + + const toolsList = document.createElement("ul"); + const tools = Array.isArray(service.tools) ? service.tools : []; + if (tools.length === 0) { + const emptyItem = document.createElement("li"); + emptyItem.textContent = "No tools reference this auth service."; + toolsList.appendChild(emptyItem); + } else { + tools.forEach((toolName) => { + const item = document.createElement("li"); + item.textContent = toolName; + toolsList.appendChild(item); + }); + } + + wrapper.appendChild(title); + wrapper.appendChild(kind); + wrapper.appendChild(headerName); + wrapper.appendChild(toolsTitle); + wrapper.appendChild(toolsList); + container.appendChild(wrapper); +} diff --git a/internal/server/static/js/loadAuthServices.js b/internal/server/static/js/loadAuthServices.js new file mode 100644 index 00000000000..da0f42a0380 --- /dev/null +++ b/internal/server/static/js/loadAuthServices.js @@ -0,0 +1,57 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { fetchJsonObjectByKey } from "./apiFetch.js"; + +/** + * Fetches details for a specific auth service. + * @param {string} authServiceName The name of the auth service to fetch details for. + * @returns {!Promise<{name: string, kind: string}>} + */ +export async function fetchAuthService(authServiceName) { + const authServices = await fetchJsonObjectByKey( + `/api/authservice/${encodeURIComponent(authServiceName)}`, + "authServices" + ); + const service = authServices[authServiceName]; + if (!service) { + throw new Error( + `Auth service "${authServiceName}" not found in API response.` + ); + } + + return { + name: service.name || authServiceName, + kind: service.kind || "", + headerName: service.headerName || "", + tools: Array.isArray(service.tools) ? service.tools : [], + }; +} + +/** + * Fetches the list of auth services from the API. + * @returns {!Promise>} + */ +export async function fetchAuthServices() { + const authServices = await fetchJsonObjectByKey( + "/api/authservice", + "authServices" + ); + return Object.values(authServices).map((service) => ({ + name: service.name || "", + kind: service.kind || "", + headerName: service.headerName || "", + tools: Array.isArray(service.tools) ? service.tools : [], + })); +} diff --git a/internal/server/static/js/loadSources.js b/internal/server/static/js/loadSources.js new file mode 100644 index 00000000000..5c6b4f65bad --- /dev/null +++ b/internal/server/static/js/loadSources.js @@ -0,0 +1,49 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { fetchJsonObjectByKey } from "./apiFetch.js"; + +/** + * Fetches details for a specific source. + * @param {string} sourceName The name of the source to fetch details for. + * @returns {!Promise<{name: string, kind: string}>} + */ +export async function fetchSource(sourceName) { + const sources = await fetchJsonObjectByKey( + `/api/source/${encodeURIComponent(sourceName)}`, + "sources" + ); + const source = sources[sourceName]; + if (!source) { + throw new Error(`Source "${sourceName}" not found in API response.`); + } + + return { + name: source.name || sourceName, + kind: source.kind || "", + config: source.config || {}, + }; +} + +/** + * Fetches the list of sources from the API. + * @returns {!Promise>} + */ +export async function fetchSources() { + const sources = await fetchJsonObjectByKey("/api/source", "sources"); + return Object.values(sources).map((source) => ({ + name: source.name || "", + kind: source.kind || "", + })); +} diff --git a/internal/server/static/js/mainContent.js b/internal/server/static/js/mainContent.js index 1dbe7219f5b..4d7c4dccaca 100644 --- a/internal/server/static/js/mainContent.js +++ b/internal/server/static/js/mainContent.js @@ -64,6 +64,36 @@ function getToolInstructions() { `; } +function getSourceInstructions() { + return ` +
+

Sources

+

To inspect a source, please click on one of your sources to the left.

+

What are Sources?

+

+ Sources define the data backends that tools connect to, such as databases or APIs. + You can define Sources as a map in the sources section of your tools.yaml file. +

+ Sources Documentation +
+ `; +} + +function getAuthServiceInstructions() { + return ` +
+

Auth Services

+

To inspect an auth service, please click on one of your auth services to the left.

+

What are Auth Services?

+

+ Auth services define how incoming requests are authenticated and how claims are extracted for tools. + You can define Auth Services as a map in the authServices section of your tools.yaml file. +

+ Auth Services Documentation +
+ `; +} + function getToolsetInstructions() { return `
@@ -77,4 +107,4 @@ function getToolsetInstructions() { Toolsets Documentation
`; -} \ No newline at end of file +} diff --git a/internal/server/static/js/navbar.js b/internal/server/static/js/navbar.js index b91503cdfb8..c5822efeafc 100644 --- a/internal/server/static/js/navbar.js +++ b/internal/server/static/js/navbar.js @@ -30,8 +30,8 @@ function renderNavbar(containerId, activePath) { App Logo diff --git a/internal/server/static/js/sources.js b/internal/server/static/js/sources.js new file mode 100644 index 00000000000..adfd5d0539c --- /dev/null +++ b/internal/server/static/js/sources.js @@ -0,0 +1,115 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { fetchSources, fetchSource } from "./loadSources.js"; +import { renderSourceDetails } from "./sourcesDisplay.js"; + +/** + * These functions run after the browser finishes loading and parsing HTML structure. + * This ensures that elements can be safely accessed. + */ +document.addEventListener("DOMContentLoaded", () => { + const sourceDisplayArea = document.getElementById("source-display-area"); + const secondaryPanelContent = document.getElementById( + "secondary-panel-content" + ); + + if (!secondaryPanelContent || !sourceDisplayArea) { + console.error("Required DOM elements not found."); + return; + } + + loadSources(secondaryPanelContent, sourceDisplayArea); +}); + +/** + * Fetches the sources and renders the list. + * @param {!HTMLElement} secondaryPanelContent The element for the source list. + * @param {!HTMLElement} sourceDisplayArea The element for showing source details. + * @returns {!Promise} + */ +async function loadSources(secondaryPanelContent, sourceDisplayArea) { + secondaryPanelContent.innerHTML = "

Fetching sources...

"; + try { + const sources = await fetchSources(); + renderSourceList(sources, secondaryPanelContent, sourceDisplayArea); + } catch (error) { + console.error("Failed to load sources:", error); + secondaryPanelContent.innerHTML = `

Failed to load sources:

${error}

`; + } +} + +/** + * Renders the list of sources as buttons. + * @param {!Array<{name: string, kind: string}>} sources The sources to render. + * @param {!HTMLElement} secondaryPanelContent The element for the source list. + * @param {!HTMLElement} sourceDisplayArea The element for showing source details. + */ +function renderSourceList(sources, secondaryPanelContent, sourceDisplayArea) { + secondaryPanelContent.innerHTML = ""; + + if (!Array.isArray(sources) || sources.length === 0) { + secondaryPanelContent.textContent = "No sources found."; + return; + } + + const ul = document.createElement("ul"); + sources.forEach((source) => { + const li = document.createElement("li"); + const button = document.createElement("button"); + button.textContent = source.name; + button.dataset.sourcename = source.name; + button.classList.add("tool-button"); + button.addEventListener("click", (event) => + handleSourceClick(event, secondaryPanelContent, sourceDisplayArea) + ); + li.appendChild(button); + ul.appendChild(li); + }); + secondaryPanelContent.appendChild(ul); +} + +/** + * Handles the click event on a source button. + * @param {!Event} event The click event object. + * @param {!HTMLElement} secondaryPanelContent The element containing the source list. + * @param {!HTMLElement} sourceDisplayArea The element for showing source details. + */ +async function handleSourceClick( + event, + secondaryPanelContent, + sourceDisplayArea +) { + const sourceName = event.target.dataset.sourcename; + if (!sourceName) { + return; + } + + const currentActive = secondaryPanelContent.querySelector( + ".tool-button.active" + ); + if (currentActive) { + currentActive.classList.remove("active"); + } + event.target.classList.add("active"); + + sourceDisplayArea.innerHTML = "

Loading source details...

"; + try { + const source = await fetchSource(sourceName); + renderSourceDetails(source, sourceDisplayArea); + } catch (error) { + console.error(`Failed to load details for source "${sourceName}":`, error); + sourceDisplayArea.innerHTML = `

Failed to load details for ${sourceName}. ${error.message}

`; + } +} diff --git a/internal/server/static/js/sourcesDisplay.js b/internal/server/static/js/sourcesDisplay.js new file mode 100644 index 00000000000..7fa4801a786 --- /dev/null +++ b/internal/server/static/js/sourcesDisplay.js @@ -0,0 +1,87 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * Renders source details into the main content area. + * @param {{name: string, kind: string}} source The source to render. + * @param {!HTMLElement} container The container to render into. + */ +export function renderSourceDetails(source, container) { + container.innerHTML = ""; + + const wrapper = document.createElement("div"); + wrapper.className = "tool-box"; + + const title = document.createElement("h3"); + title.textContent = source.name || "Unnamed source"; + + const kind = document.createElement("p"); + kind.innerHTML = `Kind: ${source.kind || "unknown"}`; + + const summary = document.createElement("p"); + summary.innerHTML = `Type: ${formatSourceType(source.kind)}`; + + const configTitle = document.createElement("h5"); + configTitle.textContent = "Configuration"; + + const configList = document.createElement("ul"); + const configEntries = + source.config && typeof source.config === "object" + ? Object.entries(source.config) + : []; + if (configEntries.length === 0) { + const emptyItem = document.createElement("li"); + emptyItem.textContent = "No configuration details available."; + configList.appendChild(emptyItem); + } else { + configEntries.forEach(([key, value]) => { + const item = document.createElement("li"); + item.textContent = `${key}: ${formatConfigValue(value)}`; + configList.appendChild(item); + }); + } + + wrapper.appendChild(title); + wrapper.appendChild(kind); + wrapper.appendChild(summary); + wrapper.appendChild(configTitle); + wrapper.appendChild(configList); + container.appendChild(wrapper); +} + +function formatConfigValue(value) { + if (value === null || value === undefined) { + return "null"; + } + if (typeof value === "object") { + try { + return JSON.stringify(value); + } catch (e) { + return "[object]"; + } + } + return String(value); +} + +function formatSourceType(kind) { + if (!kind) { + return "Unknown source"; + } + const normalized = String(kind).replace(/[_-]+/g, " ").trim().toLowerCase(); + return `${capitalizeWords(normalized)} source`; +} + +function capitalizeWords(value) { + return value.replace(/\b\w/g, (char) => char.toUpperCase()); +} diff --git a/internal/server/static/sources.html b/internal/server/static/sources.html new file mode 100644 index 00000000000..7de82650e49 --- /dev/null +++ b/internal/server/static/sources.html @@ -0,0 +1,33 @@ + + + + + + Sources View + + + + + + + +
+ + + + + + + diff --git a/internal/server/web.go b/internal/server/web.go index 23f7de06f2b..87e20239247 100644 --- a/internal/server/web.go +++ b/internal/server/web.go @@ -22,6 +22,8 @@ func webRouter() (chi.Router, error) { // direct routes for html pages to provide clean URLs r.Get("/", func(w http.ResponseWriter, r *http.Request) { serveHTML(w, r, "static/index.html") }) + r.Get("/authservices", func(w http.ResponseWriter, r *http.Request) { serveHTML(w, r, "static/authservices.html") }) + r.Get("/sources", func(w http.ResponseWriter, r *http.Request) { serveHTML(w, r, "static/sources.html") }) r.Get("/tools", func(w http.ResponseWriter, r *http.Request) { serveHTML(w, r, "static/tools.html") }) r.Get("/toolsets", func(w http.ResponseWriter, r *http.Request) { serveHTML(w, r, "static/toolsets.html") }) diff --git a/internal/server/web_test.go b/internal/server/web_test.go index 64c173285e7..dcc950b999d 100644 --- a/internal/server/web_test.go +++ b/internal/server/web_test.go @@ -59,6 +59,34 @@ func TestWebEndpoint(t *testing.T) { wantContentType: "text/html", wantPageTitle: "Tools View", }, + { + name: "web auth services page", + path: "/ui/authservices", + wantStatus: http.StatusOK, + wantContentType: "text/html", + wantPageTitle: "Auth Services View", + }, + { + name: "web auth services page with trailing slash", + path: "/ui/authservices/", + wantStatus: http.StatusOK, + wantContentType: "text/html", + wantPageTitle: "Auth Services View", + }, + { + name: "web sources page", + path: "/ui/sources", + wantStatus: http.StatusOK, + wantContentType: "text/html", + wantPageTitle: "Sources View", + }, + { + name: "web sources page with trailing slash", + path: "/ui/sources/", + wantStatus: http.StatusOK, + wantContentType: "text/html", + wantPageTitle: "Sources View", + }, { name: "web toolsets page", path: "/ui/toolsets",