diff --git a/internal/sources/dataplex/dataplex.go b/internal/sources/dataplex/dataplex.go index ac7fc74713b..9ddb05edd11 100644 --- a/internal/sources/dataplex/dataplex.go +++ b/internal/sources/dataplex/dataplex.go @@ -19,6 +19,8 @@ import ( "fmt" dataplexapi "cloud.google.com/go/dataplex/apiv1" + "cloud.google.com/go/dataplex/apiv1/dataplexpb" + "github.com/cenkalti/backoff/v5" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" @@ -121,3 +123,101 @@ func initDataplexConnection( } return client, nil } + +func (s *Source) LookupEntry(ctx context.Context, name string, view int, aspectTypes []string, entry string) (*dataplexpb.Entry, error) { + viewMap := map[int]dataplexpb.EntryView{ + 1: dataplexpb.EntryView_BASIC, + 2: dataplexpb.EntryView_FULL, + 3: dataplexpb.EntryView_CUSTOM, + 4: dataplexpb.EntryView_ALL, + } + req := &dataplexpb.LookupEntryRequest{ + Name: name, + View: viewMap[view], + AspectTypes: aspectTypes, + Entry: entry, + } + result, err := s.CatalogClient().LookupEntry(ctx, req) + if err != nil { + return nil, err + } + return result, nil +} + +func (s *Source) searchRequest(ctx context.Context, query string, pageSize int, orderBy string) (*dataplexapi.SearchEntriesResultIterator, error) { + // Create SearchEntriesRequest with the provided parameters + req := &dataplexpb.SearchEntriesRequest{ + Query: query, + Name: fmt.Sprintf("projects/%s/locations/global", s.ProjectID()), + PageSize: int32(pageSize), + OrderBy: orderBy, + SemanticSearch: true, + } + + // Perform the search using the CatalogClient - this will return an iterator + it := s.CatalogClient().SearchEntries(ctx, req) + if it == nil { + return nil, fmt.Errorf("failed to create search entries iterator for project %q", s.ProjectID()) + } + return it, nil +} + +func (s *Source) SearchAspectTypes(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.AspectType, error) { + q := query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype" + it, err := s.searchRequest(ctx, q, pageSize, orderBy) + if err != nil { + return nil, err + } + + // Iterate through the search results and call GetAspectType for each result using the resource name + var results []*dataplexpb.AspectType + for { + entry, err := it.Next() + if err != nil { + break + } + + // Create an instance of exponential backoff with default values for retrying GetAspectType calls + // InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s + getAspectBackOff := backoff.NewExponentialBackOff() + + resourceName := entry.DataplexEntry.GetEntrySource().Resource + getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{ + Name: resourceName, + } + + operation := func() (*dataplexpb.AspectType, error) { + aspectType, err := s.CatalogClient().GetAspectType(ctx, getAspectTypeReq) + if err != nil { + return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err) + } + return aspectType, nil + } + + // Retry the GetAspectType operation with exponential backoff + aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff)) + if err != nil { + return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err) + } + + results = append(results, aspectType) + } + return results, nil +} + +func (s *Source) SearchEntries(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.SearchEntriesResult, error) { + it, err := s.searchRequest(ctx, query, pageSize, orderBy) + if err != nil { + return nil, err + } + + var results []*dataplexpb.SearchEntriesResult + for { + entry, err := it.Next() + if err != nil { + break + } + results = append(results, entry) + } + return results, nil +} diff --git a/internal/sources/http/http.go b/internal/sources/http/http.go index b4e9fdd9374..238c19b6074 100644 --- a/internal/sources/http/http.go +++ b/internal/sources/http/http.go @@ -16,7 +16,9 @@ package http import ( "context" "crypto/tls" + "encoding/json" "fmt" + "io" "net/http" "net/url" "time" @@ -143,3 +145,28 @@ func (s *Source) HttpQueryParams() map[string]string { func (s *Source) Client() *http.Client { return s.client } + +func (s *Source) RunRequest(req *http.Request) (any, error) { + // Make request and fetch response + resp, err := s.Client().Do(req) + if err != nil { + return nil, fmt.Errorf("error making HTTP request: %s", err) + } + defer resp.Body.Close() + + var body []byte + body, err = io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + if resp.StatusCode < 200 || resp.StatusCode > 299 { + return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body)) + } + + var data any + if err = json.Unmarshal(body, &data); err != nil { + // if unable to unmarshal data, return result as string. + return string(body), nil + } + return data, nil +} diff --git a/internal/sources/serverlessspark/serverlessspark.go b/internal/sources/serverlessspark/serverlessspark.go index c63adb68637..f6968f9baec 100644 --- a/internal/sources/serverlessspark/serverlessspark.go +++ b/internal/sources/serverlessspark/serverlessspark.go @@ -16,15 +16,21 @@ package serverlessspark import ( "context" + "encoding/json" "fmt" + "time" dataproc "cloud.google.com/go/dataproc/v2/apiv1" + "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" longrunning "cloud.google.com/go/longrunning/autogen" + "cloud.google.com/go/longrunning/autogen/longrunningpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" "go.opentelemetry.io/otel/trace" + "google.golang.org/api/iterator" "google.golang.org/api/option" + "google.golang.org/protobuf/encoding/protojson" ) const SourceKind string = "serverless-spark" @@ -121,3 +127,168 @@ func (s *Source) Close() error { } return nil } + +func (s *Source) CancelOperation(ctx context.Context, operation string) (any, error) { + req := &longrunningpb.CancelOperationRequest{ + Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", s.GetProject(), s.GetLocation(), operation), + } + client, err := s.GetOperationsClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get operations client: %w", err) + } + err = client.CancelOperation(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to cancel operation: %w", err) + } + return fmt.Sprintf("Cancelled [%s].", operation), nil +} + +func (s *Source) CreateBatch(ctx context.Context, batch *dataprocpb.Batch) (map[string]any, error) { + req := &dataprocpb.CreateBatchRequest{ + Parent: fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation()), + Batch: batch, + } + + client := s.GetBatchControllerClient() + op, err := client.CreateBatch(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to create batch: %w", err) + } + meta, err := op.Metadata() + if err != nil { + return nil, fmt.Errorf("failed to get create batch op metadata: %w", err) + } + + projectID, location, batchID, err := ExtractBatchDetails(meta.GetBatch()) + if err != nil { + return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err) + } + consoleUrl := BatchConsoleURL(projectID, location, batchID) + logsUrl := BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{}) + + wrappedResult := map[string]any{ + "opMetadata": meta, + "consoleUrl": consoleUrl, + "logsUrl": logsUrl, + } + return wrappedResult, nil +} + +// ListBatchesResponse is the response from the list batches API. +type ListBatchesResponse struct { + Batches []Batch `json:"batches"` + NextPageToken string `json:"nextPageToken"` +} + +// Batch represents a single batch job. +type Batch struct { + Name string `json:"name"` + UUID string `json:"uuid"` + State string `json:"state"` + Creator string `json:"creator"` + CreateTime string `json:"createTime"` + Operation string `json:"operation"` + ConsoleURL string `json:"consoleUrl"` + LogsURL string `json:"logsUrl"` +} + +func (s *Source) ListBatches(ctx context.Context, ps *int, pt, filter string) (any, error) { + client := s.GetBatchControllerClient() + parent := fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation()) + req := &dataprocpb.ListBatchesRequest{ + Parent: parent, + OrderBy: "create_time desc", + } + + if ps != nil { + req.PageSize = int32(*ps) + } + if pt != "" { + req.PageToken = pt + } + if filter != "" { + req.Filter = filter + } + + it := client.ListBatches(ctx, req) + pager := iterator.NewPager(it, int(req.PageSize), req.PageToken) + + var batchPbs []*dataprocpb.Batch + nextPageToken, err := pager.NextPage(&batchPbs) + if err != nil { + return nil, fmt.Errorf("failed to list batches: %w", err) + } + + batches, err := ToBatches(batchPbs) + if err != nil { + return nil, err + } + + return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil +} + +// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs. +func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) { + batches := make([]Batch, 0, len(batchPbs)) + for _, batchPb := range batchPbs { + consoleUrl, err := BatchConsoleURLFromProto(batchPb) + if err != nil { + return nil, fmt.Errorf("error generating console url: %v", err) + } + logsUrl, err := BatchLogsURLFromProto(batchPb) + if err != nil { + return nil, fmt.Errorf("error generating logs url: %v", err) + } + batch := Batch{ + Name: batchPb.Name, + UUID: batchPb.Uuid, + State: batchPb.State.Enum().String(), + Creator: batchPb.Creator, + CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339), + Operation: batchPb.Operation, + ConsoleURL: consoleUrl, + LogsURL: logsUrl, + } + batches = append(batches, batch) + } + return batches, nil +} + +func (s *Source) GetBatch(ctx context.Context, name string) (map[string]any, error) { + client := s.GetBatchControllerClient() + req := &dataprocpb.GetBatchRequest{ + Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", s.GetProject(), s.GetLocation(), name), + } + + batchPb, err := client.GetBatch(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to get batch: %w", err) + } + + jsonBytes, err := protojson.Marshal(batchPb) + if err != nil { + return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err) + } + + var result map[string]any + if err := json.Unmarshal(jsonBytes, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err) + } + + consoleUrl, err := BatchConsoleURLFromProto(batchPb) + if err != nil { + return nil, fmt.Errorf("error generating console url: %v", err) + } + logsUrl, err := BatchLogsURLFromProto(batchPb) + if err != nil { + return nil, fmt.Errorf("error generating logs url: %v", err) + } + + wrappedResult := map[string]any{ + "consoleUrl": consoleUrl, + "logsUrl": logsUrl, + "batch": result, + } + + return wrappedResult, nil +} diff --git a/internal/tools/serverlessspark/common/urls.go b/internal/sources/serverlessspark/url.go similarity index 97% rename from internal/tools/serverlessspark/common/urls.go rename to internal/sources/serverlessspark/url.go index 3b522359927..75dcc4e5ae9 100644 --- a/internal/tools/serverlessspark/common/urls.go +++ b/internal/sources/serverlessspark/url.go @@ -1,10 +1,10 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package common +package serverlessspark import ( "fmt" @@ -23,13 +23,13 @@ import ( "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" ) +var batchFullNameRegex = regexp.MustCompile(`projects/(?P[^/]+)/locations/(?P[^/]+)/batches/(?P[^/]+)`) + const ( logTimeBufferBefore = 1 * time.Minute logTimeBufferAfter = 10 * time.Minute ) -var batchFullNameRegex = regexp.MustCompile(`projects/(?P[^/]+)/locations/(?P[^/]+)/batches/(?P[^/]+)`) - // Extract BatchDetails extracts the project ID, location, and batch ID from a fully qualified batch name. func ExtractBatchDetails(batchName string) (projectID, location, batchID string, err error) { matches := batchFullNameRegex.FindStringSubmatch(batchName) @@ -39,26 +39,6 @@ func ExtractBatchDetails(batchName string) (projectID, location, batchID string, return matches[1], matches[2], matches[3], nil } -// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page. -func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) { - projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName()) - if err != nil { - return "", err - } - return BatchConsoleURL(projectID, location, batchID), nil -} - -// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range. -func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) { - projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName()) - if err != nil { - return "", err - } - createTime := batchPb.GetCreateTime().AsTime() - stateTime := batchPb.GetStateTime().AsTime() - return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil -} - // BatchConsoleURL builds a URL to the Google Cloud Console linking to the batch summary page. func BatchConsoleURL(projectID, location, batchID string) string { return fmt.Sprintf("https://console.cloud.google.com/dataproc/batches/%s/%s/summary?project=%s", location, batchID, projectID) @@ -89,3 +69,23 @@ resource.labels.batch_id="%s"` return "https://console.cloud.google.com/logs/viewer?" + v.Encode() } + +// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page. +func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) { + projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName()) + if err != nil { + return "", err + } + return BatchConsoleURL(projectID, location, batchID), nil +} + +// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range. +func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) { + projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName()) + if err != nil { + return "", err + } + createTime := batchPb.GetCreateTime().AsTime() + stateTime := batchPb.GetStateTime().AsTime() + return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil +} diff --git a/internal/tools/serverlessspark/common/urls_test.go b/internal/sources/serverlessspark/url_test.go similarity index 86% rename from internal/tools/serverlessspark/common/urls_test.go rename to internal/sources/serverlessspark/url_test.go index c8d9e072006..16ed75738fc 100644 --- a/internal/tools/serverlessspark/common/urls_test.go +++ b/internal/sources/serverlessspark/url_test.go @@ -1,10 +1,10 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -12,19 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package common +package serverlessspark_test import ( "testing" "time" "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" "google.golang.org/protobuf/types/known/timestamppb" ) func TestExtractBatchDetails_Success(t *testing.T) { batchName := "projects/my-project/locations/us-central1/batches/my-batch" - projectID, location, batchID, err := ExtractBatchDetails(batchName) + projectID, location, batchID, err := serverlessspark.ExtractBatchDetails(batchName) if err != nil { t.Errorf("ExtractBatchDetails() error = %v, want no error", err) return @@ -45,7 +46,7 @@ func TestExtractBatchDetails_Success(t *testing.T) { func TestExtractBatchDetails_Failure(t *testing.T) { batchName := "invalid-name" - _, _, _, err := ExtractBatchDetails(batchName) + _, _, _, err := serverlessspark.ExtractBatchDetails(batchName) wantErr := "failed to parse batch name: invalid-name" if err == nil || err.Error() != wantErr { t.Errorf("ExtractBatchDetails() error = %v, want %v", err, wantErr) @@ -53,7 +54,7 @@ func TestExtractBatchDetails_Failure(t *testing.T) { } func TestBatchConsoleURL(t *testing.T) { - got := BatchConsoleURL("my-project", "us-central1", "my-batch") + got := serverlessspark.BatchConsoleURL("my-project", "us-central1", "my-batch") want := "https://console.cloud.google.com/dataproc/batches/us-central1/my-batch/summary?project=my-project" if got != want { t.Errorf("BatchConsoleURL() = %v, want %v", got, want) @@ -63,7 +64,7 @@ func TestBatchConsoleURL(t *testing.T) { func TestBatchLogsURL(t *testing.T) { startTime := time.Date(2025, 10, 1, 5, 0, 0, 0, time.UTC) endTime := time.Date(2025, 10, 1, 6, 0, 0, 0, time.UTC) - got := BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime) + got := serverlessspark.BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime) want := "https://console.cloud.google.com/logs/viewer?advancedFilter=" + "resource.type%3D%22cloud_dataproc_batch%22" + "%0Aresource.labels.project_id%3D%22my-project%22" + @@ -82,7 +83,7 @@ func TestBatchConsoleURLFromProto(t *testing.T) { batchPb := &dataprocpb.Batch{ Name: "projects/my-project/locations/us-central1/batches/my-batch", } - got, err := BatchConsoleURLFromProto(batchPb) + got, err := serverlessspark.BatchConsoleURLFromProto(batchPb) if err != nil { t.Fatalf("BatchConsoleURLFromProto() error = %v", err) } @@ -100,7 +101,7 @@ func TestBatchLogsURLFromProto(t *testing.T) { CreateTime: timestamppb.New(createTime), StateTime: timestamppb.New(stateTime), } - got, err := BatchLogsURLFromProto(batchPb) + got, err := serverlessspark.BatchLogsURLFromProto(batchPb) if err != nil { t.Fatalf("BatchLogsURLFromProto() error = %v", err) } diff --git a/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go b/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go index 39d59fbfdf6..3fed4ef0ccf 100644 --- a/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go +++ b/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go @@ -18,7 +18,6 @@ import ( "context" "fmt" - dataplexapi "cloud.google.com/go/dataplex/apiv1" dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" @@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - CatalogClient() *dataplexapi.CatalogClient + LookupEntry(context.Context, string, int, []string, string) (*dataplexpb.Entry, error) } type Config struct { @@ -118,12 +117,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } paramsMap := params.AsMap() - viewMap := map[int]dataplexpb.EntryView{ - 1: dataplexpb.EntryView_BASIC, - 2: dataplexpb.EntryView_FULL, - 3: dataplexpb.EntryView_CUSTOM, - 4: dataplexpb.EntryView_ALL, - } name, _ := paramsMap["name"].(string) entry, _ := paramsMap["entry"].(string) view, _ := paramsMap["view"].(int) @@ -132,19 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("can't convert aspectTypes to array of strings: %s", err) } aspectTypes := aspectTypeSlice.([]string) - - req := &dataplexpb.LookupEntryRequest{ - Name: name, - View: viewMap[view], - AspectTypes: aspectTypes, - Entry: entry, - } - - result, err := source.CatalogClient().LookupEntry(ctx, req) - if err != nil { - return nil, err - } - return result, nil + return source.LookupEntry(ctx, name, view, aspectTypes, entry) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go b/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go index 5f8b304e2b7..214f7396e19 100644 --- a/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go +++ b/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go @@ -18,9 +18,7 @@ import ( "context" "fmt" - dataplexapi "cloud.google.com/go/dataplex/apiv1" - dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb" - "github.com/cenkalti/backoff/v5" + "cloud.google.com/go/dataplex/apiv1/dataplexpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" @@ -45,8 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - CatalogClient() *dataplexapi.CatalogClient - ProjectID() string + SearchAspectTypes(context.Context, string, int, string) ([]*dataplexpb.AspectType, error) } type Config struct { @@ -101,61 +98,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, err } - - // Invoke the tool with the provided parameters paramsMap := params.AsMap() query, _ := paramsMap["query"].(string) - pageSize := int32(paramsMap["pageSize"].(int)) + pageSize, _ := paramsMap["pageSize"].(int) orderBy, _ := paramsMap["orderBy"].(string) - - // Create SearchEntriesRequest with the provided parameters - req := &dataplexpb.SearchEntriesRequest{ - Query: query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype", - Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()), - PageSize: pageSize, - OrderBy: orderBy, - SemanticSearch: true, - } - - // Perform the search using the CatalogClient - this will return an iterator - it := source.CatalogClient().SearchEntries(ctx, req) - if it == nil { - return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID()) - } - - // Create an instance of exponential backoff with default values for retrying GetAspectType calls - // InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s - getAspectBackOff := backoff.NewExponentialBackOff() - - // Iterate through the search results and call GetAspectType for each result using the resource name - var results []*dataplexpb.AspectType - for { - entry, err := it.Next() - if err != nil { - break - } - resourceName := entry.DataplexEntry.GetEntrySource().Resource - getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{ - Name: resourceName, - } - - operation := func() (*dataplexpb.AspectType, error) { - aspectType, err := source.CatalogClient().GetAspectType(ctx, getAspectTypeReq) - if err != nil { - return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err) - } - return aspectType, nil - } - - // Retry the GetAspectType operation with exponential backoff - aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff)) - if err != nil { - return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err) - } - - results = append(results, aspectType) - } - return results, nil + return source.SearchAspectTypes(ctx, query, pageSize, orderBy) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go b/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go index 4536e265e23..7180848a4ba 100644 --- a/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go +++ b/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go @@ -18,8 +18,7 @@ import ( "context" "fmt" - dataplexapi "cloud.google.com/go/dataplex/apiv1" - dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb" + "cloud.google.com/go/dataplex/apiv1/dataplexpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" @@ -44,8 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - CatalogClient() *dataplexapi.CatalogClient - ProjectID() string + SearchEntries(context.Context, string, int, string) ([]*dataplexpb.SearchEntriesResult, error) } type Config struct { @@ -100,34 +98,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, err } - paramsMap := params.AsMap() query, _ := paramsMap["query"].(string) - pageSize := int32(paramsMap["pageSize"].(int)) + pageSize, _ := paramsMap["pageSize"].(int) orderBy, _ := paramsMap["orderBy"].(string) - - req := &dataplexpb.SearchEntriesRequest{ - Query: query, - Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()), - PageSize: pageSize, - OrderBy: orderBy, - SemanticSearch: true, - } - - it := source.CatalogClient().SearchEntries(ctx, req) - if it == nil { - return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID()) - } - - var results []*dataplexpb.SearchEntriesResult - for { - entry, err := it.Next() - if err != nil { - break - } - results = append(results, entry) - } - return results, nil + return source.SearchEntries(ctx, query, pageSize, orderBy) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/http/http.go b/internal/tools/http/http.go index d66efa2b766..6ba9699e2b0 100644 --- a/internal/tools/http/http.go +++ b/internal/tools/http/http.go @@ -16,9 +16,7 @@ package http import ( "bytes" "context" - "encoding/json" "fmt" - "io" "net/http" "net/url" "slices" @@ -54,7 +52,7 @@ type compatibleSource interface { HttpDefaultHeaders() map[string]string HttpBaseURL() string HttpQueryParams() map[string]string - Client() *http.Client + RunRequest(*http.Request) (any, error) } type Config struct { @@ -259,29 +257,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para for k, v := range allHeaders { req.Header.Set(k, v) } - - // Make request and fetch response - resp, err := source.Client().Do(req) - if err != nil { - return nil, fmt.Errorf("error making HTTP request: %s", err) - } - defer resp.Body.Close() - - var body []byte - body, err = io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - if resp.StatusCode < 200 || resp.StatusCode > 299 { - return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body)) - } - - var data any - if err = json.Unmarshal(body, &data); err != nil { - // if unable to unmarshal data, return result as string. - return string(body), nil - } - return data, nil + return source.RunRequest(req) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/serverlessspark/createbatch/config.go b/internal/tools/serverlessspark/createbatch/config.go index 0bb3575a399..bcbf611584e 100644 --- a/internal/tools/serverlessspark/createbatch/config.go +++ b/internal/tools/serverlessspark/createbatch/config.go @@ -19,7 +19,6 @@ import ( "encoding/json" "fmt" - dataproc "cloud.google.com/go/dataproc/v2/apiv1" dataprocpb "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" "github.com/goccy/go-yaml" "google.golang.org/protobuf/encoding/protojson" @@ -36,9 +35,7 @@ func unmarshalProto(data any, m proto.Message) error { } type compatibleSource interface { - GetBatchControllerClient() *dataproc.BatchControllerClient - GetProject() string - GetLocation() string + CreateBatch(context.Context, *dataprocpb.Batch) (map[string]any, error) } // Config is a common config that can be used with any type of create batch tool. However, each tool diff --git a/internal/tools/serverlessspark/createbatch/tool.go b/internal/tools/serverlessspark/createbatch/tool.go index 3839a71a18f..dca7081aa60 100644 --- a/internal/tools/serverlessspark/createbatch/tool.go +++ b/internal/tools/serverlessspark/createbatch/tool.go @@ -16,23 +16,19 @@ package createbatch import ( "context" - "encoding/json" "fmt" - "time" dataprocpb "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" ) type BatchBuilder interface { Parameters() parameters.Parameters - BuildBatch(params parameters.ParamValues) (*dataprocpb.Batch, error) + BuildBatch(parameters.ParamValues) (*dataprocpb.Batch, error) } func NewTool(cfg Config, originalCfg tools.ToolConfig, srcs map[string]sources.Source, builder BatchBuilder) (*Tool, error) { @@ -74,7 +70,6 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par if err != nil { return nil, err } - client := source.GetBatchControllerClient() batch, err := t.Builder.BuildBatch(params) if err != nil { @@ -97,46 +92,7 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par } batch.RuntimeConfig.Version = version } - - req := &dataprocpb.CreateBatchRequest{ - Parent: fmt.Sprintf("projects/%s/locations/%s", source.GetProject(), source.GetLocation()), - Batch: batch, - } - - op, err := client.CreateBatch(ctx, req) - if err != nil { - return nil, fmt.Errorf("failed to create batch: %w", err) - } - - meta, err := op.Metadata() - if err != nil { - return nil, fmt.Errorf("failed to get create batch op metadata: %w", err) - } - - jsonBytes, err := protojson.Marshal(meta) - if err != nil { - return nil, fmt.Errorf("failed to marshal create batch op metadata to JSON: %w", err) - } - - var result map[string]any - if err := json.Unmarshal(jsonBytes, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal create batch op metadata JSON: %w", err) - } - - projectID, location, batchID, err := common.ExtractBatchDetails(meta.GetBatch()) - if err != nil { - return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err) - } - consoleUrl := common.BatchConsoleURL(projectID, location, batchID) - logsUrl := common.BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{}) - - wrappedResult := map[string]any{ - "opMetadata": meta, - "consoleUrl": consoleUrl, - "logsUrl": logsUrl, - } - - return wrappedResult, nil + return source.CreateBatch(ctx, batch) } func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go index 6d39b077eba..fe072f07607 100644 --- a/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go +++ b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go @@ -19,8 +19,7 @@ import ( "fmt" "strings" - longrunning "cloud.google.com/go/longrunning/autogen" - "cloud.google.com/go/longrunning/autogen/longrunningpb" + dataproc "cloud.google.com/go/dataproc/v2/apiv1" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" @@ -45,9 +44,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - GetOperationsClient(context.Context) (*longrunning.OperationsClient, error) - GetProject() string - GetLocation() string + GetBatchControllerClient() *dataproc.BatchControllerClient + CancelOperation(context.Context, string) (any, error) } type Config struct { @@ -106,32 +104,15 @@ func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, par if err != nil { return nil, err } - - client, err := source.GetOperationsClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get operations client: %w", err) - } - paramMap := params.AsMap() operation, ok := paramMap["operation"].(string) if !ok { return nil, fmt.Errorf("missing required parameter: operation") } - if strings.Contains(operation, "/") { return nil, fmt.Errorf("operation must be a short operation name without '/': %s", operation) } - - req := &longrunningpb.CancelOperationRequest{ - Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", source.GetProject(), source.GetLocation(), operation), - } - - err = client.CancelOperation(ctx, req) - if err != nil { - return nil, fmt.Errorf("failed to cancel operation: %w", err) - } - - return fmt.Sprintf("Cancelled [%s].", operation), nil + return source.CancelOperation(ctx, operation) } func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go b/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go index 23dd23f4bd8..85d1247fffa 100644 --- a/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go +++ b/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go @@ -16,19 +16,15 @@ package serverlesssparkgetbatch import ( "context" - "encoding/json" "fmt" "strings" dataproc "cloud.google.com/go/dataproc/v2/apiv1" - "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "google.golang.org/protobuf/encoding/protojson" ) const kind = "serverless-spark-get-batch" @@ -49,8 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { GetBatchControllerClient() *dataproc.BatchControllerClient - GetProject() string - GetLocation() string + GetBatch(context.Context, string) (map[string]any, error) } type Config struct { @@ -109,54 +104,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, err } - - client := source.GetBatchControllerClient() - paramMap := params.AsMap() name, ok := paramMap["name"].(string) if !ok { return nil, fmt.Errorf("missing required parameter: name") } - if strings.Contains(name, "/") { return nil, fmt.Errorf("name must be a short batch name without '/': %s", name) } - - req := &dataprocpb.GetBatchRequest{ - Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", source.GetProject(), source.GetLocation(), name), - } - - batchPb, err := client.GetBatch(ctx, req) - if err != nil { - return nil, fmt.Errorf("failed to get batch: %w", err) - } - - jsonBytes, err := protojson.Marshal(batchPb) - if err != nil { - return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err) - } - - var result map[string]any - if err := json.Unmarshal(jsonBytes, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err) - } - - consoleUrl, err := common.BatchConsoleURLFromProto(batchPb) - if err != nil { - return nil, fmt.Errorf("error generating console url: %v", err) - } - logsUrl, err := common.BatchLogsURLFromProto(batchPb) - if err != nil { - return nil, fmt.Errorf("error generating logs url: %v", err) - } - - wrappedResult := map[string]any{ - "consoleUrl": consoleUrl, - "logsUrl": logsUrl, - "batch": result, - } - - return wrappedResult, nil + return source.GetBatch(ctx, name) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { return parameters.ParseParams(t.Parameters, data, claims) diff --git a/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go index 9fe4bb43bf0..3757b9a6e65 100644 --- a/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go +++ b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go @@ -17,17 +17,13 @@ package serverlesssparklistbatches import ( "context" "fmt" - "time" dataproc "cloud.google.com/go/dataproc/v2/apiv1" - "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "google.golang.org/api/iterator" ) const kind = "serverless-spark-list-batches" @@ -48,8 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { GetBatchControllerClient() *dataproc.BatchControllerClient - GetProject() string - GetLocation() string + ListBatches(context.Context, *int, string, string) (any, error) } type Config struct { @@ -104,95 +99,24 @@ type Tool struct { Parameters parameters.Parameters } -// ListBatchesResponse is the response from the list batches API. -type ListBatchesResponse struct { - Batches []Batch `json:"batches"` - NextPageToken string `json:"nextPageToken"` -} - -// Batch represents a single batch job. -type Batch struct { - Name string `json:"name"` - UUID string `json:"uuid"` - State string `json:"state"` - Creator string `json:"creator"` - CreateTime string `json:"createTime"` - Operation string `json:"operation"` - ConsoleURL string `json:"consoleUrl"` - LogsURL string `json:"logsUrl"` -} - // Invoke executes the tool's operation. func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - - client := source.GetBatchControllerClient() - - parent := fmt.Sprintf("projects/%s/locations/%s", source.GetProject(), source.GetLocation()) - req := &dataprocpb.ListBatchesRequest{ - Parent: parent, - OrderBy: "create_time desc", - } - paramMap := params.AsMap() + var pageSize *int if ps, ok := paramMap["pageSize"]; ok && ps != nil { - req.PageSize = int32(ps.(int)) - if (req.PageSize) <= 0 { - return nil, fmt.Errorf("pageSize must be positive: %d", req.PageSize) - } - } - if pt, ok := paramMap["pageToken"]; ok && pt != nil { - req.PageToken = pt.(string) - } - if filter, ok := paramMap["filter"]; ok && filter != nil { - req.Filter = filter.(string) - } - - it := client.ListBatches(ctx, req) - pager := iterator.NewPager(it, int(req.PageSize), req.PageToken) - - var batchPbs []*dataprocpb.Batch - nextPageToken, err := pager.NextPage(&batchPbs) - if err != nil { - return nil, fmt.Errorf("failed to list batches: %w", err) - } - - batches, err := ToBatches(batchPbs) - if err != nil { - return nil, err - } - - return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil -} - -// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs. -func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) { - batches := make([]Batch, 0, len(batchPbs)) - for _, batchPb := range batchPbs { - consoleUrl, err := common.BatchConsoleURLFromProto(batchPb) - if err != nil { - return nil, fmt.Errorf("error generating console url: %v", err) - } - logsUrl, err := common.BatchLogsURLFromProto(batchPb) - if err != nil { - return nil, fmt.Errorf("error generating logs url: %v", err) - } - batch := Batch{ - Name: batchPb.Name, - UUID: batchPb.Uuid, - State: batchPb.State.Enum().String(), - Creator: batchPb.Creator, - CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339), - Operation: batchPb.Operation, - ConsoleURL: consoleUrl, - LogsURL: logsUrl, + pageSizeV := ps.(int) + if pageSizeV <= 0 { + return nil, fmt.Errorf("pageSize must be positive: %d", pageSizeV) } - batches = append(batches, batch) + pageSize = &pageSizeV } - return batches, nil + pt, _ := paramMap["pageToken"].(string) + filter, _ := paramMap["filter"].(string) + return source.ListBatches(ctx, pageSize, pt, filter) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/tests/serverlessspark/serverless_spark_integration_test.go b/tests/serverlessspark/serverless_spark_integration_test.go index c2f245dc4fa..12545a87aa5 100644 --- a/tests/serverlessspark/serverless_spark_integration_test.go +++ b/tests/serverlessspark/serverless_spark_integration_test.go @@ -33,8 +33,8 @@ import ( dataproc "cloud.google.com/go/dataproc/v2/apiv1" "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" "github.com/googleapis/genai-toolbox/internal/testutils" - "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" "github.com/googleapis/genai-toolbox/tests" "google.golang.org/api/iterator" "google.golang.org/api/option" @@ -676,7 +676,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct filter string pageSize int numPages int - want []serverlesssparklistbatches.Batch + want []serverlessspark.Batch }{ {name: "one page", pageSize: 2, numPages: 1, want: batch2}, {name: "two pages", pageSize: 1, numPages: 2, want: batch2}, @@ -701,7 +701,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { t.Parallel() - var actual []serverlesssparklistbatches.Batch + var actual []serverlessspark.Batch var pageToken string for i := 0; i < tc.numPages; i++ { request := map[string]any{ @@ -733,7 +733,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct t.Fatalf("unable to find result in response body") } - var listResponse serverlesssparklistbatches.ListBatchesResponse + var listResponse serverlessspark.ListBatchesResponse if err := json.Unmarshal([]byte(result), &listResponse); err != nil { t.Fatalf("error unmarshalling result: %s", err) } @@ -759,7 +759,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct } } -func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, filter string, n int, exact bool) []serverlesssparklistbatches.Batch { +func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, filter string, n int, exact bool) []serverlessspark.Batch { parent := fmt.Sprintf("projects/%s/locations/%s", serverlessSparkProject, serverlessSparkLocation) req := &dataprocpb.ListBatchesRequest{ Parent: parent, @@ -783,7 +783,7 @@ func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx co if !exact && (len(batchPbs) == 0 || len(batchPbs) > n) { t.Fatalf("expected between 1 and %d batches, got %d", n, len(batchPbs)) } - batches, err := serverlesssparklistbatches.ToBatches(batchPbs) + batches, err := serverlessspark.ToBatches(batchPbs) if err != nil { t.Fatalf("failed to convert batches to JSON: %v", err) }