Skip to content

Commit 4dcfe71

Browse files
committed
Add N on the fly
1 parent 3dbdd3d commit 4dcfe71

File tree

1 file changed

+70
-62
lines changed

1 file changed

+70
-62
lines changed

common/testing/testvars/test_vars.go

+70-62
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ type (
4747
testName string
4848
testHash uint32
4949
an Any
50-
kv sync.Map
50+
values sync.Map
51+
ns sync.Map
5152
}
5253
testNamer interface {
5354
Name() string
@@ -67,10 +68,26 @@ func newFromName(testName string) *TestVars {
6768
}
6869
}
6970

70-
func getOrCreate[T any](tv *TestVars, key string, initialValGen func(key string) T) T {
71-
v, _ := tv.kv.LoadOrStore(key, initialValGen(key))
71+
func getOrCreate[T any](tv *TestVars, key string, initialValGen func(key string) T, valNSetter func(val T, n int) T) T {
72+
v, _ := tv.values.LoadOrStore(key, initialValGen(key))
73+
74+
n, ok := tv.ns.Load(key)
75+
if !ok {
76+
//revive:disable-next-line:unchecked-type-assertion
77+
return v.(T)
78+
}
79+
7280
//revive:disable-next-line:unchecked-type-assertion
73-
return v.(T)
81+
return valNSetter(v.(T), n.(int))
82+
}
83+
84+
func (tv *TestVars) stringNSetter(v string, n int) string {
85+
return fmt.Sprintf("%s_%d", v, n)
86+
}
87+
88+
// Use this setter for entities that don't support setting n (like uuid).
89+
func unsupportedNSetter[T any](v T, _ int) T {
90+
panic(fmt.Sprintf("setting n on type %T is not supported", v))
7491
}
7592

7693
func (tv *TestVars) uniqueString(key string) string {
@@ -83,75 +100,66 @@ func (tv *TestVars) uuidString(_ string) string {
83100

84101
func (tv *TestVars) clone() *TestVars {
85102
tv2 := newFromName(tv.testName)
86-
tv.kv.Range(func(key, value any) bool {
87-
tv2.kv.Store(key, value)
103+
tv.values.Range(func(key, value any) bool {
104+
tv2.values.Store(key, value)
105+
return true
106+
})
107+
tv.ns.Range(func(key, value any) bool {
108+
tv2.ns.Store(key, value)
88109
return true
89110
})
90111

91112
return tv2
92113
}
93114
func (tv *TestVars) cloneSetVal(key string, val any) *TestVars {
94115
tv2 := tv.clone()
95-
tv2.kv.Store(key, val)
116+
tv2.values.Store(key, val)
96117
return tv2
97118
}
98119

99-
func (tv *TestVars) cloneSetN(key string, defaultValGen func(key string) string, n int) *TestVars {
120+
func (tv *TestVars) cloneSetN(key string, n int) *TestVars {
100121
tv2 := tv.clone()
101-
102-
v, ok := tv.kv.Load(key)
103-
if !ok {
104-
v = defaultValGen(key)
105-
}
106-
107-
vStr := valString(key, v)
108-
tv2.kv.Store(key, fmt.Sprintf("%s_%d", vStr, n))
109-
122+
tv2.ns.Store(key, n)
110123
return tv2
111124
}
112125

113-
func valString(key string, v any) string {
114-
vString, vIsString := v.(string)
115-
if !vIsString {
116-
vStringer, vIsStringer := v.(fmt.Stringer)
117-
if !vIsStringer {
118-
panic(fmt.Sprintf("value of key %s is of type %T but must be of type %T or implement fmt.Stringer", key, v, ""))
119-
}
120-
vString = vStringer.String()
121-
}
122-
return vString
123-
}
124-
125126
// ----------- Methods for every entity ------------
126127
// Add more as you need them following the pattern below.
127128
// Replace "Entity" with the name of the entity, i.e., UpdateID, ActivityType, etc.
128-
// Add only the necessary methods (in most cases only getter).
129+
// Add only the necessary methods (in most cases only first getter).
129130
/*
130131
func (tv *TestVars) Entity() string {
131-
return getOrCreate(tv, "entity", tv.uniqueString)
132+
return getOrCreate(tv, "entity", tv.uniqueString, tv.stringNSetter)
132133
}
133134
func (tv *TestVars) WithEntity(entity string) *TestVars {
134135
return tv.cloneSetVal("entity", entity)
135136
}
136137
func (tv *TestVars) WithEntityN(n int) *TestVars {
137-
return tv.cloneSetN("entity", tv.uniqueString, n)
138+
return tv.cloneSetN("entity", n)
138139
}
139140
*/
140141

141142
func (tv *TestVars) NamespaceID() namespace.ID {
142143
return getOrCreate(tv, "namespace_id", func(key string) namespace.ID {
143144
return namespace.ID(tv.uuidString(key))
144-
})
145+
},
146+
unsupportedNSetter,
147+
)
145148
}
146149

147150
func (tv *TestVars) WithNamespaceID(namespaceID namespace.ID) *TestVars {
148151
return tv.cloneSetVal("namespace_id", namespaceID)
149152
}
150153

151154
func (tv *TestVars) NamespaceName() namespace.Name {
152-
return getOrCreate(tv, "namespace_name", func(key string) namespace.Name {
153-
return namespace.Name(tv.uniqueString(key))
154-
})
155+
return getOrCreate(tv, "namespace_name",
156+
func(key string) namespace.Name {
157+
return namespace.Name(tv.uniqueString(key))
158+
},
159+
func(val namespace.Name, n int) namespace.Name {
160+
return namespace.Name(tv.stringNSetter(val.String(), n))
161+
},
162+
)
155163
}
156164

157165
func (tv *TestVars) WithNamespaceName(namespaceName namespace.Name) *TestVars {
@@ -177,15 +185,15 @@ func (tv *TestVars) Namespace() *namespace.Namespace {
177185
}
178186

179187
func (tv *TestVars) WorkflowID() string {
180-
return getOrCreate(tv, "workflow_id", tv.uniqueString)
188+
return getOrCreate(tv, "workflow_id", tv.uniqueString, tv.stringNSetter)
181189
}
182190

183191
func (tv *TestVars) WithWorkflowIDN(n int) *TestVars {
184-
return tv.cloneSetN("workflow_id", tv.uniqueString, n)
192+
return tv.cloneSetN("workflow_id", n)
185193
}
186194

187195
func (tv *TestVars) RunID() string {
188-
return getOrCreate(tv, "run_id", tv.uuidString)
196+
return getOrCreate(tv, "run_id", tv.uuidString, unsupportedNSetter)
189197
}
190198

191199
func (tv *TestVars) WithRunID(runID string) *TestVars {
@@ -200,23 +208,23 @@ func (tv *TestVars) WorkflowExecution() *commonpb.WorkflowExecution {
200208
}
201209

202210
func (tv *TestVars) RequestID() string {
203-
return getOrCreate(tv, "request_id", tv.uuidString)
211+
return getOrCreate(tv, "request_id", tv.uuidString, unsupportedNSetter)
204212
}
205213

206214
func (tv *TestVars) BuildID() string {
207-
return getOrCreate(tv, "build_id", tv.uniqueString)
215+
return getOrCreate(tv, "build_id", tv.uniqueString, tv.stringNSetter)
208216
}
209217

210218
func (tv *TestVars) WithBuildID(buildId string) *TestVars {
211219
return tv.cloneSetVal("build_id", buildId)
212220
}
213221

214222
func (tv *TestVars) WithBuildIDN(n int) *TestVars {
215-
return tv.cloneSetN("build_id", tv.uniqueString, n)
223+
return tv.cloneSetN("build_id", n)
216224
}
217225

218226
func (tv *TestVars) DeploymentSeries() string {
219-
return getOrCreate(tv, "deployment_series", tv.uniqueString)
227+
return getOrCreate(tv, "deployment_series", tv.uniqueString, tv.stringNSetter)
220228
}
221229

222230
func (tv *TestVars) WithDeploymentSeries(series string) *TestVars {
@@ -232,7 +240,7 @@ func (tv *TestVars) Deployment() *deploymentpb.Deployment {
232240

233241
func (tv *TestVars) TaskQueue() *taskqueuepb.TaskQueue {
234242
return &taskqueuepb.TaskQueue{
235-
Name: getOrCreate(tv, "task_queue", tv.uniqueString),
243+
Name: getOrCreate(tv, "task_queue", tv.uniqueString, tv.stringNSetter),
236244
Kind: enumspb.TASK_QUEUE_KIND_NORMAL,
237245
}
238246
}
@@ -242,12 +250,12 @@ func (tv *TestVars) WithTaskQueue(taskQueue string) *TestVars {
242250
}
243251

244252
func (tv *TestVars) WithTaskQueueN(n int) *TestVars {
245-
return tv.cloneSetN("task_queue", tv.uniqueString, n)
253+
return tv.cloneSetN("task_queue", n)
246254
}
247255

248256
func (tv *TestVars) StickyTaskQueue() *taskqueuepb.TaskQueue {
249257
return &taskqueuepb.TaskQueue{
250-
Name: getOrCreate(tv, "sticky_task_queue", tv.uniqueString),
258+
Name: getOrCreate(tv, "sticky_task_queue", tv.uniqueString, tv.stringNSetter),
251259
Kind: enumspb.TASK_QUEUE_KIND_STICKY,
252260
NormalName: tv.TaskQueue().Name,
253261
}
@@ -262,38 +270,38 @@ func (tv *TestVars) StickyExecutionAttributes(timeout time.Duration) *taskqueuep
262270

263271
func (tv *TestVars) WorkflowType() *commonpb.WorkflowType {
264272
return &commonpb.WorkflowType{
265-
Name: getOrCreate(tv, "workflow_type", tv.uniqueString),
273+
Name: getOrCreate(tv, "workflow_type", tv.uniqueString, tv.stringNSetter),
266274
}
267275
}
268276

269277
func (tv *TestVars) ActivityID() string {
270-
return getOrCreate(tv, "activity_id", tv.uniqueString)
278+
return getOrCreate(tv, "activity_id", tv.uniqueString, tv.stringNSetter)
271279
}
272280

273281
func (tv *TestVars) WithActivityIDN(n int) *TestVars {
274-
return tv.cloneSetN("activity_id", tv.uniqueString, n)
282+
return tv.cloneSetN("activity_id", n)
275283
}
276284

277285
func (tv *TestVars) ActivityType() *commonpb.ActivityType {
278286
return &commonpb.ActivityType{
279-
Name: getOrCreate(tv, "activity_type", tv.uniqueString),
287+
Name: getOrCreate(tv, "activity_type", tv.uniqueString, tv.stringNSetter),
280288
}
281289
}
282290

283291
func (tv *TestVars) MessageID() string {
284-
return getOrCreate(tv, "message_id", tv.uniqueString)
292+
return getOrCreate(tv, "message_id", tv.uniqueString, tv.stringNSetter)
285293
}
286294

287295
func (tv *TestVars) WithMessageIDN(n int) *TestVars {
288-
return tv.cloneSetN("message_id", tv.uniqueString, n)
296+
return tv.cloneSetN("message_id", n)
289297
}
290298

291299
func (tv *TestVars) UpdateID() string {
292-
return getOrCreate(tv, "update_id", tv.uniqueString)
300+
return getOrCreate(tv, "update_id", tv.uniqueString, tv.stringNSetter)
293301
}
294302

295303
func (tv *TestVars) WithUpdateIDN(n int) *TestVars {
296-
return tv.cloneSetN("update_id", tv.uniqueString, n)
304+
return tv.cloneSetN("update_id", n)
297305
}
298306

299307
func (tv *TestVars) UpdateRef() *updatepb.UpdateRef {
@@ -304,35 +312,35 @@ func (tv *TestVars) UpdateRef() *updatepb.UpdateRef {
304312
}
305313

306314
func (tv *TestVars) HandlerName() string {
307-
return getOrCreate(tv, "handler_name", tv.uniqueString)
315+
return getOrCreate(tv, "handler_name", tv.uniqueString, tv.stringNSetter)
308316
}
309317

310318
func (tv *TestVars) ClientIdentity() string {
311-
return getOrCreate(tv, "client_identity", tv.uniqueString)
319+
return getOrCreate(tv, "client_identity", tv.uniqueString, tv.stringNSetter)
312320
}
313321

314322
func (tv *TestVars) WorkerIdentity() string {
315-
return getOrCreate(tv, "worker_identity", tv.uniqueString)
323+
return getOrCreate(tv, "worker_identity", tv.uniqueString, tv.stringNSetter)
316324
}
317325

318326
func (tv *TestVars) TimerID() string {
319-
return getOrCreate(tv, "timer_id", tv.uniqueString)
327+
return getOrCreate(tv, "timer_id", tv.uniqueString, tv.stringNSetter)
320328
}
321329

322330
func (tv *TestVars) WithTimerIDN(n int) *TestVars {
323-
return tv.cloneSetN("timer_id", tv.uniqueString, n)
331+
return tv.cloneSetN("timer_id", n)
324332
}
325333

326334
func (tv *TestVars) QueryType() string {
327-
return getOrCreate(tv, "query_type", tv.uniqueString)
335+
return getOrCreate(tv, "query_type", tv.uniqueString, tv.stringNSetter)
328336
}
329337

330338
func (tv *TestVars) SignalName() string {
331-
return getOrCreate(tv, "signal_name", tv.uniqueString)
339+
return getOrCreate(tv, "signal_name", tv.uniqueString, tv.stringNSetter)
332340
}
333341

334342
func (tv *TestVars) IndexName() string {
335-
return getOrCreate(tv, "index_name", tv.uniqueString)
343+
return getOrCreate(tv, "index_name", tv.uniqueString, tv.stringNSetter)
336344
}
337345

338346
// ----------- Generic methods ------------

0 commit comments

Comments
 (0)