diff --git a/staticcheck/fakejson/encode.go b/staticcheck/fakejson/encode.go new file mode 100644 index 000000000..f5e6c4010 --- /dev/null +++ b/staticcheck/fakejson/encode.go @@ -0,0 +1,373 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file contains a modified copy of the encoding/json encoder. +// All dynamic behavior has been removed, and reflecttion has been replaced with go/types. +// This allows us to statically find unmarshable types +// with the same rules for tags, shadowing and addressability as encoding/json. +// This is used for SA1026. + +package fakejson + +import ( + "go/token" + "go/types" + "sort" + "strings" + "unicode" + + "honnef.co/go/tools/staticcheck/fakereflect" +) + +// parseTag splits a struct field's json tag into its name and +// comma-separated options. +func parseTag(tag string) string { + if idx := strings.Index(tag, ","); idx != -1 { + return tag[:idx] + } + return tag +} + +func Marshal(v types.Type) *UnsupportedTypeError { + enc := encoder{ + seen: map[fakereflect.TypeAndCanAddr]struct{}{}, + } + return enc.newTypeEncoder(fakereflect.TypeAndCanAddr{Type: v}, "x") +} + +// An UnsupportedTypeError is returned by Marshal when attempting +// to encode an unsupported value type. +type UnsupportedTypeError struct { + Type types.Type + Path string +} + +var marshalerType = types.NewInterfaceType([]*types.Func{ + types.NewFunc(token.NoPos, nil, "MarshalJSON", types.NewSignature(nil, + types.NewTuple(), + types.NewTuple( + types.NewVar(token.NoPos, nil, "", types.NewSlice(types.Typ[types.Byte])), + types.NewVar(0, nil, "", types.Universe.Lookup("error").Type())), + false, + )), +}, nil).Complete() + +var textMarshalerType = types.NewInterfaceType([]*types.Func{ + types.NewFunc(token.NoPos, nil, "MarshalText", types.NewSignature(nil, + types.NewTuple(), + types.NewTuple( + types.NewVar(token.NoPos, nil, "", types.NewSlice(types.Typ[types.Byte])), + types.NewVar(0, nil, "", types.Universe.Lookup("error").Type())), + false, + )), +}, nil).Complete() + +type encoder struct { + seen map[fakereflect.TypeAndCanAddr]struct{} +} + +func (enc *encoder) newTypeEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError { + if _, ok := enc.seen[t]; ok { + return nil + } + enc.seen[t] = struct{}{} + + if t.Implements(marshalerType) { + return nil + } + if !t.IsPtr() && t.CanAddr() && fakereflect.PtrTo(t).Implements(marshalerType) { + return nil + } + if t.Implements(textMarshalerType) { + return nil + } + if !t.IsPtr() && t.CanAddr() && fakereflect.PtrTo(t).Implements(textMarshalerType) { + return nil + } + + switch t.Type.Underlying().(type) { + case *types.Basic, *types.Interface: + return nil + case *types.Struct: + return enc.typeFields(t, stack) + case *types.Map: + return enc.newMapEncoder(t, stack) + case *types.Slice: + return enc.newSliceEncoder(t, stack) + case *types.Array: + return enc.newArrayEncoder(t, stack) + case *types.Pointer: + // we don't have to express the pointer dereference in the path; x.f is syntactic sugar for (*x).f + return enc.newTypeEncoder(t.Elem(), stack) + default: + return &UnsupportedTypeError{t.Type, stack} + } +} + +func (enc *encoder) newMapEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError { + switch t.Key().Type.Underlying().(type) { + case *types.Basic: + default: + if !t.Key().Implements(textMarshalerType) { + return &UnsupportedTypeError{ + Type: t.Type, + Path: stack, + } + } + } + return enc.newTypeEncoder(t.Elem(), stack+"[k]") +} + +func (enc *encoder) newSliceEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError { + // Byte slices get special treatment; arrays don't. + basic, ok := t.Elem().Type.Underlying().(*types.Basic) + if ok && basic.Kind() == types.Uint8 { + p := fakereflect.PtrTo(t.Elem()) + if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) { + return nil + } + } + return enc.newArrayEncoder(t, stack) +} + +func (enc *encoder) newArrayEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError { + return enc.newTypeEncoder(t.Elem(), stack+"[0]") +} + +func isValidTag(s string) bool { + if s == "" { + return false + } + for _, c := range s { + switch { + case strings.ContainsRune("!#$%&()*+-./:;<=>?@[]^_{|}~ ", c): + // Backslash and quote chars are reserved, but + // otherwise any punctuation chars are allowed + // in a tag name. + case !unicode.IsLetter(c) && !unicode.IsDigit(c): + return false + } + } + return true +} + +func typeByIndex(t fakereflect.TypeAndCanAddr, index []int) fakereflect.TypeAndCanAddr { + for _, i := range index { + if t.IsPtr() { + t = t.Elem() + } + t = t.Field(i).Type + } + return t +} + +func pathByIndex(t fakereflect.TypeAndCanAddr, index []int) string { + path := "" + for _, i := range index { + if t.IsPtr() { + t = t.Elem() + } + path += "." + t.Field(i).Name + t = t.Field(i).Type + } + return path +} + +// A field represents a single field found in a struct. +type field struct { + name string + + tag bool + index []int + typ fakereflect.TypeAndCanAddr +} + +// byIndex sorts field by index sequence. +type byIndex []field + +func (x byIndex) Len() int { return len(x) } + +func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +func (x byIndex) Less(i, j int) bool { + for k, xik := range x[i].index { + if k >= len(x[j].index) { + return false + } + if xik != x[j].index[k] { + return xik < x[j].index[k] + } + } + return len(x[i].index) < len(x[j].index) +} + +// typeFields returns a list of fields that JSON should recognize for the given type. +// The algorithm is breadth-first search over the set of structs to include - the top struct +// and then any reachable anonymous structs. +func (enc *encoder) typeFields(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError { + // Anonymous fields to explore at the current level and the next. + current := []field{} + next := []field{{typ: t}} + + // Count of queued names for current level and the next. + var count, nextCount map[fakereflect.TypeAndCanAddr]int + + // Types already visited at an earlier level. + visited := map[fakereflect.TypeAndCanAddr]bool{} + + // Fields found. + var fields []field + + for len(next) > 0 { + current, next = next, current[:0] + count, nextCount = nextCount, map[fakereflect.TypeAndCanAddr]int{} + + for _, f := range current { + if visited[f.typ] { + continue + } + visited[f.typ] = true + + // Scan f.typ for fields to include. + for i := 0; i < f.typ.NumField(); i++ { + sf := f.typ.Field(i) + if sf.Anonymous { + t := sf.Type + if t.IsPtr() { + t = t.Elem() + } + if !sf.IsExported() && !t.IsStruct() { + // Ignore embedded fields of unexported non-struct types. + continue + } + // Do not ignore embedded fields of unexported struct types + // since they may have exported fields. + } else if !sf.IsExported() { + // Ignore unexported non-embedded fields. + continue + } + tag := sf.Tag.Get("json") + if tag == "-" { + continue + } + name := parseTag(tag) + if !isValidTag(name) { + name = "" + } + index := make([]int, len(f.index)+1) + copy(index, f.index) + index[len(f.index)] = i + + ft := sf.Type + if ft.Name() == "" && ft.IsPtr() { + // Follow pointer. + ft = ft.Elem() + } + + // Record found field and index sequence. + if name != "" || !sf.Anonymous || !ft.IsStruct() { + tagged := name != "" + if name == "" { + name = sf.Name + } + field := field{ + name: name, + tag: tagged, + index: index, + typ: ft, + } + + fields = append(fields, field) + if count[f.typ] > 1 { + // If there were multiple instances, add a second, + // so that the annihilation code will see a duplicate. + // It only cares about the distinction between 1 or 2, + // so don't bother generating any more copies. + fields = append(fields, fields[len(fields)-1]) + } + continue + } + + // Record new anonymous struct to explore in next round. + nextCount[ft]++ + if nextCount[ft] == 1 { + next = append(next, field{name: ft.Name(), index: index, typ: ft}) + } + } + } + } + + sort.Slice(fields, func(i, j int) bool { + x := fields + // sort field by name, breaking ties with depth, then + // breaking ties with "name came from json tag", then + // breaking ties with index sequence. + if x[i].name != x[j].name { + return x[i].name < x[j].name + } + if len(x[i].index) != len(x[j].index) { + return len(x[i].index) < len(x[j].index) + } + if x[i].tag != x[j].tag { + return x[i].tag + } + return byIndex(x).Less(i, j) + }) + + // Delete all fields that are hidden by the Go rules for embedded fields, + // except that fields with JSON tags are promoted. + + // The fields are sorted in primary order of name, secondary order + // of field index length. Loop over names; for each name, delete + // hidden fields by choosing the one dominant field that survives. + out := fields[:0] + for advance, i := 0, 0; i < len(fields); i += advance { + // One iteration per name. + // Find the sequence of fields with the name of this first field. + fi := fields[i] + name := fi.name + for advance = 1; i+advance < len(fields); advance++ { + fj := fields[i+advance] + if fj.name != name { + break + } + } + if advance == 1 { // Only one field with this name + out = append(out, fi) + continue + } + dominant, ok := dominantField(fields[i : i+advance]) + if ok { + out = append(out, dominant) + } + } + + fields = out + sort.Sort(byIndex(fields)) + + for i := range fields { + f := &fields[i] + err := enc.newTypeEncoder(typeByIndex(t, f.index), stack+pathByIndex(t, f.index)) + if err != nil { + return err + } + } + return nil +} + +// dominantField looks through the fields, all of which are known to +// have the same name, to find the single field that dominates the +// others using Go's embedding rules, modified by the presence of +// JSON tags. If there are multiple top-level fields, the boolean +// will be false: This condition is an error in Go and we skip all +// the fields. +func dominantField(fields []field) (field, bool) { + // The fields are sorted in increasing index-length order, then by presence of tag. + // That means that the first field is the dominant one. We need only check + // for error cases: two fields at top level, either both tagged or neither tagged. + if len(fields) > 1 && len(fields[0].index) == len(fields[1].index) && fields[0].tag == fields[1].tag { + return field{}, false + } + return fields[0], true +} diff --git a/staticcheck/fakereflect/fakereflect.go b/staticcheck/fakereflect/fakereflect.go new file mode 100644 index 000000000..7f8fd4799 --- /dev/null +++ b/staticcheck/fakereflect/fakereflect.go @@ -0,0 +1,131 @@ +package fakereflect + +import ( + "fmt" + "go/types" + "reflect" +) + +type TypeAndCanAddr struct { + Type types.Type + canAddr bool +} + +type StructField struct { + Index []int + Name string + Anonymous bool + Tag reflect.StructTag + f *types.Var + Type TypeAndCanAddr +} + +func (sf StructField) IsExported() bool { return sf.f.Exported() } + +func (t TypeAndCanAddr) Field(i int) StructField { + st := t.Type.Underlying().(*types.Struct) + f := st.Field(i) + return StructField{ + f: f, + Index: []int{i}, + Name: f.Name(), + Anonymous: f.Anonymous(), + Tag: reflect.StructTag(st.Tag(i)), + Type: TypeAndCanAddr{ + Type: f.Type(), + canAddr: t.canAddr, + }, + } +} + +func (t TypeAndCanAddr) FieldByIndex(index []int) StructField { + f := t.Field(index[0]) + for _, idx := range index[1:] { + f = f.Type.Field(idx) + } + f.Index = index + return f +} + +func PtrTo(t TypeAndCanAddr) TypeAndCanAddr { + // Note that we don't care about canAddr here because it's irrelevant to all uses of PtrTo + return TypeAndCanAddr{Type: types.NewPointer(t.Type)} +} + +func (t TypeAndCanAddr) CanAddr() bool { return t.canAddr } + +func (t TypeAndCanAddr) Implements(ityp *types.Interface) bool { + return types.Implements(t.Type, ityp) +} + +func (t TypeAndCanAddr) IsSlice() bool { + _, ok := t.Type.Underlying().(*types.Slice) + return ok +} + +func (t TypeAndCanAddr) IsArray() bool { + _, ok := t.Type.Underlying().(*types.Array) + return ok +} + +func (t TypeAndCanAddr) IsPtr() bool { + _, ok := t.Type.Underlying().(*types.Pointer) + return ok +} + +func (t TypeAndCanAddr) IsInterface() bool { + _, ok := t.Type.Underlying().(*types.Interface) + return ok +} + +func (t TypeAndCanAddr) IsStruct() bool { + _, ok := t.Type.Underlying().(*types.Struct) + return ok +} + +func (t TypeAndCanAddr) Name() string { + named, ok := t.Type.(*types.Named) + if !ok { + return "" + } + return named.Obj().Name() +} + +func (t TypeAndCanAddr) NumField() int { + return t.Type.Underlying().(*types.Struct).NumFields() +} + +func (t TypeAndCanAddr) String() string { + return t.Type.String() +} + +func (t TypeAndCanAddr) Key() TypeAndCanAddr { + return TypeAndCanAddr{Type: t.Type.Underlying().(*types.Map).Key()} +} + +func (t TypeAndCanAddr) Elem() TypeAndCanAddr { + switch typ := t.Type.Underlying().(type) { + case *types.Pointer: + return TypeAndCanAddr{ + Type: typ.Elem(), + canAddr: true, + } + case *types.Slice: + return TypeAndCanAddr{ + Type: typ.Elem(), + canAddr: true, + } + case *types.Array: + return TypeAndCanAddr{ + Type: typ.Elem(), + canAddr: t.canAddr, + } + case *types.Map: + return TypeAndCanAddr{ + Type: typ.Elem(), + canAddr: false, + } + default: + panic(fmt.Sprintf("unhandled type %T", typ)) + } +} diff --git a/staticcheck/lint.go b/staticcheck/lint.go index d09d079b3..10aa866b7 100644 --- a/staticcheck/lint.go +++ b/staticcheck/lint.go @@ -34,6 +34,7 @@ import ( "honnef.co/go/tools/knowledge" "honnef.co/go/tools/pattern" "honnef.co/go/tools/printf" + "honnef.co/go/tools/staticcheck/fakejson" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" @@ -254,9 +255,9 @@ var ( } checkUnsupportedMarshal = map[string]CallCheck{ - "encoding/json.Marshal": checkUnsupportedMarshalImpl(knowledge.Arg("json.Marshal.v"), "json", "MarshalJSON", "MarshalText"), + "encoding/json.Marshal": checkUnsupportedMarshalJSON, "encoding/xml.Marshal": checkUnsupportedMarshalImpl(knowledge.Arg("xml.Marshal.v"), "xml", "MarshalXML", "MarshalText"), - "(*encoding/json.Encoder).Encode": checkUnsupportedMarshalImpl(knowledge.Arg("(*encoding/json.Encoder).Encode.v"), "json", "MarshalJSON", "MarshalText"), + "(*encoding/json.Encoder).Encode": checkUnsupportedMarshalJSON, "(*encoding/xml.Encoder).Encode": checkUnsupportedMarshalImpl(knowledge.Arg("(*encoding/xml.Encoder).Encode.v"), "xml", "MarshalXML", "MarshalText"), } @@ -868,6 +869,19 @@ func checkNoopMarshalImpl(argN int, meths ...string) CallCheck { } } +func checkUnsupportedMarshalJSON(call *Call) { + arg := call.Args[0] + T := arg.Value.Value.Type() + if err := fakejson.Marshal(T); err != nil { + typ := types.TypeString(err.Type, types.RelativeTo(arg.Value.Value.Parent().Pkg.Pkg)) + if err.Path == "x" { + arg.Invalid(fmt.Sprintf("trying to marshal unsupported type %s", typ)) + } else { + arg.Invalid(fmt.Sprintf("trying to marshal unsupported type %s, via %s", typ, err.Path)) + } + } +} + func checkUnsupportedMarshalImpl(argN int, tag string, meths ...string) CallCheck { // TODO(dh): flag slices and maps of unsupported types return func(call *Call) { diff --git a/staticcheck/testdata/src/CheckUnsupportedMarshal/CheckUnsupportedMarshal.go b/staticcheck/testdata/src/CheckUnsupportedMarshal/CheckUnsupportedMarshal.go index 63b1aec8d..ee0c86bae 100644 --- a/staticcheck/testdata/src/CheckUnsupportedMarshal/CheckUnsupportedMarshal.go +++ b/staticcheck/testdata/src/CheckUnsupportedMarshal/CheckUnsupportedMarshal.go @@ -3,6 +3,7 @@ package pkg import ( "encoding/json" "encoding/xml" + "time" ) type T1 struct { @@ -16,11 +17,11 @@ type T2 struct { } type T3 struct { - C chan int + Ch chan int } type T4 struct { - C C + C ValueMarshaler } type T5 struct { @@ -42,9 +43,27 @@ type T8 struct { *T7 } -type C chan int +type T9 struct { + F PointerMarshaler +} + +type T10 struct { + F *struct { + PointerMarshaler + } +} + +type Recursive struct { + Field *Recursive +} + +type ValueMarshaler chan int + +func (ValueMarshaler) MarshalText() ([]byte, error) { return nil, nil } -func (C) MarshalText() ([]byte, error) { return nil, nil } +type PointerMarshaler chan int + +func (*PointerMarshaler) MarshalText() ([]byte, error) { return nil, nil } func fn() { var t1 T1 @@ -54,17 +73,20 @@ func fn() { var t5 T5 var t6 T6 var t8 T8 + var t9 T9 + var t10 T10 + var t11 Recursive json.Marshal(t1) json.Marshal(t2) - json.Marshal(t3) // want `trying to marshal chan or func value, field CheckUnsupportedMarshal\.T3\.C` + json.Marshal(t3) // want `unsupported type chan int, via x\.Ch` json.Marshal(t4) - json.Marshal(t5) // want `trying to marshal chan or func value, field CheckUnsupportedMarshal\.T5\.B` + json.Marshal(t5) // want `unsupported type func\(\), via x\.B` json.Marshal(t6) (*json.Encoder)(nil).Encode(t1) (*json.Encoder)(nil).Encode(t2) - (*json.Encoder)(nil).Encode(t3) // want `trying to marshal chan or func value, field CheckUnsupportedMarshal\.T3\.C` + (*json.Encoder)(nil).Encode(t3) // want `unsupported type chan int, via x\.Ch` (*json.Encoder)(nil).Encode(t4) - (*json.Encoder)(nil).Encode(t5) // want `trying to marshal chan or func value, field CheckUnsupportedMarshal\.T5\.B` + (*json.Encoder)(nil).Encode(t5) // want `unsupported type func\(\), via x\.B` (*json.Encoder)(nil).Encode(t6) xml.Marshal(t1) @@ -80,5 +102,114 @@ func fn() { (*xml.Encoder)(nil).Encode(t5) (*xml.Encoder)(nil).Encode(t6) // want `trying to marshal chan or func value, field CheckUnsupportedMarshal\.T6\.B` - json.Marshal(t8) // want `trying to marshal chan or func value, field CheckUnsupportedMarshal\.T8\.T7\.T3\.C` + json.Marshal(t8) // want `unsupported type chan int, via x\.T7\.T3\.Ch` + json.Marshal(t9) // want `unsupported type PointerMarshaler, via x\.F` + json.Marshal(&t9) // this is fine, t9 is addressable, therefore T9.D is, too + json.Marshal(t10) // this is fine, T10.F.D is addressable + + json.Marshal(t11) + xml.Marshal(t11) +} + +func addressability() { + var a PointerMarshaler + var b []PointerMarshaler + var c struct { + F PointerMarshaler + } + var d [4]PointerMarshaler + json.Marshal(a) // want `unsupported type PointerMarshaler$` + json.Marshal(&a) + json.Marshal(b) + json.Marshal(&b) + json.Marshal(c) // want `unsupported type PointerMarshaler, via x\.F` + json.Marshal(&c) + json.Marshal(d) // want `unsupported type PointerMarshaler, via x\[0\]` + json.Marshal(&d) + + var m1 map[string]PointerMarshaler + json.Marshal(m1) // want `unsupported type PointerMarshaler, via x\[k\]` + json.Marshal(&m1) // want `unsupported type PointerMarshaler, via x\[k\]` + json.Marshal([]map[string]PointerMarshaler{m1}) // want `unsupported type PointerMarshaler, via x\[0\]\[k\]` + + var m2 map[string]*PointerMarshaler + json.Marshal(m2) + json.Marshal(&m2) + json.Marshal([]map[string]*PointerMarshaler{m2}) +} + +func maps() { + var good map[int]string + var bad map[interface{}]string + // the map key has to be statically known good; it must be a number or a string + json.Marshal(good) + json.Marshal(bad) // want `unsupported type map\[interface\{\}\]string$` + + var m1 map[string]PointerMarshaler + json.Marshal(m1) // want `unsupported type PointerMarshaler, via x\[k\]` + json.Marshal(&m1) // want `unsupported type PointerMarshaler, via x\[k\]` + json.Marshal([]map[string]PointerMarshaler{m1}) // want `unsupported type PointerMarshaler, via x\[0\]\[k\]` + + var m2 map[string]*PointerMarshaler + json.Marshal(m2) + json.Marshal(&m2) + json.Marshal([]map[string]*PointerMarshaler{m2}) + + var m3 map[string]ValueMarshaler + json.Marshal(m3) + json.Marshal(&m3) + json.Marshal([]map[string]ValueMarshaler{m3}) + + var m4 map[string]*ValueMarshaler + json.Marshal(m4) + json.Marshal(&m4) + json.Marshal([]map[string]*ValueMarshaler{m4}) + + var m5 map[ValueMarshaler]string + var m6 map[*ValueMarshaler]string + var m7 map[PointerMarshaler]string + var m8 map[*PointerMarshaler]string + + json.Marshal(m5) + json.Marshal(m6) + json.Marshal(m7) // want `unsupported type map\[PointerMarshaler\]string$` + json.Marshal(m8) +} + +func fieldPriority() { + // In this example, the channel doesn't matter, because T1.F has higher priority than T1.T2.F + type lT2 struct { + F chan int + } + type lT1 struct { + F int + lT2 + } + json.Marshal(lT1{}) + + // In this example, it does matter + type lT4 struct { + C chan int + } + type lT3 struct { + F int + lT4 + } + json.Marshal(lT3{}) // want `unsupported type chan int, via x\.lT4\.C` +} + +func longPath() { + var foo struct { + Field struct { + Field2 []struct { + Map map[string]chan int + } + } + } + json.Marshal(foo) // want `unsupported type chan int, via x\.Field\.Field2\[0\].Map\[k\]` +} + +func otherPackage() { + var x time.Ticker + json.Marshal(x) // want `unsupported type <-chan time\.Time, via x\.C` }