diff --git a/client/client.go b/client/client.go index e9b822acc..99b4482fa 100644 --- a/client/client.go +++ b/client/client.go @@ -31,6 +31,7 @@ package client import ( "context" + "crypto/tls" "io" commonpb "go.temporal.io/api/common/v1" @@ -90,6 +91,9 @@ type ( // ConnectionOptions are optional parameters that can be specified in ClientOptions ConnectionOptions = internal.ConnectionOptions + // Credentials are optional credentials that can be specified in ClientOptions. + Credentials = internal.Credentials + // StartWorkflowOptions configuration parameters for starting a workflow execution. StartWorkflowOptions = internal.StartWorkflowOptions @@ -752,3 +756,41 @@ type HistoryJSONOptions struct { func HistoryFromJSON(r io.Reader, options HistoryJSONOptions) (*historypb.History, error) { return internal.HistoryFromJSON(r, options.LastEventID) } + +// NewAPIKeyStaticCredentials creates credentials that can be provided to +// ClientOptions to use a fixed API key. +// +// This is the equivalent of providing a headers provider that sets the +// "Authorization" header with "Bearer " + the given key. This will overwrite +// any "Authorization" header that may be on the context or from existing header +// provider. +// +// Note, this uses a fixed header value for authentication. Many users that want +// to rotate this value without reconnecting should use +// [NewAPIKeyDynamicCredentials]. +func NewAPIKeyStaticCredentials(apiKey string) Credentials { + return internal.NewAPIKeyStaticCredentials(apiKey) +} + +// NewAPIKeyDynamicCredentials creates credentials powered by a callback that +// is invoked on each request. The callback accepts the context that is given by +// the calling user and can return a key or an error. When error is non-nil, the +// client call is failed with that error. When string is non-empty, it is used +// as the API key. When string is empty, nothing is set/overridden. +// +// This is the equivalent of providing a headers provider that returns the +// "Authorization" header with "Bearer " + the given function result. If the +// resulting string is non-empty, it will overwrite any "Authorization" header +// that may be on the context or from existing header provider. +func NewAPIKeyDynamicCredentials(apiKeyCallback func(context.Context) (string, error)) Credentials { + return internal.NewAPIKeyDynamicCredentials(apiKeyCallback) +} + +// NewMTLSCredentials creates credentials that use TLS with the client +// certificate as the given one. If the client options do not already enable +// TLS, this enables it. If the client options' TLS configuration is present and +// already has a client certificate, client creation will fail when applying +// these credentials. +func NewMTLSCredentials(certificate tls.Certificate) Credentials { + return internal.NewMTLSCredentials(certificate) +} diff --git a/internal/client.go b/internal/client.go index 7c286c000..b7a207cb8 100644 --- a/internal/client.go +++ b/internal/client.go @@ -36,6 +36,7 @@ import ( "go.temporal.io/api/operatorservice/v1" "go.temporal.io/api/workflowservice/v1" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "go.temporal.io/sdk/converter" "go.temporal.io/sdk/internal/common/metrics" @@ -421,6 +422,9 @@ type ( // default: default Namespace string + // Optional: Set the credentials for this client. + Credentials Credentials + // Optional: Logger framework can use to log. // default: default logger provided. Logger log.Logger @@ -473,7 +477,6 @@ type ( HeadersProvider interface { GetHeaders(ctx context.Context) (map[string]string, error) } - // TrafficController is getting called in the interceptor chain with API invocation parameters. // Result is either nil if API call is allowed or an error, in which case request would be interrupted and // the error will be propagated back through the interceptor chain. @@ -704,6 +707,13 @@ type ( } ) +// Credentials are optional credentials that can be specified in ClientOptions. +type Credentials interface { + applyToOptions(*ClientOptions) error + // Can return nil to have no interceptor + gRPCInterceptor() grpc.UnaryClientInterceptor +} + // DialClient creates a client and attempts to connect to the server. func DialClient(options ClientOptions) (Client, error) { options.ConnectionOptions.disableEagerConnection = false @@ -753,6 +763,12 @@ func newClient(options ClientOptions, existing *WorkflowClient) (Client, error) options.Logger.Info("No logger configured for temporal client. Created default one.") } + if options.Credentials != nil { + if err := options.Credentials.applyToOptions(&options); err != nil { + return nil, err + } + } + // Dial or use existing connection var connection *grpc.ClientConn var err error @@ -800,6 +816,7 @@ func newDialParameters(options *ClientOptions, excludeInternalFromRetry *atomic. options.HeadersProvider, options.TrafficController, excludeInternalFromRetry, + options.Credentials, ), DefaultServiceConfig: defaultServiceConfig, } @@ -923,3 +940,56 @@ func NewValue(data *commonpb.Payloads) converter.EncodedValue { func NewValues(data *commonpb.Payloads) converter.EncodedValues { return newEncodedValues(data, nil) } + +type apiKeyCredentials func(context.Context) (string, error) + +func NewAPIKeyStaticCredentials(apiKey string) Credentials { + return NewAPIKeyDynamicCredentials(func(ctx context.Context) (string, error) { return apiKey, nil }) +} + +func NewAPIKeyDynamicCredentials(apiKeyCallback func(context.Context) (string, error)) Credentials { + return apiKeyCredentials(apiKeyCallback) +} + +func (apiKeyCredentials) applyToOptions(*ClientOptions) error { return nil } + +func (a apiKeyCredentials) gRPCInterceptor() grpc.UnaryClientInterceptor { return a.gRPCIntercept } + +func (a apiKeyCredentials) gRPCIntercept( + ctx context.Context, + method string, + req any, + reply any, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + opts ...grpc.CallOption, +) error { + if apiKey, err := a(ctx); err != nil { + return err + } else if apiKey != "" { + // Do from-add-new instead of append to overwrite anything there + md, _ := metadata.FromOutgoingContext(ctx) + if md == nil { + md = metadata.MD{} + } + md["authorization"] = []string{"Bearer " + apiKey} + ctx = metadata.NewOutgoingContext(ctx, md) + } + return invoker(ctx, method, req, reply, cc, opts...) +} + +type mTLSCredentials tls.Certificate + +func NewMTLSCredentials(certificate tls.Certificate) Credentials { return mTLSCredentials(certificate) } + +func (m mTLSCredentials) applyToOptions(opts *ClientOptions) error { + if opts.ConnectionOptions.TLS == nil { + opts.ConnectionOptions.TLS = &tls.Config{} + } else if len(opts.ConnectionOptions.TLS.Certificates) != 0 { + return fmt.Errorf("cannot apply mTLS credentials, certificates already exist on TLS options") + } + opts.ConnectionOptions.TLS.Certificates = append(opts.ConnectionOptions.TLS.Certificates, tls.Certificate(m)) + return nil +} + +func (mTLSCredentials) gRPCInterceptor() grpc.UnaryClientInterceptor { return nil } diff --git a/internal/grpc_dialer.go b/internal/grpc_dialer.go index c16633849..e59729670 100644 --- a/internal/grpc_dialer.go +++ b/internal/grpc_dialer.go @@ -149,6 +149,7 @@ func requiredInterceptors( headersProvider HeadersProvider, controller TrafficController, excludeInternalFromRetry *atomic.Bool, + credentials Credentials, ) []grpc.UnaryClientInterceptor { interceptors := []grpc.UnaryClientInterceptor{ errorInterceptor, @@ -168,6 +169,13 @@ func requiredInterceptors( if controller != nil { interceptors = append(interceptors, trafficControllerInterceptor(controller)) } + // Add credentials interceptor. This is intentionally added after headers + // provider to overwrite anything set there. + if credentials != nil { + if interceptor := credentials.gRPCInterceptor(); interceptor != nil { + interceptors = append(interceptors, interceptor) + } + } return interceptors } diff --git a/internal/grpc_dialer_test.go b/internal/grpc_dialer_test.go index 2131b815d..557a7ea28 100644 --- a/internal/grpc_dialer_test.go +++ b/internal/grpc_dialer_test.go @@ -26,6 +26,7 @@ package internal import ( "context" + "crypto/tls" "errors" "fmt" "log" @@ -127,13 +128,13 @@ func TestHeadersProvider_Error(t *testing.T) { } func TestHeadersProvider_NotIncludedWhenNil(t *testing.T) { - interceptors := requiredInterceptors(nil, nil, nil, nil) + interceptors := requiredInterceptors(nil, nil, nil, nil, nil) require.Equal(t, 5, len(interceptors)) } func TestHeadersProvider_IncludedWithHeadersProvider(t *testing.T) { interceptors := requiredInterceptors(nil, - authHeadersProvider{token: "test-auth-token"}, nil, nil) + authHeadersProvider{token: "test-auth-token"}, nil, nil, nil) require.Equal(t, 6, len(interceptors)) } @@ -438,12 +439,73 @@ func TestResourceExhaustedCause(t *testing.T) { assert.True(t, foundWithoutCause) } +func TestCredentialsAPIKey(t *testing.T) { + srv, err := startTestGRPCServer() + require.NoError(t, err) + defer srv.Stop() + + // Fixed string + client, err := DialClient(ClientOptions{ + HostPort: srv.addr, + Credentials: NewAPIKeyStaticCredentials("my-api-key"), + }) + require.NoError(t, err) + defer client.Close() + require.Equal( + t, + []string{"Bearer my-api-key"}, + metadata.ValueFromIncomingContext(srv.getSystemInfoRequestContext, "Authorization"), + ) + + // Callback + client, err = DialClient(ClientOptions{ + HostPort: srv.addr, + Credentials: NewAPIKeyDynamicCredentials(func(ctx context.Context) (string, error) { + return "my-callback-api-key", nil + }), + }) + require.NoError(t, err) + defer client.Close() + require.Equal( + t, + []string{"Bearer my-callback-api-key"}, + metadata.ValueFromIncomingContext(srv.getSystemInfoRequestContext, "Authorization"), + ) +} + +func TestCredentialsMTLS(t *testing.T) { + // Just confirming option is set, not full end-to-end mTLS test + + // No TLS set + var clientOptions ClientOptions + creds := NewMTLSCredentials(tls.Certificate{Certificate: [][]byte{[]byte("somedata1")}}) + require.NoError(t, creds.applyToOptions(&clientOptions)) + require.Equal(t, "somedata1", string(clientOptions.ConnectionOptions.TLS.Certificates[0].Certificate[0])) + + // TLS already set + clientOptions = ClientOptions{} + clientOptions.ConnectionOptions.TLS = &tls.Config{ServerName: "my-server-name"} + creds = NewMTLSCredentials(tls.Certificate{Certificate: [][]byte{[]byte("somedata2")}}) + require.NoError(t, creds.applyToOptions(&clientOptions)) + require.Equal(t, "my-server-name", clientOptions.ConnectionOptions.TLS.ServerName) + require.Equal(t, "somedata2", string(clientOptions.ConnectionOptions.TLS.Certificates[0].Certificate[0])) + + // Fail with existing cert + clientOptions = ClientOptions{} + clientOptions.ConnectionOptions.TLS = &tls.Config{ + Certificates: []tls.Certificate{{Certificate: [][]byte{[]byte("somedata3")}}}, + } + creds = NewMTLSCredentials(tls.Certificate{Certificate: [][]byte{[]byte("somedata4")}}) + require.Error(t, creds.applyToOptions(&clientOptions)) +} + type testGRPCServer struct { workflowservice.UnimplementedWorkflowServiceServer *grpc.Server addr string healthServer *health.Server sigWfCount int32 + getSystemInfoRequestContext context.Context getSystemInfoResponse workflowservice.GetSystemInfoResponse getSystemInfoResponseError error signalWorkflowExecutionResponse workflowservice.SignalWorkflowExecutionResponse @@ -500,6 +562,7 @@ func (t *testGRPCServer) GetSystemInfo( ctx context.Context, req *workflowservice.GetSystemInfoRequest, ) (*workflowservice.GetSystemInfoResponse, error) { + t.getSystemInfoRequestContext = ctx return &t.getSystemInfoResponse, t.getSystemInfoResponseError }