Skip to content

Commit 153d669

Browse files
committed
feat: implement KMS key revocation and coordinated rotation
Adds KMS key state monitoring and coordinated key rotation to prevent message queue consumption failures during encryption key updates. Key Changes: - Add KeyManager in RootCoord for periodic KMS state polling - Integrate KeyManager with QuotaCenter for access denial - Implement revocation checks in Proxy SimpleLimiter - Add rotation callback coordination via AlterDatabase broadcast - Drop internal properties before metadata persistence - Add GetStates() and InvalidateCipherCache() to hookutil Access Denial: - Revoked keys: Release collections + deny DML/DQL (DDL still allowed) - Check performed on every request at proxy layer - Manual LoadCollection required after key recovery Key Rotation Flow: 1. CipherPlugin rotates key, writes to etcd 2. Plugin invokes onKeyRotated callback 3. KeyManager broadcasts AlterDatabase with internal property 4. StreamingNode receives message and reloads cipher 5. ACK callback invalidate Proxy db cache and refresh key See also: #45117, #45981, #45242 Signed-off-by: yangxuan <[email protected]>
1 parent 4f080bd commit 153d669

File tree

18 files changed

+546
-60
lines changed

18 files changed

+546
-60
lines changed

internal/coordinator/mix_coord.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ func (s *mixCoordImpl) initInternal() error {
215215
log.Error("queryCoord start failed", zap.Error(err))
216216
return err
217217
}
218+
218219
return nil
219220
}
220221

internal/flushcommon/pipeline/flow_graph_dd_node.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,19 @@ func (ddn *ddNode) Operate(in []Msg) []Msg {
297297
} else {
298298
logger.Info("handle put collection message success")
299299
}
300+
case commonpb.MsgType_AlterDatabase:
301+
alterDatabaseMsg := msg.(*adaptor.AlterDatabaseMessageBody)
302+
logger := log.With(
303+
zap.String("vchannel", ddn.Name()),
304+
zap.Int32("msgType", int32(msg.Type())),
305+
zap.Uint64("timetick", alterDatabaseMsg.AlterDatabaseMessage.TimeTick()),
306+
)
307+
logger.Info("receive alter database message")
308+
if err := ddn.msgHandler.HandleAlterDatabase(ddn.ctx, alterDatabaseMsg.AlterDatabaseMessage); err != nil {
309+
logger.Warn("handle alter database message failed", zap.Error(err))
310+
} else {
311+
logger.Info("handle alter database message success")
312+
}
300313
}
301314
}
302315

internal/flushcommon/util/msg_handler.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ type MsgHandler interface {
3434
HandleSchemaChange(ctx context.Context, schemaChangeMsg message.ImmutableSchemaChangeMessageV2) error
3535

3636
HandleAlterCollection(ctx context.Context, alterCollectionMsg message.ImmutableAlterCollectionMessageV2) error
37+
38+
HandleAlterDatabase(ctx context.Context, alterDatabaseMsg message.ImmutableAlterDatabaseMessageV2) error
3739
}
3840

3941
func ConvertInternalImportFile(file *msgpb.ImportFile, _ int) *internalpb.ImportFile {

internal/proxy/impl.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,13 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
169169
node.shardMgr.RemoveDatabase(request.GetDbName())
170170
fallthrough
171171
case commonpb.MsgType_AlterDatabase:
172-
globalMetaCache.RemoveDatabase(ctx, request.GetDbName())
172+
if db, err := globalMetaCache.GetDatabaseInfo(ctx, request.GetDbName()); err == nil {
173+
if db != nil {
174+
err := hookutil.RefreshEZ(db.properties)
175+
log.Info("failed to refresh ez hook", zap.Error(err))
176+
}
177+
}
178+
globalMetaCache.RemoveDatabase(ctx, dbName)
173179
case commonpb.MsgType_AlterCollection, commonpb.MsgType_AlterCollectionField:
174180
if request.CollectionID != UniqueID(0) {
175181
aliasName = globalMetaCache.RemoveCollectionsByID(ctx, collectionID, 0, false)

internal/proxy/simple_rate_limiter.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import (
3434
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
3535
"github.com/milvus-io/milvus/pkg/v2/proto/proxypb"
3636
"github.com/milvus-io/milvus/pkg/v2/util"
37+
"github.com/milvus-io/milvus/pkg/v2/util/merr"
3738
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
3839
"github.com/milvus-io/milvus/pkg/v2/util/ratelimitutil"
3940
"github.com/milvus-io/milvus/pkg/v2/util/retry"
@@ -48,6 +49,9 @@ type SimpleLimiter struct {
4849
// for alloc
4950
allocWaitInterval time.Duration
5051
allocRetryTimes uint
52+
53+
// for KMS key revocation
54+
revokedDatabases sync.Map // map[int64]string (dbID → reason)
5155
}
5256

5357
// NewSimpleLimiter returns a new SimpleLimiter.
@@ -73,6 +77,12 @@ func (m *SimpleLimiter) Check(dbID int64, collectionIDToPartIDs map[int64][]int6
7377
return nil
7478
}
7579

80+
// Check for KMS key revocation (highest priority)
81+
if dbID != util.InvalidDBID && m.isDatabaseRevoked(dbID) {
82+
reason := m.getRevokedReason(dbID)
83+
return merr.WrapErrKMSKeyRevoked(dbID, reason)
84+
}
85+
7686
m.quotaStatesMu.RLock()
7787
defer m.quotaStatesMu.RUnlock()
7888

@@ -378,3 +388,32 @@ func IsDDLRequest(rt internalpb.RateType) bool {
378388
return false
379389
}
380390
}
391+
392+
// MarkDatabaseRevoked marks a database as revoked due to KMS key issues
393+
func (m *SimpleLimiter) MarkDatabaseRevoked(dbID int64, reason string) {
394+
m.revokedDatabases.Store(dbID, reason)
395+
log.Info("database marked as revoked",
396+
zap.Int64("dbID", dbID),
397+
zap.String("reason", reason))
398+
}
399+
400+
// UnmarkDatabaseRevoked removes revocation mark from a database
401+
func (m *SimpleLimiter) UnmarkDatabaseRevoked(dbID int64) {
402+
m.revokedDatabases.Delete(dbID)
403+
log.Info("database revocation mark removed",
404+
zap.Int64("dbID", dbID))
405+
}
406+
407+
// isDatabaseRevoked checks if a database is currently revoked
408+
func (m *SimpleLimiter) isDatabaseRevoked(dbID int64) bool {
409+
_, ok := m.revokedDatabases.Load(dbID)
410+
return ok
411+
}
412+
413+
// getRevokedReason returns the revocation reason for a database
414+
func (m *SimpleLimiter) getRevokedReason(dbID int64) string {
415+
if reason, ok := m.revokedDatabases.Load(dbID); ok {
416+
return reason.(string)
417+
}
418+
return "unknown reason"
419+
}

internal/rootcoord/ddl_callbacks_alter_database.go

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,17 @@ func (c *Core) broadcastAlterDatabase(ctx context.Context, req *rootcoordpb.Alte
9191
}
9292
}
9393

94+
// Pop the internal properties before persist
95+
newProperties = dropInternalProperties(newProperties)
96+
97+
// Get all vchannels for collections in this database to broadcast the message
98+
broadcastChannels, err := c.getVChannelsForDatabase(ctx, req.GetDbName())
99+
if err != nil {
100+
return errors.Wrap(err, "failed to get vchannels for database")
101+
}
102+
// Always include control channel for ordering and callback coordination
103+
broadcastChannels = append(broadcastChannels, streaming.WAL().ControlChannel())
104+
94105
msg := message.NewAlterDatabaseMessageBuilderV2().
95106
WithHeader(&message.AlterDatabaseMessageHeader{
96107
DbName: req.GetDbName(),
@@ -100,12 +111,36 @@ func (c *Core) broadcastAlterDatabase(ctx context.Context, req *rootcoordpb.Alte
100111
Properties: newProperties,
101112
AlterLoadConfig: alterLoadConfig,
102113
}).
103-
WithBroadcast([]string{streaming.WAL().ControlChannel()}).
114+
WithBroadcast(broadcastChannels).
104115
MustBuildBroadcast()
105116
_, err = broadcaster.Broadcast(ctx, msg)
106117
return err
107118
}
108119

120+
func dropInternalProperties(props []*commonpb.KeyValuePair) []*commonpb.KeyValuePair {
121+
var newProps []*commonpb.KeyValuePair
122+
for _, prop := range props {
123+
if !common.IsInternalPropertyKey(prop) {
124+
newProps = append(newProps, prop)
125+
}
126+
}
127+
return newProps
128+
}
129+
130+
// getVChannelsForDatabase gets all virtual channels for collections in the database.
131+
func (c *Core) getVChannelsForDatabase(ctx context.Context, dbName string) ([]string, error) {
132+
colls, err := c.meta.ListCollections(ctx, dbName, typeutil.MaxTimestamp, true)
133+
if err != nil {
134+
return nil, err
135+
}
136+
137+
vchannels := make([]string, 0)
138+
for _, coll := range colls {
139+
vchannels = append(vchannels, coll.VirtualChannelNames...)
140+
}
141+
return vchannels, nil
142+
}
143+
109144
// getAlterLoadConfigOfAlterDatabase gets the alter load config of alter database.
110145
func (c *Core) getAlterLoadConfigOfAlterDatabase(ctx context.Context, dbName string, oldProps []*commonpb.KeyValuePair, newProps []*commonpb.KeyValuePair) (*message.AlterLoadConfigOfAlterDatabase, error) {
111146
oldReplicaNumber, _ := common.DatabaseLevelReplicaNumber(oldProps)
@@ -137,7 +172,7 @@ func (c *DDLCallback) alterDatabaseV1AckCallback(ctx context.Context, result mes
137172
header := result.Message.Header()
138173
body := result.Message.MustBody()
139174

140-
db := model.NewDatabase(header.DbId, header.DbName, etcdpb.DatabaseState_DatabaseCreated, result.Message.MustBody().Properties)
175+
db := model.NewDatabase(header.DbId, header.DbName, etcdpb.DatabaseState_DatabaseCreated, body.Properties)
141176
if err := c.meta.AlterDatabase(ctx, db, result.GetControlChannelResult().TimeTick); err != nil {
142177
return errors.Wrap(err, "failed to alter database")
143178
}

internal/rootcoord/key_manager.go

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
// Licensed to the LF AI & Data foundation under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
17+
package rootcoord
18+
19+
import (
20+
"context"
21+
"errors"
22+
"fmt"
23+
24+
"go.uber.org/zap"
25+
26+
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
27+
"github.com/milvus-io/milvus/internal/metastore/model"
28+
"github.com/milvus-io/milvus/internal/types"
29+
"github.com/milvus-io/milvus/internal/util/hookutil"
30+
"github.com/milvus-io/milvus/pkg/v2/common"
31+
"github.com/milvus-io/milvus/pkg/v2/log"
32+
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
33+
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
34+
"github.com/milvus-io/milvus/pkg/v2/util"
35+
"github.com/milvus-io/milvus/pkg/v2/util/merr"
36+
"github.com/samber/lo"
37+
)
38+
39+
type KeyManager struct {
40+
ctx context.Context
41+
meta IMetaTable
42+
mixCoord types.MixCoord
43+
enabled bool
44+
}
45+
46+
func NewKeyManager(
47+
ctx context.Context,
48+
meta IMetaTable,
49+
mixCoord types.MixCoord,
50+
) *KeyManager {
51+
return &KeyManager{
52+
ctx: ctx,
53+
meta: meta,
54+
mixCoord: mixCoord,
55+
enabled: hookutil.GetCipherWithState() != nil,
56+
}
57+
}
58+
59+
func (km *KeyManager) Init() error {
60+
if !km.enabled {
61+
log.Info("KeyManager disabled (cipher plugin not loaded)")
62+
return nil
63+
}
64+
65+
hookutil.GetCipherWithState().RegisterRotationCallback(km.onKeyRotated)
66+
log.Info("KeyManager initialized")
67+
return nil
68+
}
69+
70+
func (km *KeyManager) GetDatabaseEzStates() ([]int64, error) {
71+
if !km.enabled {
72+
return nil, nil
73+
}
74+
75+
currentStates, err := hookutil.GetEzStates()
76+
if err != nil {
77+
return nil, fmt.Errorf("failed to get cipher states: %w", err)
78+
}
79+
80+
revokedDBs := make(map[int64]struct{})
81+
for ezID, currentState := range currentStates {
82+
switch currentState {
83+
case hookutil.KeyStateDisabled, hookutil.KeyStatePendingDeletion:
84+
db, err := km.getDatabaseByEzID(ezID)
85+
if err != nil {
86+
log.Warn("KeyManager: failed to get database for ezID", zap.Int64("ezID", ezID), zap.Error(err))
87+
continue
88+
}
89+
90+
revokedDBs[db.ID] = struct{}{}
91+
}
92+
}
93+
94+
revokedDBIDs := lo.Keys(revokedDBs)
95+
if err := km.releaseLoadedCollections(revokedDBIDs); err != nil {
96+
log.Warn("KeyManager: failed to release collections for revoked databases", zap.Error(err))
97+
}
98+
99+
return revokedDBIDs, nil
100+
}
101+
102+
func (km *KeyManager) releaseLoadedCollections(revokedDBIDs []int64) error {
103+
if len(revokedDBIDs) == 0 {
104+
return nil
105+
}
106+
107+
collectionIDsToCheck := make([]int64, 0)
108+
109+
for _, dbID := range revokedDBIDs {
110+
db, err := km.meta.GetDatabaseByID(km.ctx, dbID, 0)
111+
if err != nil {
112+
log.Warn("KeyManager: failed to get database metadata", zap.Int64("dbID", dbID), zap.Error(err))
113+
continue
114+
}
115+
116+
colls, err := km.meta.ListCollections(km.ctx, db.Name, 0, true)
117+
if err != nil {
118+
log.Warn("KeyManager: failed to list collections for revoked database",
119+
zap.Int64("dbID", dbID),
120+
zap.String("dbName", db.Name),
121+
zap.Error(err))
122+
continue
123+
}
124+
125+
for _, coll := range colls {
126+
collectionIDsToCheck = append(collectionIDsToCheck, coll.CollectionID)
127+
}
128+
}
129+
130+
if len(collectionIDsToCheck) == 0 {
131+
return nil
132+
}
133+
134+
resp, err := km.mixCoord.ShowLoadCollections(km.ctx, &querypb.ShowCollectionsRequest{
135+
CollectionIDs: collectionIDsToCheck,
136+
})
137+
if err := merr.CheckRPCCall(resp.GetStatus(), err); err != nil {
138+
if errors.Is(err, merr.ErrCollectionNotLoaded) {
139+
return nil
140+
}
141+
return fmt.Errorf("failed to get loaded collections: %w", err)
142+
}
143+
144+
log.Info("KeyManager: releasing loaded collection for revoked database",
145+
zap.Int64s("collectionIDs", resp.GetCollectionIDs()),
146+
)
147+
for _, collID := range resp.GetCollectionIDs() {
148+
req := &querypb.ReleaseCollectionRequest{CollectionID: collID}
149+
if _, err := km.mixCoord.ReleaseCollection(km.ctx, req); err != nil {
150+
log.Warn("KeyManager: failed to release collection", zap.Int64("collectionID", collID), zap.Error(err))
151+
continue
152+
}
153+
}
154+
return nil
155+
}
156+
157+
func (km *KeyManager) getDatabaseByEzID(ezID int64) (*model.Database, error) {
158+
db, err := km.meta.GetDatabaseByID(km.ctx, ezID, 0)
159+
if err != nil {
160+
return km.meta.GetDatabaseByID(km.ctx, util.DefaultDBID, 0)
161+
}
162+
return db, nil
163+
}
164+
165+
func (km *KeyManager) onKeyRotated(ezID int64) error {
166+
if !km.enabled {
167+
return nil
168+
}
169+
170+
log.Info("KeyManager: handling key rotation", zap.Int64("ezID", ezID))
171+
db, err := km.getDatabaseByEzID(ezID)
172+
if err != nil {
173+
return err
174+
}
175+
176+
req := &rootcoordpb.AlterDatabaseRequest{
177+
DbName: db.Name,
178+
Properties: []*commonpb.KeyValuePair{
179+
{
180+
Key: common.InternalCipherKeyRotatedKey,
181+
},
182+
},
183+
}
184+
185+
status, err := km.mixCoord.AlterDatabase(km.ctx, req)
186+
if err := merr.CheckRPCCall(status, err); err != nil {
187+
log.Error("KeyManager: failed to broadcast key rotation",
188+
zap.Int64("dbID", db.ID),
189+
zap.Int64("ezID", ezID),
190+
zap.String("dbName", db.Name),
191+
zap.Error(err))
192+
return fmt.Errorf("failed to broadcast key rotation: %w", err)
193+
}
194+
195+
log.Info("KeyManager: key rotation handled", zap.Int64("ezID", ezID))
196+
return nil
197+
}

0 commit comments

Comments
 (0)