Skip to content

Commit 88f4b30

Browse files
authored
feat(tools/bigquery): Support end-user credential passthrough on multiple BQ tools (#1314)
Support end-user credential passthrough on BQ Tools that are using clients.
1 parent 19a7fe2 commit 88f4b30

File tree

10 files changed

+535
-145
lines changed

10 files changed

+535
-145
lines changed

internal/sources/bigquery/bigquery.go

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ const SourceKind string = "bigquery"
3535
// validate interface
3636
var _ sources.SourceConfig = Config{}
3737

38-
type BigqueryClientCreator func(tokenString tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
38+
type BigqueryClientCreator func(tokenString tools.AccessToken, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
3939

4040
func init() {
4141
if !sources.Register(SourceKind, newConfig) {
@@ -88,6 +88,8 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
8888
s := &Source{
8989
Name: r.Name,
9090
Kind: SourceKind,
91+
Project: r.Project,
92+
Location: r.Location,
9193
Client: client,
9294
RestService: restService,
9395
TokenSource: tokenSource,
@@ -105,6 +107,8 @@ type Source struct {
105107
// BigQuery Google SQL struct with client
106108
Name string `yaml:"name"`
107109
Kind string `yaml:"kind"`
110+
Project string
111+
Location string
108112
Client *bigqueryapi.Client
109113
RestService *bigqueryrestapi.Service
110114
TokenSource oauth2.TokenSource
@@ -130,6 +134,14 @@ func (s *Source) UseClientAuthorization() bool {
130134
return s.UseClientOAuth
131135
}
132136

137+
func (s *Source) BigQueryProject() string {
138+
return s.Project
139+
}
140+
141+
func (s *Source) BigQueryLocation() string {
142+
return s.Location
143+
}
144+
133145
func (s *Source) BigQueryTokenSource() oauth2.TokenSource {
134146
return s.TokenSource
135147
}
@@ -188,6 +200,7 @@ func initBigQueryConnectionWithOAuthToken(
188200
name string,
189201
userAgent string,
190202
tokenString tools.AccessToken,
203+
wantRestService bool,
191204
) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
192205
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
193206
defer span.End()
@@ -204,13 +217,16 @@ func initBigQueryConnectionWithOAuthToken(
204217
}
205218
client.Location = location
206219

207-
// Initialize the low-level BigQuery REST service using the same credentials
208-
restService, err := bigqueryrestapi.NewService(ctx, option.WithUserAgent(userAgent), option.WithTokenSource(ts))
209-
if err != nil {
210-
return nil, nil, fmt.Errorf("failed to create BigQuery v2 service: %w", err)
220+
if wantRestService {
221+
// Initialize the low-level BigQuery REST service using the same credentials
222+
restService, err := bigqueryrestapi.NewService(ctx, option.WithUserAgent(userAgent), option.WithTokenSource(ts))
223+
if err != nil {
224+
return nil, nil, fmt.Errorf("failed to create BigQuery v2 service: %w", err)
225+
}
226+
return client, restService, nil
211227
}
212228

213-
return client, restService, nil
229+
return client, nil, nil
214230
}
215231

216232
// newBigQueryClientCreator sets the project parameters for the init helper
@@ -222,13 +238,13 @@ func newBigQueryClientCreator(
222238
project string,
223239
location string,
224240
name string,
225-
) (func(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error), error) {
241+
) (func(tools.AccessToken, bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error), error) {
226242
userAgent, err := util.UserAgentFromContext(ctx)
227243
if err != nil {
228244
return nil, err
229245
}
230246

231-
return func(tokenString tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
232-
return initBigQueryConnectionWithOAuthToken(ctx, tracer, project, location, name, userAgent, tokenString)
247+
return func(tokenString tools.AccessToken, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
248+
return initBigQueryConnectionWithOAuthToken(ctx, tracer, project, location, name, userAgent, tokenString, wantRestService)
233249
}, nil
234250
}

internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
5555
type compatibleSource interface {
5656
BigQueryClient() *bigqueryapi.Client
5757
BigQueryTokenSource() oauth2.TokenSource
58+
BigQueryProject() string
59+
BigQueryLocation() string
5860
GetMaxQueryResultRows() int
61+
UseClientAuthorization() bool
5962
}
6063

6164
type BQTableReference struct {
@@ -145,9 +148,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
145148
t := Tool{
146149
Name: cfg.Name,
147150
Kind: kind,
151+
Project: s.BigQueryProject(),
152+
Location: s.BigQueryLocation(),
148153
Parameters: parameters,
149154
AuthRequired: cfg.AuthRequired,
150155
Client: s.BigQueryClient(),
156+
UseClientOAuth: s.UseClientAuthorization(),
151157
TokenSource: s.BigQueryTokenSource(),
152158
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
153159
mcpManifest: mcpManifest,
@@ -160,10 +166,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
160166
var _ tools.Tool = Tool{}
161167

162168
type Tool struct {
163-
Name string `yaml:"name"`
164-
Kind string `yaml:"kind"`
165-
AuthRequired []string `yaml:"authRequired"`
166-
Parameters tools.Parameters `yaml:"parameters"`
169+
Name string `yaml:"name"`
170+
Kind string `yaml:"kind"`
171+
AuthRequired []string `yaml:"authRequired"`
172+
UseClientOAuth bool `yaml:"useClientOAuth"`
173+
Parameters tools.Parameters `yaml:"parameters"`
174+
175+
Project string
176+
Location string
167177
Client *bigqueryapi.Client
168178
TokenSource oauth2.TokenSource
169179
manifest tools.Manifest
@@ -172,14 +182,25 @@ type Tool struct {
172182
}
173183

174184
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
175-
// Get credentials for the API call
176-
if t.TokenSource == nil {
177-
return nil, fmt.Errorf("authentication error: found credentials but they are missing a valid token source")
178-
}
185+
var tokenStr string
179186

180-
token, err := t.TokenSource.Token()
181-
if err != nil {
182-
return nil, fmt.Errorf("failed to get token from credentials: %w", err)
187+
// Get credentials for the API call
188+
if t.UseClientOAuth {
189+
// Use client-side access token
190+
if accessToken == "" {
191+
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header")
192+
}
193+
tokenStr = string(accessToken)
194+
} else {
195+
// Use ADC
196+
if t.TokenSource == nil {
197+
return nil, fmt.Errorf("ADC is missing a valid token source")
198+
}
199+
token, err := t.TokenSource.Token()
200+
if err != nil {
201+
return nil, fmt.Errorf("failed to get token from ADC: %w", err)
202+
}
203+
tokenStr = token.AccessToken
183204
}
184205

185206
// Extract parameters from the map
@@ -197,15 +218,15 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
197218
}
198219

199220
// Construct URL, headers, and payload
200-
projectID := t.Client.Project()
201-
location := t.Client.Location
221+
projectID := t.Project
222+
location := t.Location
202223
if location == "" {
203224
location = "us"
204225
}
205226
caURL := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1alpha/projects/%s/locations/%s:chat", projectID, location)
206227

207228
headers := map[string]string{
208-
"Authorization": fmt.Sprintf("Bearer %s", token.AccessToken),
229+
"Authorization": fmt.Sprintf("Bearer %s", tokenStr),
209230
"Content-Type": "application/json",
210231
}
211232

@@ -246,7 +267,7 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
246267
}
247268

248269
func (t Tool) RequiresClientAuthorization() bool {
249-
return false
270+
return t.UseClientOAuth
250271
}
251272

252273
// StreamMessage represents a single message object from the streaming API response.

internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
4848
type compatibleSource interface {
4949
BigQueryClient() *bigqueryapi.Client
5050
BigQueryRestService() *bigqueryrestapi.Service
51+
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
52+
UseClientAuthorization() bool
5153
}
5254

5355
// validate compatible sources are still compatible
@@ -100,14 +102,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
100102

101103
// finish tool setup
102104
t := Tool{
103-
Name: cfg.Name,
104-
Kind: kind,
105-
Parameters: parameters,
106-
AuthRequired: cfg.AuthRequired,
107-
Client: s.BigQueryClient(),
108-
RestService: s.BigQueryRestService(),
109-
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
110-
mcpManifest: mcpManifest,
105+
Name: cfg.Name,
106+
Kind: kind,
107+
Parameters: parameters,
108+
AuthRequired: cfg.AuthRequired,
109+
UseClientOAuth: s.UseClientAuthorization(),
110+
ClientCreator: s.BigQueryClientCreator(),
111+
Client: s.BigQueryClient(),
112+
RestService: s.BigQueryRestService(),
113+
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
114+
mcpManifest: mcpManifest,
111115
}
112116
return t, nil
113117
}
@@ -116,14 +120,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
116120
var _ tools.Tool = Tool{}
117121

118122
type Tool struct {
119-
Name string `yaml:"name"`
120-
Kind string `yaml:"kind"`
121-
AuthRequired []string `yaml:"authRequired"`
122-
Parameters tools.Parameters `yaml:"parameters"`
123-
Client *bigqueryapi.Client
124-
RestService *bigqueryrestapi.Service
125-
manifest tools.Manifest
126-
mcpManifest tools.McpManifest
123+
Name string `yaml:"name"`
124+
Kind string `yaml:"kind"`
125+
AuthRequired []string `yaml:"authRequired"`
126+
UseClientOAuth bool `yaml:"useClientOAuth"`
127+
Parameters tools.Parameters `yaml:"parameters"`
128+
129+
Client *bigqueryapi.Client
130+
RestService *bigqueryrestapi.Service
131+
ClientCreator bigqueryds.BigqueryClientCreator
132+
manifest tools.Manifest
133+
mcpManifest tools.McpManifest
127134
}
128135

129136
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
@@ -137,7 +144,19 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
137144
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
138145
}
139146

140-
dryRunJob, err := dryRunQuery(ctx, t.RestService, t.Client.Project(), t.Client.Location, sql)
147+
bqClient := t.Client
148+
restService := t.RestService
149+
150+
var err error
151+
// Initialize new client if using user OAuth token
152+
if t.UseClientOAuth {
153+
bqClient, restService, err = t.ClientCreator(accessToken, true)
154+
if err != nil {
155+
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
156+
}
157+
}
158+
159+
dryRunJob, err := dryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql)
141160
if err != nil {
142161
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
143162
}
@@ -156,8 +175,8 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
156175

157176
statementType := dryRunJob.Statistics.Query.StatementType
158177
// JobStatistics.QueryStatistics.StatementType
159-
query := t.Client.Query(sql)
160-
query.Location = t.Client.Location
178+
query := bqClient.Query(sql)
179+
query.Location = bqClient.Location
161180

162181
// Log the query executed for debugging.
163182
logger, err := util.LoggerFromContext(ctx)
@@ -223,7 +242,7 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
223242
}
224243

225244
func (t Tool) RequiresClientAuthorization() bool {
226-
return false
245+
return t.UseClientOAuth
227246
}
228247

229248
// dryRunQuery performs a dry run of the SQL query to validate it and get metadata.

internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
4848
type compatibleSource interface {
4949
BigQueryClient() *bigqueryapi.Client
5050
BigQueryRestService() *bigqueryrestapi.Service
51+
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
52+
UseClientAuthorization() bool
5153
}
5254

5355
// validate compatible sources are still compatible
@@ -104,14 +106,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
104106

105107
// finish tool setup
106108
t := Tool{
107-
Name: cfg.Name,
108-
Kind: kind,
109-
Parameters: parameters,
110-
AuthRequired: cfg.AuthRequired,
111-
Client: s.BigQueryClient(),
112-
RestService: s.BigQueryRestService(),
113-
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
114-
mcpManifest: mcpManifest,
109+
Name: cfg.Name,
110+
Kind: kind,
111+
Parameters: parameters,
112+
AuthRequired: cfg.AuthRequired,
113+
UseClientOAuth: s.UseClientAuthorization(),
114+
ClientCreator: s.BigQueryClientCreator(),
115+
Client: s.BigQueryClient(),
116+
RestService: s.BigQueryRestService(),
117+
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
118+
mcpManifest: mcpManifest,
115119
}
116120
return t, nil
117121
}
@@ -120,14 +124,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
120124
var _ tools.Tool = Tool{}
121125

122126
type Tool struct {
123-
Name string `yaml:"name"`
124-
Kind string `yaml:"kind"`
125-
AuthRequired []string `yaml:"authRequired"`
126-
Parameters tools.Parameters `yaml:"parameters"`
127-
Client *bigqueryapi.Client
128-
RestService *bigqueryrestapi.Service
129-
manifest tools.Manifest
130-
mcpManifest tools.McpManifest
127+
Name string `yaml:"name"`
128+
Kind string `yaml:"kind"`
129+
AuthRequired []string `yaml:"authRequired"`
130+
UseClientOAuth bool `yaml:"useClientOAuth"`
131+
Parameters tools.Parameters `yaml:"parameters"`
132+
133+
Client *bigqueryapi.Client
134+
RestService *bigqueryrestapi.Service
135+
ClientCreator bigqueryds.BigqueryClientCreator
136+
manifest tools.Manifest
137+
mcpManifest tools.McpManifest
131138
}
132139

133140
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
@@ -187,9 +194,20 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
187194
horizon => %d%s)`,
188195
historyDataSource, dataCol, timestampCol, horizon, idColsArg)
189196

197+
bqClient := t.Client
198+
var err error
199+
200+
// Initialize new client if using user OAuth token
201+
if t.UseClientOAuth {
202+
bqClient, _, err = t.ClientCreator(accessToken, false)
203+
if err != nil {
204+
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
205+
}
206+
}
207+
190208
// JobStatistics.QueryStatistics.StatementType
191-
query := t.Client.Query(sql)
192-
query.Location = t.Client.Location
209+
query := bqClient.Query(sql)
210+
query.Location = bqClient.Location
193211

194212
// Log the query executed for debugging.
195213
logger, err := util.LoggerFromContext(ctx)
@@ -247,5 +265,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
247265
}
248266

249267
func (t Tool) RequiresClientAuthorization() bool {
250-
return false
268+
return t.UseClientOAuth
251269
}

0 commit comments

Comments
 (0)