Skip to content

Commit

Permalink
marshal safe
Browse files Browse the repository at this point in the history
  • Loading branch information
AsterDY committed Jun 27, 2024
1 parent 2c4dbb2 commit 7e127bc
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 35 deletions.
30 changes: 18 additions & 12 deletions ast/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,12 @@ func quoteString(e *[]byte, s string) {
var bytesPool = sync.Pool{}

func (self *Node) MarshalJSON() ([]byte, error) {
return self.marshalJSON(noLazy)
}

func (self *Node) marshalJSON(locked bool) ([]byte, error) {
buf := newBuffer()
err := self.encode(buf)
err := self.encode(buf, locked)
if err != nil {
freeBuffer(buf)
return nil, err
Expand All @@ -117,9 +121,9 @@ func freeBuffer(buf *[]byte) {
bytesPool.Put(buf)
}

func (self *Node) encode(buf *[]byte) error {
func (self *Node) encode(buf *[]byte, locked bool) error {
if self.IsRaw() {
return self.encodeRaw(buf)
return self.encodeRaw(buf, locked)
}
switch int(self.itype()) {
case V_NONE : return ErrNotExist
Expand All @@ -136,16 +140,18 @@ func (self *Node) encode(buf *[]byte) error {
}
}

func (self *Node) encodeRaw(buf *[]byte) error {
if noLazy {
func (self *Node) encodeRaw(buf *[]byte, locked bool) error {
if locked {
self.rlock()
if !self.IsRaw() {
self.runlock()
return self.encode(buf)
return self.encode(buf, false)
}
defer self.runlock()
}
raw := self.toString()
if locked {
self.runlock()
}
*buf = append(*buf, raw...)
return nil
}
Expand Down Expand Up @@ -206,7 +212,7 @@ func (self *Node) encodeArray(buf *[]byte) error {
*buf = append(*buf, ',')
}
started = true
if err := n.encode(buf); err != nil {
if err := n.encode(buf, true); err != nil {
return err
}
}
Expand All @@ -215,16 +221,16 @@ func (self *Node) encodeArray(buf *[]byte) error {
return nil
}

func (self *Pair) encode(buf *[]byte) error {
func (self *Pair) encode(buf *[]byte, locked bool) error {
if len(*buf) == 0 {
*buf = append(*buf, '"', '"', ':')
return self.Value.encode(buf)
return self.Value.encode(buf, locked)
}

quote(buf, self.Key)
*buf = append(*buf, ':')

return self.Value.encode(buf)
return self.Value.encode(buf, locked)
}

func (self *Node) encodeObject(buf *[]byte) error {
Expand Down Expand Up @@ -252,7 +258,7 @@ func (self *Node) encodeObject(buf *[]byte) error {
*buf = append(*buf, ',')
}
started = true
if err := n.encode(buf); err != nil {
if err := n.encode(buf, true); err != nil {
return err
}
}
Expand Down
6 changes: 2 additions & 4 deletions ast/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"encoding/json"
"fmt"
"strconv"
"sync"
"unsafe"

"github.com/bytedance/sonic/internal/native/types"
Expand Down Expand Up @@ -57,7 +56,7 @@ type Node struct {
t types.ValueType
l uint
p unsafe.Pointer
m *sync.RWMutex
m unsafe.Pointer
}

// UnmarshalJSON is just an adapter to json.Unmarshaler.
Expand Down Expand Up @@ -116,7 +115,6 @@ func (self *Node) Check() error {
}

// IsRaw returns true if node's underlying value is raw json
//go:nocheckptr
func (self Node) IsRaw() bool {
return self.t&_V_RAW != 0
}
Expand All @@ -143,7 +141,7 @@ func (self *Node) Raw() (string, error) {
if noLazy {
self.runlock()
}
buf, err := self.MarshalJSON()
buf, err := self.marshalJSON(false)
return rt.Mem2Str(buf), err
}
ret := self.toString()
Expand Down
23 changes: 14 additions & 9 deletions ast/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package ast
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"

"github.com/bytedance/sonic/internal/native/types"
"github.com/bytedance/sonic/internal/rt"
Expand Down Expand Up @@ -606,7 +608,7 @@ func newRawNode(str string, typ types.ValueType) Node {
l: uint(len(str)),
}
if noLazy {
ret.m = new(sync.RWMutex)
ret.m = unsafe.Pointer(new(sync.RWMutex))
}
return ret
}
Expand Down Expand Up @@ -667,27 +669,30 @@ func switchRawType(c byte) types.ValueType {
return typeJumpTable[c]
}

func (self *Node) loadm() *sync.RWMutex {
return (*sync.RWMutex)(atomic.LoadPointer(&self.m))
}

func (self *Node) lock() {
if self.m != nil {
self.m.Lock()
if m := self.loadm(); m != nil {
m.Lock()
}
}

func (self *Node) unlock() {
if self.m != nil {
self.m.Unlock()
if m := self.loadm(); m != nil {
m.Unlock()
}
}

func (self *Node) rlock() {
if self.m != nil {
self.m.RLock()
if m := self.loadm(); m != nil {
m.RLock()
}
}

func (self *Node) runlock() {
if self.m != nil {
self.m.RUnlock()
if m := self.loadm(); m != nil {
m.RUnlock()
}
}
33 changes: 23 additions & 10 deletions ast/parser_norace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,31 @@ import (
)


func TestParseNoLazy(t *testing.T) {
func TestNodeRace(t *testing.T) {
noLazy = true

src := `{"1":1,"2": [ 1 , 2 , { "3" : 3 } ] }`
src := `{"1":1,"2": [ 1 , 1 , { "3" : 1 , "4" : [] } ] }`
node := NewRaw(src)

cases := []struct{
path []interface{}
exp []string
scalar bool
lv int
}{
{[]interface{}{"1"}, []string{`1`}},
{[]interface{}{"2"}, []string{`[ 1 , 2 , { "3" : 3 } ]`, `[1,2,{ "3" : 3 }]`, `[1,2,{"3":3}]`}},
{[]interface{}{"2", 1}, []string{`2`}},
{[]interface{}{"2", 2}, []string{`{ "3" : 3 }`, `{"3":3}`}},
{[]interface{}{"2", 2, "3"}, []string{`3`}},
{[]interface{}{"1"}, []string{`1`}, true, 0},
{[]interface{}{"2"}, []string{`[ 1 , 1 , { "3" : 1 , "4" : [] } ]`, `[1,1,{ "3" : 1 , "4" : [] }]`, `[1,1,{"3":1,"4":[]}]`}, false, 3},
{[]interface{}{"2", 1}, []string{`1`}, true, 1},
{[]interface{}{"2", 2}, []string{`{ "3" : 1 , "4" : [] }`, `{"3":1,"4":[]}`}, false, 2},
{[]interface{}{"2", 2, "3"}, []string{`1`}, true, 0},
{[]interface{}{"2", 2, "4"}, []string{`[]`}, false, 0},
}

wg := sync.WaitGroup{}
start := sync.RWMutex{}
start.Lock()

P := 1000
P := 10000
for i := range cases {
// println(i)
c := cases[i]
Expand All @@ -54,8 +57,18 @@ func TestParseNoLazy(t *testing.T) {
go func () {
defer wg.Done()
start.RLock()
v, err := node.GetByPath(c.path...).Raw()
n := node.GetByPath(c.path...)
v, err := n.Raw()
iv, _ := n.Int64()
lv, _ := n.Len()
_, e := n.Interface()
require.NoError(t, err)
require.NoError(t, e)
if c.scalar {
require.Equal(t, int64(1), iv)
} else {
require.Equal(t, c.lv, lv)
}
eq := false
for _, exp := range c.exp {
if exp == v {
Expand All @@ -72,4 +85,4 @@ func TestParseNoLazy(t *testing.T) {
wg.Wait()

noLazy = false
}
}

0 comments on commit 7e127bc

Please sign in to comment.