diff --git a/http2/frame.go b/http2/frame.go index 105c3b279..bdd7861f0 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -225,6 +225,11 @@ var fhBytes = sync.Pool{ }, } +func invalidHTTP1LookingFrameHeader() FrameHeader { + fh, _ := readFrameHeader(make([]byte, frameHeaderLen), strings.NewReader("HTTP/1.1 ")) + return fh +} + // ReadFrameHeader reads 9 bytes from r and returns a FrameHeader. // Most users should use Framer.ReadFrame instead. func ReadFrameHeader(r io.Reader) (FrameHeader, error) { @@ -503,10 +508,16 @@ func (fr *Framer) ReadFrame() (Frame, error) { return nil, err } if fh.Length > fr.maxReadSize { + if fh == invalidHTTP1LookingFrameHeader() { + return nil, fmt.Errorf("http2: failed reading the frame payload: %w, note that the frame header looked like an HTTP/1.1 header", err) + } return nil, ErrFrameTooLarge } payload := fr.getReadBuf(fh.Length) if _, err := io.ReadFull(fr.r, payload); err != nil { + if fh == invalidHTTP1LookingFrameHeader() { + return nil, fmt.Errorf("http2: failed reading the frame payload: %w, note that the frame header looked like an HTTP/1.1 header", err) + } return nil, err } f, err := typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload) diff --git a/http2/transport_test.go b/http2/transport_test.go index 757a45a7a..e4d44bac4 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -272,6 +272,48 @@ func TestTransport(t *testing.T) { } } +func TestTransportFailureErrorForHTTP1Response(t *testing.T) { + const expectedHTTP1PayloadHint = "frame header looked like an HTTP/1.1 header" + + ts := httptest.NewServer(http.NewServeMux()) + t.Cleanup(ts.Close) + + for _, tc := range []struct { + name string + maxFrameSize uint32 + expectedErrorIs error + }{ + { + name: "with default max frame size", + maxFrameSize: 0, + }, + { + name: "with enough frame size to start reading", + maxFrameSize: invalidHTTP1LookingFrameHeader.Length + 1, + }, + } { + t.Run(tc.name, func(t *testing.T) { + tr := &Transport{ + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + return net.Dial(network, addr) + }, + MaxReadFrameSize: tc.maxFrameSize, + AllowHTTP: true, + } + + req, err := http.NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + + _, err = tr.RoundTrip(req) + if !strings.Contains(err.Error(), expectedHTTP1PayloadHint) { + t.Errorf("expected error to contain %q, got %v", expectedHTTP1PayloadHint, err) + } + }) + } +} + func testTransportReusesConns(t *testing.T, useClient, wantSame bool, modReq func(*http.Request)) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, r.RemoteAddr)