Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support atomic writes in the docstore #3500

Merged
merged 3 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions docstore/awsdynamodb/dynamo.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,12 @@ func (c *collection) RevisionField() string { return c.opts.RevisionField }

func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, opts *driver.RunActionsOptions) driver.ActionListError {
errs := make([]error, len(actions))
beforeGets, gets, writes, afterGets := driver.GroupActions(actions)
beforeGets, gets, writes, writesTx, afterGets := driver.GroupActions(actions)
c.runGets(ctx, beforeGets, errs, opts)
ch := make(chan struct{})
ch2 := make(chan struct{})
go func() { defer close(ch); c.runWrites(ctx, writes, errs, opts) }()
go func() { defer close(ch2); c.transactWrite(ctx, writesTx, errs, opts) }()
c.runGets(ctx, gets, errs, opts)
<-ch
jba marked this conversation as resolved.
Show resolved Hide resolved
c.runGets(ctx, afterGets, errs, opts)
Expand Down Expand Up @@ -613,25 +615,26 @@ func revisionPrecondition(doc driver.Document, revField string) (*expression.Con
return &cb, nil
}

// TODO(jba): use this if/when we support atomic writes.
func (c *collection) transactWrite(ctx context.Context, actions []*driver.Action, errs []error, opts *driver.RunActionsOptions, start, end int) {
func (c *collection) transactWrite(ctx context.Context, actions []*driver.Action, errs []error, opts *driver.RunActionsOptions) {
jba marked this conversation as resolved.
Show resolved Hide resolved
if len(actions) == 0 {
return
}
setErr := func(err error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea here is that if anything in this function fails, everything fails (since it's atomic). Document that here.

for i := start; i <= end; i++ {
errs[actions[i].Index] = err
for _, a := range actions {
errs[a.Index] = err
}
}

tws := make([]*dyn.TransactWriteItem, 0, len(actions))
var ops []*writeOp
tws := make([]*dyn.TransactWriteItem, 0, end-start+1)
for i := start; i <= end; i++ {
a := actions[i]
op, err := c.newWriteOp(a, opts)
for _, w := range actions {
op, err := c.newWriteOp(w, opts)
if err != nil {
setErr(err)
return
errs[w.Index] = err
jba marked this conversation as resolved.
Show resolved Hide resolved
} else {
ops = append(ops, op)
tws = append(tws, op.writeItem)
}
ops = append(ops, op)
tws = append(tws, op.writeItem)
}

in := &dyn.TransactWriteItemsInput{
Expand Down
39 changes: 24 additions & 15 deletions docstore/docstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,20 @@ func (c *Collection) Actions() *ActionList {
// document; a Get after the write will see the new value if the service is strongly
// consistent, but may see the old value if the service is eventually consistent.
type ActionList struct {
coll *Collection
actions []*Action
beforeDo func(asFunc func(interface{}) bool) error
coll *Collection
actions []*Action
enableAtomicWrites bool
beforeDo func(asFunc func(interface{}) bool) error
}

// An Action is a read or write on a single document.
// Use the methods of ActionList to create and execute Actions.
type Action struct {
kind driver.ActionKind
doc Document
fieldpaths []FieldPath // paths to retrieve, for Get
mods Mods // modifications to make, for Update
kind driver.ActionKind
doc Document
fieldpaths []FieldPath // paths to retrieve, for Get
mods Mods // modifications to make, for Update
inAtomicWrite bool // if this action is a part of atomic writes
}

func (l *ActionList) add(a *Action) *ActionList {
Expand All @@ -170,7 +172,7 @@ func (l *ActionList) add(a *Action) *ActionList {
// Except for setting the revision field and possibly setting the key fields, the doc
// argument is not modified.
func (l *ActionList) Create(doc Document) *ActionList {
return l.add(&Action{kind: driver.Create, doc: doc})
return l.add(&Action{kind: driver.Create, doc: doc, inAtomicWrite: l.enableAtomicWrites})
}

// Replace adds an action that replaces a document to the given ActionList, and
Expand All @@ -182,7 +184,7 @@ func (l *ActionList) Create(doc Document) *ActionList {
// See the Revisions section of the package documentation for how revisions are
// handled.
func (l *ActionList) Replace(doc Document) *ActionList {
return l.add(&Action{kind: driver.Replace, doc: doc})
return l.add(&Action{kind: driver.Replace, doc: doc, inAtomicWrite: l.enableAtomicWrites})
}

// Put adds an action that adds or replaces a document to the given ActionList, and returns the ActionList.
Expand All @@ -195,7 +197,7 @@ func (l *ActionList) Replace(doc Document) *ActionList {
// See the Revisions section of the package documentation for how revisions are
// handled.
func (l *ActionList) Put(doc Document) *ActionList {
return l.add(&Action{kind: driver.Put, doc: doc})
return l.add(&Action{kind: driver.Put, doc: doc, inAtomicWrite: l.enableAtomicWrites})
}

// Delete adds an action that deletes a document to the given ActionList, and returns
Expand All @@ -210,7 +212,7 @@ func (l *ActionList) Delete(doc Document) *ActionList {
// semantics of an action list are to stop at first error, then we might abort a
// list of Deletes just because one of the docs was not present, and that seems
// wrong, or at least something you'd want to turn off.
return l.add(&Action{kind: driver.Delete, doc: doc})
return l.add(&Action{kind: driver.Delete, doc: doc, inAtomicWrite: l.enableAtomicWrites})
}

// Get adds an action that retrieves a document to the given ActionList, and
Expand Down Expand Up @@ -252,9 +254,10 @@ func (l *ActionList) Get(doc Document, fps ...FieldPath) *ActionList {
// the updated document, call Get after calling Update.
func (l *ActionList) Update(doc Document, mods Mods) *ActionList {
return l.add(&Action{
kind: driver.Update,
doc: doc,
mods: mods,
kind: driver.Update,
doc: doc,
mods: mods,
inAtomicWrite: l.enableAtomicWrites,
})
}

Expand Down Expand Up @@ -430,7 +433,7 @@ func (c *Collection) toDriverAction(a *Action) (*driver.Action, error) {
// A Put with a revision field is equivalent to a Replace.
kind = driver.Replace
}
d := &driver.Action{Kind: kind, Doc: ddoc, Key: key}
d := &driver.Action{Kind: kind, Doc: ddoc, Key: key, InAtomicWrite: a.inAtomicWrite}
if a.fieldpaths != nil {
d.FieldPaths, err = parseFieldPaths(a.fieldpaths)
if err != nil {
Expand Down Expand Up @@ -534,6 +537,12 @@ func (l *ActionList) String() string {
return "[" + strings.Join(as, ", ") + "]"
}

// AtomicWrites causes all following writes in the list to execute atomically.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"to execute as a single atomic operation"

(to avoid confusion with the fact that each write already happens atomically by itself)

func (l *ActionList) AtomicWrites() *ActionList {
l.enableAtomicWrites = true
return l
}

func (a *Action) String() string {
buf := &strings.Builder{}
fmt.Fprintf(buf, "%s(%v", a.kind, a.doc)
Expand Down
14 changes: 7 additions & 7 deletions docstore/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ const (

//go:generate stringer -type=ActionKind

// An Action describes a single operation on a single document.
type Action struct {
Kind ActionKind // the kind of action
Doc Document // the document on which to perform the action
Key interface{} // the document key returned by Collection.Key, to avoid recomputing it
FieldPaths [][]string // field paths to retrieve, for Get only
Mods []Mod // modifications to make, for Update only
Index int // the index of the action in the original action list
Kind ActionKind // the kind of action
Doc Document // the document on which to perform the action
Key interface{} // the document key returned by Collection.Key, to avoid recomputing it
FieldPaths [][]string // field paths to retrieve, for Get only
Mods []Mod // modifications to make, for Update only
Index int // the index of the action in the original action list
InAtomicWrite bool // if this action is a part of transaction
}

// A Mod is a modification to a field path in a document.
Expand Down
22 changes: 18 additions & 4 deletions docstore/driver/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ func SplitActions(actions []*Action, split func(a, b *Action) bool) [][]*Action

// GroupActions separates actions into four sets: writes, gets that must happen before the writes,
// gets that must happen after the writes, and gets that can happen concurrently with the writes.
func GroupActions(actions []*Action) (beforeGets, getList, writeList, afterGets []*Action) {
func GroupActions(actions []*Action) (beforeGets, getList, writeList, writesTxList, afterGets []*Action) {
// maps from key to action
bgets := map[interface{}]*Action{}
agets := map[interface{}]*Action{}
cgets := map[interface{}]*Action{}
writes := map[interface{}]*Action{}
writesTx := map[interface{}]*Action{}
var nilkeys []*Action
for _, a := range actions {
if a.Key == nil {
Expand All @@ -69,7 +70,7 @@ func GroupActions(actions []*Action) (beforeGets, getList, writeList, afterGets
} else if a.Kind == Get {
// If there was a prior write with this key, make sure this get
// happens after the writes.
if _, ok := writes[a.Key]; ok {
if valueExistsInMaps(a.Key, writes, writesTx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need a whole function.

if _, ok := writes[a.Key]; ok {
    agets[a.Key] = a
} else if _, ok :=  writesTx[a.Key]; ok {
   agets[a.Key] = a
} else ...

agets[a.Key] = a
} else {
cgets[a.Key] = a
Expand All @@ -81,7 +82,11 @@ func GroupActions(actions []*Action) (beforeGets, getList, writeList, afterGets
delete(cgets, a.Key)
bgets[a.Key] = g
}
writes[a.Key] = a
if a.InAtomicWrite {
writesTx[a.Key] = a
} else {
writes[a.Key] = a
}
}
}

Expand All @@ -95,7 +100,16 @@ func GroupActions(actions []*Action) (beforeGets, getList, writeList, afterGets
return as
}

return vals(bgets), vals(cgets), append(vals(writes), nilkeys...), vals(agets)
return vals(bgets), vals(cgets), append(vals(writes), nilkeys...), vals(writesTx), vals(agets)
}

func valueExistsInMaps(key interface{}, maps ...map[interface{}]*Action) bool {
jba marked this conversation as resolved.
Show resolved Hide resolved
for _, m := range maps {
if _, ok := m[key]; ok {
return true
}
}
return false
}

// AsFunc creates and returns an "as function" that behaves as follows:
Expand Down
12 changes: 6 additions & 6 deletions docstore/driver/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestGroupActions(t *testing.T) {
}{
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add various cases involving atomic writes, mixed with other writes and gets before, after, and concurrently.

in: []*Action{{Kind: Get, Key: 1}},
want: [][]int{nil, {0}, nil, nil},
want: [][]int{nil, {0}, nil, nil, nil},
},
{
in: []*Action{
Expand All @@ -89,16 +89,16 @@ func TestGroupActions(t *testing.T) {
{Kind: Replace, Key: 2},
{Kind: Get, Key: 2},
},
want: [][]int{{0}, {1}, {2, 3}, {4}},
want: [][]int{{0}, {1}, {2, 3}, nil, {4}},
},
{
in: []*Action{{Kind: Create}, {Kind: Create}, {Kind: Create}},
want: [][]int{nil, nil, {0, 1, 2}, nil},
want: [][]int{nil, nil, {0, 1, 2}, nil, nil},
},
} {
got := make([][]*Action, 4)
got[0], got[1], got[2], got[3] = GroupActions(test.in)
want := make([][]*Action, 4)
got := make([][]*Action, 5)
got[0], got[1], got[2], got[3], got[4] = GroupActions(test.in)
want := make([][]*Action, 5)
for i, s := range test.want {
for _, x := range s {
want[i] = append(want[i], test.in[x])
Expand Down
76 changes: 76 additions & 0 deletions docstore/drivertest/drivertest.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
jba marked this conversation as resolved.
Show resolved Hide resolved
"gocloud.dev/docstore"
"gocloud.dev/docstore/driver"
"gocloud.dev/gcerrors"
Expand Down Expand Up @@ -1900,6 +1901,81 @@ func testMultipleActions(t *testing.T, coll *docstore.Collection, revField strin
}
}

func testAtomicWrites(t *testing.T, coll *docstore.Collection, revField string) {
t.Helper()

ctx := context.Background()

must := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}

var docs []docmap
for i := 0; i < 9; i++ {
docs = append(docs, docmap{
KeyField: fmt.Sprintf("testAtomicWrites%d", i),
"s": fmt.Sprint(i),
revField: nil,
})
}

compare := func(gots, wants []docmap) {
t.Helper()
for i := 0; i < len(gots); i++ {
got := gots[i]
want := clone(wants[i])
want[revField] = got[revField]
if !cmp.Equal(got, want, cmpopts.IgnoreUnexported(tspb.Timestamp{})) {
t.Errorf("index #%d:\ngot %v\nwant %v", i, got, want)
}
}
}

// Put the first six docs.
actions := coll.Actions()
for i := 0; i < 6; i++ {
actions.Create(docs[i])
}
must(actions.Do(ctx))

// Delete the first three, get the second three, and update last three in transaction.
gdocs := []docmap{
{KeyField: docs[3][KeyField]},
{KeyField: docs[4][KeyField]},
{KeyField: docs[5][KeyField]},
}
actions = coll.Actions()
actions.Get(gdocs[0])
actions.Delete(docs[0])
actions.Delete(docs[1])
actions.Get(gdocs[1])
actions.Delete(docs[2])
actions.Get(gdocs[2])
actions.AtomicWrites()
actions.Update(docs[6], docstore.Mods{"s": "66'"})
actions.Update(docs[7], docstore.Mods{"s": "77'"})
actions.Update(docs[8], docstore.Mods{"s": "88"})

must(actions.Do(ctx))
compare(gdocs, docs[3:6])

// Get the docs updated as part of atomic writes and verify that got written.
actions = coll.Actions()

doc := docmap{KeyField: docs[6][KeyField]}
_ = coll.Get(ctx, doc)
assert.Equal(t, "66", doc["s"])
doc = docmap{KeyField: docs[7][KeyField]}
_ = coll.Get(ctx, doc)
assert.Equal(t, "77", doc["s"])
doc = docmap{KeyField: docs[8][KeyField]}
_ = coll.Get(ctx, doc)
assert.Equal(t, "88", doc["s"])
}
jba marked this conversation as resolved.
Show resolved Hide resolved

func testActionsOnStructNoRev(t *testing.T, _ Harness, coll *docstore.Collection) {
t.Helper()

Expand Down
2 changes: 1 addition & 1 deletion docstore/gcpfirestore/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func (c *collection) RevisionField() string {
// RunActions implements driver.RunActions.
func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, opts *driver.RunActionsOptions) driver.ActionListError {
errs := make([]error, len(actions))
beforeGets, gets, writes, afterGets := driver.GroupActions(actions)
beforeGets, gets, writes, _, afterGets := driver.GroupActions(actions)
calls := c.buildCommitCalls(writes, errs)
// runGets does not issue concurrent RPCs, so it doesn't need a throttle.
c.runGets(ctx, beforeGets, errs, opts)
Expand Down
2 changes: 1 addition & 1 deletion docstore/memdocstore/mem.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, o
}
}

beforeGets, gets, writes, afterGets := driver.GroupActions(actions)
beforeGets, gets, writes, _, afterGets := driver.GroupActions(actions)
sandeepvinayak marked this conversation as resolved.
Show resolved Hide resolved
run(beforeGets)
run(gets)
run(writes)
Expand Down
2 changes: 1 addition & 1 deletion docstore/mongodocstore/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ const mongoIDField = "_id"

func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, opts *driver.RunActionsOptions) driver.ActionListError {
errs := make([]error, len(actions))
beforeGets, gets, writes, afterGets := driver.GroupActions(actions)
beforeGets, gets, writes, _, afterGets := driver.GroupActions(actions)
c.runGets(ctx, beforeGets, errs, opts)
ch := make(chan []error)
go func() { ch <- c.bulkWrite(ctx, writes, errs, opts) }()
Expand Down
Loading