diff --git a/api/grpc/token/v1/token.pb.go b/api/grpc/token/v1/token.pb.go index 937ee5fee4..683a114c03 100644 --- a/api/grpc/token/v1/token.pb.go +++ b/api/grpc/token/v1/token.pb.go @@ -144,6 +144,50 @@ func (x *RefreshReq) GetVerified() bool { return false } +type RevokeReq struct { + state protoimpl.MessageState `protogen:"open.v1"` + Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RevokeReq) Reset() { + *x = RevokeReq{} + mi := &file_token_v1_token_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RevokeReq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RevokeReq) ProtoMessage() {} + +func (x *RevokeReq) ProtoReflect() protoreflect.Message { + mi := &file_token_v1_token_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RevokeReq.ProtoReflect.Descriptor instead. +func (*RevokeReq) Descriptor() ([]byte, []int) { + return file_token_v1_token_proto_rawDescGZIP(), []int{2} +} + +func (x *RevokeReq) GetToken() string { + if x != nil { + return x.Token + } + return "" +} + // If a token is not carrying any information itself, the type // field can be used to determine how to validate the token. // Also, different tokens can be encoded in different ways. @@ -158,7 +202,7 @@ type Token struct { func (x *Token) Reset() { *x = Token{} - mi := &file_token_v1_token_proto_msgTypes[2] + mi := &file_token_v1_token_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -170,7 +214,7 @@ func (x *Token) String() string { func (*Token) ProtoMessage() {} func (x *Token) ProtoReflect() protoreflect.Message { - mi := &file_token_v1_token_proto_msgTypes[2] + mi := &file_token_v1_token_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -183,7 +227,7 @@ func (x *Token) ProtoReflect() protoreflect.Message { // Deprecated: Use Token.ProtoReflect.Descriptor instead. func (*Token) Descriptor() ([]byte, []int) { - return file_token_v1_token_proto_rawDescGZIP(), []int{2} + return file_token_v1_token_proto_rawDescGZIP(), []int{3} } func (x *Token) GetAccessToken() string { @@ -207,6 +251,42 @@ func (x *Token) GetAccessType() string { return "" } +type RevokeRes struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RevokeRes) Reset() { + *x = RevokeRes{} + mi := &file_token_v1_token_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RevokeRes) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RevokeRes) ProtoMessage() {} + +func (x *RevokeRes) ProtoReflect() protoreflect.Message { + mi := &file_token_v1_token_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RevokeRes.ProtoReflect.Descriptor instead. +func (*RevokeRes) Descriptor() ([]byte, []int) { + return file_token_v1_token_proto_rawDescGZIP(), []int{4} +} + var File_token_v1_token_proto protoreflect.FileDescriptor const file_token_v1_token_proto_rawDesc = "" + @@ -220,16 +300,20 @@ const file_token_v1_token_proto_rawDesc = "" + "\n" + "RefreshReq\x12#\n" + "\rrefresh_token\x18\x01 \x01(\tR\frefreshToken\x12\x1a\n" + - "\bverified\x18\x02 \x01(\bR\bverified\"\x87\x01\n" + + "\bverified\x18\x02 \x01(\bR\bverified\"!\n" + + "\tRevokeReq\x12\x14\n" + + "\x05token\x18\x01 \x01(\tR\x05token\"\x87\x01\n" + "\x05Token\x12!\n" + "\faccess_token\x18\x01 \x01(\tR\vaccessToken\x12(\n" + "\rrefresh_token\x18\x02 \x01(\tH\x00R\frefreshToken\x88\x01\x01\x12\x1f\n" + "\vaccess_type\x18\x03 \x01(\tR\n" + "accessTypeB\x10\n" + - "\x0e_refresh_token2r\n" + + "\x0e_refresh_token\"\v\n" + + "\tRevokeRes2\xa8\x01\n" + "\fTokenService\x12.\n" + "\x05Issue\x12\x12.token.v1.IssueReq\x1a\x0f.token.v1.Token\"\x00\x122\n" + - "\aRefresh\x12\x14.token.v1.RefreshReq\x1a\x0f.token.v1.Token\"\x00B.Z,github.com/absmach/supermq/api/grpc/token/v1b\x06proto3" + "\aRefresh\x12\x14.token.v1.RefreshReq\x1a\x0f.token.v1.Token\"\x00\x124\n" + + "\x06Revoke\x12\x13.token.v1.RevokeReq\x1a\x13.token.v1.RevokeRes\"\x00B.Z,github.com/absmach/supermq/api/grpc/token/v1b\x06proto3" var ( file_token_v1_token_proto_rawDescOnce sync.Once @@ -243,19 +327,23 @@ func file_token_v1_token_proto_rawDescGZIP() []byte { return file_token_v1_token_proto_rawDescData } -var file_token_v1_token_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_token_v1_token_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_token_v1_token_proto_goTypes = []any{ (*IssueReq)(nil), // 0: token.v1.IssueReq (*RefreshReq)(nil), // 1: token.v1.RefreshReq - (*Token)(nil), // 2: token.v1.Token + (*RevokeReq)(nil), // 2: token.v1.RevokeReq + (*Token)(nil), // 3: token.v1.Token + (*RevokeRes)(nil), // 4: token.v1.RevokeRes } var file_token_v1_token_proto_depIdxs = []int32{ 0, // 0: token.v1.TokenService.Issue:input_type -> token.v1.IssueReq 1, // 1: token.v1.TokenService.Refresh:input_type -> token.v1.RefreshReq - 2, // 2: token.v1.TokenService.Issue:output_type -> token.v1.Token - 2, // 3: token.v1.TokenService.Refresh:output_type -> token.v1.Token - 2, // [2:4] is the sub-list for method output_type - 0, // [0:2] is the sub-list for method input_type + 2, // 2: token.v1.TokenService.Revoke:input_type -> token.v1.RevokeReq + 3, // 3: token.v1.TokenService.Issue:output_type -> token.v1.Token + 3, // 4: token.v1.TokenService.Refresh:output_type -> token.v1.Token + 4, // 5: token.v1.TokenService.Revoke:output_type -> token.v1.RevokeRes + 3, // [3:6] is the sub-list for method output_type + 0, // [0:3] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -266,14 +354,14 @@ func file_token_v1_token_proto_init() { if File_token_v1_token_proto != nil { return } - file_token_v1_token_proto_msgTypes[2].OneofWrappers = []any{} + file_token_v1_token_proto_msgTypes[3].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_token_v1_token_proto_rawDesc), len(file_token_v1_token_proto_rawDesc)), NumEnums: 0, - NumMessages: 3, + NumMessages: 5, NumExtensions: 0, NumServices: 1, }, diff --git a/api/grpc/token/v1/token_grpc.pb.go b/api/grpc/token/v1/token_grpc.pb.go index 70ac6a7609..98a7a00c4a 100644 --- a/api/grpc/token/v1/token_grpc.pb.go +++ b/api/grpc/token/v1/token_grpc.pb.go @@ -24,6 +24,7 @@ const _ = grpc.SupportPackageIsVersion9 const ( TokenService_Issue_FullMethodName = "/token.v1.TokenService/Issue" TokenService_Refresh_FullMethodName = "/token.v1.TokenService/Refresh" + TokenService_Revoke_FullMethodName = "/token.v1.TokenService/Revoke" ) // TokenServiceClient is the client API for TokenService service. @@ -32,6 +33,7 @@ const ( type TokenServiceClient interface { Issue(ctx context.Context, in *IssueReq, opts ...grpc.CallOption) (*Token, error) Refresh(ctx context.Context, in *RefreshReq, opts ...grpc.CallOption) (*Token, error) + Revoke(ctx context.Context, in *RevokeReq, opts ...grpc.CallOption) (*RevokeRes, error) } type tokenServiceClient struct { @@ -62,12 +64,23 @@ func (c *tokenServiceClient) Refresh(ctx context.Context, in *RefreshReq, opts . return out, nil } +func (c *tokenServiceClient) Revoke(ctx context.Context, in *RevokeReq, opts ...grpc.CallOption) (*RevokeRes, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(RevokeRes) + err := c.cc.Invoke(ctx, TokenService_Revoke_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + // TokenServiceServer is the server API for TokenService service. // All implementations must embed UnimplementedTokenServiceServer // for forward compatibility. type TokenServiceServer interface { Issue(context.Context, *IssueReq) (*Token, error) Refresh(context.Context, *RefreshReq) (*Token, error) + Revoke(context.Context, *RevokeReq) (*RevokeRes, error) mustEmbedUnimplementedTokenServiceServer() } @@ -84,6 +97,9 @@ func (UnimplementedTokenServiceServer) Issue(context.Context, *IssueReq) (*Token func (UnimplementedTokenServiceServer) Refresh(context.Context, *RefreshReq) (*Token, error) { return nil, status.Errorf(codes.Unimplemented, "method Refresh not implemented") } +func (UnimplementedTokenServiceServer) Revoke(context.Context, *RevokeReq) (*RevokeRes, error) { + return nil, status.Errorf(codes.Unimplemented, "method Revoke not implemented") +} func (UnimplementedTokenServiceServer) mustEmbedUnimplementedTokenServiceServer() {} func (UnimplementedTokenServiceServer) testEmbeddedByValue() {} @@ -141,6 +157,24 @@ func _TokenService_Refresh_Handler(srv interface{}, ctx context.Context, dec fun return interceptor(ctx, in, info, handler) } +func _TokenService_Revoke_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RevokeReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(TokenServiceServer).Revoke(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: TokenService_Revoke_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(TokenServiceServer).Revoke(ctx, req.(*RevokeReq)) + } + return interceptor(ctx, in, info, handler) +} + // TokenService_ServiceDesc is the grpc.ServiceDesc for TokenService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -156,6 +190,10 @@ var TokenService_ServiceDesc = grpc.ServiceDesc{ MethodName: "Refresh", Handler: _TokenService_Refresh_Handler, }, + { + MethodName: "Revoke", + Handler: _TokenService_Revoke_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "token/v1/token.proto", diff --git a/auth/api/grpc/token/client.go b/auth/api/grpc/token/client.go index 25c3bf62fe..96cbdb2d10 100644 --- a/auth/api/grpc/token/client.go +++ b/auth/api/grpc/token/client.go @@ -20,6 +20,7 @@ const tokenSvcName = "token.v1.TokenService" type tokenGrpcClient struct { issue endpoint.Endpoint refresh endpoint.Endpoint + revoke endpoint.Endpoint timeout time.Duration } @@ -44,6 +45,14 @@ func NewTokenClient(conn *grpc.ClientConn, timeout time.Duration) grpcTokenV1.To decodeRefreshResponse, grpcTokenV1.Token{}, ).Endpoint(), + revoke: kitgrpc.NewClient( + conn, + tokenSvcName, + "Revoke", + encodeRevokeRequest, + decodeRevokeResponse, + grpcTokenV1.RevokeRes{}, + ).Endpoint(), timeout: timeout, } } @@ -97,3 +106,23 @@ func encodeRefreshRequest(_ context.Context, grpcReq any) (any, error) { func decodeRefreshResponse(_ context.Context, grpcRes any) (any, error) { return grpcRes, nil } + +func (client tokenGrpcClient) Revoke(ctx context.Context, req *grpcTokenV1.RevokeReq, _ ...grpc.CallOption) (*grpcTokenV1.RevokeRes, error) { + ctx, cancel := context.WithTimeout(ctx, client.timeout) + defer cancel() + + res, err := client.revoke(ctx, revokeReq{token: req.GetToken()}) + if err != nil { + return &grpcTokenV1.RevokeRes{}, grpcapi.DecodeError(err) + } + return res.(*grpcTokenV1.RevokeRes), nil +} + +func encodeRevokeRequest(_ context.Context, grpcReq any) (any, error) { + req := grpcReq.(revokeReq) + return &grpcTokenV1.RevokeReq{Token: req.token}, nil +} + +func decodeRevokeResponse(_ context.Context, grpcRes any) (any, error) { + return grpcRes, nil +} diff --git a/auth/api/grpc/token/endpoint.go b/auth/api/grpc/token/endpoint.go index b03e42ae53..dc8cca3611 100644 --- a/auth/api/grpc/token/endpoint.go +++ b/auth/api/grpc/token/endpoint.go @@ -56,3 +56,18 @@ func refreshEndpoint(svc auth.Service) endpoint.Endpoint { return ret, nil } } + +func revokeEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(revokeReq) + if err := req.validate(); err != nil { + return nil, err + } + err := svc.RevokeToken(ctx, req.token) + if err != nil { + return nil, err + } + + return nil, nil + } +} diff --git a/auth/api/grpc/token/endpoint_test.go b/auth/api/grpc/token/endpoint_test.go index be1820e42c..08142d880a 100644 --- a/auth/api/grpc/token/endpoint_test.go +++ b/auth/api/grpc/token/endpoint_test.go @@ -24,24 +24,9 @@ import ( ) const ( - port = 8082 - secret = "secret" - email = "test@example.com" - id = "testID" - clientsType = "clients" - usersType = "users" - description = "Description" - groupName = "smqx" - adminPermission = "admin" - - authoritiesObj = "authorities" - memberRelation = "member" - loginDuration = 30 * time.Minute - refreshDuration = 24 * time.Hour - invalidDuration = 7 * 24 * time.Hour - validToken = "valid" - inValidToken = "invalid" - validPolicy = "valid" + port = 8082 + validToken = "valid" + inValidToken = "invalid" ) var ( @@ -63,9 +48,9 @@ func startGRPCServer(svc auth.Service, port int) *grpc.Server { func TestIssue(t *testing.T) { conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) - defer conn.Close() assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) grpcClient := grpcapi.NewTokenClient(conn, time.Second) + defer conn.Close() cases := []struct { desc string @@ -127,9 +112,9 @@ func TestIssue(t *testing.T) { func TestRefresh(t *testing.T) { conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) - defer conn.Close() assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) grpcClient := grpcapi.NewTokenClient(conn, time.Second) + defer conn.Close() cases := []struct { desc string @@ -167,3 +152,44 @@ func TestRefresh(t *testing.T) { svcCall.Unset() } } + +func TestRevoke(t *testing.T) { + conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) + grpcClient := grpcapi.NewTokenClient(conn, time.Second) + defer conn.Close() + + cases := []struct { + desc string + token string + err error + }{ + { + desc: "revoke token with valid token", + token: validToken, + err: nil, + }, + { + desc: "revoke token with invalid token", + token: inValidToken, + err: svcerr.ErrAuthentication, + }, + { + desc: "revoke token with empty token", + token: "", + err: apiutil.ErrMissingSecret, + }, + { + desc: "revoke already revoked token", + token: validToken, + err: svcerr.ErrConflict, + }, + } + + for _, tc := range cases { + svcCall := svc.On("RevokeToken", mock.Anything, tc.token).Return(tc.err) + _, err := grpcClient.Revoke(context.Background(), &grpcTokenV1.RevokeReq{Token: tc.token}) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + svcCall.Unset() + } +} diff --git a/auth/api/grpc/token/requests.go b/auth/api/grpc/token/requests.go index a5ab3c0949..a87d2b83e4 100644 --- a/auth/api/grpc/token/requests.go +++ b/auth/api/grpc/token/requests.go @@ -38,3 +38,15 @@ func (req refreshReq) validate() error { return nil } + +type revokeReq struct { + token string +} + +func (req revokeReq) validate() error { + if req.token == "" { + return apiutil.ErrMissingSecret + } + + return nil +} diff --git a/auth/api/grpc/token/server.go b/auth/api/grpc/token/server.go index 319e46e6ec..7e5263994d 100644 --- a/auth/api/grpc/token/server.go +++ b/auth/api/grpc/token/server.go @@ -18,6 +18,7 @@ type tokenGrpcServer struct { grpcTokenV1.UnimplementedTokenServiceServer issue kitgrpc.Handler refresh kitgrpc.Handler + revoke kitgrpc.Handler } // NewAuthServer returns new AuthnServiceServer instance. @@ -33,6 +34,11 @@ func NewTokenServer(svc auth.Service) grpcTokenV1.TokenServiceServer { decodeRefreshRequest, encodeIssueResponse, ), + revoke: kitgrpc.NewServer( + (revokeEndpoint(svc)), + decodeRevokeRequest, + encodeRevokeResponse, + ), } } @@ -76,3 +82,20 @@ func encodeIssueResponse(_ context.Context, grpcRes any) (any, error) { AccessType: res.accessType, }, nil } + +func (s *tokenGrpcServer) Revoke(ctx context.Context, req *grpcTokenV1.RevokeReq) (*grpcTokenV1.RevokeRes, error) { + _, res, err := s.revoke.ServeGRPC(ctx, req) + if err != nil { + return nil, grpcapi.EncodeError(err) + } + return res.(*grpcTokenV1.RevokeRes), nil +} + +func decodeRevokeRequest(_ context.Context, grpcReq any) (any, error) { + req := grpcReq.(*grpcTokenV1.RevokeReq) + return revokeReq{token: req.GetToken()}, nil +} + +func encodeRevokeResponse(_ context.Context, grpcRes any) (any, error) { + return &grpcTokenV1.RevokeRes{}, nil +} diff --git a/auth/cache/doc.go b/auth/cache/doc.go index 42396c9830..571a973d0f 100644 --- a/auth/cache/doc.go +++ b/auth/cache/doc.go @@ -1,4 +1,6 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 +// Package cache contains the domain concept definitions needed to +// support SuperMQ auth cache service functionality. package cache diff --git a/auth/cache/tokens.go b/auth/cache/tokens.go new file mode 100644 index 0000000000..9a9c10ed76 --- /dev/null +++ b/auth/cache/tokens.go @@ -0,0 +1,126 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache + +import ( + "context" + "fmt" + "time" + + "github.com/absmach/supermq/auth" + "github.com/redis/go-redis/v9" +) + +const ( + defDuration = 15 * time.Minute + refreshPrefix = "refresh_tokens:" +) + +var _ auth.TokensCache = (*tokensCache)(nil) + +type tokensCache struct { + client *redis.Client + keyDuration time.Duration +} + +// NewTokensCache returns redis auth cache implementation. +func NewTokensCache(client *redis.Client, duration time.Duration) auth.TokensCache { + if duration == 0 { + duration = defDuration + } + return &tokensCache{ + client: client, + keyDuration: duration, + } +} + +// SaveActive saves an active refresh token ID for a user with TTL. +func (tc *tokensCache) SaveActive(ctx context.Context, userID, tokenID string, ttl time.Duration) error { + pipe := tc.client.TxPipeline() + + pipe.Set(ctx, tc.tokenKey(tokenID), userID, ttl) + pipe.SAdd(ctx, tc.userTokensKey(userID), tokenID) + + _, err := pipe.Exec(ctx) + + return err +} + +// IsActive checks if the token ID is active for the given user. +func (tc *tokensCache) IsActive(ctx context.Context, tokenID string) (bool, error) { + count, err := tc.client.Exists(ctx, tc.tokenKey(tokenID)).Result() + if err != nil { + return false, err + } + return count > 0, nil +} + +func (tc *tokensCache) ListUserTokens(ctx context.Context, userID string) ([]string, error) { + key := tc.userTokensKey(userID) + tokenIDs, err := tc.client.SMembers(ctx, key).Result() + if err != nil { + return nil, err + } + + if len(tokenIDs) == 0 { + return nil, nil + } + + valid := make([]string, 0, len(tokenIDs)) + pipe := tc.client.Pipeline() + + existsCmds := make(map[string]*redis.IntCmd, len(tokenIDs)) + for _, tokenID := range tokenIDs { + existsCmds[tokenID] = pipe.Exists(ctx, tc.tokenKey(tokenID)) + } + + _, err = pipe.Exec(ctx) + if err != nil { + return nil, err + } + + cleanup := tc.client.Pipeline() + for tokenID, cmd := range existsCmds { + if cmd.Val() == 1 { + valid = append(valid, tokenID) + } else { + cleanup.SRem(ctx, key, tokenID) + } + } + + _, err = cleanup.Exec(ctx) + if err != nil { + return nil, err + } + + return valid, nil +} + +// RemoveActive removes an active refresh token ID for a user. +func (tc *tokensCache) RemoveActive(ctx context.Context, tokenID string) error { + tokenKey := tc.tokenKey(tokenID) + + userID, err := tc.client.Get(ctx, tokenKey).Result() + if err == redis.Nil { + return nil + } + if err != nil { + return err + } + + pipe := tc.client.TxPipeline() + pipe.Del(ctx, tokenKey) + pipe.SRem(ctx, tc.userTokensKey(userID), tokenID) + + _, err = pipe.Exec(ctx) + return err +} + +func (tc *tokensCache) tokenKey(tokenID string) string { + return fmt.Sprintf("%s:token:%s", refreshPrefix, tokenID) +} + +func (tc *tokensCache) userTokensKey(userID string) string { + return fmt.Sprintf("%s:user_tokens:%s", refreshPrefix, userID) +} diff --git a/auth/cache/tokens_test.go b/auth/cache/tokens_test.go new file mode 100644 index 0000000000..39c372bef1 --- /dev/null +++ b/auth/cache/tokens_test.go @@ -0,0 +1,262 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache_test + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/auth/cache" + "github.com/absmach/supermq/internal/testsutil" + "github.com/absmach/supermq/pkg/errors" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +var ( + storeClient *redis.Client + storeURL string +) + +func TestMain(m *testing.M) { + code := testsutil.RunRedisTest(m, &storeClient, &storeURL) + os.Exit(code) +} + +func setupRedisTokensClient() auth.TokensCache { + return cache.NewTokensCache(storeClient, 10*time.Minute) +} + +func TestTokenSave(t *testing.T) { + storeClient.FlushAll(context.Background()) + tokensCache := setupRedisTokensClient() + + userID := testsutil.GenerateUUID(t) + tokenID := testsutil.GenerateUUID(t) + + cases := []struct { + desc string + userID string + tokenID string + ttl time.Duration + err error + }{ + { + desc: "Save active token", + userID: userID, + tokenID: tokenID, + ttl: 10 * time.Minute, + err: nil, + }, + { + desc: "Save already cached token", + userID: userID, + tokenID: tokenID, + ttl: 10 * time.Minute, + err: nil, + }, + { + desc: "Save another token for same user", + userID: userID, + tokenID: testsutil.GenerateUUID(t), + ttl: 10 * time.Minute, + err: nil, + }, + { + desc: "Save token with empty id", + userID: userID, + tokenID: "", + ttl: 10 * time.Minute, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tokensCache.SaveActive(context.Background(), tc.userID, tc.tokenID, tc.ttl) + if err == nil { + ok, err := tokensCache.IsActive(context.Background(), tc.tokenID) + assert.NoError(t, err) + assert.True(t, ok) + } + assert.True(t, errors.Contains(err, tc.err)) + }) + } +} + +func TestTokenContains(t *testing.T) { + storeClient.FlushAll(context.Background()) + tokensCache := setupRedisTokensClient() + + userID := testsutil.GenerateUUID(t) + tokenID := testsutil.GenerateUUID(t) + + err := tokensCache.SaveActive(context.Background(), userID, tokenID, 10*time.Minute) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + + cases := []struct { + desc string + userID string + tokenID string + ok bool + }{ + { + desc: "IsActive for existing token", + userID: userID, + tokenID: tokenID, + ok: true, + }, + { + desc: "IsActive for non existing token", + userID: userID, + tokenID: testsutil.GenerateUUID(t), + }, + { + desc: "IsActive with empty token id", + userID: userID, + tokenID: "", + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + ok, err := tokensCache.IsActive(context.Background(), tc.tokenID) + if tc.ok { + assert.NoError(t, err) + } + assert.Equal(t, tc.ok, ok) + }) + } +} + +func TestTokenRemove(t *testing.T) { + storeClient.FlushAll(context.Background()) + tokensCache := setupRedisTokensClient() + + userID := testsutil.GenerateUUID(t) + num := 10 + var tokenIDs []string + for range num { + tokenID := testsutil.GenerateUUID(t) + err := tokensCache.SaveActive(context.Background(), userID, tokenID, 10*time.Minute) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + tokenIDs = append(tokenIDs, tokenID) + } + + cases := []struct { + desc string + userID string + tokenID string + err error + }{ + { + desc: "Remove an existing token from cache", + userID: userID, + tokenID: tokenIDs[0], + err: nil, + }, + { + desc: "Remove token with empty id from cache", + userID: userID, + tokenID: "", + err: nil, + }, + { + desc: "Remove non existing id from cache", + userID: userID, + tokenID: testsutil.GenerateUUID(t), + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tokensCache.RemoveActive(context.Background(), tc.tokenID) + assert.True(t, errors.Contains(err, tc.err)) + if err == nil { + ok, err := tokensCache.IsActive(context.Background(), tc.tokenID) + assert.NoError(t, err) + assert.False(t, ok) + } + }) + } +} + +func TestListUserTokens(t *testing.T) { + storeClient.FlushAll(context.Background()) + tokensCache := setupRedisTokensClient() + + userID := testsutil.GenerateUUID(t) + userID2 := testsutil.GenerateUUID(t) + num := 5 + var tokenIDs []string + + for range num { + tokenID := testsutil.GenerateUUID(t) + err := tokensCache.SaveActive(context.Background(), userID, tokenID, 10*time.Minute) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + tokenIDs = append(tokenIDs, tokenID) + } + + tokenID2 := testsutil.GenerateUUID(t) + err := tokensCache.SaveActive(context.Background(), userID2, tokenID2, 10*time.Minute) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + + cases := []struct { + desc string + userID string + expectedCount int + expectedTokens []string + err error + }{ + { + desc: "List all tokens for user with multiple tokens", + userID: userID, + expectedCount: num, + expectedTokens: tokenIDs, + err: nil, + }, + { + desc: "List tokens for user with single token", + userID: userID2, + expectedCount: 1, + expectedTokens: []string{tokenID2}, + err: nil, + }, + { + desc: "List tokens for user with no tokens", + userID: testsutil.GenerateUUID(t), + expectedCount: 0, + expectedTokens: nil, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + tokens, err := tokensCache.ListUserTokens(context.Background(), tc.userID) + assert.True(t, errors.Contains(err, tc.err)) + assert.Equal(t, tc.expectedCount, len(tokens)) + if tc.expectedTokens != nil { + assert.ElementsMatch(t, tc.expectedTokens, tokens) + } + }) + } + + t.Run("Cleanup expired tokens from list", func(t *testing.T) { + // Remove one token directly from Redis to simulate expiration + err := tokensCache.RemoveActive(context.Background(), tokenIDs[0]) + assert.NoError(t, err) + + // List should now return only valid tokens + tokens, err := tokensCache.ListUserTokens(context.Background(), userID) + assert.NoError(t, err) + assert.Equal(t, num-1, len(tokens)) + assert.NotContains(t, tokens, tokenIDs[0]) + }) +} diff --git a/auth/key_manager.go b/auth/key_manager.go index 73c74db82f..9a90388142 100644 --- a/auth/key_manager.go +++ b/auth/key_manager.go @@ -5,13 +5,16 @@ package auth import ( "context" - "errors" + "time" + + "github.com/absmach/supermq/pkg/errors" ) var ( ErrUnsupportedKeyAlgorithm = errors.New("unsupported key algorithm") ErrInvalidSymmetricKey = errors.New("invalid symmetric key") ErrPublicKeysNotSupported = errors.New("public keys not supported for symmetric algorithm") + ErrRevokedToken = errors.NewAuthNError("token is revoked") ) // PublicKeyInfo represents a public key for external distribution via JWKS. @@ -33,7 +36,8 @@ type PublicKeyInfo struct { // Implementations manage underlying cryptographic operations and key distribution. type Tokenizer interface { // Issue creates a signed token string from the given key claims. - Issue(key Key) (token string, err error) + // For RefreshKey types, the token ID is stored as active in the cache. + Issue(ctx context.Context, key Key) (token string, err error) // Parse verifies and parses a token string (JWT or PAT), returning the extracted claims. // For PAT tokens (prefix "pat"), returns a Key with Type set to PersonalAccessToken. @@ -43,6 +47,24 @@ type Tokenizer interface { // RetrieveJWKS returns public keys for distribution via JWKS endpoint. // Returns ErrPublicKeysNotSupported for symmetric tokenizers (HMAC). RetrieveJWKS() ([]PublicKeyInfo, error) + + // Revoke revokes a refresh token. + Revoke(ctx context.Context, token string) error +} + +// TokensCache represents a cache repository for managing active refresh tokens per user. +type TokensCache interface { + // SaveActive saves an active refresh token ID for a user with TTL. + SaveActive(ctx context.Context, userID, tokenID string, ttl time.Duration) error + + // IsActive checks if the token ID is active for the given user. + IsActive(ctx context.Context, tokenID string) (bool, error) + + // ListUserTokens lists all active token IDs for a given user. + ListUserTokens(ctx context.Context, userID string) ([]string, error) + + // Remove removes an active refresh token ID for a user. + RemoveActive(ctx context.Context, tokenID string) error } // IsSymmetricAlgorithm determines if the given algorithm is symmetric (HMAC-based). diff --git a/auth/middleware/logging.go b/auth/middleware/logging.go index b2b8df685b..c256f08137 100644 --- a/auth/middleware/logging.go +++ b/auth/middleware/logging.go @@ -46,6 +46,22 @@ func (lm *loggingMiddleware) Issue(ctx context.Context, token string, key auth.K return lm.svc.Issue(ctx, token, key) } +func (lm *loggingMiddleware) RevokeToken(ctx context.Context, token string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Revoke token failed to complete successfully", args...) + return + } + lm.logger.Info("Revoke token completed successfully", args...) + }(time.Now()) + + return lm.svc.RevokeToken(ctx, token) +} + func (lm *loggingMiddleware) Revoke(ctx context.Context, token, id string) (err error) { defer func(begin time.Time) { args := []any{ diff --git a/auth/middleware/metrics.go b/auth/middleware/metrics.go index e8c2560bc0..aec040cc5a 100644 --- a/auth/middleware/metrics.go +++ b/auth/middleware/metrics.go @@ -40,6 +40,15 @@ func (ms *metricsMiddleware) Issue(ctx context.Context, token string, key auth.K return ms.svc.Issue(ctx, token, key) } +func (ms *metricsMiddleware) RevokeToken(ctx context.Context, token string) error { + defer func(begin time.Time) { + ms.counter.With("method", "revoke_token").Add(1) + ms.latency.With("method", "revoke_token").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.RevokeToken(ctx, token) +} + func (ms *metricsMiddleware) Revoke(ctx context.Context, token, id string) error { defer func(begin time.Time) { ms.counter.With("method", "revoke_key").Add(1) diff --git a/auth/middleware/tracing.go b/auth/middleware/tracing.go index 3855b1b1ff..a6f18058cb 100644 --- a/auth/middleware/tracing.go +++ b/auth/middleware/tracing.go @@ -36,6 +36,13 @@ func (tm *tracingMiddleware) Issue(ctx context.Context, token string, key auth.K return tm.svc.Issue(ctx, token, key) } +func (tm *tracingMiddleware) RevokeToken(ctx context.Context, token string) error { + ctx, span := tm.tracer.Start(ctx, "revoke_token") + defer span.End() + + return tm.svc.RevokeToken(ctx, token) +} + func (tm *tracingMiddleware) Revoke(ctx context.Context, token, id string) error { ctx, span := tm.tracer.Start(ctx, "revoke", trace.WithAttributes( attribute.String("id", id), diff --git a/auth/mocks/service.go b/auth/mocks/service.go index 4482b6df73..101d3abce6 100644 --- a/auth/mocks/service.go +++ b/auth/mocks/service.go @@ -1344,6 +1344,63 @@ func (_c *Service_RevokePATSecret_Call) RunAndReturn(run func(ctx context.Contex return _c } +// RevokeToken provides a mock function for the type Service +func (_mock *Service) RevokeToken(ctx context.Context, token string) error { + ret := _mock.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for RevokeToken") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, token) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RevokeToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeToken' +type Service_RevokeToken_Call struct { + *mock.Call +} + +// RevokeToken is a helper method to define mock.On call +// - ctx context.Context +// - token string +func (_e *Service_Expecter) RevokeToken(ctx interface{}, token interface{}) *Service_RevokeToken_Call { + return &Service_RevokeToken_Call{Call: _e.mock.On("RevokeToken", ctx, token)} +} + +func (_c *Service_RevokeToken_Call) Run(run func(ctx context.Context, token string)) *Service_RevokeToken_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_RevokeToken_Call) Return(err error) *Service_RevokeToken_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RevokeToken_Call) RunAndReturn(run func(ctx context.Context, token string) error) *Service_RevokeToken_Call { + _c.Call.Return(run) + return _c +} + // UpdatePATDescription provides a mock function for the type Service func (_mock *Service) UpdatePATDescription(ctx context.Context, token string, patID string, description string) (auth.PAT, error) { ret := _mock.Called(ctx, token, patID, description) diff --git a/auth/mocks/token_client.go b/auth/mocks/token_client.go index 025065092f..756512216a 100644 --- a/auth/mocks/token_client.go +++ b/auth/mocks/token_client.go @@ -208,3 +208,86 @@ func (_c *TokenServiceClient_Refresh_Call) RunAndReturn(run func(ctx context.Con _c.Call.Return(run) return _c } + +// Revoke provides a mock function for the type TokenServiceClient +func (_mock *TokenServiceClient) Revoke(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption) (*v1.RevokeRes, error) { + var tmpRet mock.Arguments + if len(opts) > 0 { + tmpRet = _mock.Called(ctx, in, opts) + } else { + tmpRet = _mock.Called(ctx, in) + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for Revoke") + } + + var r0 *v1.RevokeRes + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) (*v1.RevokeRes, error)); ok { + return returnFunc(ctx, in, opts...) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) *v1.RevokeRes); ok { + r0 = returnFunc(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.RevokeRes) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) error); ok { + r1 = returnFunc(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// TokenServiceClient_Revoke_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Revoke' +type TokenServiceClient_Revoke_Call struct { + *mock.Call +} + +// Revoke is a helper method to define mock.On call +// - ctx context.Context +// - in *v1.RevokeReq +// - opts ...grpc.CallOption +func (_e *TokenServiceClient_Expecter) Revoke(ctx interface{}, in interface{}, opts ...interface{}) *TokenServiceClient_Revoke_Call { + return &TokenServiceClient_Revoke_Call{Call: _e.mock.On("Revoke", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *TokenServiceClient_Revoke_Call) Run(run func(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption)) *TokenServiceClient_Revoke_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *v1.RevokeReq + if args[1] != nil { + arg1 = args[1].(*v1.RevokeReq) + } + var arg2 []grpc.CallOption + var variadicArgs []grpc.CallOption + if len(args) > 2 { + variadicArgs = args[2].([]grpc.CallOption) + } + arg2 = variadicArgs + run( + arg0, + arg1, + arg2..., + ) + }) + return _c +} + +func (_c *TokenServiceClient_Revoke_Call) Return(revokeRes *v1.RevokeRes, err error) *TokenServiceClient_Revoke_Call { + _c.Call.Return(revokeRes, err) + return _c +} + +func (_c *TokenServiceClient_Revoke_Call) RunAndReturn(run func(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption) (*v1.RevokeRes, error)) *TokenServiceClient_Revoke_Call { + _c.Call.Return(run) + return _c +} diff --git a/auth/mocks/tokenizer.go b/auth/mocks/tokenizer.go index f4b5c786f9..b9b66d2e38 100644 --- a/auth/mocks/tokenizer.go +++ b/auth/mocks/tokenizer.go @@ -43,8 +43,8 @@ func (_m *Tokenizer) EXPECT() *Tokenizer_Expecter { } // Issue provides a mock function for the type Tokenizer -func (_mock *Tokenizer) Issue(key auth.Key) (string, error) { - ret := _mock.Called(key) +func (_mock *Tokenizer) Issue(ctx context.Context, key auth.Key) (string, error) { + ret := _mock.Called(ctx, key) if len(ret) == 0 { panic("no return value specified for Issue") @@ -52,16 +52,16 @@ func (_mock *Tokenizer) Issue(key auth.Key) (string, error) { var r0 string var r1 error - if returnFunc, ok := ret.Get(0).(func(auth.Key) (string, error)); ok { - return returnFunc(key) + if returnFunc, ok := ret.Get(0).(func(context.Context, auth.Key) (string, error)); ok { + return returnFunc(ctx, key) } - if returnFunc, ok := ret.Get(0).(func(auth.Key) string); ok { - r0 = returnFunc(key) + if returnFunc, ok := ret.Get(0).(func(context.Context, auth.Key) string); ok { + r0 = returnFunc(ctx, key) } else { r0 = ret.Get(0).(string) } - if returnFunc, ok := ret.Get(1).(func(auth.Key) error); ok { - r1 = returnFunc(key) + if returnFunc, ok := ret.Get(1).(func(context.Context, auth.Key) error); ok { + r1 = returnFunc(ctx, key) } else { r1 = ret.Error(1) } @@ -74,19 +74,25 @@ type Tokenizer_Issue_Call struct { } // Issue is a helper method to define mock.On call +// - ctx context.Context // - key auth.Key -func (_e *Tokenizer_Expecter) Issue(key interface{}) *Tokenizer_Issue_Call { - return &Tokenizer_Issue_Call{Call: _e.mock.On("Issue", key)} +func (_e *Tokenizer_Expecter) Issue(ctx interface{}, key interface{}) *Tokenizer_Issue_Call { + return &Tokenizer_Issue_Call{Call: _e.mock.On("Issue", ctx, key)} } -func (_c *Tokenizer_Issue_Call) Run(run func(key auth.Key)) *Tokenizer_Issue_Call { +func (_c *Tokenizer_Issue_Call) Run(run func(ctx context.Context, key auth.Key)) *Tokenizer_Issue_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 auth.Key + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(auth.Key) + arg0 = args[0].(context.Context) + } + var arg1 auth.Key + if args[1] != nil { + arg1 = args[1].(auth.Key) } run( arg0, + arg1, ) }) return _c @@ -97,7 +103,7 @@ func (_c *Tokenizer_Issue_Call) Return(token string, err error) *Tokenizer_Issue return _c } -func (_c *Tokenizer_Issue_Call) RunAndReturn(run func(key auth.Key) (string, error)) *Tokenizer_Issue_Call { +func (_c *Tokenizer_Issue_Call) RunAndReturn(run func(ctx context.Context, key auth.Key) (string, error)) *Tokenizer_Issue_Call { _c.Call.Return(run) return _c } @@ -222,3 +228,60 @@ func (_c *Tokenizer_RetrieveJWKS_Call) RunAndReturn(run func() ([]auth.PublicKey _c.Call.Return(run) return _c } + +// Revoke provides a mock function for the type Tokenizer +func (_mock *Tokenizer) Revoke(ctx context.Context, token string) error { + ret := _mock.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for Revoke") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, token) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Tokenizer_Revoke_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Revoke' +type Tokenizer_Revoke_Call struct { + *mock.Call +} + +// Revoke is a helper method to define mock.On call +// - ctx context.Context +// - token string +func (_e *Tokenizer_Expecter) Revoke(ctx interface{}, token interface{}) *Tokenizer_Revoke_Call { + return &Tokenizer_Revoke_Call{Call: _e.mock.On("Revoke", ctx, token)} +} + +func (_c *Tokenizer_Revoke_Call) Run(run func(ctx context.Context, token string)) *Tokenizer_Revoke_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Tokenizer_Revoke_Call) Return(err error) *Tokenizer_Revoke_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Tokenizer_Revoke_Call) RunAndReturn(run func(ctx context.Context, token string) error) *Tokenizer_Revoke_Call { + _c.Call.Return(run) + return _c +} diff --git a/auth/mocks/tokens_cache.go b/auth/mocks/tokens_cache.go new file mode 100644 index 0000000000..df3eddb784 --- /dev/null +++ b/auth/mocks/tokens_cache.go @@ -0,0 +1,303 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + "time" + + mock "github.com/stretchr/testify/mock" +) + +// NewTokensCache creates a new instance of TokensCache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTokensCache(t interface { + mock.TestingT + Cleanup(func()) +}) *TokensCache { + mock := &TokensCache{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// TokensCache is an autogenerated mock type for the TokensCache type +type TokensCache struct { + mock.Mock +} + +type TokensCache_Expecter struct { + mock *mock.Mock +} + +func (_m *TokensCache) EXPECT() *TokensCache_Expecter { + return &TokensCache_Expecter{mock: &_m.Mock} +} + +// IsActive provides a mock function for the type TokensCache +func (_mock *TokensCache) IsActive(ctx context.Context, tokenID string) (bool, error) { + ret := _mock.Called(ctx, tokenID) + + if len(ret) == 0 { + panic("no return value specified for IsActive") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (bool, error)); ok { + return returnFunc(ctx, tokenID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) bool); ok { + r0 = returnFunc(ctx, tokenID) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, tokenID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// TokensCache_IsActive_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsActive' +type TokensCache_IsActive_Call struct { + *mock.Call +} + +// IsActive is a helper method to define mock.On call +// - ctx context.Context +// - tokenID string +func (_e *TokensCache_Expecter) IsActive(ctx interface{}, tokenID interface{}) *TokensCache_IsActive_Call { + return &TokensCache_IsActive_Call{Call: _e.mock.On("IsActive", ctx, tokenID)} +} + +func (_c *TokensCache_IsActive_Call) Run(run func(ctx context.Context, tokenID string)) *TokensCache_IsActive_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *TokensCache_IsActive_Call) Return(b bool, err error) *TokensCache_IsActive_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *TokensCache_IsActive_Call) RunAndReturn(run func(ctx context.Context, tokenID string) (bool, error)) *TokensCache_IsActive_Call { + _c.Call.Return(run) + return _c +} + +// ListUserTokens provides a mock function for the type TokensCache +func (_mock *TokensCache) ListUserTokens(ctx context.Context, userID string) ([]string, error) { + ret := _mock.Called(ctx, userID) + + if len(ret) == 0 { + panic("no return value specified for ListUserTokens") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) ([]string, error)); ok { + return returnFunc(ctx, userID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) []string); ok { + r0 = returnFunc(ctx, userID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, userID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// TokensCache_ListUserTokens_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserTokens' +type TokensCache_ListUserTokens_Call struct { + *mock.Call +} + +// ListUserTokens is a helper method to define mock.On call +// - ctx context.Context +// - userID string +func (_e *TokensCache_Expecter) ListUserTokens(ctx interface{}, userID interface{}) *TokensCache_ListUserTokens_Call { + return &TokensCache_ListUserTokens_Call{Call: _e.mock.On("ListUserTokens", ctx, userID)} +} + +func (_c *TokensCache_ListUserTokens_Call) Run(run func(ctx context.Context, userID string)) *TokensCache_ListUserTokens_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *TokensCache_ListUserTokens_Call) Return(strings []string, err error) *TokensCache_ListUserTokens_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *TokensCache_ListUserTokens_Call) RunAndReturn(run func(ctx context.Context, userID string) ([]string, error)) *TokensCache_ListUserTokens_Call { + _c.Call.Return(run) + return _c +} + +// RemoveActive provides a mock function for the type TokensCache +func (_mock *TokensCache) RemoveActive(ctx context.Context, tokenID string) error { + ret := _mock.Called(ctx, tokenID) + + if len(ret) == 0 { + panic("no return value specified for RemoveActive") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, tokenID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// TokensCache_RemoveActive_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveActive' +type TokensCache_RemoveActive_Call struct { + *mock.Call +} + +// RemoveActive is a helper method to define mock.On call +// - ctx context.Context +// - tokenID string +func (_e *TokensCache_Expecter) RemoveActive(ctx interface{}, tokenID interface{}) *TokensCache_RemoveActive_Call { + return &TokensCache_RemoveActive_Call{Call: _e.mock.On("RemoveActive", ctx, tokenID)} +} + +func (_c *TokensCache_RemoveActive_Call) Run(run func(ctx context.Context, tokenID string)) *TokensCache_RemoveActive_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *TokensCache_RemoveActive_Call) Return(err error) *TokensCache_RemoveActive_Call { + _c.Call.Return(err) + return _c +} + +func (_c *TokensCache_RemoveActive_Call) RunAndReturn(run func(ctx context.Context, tokenID string) error) *TokensCache_RemoveActive_Call { + _c.Call.Return(run) + return _c +} + +// SaveActive provides a mock function for the type TokensCache +func (_mock *TokensCache) SaveActive(ctx context.Context, userID string, tokenID string, ttl time.Duration) error { + ret := _mock.Called(ctx, userID, tokenID, ttl) + + if len(ret) == 0 { + panic("no return value specified for SaveActive") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, time.Duration) error); ok { + r0 = returnFunc(ctx, userID, tokenID, ttl) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// TokensCache_SaveActive_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveActive' +type TokensCache_SaveActive_Call struct { + *mock.Call +} + +// SaveActive is a helper method to define mock.On call +// - ctx context.Context +// - userID string +// - tokenID string +// - ttl time.Duration +func (_e *TokensCache_Expecter) SaveActive(ctx interface{}, userID interface{}, tokenID interface{}, ttl interface{}) *TokensCache_SaveActive_Call { + return &TokensCache_SaveActive_Call{Call: _e.mock.On("SaveActive", ctx, userID, tokenID, ttl)} +} + +func (_c *TokensCache_SaveActive_Call) Run(run func(ctx context.Context, userID string, tokenID string, ttl time.Duration)) *TokensCache_SaveActive_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 time.Duration + if args[3] != nil { + arg3 = args[3].(time.Duration) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *TokensCache_SaveActive_Call) Return(err error) *TokensCache_SaveActive_Call { + _c.Call.Return(err) + return _c +} + +func (_c *TokensCache_SaveActive_Call) RunAndReturn(run func(ctx context.Context, userID string, tokenID string, ttl time.Duration) error) *TokensCache_SaveActive_Call { + _c.Call.Return(run) + return _c +} diff --git a/auth/mocks/tokens_repository.go b/auth/mocks/tokens_repository.go new file mode 100644 index 0000000000..7f69607349 --- /dev/null +++ b/auth/mocks/tokens_repository.go @@ -0,0 +1,156 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + mock "github.com/stretchr/testify/mock" +) + +// NewTokensRepository creates a new instance of TokensRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTokensRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *TokensRepository { + mock := &TokensRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// TokensRepository is an autogenerated mock type for the TokensRepository type +type TokensRepository struct { + mock.Mock +} + +type TokensRepository_Expecter struct { + mock *mock.Mock +} + +func (_m *TokensRepository) EXPECT() *TokensRepository_Expecter { + return &TokensRepository_Expecter{mock: &_m.Mock} +} + +// Contains provides a mock function for the type TokensRepository +func (_mock *TokensRepository) Contains(ctx context.Context, id string) bool { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Contains") + } + + var r0 bool + if returnFunc, ok := ret.Get(0).(func(context.Context, string) bool); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Get(0).(bool) + } + return r0 +} + +// TokensRepository_Contains_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Contains' +type TokensRepository_Contains_Call struct { + *mock.Call +} + +// Contains is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *TokensRepository_Expecter) Contains(ctx interface{}, id interface{}) *TokensRepository_Contains_Call { + return &TokensRepository_Contains_Call{Call: _e.mock.On("Contains", ctx, id)} +} + +func (_c *TokensRepository_Contains_Call) Run(run func(ctx context.Context, id string)) *TokensRepository_Contains_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *TokensRepository_Contains_Call) Return(b bool) *TokensRepository_Contains_Call { + _c.Call.Return(b) + return _c +} + +func (_c *TokensRepository_Contains_Call) RunAndReturn(run func(ctx context.Context, id string) bool) *TokensRepository_Contains_Call { + _c.Call.Return(run) + return _c +} + +// Save provides a mock function for the type TokensRepository +func (_mock *TokensRepository) Save(ctx context.Context, id string) error { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// TokensRepository_Save_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Save' +type TokensRepository_Save_Call struct { + *mock.Call +} + +// Save is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *TokensRepository_Expecter) Save(ctx interface{}, id interface{}) *TokensRepository_Save_Call { + return &TokensRepository_Save_Call{Call: _e.mock.On("Save", ctx, id)} +} + +func (_c *TokensRepository_Save_Call) Run(run func(ctx context.Context, id string)) *TokensRepository_Save_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *TokensRepository_Save_Call) Return(err error) *TokensRepository_Save_Call { + _c.Call.Return(err) + return _c +} + +func (_c *TokensRepository_Save_Call) RunAndReturn(run func(ctx context.Context, id string) error) *TokensRepository_Save_Call { + _c.Call.Return(run) + return _c +} diff --git a/auth/postgres/errors.go b/auth/postgres/errors.go new file mode 100644 index 0000000000..45477426e6 --- /dev/null +++ b/auth/postgres/errors.go @@ -0,0 +1,24 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import "github.com/absmach/supermq/pkg/errors" + +var _ errors.Mapper = (*duplicateErrors)(nil) + +type duplicateErrors struct{} + +// GetError maps constraint names to known errors. +func (d duplicateErrors) GetError(constraint string) (error, bool) { + switch constraint { + case "revoked_tokens_pkey": + return errors.NewRequestError("revoked token already exists"), true + default: + return nil, false + } +} + +func NewDuplicateErrors() errors.Mapper { + return duplicateErrors{} +} diff --git a/auth/service.go b/auth/service.go index b88a8efaaa..dd57112734 100644 --- a/auth/service.go +++ b/auth/service.go @@ -72,6 +72,9 @@ type Authn interface { // Issue issues a new Key, returning its token value alongside. Issue(ctx context.Context, token string, key Key) (Token, error) + // RevokeToken revokes the token. + RevokeToken(ctx context.Context, token string) error + // Revoke removes the Key with the provided id that is // issued by the user identified by the provided key. Revoke(ctx context.Context, token, id string) error @@ -148,6 +151,10 @@ func (svc service) Issue(ctx context.Context, token string, key Key) (Token, err } } +func (svc service) RevokeToken(ctx context.Context, token string) error { + return svc.tokenizer.Revoke(ctx, token) +} + func (svc service) Revoke(ctx context.Context, token, id string) error { issuerID, _, err := svc.authenticate(ctx, token) if err != nil { @@ -287,7 +294,7 @@ func (svc service) tmpKey(ctx context.Context, duration time.Duration, key Key) if err := svc.checkUserRole(ctx, key); err != nil { return Token{}, errors.Wrap(errIssueTmp, err) } - value, err := svc.tokenizer.Issue(key) + value, err := svc.tokenizer.Issue(ctx, key) if err != nil { return Token{}, errors.Wrap(errIssueTmp, err) } @@ -304,14 +311,19 @@ func (svc service) accessKey(ctx context.Context, key Key) (Token, error) { return Token{}, errors.Wrap(errIssueUser, err) } - access, err := svc.tokenizer.Issue(key) + access, err := svc.tokenizer.Issue(ctx, key) if err != nil { return Token{}, errors.Wrap(errIssueTmp, err) } key.ExpiresAt = time.Now().UTC().Add(svc.refreshDuration) key.Type = RefreshKey - refresh, err := svc.tokenizer.Issue(key) + id, err := svc.idProvider.ID() + if err != nil { + return Token{}, errors.Wrap(errIssueTmp, err) + } + key.ID = id + refresh, err := svc.tokenizer.Issue(ctx, key) if err != nil { return Token{}, errors.Wrap(errIssueTmp, err) } @@ -328,7 +340,7 @@ func (svc service) invitationKey(ctx context.Context, key Key) (Token, error) { return Token{}, errors.Wrap(errIssueTmp, err) } - access, err := svc.tokenizer.Issue(key) + access, err := svc.tokenizer.Issue(ctx, key) if err != nil { return Token{}, errors.Wrap(errIssueTmp, err) } @@ -354,14 +366,14 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token key.Role = k.Role key.ExpiresAt = time.Now().UTC().Add(svc.loginDuration) - access, err := svc.tokenizer.Issue(key) + access, err := svc.tokenizer.Issue(ctx, key) if err != nil { return Token{}, errors.Wrap(errIssueTmp, err) } - key.ExpiresAt = time.Now().UTC().Add(svc.refreshDuration) + key.ExpiresAt = k.ExpiresAt key.Type = RefreshKey - refresh, err := svc.tokenizer.Issue(key) + refresh, err := svc.tokenizer.Issue(ctx, key) if err != nil { return Token{}, errors.Wrap(errIssueTmp, err) } @@ -437,7 +449,7 @@ func (svc service) userKey(ctx context.Context, token string, key Key) (Token, e return Token{}, errors.Wrap(errIssueUser, err) } - tkn, err := svc.tokenizer.Issue(key) + tkn, err := svc.tokenizer.Issue(ctx, key) if err != nil { return Token{}, errors.Wrap(errIssueUser, err) } diff --git a/auth/service_test.go b/auth/service_test.go index f93ade619f..dea2d6e3e4 100644 --- a/auth/service_test.go +++ b/auth/service_test.go @@ -133,7 +133,7 @@ func TestIssue(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.tokenizerErr) + tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.tokenizerErr) policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ Subject: tc.key.Subject, SubjectType: policies.UserType, @@ -172,7 +172,7 @@ func TestIssue(t *testing.T) { } for _, tc := range cases2 { t.Run(tc.desc, func(t *testing.T) { - tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.tokenizerErr) + tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.tokenizerErr) repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr) policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ Subject: tc.key.Subject, @@ -265,7 +265,7 @@ func TestIssue(t *testing.T) { } for _, tc := range cases3 { t.Run(tc.desc, func(t *testing.T) { - tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.issueErr) + tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.issueErr) tokenizerCall1 := tokenizer.On("Parse", mock.Anything, tc.token).Return(tc.parseRes, tc.parseErr) repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr) policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ @@ -346,8 +346,9 @@ func TestIssue(t *testing.T) { } for _, tc := range cases4 { t.Run(tc.desc, func(t *testing.T) { - tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.issueErr) + tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.issueErr) tokenizerCall1 := tokenizer.On("Parse", mock.Anything, tc.token).Return(tc.parseRes, tc.parseErr) + tokenizerCall2 := tokenizer.On("Revoke", mock.Anything, tc.token).Return(tc.parseErr) policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ Subject: tc.key.Subject, SubjectType: policies.UserType, @@ -359,6 +360,7 @@ func TestIssue(t *testing.T) { assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) tokenizerCall.Unset() tokenizerCall1.Unset() + tokenizerCall2.Unset() policyCall.Unset() }) } diff --git a/auth/tokenizer/asymmetric/rotation_test.go b/auth/tokenizer/asymmetric/rotation_test.go index 64547620aa..c60e47aa5e 100644 --- a/auth/tokenizer/asymmetric/rotation_test.go +++ b/auth/tokenizer/asymmetric/rotation_test.go @@ -16,6 +16,7 @@ import ( "time" "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/auth/mocks" "github.com/absmach/supermq/auth/tokenizer/asymmetric" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -46,7 +47,8 @@ func TestTwoKeyRotation(t *testing.T) { saveKey(t, retiringPriv, retiringKeyPath) idProvider := &incrementingIDProvider{} - tokenizer, err := asymmetric.NewTokenizer(activeKeyPath, retiringKeyPath, idProvider, newTestLogger()) + cache := new(mocks.TokensCache) + tokenizer, err := asymmetric.NewTokenizer(activeKeyPath, retiringKeyPath, idProvider, cache, newTestLogger()) require.NoError(t, err) testKey := auth.Key{ @@ -59,7 +61,7 @@ func TestTwoKeyRotation(t *testing.T) { Verified: true, } - token, err := tokenizer.Issue(testKey) + token, err := tokenizer.Issue(context.Background(), testKey) require.NoError(t, err) assert.NotEmpty(t, token) @@ -88,7 +90,8 @@ func TestSingleKeyMode(t *testing.T) { saveKey(t, privateKey, keyPath) idProvider := &mockIDProvider{id: "single-id"} - tokenizer, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger()) + cache := new(mocks.TokensCache) + tokenizer, err := asymmetric.NewTokenizer(keyPath, "", idProvider, cache, newTestLogger()) require.NoError(t, err) testKey := auth.Key{ @@ -100,7 +103,7 @@ func TestSingleKeyMode(t *testing.T) { ExpiresAt: time.Now().Add(1 * time.Hour).UTC(), } - token, err := tokenizer.Issue(testKey) + token, err := tokenizer.Issue(context.Background(), testKey) require.NoError(t, err) _, err = tokenizer.Parse(context.Background(), token) @@ -123,7 +126,8 @@ func TestMissingRetiringKey(t *testing.T) { retiringKeyPath := filepath.Join(tmpDir, "nonexistent.key") idProvider := &mockIDProvider{id: "test-id"} - tokenizer, err := asymmetric.NewTokenizer(activeKeyPath, retiringKeyPath, idProvider, newTestLogger()) + cache := new(mocks.TokensCache) + tokenizer, err := asymmetric.NewTokenizer(activeKeyPath, retiringKeyPath, idProvider, cache, newTestLogger()) require.NoError(t, err, "Should succeed even if retiring key is missing") testKey := auth.Key{ @@ -135,7 +139,7 @@ func TestMissingRetiringKey(t *testing.T) { ExpiresAt: time.Now().Add(1 * time.Hour).UTC(), } - token, err := tokenizer.Issue(testKey) + token, err := tokenizer.Issue(context.Background(), testKey) require.NoError(t, err) _, err = tokenizer.Parse(context.Background(), token) diff --git a/auth/tokenizer/asymmetric/tokenizer.go b/auth/tokenizer/asymmetric/tokenizer.go index ac8d168ece..591dc3dbe6 100644 --- a/auth/tokenizer/asymmetric/tokenizer.go +++ b/auth/tokenizer/asymmetric/tokenizer.go @@ -13,11 +13,13 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/absmach/supermq" "github.com/absmach/supermq/auth" smqjwt "github.com/absmach/supermq/auth/tokenizer/util" "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" @@ -47,6 +49,7 @@ type keyPair struct { type tokenizer struct { activeKey *keyPair retiringKey *keyPair // Optional, for key rotation grace period + cache auth.TokensCache } var _ auth.Tokenizer = (*tokenizer)(nil) @@ -56,7 +59,7 @@ var _ auth.Tokenizer = (*tokenizer)(nil) // If retiringKeyPath is provided but the file doesn't exist or is invalid, a warning is logged // but the tokenizer is still created with just the active key. // Key IDs are derived from filenames to ensure consistency across multiple service instances. -func NewTokenizer(activeKeyPath, retiringKeyPath string, idProvider supermq.IDProvider, logger *slog.Logger) (auth.Tokenizer, error) { +func NewTokenizer(activeKeyPath, retiringKeyPath string, idProvider supermq.IDProvider, cache auth.TokensCache, logger *slog.Logger) (auth.Tokenizer, error) { activeKID := keyIDFromPath(activeKeyPath) activePrivateJwk, activePublicJwk, err := loadKeyPair(activeKeyPath, activeKID) @@ -70,6 +73,7 @@ func NewTokenizer(activeKeyPath, retiringKeyPath string, idProvider supermq.IDPr privateKey: activePrivateJwk, publicKey: activePublicJwk, }, + cache: cache, } if retiringKeyPath != "" { @@ -95,8 +99,8 @@ func NewTokenizer(activeKeyPath, retiringKeyPath string, idProvider supermq.IDPr return mgr, nil } -func (km *tokenizer) Issue(key auth.Key) (string, error) { - if km.activeKey == nil { +func (tok *tokenizer) Issue(ctx context.Context, key auth.Key) (string, error) { + if tok.activeKey == nil { return "", errNoActiveKey } @@ -105,29 +109,94 @@ func (km *tokenizer) Issue(key auth.Key) (string, error) { return "", err } headers := jws.NewHeaders() - if err := headers.Set(jwk.KeyIDKey, km.activeKey.id); err != nil { + if err := headers.Set(jwk.KeyIDKey, tok.activeKey.id); err != nil { return "", err } - signedBytes, err := jwt.Sign(tkn, jwt.WithKey(jwa.EdDSA, km.activeKey.privateKey, jws.WithProtectedHeaders(headers))) + signedBytes, err := jwt.Sign(tkn, jwt.WithKey(jwa.EdDSA, tok.activeKey.privateKey, jws.WithProtectedHeaders(headers))) if err != nil { return "", err } + if key.Type == auth.RefreshKey && key.ID != "" && key.Subject != "" { + ttl := time.Until(key.ExpiresAt) + if ttl > 0 { + if err := tok.cache.SaveActive(ctx, key.Subject, key.ID, ttl); err != nil { + return "", err + } + } + } + return string(signedBytes), nil } -func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) { +func (tok *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) { + key, err := tok.parseToken(tokenString) + if err != nil { + return auth.Key{}, err + } + if key.Type == auth.RefreshKey { + found, err := tok.cache.IsActive(ctx, key.ID) + if err != nil { + return auth.Key{}, err + } + if !found { + return auth.Key{}, auth.ErrRevokedToken + } + } + + return key, nil +} + +func (tok *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) { + publicKeys := make([]auth.PublicKeyInfo, 0, 2) + + if tok.activeKey != nil { + if pkInfo := extractPublicKeyInfo(tok.activeKey); pkInfo != nil { + publicKeys = append(publicKeys, *pkInfo) + } + } + + if tok.retiringKey != nil { + if pkInfo := extractPublicKeyInfo(tok.retiringKey); pkInfo != nil { + publicKeys = append(publicKeys, *pkInfo) + } + } + + if len(publicKeys) == 0 { + return nil, errNoValidPublicKeys + } + + return publicKeys, nil +} + +func (tok *tokenizer) Revoke(ctx context.Context, token string) error { + key, err := tok.parseToken(token) + if err != nil { + return err + } + + if key.Type == auth.RefreshKey { + // Remove the refresh token from active tokens + if err := tok.cache.RemoveActive(ctx, key.ID); err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + } + + return nil +} + +func (tok *tokenizer) parseToken(tokenString string) (auth.Key, error) { if len(tokenString) >= 3 && tokenString[:3] == patPrefix { return auth.Key{Type: auth.PersonalAccessToken}, nil } set := jwk.NewSet() - if err := set.AddKey(km.activeKey.publicKey); err != nil { + if err := set.AddKey(tok.activeKey.publicKey); err != nil { return auth.Key{}, err } - if km.retiringKey != nil { - if err := set.AddKey(km.retiringKey.publicKey); err != nil { + if tok.retiringKey != nil { + if err := set.AddKey(tok.retiringKey.publicKey); err != nil { return auth.Key{}, err } } @@ -148,28 +217,6 @@ func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, e return smqjwt.ToKey(tkn) } -func (km *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) { - publicKeys := make([]auth.PublicKeyInfo, 0, 2) - - if km.activeKey != nil { - if pkInfo := extractPublicKeyInfo(km.activeKey); pkInfo != nil { - publicKeys = append(publicKeys, *pkInfo) - } - } - - if km.retiringKey != nil { - if pkInfo := extractPublicKeyInfo(km.retiringKey); pkInfo != nil { - publicKeys = append(publicKeys, *pkInfo) - } - } - - if len(publicKeys) == 0 { - return nil, errNoValidPublicKeys - } - - return publicKeys, nil -} - func extractPublicKeyInfo(kp *keyPair) *auth.PublicKeyInfo { var rawKey ed25519.PublicKey if err := kp.publicKey.Raw(&rawKey); err != nil { diff --git a/auth/tokenizer/asymmetric/tokenizer_test.go b/auth/tokenizer/asymmetric/tokenizer_test.go index 4679ca6c5a..a8d3c9067c 100644 --- a/auth/tokenizer/asymmetric/tokenizer_test.go +++ b/auth/tokenizer/asymmetric/tokenizer_test.go @@ -17,6 +17,7 @@ import ( "time" "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/auth/mocks" "github.com/absmach/supermq/auth/tokenizer/asymmetric" smqerrors "github.com/absmach/supermq/pkg/errors" "github.com/lestrrat-go/jwx/v2/jwa" @@ -40,6 +41,7 @@ func newTestLogger() *slog.Logger { func TestNewKeyManager(t *testing.T) { idProvider := &mockIDProvider{id: "unused"} + cache := new(mocks.TokensCache) tmpDir := t.TempDir() keyPath := filepath.Join(tmpDir, "private.key") @@ -105,7 +107,7 @@ func TestNewKeyManager(t *testing.T) { t.Run(tc.name, func(t *testing.T) { path := tc.setupKey() - km, err := asymmetric.NewTokenizer(path, "", idProvider, newTestLogger()) + km, err := asymmetric.NewTokenizer(path, "", idProvider, cache, newTestLogger()) if tc.expectErr { assert.Error(t, err) @@ -123,6 +125,7 @@ func TestNewKeyManager(t *testing.T) { func TestSign(t *testing.T) { idProvider := &mockIDProvider{id: "unused"} + cache := new(mocks.TokensCache) tmpDir := t.TempDir() keyPath := filepath.Join(tmpDir, "private.key") @@ -141,7 +144,7 @@ func TestSign(t *testing.T) { err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600) require.NoError(t, err) - km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger()) + km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, cache, newTestLogger()) require.NoError(t, err) cases := []struct { @@ -187,7 +190,7 @@ func TestSign(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - token, err := km.Issue(tc.key) + token, err := km.Issue(context.Background(), tc.key) assert.NoError(t, err) assert.NotEmpty(t, token) @@ -199,6 +202,7 @@ func TestSign(t *testing.T) { func TestVerify(t *testing.T) { idProvider := &mockIDProvider{id: "unused"} + cache := new(mocks.TokensCache) tmpDir := t.TempDir() keyPath := filepath.Join(tmpDir, "private.key") @@ -218,7 +222,7 @@ func TestVerify(t *testing.T) { err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600) require.NoError(t, err) - km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger()) + km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, cache, newTestLogger()) require.NoError(t, err) validKey := auth.Key{ @@ -232,12 +236,12 @@ func TestVerify(t *testing.T) { Verified: true, } - validToken, err := km.Issue(validKey) + validToken, err := km.Issue(context.Background(), validKey) require.NoError(t, err, "Signing a valid token should succeed") expiredKey := validKey expiredKey.ExpiresAt = time.Now().Add(-1 * time.Hour).UTC() - expiredToken, err := km.Issue(expiredKey) + expiredToken, err := km.Issue(context.Background(), expiredKey) require.NoError(t, err, "Creating an expired token should succeed") wrongIssuerKey := validKey @@ -317,6 +321,7 @@ func TestVerify(t *testing.T) { func TestPublicKeys(t *testing.T) { idProvider := &mockIDProvider{id: "unused"} + cache := new(mocks.TokensCache) tmpDir := t.TempDir() keyPath := filepath.Join(tmpDir, "private.key") @@ -336,7 +341,7 @@ func TestPublicKeys(t *testing.T) { err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600) require.NoError(t, err) - km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger()) + km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, cache, newTestLogger()) require.NoError(t, err) keys, err := km.RetrieveJWKS() @@ -358,6 +363,7 @@ func TestPublicKeys(t *testing.T) { func TestSignAndVerifyRoundTrip(t *testing.T) { idProvider := &mockIDProvider{id: "unused"} + cache := new(mocks.TokensCache) tmpDir := t.TempDir() keyPath := filepath.Join(tmpDir, "private.key") @@ -376,7 +382,7 @@ func TestSignAndVerifyRoundTrip(t *testing.T) { err = os.WriteFile(keyPath, pem.EncodeToMemory(pemBlock), 0o600) require.NoError(t, err) - km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, newTestLogger()) + km, err := asymmetric.NewTokenizer(keyPath, "", idProvider, cache, newTestLogger()) require.NoError(t, err) originalKey := auth.Key{ @@ -390,7 +396,7 @@ func TestSignAndVerifyRoundTrip(t *testing.T) { Verified: true, } - token, err := km.Issue(originalKey) + token, err := km.Issue(context.Background(), originalKey) require.NoError(t, err) verifiedKey, err := km.Parse(context.Background(), token) diff --git a/auth/tokenizer/symmetric/tokenizer.go b/auth/tokenizer/symmetric/tokenizer.go index 5933bb697f..e6ea1b1220 100644 --- a/auth/tokenizer/symmetric/tokenizer.go +++ b/auth/tokenizer/symmetric/tokenizer.go @@ -5,6 +5,7 @@ package symmetric import ( "context" + "time" "github.com/absmach/supermq/auth" smqjwt "github.com/absmach/supermq/auth/tokenizer/util" @@ -14,20 +15,19 @@ import ( "github.com/lestrrat-go/jwx/v2/jwt" ) -const ( - patPrefix = "pat" -) +const patPrefix = "pat" var errJWTExpiryKey = errors.New(`"exp" not satisfied`) type tokenizer struct { algorithm jwa.KeyAlgorithm secret []byte + cache auth.TokensCache } var _ auth.Tokenizer = (*tokenizer)(nil) -func NewTokenizer(algorithm string, secret []byte) (auth.Tokenizer, error) { +func NewTokenizer(algorithm string, secret []byte, cache auth.TokensCache) (auth.Tokenizer, error) { alg := jwa.KeyAlgorithmFrom(algorithm) if _, ok := alg.(jwa.InvalidKeyAlgorithm); ok { return nil, auth.ErrUnsupportedKeyAlgorithm @@ -38,24 +38,74 @@ func NewTokenizer(algorithm string, secret []byte) (auth.Tokenizer, error) { return &tokenizer{ secret: secret, algorithm: alg, + cache: cache, }, nil } -func (km *tokenizer) Issue(key auth.Key) (string, error) { +func (tok *tokenizer) Issue(ctx context.Context, key auth.Key) (string, error) { tkn, err := smqjwt.BuildToken(key) if err != nil { return "", err } - signedBytes, err := jwt.Sign(tkn, jwt.WithKey(km.algorithm, km.secret)) + signedBytes, err := jwt.Sign(tkn, jwt.WithKey(tok.algorithm, tok.secret)) if err != nil { return "", err } + // Store refresh tokens as active with TTL + if key.Type == auth.RefreshKey && key.ID != "" && key.Subject != "" { + ttl := time.Until(key.ExpiresAt) + if ttl > 0 { + if err := tok.cache.SaveActive(ctx, key.Subject, key.ID, ttl); err != nil { + return "", err + } + } + } + return string(signedBytes), nil } -func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) { +func (tok *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) { + key, err := tok.parseToken(tokenString) + if err != nil { + return auth.Key{}, err + } + if key.Type == auth.RefreshKey { + // Check if the refresh token is active for this user + found, err := tok.cache.IsActive(ctx, key.ID) + if err != nil { + return auth.Key{}, err + } + if !found { + return auth.Key{}, auth.ErrRevokedToken + } + } + + return key, nil +} + +func (tok *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) { + return nil, auth.ErrPublicKeysNotSupported +} + +func (tok *tokenizer) Revoke(ctx context.Context, token string) error { + key, err := tok.parseToken(token) + if err != nil { + return err + } + + if key.Type == auth.RefreshKey { + // Remove the refresh token from active tokens + if err := tok.cache.RemoveActive(ctx, key.ID); err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + } + + return nil +} + +func (tok *tokenizer) parseToken(tokenString string) (auth.Key, error) { if len(tokenString) >= 3 && tokenString[:3] == patPrefix { return auth.Key{Type: auth.PersonalAccessToken}, nil } @@ -63,7 +113,7 @@ func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, e tkn, err := jwt.Parse( []byte(tokenString), jwt.WithValidate(true), - jwt.WithKey(km.algorithm, km.secret), + jwt.WithKey(tok.algorithm, tok.secret), ) if err != nil { if errors.Contains(err, errJWTExpiryKey) { @@ -78,7 +128,3 @@ func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, e return smqjwt.ToKey(tkn) } - -func (km *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) { - return nil, auth.ErrPublicKeysNotSupported -} diff --git a/auth/tokenizer/symmetric/tokenizer_test.go b/auth/tokenizer/symmetric/tokenizer_test.go index 0703d8d36e..63b66c0aed 100644 --- a/auth/tokenizer/symmetric/tokenizer_test.go +++ b/auth/tokenizer/symmetric/tokenizer_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/auth/mocks" "github.com/absmach/supermq/auth/tokenizer/symmetric" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" @@ -17,6 +18,7 @@ import ( ) func TestNewTokenizer(t *testing.T) { + cache := new(mocks.TokensCache) cases := []struct { name string algorithm string @@ -67,7 +69,7 @@ func TestNewTokenizer(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - km, err := symmetric.NewTokenizer(tc.algorithm, tc.secret) + km, err := symmetric.NewTokenizer(tc.algorithm, tc.secret, cache) if tc.expectErr { assert.Error(t, err) @@ -85,8 +87,9 @@ func TestNewTokenizer(t *testing.T) { func TestSign(t *testing.T) { secret := []byte("my-super-secret-key-for-testing") + cache := new(mocks.TokensCache) - km, err := symmetric.NewTokenizer("HS256", secret) + km, err := symmetric.NewTokenizer("HS256", secret, cache) require.NoError(t, err) cases := []struct { @@ -132,7 +135,7 @@ func TestSign(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - token, err := km.Issue(tc.key) + token, err := km.Issue(context.Background(), tc.key) assert.NoError(t, err) assert.NotEmpty(t, token) @@ -145,8 +148,9 @@ func TestSign(t *testing.T) { func TestVerify(t *testing.T) { secret := []byte("my-super-secret-key-for-testing") + cache := new(mocks.TokensCache) - km, err := symmetric.NewTokenizer("HS256", secret) + km, err := symmetric.NewTokenizer("HS256", secret, cache) require.NoError(t, err) validKey := auth.Key{ @@ -160,12 +164,12 @@ func TestVerify(t *testing.T) { Verified: true, } - validToken, err := km.Issue(validKey) + validToken, err := km.Issue(context.Background(), validKey) require.NoError(t, err, "Signing valid token should succeed") expiredKey := validKey expiredKey.ExpiresAt = time.Now().Add(-1 * time.Hour).UTC() - expiredToken, err := km.Issue(expiredKey) + expiredToken, err := km.Issue(context.Background(), expiredKey) require.NoError(t, err) wrongIssuerKey := validKey @@ -188,9 +192,9 @@ func TestVerify(t *testing.T) { require.NoError(t, err) wrongIssuerToken := string(wrongIssuerTokenBytes) - wrongSecretKM, err := symmetric.NewTokenizer("HS256", []byte("different-secret-key-here")) + wrongSecretKM, err := symmetric.NewTokenizer("HS256", []byte("different-secret-key-here"), cache) require.NoError(t, err) - wrongSecretToken, err := wrongSecretKM.Issue(validKey) + wrongSecretToken, err := wrongSecretKM.Issue(context.Background(), validKey) require.NoError(t, err) cases := []struct { @@ -251,8 +255,9 @@ func TestVerify(t *testing.T) { func TestPublicKeys(t *testing.T) { secret := []byte("my-super-secret-key-for-testing") + cache := new(mocks.TokensCache) - km, err := symmetric.NewTokenizer("HS256", secret) + km, err := symmetric.NewTokenizer("HS256", secret, cache) require.NoError(t, err) keys, err := km.RetrieveJWKS() @@ -263,12 +268,13 @@ func TestPublicKeys(t *testing.T) { func TestSignAndVerifyRoundTrip(t *testing.T) { algorithms := []string{"HS256", "HS384", "HS512"} + cache := new(mocks.TokensCache) for _, alg := range algorithms { t.Run(alg, func(t *testing.T) { secret := []byte("my-super-secret-key-for-testing-" + alg) - km, err := symmetric.NewTokenizer(alg, secret) + km, err := symmetric.NewTokenizer(alg, secret, cache) require.NoError(t, err) originalKey := auth.Key{ @@ -282,7 +288,7 @@ func TestSignAndVerifyRoundTrip(t *testing.T) { Verified: true, } - token, err := km.Issue(originalKey) + token, err := km.Issue(context.Background(), originalKey) require.NoError(t, err) verifiedKey, err := km.Parse(context.Background(), token) @@ -301,6 +307,7 @@ func TestSignAndVerifyRoundTrip(t *testing.T) { func TestDifferentAlgorithms(t *testing.T) { secret := []byte("my-super-secret-key-for-testing-algorithms") + cache := new(mocks.TokensCache) key := auth.Key{ ID: "key-id", @@ -313,19 +320,19 @@ func TestDifferentAlgorithms(t *testing.T) { Verified: true, } - km256, err := symmetric.NewTokenizer("HS256", secret) + km256, err := symmetric.NewTokenizer("HS256", secret, cache) require.NoError(t, err) - token256, err := km256.Issue(key) + token256, err := km256.Issue(context.Background(), key) require.NoError(t, err) - km384, err := symmetric.NewTokenizer("HS384", secret) + km384, err := symmetric.NewTokenizer("HS384", secret, cache) require.NoError(t, err) - token384, err := km384.Issue(key) + token384, err := km384.Issue(context.Background(), key) require.NoError(t, err) - km512, err := symmetric.NewTokenizer("HS512", secret) + km512, err := symmetric.NewTokenizer("HS512", secret, cache) require.NoError(t, err) - token512, err := km512.Issue(key) + token512, err := km512.Issue(context.Background(), key) require.NoError(t, err) assert.NotEqual(t, token256, token384) diff --git a/cmd/auth/main.go b/cmd/auth/main.go index d2e36d71c8..f3b472fdd2 100644 --- a/cmd/auth/main.go +++ b/cmd/auth/main.go @@ -165,17 +165,19 @@ func main() { return } + tokensCache := cache.NewTokensCache(cacheclient, cfg.CacheKeyDuration) + var tokenizer auth.Tokenizer switch { case isSymmetric: - tokenizer, err = symmetric.NewTokenizer(cfg.KeyAlgorithm, []byte(cfg.SecretKey)) + tokenizer, err = symmetric.NewTokenizer(cfg.KeyAlgorithm, []byte(cfg.SecretKey), tokensCache) if err != nil { logger.Error(fmt.Sprintf("failed to create symmetric key manager: %s", err.Error())) exitCode = 1 return } default: - tokenizer, err = asymmetric.NewTokenizer(cfg.ActiveKeyPath, cfg.RetiringKeyPath, idProvider, logger) + tokenizer, err = asymmetric.NewTokenizer(cfg.ActiveKeyPath, cfg.RetiringKeyPath, idProvider, tokensCache, logger) if err != nil { logger.Error(fmt.Sprintf("failed to create asymmetric key manager: %s", err.Error())) exitCode = 1 @@ -292,11 +294,11 @@ func validateKeyConfig(isSymmetric bool, cfg config, l *slog.Logger) error { } func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration, tokenizer auth.Tokenizer, idProvider supermq.IDProvider) (auth.Service, error) { - cache := cache.NewPatsCache(cacheClient, keyDuration) + patsCache := cache.NewPatsCache(cacheClient, keyDuration) database := pgclient.NewDatabase(db, dbConfig, tracer) keysRepo := apostgres.New(database) - patsRepo := apostgres.NewPatRepo(database, cache) + patsRepo := apostgres.NewPatRepo(database, patsCache) hasher := hasher.New() pEvaluator := spicedb.NewPolicyEvaluator(spicedbClient, logger) diff --git a/http/handler.go b/http/handler.go index 5f510b65cd..93871f0444 100644 --- a/http/handler.go +++ b/http/handler.go @@ -110,7 +110,6 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt clientID, err := h.authAccess(ctx, s.Username, string(s.Password), domainID, channelID, connections.Publish, topicType) if err != nil { - fmt.Println("AuthPublish authAccess error:", err) return err } diff --git a/internal/proto/token/v1/token.proto b/internal/proto/token/v1/token.proto index 10c066511f..de8e61129e 100644 --- a/internal/proto/token/v1/token.proto +++ b/internal/proto/token/v1/token.proto @@ -9,6 +9,7 @@ option go_package = "github.com/absmach/supermq/api/grpc/token/v1"; service TokenService { rpc Issue(IssueReq) returns (Token) {} rpc Refresh(RefreshReq) returns (Token) {} + rpc Revoke(RevokeReq) returns (RevokeRes) {} } message IssueReq { @@ -23,6 +24,10 @@ message RefreshReq { bool verified = 2; } +message RevokeReq { + string token = 1; +} + // If a token is not carrying any information itself, the type // field can be used to determine how to validate the token. // Also, different tokens can be encoded in different ways. @@ -31,3 +36,7 @@ message Token { optional string refresh_token = 2; string access_type = 3; } + +message RevokeRes{ + +} \ No newline at end of file diff --git a/tools/config/.mockery.yaml b/tools/config/.mockery.yaml index ad8877e490..16ff239c93 100644 --- a/tools/config/.mockery.yaml +++ b/tools/config/.mockery.yaml @@ -58,6 +58,7 @@ packages: Cache: Hasher: KeyRepository: + TokensCache: Tokenizer: PATS: PATSRepository: @@ -144,4 +145,3 @@ packages: github.com/absmach/supermq/notifications: interfaces: Notifier: - diff --git a/users/api/endpoints.go b/users/api/endpoints.go index 60b5910db4..f1429a1989 100644 --- a/users/api/endpoints.go +++ b/users/api/endpoints.go @@ -453,6 +453,27 @@ func refreshTokenEndpoint(svc users.Service) endpoint.Endpoint { } } +func revokeRefreshTokenEndpoint(svc users.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(tokenReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthentication + } + + err := svc.RevokeRefreshToken(ctx, session, req.RefreshToken) + if err != nil { + return nil, err + } + + return revokeRes{revoked: true}, nil + } +} + func enableEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(changeUserStatusReq) diff --git a/users/api/responses.go b/users/api/responses.go index f862912d83..f78ff63e79 100644 --- a/users/api/responses.go +++ b/users/api/responses.go @@ -80,6 +80,22 @@ func (res tokenRes) Empty() bool { return res.AccessToken == "" || res.RefreshToken == "" } +type revokeRes struct { + revoked bool +} + +func (res revokeRes) Code() int { + return http.StatusNoContent +} + +func (res revokeRes) Headers() map[string]string { + return map[string]string{} +} + +func (res revokeRes) Empty() bool { + return false +} + type sendVerificationRes struct{} func (res sendVerificationRes) Code() int { diff --git a/users/api/users.go b/users/api/users.go index 01f3ab6184..7fec132e3f 100644 --- a/users/api/users.go +++ b/users/api/users.go @@ -78,6 +78,12 @@ func usersHandler(svc users.Service, authn smqauthn.AuthNMiddleware, tokenClient api.EncodeResponse, opts..., ), "refresh_token").ServeHTTP) + r.Post("/tokens/revoke", otelhttp.NewHandler(kithttp.NewServer( + revokeRefreshTokenEndpoint(svc), + decodeRevokeRefreshToken, + api.EncodeResponse, + opts..., + ), "revoke_refresh_token").ServeHTTP) r.Patch("/{id}/email", otelhttp.NewHandler(kithttp.NewServer( updateEmailEndpoint(svc), decodeUpdateUserEmail, @@ -528,6 +534,12 @@ func decodeRefreshToken(_ context.Context, r *http.Request) (any, error) { return req, nil } +func decodeRevokeRefreshToken(_ context.Context, r *http.Request) (any, error) { + req := tokenReq{RefreshToken: apiutil.ExtractBearerToken(r)} + + return req, nil +} + func decodeCreateUserReq(_ context.Context, r *http.Request) (any, error) { if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) diff --git a/users/events/events.go b/users/events/events.go index f298845350..895c1517a8 100644 --- a/users/events/events.go +++ b/users/events/events.go @@ -29,11 +29,10 @@ const ( profileView = userPrefix + "view_profile" userList = userPrefix + "list" userSearch = userPrefix + "search" - userListByGroup = userPrefix + "list_by_group" userIdentify = userPrefix + "identify" - generateResetToken = userPrefix + "generate_reset_token" issueToken = userPrefix + "issue_token" refreshToken = userPrefix + "refresh_token" + revokeRefreshToken = userPrefix + "revoke_refresh_token" resetSecret = userPrefix + "reset_secret" sendPasswordReset = userPrefix + "send_password_reset" oauthCallback = userPrefix + "oauth_callback" @@ -480,6 +479,17 @@ func (rte refreshTokenEvent) Encode() (map[string]any, error) { }, nil } +type revokeRefreshTokenEvent struct { + requestID string +} + +func (rrte revokeRefreshTokenEvent) Encode() (map[string]any, error) { + return map[string]any{ + "operation": revokeRefreshToken, + "request_id": rrte.requestID, + }, nil +} + type resetSecretEvent struct { requestID string } diff --git a/users/events/streams.go b/users/events/streams.go index ad3c5e376b..921b73dabe 100644 --- a/users/events/streams.go +++ b/users/events/streams.go @@ -15,33 +15,32 @@ import ( ) const ( - supermqPrefix = "supermq." - createStream = supermqPrefix + userCreate - sendVerificationStream = supermqPrefix + userSendVerification - verifyEmailStream = supermqPrefix + userVerifyEmail - updateStream = supermqPrefix + userUpdate - updateRoleStream = supermqPrefix + userUpdateRole - updateTagsStream = supermqPrefix + userUpdateTags - updateSecretStream = supermqPrefix + userUpdateSecret - updateUsernameStream = supermqPrefix + userUpdateUsername - updatePictureStream = supermqPrefix + userUpdateProfilePicture - UpdateEmailStream = supermqPrefix + userUpdateEmail - enableStream = supermqPrefix + userEnable - disableStream = supermqPrefix + userDisable - viewStream = supermqPrefix + userView - viewProfileStream = supermqPrefix + profileView - listStream = supermqPrefix + userList - searchStream = supermqPrefix + userSearch - listByGroupStream = supermqPrefix + userListByGroup - identifyStream = supermqPrefix + userIdentify - resetTokenStream = supermqPrefix + generateResetToken - issueTokenStream = supermqPrefix + issueToken - refreshTokenStream = supermqPrefix + refreshToken - resetSecretStream = supermqPrefix + resetSecret - sendPasswordResetStream = supermqPrefix + sendPasswordReset - oauthStream = supermqPrefix + oauthCallback - addPolicyStream = supermqPrefix + addClientPolicy - deleteStream = supermqPrefix + deleteUser + supermqPrefix = "supermq." + createStream = supermqPrefix + userCreate + sendVerificationStream = supermqPrefix + userSendVerification + verifyEmailStream = supermqPrefix + userVerifyEmail + updateStream = supermqPrefix + userUpdate + updateRoleStream = supermqPrefix + userUpdateRole + updateTagsStream = supermqPrefix + userUpdateTags + updateSecretStream = supermqPrefix + userUpdateSecret + updateUsernameStream = supermqPrefix + userUpdateUsername + updatePictureStream = supermqPrefix + userUpdateProfilePicture + UpdateEmailStream = supermqPrefix + userUpdateEmail + enableStream = supermqPrefix + userEnable + disableStream = supermqPrefix + userDisable + viewStream = supermqPrefix + userView + viewProfileStream = supermqPrefix + profileView + listStream = supermqPrefix + userList + searchStream = supermqPrefix + userSearch + identifyStream = supermqPrefix + userIdentify + issueTokenStream = supermqPrefix + issueToken + refreshTokenStream = supermqPrefix + refreshToken + revokeRefreshTokenStream = supermqPrefix + revokeRefreshToken + resetSecretStream = supermqPrefix + resetSecret + sendPasswordResetStream = supermqPrefix + sendPasswordReset + oauthStream = supermqPrefix + oauthCallback + addPolicyStream = supermqPrefix + addClientPolicy + deleteStream = supermqPrefix + deleteUser ) var _ users.Service = (*eventStore)(nil) @@ -387,6 +386,19 @@ func (es *eventStore) RefreshToken(ctx context.Context, session authn.Session, r return token, nil } +func (es *eventStore) RevokeRefreshToken(ctx context.Context, session authn.Session, refreshToken string) error { + err := es.svc.RevokeRefreshToken(ctx, session, refreshToken) + if err != nil { + return err + } + + event := revokeRefreshTokenEvent{ + requestID: middleware.GetReqID(ctx), + } + + return es.Publish(ctx, revokeRefreshTokenStream, event) +} + func (es *eventStore) ResetSecret(ctx context.Context, session authn.Session, secret string) error { if err := es.svc.ResetSecret(ctx, session, secret); err != nil { return err diff --git a/users/middleware/authorization.go b/users/middleware/authorization.go index 802f67eacb..109286fb35 100644 --- a/users/middleware/authorization.go +++ b/users/middleware/authorization.go @@ -169,6 +169,10 @@ func (am *authorizationMiddleware) RefreshToken(ctx context.Context, session aut return am.svc.RefreshToken(ctx, session, refreshToken) } +func (am *authorizationMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, refreshToken string) error { + return am.svc.RevokeRefreshToken(ctx, session, refreshToken) +} + func (am *authorizationMiddleware) OAuthCallback(ctx context.Context, user users.User) (users.User, error) { return am.svc.OAuthCallback(ctx, user) } diff --git a/users/middleware/logging.go b/users/middleware/logging.go index 8339cfe9ca..f5348eb741 100644 --- a/users/middleware/logging.go +++ b/users/middleware/logging.go @@ -129,6 +129,24 @@ func (lm *loggingMiddleware) RefreshToken(ctx context.Context, session authn.Ses return lm.svc.RefreshToken(ctx, session, refreshToken) } +// RevokeRefreshToken logs the revoke_refresh_token request. It logs the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, refreshToken string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Revoke refresh token failed", args...) + return + } + lm.logger.Info("Revoke refresh token completed successfully", args...) + }(time.Now()) + return lm.svc.RevokeRefreshToken(ctx, session, refreshToken) +} + // View logs the view_user request. It logs the user id and the time it took to complete the request. // If the request fails, it logs the error. func (lm *loggingMiddleware) View(ctx context.Context, session authn.Session, id string) (c users.User, err error) { diff --git a/users/middleware/metrics.go b/users/middleware/metrics.go index 1095e267d0..21be17601e 100644 --- a/users/middleware/metrics.go +++ b/users/middleware/metrics.go @@ -75,6 +75,15 @@ func (ms *metricsMiddleware) RefreshToken(ctx context.Context, session authn.Ses return ms.svc.RefreshToken(ctx, session, refreshToken) } +// RevokeRefreshToken instruments RevokeRefreshToken method with metrics. +func (ms *metricsMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, refreshToken string) error { + defer func(begin time.Time) { + ms.counter.With("method", "revoke_refresh_token").Add(1) + ms.latency.With("method", "revoke_refresh_token").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.RevokeRefreshToken(ctx, session, refreshToken) +} + // View instruments View method with metrics. func (ms *metricsMiddleware) View(ctx context.Context, session authn.Session, id string) (users.User, error) { defer func(begin time.Time) { diff --git a/users/middleware/tracing.go b/users/middleware/tracing.go index 9ee39763e7..e5f1437d00 100644 --- a/users/middleware/tracing.go +++ b/users/middleware/tracing.go @@ -65,6 +65,14 @@ func (tm *tracingMiddleware) RefreshToken(ctx context.Context, session authn.Ses return tm.svc.RefreshToken(ctx, session, refreshToken) } +// RevokeRefreshToken traces the "RevokeRefreshToken" operation of the wrapped users.Service. +func (tm *tracingMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, refreshToken string) error { + ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_revoke_refresh_token", trace.WithAttributes(attribute.String("refresh_token", refreshToken))) + defer span.End() + + return tm.svc.RevokeRefreshToken(ctx, session, refreshToken) +} + // View traces the "View" operation of the wrapped users.Service. func (tm *tracingMiddleware) View(ctx context.Context, session authn.Session, id string) (users.User, error) { ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_view_user", trace.WithAttributes(attribute.String("id", id))) diff --git a/users/mocks/service.go b/users/mocks/service.go index 74a090668b..1ce29176f6 100644 --- a/users/mocks/service.go +++ b/users/mocks/service.go @@ -801,6 +801,69 @@ func (_c *Service_ResetSecret_Call) RunAndReturn(run func(ctx context.Context, s return _c } +// RevokeRefreshToken provides a mock function for the type Service +func (_mock *Service) RevokeRefreshToken(ctx context.Context, session authn.Session, refreshToken string) error { + ret := _mock.Called(ctx, session, refreshToken) + + if len(ret) == 0 { + panic("no return value specified for RevokeRefreshToken") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, refreshToken) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RevokeRefreshToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeRefreshToken' +type Service_RevokeRefreshToken_Call struct { + *mock.Call +} + +// RevokeRefreshToken is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - refreshToken string +func (_e *Service_Expecter) RevokeRefreshToken(ctx interface{}, session interface{}, refreshToken interface{}) *Service_RevokeRefreshToken_Call { + return &Service_RevokeRefreshToken_Call{Call: _e.mock.On("RevokeRefreshToken", ctx, session, refreshToken)} +} + +func (_c *Service_RevokeRefreshToken_Call) Run(run func(ctx context.Context, session authn.Session, refreshToken string)) *Service_RevokeRefreshToken_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_RevokeRefreshToken_Call) Return(err error) *Service_RevokeRefreshToken_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RevokeRefreshToken_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, refreshToken string) error) *Service_RevokeRefreshToken_Call { + _c.Call.Return(run) + return _c +} + // SearchUsers provides a mock function for the type Service func (_mock *Service) SearchUsers(ctx context.Context, pm users.Page) (users.UsersPage, error) { ret := _mock.Called(ctx, pm) diff --git a/users/service.go b/users/service.go index 2dbcb57ea3..88859683eb 100644 --- a/users/service.go +++ b/users/service.go @@ -229,6 +229,22 @@ func (svc service) RefreshToken(ctx context.Context, session authn.Session, refr return token, nil } +func (svc service) RevokeRefreshToken(ctx context.Context, session authn.Session, refreshToken string) error { + dbUser, err := svc.users.RetrieveByID(ctx, session.UserID) + if err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + if dbUser.Status == DisabledStatus { + return errors.Wrap(svcerr.ErrAuthentication, errLoginDisableUser) + } + _, err = svc.token.Revoke(ctx, &grpcTokenV1.RevokeReq{Token: refreshToken}) + if err != nil { + return errors.Wrap(svcerr.ErrAuthorization, err) + } + + return nil +} + func (svc service) View(ctx context.Context, session authn.Session, id string) (User, error) { user, err := svc.users.RetrieveByID(ctx, id) if err != nil { diff --git a/users/users.go b/users/users.go index 4a269f4215..b6ca83fb75 100644 --- a/users/users.go +++ b/users/users.go @@ -237,6 +237,9 @@ type Service interface { // a new pair of access and refresh tokens. RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (*grpcTokenV1.Token, error) + // RevokeRefreshToken revokes a refresh token. + RevokeRefreshToken(ctx context.Context, session authn.Session, refreshToken string) error + // OAuthCallback handles the callback from any supported OAuth provider. // It processes the OAuth tokens and either signs in or signs up the user based on the provided state. OAuthCallback(ctx context.Context, user User) (User, error)