diff --git a/diam/avp.go b/diam/avp.go index 501a19ee..23fe903f 100644 --- a/diam/avp.go +++ b/diam/avp.go @@ -14,6 +14,9 @@ import ( "github.com/fiorix/go-diameter/v4/diam/dict" ) +// Used to signal that parsing should not stop. +type DecodeError error + // AVP is a Diameter attribute-value-pair. type AVP struct { Code uint32 // Code of this AVP @@ -91,7 +94,7 @@ func (a *AVP) DecodeFromBytes(data []byte, application uint32, dictionary *dict. } a.Data, err = datatype.Decode(dictAVP.Data.Type, payload) if err != nil { - return err + return DecodeError(fmt.Errorf("%s(%d): %v", dictAVP.Name, dictAVP.Code, err)) } // Handle grouped AVPs. if a.Data.Type() == datatype.GroupedType { @@ -100,7 +103,7 @@ func (a *AVP) DecodeFromBytes(data []byte, application uint32, dictionary *dict. application, dictionary, ) if err != nil { - return err + return DecodeError(fmt.Errorf("%s(%d): Grouped{%v}", dictAVP.Name, dictAVP.Code, err)) } } return nil diff --git a/diam/datatype/address.go b/diam/datatype/address.go index a168b20e..d5269ce5 100644 --- a/diam/datatype/address.go +++ b/diam/datatype/address.go @@ -17,19 +17,19 @@ type Address []byte // DecodeAddress decodes an Address data type from byte array. func DecodeAddress(b []byte) (Type, error) { if len(b) < 3 { - return nil, fmt.Errorf("Not enough data to make an Address from byte[%d] = %+v", len(b), b) + return Address{}, fmt.Errorf("Not enough data to make an Address from byte[%d] = %+v", len(b), b) } if binary.BigEndian.Uint16(b[:2]) == 0 || binary.BigEndian.Uint16(b[:2]) == 65535 { - return nil, errors.New("Invalid address type received") + return Address{}, errors.New("Invalid address type received") } switch binary.BigEndian.Uint16(b[:2]) { case 0x01: if len(b[2:]) != 4 { - return nil, errors.New("Invalid length for IPv4") + return Address{}, errors.New("Invalid length for IPv4") } case 0x02: if len(b[2:]) != 16 { - return nil, errors.New("Invalid length for IPv6") + return Address{}, errors.New("Invalid length for IPv6") } default: return Address(b), nil @@ -92,5 +92,8 @@ func (addr Address) String() string { if ip6 := net.IP(addr).To16(); ip6 != nil { return fmt.Sprintf("Address{%s},Padding:%d", net.IP(addr), addr.Padding()) } - return fmt.Sprintf("Address{%#v}, Type{%#v} Padding:%d", addr[2:], addr[:2], addr.Padding()) + if len(addr) == 0 { + return "Address{},Padding:0" // NOTE: To avoid panicking on addr[2:] + } + return fmt.Sprintf("Address{%#v},Type{%#v},Padding:%d", addr[2:], addr[:2], addr.Padding()) } diff --git a/diam/datatype/enum.go b/diam/datatype/enum.go index 5f8eb264..589f49e4 100644 --- a/diam/datatype/enum.go +++ b/diam/datatype/enum.go @@ -12,10 +12,7 @@ type Enumerated Integer32 // DecodeEnumerated decodes an Enumerated data type from byte array. func DecodeEnumerated(b []byte) (Type, error) { v, err := DecodeInteger32(b) - if err != nil { - return nil, err - } - return Enumerated(v.(Integer32)), nil + return Enumerated(v.(Integer32)), err } // Serialize implements the Type interface. diff --git a/diam/dict/parser.go b/diam/dict/parser.go index e8830074..5c33885f 100644 --- a/diam/dict/parser.go +++ b/diam/dict/parser.go @@ -40,6 +40,14 @@ type Parser struct { command map[codeIdx]*Command // Command index mu sync.Mutex // Protects all maps once sync.Once + + // Strict indicates whether an error should be returned when one or more + // AVPs are invalid/empty and cannot be properly decoded. + // + // Defaults to true. When set to false, all decoding errors found during the + // parsing process will be stored in the Message's DecodeErr field which is + // accessible from a request handler. + Strict bool } type codeIdx struct { @@ -62,6 +70,7 @@ type appIdTypeIdx struct { // NewParser allocates a new Parser optionally loading dictionary XML files. func NewParser(filename ...string) (*Parser, error) { p := new(Parser) + p.Strict = true var err error for _, f := range filename { if err = p.LoadFile(f); err != nil { diff --git a/diam/group.go b/diam/group.go index 384b855c..e7b394a2 100644 --- a/diam/group.go +++ b/diam/group.go @@ -7,6 +7,7 @@ package diam import ( "bytes" "fmt" + "strings" "github.com/fiorix/go-diameter/v4/diam/datatype" "github.com/fiorix/go-diameter/v4/diam/dict" @@ -25,15 +26,19 @@ type GroupedAVP struct { func DecodeGrouped(data datatype.Grouped, application uint32, dictionary *dict.Parser) (*GroupedAVP, error) { g := &GroupedAVP{} b := []byte(data) + var errs []string for n := 0; n < len(b); { avp, err := DecodeAVP(b[n:], application, dictionary) if err != nil { - return nil, err + errs = append(errs, err.Error()) } g.AVP = append(g.AVP, avp) n += avp.Len() } // TODO: handle nested groups? + if len(errs) > 0 { + return g, fmt.Errorf("%s", strings.Join(errs, "; ")) + } return g, nil } diff --git a/diam/message.go b/diam/message.go index 73c2d721..0938e6d6 100644 --- a/diam/message.go +++ b/diam/message.go @@ -12,6 +12,7 @@ import ( "io" "math/rand" "net" + "strings" "sync" "time" @@ -32,9 +33,9 @@ type Message struct { Header *Header AVP []*AVP // AVPs in this message. - // dictionary parser object used to encode and decode AVPs. - dictionary *dict.Parser - stream uint // the stream this message was received on (if any) + DecodeErr error // Possible decoding error on one or more AVPs (does not halt parsing) + dictionary *dict.Parser // dictionary parser object used to encode and decode AVPs. + stream uint // the stream this message was received on (if any) ctx context.Context } @@ -74,7 +75,10 @@ func ReadMessage(reader io.Reader, dictionary *dict.Parser) (*Message, error) { } m.stream = stream if err = m.readBody(reader, buf, cmd, stream); err != nil { - return nil, err + return m, err + } + if dictionary.Strict { + return m, m.DecodeErr } return m, nil } @@ -149,15 +153,24 @@ func (m *Message) maxAVPsFor(cmd *dict.Command) int { func (m *Message) decodeAVPs(b []byte) error { var a *AVP + var decodeErrs []string var err error for n := 0; n < len(b); { a, err = DecodeAVP(b[n:], m.Header.ApplicationID, m.Dictionary()) if err != nil { - return fmt.Errorf("Failed to decode AVP: %s", err) + if decodeErr, ok := err.(DecodeError); ok { + decodeErrs = append(decodeErrs, decodeErr.Error()) + } else { + return err + } } m.AVP = append(m.AVP, a) n += a.Len() } + if len(decodeErrs) > 0 { + // Depending on the settings, this will be thrown by the state machine or passed to the best handler + m.DecodeErr = fmt.Errorf("Failed to decode one or more AVPs: {%s}", strings.Join(decodeErrs, "; ")) + } return nil } diff --git a/diam/server.go b/diam/server.go index 44514bcc..9cae5212 100644 --- a/diam/server.go +++ b/diam/server.go @@ -169,10 +169,7 @@ func (c *conn) readMessage() (m *Message, err error) { } else { m, err = ReadMessage(c.buf.Reader, c.dictionary()) } - if err != nil { - return nil, err - } - return m, nil + return m, err } // Serve a new connection.