Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option to loosen message parsing to allow AVPs with empty/invalid payloads while defaulting to strict parsing #147

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions diam/avp.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (
"github.com/fiorix/go-diameter/v4/diam/dict"
)

// Used to signal that parsing should not stop.
type DecodeError error

// AVP is a Diameter attribute-value-pair.
type AVP struct {
Code uint32 // Code of this AVP
Expand Down Expand Up @@ -91,7 +94,7 @@ func (a *AVP) DecodeFromBytes(data []byte, application uint32, dictionary *dict.
}
a.Data, err = datatype.Decode(dictAVP.Data.Type, payload)
if err != nil {
return err
return DecodeError(fmt.Errorf("%s(%d): %v", dictAVP.Name, dictAVP.Code, err))
}
// Handle grouped AVPs.
if a.Data.Type() == datatype.GroupedType {
Expand All @@ -100,7 +103,7 @@ func (a *AVP) DecodeFromBytes(data []byte, application uint32, dictionary *dict.
application, dictionary,
)
if err != nil {
return err
return DecodeError(fmt.Errorf("%s(%d): Grouped{%v}", dictAVP.Name, dictAVP.Code, err))
}
}
return nil
Expand Down
13 changes: 8 additions & 5 deletions diam/datatype/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@ type Address []byte
// DecodeAddress decodes an Address data type from byte array.
func DecodeAddress(b []byte) (Type, error) {
if len(b) < 3 {
return nil, fmt.Errorf("Not enough data to make an Address from byte[%d] = %+v", len(b), b)
return Address{}, fmt.Errorf("Not enough data to make an Address from byte[%d] = %+v", len(b), b)
}
if binary.BigEndian.Uint16(b[:2]) == 0 || binary.BigEndian.Uint16(b[:2]) == 65535 {
return nil, errors.New("Invalid address type received")
return Address{}, errors.New("Invalid address type received")
}
switch binary.BigEndian.Uint16(b[:2]) {
case 0x01:
if len(b[2:]) != 4 {
return nil, errors.New("Invalid length for IPv4")
return Address{}, errors.New("Invalid length for IPv4")
}
case 0x02:
if len(b[2:]) != 16 {
return nil, errors.New("Invalid length for IPv6")
return Address{}, errors.New("Invalid length for IPv6")
}
default:
return Address(b), nil
Expand Down Expand Up @@ -92,5 +92,8 @@ func (addr Address) String() string {
if ip6 := net.IP(addr).To16(); ip6 != nil {
return fmt.Sprintf("Address{%s},Padding:%d", net.IP(addr), addr.Padding())
}
return fmt.Sprintf("Address{%#v}, Type{%#v} Padding:%d", addr[2:], addr[:2], addr.Padding())
if len(addr) == 0 {
return "Address{},Padding:0" // NOTE: To avoid panicking on addr[2:]
}
return fmt.Sprintf("Address{%#v},Type{%#v},Padding:%d", addr[2:], addr[:2], addr.Padding())
}
5 changes: 1 addition & 4 deletions diam/datatype/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@ type Enumerated Integer32
// DecodeEnumerated decodes an Enumerated data type from byte array.
func DecodeEnumerated(b []byte) (Type, error) {
v, err := DecodeInteger32(b)
if err != nil {
return nil, err
}
return Enumerated(v.(Integer32)), nil
return Enumerated(v.(Integer32)), err
}

// Serialize implements the Type interface.
Expand Down
9 changes: 9 additions & 0 deletions diam/dict/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ type Parser struct {
command map[codeIdx]*Command // Command index
mu sync.Mutex // Protects all maps
once sync.Once

// Strict indicates whether an error should be returned when one or more
// AVPs are invalid/empty and cannot be properly decoded.
//
// Defaults to true. When set to false, all decoding errors found during the
// parsing process will be stored in the Message's DecodeErr field which is
// accessible from a request handler.
Strict bool
}

type codeIdx struct {
Expand All @@ -62,6 +70,7 @@ type appIdTypeIdx struct {
// NewParser allocates a new Parser optionally loading dictionary XML files.
func NewParser(filename ...string) (*Parser, error) {
p := new(Parser)
p.Strict = true
var err error
for _, f := range filename {
if err = p.LoadFile(f); err != nil {
Expand Down
7 changes: 6 additions & 1 deletion diam/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package diam
import (
"bytes"
"fmt"
"strings"

"github.com/fiorix/go-diameter/v4/diam/datatype"
"github.com/fiorix/go-diameter/v4/diam/dict"
Expand All @@ -25,15 +26,19 @@ type GroupedAVP struct {
func DecodeGrouped(data datatype.Grouped, application uint32, dictionary *dict.Parser) (*GroupedAVP, error) {
g := &GroupedAVP{}
b := []byte(data)
var errs []string
for n := 0; n < len(b); {
avp, err := DecodeAVP(b[n:], application, dictionary)
if err != nil {
return nil, err
errs = append(errs, err.Error())
}
g.AVP = append(g.AVP, avp)
n += avp.Len()
}
// TODO: handle nested groups?
if len(errs) > 0 {
return g, fmt.Errorf("%s", strings.Join(errs, "; "))
}
return g, nil
}

Expand Down
23 changes: 18 additions & 5 deletions diam/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"io"
"math/rand"
"net"
"strings"
"sync"
"time"

Expand All @@ -32,9 +33,9 @@ type Message struct {
Header *Header
AVP []*AVP // AVPs in this message.

// dictionary parser object used to encode and decode AVPs.
dictionary *dict.Parser
stream uint // the stream this message was received on (if any)
DecodeErr error // Possible decoding error on one or more AVPs (does not halt parsing)
dictionary *dict.Parser // dictionary parser object used to encode and decode AVPs.
stream uint // the stream this message was received on (if any)
ctx context.Context
}

Expand Down Expand Up @@ -74,7 +75,10 @@ func ReadMessage(reader io.Reader, dictionary *dict.Parser) (*Message, error) {
}
m.stream = stream
if err = m.readBody(reader, buf, cmd, stream); err != nil {
return nil, err
return m, err
}
if dictionary.Strict {
return m, m.DecodeErr
}
return m, nil
}
Expand Down Expand Up @@ -149,15 +153,24 @@ func (m *Message) maxAVPsFor(cmd *dict.Command) int {

func (m *Message) decodeAVPs(b []byte) error {
var a *AVP
var decodeErrs []string
var err error
for n := 0; n < len(b); {
a, err = DecodeAVP(b[n:], m.Header.ApplicationID, m.Dictionary())
if err != nil {
return fmt.Errorf("Failed to decode AVP: %s", err)
if decodeErr, ok := err.(DecodeError); ok {
decodeErrs = append(decodeErrs, decodeErr.Error())
} else {
return err
}
}
m.AVP = append(m.AVP, a)
n += a.Len()
}
if len(decodeErrs) > 0 {
// Depending on the settings, this will be thrown by the state machine or passed to the best handler
m.DecodeErr = fmt.Errorf("Failed to decode one or more AVPs: {%s}", strings.Join(decodeErrs, "; "))
}
return nil
}

Expand Down
5 changes: 1 addition & 4 deletions diam/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,7 @@ func (c *conn) readMessage() (m *Message, err error) {
} else {
m, err = ReadMessage(c.buf.Reader, c.dictionary())
}
if err != nil {
return nil, err
}
return m, nil
return m, err
}

// Serve a new connection.
Expand Down