diff --git a/decoders/netflow/netflow.go b/decoders/netflow/netflow.go index 3a415d2..8836494 100644 --- a/decoders/netflow/netflow.go +++ b/decoders/netflow/netflow.go @@ -100,7 +100,7 @@ func DecodeTemplateSet(payload *bytes.Buffer) ([]TemplateRecord, error) { break } - if int(templateRecord.FieldCount) < 0 { + if templateRecord.FieldCount == 0 { return records, NewErrorDecodingNetFlow("Error decoding TemplateSet: zero count.") } diff --git a/decoders/sflow/sflow.go b/decoders/sflow/sflow.go index 5208809..5a8037a 100644 --- a/decoders/sflow/sflow.go +++ b/decoders/sflow/sflow.go @@ -16,6 +16,19 @@ const ( FORMAT_ETH = 2 FORMAT_IPV4 = 3 FORMAT_IPV6 = 4 + + // The following max constants control what goflow considers reasonable amounts of data objects reported in a single packet. + // This is to prevent an attacker from making us allocate arbitrary amounts of memory and let goflow be killed by the OOM killer. + // sflow samples are reported in UDP packets which have a maximum PDU size of 64Kib. The following numbers are derived from that fact. + MAX_UDP_PKT_SIZE = 65535 + USUAL_SAMPLED_HEADER_SIZE = 128 + FLOW_RECORD_HEADER_SIZE = 8 + ASN_SIZE = 4 + COMMUNITY_SIZE = 4 + MAX_SAMPLES_PER_PACKET = MAX_UDP_PKT_SIZE / USUAL_SAMPLED_HEADER_SIZE + MAX_FLOW_RECORDS = MAX_UDP_PKT_SIZE / FLOW_RECORD_HEADER_SIZE + MAX_AS_PATH_LENGTH = MAX_UDP_PKT_SIZE / ASN_SIZE + MAX_COMMUNITIES_LENGTH = MAX_UDP_PKT_SIZE / COMMUNITY_SIZE ) type ErrorDecodingSFlow struct { @@ -203,6 +216,9 @@ func DecodeFlowRecord(header *RecordHeader, payload *bytes.Buffer) (FlowRecord, if int(extendedGateway.ASPathLength) > payload.Len()-4 { return flowRecord, errors.New(fmt.Sprintf("Invalid AS path length: %v.", extendedGateway.ASPathLength)) } + if extendedGateway.ASPathLength > MAX_AS_PATH_LENGTH { + return flowRecord, fmt.Errorf("Invalid AS path length: %d", extendedGateway.ASPathLength) + } asPath = make([]uint32, extendedGateway.ASPathLength) if len(asPath) > 0 { err = utils.BinaryDecoder(payload, asPath) @@ -218,7 +234,11 @@ func DecodeFlowRecord(header *RecordHeader, payload *bytes.Buffer) (FlowRecord, return flowRecord, err } if int(extendedGateway.CommunitiesLength) > payload.Len()-4 { - return flowRecord, errors.New(fmt.Sprintf("Invalid Communities length: %v.", extendedGateway.ASPathLength)) + return flowRecord, errors.New(fmt.Sprintf("Invalid Communities length: %v.", extendedGateway.CommunitiesLength)) + } + + if extendedGateway.CommunitiesLength > MAX_COMMUNITIES_LENGTH { + return flowRecord, fmt.Errorf("Invalid communities length: %d", extendedGateway.CommunitiesLength) } communities := make([]uint32, extendedGateway.CommunitiesLength) if len(communities) > 0 { @@ -280,6 +300,9 @@ func DecodeSample(header *SampleHeader, payload *bytes.Buffer) (interface{}, err return sample, err } recordsCount = flowSample.FlowRecordsCount + if recordsCount > MAX_FLOW_RECORDS { + return flowSample, fmt.Errorf("Invalid number of flows records: %d", recordsCount) + } flowSample.Records = make([]FlowRecord, recordsCount) sample = flowSample } else if format == FORMAT_ETH || format == FORMAT_IPV6 { @@ -291,6 +314,10 @@ func DecodeSample(header *SampleHeader, payload *bytes.Buffer) (interface{}, err Header: *header, CounterRecordsCount: recordsCount, } + + if recordsCount > MAX_SAMPLES_PER_PACKET { + return flowSample, fmt.Errorf("Invalid number of samples: %d", recordsCount) + } counterSample.Records = make([]CounterRecord, recordsCount) sample = counterSample } else if format == FORMAT_IPV4 {