diff --git a/txnkv/transaction/txn.go b/txnkv/transaction/txn.go index 3237ddbcaa..61e88a5d29 100644 --- a/txnkv/transaction/txn.go +++ b/txnkv/transaction/txn.go @@ -834,6 +834,17 @@ func (txn *KVTxn) filterAggressiveLockedKeys(lockCtx *tikv.LockCtx, allKeys [][] // LockKeys tries to lock the entries with the keys in KV store. // lockCtx is the context for lock, lockCtx.lockWaitTime in ms func (txn *KVTxn) LockKeys(ctx context.Context, lockCtx *tikv.LockCtx, keysInput ...[]byte) error { + return txn.lockKeys(ctx, lockCtx, nil, keysInput...) +} + +// LockKeysFunc tries to lock the entries with the keys in KV store. +// lockCtx is the context for lock, lockCtx.lockWaitTime in ms +// fn is a function which run before the lock is released. +func (txn *KVTxn) LockKeysFunc(ctx context.Context, lockCtx *tikv.LockCtx, fn func(), keysInput ...[]byte) error { + return txn.lockKeys(ctx, lockCtx, fn, keysInput...) +} + +func (txn *KVTxn) lockKeys(ctx context.Context, lockCtx *tikv.LockCtx, fn func(), keysInput ...[]byte) error { if txn.interceptor != nil { // User has called txn.SetRPCInterceptor() to explicitly set an interceptor, we // need to bind it to ctx so that the internal client can perceive and execute @@ -871,6 +882,11 @@ func (txn *KVTxn) LockKeys(ctx context.Context, lockCtx *tikv.LockCtx, keysInput } } }() + defer func() { + if fn != nil { + fn() + } + }() if !txn.IsPessimistic() && txn.aggressiveLockingContext != nil { return errors.New("trying to perform aggressive locking in optimistic transaction")