Skip to content

Commit

Permalink
Merge pull request #42 from gotd/feature/proto-gzip-mitigate-oom
Browse files Browse the repository at this point in the history
feat(proto): mitigate possible DOS in gzip decoding
  • Loading branch information
ernado authored Dec 19, 2020
2 parents 12233b7 + dfdb786 commit 22334d9
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 16 deletions.
10 changes: 10 additions & 0 deletions bin/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ type Buffer struct {
Buf []byte
}

// Encode wrapper.
func (b *Buffer) Encode(e Encoder) error {
return e.Encode(b)
}

// Decode wrapper.
func (b *Buffer) Decode(d Decoder) error {
return d.Decode(b)
}

// ResetN resets buffer and expands it to fit n bytes.
func (b *Buffer) ResetN(n int) {
b.Buf = append(b.Buf[:0], make([]byte, n)...)
Expand Down
3 changes: 2 additions & 1 deletion internal/proto/codec/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import (
"github.com/gotd/td/bin"
)

func tryReadLength(r io.Reader, b *bin.Buffer) (int, error) {
// readLen reads 32-bit integer and validates it as message length.
func readLen(r io.Reader, b *bin.Buffer) (int, error) {
b.ResetN(bin.Word)
if _, err := io.ReadFull(r, b.Buf[:bin.Word]); err != nil {
return 0, fmt.Errorf("failed to read length: %w", err)
Expand Down
4 changes: 2 additions & 2 deletions internal/proto/codec/full.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ var errSeqNoMismatch = errors.New("seq_no mismatch")
var errCRCMismatch = errors.New("crc mismatch")

func readFull(r io.Reader, seqNo int, b *bin.Buffer) error {
n, err := tryReadLength(r, b)
n, err := readLen(r, b)
if err != nil {
return err
return xerrors.Errorf("len: %w", err)
}

// Put length, because it need to count CRC.
Expand Down
2 changes: 1 addition & 1 deletion internal/proto/codec/intermediate.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func writeIntermediate(w io.Writer, b *bin.Buffer) error {

// readIntermediate reads payload from r to b.
func readIntermediate(r io.Reader, b *bin.Buffer) error {
n, err := tryReadLength(r, b)
n, err := readLen(r, b)
if err != nil {
return err
}
Expand Down
33 changes: 29 additions & 4 deletions internal/proto/gzip.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package proto
import (
"bytes"
"compress/gzip"
"io"
"io/ioutil"

"golang.org/x/xerrors"
Expand All @@ -21,6 +22,26 @@ type GZIP struct {
// GZIPTypeID is TL type id of GZIP.
const GZIPTypeID = 0x3072cfa1

// Encode implements bin.Encoder.
func (g GZIP) Encode(b *bin.Buffer) error {
b.PutID(GZIPTypeID)

// Writing compressed data to buf.
buf := new(bytes.Buffer)
w := gzip.NewWriter(buf)
if _, err := io.Copy(w, bytes.NewReader(g.Data)); err != nil {
return xerrors.Errorf("compress: %w", err)
}
if err := w.Close(); err != nil {
return xerrors.Errorf("close: %w", err)
}

// Writing compressed data as bytes.
b.PutBytes(buf.Bytes())

return nil
}

// Decode implements bin.Decoder.
func (g *GZIP) Decode(b *bin.Buffer) error {
if err := b.ConsumeID(GZIPTypeID); err != nil {
Expand All @@ -37,13 +58,17 @@ func (g *GZIP) Decode(b *bin.Buffer) error {
}
defer func() { _ = r.Close() }()

if g.Data, err = ioutil.ReadAll(r); err != nil {
return err
// Apply mitigation for reading too much data which can result in OOM.
const maxUncompressedSize = 1024 * 1024 * 10 // 10 mb
// TODO(ernado): fail explicitly if limit is reached
// Currently we just return nil, but it is better than failing with OOM.
if g.Data, err = ioutil.ReadAll(io.LimitReader(r, maxUncompressedSize)); err != nil {
return xerrors.Errorf("decompress: %w", err)
}

if err := r.Close(); err != nil {
// This will verify checksum.
return xerrors.Errorf("gzip error: %w", err)
// This will verify checksum only if limit is not reached.
return xerrors.Errorf("checksum: %w", err)
}

return nil
Expand Down
37 changes: 37 additions & 0 deletions internal/proto/gzip_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package proto

import (
"bytes"
"testing"

"github.com/stretchr/testify/require"

"github.com/gotd/td/bin"
)

func TestGZIP_Encode(t *testing.T) {
data := bytes.Repeat([]byte{1, 2, 3}, 100)
g := &GZIP{
Data: data,
}

var b bin.Buffer
require.NoError(t, b.Encode(g))

var decoded GZIP
require.NoError(t, b.Decode(&decoded))
require.Equal(t, data, decoded.Data)
}

func TestGZIP_Decode(t *testing.T) {
g := &GZIP{
Data: make([]byte, 1024*1024*15),
}
var b bin.Buffer
require.NoError(t, b.Encode(g))

var decoded GZIP
// TODO(ernado): fail explicitly if limit is reached
require.NoError(t, b.Decode(&decoded))
require.Less(t, len(decoded.Data), len(g.Data))
}
10 changes: 6 additions & 4 deletions telegram/client_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ func (h handler) OnMessage(k tgtest.Session, msgID int64, in *bin.Buffer) error

func testTransport(trp *transport.Transport) func(t *testing.T) {
return func(t *testing.T) {
srv := tgtest.NewUnstartedServer(t, trp.Codec())
t.Helper()

ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
defer cancel()

srv := tgtest.NewUnstartedServer(ctx, t, trp.Codec())
h := handler{
server: srv,
t: t,
Expand All @@ -93,9 +98,6 @@ func testTransport(trp *transport.Transport) func(t *testing.T) {
srv.Start()
defer srv.Close()

ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
defer cancel()

dispatcher := tg.NewUpdateDispatcher()
log, _ := zap.NewDevelopment(zap.IncreaseLevel(zapcore.DebugLevel))
client := NewClient(1, "hash", Options{
Expand Down
8 changes: 4 additions & 4 deletions telegram/internal/tgtest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,20 @@ func (s *Server) Close() {
_ = s.server.Close()
}

func NewServer(tb TB, codec transport.Codec, h Handler) *Server {
s := NewUnstartedServer(tb, codec)
func NewServer(ctx context.Context, tb TB, codec transport.Codec, h Handler) *Server {
s := NewUnstartedServer(ctx, tb, codec)
s.SetHandler(h)
s.Start()
return s
}

func NewUnstartedServer(tb TB, codec transport.Codec) *Server {
func NewUnstartedServer(ctx context.Context, tb TB, codec transport.Codec) *Server {
k, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(ctx)
s := &Server{
server: transport.NewCustomServer(codec, newLocalListener()),
tb: tb,
Expand Down

0 comments on commit 22334d9

Please sign in to comment.