Skip to content

Commit a2ffd62

Browse files
committed
fix mget test
1 parent e0b122a commit a2ffd62

File tree

1 file changed

+65
-13
lines changed

1 file changed

+65
-13
lines changed

osscluster_router.go

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -335,16 +335,41 @@ func (c *ClusterClient) executeParallel(ctx context.Context, cmd Cmder, nodes []
335335

336336
// aggregateMultiSlotResults aggregates results from multi-slot execution
337337
func (c *ClusterClient) aggregateMultiSlotResults(ctx context.Context, cmd Cmder, results <-chan slotResult, keyOrder []string, policy *routing.CommandPolicy) error {
338-
keyedResults := make(map[string]Cmder)
338+
keyedResults := make(map[string]interface{})
339339
var firstErr error
340340

341341
for result := range results {
342342
if result.err != nil && firstErr == nil {
343343
firstErr = result.err
344344
}
345-
if result.cmd != nil {
346-
for _, key := range result.keys {
347-
keyedResults[key] = result.cmd
345+
if result.cmd != nil && result.err == nil {
346+
// For MGET, extract individual values from the array result
347+
if strings.ToLower(cmd.Name()) == "mget" {
348+
if sliceCmd, ok := result.cmd.(*SliceCmd); ok {
349+
values := sliceCmd.Val()
350+
if len(values) == len(result.keys) {
351+
for i, key := range result.keys {
352+
keyedResults[key] = values[i]
353+
}
354+
} else {
355+
// Fallback: map all keys to the entire result
356+
for _, key := range result.keys {
357+
keyedResults[key] = values
358+
}
359+
}
360+
} else {
361+
// Fallback for non-SliceCmd results
362+
value := ExtractCommandValue(result.cmd)
363+
for _, key := range result.keys {
364+
keyedResults[key] = value
365+
}
366+
}
367+
} else {
368+
// For other commands, map each key to the entire result
369+
value := ExtractCommandValue(result.cmd)
370+
for _, key := range result.keys {
371+
keyedResults[key] = value
372+
}
348373
}
349374
}
350375
}
@@ -354,7 +379,36 @@ func (c *ClusterClient) aggregateMultiSlotResults(ctx context.Context, cmd Cmder
354379
return firstErr
355380
}
356381

357-
return c.aggregateKeyedResponses(cmd, keyedResults, keyOrder, policy)
382+
return c.aggregateKeyedValues(cmd, keyedResults, keyOrder, policy)
383+
}
384+
385+
// aggregateKeyedValues aggregates individual key-value pairs while preserving key order
386+
func (c *ClusterClient) aggregateKeyedValues(cmd Cmder, keyedResults map[string]interface{}, keyOrder []string, policy *routing.CommandPolicy) error {
387+
if len(keyedResults) == 0 {
388+
return fmt.Errorf("redis: no results to aggregate")
389+
}
390+
391+
aggregator := c.createAggregator(policy, cmd, true)
392+
393+
// Set key order for keyed aggregators
394+
if keyedAgg, ok := aggregator.(*routing.DefaultKeyedAggregator); ok {
395+
keyedAgg.SetKeyOrder(keyOrder)
396+
}
397+
398+
// Add results with keys
399+
for key, value := range keyedResults {
400+
if keyedAgg, ok := aggregator.(*routing.DefaultKeyedAggregator); ok {
401+
if err := keyedAgg.AddWithKey(key, value, nil); err != nil {
402+
return err
403+
}
404+
} else {
405+
if err := aggregator.Add(value, nil); err != nil {
406+
return err
407+
}
408+
}
409+
}
410+
411+
return c.finishAggregation(cmd, aggregator)
358412
}
359413

360414
// aggregateKeyedResponses aggregates responses while preserving key order
@@ -418,15 +472,13 @@ func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *rout
418472

419473
// createAggregator creates the appropriate response aggregator
420474
func (c *ClusterClient) createAggregator(policy *routing.CommandPolicy, cmd Cmder, isKeyed bool) routing.ResponseAggregator {
475+
cmdName := strings.ToLower(cmd.Name())
476+
// For MGET without policy, use keyed aggregator
477+
if cmdName == "mget" {
478+
return routing.NewDefaultAggregator(true)
479+
}
480+
421481
if policy != nil {
422-
// For specific multi-shard commands that need keyed aggregation despite having
423-
// all_succeeded policy (like MGET which needs to preserve key order)
424-
if policy.Request == routing.ReqMultiShard && policy.Response == routing.RespAllSucceeded && isKeyed {
425-
cmdName := strings.ToLower(cmd.Name())
426-
if cmdName == "mget" {
427-
return routing.NewDefaultAggregator(true)
428-
}
429-
}
430482
return routing.NewResponseAggregator(policy.Response, cmd.Name())
431483
}
432484

0 commit comments

Comments
 (0)