diff --git a/attributes.go b/attributes.go new file mode 100644 index 0000000..f010f7c --- /dev/null +++ b/attributes.go @@ -0,0 +1,161 @@ +package jsonapi + +import ( + "encoding/json" + "reflect" + "strconv" + "strings" + "time" +) + +// NOTE: reciever for MarshalJSON() should not be a pointer +// https://play.golang.org/p/Cf9yYLIzJA (MarshalJSON() w/ pointer reciever) +// https://play.golang.org/p/5EsItAtgXy (MarshalJSON() w/o pointer reciever) + +const iso8601Layout = "2006-01-02T15:04:05Z07:00" + +var ( + jsonUnmarshaler = reflect.TypeOf(new(json.Unmarshaler)).Elem() +) + +// iso8601Datetime represents a ISO8601 formatted datetime +// It is a time.Time instance that marshals and unmarshals to the ISO8601 ref +type iso8601Datetime struct { + time.Time +} + +// MarshalJSON implements the json.Marshaler interface. +func (t iso8601Datetime) MarshalJSON() ([]byte, error) { + s := t.Time.Format(iso8601Layout) + return json.Marshal(s) +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (t *iso8601Datetime) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + // Fractional seconds are handled implicitly by Parse. + var err error + if t.Time, err = time.Parse(strconv.Quote(iso8601Layout), string(data)); err != nil { + return ErrInvalidISO8601 + } + return err +} + +// iso8601Datetime.String() - override default String() on time +func (t iso8601Datetime) String() string { + return t.Format(iso8601Layout) +} + +// unix(Unix Seconds) marshals/unmarshals the number of milliseconds elapsed since January 1, 1970 UTC +type unix struct { + time.Time +} + +// MarshalJSON implements the json.Marshaler interface. +func (t unix) MarshalJSON() ([]byte, error) { + return json.Marshal(t.Unix()) +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (t *unix) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + s := string(data) + if s == "null" { + return nil + } + + v, err := stringToInt64(s) + if err != nil { + // return this specific error to maintain existing tests. + // TODO: consider refactoring tests to not assert against error string + return ErrInvalidTime + } + + t.Time = time.Unix(v, 0).In(time.UTC) + + return nil +} + +// unixMilli (Unix Millisecond) marshals/unmarshals the number of milliseconds elapsed since January 1, 1970 UTC +type unixMilli struct { + time.Time +} + +// MarshalJSON implements the json.Marshaler interface. +func (t unixMilli) MarshalJSON() ([]byte, error) { + return json.Marshal(t.UnixNano() / int64(time.Millisecond)) +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (t *unixMilli) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + s := string(data) + if s == "null" { + return nil + } + + v, err := stringToInt64(s) + if err != nil { + return err + } + + t.Time = time.Unix(v/1000, (v % 1000 * int64(time.Millisecond))).In(time.UTC) + + return nil +} + +// stringToInt64 convert time in either decimal or exponential notation to int64 +// https://golang.org/doc/go1.8#encoding_json +// go1.8 prefers decimal notation +// go1.7 may use exponetial notation, so check if it came in as a float +func stringToInt64(s string) (int64, error) { + var v int64 + if strings.Contains(s, ".") { + fv, err := strconv.ParseFloat(s, 64) + if err != nil { + return v, err + } + v = int64(fv) + } else { + iv, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return v, err + } + v = iv + } + return v, nil +} + +func implementsJSONUnmarshaler(t reflect.Type) bool { + ok, _ := deepCheckImplementation(t, jsonUnmarshaler) + return ok +} + +func deepCheckImplementation(t, interfaceType reflect.Type) (bool, reflect.Type) { + // check as-is + if t.Implements(interfaceType) { + return true, t + } + + switch t.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice: + // check ptr implementation + ptrType := reflect.PtrTo(t) + if ptrType.Implements(interfaceType) { + return true, ptrType + } + // since these are reference types, re-check on the element of t + return deepCheckImplementation(t.Elem(), interfaceType) + default: + // check ptr implementation + ptrType := reflect.PtrTo(t) + if ptrType.Implements(interfaceType) { + return true, ptrType + } + // nothing else to check, return false + return false, nil + } +} diff --git a/attributes_test.go b/attributes_test.go new file mode 100644 index 0000000..04ee7f4 --- /dev/null +++ b/attributes_test.go @@ -0,0 +1,339 @@ +package jsonapi + +import ( + "encoding/json" + "reflect" + "strconv" + "testing" + "time" +) + +func TestIso8601Datetime(t *testing.T) { + pacific, err := time.LoadLocation("America/Los_Angeles") + if err != nil { + t.Fatal(err) + } + + type test struct { + stringVal string + dtVal iso8601Datetime + } + + tests := []*test{ + &test{ + stringVal: strconv.Quote("2017-04-06T13:00:00-07:00"), + dtVal: iso8601Datetime{Time: time.Date(2017, time.April, 6, 13, 0, 0, 0, pacific)}, + }, + &test{ + stringVal: strconv.Quote("2007-05-06T13:00:00-07:00"), + dtVal: iso8601Datetime{Time: time.Date(2007, time.May, 6, 13, 0, 0, 0, pacific)}, + }, + &test{ + stringVal: strconv.Quote("2016-12-08T15:18:54Z"), + dtVal: iso8601Datetime{Time: time.Date(2016, time.December, 8, 15, 18, 54, 0, time.UTC)}, + }, + } + + for _, test := range tests { + // unmarshal stringVal by calling UnmarshalJSON() + dt := &iso8601Datetime{} + if err := dt.UnmarshalJSON([]byte(test.stringVal)); err != nil { + t.Fatal(err) + } + + // compare unmarshaled stringVal to dtVal + if !dt.Time.Equal(test.dtVal.Time) { + t.Errorf("\n\tE=%+v\n\tA=%+v", test.dtVal.UnixNano(), dt.UnixNano()) + } + + // marshal dtVal by calling MarshalJSON() + b, err := test.dtVal.MarshalJSON() + if err != nil { + t.Fatal(err) + } + + // compare marshaled dtVal to stringVal + if test.stringVal != string(b) { + t.Errorf("\n\tE=%+v\n\tA=%+v", test.stringVal, string(b)) + } + } +} + +func TestUnixMilliVariations(t *testing.T) { + control := unixMilli{ + Time: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), + } + + { + var val map[string]unixMilli + t.Logf("\nval: %#v\n", val) + + payload := []byte(`{"foo": 1257894000000, "bar":1257894000000}`) + json.Unmarshal(payload, &val) + + if !val["foo"].Time.Equal(control.Time) { + t.Errorf("\n\tE=%+v\n\tA=%+v", control.Time, val["foo"].Time) + } + + b, _ := json.Marshal(val) + is, err := isJSONEqual(b, payload) + if err != nil { + t.Fatal(err) + } + if !is { + t.Errorf("\n\tE=%s\n\tA=%s", payload, b) + } + } + { + var val map[string]*unixMilli + t.Logf("\nval: %#v\n", val) + + payload := []byte(`{"foo": 1257894000000, "bar":1257894000000}`) + json.Unmarshal(payload, &val) + + if !val["foo"].Time.Equal(control.Time) { + t.Errorf("\n\tE=%+v\n\tA=%+v", control.Time, val["foo"].Time) + } + + b, _ := json.Marshal(val) + is, err := isJSONEqual(b, payload) + if err != nil { + t.Fatal(err) + } + if !is { + t.Errorf("\n\tE=%s\n\tA=%s", payload, b) + } + } + { + var val []*unixMilli + t.Logf("\nval: %#v\n", val) + + payload := []byte(`[1257894000000,1257894000000]`) + json.Unmarshal(payload, &val) + + if !val[0].Time.Equal(control.Time) { + t.Errorf("\n\tE=%+v\n\tA=%+v", control.Time, val[0].Time) + } + + b, _ := json.Marshal(val) + is, err := isJSONEqual(b, payload) + if err != nil { + t.Fatal(err) + } + if !is { + t.Errorf("\n\tE=%s\n\tA=%s", payload, b) + } + } + { + var val []unixMilli + t.Logf("\nval: %#v\n", val) + + payload := []byte(`[1257894000000,1257894000000]`) + json.Unmarshal(payload, &val) + + if !val[0].Time.Equal(control.Time) { + t.Errorf("\n\tE=%+v\n\tA=%+v", control.Time, val[0].Time) + } + + b, _ := json.Marshal(val) + is, err := isJSONEqual(b, payload) + if err != nil { + t.Fatal(err) + } + if !is { + t.Errorf("\n\tE=%s\n\tA=%s", payload, b) + } + } + { + var val unixMilli + t.Logf("\nval: %#v\n", val) + + payload := []byte(`1257894000000`) + json.Unmarshal(payload, &val) + + if !val.Time.Equal(control.Time) { + t.Errorf("\n\tE=%+v\n\tA=%+v", control.Time, val.Time) + } + + b, _ := json.Marshal(val) + is, err := isJSONEqual(b, payload) + if err != nil { + t.Fatal(err) + } + if !is { + t.Errorf("\n\tE=%s\n\tA=%s", payload, b) + } + } + { + var val *unixMilli + t.Logf("\nval: %#v\n", val) + + payload := []byte(`1257894000000`) + json.Unmarshal(payload, &val) + + if !val.Time.Equal(control.Time) { + t.Errorf("\n\tE=%+v\n\tA=%+v", control.Time, val.Time) + } + + b, _ := json.Marshal(val) + is, err := isJSONEqual(b, payload) + if err != nil { + t.Fatal(err) + } + if !is { + t.Errorf("\n\tE=%s\n\tA=%s", payload, b) + } + } +} +func TestUnixMilli(t *testing.T) { + type test struct { + stringVal string + dtVal unixMilli + } + + tests := []*test{ + &test{ + stringVal: "1257894000000", + dtVal: unixMilli{Time: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)}, + }, + &test{ + stringVal: "1257894000999", + dtVal: unixMilli{Time: time.Date(2009, time.November, 10, 23, 0, 0, 999000000, time.UTC)}, + }, + } + + for _, test := range tests { + // unmarshal stringVal by calling UnmarshalJSON() + dt := &unixMilli{} + if err := dt.UnmarshalJSON([]byte(test.stringVal)); err != nil { + t.Fatal(err) + } + + // compare unmarshaled stringVal to dtVal + if !dt.Time.Equal(test.dtVal.Time) { + t.Errorf("\n\tE=%+v\n\tA=%+v", test.dtVal.UnixNano(), dt.UnixNano()) + } + + // marshal dtVal by calling MarshalJSON() + b, err := test.dtVal.MarshalJSON() + if err != nil { + t.Fatal(err) + } + + // compare marshaled dtVal to stringVal + if test.stringVal != string(b) { + t.Errorf("\n\tE=%+v\n\tA=%+v", test.stringVal, string(b)) + } + } +} + +func TestImplementsJSONUnmarshaler(t *testing.T) { + { // positive + raw := json.RawMessage{} + typ := reflect.TypeOf(&raw) + if ok := implementsJSONUnmarshaler(typ); !ok { + t.Error("expected json.RawMessage to implement json.Unmarshaler") + } + } + { // positive + isoDateTime := iso8601Datetime{} + typ := reflect.TypeOf(&isoDateTime) + if ok := implementsJSONUnmarshaler(typ); !ok { + t.Error("expected ISO8601Datetime to implement json.Unmarshaler") + } + } + { // negative + type customString string + input := customString("foo") + typ := reflect.TypeOf(&input) + if ok := implementsJSONUnmarshaler(typ); ok { + t.Error("got true; expected customString to not implement json.Unmarshaler") + } + } +} + +func TestDeepCheckImplementation(t *testing.T) { + tests := []struct { + name string + input interface{} + }{ + { + name: "concrete ( RawMessage is a reflect.Type of slice)", + input: json.RawMessage{}, + }, + { + name: "RawMessage ptr", + input: &json.RawMessage{}, + }, + { + name: "concrete slice of RawMessage", + input: []json.RawMessage{}, + }, + { + name: "slice of RawMessage ptrs", + input: []*json.RawMessage{}, + }, + { + name: "concrete map of RawMessage", + input: map[string]json.RawMessage{}, + }, + { + name: "map of RawMessage ptrs", + input: map[string]*json.RawMessage{}, + }, + { + name: "map of RawMessage slice", + input: map[string][]json.RawMessage{}, + }, + { + name: "ptr ptr of RawMessage", + input: func() **json.RawMessage { + r := &json.RawMessage{} + return &r + }(), + }, + { + name: "concrete unixMilli (struct)", + input: unixMilli{}, + }, + { + name: "unixMilli ptr", + input: &unixMilli{}, + }, + { + name: "concrete slice of unixMilli", + input: []unixMilli{}, + }, + { + name: "slice of unixMilli ptrs", + input: []*unixMilli{}, + }, + { + name: "concrete map of unixMilli", + input: map[string]unixMilli{}, + }, + { + name: "map of unixMilli ptrs", + input: map[string]*unixMilli{}, + }, + { + name: "map of unixMilli slice", + input: map[string][]unixMilli{}, + }, + { + name: "ptr ptr of unixMilli", + input: func() **unixMilli { + r := &unixMilli{} + return &r + }(), + }, + } + + for _, scenario := range tests { + typ := reflect.TypeOf(scenario.input) + ok, elemType := deepCheckImplementation(typ, jsonUnmarshaler) + if !ok { + t.Errorf("\n\tE=%v\n\tA=%v", typ, elemType) + } + } +} diff --git a/request.go b/request.go index 9e0eb1a..86ac9ef 100644 --- a/request.go +++ b/request.go @@ -32,6 +32,9 @@ var ( ErrUnsupportedPtrType = errors.New("Pointer type in struct is not supported") // ErrInvalidType is returned when the given type is incompatible with the expected type. ErrInvalidType = errors.New("Invalid type provided") // I wish we used punctuation. + + ptrTimeType = reflect.TypeOf(new(time.Time)) + timeType = ptrTimeType.Elem() ) // UnmarshalPayload converts an io into a struct instance using jsonapi tags on @@ -138,9 +141,9 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) embeddeds := []*embedded{} for i := 0; i < modelValue.NumField(); i++ { - fieldType := modelType.Field(i) + structField := modelType.Field(i) fieldValue := modelValue.Field(i) - tag := fieldType.Tag.Get(annotationJSONAPI) + tag := structField.Tag.Get(annotationJSONAPI) // handle explicit ignore annotation if shouldIgnoreField(tag) { @@ -148,7 +151,7 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) } // handles embedded structs - if isEmbeddedStruct(fieldType) { + if isEmbeddedStruct(structField) { embeddeds = append(embeddeds, &embedded{ model: reflect.ValueOf(fieldValue.Addr().Interface()), @@ -159,7 +162,7 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) } // handles pointers to embedded structs - if isEmbeddedStructPtr(fieldType) { + if isEmbeddedStructPtr(structField) { embeddeds = append(embeddeds, &embedded{ model: reflect.ValueOf(fieldValue.Interface()), @@ -187,11 +190,11 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) return err } case annotationPrimary: - if err := handlePrimaryUnmarshal(data, args, fieldType, fieldValue); err != nil { + if err := handlePrimaryUnmarshal(data, args, structField, fieldValue); err != nil { return err } case annotationAttribute: - if err := handleAttributeUnmarshal(data, args, fieldType, fieldValue); err != nil { + if err := handleAttributeUnmarshal(data, args, structField, fieldValue); err != nil { return err } case annotationRelation: @@ -218,10 +221,12 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) assign(em.structField, tmp) data = copy } - return nil + } else { + // handle non-nil scenarios + if err := unmarshalNode(data, em.model, included); err != nil { + return err + } } - // handle non-nil scenarios - return unmarshalNode(data, em.model, included) } return nil @@ -269,65 +274,34 @@ func handlePrimaryUnmarshal(data *Node, args []string, fieldType reflect.StructF kind = fieldType.Type.Kind() } - var idValue reflect.Value + switch kind { + default: + // only handle strings and numerics + return ErrBadJSONAPIID + case reflect.String: + assign(fieldValue, reflect.ValueOf(data.ID)) + case + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - // Handle String case - if kind == reflect.String { - // ID will have to be transmitted as a string per the JSON API spec - idValue = reflect.ValueOf(data.ID) - } else { - // Value was not a string... only other supported type was a numeric, - // which would have been sent as a float value. - floatValue, err := strconv.ParseFloat(data.ID, 64) + fv, err := strconv.ParseFloat(data.ID, 64) if err != nil { - // Could not convert the value in the "id" attr to a float return ErrBadJSONAPIID } - // Convert the numeric float to one of the supported ID numeric types - // (int[8,16,32,64] or uint[8,16,32,64]) - switch kind { - case reflect.Int: - n := int(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Int8: - n := int8(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Int16: - n := int16(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Int32: - n := int32(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Int64: - n := int64(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Uint: - n := uint(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Uint8: - n := uint8(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Uint16: - n := uint16(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Uint32: - n := uint32(floatValue) - idValue = reflect.ValueOf(&n) - case reflect.Uint64: - n := uint64(floatValue) - idValue = reflect.ValueOf(&n) - default: - // We had a JSON float (numeric), but our field was not one of the - // allowed numeric types - return ErrBadJSONAPIID + b, err := json.Marshal(fv) + if err != nil { + return err + } + + v := fieldValue.Addr().Interface() + if err := json.Unmarshal(b, v); err != nil { + return err } } - // set value and clear ID to denote it's already been processed - assign(fieldValue, idValue) + // clear ID to denote it's already been processed data.ID = "" - return nil } @@ -418,7 +392,7 @@ func handleToManyRelationUnmarshal(relationData interface{}, fieldType reflect.T return &models, nil } -// TODO: break this out into smaller funcs +// handleAttributeUnmarshal func handleAttributeUnmarshal(data *Node, args []string, fieldType reflect.StructField, fieldValue reflect.Value) error { if len(args) < 2 { return ErrBadJSONAPIStructTag @@ -427,17 +401,6 @@ func handleAttributeUnmarshal(data *Node, args []string, fieldType reflect.Struc if attributes == nil || len(data.Attributes) == 0 { return nil } - - var iso8601 bool - - if len(args) > 2 { - for _, arg := range args[2:] { - if arg == annotationISO8601 { - iso8601 = true - } - } - } - val := attributes[args[1]] // continue if the attribute was not included in the request @@ -445,213 +408,128 @@ func handleAttributeUnmarshal(data *Node, args []string, fieldType reflect.Struc return nil } - v := reflect.ValueOf(val) - - // Handle field of type time.Time - if fieldValue.Type() == reflect.TypeOf(time.Time{}) { - if iso8601 { - var tm string - if v.Kind() == reflect.String { - tm = v.Interface().(string) - } else { - return ErrInvalidISO8601 - } - - t, err := time.Parse(iso8601TimeFormat, tm) - if err != nil { - return ErrInvalidISO8601 - } - - fieldValue.Set(reflect.ValueOf(t)) - delete(data.Attributes, args[1]) - return nil - } - - var at int64 + // custom handling of time + if isTimeValue(fieldValue) { + return handleTimeAttributes(data, args, fieldValue, fieldType) + } - if v.Kind() == reflect.Float64 { - at = int64(v.Interface().(float64)) - } else if v.Kind() == reflect.Int { - at = v.Int() - } else { - return ErrInvalidTime - } + // standard attributes that the json package knows how to handle, plus implementions on json.Unmarshaler + return handleWithJSONMarshaler(data, args, fieldValue) +} - t := time.Unix(at, 0) +func fullNode(n *Node, included *map[string]*Node) *Node { + includedKey := fmt.Sprintf("%s,%s", n.Type, n.ID) - fieldValue.Set(reflect.ValueOf(t)) - delete(data.Attributes, args[1]) - return nil + if included != nil && (*included)[includedKey] != nil { + return deepCopyNode((*included)[includedKey]) } - if fieldValue.Type() == reflect.TypeOf([]string{}) { - values := make([]string, v.Len()) - for i := 0; i < v.Len(); i++ { - values[i] = v.Index(i).Interface().(string) - } + return deepCopyNode(n) +} - fieldValue.Set(reflect.ValueOf(values)) - delete(data.Attributes, args[1]) - return nil +// assign will take the value specified and assign it to the field; if +// field is expecting a ptr assign will assign a ptr. +func assign(field, value reflect.Value) { + if field.Kind() == reflect.Ptr { + field.Set(value) + } else { + field.Set(reflect.Indirect(value)) } +} - if fieldValue.Type() == reflect.TypeOf(new(time.Time)) { - if iso8601 { - var tm string - if v.Kind() == reflect.String { - tm = v.Interface().(string) - } else { - return ErrInvalidISO8601 - - } - - v, err := time.Parse(iso8601TimeFormat, tm) - if err != nil { - return ErrInvalidISO8601 - } - - t := &v - - fieldValue.Set(reflect.ValueOf(t)) - delete(data.Attributes, args[1]) - return nil - } - - var at int64 - - if v.Kind() == reflect.Float64 { - at = int64(v.Interface().(float64)) - } else if v.Kind() == reflect.Int { - at = v.Int() - } else { - return ErrInvalidTime - } - - v := time.Unix(at, 0) - t := &v - - fieldValue.Set(reflect.ValueOf(t)) - delete(data.Attributes, args[1]) - return nil +// handleTimeAttributes - handle field of type time.Time and *time.Time +// TODO: consider refactoring/removing this toggling (would be a breaking change) +// standard time.Time implements RFC3339 (https://golang.org/pkg/time/#Time.UnmarshalJSON) but is overridden here. +// jsonapi doesn't specify, but does recommends ISO8601 (http://jsonapi.org/recommendations/#date-and-time-fields) +// IMHO (skimata): just default on recommended ISO8601, all others desired formats +// should implement w/ a custom marshaler/unmarshaler +func handleTimeAttributes(data *Node, args []string, fieldValue reflect.Value, structField reflect.StructField) error { + b, err := json.Marshal(data.Attributes[args[1]]) + if err != nil { + return err } - - // JSON value was a float (numeric) - if v.Kind() == reflect.Float64 { - floatValue := v.Interface().(float64) - - // The field may or may not be a pointer to a numeric; the kind var - // will not contain a pointer type - var kind reflect.Kind - if fieldValue.Kind() == reflect.Ptr { - kind = fieldType.Type.Elem().Kind() - } else { - kind = fieldType.Type.Kind() + var tm time.Time + if useISO8601(args) { + iso := &iso8601Datetime{} + if err := iso.UnmarshalJSON(b); err != nil { + return err } - - var numericValue reflect.Value - - switch kind { - case reflect.Int: - n := int(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Int8: - n := int8(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Int16: - n := int16(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Int32: - n := int32(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Int64: - n := int64(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Uint: - n := uint(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Uint8: - n := uint8(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Uint16: - n := uint16(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Uint32: - n := uint32(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Uint64: - n := uint64(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Float32: - n := float32(floatValue) - numericValue = reflect.ValueOf(&n) - case reflect.Float64: - n := floatValue - numericValue = reflect.ValueOf(&n) - default: - return ErrUnknownFieldNumberType + tm = iso.Time + } else { + epoch := &unix{} + if err := epoch.UnmarshalJSON(b); err != nil { + return err } + tm = epoch.Time + } - assign(fieldValue, numericValue) - delete(data.Attributes, args[1]) - return nil + if structField.Type.Kind() == reflect.Ptr { + fieldValue.Set(reflect.ValueOf(&tm)) + } else { + fieldValue.Set(reflect.ValueOf(tm)) } - // Field was a Pointer type - if fieldValue.Kind() == reflect.Ptr { - var concreteVal reflect.Value - - switch cVal := val.(type) { - case string: - concreteVal = reflect.ValueOf(&cVal) - case bool: - concreteVal = reflect.ValueOf(&cVal) - case complex64: - concreteVal = reflect.ValueOf(&cVal) - case complex128: - concreteVal = reflect.ValueOf(&cVal) - case uintptr: - concreteVal = reflect.ValueOf(&cVal) - default: - return ErrUnsupportedPtrType - } + delete(data.Attributes, args[1]) + return nil +} - if fieldValue.Type() != concreteVal.Type() { - return ErrUnsupportedPtrType - } +func handleWithJSONMarshaler(data *Node, args []string, fieldValue reflect.Value) error { + v := fieldValue.Addr().Interface() - fieldValue.Set(concreteVal) - delete(data.Attributes, args[1]) - return nil + b, err := json.Marshal(data.Attributes[args[1]]) + if err != nil { + return err } - // As a final catch-all, ensure types line up to avoid a runtime panic. - // Ignore interfaces since interfaces are poly - if fieldValue.Kind() != reflect.Interface && fieldValue.Kind() != v.Kind() { - return ErrInvalidType + if err := json.Unmarshal(b, v); err != nil { + return err } - // set val and clear attribute key so its not processed again - fieldValue.Set(reflect.ValueOf(val)) + // success; clear value delete(data.Attributes, args[1]) return nil } -func fullNode(n *Node, included *map[string]*Node) *Node { - includedKey := fmt.Sprintf("%s,%s", n.Type, n.ID) - - if included != nil && (*included)[includedKey] != nil { - return deepCopyNode((*included)[includedKey]) +func isTimeValue(fieldValue reflect.Value) bool { + switch fieldValue.Type() { + default: + return false + case timeType, ptrTimeType: + return true } - return deepCopyNode(n) } -// assign will take the value specified and assign it to the field; if -// field is expecting a ptr assign will assign a ptr. -func assign(field, value reflect.Value) { - if field.Kind() == reflect.Ptr { - field.Set(value) - } else { - field.Set(reflect.Indirect(value)) +func hasStandardJSONSupport(structField reflect.StructField) bool { + kind := structField.Type.Kind() + if kind == reflect.Ptr { + kind = structField.Type.Elem().Kind() + } + + switch kind { + default: + return false + case reflect.Map, reflect.Slice: + return hasStandardJSONSupport(reflect.StructField{Type: structField.Type.Elem()}) + case + reflect.Bool, + reflect.String, + reflect.Complex64, reflect.Complex128, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Uintptr, + reflect.Float32, reflect.Float64: + return true + } +} + +func useISO8601(args []string) bool { + if len(args) > 2 { + for _, arg := range args[2:] { + if arg == annotationISO8601 { + return true + } + } } + return false } diff --git a/request_test.go b/request_test.go index ea9a15f..06a0bc0 100644 --- a/request_test.go +++ b/request_test.go @@ -126,16 +126,18 @@ func TestUnmarshalToStructWithPointerAttr_BadType(t *testing.T) { in := map[string]interface{}{ "name": true, // This is the wrong type. } - expectedErrorMessage := ErrUnsupportedPtrType.Error() err := UnmarshalPayload(sampleWithPointerPayload(in), out) if err == nil { t.Fatalf("Expected error due to invalid type.") } - if err.Error() != expectedErrorMessage { - t.Fatalf("Unexpected error message: %s", err.Error()) + + jTypeErr, ok := err.(*json.UnmarshalTypeError) + if !ok { + t.Fatalf("Expected an unmarshal error, got %#v\n", err) } + t.Logf("successfully returned err when unmarshaling a %s, to a %s field\n", jTypeErr.Value, jTypeErr.Type.String()) } func TestStringPointerField(t *testing.T) { @@ -196,8 +198,36 @@ func TestUnmarshalInvalidJSON_BadType(t *testing.T) { BadValue interface{} Error error }{ // The `Field` values here correspond to the `ModelBadTypes` jsonapi fields. - {Field: "string_field", BadValue: 0, Error: ErrUnknownFieldNumberType}, // Expected string. - {Field: "float_field", BadValue: "A string.", Error: ErrInvalidType}, // Expected float64. + {Field: "string_field", BadValue: 0, Error: ErrUnknownFieldNumberType}, // Expected string. + {Field: "float_field", BadValue: "A string.", Error: ErrInvalidType}, // Expected float64. + } + for _, test := range badTypeTests { + t.Run(fmt.Sprintf("Test_%s", test.Field), func(t *testing.T) { + out := new(ModelBadTypes) + in := map[string]interface{}{} + in[test.Field] = test.BadValue + // should not compare err string + // expectedErrorMessage := test.Error.Error() + + err := UnmarshalPayload(samplePayloadWithBadTypes(in), out) + + if err == nil { + t.Fatalf("Expected error due to invalid type.") + } + jTypeErr, ok := err.(*json.UnmarshalTypeError) + if !ok { + t.Fatalf("Expected an unmarshal error, got %#v\n", err) + } + t.Logf("successfully returned err when unmarshaling a %s, to a %s field\n", jTypeErr.Value, jTypeErr.Type.String()) + }) + } +} +func TestUnmarshalInvalidJSON_BadType_Time(t *testing.T) { + var badTypeTests = []struct { + Field string + BadValue interface{} + Error error + }{ // The `Field` values here correspond to the `ModelBadTypes` jsonapi fields. {Field: "time_field", BadValue: "A string.", Error: ErrInvalidTime}, // Expected int64. {Field: "time_ptr_field", BadValue: "A string.", Error: ErrInvalidTime}, // Expected *time / int64. } @@ -213,6 +243,7 @@ func TestUnmarshalInvalidJSON_BadType(t *testing.T) { if err == nil { t.Fatalf("Expected error due to invalid type.") } + if err.Error() != expectedErrorMessage { t.Fatalf("Unexpected error message: %s", err.Error()) } @@ -332,7 +363,7 @@ func TestUnmarshalInvalidISO8601(t *testing.T) { out := new(Timestamp) if err := UnmarshalPayload(in, out); err != ErrInvalidISO8601 { - t.Fatalf("Expected ErrInvalidISO8601, got %v", err) + t.Fatalf("Expected %v, got %v", ErrInvalidISO8601, err) } } diff --git a/response.go b/response.go index 2e9acd7..1c582f9 100644 --- a/response.go +++ b/response.go @@ -202,8 +202,10 @@ func MarshalOnePayloadEmbedded(w io.Writer, model interface{}) error { return nil } -func visitModelNode(model interface{}, included *map[string]*Node, - sideload bool) (*Node, error) { +// visitModelNode converts models to jsonapi payloads +// it handles the deepest models first. (i.e. embedded models) +// this is so that upper-level attributes can overwrite lower-level attributes +func visitModelNode(model interface{}, included *map[string]*Node, sideload bool) (*Node, error) { node := new(Node) var er error @@ -211,12 +213,13 @@ func visitModelNode(model interface{}, included *map[string]*Node, modelValue := reflect.ValueOf(model).Elem() modelType := reflect.ValueOf(model).Type().Elem() + // handle just the embedded models first for i := 0; i < modelValue.NumField(); i++ { fieldValue := modelValue.Field(i) fieldType := modelType.Field(i) + // skip if annotated w/ ignore tag := fieldType.Tag.Get(annotationJSONAPI) - if shouldIgnoreField(tag) { continue } @@ -239,6 +242,22 @@ func visitModelNode(model interface{}, included *map[string]*Node, break } node.merge(embNode) + } + } + + // handle everthing else + for i := 0; i < modelValue.NumField(); i++ { + fieldValue := modelValue.Field(i) + fieldType := modelType.Field(i) + + tag := fieldType.Tag.Get(annotationJSONAPI) + + if shouldIgnoreField(tag) { + continue + } + + // skip embedded because it was handled in a previous loop + if isEmbeddedStruct(fieldType) || isEmbeddedStructPtr(fieldType) { continue } diff --git a/response_test.go b/response_test.go index f3a0d92..3cb4893 100644 --- a/response_test.go +++ b/response_test.go @@ -1236,12 +1236,17 @@ func TestMarshalUnmarshalCompositeStruct(t *testing.T) { { type Model struct { - Thing `jsonapi:"-"` - ModelID int `jsonapi:"primary,models"` - Foo string `jsonapi:"attr,foo"` - Bar string `jsonapi:"attr,bar"` - Bat string `jsonapi:"attr,bat"` - Buzz int `jsonapi:"attr,buzz"` + Thing `jsonapi:"-"` + ModelID int `jsonapi:"primary,models"` + Foo string `jsonapi:"attr,foo"` + Bar string `jsonapi:"attr,bar"` + Bat string `jsonapi:"attr,bat"` + Buzz int `jsonapi:"attr,buzz"` + CreateDate iso8601Datetime `jsonapi:"attr,create-date"` + } + + isoDate := iso8601Datetime{ + Time: time.Date(2016, time.December, 8, 15, 18, 54, 0, time.UTC), } scenarios = append(scenarios, test{ @@ -1252,19 +1257,21 @@ func TestMarshalUnmarshalCompositeStruct(t *testing.T) { Type: "models", ID: "1", Attributes: map[string]interface{}{ - "bar": "barry", - "bat": "batty", - "buzz": 99, - "foo": "fooey", + "bar": "barry", + "bat": "batty", + "buzz": 99, + "foo": "fooey", + "create-date": isoDate.String(), }, }, }, expected: &Model{ - ModelID: 1, - Foo: "fooey", - Bar: "barry", - Bat: "batty", - Buzz: 99, + ModelID: 1, + Foo: "fooey", + Bar: "barry", + Bat: "batty", + Buzz: 99, + CreateDate: isoDate, }, }) } @@ -1405,6 +1412,73 @@ func TestMarshalUnmarshalCompositeStruct(t *testing.T) { }, }) } + + { + type Model struct { + *Thing + ModelID int `jsonapi:"primary,models"` + Foo string `jsonapi:"attr,foo"` + Bar string `jsonapi:"attr,bar"` + Bat string `jsonapi:"attr,bat"` + FunTimes []unixMilli `jsonapi:"attr,fun-times"` + SadTimes []*unixMilli `jsonapi:"attr,sad-times"` + GoodTimes map[string]unixMilli `jsonapi:"attr,good-times"` + BadTimes map[string]*unixMilli `jsonapi:"attr,bad-times"` + CreateDate *unixMilli `jsonapi:"attr,create-date"` + UpdateDate unixMilli `jsonapi:"attr,update-date"` + } + + unixMs := unixMilli{ + Time: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC), + } + + scenarios = append(scenarios, test{ + name: "unixMilli in all supported variations", + dst: &Model{}, + payload: &OnePayload{ + Data: &Node{ + Type: "models", + ID: "1", + Attributes: map[string]interface{}{ + "bar": "barry", + "bat": "batty", + "foo": "fooey", + "fun-times": []int64{1257894000000, 1257894000000}, + "sad-times": []int64{1257894000000, 1257894000000}, + "bad-times": map[string]int64{ + "abc": 1257894000000, + "xyz": 1257894000000, + }, + "good-times": map[string]int64{ + "abc": 1257894000000, + "xyz": 1257894000000, + }, + "create-date": 1257894000000, + "update-date": 1257894000000, + }, + }, + }, + expected: &Model{ + ModelID: 1, + Foo: "fooey", + Bar: "barry", + Bat: "batty", + FunTimes: []unixMilli{unixMs, unixMs}, + SadTimes: []*unixMilli{&unixMs, &unixMs}, + GoodTimes: map[string]unixMilli{ + "abc": unixMs, + "xyz": unixMs, + }, + BadTimes: map[string]*unixMilli{ + "abc": &unixMs, + "xyz": &unixMs, + }, + CreateDate: &unixMs, + UpdateDate: unixMs, + }, + }) + } + for _, scenario := range scenarios { t.Logf("running scenario: %s\n", scenario.name) @@ -1426,7 +1500,7 @@ func TestMarshalUnmarshalCompositeStruct(t *testing.T) { t.Fatal(err) } if !isJSONEqual { - t.Errorf("Got\n%s\nExpected\n%s\n", buf.Bytes(), payload) + t.Errorf("Marshaling Got\n%s\nExpected\n%s\n", buf.Bytes(), payload) } // run jsonapi unmarshal @@ -1436,7 +1510,7 @@ func TestMarshalUnmarshalCompositeStruct(t *testing.T) { // assert decoded and expected models are equal if !reflect.DeepEqual(scenario.expected, scenario.dst) { - t.Errorf("Got\n%#v\nExpected\n%#v\n", scenario.dst, scenario.expected) + t.Errorf("Unmarshaling Got\n%#v\nExpected\n%#v\n", scenario.dst, scenario.expected) } } }