diff --git a/shadow/diff.go b/shadow/diff.go index 69bd0a50..635af682 100644 --- a/shadow/diff.go +++ b/shadow/diff.go @@ -17,6 +17,7 @@ package shadow import ( "errors" "reflect" + "strings" ) var errInvalidAttribute = errors.New("invalid attribute key") @@ -87,17 +88,30 @@ func stateDiff(base, in interface{}) (interface{}, bool, error) { for _, k := range keysIn { keysInMap[k] = struct{}{} } + baseMatcher, err := newAttributeMatcher(reflect.ValueOf(base)) + if err != nil { + return nil, false, err + } + inMatcher, err := newAttributeMatcher(reflect.ValueOf(in)) + if err != nil { + return nil, false, err + } out := make(map[string]interface{}) for _, k := range keys { if _, ok := keysInMap[k]; !ok { continue } delete(keysInMap, k) - a, _ := attributeByKey(base, k) - b, err := attributeByKey(in, k) + a, _, _ := baseMatcher.byKey(k) + b, bInfo, err := inMatcher.byKey(k) if err != nil { return nil, false, err } + + if bInfo.omitempty && bInfo.val.IsZero() { + continue + } + d, difer, err := stateDiff(a, b) if err != nil { return nil, false, err @@ -107,7 +121,10 @@ func stateDiff(base, in interface{}) (interface{}, bool, error) { } } for k := range keysInMap { - b, _ := attributeByKey(in, k) + b, bInfo, _ := inMatcher.byKey(k) + if bInfo.omitempty && bInfo.val.IsZero() { + continue + } out[k] = b } if len(out) == 0 { @@ -134,9 +151,18 @@ func attributeKeys(a interface{}) ([]string, bool, error) { } return out, true, nil case reflect.Struct: - out := make([]string, v.NumField()) - for i := range out { - out[i] = t.Field(i).Name + n := v.NumField() + out := make([]string, 0, n) + for i := 0; i < n; i++ { + jsonName, _, ok := jsonFieldInfo(t.Field(i).Tag) + if !ok { + continue + } + if jsonName == "" { + out = append(out, t.Field(i).Name) + } else { + out = append(out, jsonName) + } } return out, true, nil case reflect.Ptr: @@ -145,30 +171,82 @@ func attributeKeys(a interface{}) ([]string, bool, error) { return nil, false, nil } -func attributeByKey(a interface{}, k string) (interface{}, error) { - ret, err := attributeByKeyImpl(reflect.ValueOf(a), k) - if err != nil { - return nil, err - } - return ret.Interface(), nil +type attributeMatcher struct { + byName map[string]attributeInfo } -func attributeByKeyImpl(v reflect.Value, k string) (reflect.Value, error) { - if !v.IsValid() { - return reflect.Value{}, errInvalidAttribute +type attributeInfo struct { + val reflect.Value + omitempty bool +} + +func newAttributeMatcher(val reflect.Value) (*attributeMatcher, error) { + if !val.IsValid() { + return nil, errInvalidAttribute } - t := v.Type() + t := val.Type() switch t.Kind() { case reflect.Struct: - return v.FieldByName(k), nil + a := &attributeMatcher{byName: make(map[string]attributeInfo)} + n := t.NumField() + for i := 0; i < n; i++ { + jsonName, omitempty, ok := jsonFieldInfo(t.Field(i).Tag) + if !ok { + continue + } + var name string + if jsonName == "" { + name = t.Field(i).Name + } else { + name = jsonName + } + a.byName[name] = attributeInfo{ + val: val.Field(i), + omitempty: omitempty, + } + } + return a, nil case reflect.Map: - val := v.MapIndex(reflect.ValueOf(k)) - if !val.IsValid() { - return reflect.Value{}, nil + a := &attributeMatcher{byName: make(map[string]attributeInfo)} + for _, key := range val.MapKeys() { + name := key.String() + a.byName[name] = attributeInfo{val: val.MapIndex(key)} } - return val, nil + return a, nil case reflect.Ptr, reflect.Interface: - return attributeByKeyImpl(v.Elem(), k) + return newAttributeMatcher(val.Elem()) + } + return nil, errInvalidAttribute +} + +func (a *attributeMatcher) byKey(k string) (interface{}, attributeInfo, error) { + val, ok := a.byName[k] + if !ok { + return reflect.Value{}, attributeInfo{}, errInvalidAttribute + } + return val.val.Interface(), val, nil +} + +func jsonFieldInfo(t reflect.StructTag) (string, bool, bool) { + tag, ok := t.Lookup("json") + if !ok { + // Use struct field name. + return "", false, true + } + if tag == "-" { + // Field is ignored. + return "", false, false + } + tags := strings.Split(tag, ",") + var omitempty bool + for _, tag := range tags { + if tag == "omitempty" { + omitempty = true + } + } + if tags[0] == "" { + // Use struct field name. + return "", omitempty, true } - return reflect.Value{}, errInvalidAttribute + return tags[0], omitempty, true } diff --git a/shadow/diff_test.go b/shadow/diff_test.go index ba7857d1..d10c9974 100644 --- a/shadow/diff_test.go +++ b/shadow/diff_test.go @@ -26,6 +26,18 @@ type testStruct struct { A, B, C int S string } +type testStructWithTag struct { + AX int `json:"A,random_tag=aaa"` + BX int `json:"B"` + C int `json:",random_tag=bbb"` + SX string `json:"S"` + Garbage int `json:"-"` +} +type testStructWithOmitempty struct { + A int `json:",omitempty"` + B, C int + S string +} type testSubStruct struct { S1, S2 int } @@ -83,12 +95,36 @@ func TestStateDiff(t *testing.T) { diff: map[string]interface{}{"B": 3, "C": 0}, hasDiff: true, }, + "Map2StructWithTag": { + base: map[string]interface{}{"A": 1, "B": 2, "S": "test"}, + input: testStructWithTag{AX: 1, BX: 3, SX: "test"}, + diff: map[string]interface{}{"B": 3, "C": 0}, + hasDiff: true, + }, + "Map2StructWithoutOmitempty": { + base: map[string]interface{}{"A": 1, "B": 2, "S": "test"}, + input: testStruct{A: 0, B: 3, S: "test"}, + diff: map[string]interface{}{"A": 0, "B": 3, "C": 0}, + hasDiff: true, + }, + "Map2StructWithOmitempty": { + base: map[string]interface{}{"A": 1, "B": 2, "S": "test"}, + input: testStructWithOmitempty{A: 0, B: 3, S: "test"}, + diff: map[string]interface{}{"B": 3, "C": 0}, + hasDiff: true, + }, "Map2StructPtr": { base: map[string]interface{}{"A": 1, "B": 2, "S": "test"}, input: &testStruct{A: 1, B: 3, S: "test"}, diff: map[string]interface{}{"B": 3, "C": 0}, hasDiff: true, }, + "Map2StructWithTagPtr": { + base: map[string]interface{}{"A": 1, "B": 2, "S": "test"}, + input: &testStructWithTag{AX: 1, BX: 3, SX: "test"}, + diff: map[string]interface{}{"B": 3, "C": 0}, + hasDiff: true, + }, "Map2NestedStruct_TypeChange": { base: map[string]interface{}{"A": 1, "B": 2, "S": 3}, input: testStructNested{A: 1, B: 3, S: testSubStruct{S1: 2}}, @@ -100,11 +136,21 @@ func TestStateDiff(t *testing.T) { input: testStruct{A: 1, B: 2, S: "test"}, hasDiff: false, }, + "Map2StructWithTag_Equal": { + base: map[string]interface{}{"B": 2, "A": 1, "C": 0, "S": "test"}, + input: testStructWithTag{AX: 1, BX: 2, SX: "test"}, + hasDiff: false, + }, "Map2StructPtr_Equal": { base: map[string]interface{}{"B": 2, "A": 1, "C": 0, "S": "test"}, input: &testStruct{A: 1, B: 2, S: "test"}, hasDiff: false, }, + "Map2StructWithTagPtr_Equal": { + base: map[string]interface{}{"B": 2, "A": 1, "C": 0, "S": "test"}, + input: &testStructWithTag{AX: 1, BX: 2, SX: "test"}, + hasDiff: false, + }, "Nil2Map": { base: nil, input: map[string]interface{}{"A": 1, "B": 3, "S": "test"}, @@ -208,11 +254,21 @@ func TestAttributeKeys(t *testing.T) { keys: []string{"A", "B", "C", "S"}, hasChild: true, }, + "StructWithTag": { + input: testStructWithTag{}, + keys: []string{"A", "B", "C", "S"}, + hasChild: true, + }, "StructPtr": { input: &testStruct{}, keys: []string{"A", "B", "C", "S"}, hasChild: true, }, + "StructWithTagPtr": { + input: &testStructWithTag{}, + keys: []string{"A", "B", "C", "S"}, + hasChild: true, + }, "NestedStruct": { input: testStructNested{}, keys: []string{"A", "B", "C", "S"}, @@ -244,7 +300,7 @@ func TestAttributeKeys(t *testing.T) { } } -func TestAttributeByKey(t *testing.T) { +func TestAttributeMatcher(t *testing.T) { testCases := map[string]struct { input interface{} keyValue map[string]interface{} @@ -258,10 +314,18 @@ func TestAttributeByKey(t *testing.T) { input: testStruct{A: 2, S: "test"}, keyValue: map[string]interface{}{"A": 2, "B": 0, "C": 0, "S": "test"}, }, + "StructWithTag": { + input: testStructWithTag{AX: 2, SX: "test"}, + keyValue: map[string]interface{}{"A": 2, "B": 0, "C": 0, "S": "test"}, + }, "StructPtr": { input: &testStruct{A: 2, S: "test"}, keyValue: map[string]interface{}{"A": 2, "B": 0, "C": 0, "S": "test"}, }, + "StructWithTagPtr": { + input: &testStructWithTag{AX: 2, SX: "test"}, + keyValue: map[string]interface{}{"A": 2, "B": 0, "C": 0, "S": "test"}, + }, "NestedStruct": { input: testStructNested{A: 3, S: testSubStruct{S2: 4}}, keyValue: map[string]interface{}{"A": 3, "S": testSubStruct{S2: 4}}, @@ -271,7 +335,11 @@ func TestAttributeByKey(t *testing.T) { tt := tt t.Run(name, func(t *testing.T) { for k, v := range tt.keyValue { - a, err := attributeByKey(tt.input, k) + matcher, err := newAttributeMatcher(reflect.ValueOf(tt.input)) + if err != nil { + t.Fatal(err) + } + a, _, err := matcher.byKey(k) if tt.err != nil { if !errors.Is(tt.err, err) { t.Errorf("Expected error: '%v', got: '%v'", tt.err, err)