diff --git a/collection/skipmap/gen_func.go b/collection/skipmap/gen_func.go index c482f9b..fb9243f 100644 --- a/collection/skipmap/gen_func.go +++ b/collection/skipmap/gen_func.go @@ -137,15 +137,21 @@ func (s *FuncMap[keyT, valueT]) Store(key keyT, value valueT) { for { nodeFound := s.findNode(key, &preds, &succs) if nodeFound != nil { // indicating the key is already in the skip-list - if !nodeFound.flags.Get(marked) { - // We don't need to care about whether or not the node is fully linked, - // just replace the value. - nodeFound.storeVal(value) - return + if !nodeFound.flags.MGet(fullyLinked|marked, fullyLinked) { + // If the node is not fully linked or is marked for deletion, + // we need to retry in the next loop. + continue } - // If the node is marked, represents some other goroutines is in the process of deleting this node, - // we need to add this node in next loop. - continue + // Lock the node to prevent a concurrent delete from + // marking and unlinking it while we update the value. + nodeFound.mu.Lock() + if nodeFound.flags.Get(marked) { + nodeFound.mu.Unlock() + continue + } + nodeFound.storeVal(value) + nodeFound.mu.Unlock() + return } // Add this node into skip list. var ( @@ -302,14 +308,12 @@ func (s *FuncMap[keyT, valueT]) LoadOrStore(key keyT, value valueT) (actual valu for { nodeFound := s.findNode(key, &preds, &succs) if nodeFound != nil { // indicating the key is already in the skip-list - if !nodeFound.flags.Get(marked) { - // We don't need to care about whether or not the node is fully linked, - // just return the value. - return nodeFound.loadVal(), true + if !nodeFound.flags.MGet(fullyLinked|marked, fullyLinked) { + // If the node is not fully linked or is marked for deletion, + // we need to retry in the next loop. + continue } - // If the node is marked, represents some other goroutines is in the process of deleting this node, - // we need to add this node in next loop. - continue + return nodeFound.loadVal(), true } // Add this node into skip list. var ( @@ -371,14 +375,12 @@ func (s *FuncMap[keyT, valueT]) LoadOrStoreLazy(key keyT, f func() valueT) (actu for { nodeFound := s.findNode(key, &preds, &succs) if nodeFound != nil { // indicating the key is already in the skip-list - if !nodeFound.flags.Get(marked) { - // We don't need to care about whether or not the node is fully linked, - // just return the value. - return nodeFound.loadVal(), true + if !nodeFound.flags.MGet(fullyLinked|marked, fullyLinked) { + // If the node is not fully linked or is marked for deletion, + // we need to retry in the next loop. + continue } - // If the node is marked, represents some other goroutines is in the process of deleting this node, - // we need to add this node in next loop. - continue + return nodeFound.loadVal(), true } // Add this node into skip list. var ( diff --git a/collection/skipmap/gen_ordered.go b/collection/skipmap/gen_ordered.go index 2239623..ec302f8 100644 --- a/collection/skipmap/gen_ordered.go +++ b/collection/skipmap/gen_ordered.go @@ -137,15 +137,21 @@ func (s *OrderedMap[keyT, valueT]) Store(key keyT, value valueT) { for { nodeFound := s.findNode(key, &preds, &succs) if nodeFound != nil { // indicating the key is already in the skip-list - if !nodeFound.flags.Get(marked) { - // We don't need to care about whether or not the node is fully linked, - // just replace the value. - nodeFound.storeVal(value) - return + if !nodeFound.flags.MGet(fullyLinked|marked, fullyLinked) { + // If the node is not fully linked or is marked for deletion, + // we need to retry in the next loop. + continue } - // If the node is marked, represents some other goroutines is in the process of deleting this node, - // we need to add this node in next loop. - continue + // Lock the node to prevent a concurrent delete from + // marking and unlinking it while we update the value. + nodeFound.mu.Lock() + if nodeFound.flags.Get(marked) { + nodeFound.mu.Unlock() + continue + } + nodeFound.storeVal(value) + nodeFound.mu.Unlock() + return } // Add this node into skip list. var ( @@ -302,14 +308,12 @@ func (s *OrderedMap[keyT, valueT]) LoadOrStore(key keyT, value valueT) (actual v for { nodeFound := s.findNode(key, &preds, &succs) if nodeFound != nil { // indicating the key is already in the skip-list - if !nodeFound.flags.Get(marked) { - // We don't need to care about whether or not the node is fully linked, - // just return the value. - return nodeFound.loadVal(), true + if !nodeFound.flags.MGet(fullyLinked|marked, fullyLinked) { + // If the node is not fully linked or is marked for deletion, + // we need to retry in the next loop. + continue } - // If the node is marked, represents some other goroutines is in the process of deleting this node, - // we need to add this node in next loop. - continue + return nodeFound.loadVal(), true } // Add this node into skip list. var ( @@ -371,14 +375,12 @@ func (s *OrderedMap[keyT, valueT]) LoadOrStoreLazy(key keyT, f func() valueT) (a for { nodeFound := s.findNode(key, &preds, &succs) if nodeFound != nil { // indicating the key is already in the skip-list - if !nodeFound.flags.Get(marked) { - // We don't need to care about whether or not the node is fully linked, - // just return the value. - return nodeFound.loadVal(), true + if !nodeFound.flags.MGet(fullyLinked|marked, fullyLinked) { + // If the node is not fully linked or is marked for deletion, + // we need to retry in the next loop. + continue } - // If the node is marked, represents some other goroutines is in the process of deleting this node, - // we need to add this node in next loop. - continue + return nodeFound.loadVal(), true } // Add this node into skip list. var ( diff --git a/collection/skipmap/gen_ordereddesc.go b/collection/skipmap/gen_ordereddesc.go index 3f94634..9632424 100644 --- a/collection/skipmap/gen_ordereddesc.go +++ b/collection/skipmap/gen_ordereddesc.go @@ -137,15 +137,21 @@ func (s *OrderedMapDesc[keyT, valueT]) Store(key keyT, value valueT) { for { nodeFound := s.findNode(key, &preds, &succs) if nodeFound != nil { // indicating the key is already in the skip-list - if !nodeFound.flags.Get(marked) { - // We don't need to care about whether or not the node is fully linked, - // just replace the value. - nodeFound.storeVal(value) - return + if !nodeFound.flags.MGet(fullyLinked|marked, fullyLinked) { + // If the node is not fully linked or is marked for deletion, + // we need to retry in the next loop. + continue } - // If the node is marked, represents some other goroutines is in the process of deleting this node, - // we need to add this node in next loop. - continue + // Lock the node to prevent a concurrent delete from + // marking and unlinking it while we update the value. + nodeFound.mu.Lock() + if nodeFound.flags.Get(marked) { + nodeFound.mu.Unlock() + continue + } + nodeFound.storeVal(value) + nodeFound.mu.Unlock() + return } // Add this node into skip list. var ( @@ -302,14 +308,12 @@ func (s *OrderedMapDesc[keyT, valueT]) LoadOrStore(key keyT, value valueT) (actu for { nodeFound := s.findNode(key, &preds, &succs) if nodeFound != nil { // indicating the key is already in the skip-list - if !nodeFound.flags.Get(marked) { - // We don't need to care about whether or not the node is fully linked, - // just return the value. - return nodeFound.loadVal(), true + if !nodeFound.flags.MGet(fullyLinked|marked, fullyLinked) { + // If the node is not fully linked or is marked for deletion, + // we need to retry in the next loop. + continue } - // If the node is marked, represents some other goroutines is in the process of deleting this node, - // we need to add this node in next loop. - continue + return nodeFound.loadVal(), true } // Add this node into skip list. var ( @@ -371,14 +375,12 @@ func (s *OrderedMapDesc[keyT, valueT]) LoadOrStoreLazy(key keyT, f func() valueT for { nodeFound := s.findNode(key, &preds, &succs) if nodeFound != nil { // indicating the key is already in the skip-list - if !nodeFound.flags.Get(marked) { - // We don't need to care about whether or not the node is fully linked, - // just return the value. - return nodeFound.loadVal(), true + if !nodeFound.flags.MGet(fullyLinked|marked, fullyLinked) { + // If the node is not fully linked or is marked for deletion, + // we need to retry in the next loop. + continue } - // If the node is marked, represents some other goroutines is in the process of deleting this node, - // we need to add this node in next loop. - continue + return nodeFound.loadVal(), true } // Add this node into skip list. var ( diff --git a/collection/skipmap/skipmap.tpl b/collection/skipmap/skipmap.tpl index eb81bcd..961e0ff 100644 --- a/collection/skipmap/skipmap.tpl +++ b/collection/skipmap/skipmap.tpl @@ -1,3 +1,17 @@ +// Copyright 2025 Bytedance Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + // Code generated by gen.go; DO NOT EDIT. package {{.Package}} @@ -121,15 +135,21 @@ func (s *{{.StructPrefix}}Map{{.StructSuffix}}{{.TypeArgument}}) Store(key {{.Ke for { nodeFound := s.findNode(key, &preds, &succs) if nodeFound != nil { // indicating the key is already in the skip-list - if !nodeFound.flags.Get(marked) { - // We don't need to care about whether or not the node is fully linked, - // just replace the value. - nodeFound.storeVal(value) - return + if !nodeFound.flags.MGet(fullyLinked|marked, fullyLinked) { + // If the node is not fully linked or is marked for deletion, + // we need to retry in the next loop. + continue } - // If the node is marked, represents some other goroutines is in the process of deleting this node, - // we need to add this node in next loop. - continue + // Lock the node to prevent a concurrent delete from + // marking and unlinking it while we update the value. + nodeFound.mu.Lock() + if nodeFound.flags.Get(marked) { + nodeFound.mu.Unlock() + continue + } + nodeFound.storeVal(value) + nodeFound.mu.Unlock() + return } // Add this node into skip list. var ( @@ -286,14 +306,12 @@ func (s *{{.StructPrefix}}Map{{.StructSuffix}}{{.TypeArgument}}) LoadOrStore(key for { nodeFound := s.findNode(key, &preds, &succs) if nodeFound != nil { // indicating the key is already in the skip-list - if !nodeFound.flags.Get(marked) { - // We don't need to care about whether or not the node is fully linked, - // just return the value. - return nodeFound.loadVal(), true + if !nodeFound.flags.MGet(fullyLinked|marked, fullyLinked) { + // If the node is not fully linked or is marked for deletion, + // we need to retry in the next loop. + continue } - // If the node is marked, represents some other goroutines is in the process of deleting this node, - // we need to add this node in next loop. - continue + return nodeFound.loadVal(), true } // Add this node into skip list. var ( @@ -355,14 +373,12 @@ func (s *{{.StructPrefix}}Map{{.StructSuffix}}{{.TypeArgument}}) LoadOrStoreLazy for { nodeFound := s.findNode(key, &preds, &succs) if nodeFound != nil { // indicating the key is already in the skip-list - if !nodeFound.flags.Get(marked) { - // We don't need to care about whether or not the node is fully linked, - // just return the value. - return nodeFound.loadVal(), true + if !nodeFound.flags.MGet(fullyLinked|marked, fullyLinked) { + // If the node is not fully linked or is marked for deletion, + // we need to retry in the next loop. + continue } - // If the node is marked, represents some other goroutines is in the process of deleting this node, - // we need to add this node in next loop. - continue + return nodeFound.loadVal(), true } // Add this node into skip list. var ( diff --git a/collection/skipmap/skipmap_test.go b/collection/skipmap/skipmap_test.go index c2f9ff2..dbc3ec9 100644 --- a/collection/skipmap/skipmap_test.go +++ b/collection/skipmap/skipmap_test.go @@ -686,3 +686,93 @@ func testSkipMapIntUnmarshalJSON[T int | uint](t *testing.T, newset func() anysk } } } + +// Store can race with LoadAndDelete such that Store writes to a node that is +// concurrently marked and unlinked, silently losing the value. To detect this, +// set key=oldValue, then race Store(key, newValue) vs LoadAndDelete(key). If +// the delete returns oldValue it went first, so the key must still exist with +// newValue afterward. Finding the key absent means the Store was lost. +// +// See: https://github.com/bytedance/gg/issues/36 +func TestStoreLoadAndDeleteRace(t *testing.T) { + const key = "k" + rounds := 100_000 + if testing.Short() { + rounds = 10_000 + } + + for round := 0; round < rounds; round++ { + m := New[string, int]() + oldValue := round*2 + 1 + newValue := round*2 + 2 + + m.Store(key, oldValue) + + var wg sync.WaitGroup + var delValue int + var delLoaded bool + + wg.Add(2) + go func() { + defer wg.Done() + m.Store(key, newValue) + }() + go func() { + defer wg.Done() + delValue, delLoaded = m.LoadAndDelete(key) + }() + + wg.Wait() + + val, exists := m.Load(key) + + if delLoaded && delValue == oldValue && !exists { + t.Fatalf("round %d: Store(%s, %d) lost: LoadAndDelete returned old=%d but key is absent", + round, key, newValue, oldValue) + } + + if exists && val != newValue { + t.Fatalf("round %d: key has value %d, want %d", round, val, newValue) + } + } +} + +// LoadOrStore can return a value from a node that hasn't been fully linked +// into the skip list yet. Load checks the fullyLinked flag and rejects such +// nodes, so it returns nil for a key that LoadOrStore just reported as present. +// Race 8 goroutines doing LoadOrStore on the same absent key; since nothing +// deletes the key, every goroutine's follow-up Load must succeed. +// +// See: https://github.com/bytedance/gg/issues/36 +func TestLoadOrStoreLoadRace(t *testing.T) { + const key = "k" + rounds := 100_000 + if testing.Short() { + rounds = 10_000 + } + + for round := 0; round < rounds; round++ { + m := New[string, int]() + + var wg sync.WaitGroup + var failed int32 + + for i := 0; i < 8; i++ { + wg.Add(1) + go func(v int) { + defer wg.Done() + m.LoadOrStore(key, v) + if _, ok := m.Load(key); !ok { + atomic.AddInt32(&failed, 1) + } + }(round*8 + i) + } + + wg.Wait() + + if n := atomic.LoadInt32(&failed); n > 0 { + t.Fatalf("round %d: Load returned nil %d times after LoadOrStore (no deletes running)", + round, n) + } + } +}