Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions service/history/queues/dlq_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"sync"

"go.temporal.io/server/common/log"
"go.temporal.io/server/common/log/tag"
Expand All @@ -20,6 +21,7 @@ type (
metricsHandler metrics.Handler
logger log.SnTaggedLogger
namespaceRegistry namespace.Registry
enqueueMutex sync.Map // map[persistence.QueueKey]*sync.Mutex for per-queue locking
}
// QueueWriter is a subset of persistence.HistoryTaskQueueManager.
QueueWriter interface {
Expand Down Expand Up @@ -77,13 +79,21 @@ func (q *DLQWriter) WriteTaskToDLQ(
}
}

resp, err := q.dlqWriter.EnqueueTask(ctx, &persistence.EnqueueTaskRequest{
QueueType: queueKey.QueueType,
SourceCluster: queueKey.SourceCluster,
TargetCluster: queueKey.TargetCluster,
Task: task,
SourceShardID: sourceShardID,
})
resp, err := func() (*persistence.EnqueueTaskResponse, error) {
// Acquire a process-level lock for this specific DLQ to prevent concurrent writes
// from multiple shards causing CAS conflicts in the persistence layer.
mu := q.getQueueMutex(queueKey)
mu.Lock()
defer mu.Unlock()

return q.dlqWriter.EnqueueTask(ctx, &persistence.EnqueueTaskRequest{
QueueType: queueKey.QueueType,
SourceCluster: queueKey.SourceCluster,
TargetCluster: queueKey.TargetCluster,
Task: task,
SourceShardID: sourceShardID,
})
}()
if err != nil {
return fmt.Errorf("%w: %v", ErrSendTaskToDLQ, err)
}
Expand Down Expand Up @@ -117,3 +127,15 @@ func (q *DLQWriter) WriteTaskToDLQ(
)
return nil
}

// getQueueMutex returns a per-queue mutex, creating it if it doesn't exist.
// This provides process-level locking to serialize concurrent writes to the same queue.
func (q *DLQWriter) getQueueMutex(queueKey persistence.QueueKey) *sync.Mutex {
if mu, ok := q.enqueueMutex.Load(queueKey); ok {
return mu.(*sync.Mutex) //nolint:revive
}

newMutex := &sync.Mutex{}
actual, _ := q.enqueueMutex.LoadOrStore(queueKey, newMutex)
return actual.(*sync.Mutex) //nolint:revive
}
163 changes: 163 additions & 0 deletions service/history/queues/dlq_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package queues_test
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -13,11 +16,13 @@ import (
"go.temporal.io/server/common/metrics"
"go.temporal.io/server/common/metrics/metricstest"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/service/history/queues"
"go.temporal.io/server/service/history/queues/queuestest"
"go.temporal.io/server/service/history/tasks"
"go.temporal.io/server/service/history/tests"
"go.uber.org/mock/gomock"
"golang.org/x/sync/errgroup"
)

type (
Expand All @@ -27,11 +32,14 @@ type (
}
logRecorder struct {
log.SnTaggedLogger
mu sync.Mutex
records []logRecord
}
)

func (l *logRecorder) Warn(msg string, tags ...tag.Tag) {
l.mu.Lock()
defer l.mu.Unlock()
l.records = append(l.records, logRecord{msg: msg, tags: tags})
}

Expand Down Expand Up @@ -128,3 +136,158 @@ func TestDLQWriter_Ok(t *testing.T) {
namespaceStateTag := metrics.NamespaceStateTag("active")
assert.Equal(t, "active", recordings[0].Tags[namespaceStateTag.Key])
}

func TestDLQWriter_ConcurrentWrites(t *testing.T) {
t.Parallel()

// This test verifies that the DLQ writer serializes concurrent writes to the same queue
// using a process-level lock, preventing CAS conflicts in the persistence layer.
queueWriter := &queuestest.FakeQueueWriter{}
ctrl := gomock.NewController(t)
namespaceRegistry := namespace.NewMockRegistry(ctrl)
namespaceRegistry.EXPECT().GetNamespaceByID(gomock.Any()).Return(&namespace.Namespace{}, nil).AnyTimes()
logger := &logRecorder{SnTaggedLogger: log.NewTestLogger()}
metricsHandler := metricstest.NewCaptureHandler()
writer := queues.NewDLQWriter(queueWriter, metricsHandler, logger, namespaceRegistry)

const numConcurrentWrites = 50
var g errgroup.Group
var concurrentAccessCount atomic.Int32
var maxConcurrentAccess atomic.Int32

// Create tasks that will write to the same DLQ (same category, source, target)
testTasks := make([]*tasks.WorkflowTask, numConcurrentWrites)
for i := 0; i < numConcurrentWrites; i++ {
testTasks[i] = &tasks.WorkflowTask{
WorkflowKey: definition.WorkflowKey{
NamespaceID: string(tests.NamespaceID),
WorkflowID: tests.WorkflowID,
RunID: tests.RunID,
},
}
}

// Wrap the queue writer to track concurrent access
queueWriter.EnqueueTaskFunc = func(ctx context.Context, request *persistence.EnqueueTaskRequest) (*persistence.EnqueueTaskResponse, error) {
// Increment concurrent access counter
current := concurrentAccessCount.Add(1)

// Track max concurrent access
for {
maxWrites := maxConcurrentAccess.Load()
if current <= maxWrites || maxConcurrentAccess.CompareAndSwap(maxWrites, current) {
break
}
}

// Simulate some work that could cause race conditions.
time.Sleep(10 * time.Millisecond) //nolint:forbidigo

// Decrement counter
concurrentAccessCount.Add(-1)

return &persistence.EnqueueTaskResponse{Metadata: persistence.MessageMetadata{ID: 0}}, nil
}

// Launch concurrent writes to the same DLQ
for i := 0; i < numConcurrentWrites; i++ {
task := testTasks[i]
g.Go(func() error {
err := writer.WriteTaskToDLQ(
context.Background(),
"source-cluster",
"target-cluster",
1, // same shard ID
task,
true,
)
require.NoError(t, err)
return nil
})
}

require.NoError(t, g.Wait())

// Verify all writes succeeded
require.Len(t, queueWriter.EnqueueTaskRequests, numConcurrentWrites)

// The key assertion: with the lock in place, we should never have more than 1 concurrent access
assert.Equal(t, int32(1), maxConcurrentAccess.Load(),
"Expected serialized access (max 1 concurrent), but got %d concurrent accesses. "+
"This indicates the lock is not working properly.", maxConcurrentAccess.Load())
}

func TestDLQWriter_ConcurrentWritesDifferentQueues(t *testing.T) {
t.Parallel()

// This test verifies that concurrent writes to DIFFERENT queues can proceed in parallel
queueWriter := &queuestest.FakeQueueWriter{}
ctrl := gomock.NewController(t)
namespaceRegistry := namespace.NewMockRegistry(ctrl)
namespaceRegistry.EXPECT().GetNamespaceByID(gomock.Any()).Return(&namespace.Namespace{}, nil).AnyTimes()
logger := &logRecorder{SnTaggedLogger: log.NewTestLogger()}
metricsHandler := metricstest.NewCaptureHandler()
writer := queues.NewDLQWriter(queueWriter, metricsHandler, logger, namespaceRegistry)

const numConcurrentWrites = 50
const numQueues = 5
var g errgroup.Group
var concurrentAccessCount atomic.Int32
var maxConcurrentAccess atomic.Int32

// Wrap the queue writer to track concurrent access
queueWriter.EnqueueTaskFunc = func(ctx context.Context, request *persistence.EnqueueTaskRequest) (*persistence.EnqueueTaskResponse, error) {
current := concurrentAccessCount.Add(1)

for {
maxValue := maxConcurrentAccess.Load()
if current <= maxValue || maxConcurrentAccess.CompareAndSwap(maxValue, current) {
break
}
}

// Simulate some work that could cause race conditions.
time.Sleep(10 * time.Millisecond) //nolint:forbidigo
concurrentAccessCount.Add(-1)

return &persistence.EnqueueTaskResponse{Metadata: persistence.MessageMetadata{ID: 0}}, nil
}

// Launch concurrent writes to DIFFERENT target clusters (different DLQs)
for i := 0; i < numConcurrentWrites; i++ {
index := i
g.Go(func() error {
task := &tasks.WorkflowTask{
WorkflowKey: definition.WorkflowKey{
NamespaceID: string(tests.NamespaceID),
WorkflowID: tests.WorkflowID,
RunID: tests.RunID,
},
}
// Use different target clusters to create different queue keys
targetCluster := "target-cluster-" + string(rune('A'+index%numQueues))
err := writer.WriteTaskToDLQ(
context.Background(),
"source-cluster",
targetCluster,
1,
task,
true,
)
require.NoError(t, err)
return nil
})
}

require.NoError(t, g.Wait())

// Verify all writes succeeded
require.Len(t, queueWriter.EnqueueTaskRequests, numConcurrentWrites)

// Since these are different queues, they should be able to execute concurrently
// We expect to see more than 1 concurrent access, but less than or equal to numQueues.
assert.Greater(t, maxConcurrentAccess.Load(), int32(1),
"Expected concurrent access to different queues (> 1), but got %d.", maxConcurrentAccess.Load())
assert.LessOrEqual(t, maxConcurrentAccess.Load(), int32(numQueues),
"Expected less than %d concurrent accesses, but got %d.", numQueues, maxConcurrentAccess.Load())
}
17 changes: 16 additions & 1 deletion service/history/queues/queuestest/fake_queue_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,40 @@ package queuestest

import (
"context"
"sync"

"go.temporal.io/server/common/persistence"
"go.temporal.io/server/service/history/queues"
)

// EnqueueTaskFunc is a function type for custom EnqueueTask behavior in tests
type EnqueueTaskFunc func(context.Context, *persistence.EnqueueTaskRequest) (*persistence.EnqueueTaskResponse, error)

// FakeQueueWriter is a [queues.QueueWriter] which records the requests it receives and returns the provided errors.
type FakeQueueWriter struct {
mu sync.Mutex
EnqueueTaskRequests []*persistence.EnqueueTaskRequest
EnqueueTaskErr error
CreateQueueErr error
// EnqueueTaskFunc allows tests to provide custom behavior for EnqueueTask calls.
// If set, this function is called instead of the default behavior.
EnqueueTaskFunc EnqueueTaskFunc
}

var _ queues.QueueWriter = (*FakeQueueWriter)(nil)

func (d *FakeQueueWriter) EnqueueTask(
_ context.Context,
ctx context.Context,
request *persistence.EnqueueTaskRequest,
) (*persistence.EnqueueTaskResponse, error) {
// Protect the slice append from concurrent access
d.mu.Lock()
d.EnqueueTaskRequests = append(d.EnqueueTaskRequests, request)
d.mu.Unlock()

if d.EnqueueTaskFunc != nil {
return d.EnqueueTaskFunc(ctx, request)
}
return &persistence.EnqueueTaskResponse{Metadata: persistence.MessageMetadata{ID: 0}}, d.EnqueueTaskErr
}

Expand Down
Loading