Skip to content

Commit

Permalink
Merge pull request #146 from monzo/h2-alpn
Browse files Browse the repository at this point in the history
Support HTTP2 Prior Knowledge
  • Loading branch information
milesbxf authored Mar 31, 2022
2 parents f9f07d6 + 525ba6c commit f03e3c0
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 24 deletions.
45 changes: 42 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,41 @@
package typhon

import (
"context"
"crypto/tls"
"net"
"net/http"
"time"

"github.com/monzo/terrors"
"golang.org/x/net/http2"
)

var (
// Client is used to send all requests by default. It can be overridden globally but MUST only be done before use
// takes place; access is not synchronised.
Client Service = BareClient
// RoundTripper is used by default in Typhon clients
RoundTripper http.RoundTripper = &http.Transport{
// RoundTripper chooses HTTP1, or H2C based on a context flag (see WithH2C)
RoundTripper http.RoundTripper = dynamicRoundTripper{}

// HTTPRoundTripper is a HTTP1 and TLS HTTP2 client
HTTPRoundTripper http.RoundTripper = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DisableKeepAlives: false,
DisableCompression: false,
IdleConnTimeout: 10 * time.Minute,
MaxIdleConnsPerHost: 10}
MaxIdleConnsPerHost: 10,
}

// H2cRoundTripper is a prior-knowledge H2c client. It does not support ProxyFromEnvironment.
H2cRoundTripper http.RoundTripper = &http2.Transport{
AllowHTTP: true,
// This monstrosity is needed to get the http2 Transport to dial over cleartext.
// See https://github.com/thrawn01/h2c-golang-example
DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
return net.Dial(network, addr)
},
}
)

// A ResponseFuture is a container for a Response which will materialise at some point.
Expand Down Expand Up @@ -91,3 +109,24 @@ func SendVia(req Request, svc Service) *ResponseFuture {
func Send(req Request) *ResponseFuture {
return SendVia(req, Client)
}

type withH2C struct{}

// WithH2C instructs the dynamicRoundTripper to use prior-knowledge cleartext HTTP2 instead of HTTP1.1
func WithH2C(ctx context.Context) context.Context {
return context.WithValue(ctx, withH2C{}, true)
}

func isH2C(ctx context.Context) bool {
b, _ := ctx.Value(withH2C{}).(bool)
return b
}

type dynamicRoundTripper struct{}

func (d dynamicRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
if r.URL.Scheme == "http" && isH2C(r.Context()) {
return H2cRoundTripper.RoundTrip(r)
}
return HTTPRoundTripper.RoundTrip(r)
}
9 changes: 9 additions & 0 deletions e2e_http1_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package typhon

import (
"context"
"crypto/tls"
"fmt"
"testing"
Expand All @@ -26,6 +27,10 @@ func (f http1Flavour) Proto() string {
return "HTTP/1.1"
}

func (f http1Flavour) Context() (context.Context, func()) {
return context.WithCancel(context.Background())
}

type http1TLSFlavour struct {
T *testing.T
cert tls.Certificate
Expand All @@ -48,3 +53,7 @@ func (f http1TLSFlavour) URL(s *Server) string {
func (f http1TLSFlavour) Proto() string {
return "HTTP/1.1"
}

func (f http1TLSFlavour) Context() (context.Context, func()) {
return context.WithCancel(context.Background())
}
35 changes: 35 additions & 0 deletions e2e_http2_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package typhon

import (
"context"
"crypto/tls"
"fmt"
"testing"
Expand Down Expand Up @@ -28,6 +29,36 @@ func (f http2H2cFlavour) Proto() string {
return "HTTP/2.0"
}

func (f http2H2cFlavour) Context() (context.Context, func()) {
return context.WithCancel(context.Background())
}

type http2H2cPriorKnowledgeFlavour struct {
T *testing.T
client Service
}

func (f http2H2cPriorKnowledgeFlavour) Serve(svc Service) *Server {
svc = svc.Filter(H2cFilter)
s, err := Listen(svc, "localhost:0")
require.NoError(f.T, err)
return s
}

func (f http2H2cPriorKnowledgeFlavour) URL(s *Server) string {
return fmt.Sprintf("http://%s", s.Listener().Addr())
}

func (f http2H2cPriorKnowledgeFlavour) Proto() string {
return "HTTP/2.0"
}

func (f http2H2cPriorKnowledgeFlavour) Context() (context.Context, func()) {
ctx, cancel := context.WithCancel(context.Background())
ctx = WithH2C(ctx)
return ctx, cancel
}

type http2H2Flavour struct {
T *testing.T
client Service
Expand All @@ -52,3 +83,7 @@ func (f http2H2Flavour) URL(s *Server) string {
func (f http2H2Flavour) Proto() string {
return "HTTP/2.0"
}

func (f http2H2Flavour) Context() (context.Context, func()) {
return context.WithCancel(context.Background())
}
55 changes: 34 additions & 21 deletions e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type e2eFlavour interface {
Serve(Service) *Server
URL(*Server) string
Proto() string
Context() (context.Context, func())
}

// flavours runs the passed E2E test with all test flavours (HTTP/1.1, HTTP/2.0/h2c, etc.)
Expand Down Expand Up @@ -85,6 +86,13 @@ func someFlavours(t *testing.T, only []string, impl func(*testing.T, e2eFlavour)
impl(t, http2H2cFlavour{T: t})
})
}
if run("http2.0-h2c-prior-knowledge") {
t.Run("http2.0-h2c-prior-knowledge", func(t *testing.T) {
defer leaktest.Check(t)()
Client = Service(BareClient).Filter(ErrorFilter)
impl(t, http2H2cPriorKnowledgeFlavour{T: t})
})
}
if run("http2.0-h2") {
t.Run("http2.0-h2", func(t *testing.T) {
defer leaktest.Check(t)()
Expand All @@ -103,7 +111,7 @@ func someFlavours(t *testing.T, only []string, impl func(*testing.T, e2eFlavour)

func TestE2E(t *testing.T) {
flavours(t, func(t *testing.T, flav e2eFlavour) {
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := flav.Context()
defer cancel()

svc := Service(func(req Request) Response {
Expand Down Expand Up @@ -137,7 +145,7 @@ func TestE2E(t *testing.T) {

func TestE2EProtobuf(t *testing.T) {
flavours(t, func(t *testing.T, flav e2eFlavour) {
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := flav.Context()
defer cancel()

svc := Service(func(req Request) Response {
Expand Down Expand Up @@ -177,7 +185,7 @@ func TestE2EProtobuf(t *testing.T) {

func TestE2EStreaming(t *testing.T) {
someFlavours(t, []string{"http1.1", "http1.1-tls"}, func(t *testing.T, flav e2eFlavour) {
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := flav.Context()
defer cancel()

chunks := make(chan []byte)
Expand Down Expand Up @@ -219,8 +227,8 @@ func TestE2EStreaming(t *testing.T) {
// The HTTP/2.0 streaming implementation is more advanced, as it allows the response body to be streamed back
// concurrently with the request body. This test constructs a server that echoes the request body back to the client
// and asserts that the chunks are returned in real time.
someFlavours(t, []string{"http2.0-h2", "http2.0-h2c"}, func(t *testing.T, flav e2eFlavour) {
ctx, cancel := context.WithCancel(context.Background())
someFlavours(t, []string{"http2.0-h2", "http2.0-h2c", "http2.0-h2c-prior-knowledge", "http"}, func(t *testing.T, flav e2eFlavour) {
ctx, cancel := flav.Context()
defer cancel()

svc := Service(func(req Request) Response {
Expand Down Expand Up @@ -256,7 +264,7 @@ func TestE2EStreaming(t *testing.T) {

func TestE2EDomainSocket(t *testing.T) {
someFlavours(t, []string{"http1.1"}, func(t *testing.T, flav e2eFlavour) {
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := flav.Context()
defer cancel()

svc := Service(func(req Request) Response {
Expand Down Expand Up @@ -289,7 +297,7 @@ func TestE2EDomainSocket(t *testing.T) {

func TestE2EError(t *testing.T) {
flavours(t, func(t *testing.T, flav e2eFlavour) {
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := flav.Context()
defer cancel()

expectedErr := terrors.Unauthorized("ah_ah_ah", "You didn't say the magic word!", map[string]string{
Expand Down Expand Up @@ -323,7 +331,7 @@ func TestE2EError(t *testing.T) {

func TestE2EErrorWithProtobuf(t *testing.T) {
flavours(t, func(t *testing.T, flav e2eFlavour) {
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := flav.Context()
defer cancel()

expectedErr := terrors.Unauthorized("ah_ah_ah", "You didn't say the magic word!", map[string]string{
Expand Down Expand Up @@ -371,7 +379,7 @@ func TestE2ECancellation(t *testing.T) {
s := flav.Serve(svc)
defer s.Stop(context.Background())

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := flav.Context()
req := NewRequest(ctx, "GET", flav.URL(s), nil)
req.Send()
select {
Expand Down Expand Up @@ -403,7 +411,7 @@ func TestE2ENoFollowRedirect(t *testing.T) {
s := flav.Serve(svc)
defer s.Stop(context.Background())

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := flav.Context()
defer cancel()
req := NewRequest(ctx, "GET", flav.URL(s), nil)
rsp := req.Send().Response()
Expand All @@ -415,7 +423,7 @@ func TestE2ENoFollowRedirect(t *testing.T) {

func TestE2EProxiedStreamer(t *testing.T) {
flavours(t, func(t *testing.T, flav e2eFlavour) {
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := flav.Context()
defer cancel()

chunks := make(chan bool)
Expand Down Expand Up @@ -470,7 +478,8 @@ func TestE2EProxiedStreamer(t *testing.T) {
// cancelled) is used to make a request.
func TestE2EInfiniteContext(t *testing.T) {
flavours(t, func(t *testing.T, flav e2eFlavour) {
ctx := context.Background()
ctx, cancel := flav.Context()
defer cancel()

var receivedCtx context.Context
svc := Service(func(req Request) Response {
Expand Down Expand Up @@ -517,7 +526,7 @@ func TestE2ERequestAutoChunking(t *testing.T) {
s := flav.Serve(svc)
defer s.Stop(context.Background())

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := flav.Context()
defer cancel()

// Streamer; should be chunked
Expand Down Expand Up @@ -567,7 +576,7 @@ func TestE2EResponseAutoChunking(t *testing.T) {
s := flav.Serve(svc)
defer s.Stop(context.Background())

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := flav.Context()
defer cancel()

// Streamer; should be chunked
Expand Down Expand Up @@ -634,7 +643,7 @@ func TestE2EStreamingCancellation(t *testing.T) {
s := flav.Serve(svc)
defer s.Stop(context.Background())

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := flav.Context()
req := NewRequest(ctx, "GET", flav.URL(s), nil)
req.Send().Response()
cancel()
Expand Down Expand Up @@ -662,7 +671,9 @@ func TestE2EStreamingServerAbort(t *testing.T) {
s := flav.Serve(svc)
defer s.Stop(context.Background())

ctx := context.Background()
ctx, cancel := flav.Context()
defer cancel()

req := NewRequest(ctx, "GET", flav.URL(s), nil)
rsp := req.Send().Response()
<-done
Expand All @@ -671,7 +682,7 @@ func TestE2EStreamingServerAbort(t *testing.T) {
assert.EqualError(t, err, io.ErrUnexpectedEOF.Error())
})

someFlavours(t, []string{"http2.0-h2", "http2.0-h2c"}, func(t *testing.T, flav e2eFlavour) {
someFlavours(t, []string{"http2.0-h2", "http2.0-h2c", "http2.0-h2c-prior-knowledge"}, func(t *testing.T, flav e2eFlavour) {
done := make(chan struct{})
svc := Service(func(req Request) Response {
s := Streamer()
Expand All @@ -688,7 +699,9 @@ func TestE2EStreamingServerAbort(t *testing.T) {
s := flav.Serve(svc)
defer s.Stop(context.Background())

ctx := context.Background()
ctx, cancel := flav.Context()
defer cancel()

req := NewRequest(ctx, "GET", flav.URL(s), nil)
rsp := req.Send().Response()
<-done
Expand All @@ -704,7 +717,7 @@ func TestE2EStreamingServerAbort(t *testing.T) {
// will write chunks of output by both copying them from the request body, and writing them directly. The two should
// be interleaved and sent to the client without delay.
func TestE2EFullDuplex(t *testing.T) {
someFlavours(t, []string{"http2.0-h2", "http2.0-h2c"}, func(t *testing.T, flav e2eFlavour) {
someFlavours(t, []string{"http2.0-h2", "http2.0-h2c", "http2.0-h2c-prior-knowledge"}, func(t *testing.T, flav e2eFlavour) {
chunks := make(chan []byte)
svc := Service(func(req Request) Response {
body := Streamer()
Expand All @@ -721,7 +734,7 @@ func TestE2EFullDuplex(t *testing.T) {
s := flav.Serve(svc)
defer s.Stop(context.Background())

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := flav.Context()
defer cancel()
req := NewRequest(ctx, "GET", flav.URL(s), nil)
req.Body = Streamer()
Expand Down Expand Up @@ -749,7 +762,7 @@ func TestE2EFullDuplex(t *testing.T) {

func TestE2EDraining(t *testing.T) {
flavours(t, func(t *testing.T, flav e2eFlavour) {
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := flav.Context()
defer cancel()

returnRsp := make(chan bool)
Expand Down

0 comments on commit f03e3c0

Please sign in to comment.