Skip to content

Commit

Permalink
shadow: handle json tag in state struct (#237)
Browse files Browse the repository at this point in the history
* add test cases to validate json tag handling
* skip ignored json field
* support omitempty
  • Loading branch information
at-wat authored Mar 15, 2021
1 parent ec4e064 commit 39e8081
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 25 deletions.
124 changes: 101 additions & 23 deletions shadow/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package shadow
import (
"errors"
"reflect"
"strings"
)

var errInvalidAttribute = errors.New("invalid attribute key")
Expand Down Expand Up @@ -87,17 +88,30 @@ func stateDiff(base, in interface{}) (interface{}, bool, error) {
for _, k := range keysIn {
keysInMap[k] = struct{}{}
}
baseMatcher, err := newAttributeMatcher(reflect.ValueOf(base))
if err != nil {
return nil, false, err
}
inMatcher, err := newAttributeMatcher(reflect.ValueOf(in))
if err != nil {
return nil, false, err
}
out := make(map[string]interface{})
for _, k := range keys {
if _, ok := keysInMap[k]; !ok {
continue
}
delete(keysInMap, k)
a, _ := attributeByKey(base, k)
b, err := attributeByKey(in, k)
a, _, _ := baseMatcher.byKey(k)
b, bInfo, err := inMatcher.byKey(k)
if err != nil {
return nil, false, err
}

if bInfo.omitempty && bInfo.val.IsZero() {
continue
}

d, difer, err := stateDiff(a, b)
if err != nil {
return nil, false, err
Expand All @@ -107,7 +121,10 @@ func stateDiff(base, in interface{}) (interface{}, bool, error) {
}
}
for k := range keysInMap {
b, _ := attributeByKey(in, k)
b, bInfo, _ := inMatcher.byKey(k)
if bInfo.omitempty && bInfo.val.IsZero() {
continue
}
out[k] = b
}
if len(out) == 0 {
Expand All @@ -134,9 +151,18 @@ func attributeKeys(a interface{}) ([]string, bool, error) {
}
return out, true, nil
case reflect.Struct:
out := make([]string, v.NumField())
for i := range out {
out[i] = t.Field(i).Name
n := v.NumField()
out := make([]string, 0, n)
for i := 0; i < n; i++ {
jsonName, _, ok := jsonFieldInfo(t.Field(i).Tag)
if !ok {
continue
}
if jsonName == "" {
out = append(out, t.Field(i).Name)
} else {
out = append(out, jsonName)
}
}
return out, true, nil
case reflect.Ptr:
Expand All @@ -145,30 +171,82 @@ func attributeKeys(a interface{}) ([]string, bool, error) {
return nil, false, nil
}

func attributeByKey(a interface{}, k string) (interface{}, error) {
ret, err := attributeByKeyImpl(reflect.ValueOf(a), k)
if err != nil {
return nil, err
}
return ret.Interface(), nil
type attributeMatcher struct {
byName map[string]attributeInfo
}

func attributeByKeyImpl(v reflect.Value, k string) (reflect.Value, error) {
if !v.IsValid() {
return reflect.Value{}, errInvalidAttribute
type attributeInfo struct {
val reflect.Value
omitempty bool
}

func newAttributeMatcher(val reflect.Value) (*attributeMatcher, error) {
if !val.IsValid() {
return nil, errInvalidAttribute
}
t := v.Type()
t := val.Type()
switch t.Kind() {
case reflect.Struct:
return v.FieldByName(k), nil
a := &attributeMatcher{byName: make(map[string]attributeInfo)}
n := t.NumField()
for i := 0; i < n; i++ {
jsonName, omitempty, ok := jsonFieldInfo(t.Field(i).Tag)
if !ok {
continue
}
var name string
if jsonName == "" {
name = t.Field(i).Name
} else {
name = jsonName
}
a.byName[name] = attributeInfo{
val: val.Field(i),
omitempty: omitempty,
}
}
return a, nil
case reflect.Map:
val := v.MapIndex(reflect.ValueOf(k))
if !val.IsValid() {
return reflect.Value{}, nil
a := &attributeMatcher{byName: make(map[string]attributeInfo)}
for _, key := range val.MapKeys() {
name := key.String()
a.byName[name] = attributeInfo{val: val.MapIndex(key)}
}
return val, nil
return a, nil
case reflect.Ptr, reflect.Interface:
return attributeByKeyImpl(v.Elem(), k)
return newAttributeMatcher(val.Elem())
}
return nil, errInvalidAttribute
}

func (a *attributeMatcher) byKey(k string) (interface{}, attributeInfo, error) {
val, ok := a.byName[k]
if !ok {
return reflect.Value{}, attributeInfo{}, errInvalidAttribute
}
return val.val.Interface(), val, nil
}

func jsonFieldInfo(t reflect.StructTag) (string, bool, bool) {
tag, ok := t.Lookup("json")
if !ok {
// Use struct field name.
return "", false, true
}
if tag == "-" {
// Field is ignored.
return "", false, false
}
tags := strings.Split(tag, ",")
var omitempty bool
for _, tag := range tags {
if tag == "omitempty" {
omitempty = true
}
}
if tags[0] == "" {
// Use struct field name.
return "", omitempty, true
}
return reflect.Value{}, errInvalidAttribute
return tags[0], omitempty, true
}
72 changes: 70 additions & 2 deletions shadow/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ type testStruct struct {
A, B, C int
S string
}
type testStructWithTag struct {
AX int `json:"A,random_tag=aaa"`
BX int `json:"B"`
C int `json:",random_tag=bbb"`
SX string `json:"S"`
Garbage int `json:"-"`
}
type testStructWithOmitempty struct {
A int `json:",omitempty"`
B, C int
S string
}
type testSubStruct struct {
S1, S2 int
}
Expand Down Expand Up @@ -83,12 +95,36 @@ func TestStateDiff(t *testing.T) {
diff: map[string]interface{}{"B": 3, "C": 0},
hasDiff: true,
},
"Map2StructWithTag": {
base: map[string]interface{}{"A": 1, "B": 2, "S": "test"},
input: testStructWithTag{AX: 1, BX: 3, SX: "test"},
diff: map[string]interface{}{"B": 3, "C": 0},
hasDiff: true,
},
"Map2StructWithoutOmitempty": {
base: map[string]interface{}{"A": 1, "B": 2, "S": "test"},
input: testStruct{A: 0, B: 3, S: "test"},
diff: map[string]interface{}{"A": 0, "B": 3, "C": 0},
hasDiff: true,
},
"Map2StructWithOmitempty": {
base: map[string]interface{}{"A": 1, "B": 2, "S": "test"},
input: testStructWithOmitempty{A: 0, B: 3, S: "test"},
diff: map[string]interface{}{"B": 3, "C": 0},
hasDiff: true,
},
"Map2StructPtr": {
base: map[string]interface{}{"A": 1, "B": 2, "S": "test"},
input: &testStruct{A: 1, B: 3, S: "test"},
diff: map[string]interface{}{"B": 3, "C": 0},
hasDiff: true,
},
"Map2StructWithTagPtr": {
base: map[string]interface{}{"A": 1, "B": 2, "S": "test"},
input: &testStructWithTag{AX: 1, BX: 3, SX: "test"},
diff: map[string]interface{}{"B": 3, "C": 0},
hasDiff: true,
},
"Map2NestedStruct_TypeChange": {
base: map[string]interface{}{"A": 1, "B": 2, "S": 3},
input: testStructNested{A: 1, B: 3, S: testSubStruct{S1: 2}},
Expand All @@ -100,11 +136,21 @@ func TestStateDiff(t *testing.T) {
input: testStruct{A: 1, B: 2, S: "test"},
hasDiff: false,
},
"Map2StructWithTag_Equal": {
base: map[string]interface{}{"B": 2, "A": 1, "C": 0, "S": "test"},
input: testStructWithTag{AX: 1, BX: 2, SX: "test"},
hasDiff: false,
},
"Map2StructPtr_Equal": {
base: map[string]interface{}{"B": 2, "A": 1, "C": 0, "S": "test"},
input: &testStruct{A: 1, B: 2, S: "test"},
hasDiff: false,
},
"Map2StructWithTagPtr_Equal": {
base: map[string]interface{}{"B": 2, "A": 1, "C": 0, "S": "test"},
input: &testStructWithTag{AX: 1, BX: 2, SX: "test"},
hasDiff: false,
},
"Nil2Map": {
base: nil,
input: map[string]interface{}{"A": 1, "B": 3, "S": "test"},
Expand Down Expand Up @@ -208,11 +254,21 @@ func TestAttributeKeys(t *testing.T) {
keys: []string{"A", "B", "C", "S"},
hasChild: true,
},
"StructWithTag": {
input: testStructWithTag{},
keys: []string{"A", "B", "C", "S"},
hasChild: true,
},
"StructPtr": {
input: &testStruct{},
keys: []string{"A", "B", "C", "S"},
hasChild: true,
},
"StructWithTagPtr": {
input: &testStructWithTag{},
keys: []string{"A", "B", "C", "S"},
hasChild: true,
},
"NestedStruct": {
input: testStructNested{},
keys: []string{"A", "B", "C", "S"},
Expand Down Expand Up @@ -244,7 +300,7 @@ func TestAttributeKeys(t *testing.T) {
}
}

func TestAttributeByKey(t *testing.T) {
func TestAttributeMatcher(t *testing.T) {
testCases := map[string]struct {
input interface{}
keyValue map[string]interface{}
Expand All @@ -258,10 +314,18 @@ func TestAttributeByKey(t *testing.T) {
input: testStruct{A: 2, S: "test"},
keyValue: map[string]interface{}{"A": 2, "B": 0, "C": 0, "S": "test"},
},
"StructWithTag": {
input: testStructWithTag{AX: 2, SX: "test"},
keyValue: map[string]interface{}{"A": 2, "B": 0, "C": 0, "S": "test"},
},
"StructPtr": {
input: &testStruct{A: 2, S: "test"},
keyValue: map[string]interface{}{"A": 2, "B": 0, "C": 0, "S": "test"},
},
"StructWithTagPtr": {
input: &testStructWithTag{AX: 2, SX: "test"},
keyValue: map[string]interface{}{"A": 2, "B": 0, "C": 0, "S": "test"},
},
"NestedStruct": {
input: testStructNested{A: 3, S: testSubStruct{S2: 4}},
keyValue: map[string]interface{}{"A": 3, "S": testSubStruct{S2: 4}},
Expand All @@ -271,7 +335,11 @@ func TestAttributeByKey(t *testing.T) {
tt := tt
t.Run(name, func(t *testing.T) {
for k, v := range tt.keyValue {
a, err := attributeByKey(tt.input, k)
matcher, err := newAttributeMatcher(reflect.ValueOf(tt.input))
if err != nil {
t.Fatal(err)
}
a, _, err := matcher.byKey(k)
if tt.err != nil {
if !errors.Is(tt.err, err) {
t.Errorf("Expected error: '%v', got: '%v'", tt.err, err)
Expand Down

0 comments on commit 39e8081

Please sign in to comment.