diff --git a/go.mod b/go.mod index 857e1f01..5652d2e4 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module github.com/mholt/caddy-l4 go 1.25 +replace github.com/caddyserver/caddy/v2 v2.10.1 => github.com/vnxme/caddy/v2 v2.0.0-20250822175201-1e2ae1b66bee + require ( github.com/caddyserver/caddy/v2 v2.10.1 github.com/fsnotify/fsnotify v1.9.0 diff --git a/go.sum b/go.sum index 590f98e8..35082272 100644 --- a/go.sum +++ b/go.sum @@ -92,8 +92,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= -github.com/caddyserver/caddy/v2 v2.10.1 h1:j//AmT4hh88we6HWjzTIxk2Y7jOypsUvmZH1K1AOh38= -github.com/caddyserver/caddy/v2 v2.10.1/go.mod h1:TXLQHx+ev4HDpkO6PnVVHUbL6OXt6Dfe7VcIBdQnPL0= +github.com/caddyserver/certmagic v0.23.0 h1:CfpZ/50jMfG4+1J/u2LV6piJq4HOfO6ppOnOf7DkFEU= +github.com/caddyserver/certmagic v0.23.0/go.mod h1:9mEZIWqqWoI+Gf+4Trh04MOVPD0tGSxtqsxg87hAIH4= github.com/caddyserver/certmagic v0.24.0 h1:EfXTWpxHAUKgDfOj6MHImJN8Jm4AMFfMT6ITuKhrDF0= github.com/caddyserver/certmagic v0.24.0/go.mod h1:xPT7dC1DuHHnS2yuEQCEyks+b89sUkMENh8dJF+InLE= github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA= @@ -436,6 +436,10 @@ github.com/urfave/cli v1.22.17 h1:SYzXoiPfQjHBbkYxbew5prZHS1TOLT3ierW8SYLqtVQ= github.com/urfave/cli v1.22.17/go.mod h1:b0ht0aqgH/6pBYzzxURyrM4xXNgsoT/n2ZzwQiEhNVo= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/vnxme/caddy/v2 v2.0.0-20250813194301-985e05d95dee h1:jzcXTnn8nOlIX0aazdAiHPmYyaXjpCEvYaai7bqqAb0= +github.com/vnxme/caddy/v2 v2.0.0-20250813194301-985e05d95dee/go.mod h1:ly5YIVCbxP8LITj7dMu33zMd9EwvLdcuuQoUwzNiQ0Y= +github.com/vnxme/caddy/v2 v2.0.0-20250822175201-1e2ae1b66bee h1:Thr+OARAP3gHUQhWCW0/bU97gBW9V7wfrK2TeglVTLU= +github.com/vnxme/caddy/v2 v2.0.0-20250822175201-1e2ae1b66bee/go.mod h1:TXLQHx+ev4HDpkO6PnVVHUbL6OXt6Dfe7VcIBdQnPL0= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= diff --git a/integration/caddyfile_adapt/pcw_empty.caddytest b/integration/caddyfile_adapt/pcw_empty.caddytest new file mode 100644 index 00000000..0cc357de --- /dev/null +++ b/integration/caddyfile_adapt/pcw_empty.caddytest @@ -0,0 +1,76 @@ +{ + servers { + packet_conn_wrappers { + layer4 + } + } +} +https://localhost { + tls { + issuer internal + } + respond "OK" 200 +} +---------- +{ + "apps": { + "http": { + "servers": { + "srv0": { + "listen": [ + ":443" + ], + "packet_conn_wrappers": [ + { + "wrapper": "layer4" + } + ], + "routes": [ + { + "match": [ + { + "host": [ + "localhost" + ] + } + ], + "handle": [ + { + "handler": "subroute", + "routes": [ + { + "handle": [ + { + "body": "OK", + "handler": "static_response", + "status_code": 200 + } + ] + } + ] + } + ], + "terminal": true + } + ] + } + } + }, + "tls": { + "automation": { + "policies": [ + { + "subjects": [ + "localhost" + ], + "issuers": [ + { + "module": "internal" + } + ] + } + ] + } + } + } +} \ No newline at end of file diff --git a/integration/caddyfile_adapt/pcw_matcher_sets.caddytest b/integration/caddyfile_adapt/pcw_matcher_sets.caddytest new file mode 100644 index 00000000..ca6dd3a1 --- /dev/null +++ b/integration/caddyfile_adapt/pcw_matcher_sets.caddytest @@ -0,0 +1,125 @@ +{ + servers { + packet_conn_wrappers { + layer4 { + @d dns + route @d { + proxy udp/one.one.one.one:53 + } + @w wireguard + route @w { + proxy udp/192.168.1.1:51820 + } + } + } + } +} +https://localhost { + tls { + issuer internal + } + respond "OK" 200 +} +---------- +{ + "apps": { + "http": { + "servers": { + "srv0": { + "listen": [ + ":443" + ], + "packet_conn_wrappers": [ + { + "routes": [ + { + "handle": [ + { + "handler": "proxy", + "upstreams": [ + { + "dial": [ + "udp/one.one.one.one:53" + ] + } + ] + } + ], + "match": [ + { + "dns": {} + } + ] + }, + { + "handle": [ + { + "handler": "proxy", + "upstreams": [ + { + "dial": [ + "udp/192.168.1.1:51820" + ] + } + ] + } + ], + "match": [ + { + "wireguard": {} + } + ] + } + ], + "wrapper": "layer4" + } + ], + "routes": [ + { + "match": [ + { + "host": [ + "localhost" + ] + } + ], + "handle": [ + { + "handler": "subroute", + "routes": [ + { + "handle": [ + { + "body": "OK", + "handler": "static_response", + "status_code": 200 + } + ] + } + ] + } + ], + "terminal": true + } + ] + } + } + }, + "tls": { + "automation": { + "policies": [ + { + "subjects": [ + "localhost" + ], + "issuers": [ + { + "module": "internal" + } + ] + } + ] + } + } + } +} \ No newline at end of file diff --git a/integration/common_test.go b/integration/common_test.go new file mode 100644 index 00000000..a5bcc719 --- /dev/null +++ b/integration/common_test.go @@ -0,0 +1,350 @@ +package integration + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/http" + "testing" + "time" + + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" + "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" + "github.com/miekg/dns" + "github.com/quic-go/quic-go/http3" +) + +// dnsQueryConfig contains DNS query parameters and a function for response message validation. +type dnsQueryConfig struct { + // QueryId is an identifier used for logging. + QueryId string + // Timeout is a DNS request timeout (set via context.WithTimeout). + Timeout time.Duration + // TLSConfig in a TLS configuration. + TLSConfig *tls.Config + + // Net can be "tcp-tls", "tcp" or "udp", "" is a synonym for "udp". + Net string + // Address must be given in "{address}:{port}" or another format supported by net.Dial. + Address string + // QueryType must be one of dns.TypeA, dns.TypeAAAA, dns.TypeMX, dns.TypeNS, etc. + QueryType uint16 + // DomainName must be a FQDN ending with a dot. + DomainName string + + // Check validates a DNS response message. + Check func(*testing.T, *dns.Msg) +} + +// httpRequestConfig contains HTTP request parameters and a function for response validation. +type httpRequestConfig struct { + // RequestId is an identifier used for logging. + RequestId string + // Timeout is an HTTP request timeout (set via context.WithTimeout). + Timeout time.Duration + // TLSConfig in a TLS configuration. + TLSConfig *tls.Config + + // Method is an HTTP method, e.g. "GET", "PUT", etc. + Method string + // Url is a resource locator in "{scheme}://{hostname}:{port}/{path}?{query}" format. + Url string + // Body is an optional request body. + Body io.Reader + + // Check validates an HTTP response. + Check func(*testing.T, *http.Response) + // Transport returns an http.RoundTripper implementation. + Transport func() http.RoundTripper +} + +// udpExchangeConfig contains UDP exchange parameters and a function for response message validation. +type udpExchangeConfig struct { + // ExchangeId is an identifier used for logging. + ExchangeId string + // Timeout is a UDP connection timeout (set via conn.SetDeadline). + Timeout time.Duration + // Delay is a period of time UDP exchange sleeps between message exchanges. + Delay time.Duration + + // Address must be given in "{address}:{port}" or another format supported by net.Dial. + Address string + + // Messages contains a list of byte sequences to send. + Messages [][]byte + // Check validates a UDP response message. + Check func(*testing.T, []byte, []byte) +} + +func exchangeUDP(t *testing.T, config *udpExchangeConfig) { + t.Helper() + + var logPrefix string + if len(config.ExchangeId) > 0 { + logPrefix = "[" + config.ExchangeId + "] " + } + + // Resolve address + udpAddr, err := net.ResolveUDPAddr("udp", config.Address) + if err != nil { + t.Fatalf("%sresolve UDP address: %v", logPrefix, err) + } + + // Dial UDP (connectionless, but gives us a socket) + conn, err := net.DialUDP("udp", nil, udpAddr) + if err != nil { + t.Fatalf("%sdial UDP: %v", logPrefix, err) + } + defer func() { _ = conn.Close() }() + + // Set a deadline so we don't block forever + _ = conn.SetDeadline(time.Now().Add(config.Timeout)) + + // Send messages and read responses + buf := make([][]byte, len(config.Messages)) + for i, m := range config.Messages { + // Send message + _, err = conn.Write(m) + if err != nil { + t.Fatalf("%swrite UDP message: %v", logPrefix, err) + } + + if config.Check != nil { + // Receive response + buf[i] = make([]byte, 9000) + n, _, err := conn.ReadFromUDP(buf[i]) + if err != nil { + t.Fatalf("%sread UDP response: %v", logPrefix, err) + } + + r := buf[i][:n] + t.Logf("%swrote UDP message %v, read UDP response %v", logPrefix, string(m), string(r)) + + config.Check(t, m, r) + } + + time.Sleep(config.Delay) + } +} + +// loadCaddyWithJSON launches Caddy instance with a given JSON config. +func loadCaddyWithJSON(t *testing.T, config []byte) { + t.Helper() + + // Start Caddy with the given config + err := caddy.Load(config, true) + if err != nil { + t.Fatalf("load Caddy: %v", err) + } + + // Give Caddy a moment to come up + time.Sleep(waitForCaddyLaunch) +} + +// loadCaddyWithCaddyfile launches Caddy instance with a given caddyfile config. +func loadCaddyWithCaddyfile(t *testing.T, caddyfileConfig string) { + t.Helper() + + // Create a caddyfile adapter + adapter := caddyfile.Adapter{ServerType: httpcaddyfile.ServerType{}} + + // Parse the given config, process warnings and errors + config, warnings, err := adapter.Adapt([]byte(caddyfileConfig), nil) + if len(warnings) > 0 { + t.Logf("caddyfile warnings: %v", warnings) + } + if err != nil { + t.Fatalf("adapt caddyfile: %v", err) + } + + loadCaddyWithJSON(t, config) +} + +// provideDNSMessageCheck returns a function that conducts DNS response code and resource record checks. +func provideDNSMessageCheck(queryId string, checkRecords bool, expectedRecordType uint16) func(*testing.T, *dns.Msg) { + return func(t *testing.T, r *dns.Msg) { + var logPrefix string + if len(queryId) > 0 { + logPrefix = "[" + queryId + "] " + } + + // Do a status code check + if r.Rcode != dns.RcodeSuccess { + t.Fatalf("%sexpected DNS response code NOERROR, got %s", logPrefix, dns.RcodeToString[r.Rcode]) + } + + if checkRecords { + // Process resource records + found := false + for _, ans := range r.Answer { + t.Logf("%sreceived DNS resource record: %s", logPrefix, ans.String()) + if ans.Header().Rrtype == expectedRecordType { + found = true + } + } + if !found { + t.Errorf("%sexpected at least one DNS record of type %s, got none", logPrefix, dns.TypeToString[expectedRecordType]) + } + } + } +} + +// provideHTTPResponseCheck returns a function that conducts HTTP status code and response body checks. +func provideHTTPResponseCheck(requestId string, expectedCode int, expectedBody string) func(*testing.T, *http.Response) { + return func(t *testing.T, resp *http.Response) { + var logPrefix string + if len(requestId) > 0 { + logPrefix = "[" + requestId + "] " + } + + if expectedCode > 0 && resp.StatusCode != expectedCode { + t.Errorf("%sexpected HTTP status code %d, got %d", logPrefix, expectedCode, resp.StatusCode) + } + + if len(expectedBody) > 0 && resp.Body != nil { + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("%sread HTTP response: %v", logPrefix, err) + } + + if string(respBody) != expectedBody { + t.Errorf("%sexpected HTTP response body %s, got %s", logPrefix, expectedBody, string(respBody)) + } + + t.Logf("%sreceived HTTP response body: %s", logPrefix, string(respBody)) + } + } +} + +// provideTLSConfig return a basic TLS config. +func provideTLSConfig(insecureSkipVerify bool, serverName string) *tls.Config { + return &tls.Config{ + InsecureSkipVerify: insecureSkipVerify, + ServerName: serverName, + } +} + +// queryDNS makes a DNS query. +func queryDNS(t *testing.T, config *dnsQueryConfig) { + t.Helper() + + var logPrefix string + if len(config.QueryId) > 0 { + logPrefix = "[" + config.QueryId + "] " + } + + // Compose a DNS query + m := new(dns.Msg) + m.SetQuestion(config.DomainName, config.QueryType) + + // Create a DNS client + c := &dns.Client{ + Net: config.Net, + TLSConfig: config.TLSConfig, + } + + // Create an empty context with the given timeout + ctx, cancel := context.WithTimeout(t.Context(), config.Timeout) + defer cancel() + + // Make the client conduct the query + r, _, err := c.ExchangeContext(ctx, m, config.Address) + if err != nil { + t.Fatalf("%smake DNS query: %v", logPrefix, err) + } + + // Do status code and resource record checks + check := config.Check + if check == nil { + check = provideDNSMessageCheck(config.QueryId, true, config.QueryType) + } + check(t, r) +} + +// requestHTTP3 makes an HTTP/3 request. +func requestHTTP3(t *testing.T, config *httpRequestConfig) { + t.Helper() + + // Create an HTTP/3 transport + config.Transport = func() http.RoundTripper { + return &http3.Transport{ + TLSClientConfig: config.TLSConfig, + } + } + + requestHTTP(t, config) +} + +// requestHTTP makes an HTTP request. +func requestHTTP(t *testing.T, config *httpRequestConfig) { + t.Helper() + + var logPrefix string + if len(config.RequestId) > 0 { + logPrefix = "[" + config.RequestId + "] " + } + + // Create an HTTP transport unless there is a custom http.RoundTripper implementation + var transport http.RoundTripper + if config.Transport != nil { + transport = config.Transport() + } + if transport == nil { + transport = &http.Transport{ + TLSClientConfig: config.TLSConfig, + } + } + if closable, ok := transport.(io.Closer); ok { + defer func() { + _ = closable.Close() + }() + } + + // Create an HTTP client with a custom transport + client := http.Client{ + Transport: transport, + } + + // Create an empty context with the given timeout + ctx, cancel := context.WithTimeout(t.Context(), config.Timeout) + defer cancel() + + // Compose an HTTP request with the context and given parameters + req, err := http.NewRequestWithContext(ctx, config.Method, config.Url, config.Body) + if err != nil { + t.Fatalf("%scompose HTTP request: %v", logPrefix, err) + } + + // Make the client conduct the request + resp, err := client.Do(req) + if err != nil { + t.Fatalf("%smake HTTP request: %v", logPrefix, err) + } + defer func() { _ = resp.Body.Close() }() + + // Do status code, response or custom checks + check := config.Check + if check == nil { + check = provideHTTPResponseCheck(config.RequestId, 200, "") + } + check(t, resp) +} + +// stopCaddy stops Caddy instance. +func stopCaddy(t *testing.T) { + t.Helper() + + err := caddy.Stop() + if err != nil { + t.Fatalf("stop Caddy: %v", err) + } +} + +const ( + waitForCaddyLaunch = 500 * time.Millisecond + waitForDNSRequest = 5 * time.Second + waitForHTTP3Request = 5 * time.Second + waitForUDPExchange = 5 * time.Second +) diff --git a/integration/packetconn_test.go b/integration/packetconn_test.go new file mode 100644 index 00000000..927f74ee --- /dev/null +++ b/integration/packetconn_test.go @@ -0,0 +1,243 @@ +package integration + +import ( + "bytes" + "fmt" + "strconv" + "sync" + "testing" + "time" + + "github.com/miekg/dns" + + _ "github.com/mholt/caddy-l4/layer4" + _ "github.com/mholt/caddy-l4/modules/l4dns" + _ "github.com/mholt/caddy-l4/modules/l4echo" + _ "github.com/mholt/caddy-l4/modules/l4proxy" + _ "github.com/mholt/caddy-l4/modules/l4regexp" +) + +const ( + testPCWCaddyfile = ` +{ + debug + layer4 { + udp/:2000 { + @e regexp ^ECHO\d\d\d\d$ 8 + route { + echo + } + } + udp/:5300 { + @d dns + route @d { + proxy udp/one.one.one.one:53 + } + } + } + servers :443 { + packet_conn_wrappers { + layer4 { + @d dns + route @d { + proxy udp/one.one.one.one:53 + } + @e regexp ^ECHO\d\d\d\d$ 8 + route @e { + echo + } + } + } + protocols h3 + } + servers :8443 { + protocols h3 + } +} +https://localhost, https://localhost:8443 { + tls { + issuer internal + } + respond "{http.request.uri.query}" 200 +} + ` // Caddy configuration to test + + testPCWHostname = "localhost" // Hostname that clients connect to + + testPCWPortMultiplex = 443 // Port serving multiple services + testPCWPortEchoOnly = 2000 // Port serving echo only + testPCWPortDNSOnly = 5300 // Port serving DNS only + testPCWPortHTTP3Only = 8443 // Port serving HTTP/3 only + + testPCWMultiplier = 10 // How many times each domain is requested by DNS client + + testPCWDelay = 100 * time.Millisecond // How much time to wait between spawning clients +) + +var ( + testPCWDomains = []string{ + "baidu.com.", + "bing.com.", + "chatgpt.com.", + "facebook.com.", + "google.com.", + "instagram.com.", + "reddit.com.", + "tiktok.com.", + "x.com.", + "wikipedia.org.", + } // List of domains requested by DNS client. Its length determines the number of tests (adjusted by the multiplier). +) + +func testPCWWithDNSQueries(t *testing.T, wg *sync.WaitGroup, port int) { + t.Helper() + + defer wg.Done() + + address := fmt.Sprintf("%s:%d", testPCWHostname, port) + + multiplier, qType := testPCWMultiplier, dns.TypeA + if port == testPCWPortDNSOnly { + multiplier, qType = 1, dns.TypeNS // Disregard the multiplier for DNS only tests, query NS instead of A + } + + var c int + l := len(testPCWDomains) + for j := range multiplier { + for i, domain := range testPCWDomains { + c = i + j*l + + wg.Add(1) + go func() { + defer wg.Done() + + t.Logf("[DNS-%d.%d] connecting to %s, querying %s IN %s, expecting at least one", port, c, address, domain, dns.TypeToString[qType]) + + config := dnsQueryConfig{ + QueryId: fmt.Sprintf("DNS-%d.%d", port, c), + Timeout: waitForDNSRequest, + TLSConfig: nil, + + Net: "udp", + Address: address, + QueryType: qType, + DomainName: domain, + } + config.Check = provideDNSMessageCheck(config.QueryId, true, qType) + + queryDNS(t, &config) + }() + + time.Sleep(testPCWDelay) + } + } +} + +func testPCWWithHTTP3Requests(t *testing.T, wg *sync.WaitGroup, port int) { + t.Helper() + + defer wg.Done() + + address := fmt.Sprintf("%s:%d", testPCWHostname, port) + + var payload, url string + for i := range testPCWMultiplier * len(testPCWDomains) { + payload = strconv.Itoa(i) + url = "https://" + address + "/?" + payload + + wg.Add(1) + go func() { + defer wg.Done() + + t.Logf("[HTTP/3-%d.%d] connecting to %s, requesting %s, expecting %s", port, i, address, url, payload) + + config := httpRequestConfig{ + RequestId: fmt.Sprintf("HTTP/3-%d.%d", port, i), + Timeout: waitForHTTP3Request, + TLSConfig: provideTLSConfig(true, testPCWHostname), + + Method: "GET", + Url: url, + Body: nil, + } + config.Check = provideHTTPResponseCheck(config.RequestId, 200, payload) + + requestHTTP3(t, &config) + }() + + time.Sleep(testPCWDelay) + } +} + +func testPCWWithUDPExchanges(t *testing.T, wg *sync.WaitGroup, port int) { + t.Helper() + + defer wg.Done() + + address := fmt.Sprintf("%s:%d", testPCWHostname, port) + + messages := make([][]byte, testPCWMultiplier*len(testPCWDomains)) + for i, _ := range messages { + messages[i] = []byte(fmt.Sprintf("ECHO%04d", i)) + } + + for i := range 1 { + wg.Add(1) + go func() { + defer wg.Done() + + t.Logf("[UDP-%d.%d] connecting to %s, sending %s, expecting %s", port, i, address, messages[0], messages[0]) + + config := udpExchangeConfig{ + ExchangeId: fmt.Sprintf("UDP-%d.%d", port, i), + Timeout: 3 * waitForUDPExchange, + Delay: testPCWDelay, + + Address: address, + Messages: messages, + } + config.Check = func(t *testing.T, m []byte, r []byte) { + if !bytes.Equal(m, r) { + t.Errorf("[%s] expected %s, got %s", config.ExchangeId, string(m), string(r)) + } + } + + exchangeUDP(t, &config) + }() + + time.Sleep(testPCWDelay) + } +} + +// TestPCW tests the packet conn wrapper implementation. +func TestPCW(t *testing.T) { + // Load Caddy + loadCaddyWithCaddyfile(t, testPCWCaddyfile[1:][:len(testPCWCaddyfile)-2]) // Workaround avoid caddyfile warnings + + // Use a wait group to sync goroutines + var wg sync.WaitGroup + + // Spawn goroutines that make HTTP/3 requests + for _, port := range []int{testPCWPortHTTP3Only, testPCWPortMultiplex} { + wg.Add(1) + go testPCWWithHTTP3Requests(t, &wg, port) + } + + //// Spawn goroutines that make DNS queries + //for _, port := range []int{testPCWPortDNSOnly, testPCWPortMultiplex} { + // wg.Add(1) + // go testPCWWithDNSQueries(t, &wg, port) + //} + + // Spawn goroutines that exchange UDP messages + for _, port := range []int{testPCWPortEchoOnly, testPCWPortMultiplex} { + wg.Add(1) + go testPCWWithUDPExchanges(t, &wg, port) + } + + // Delay for all goroutines to finish + wg.Wait() + + // Stop Caddy + stopCaddy(t) +} diff --git a/integration/udp_test.go b/integration/udp_test.go new file mode 100644 index 00000000..00577efb --- /dev/null +++ b/integration/udp_test.go @@ -0,0 +1,77 @@ +package integration + +import ( + "bytes" + "fmt" + "sync" + "testing" + "time" + + _ "github.com/mholt/caddy-l4/layer4" + _ "github.com/mholt/caddy-l4/modules/l4echo" + _ "github.com/mholt/caddy-l4/modules/l4regexp" +) + +const ( + testUDPCaddyfile = ` +{ + debug + layer4 { + udp/:2000 { + @e regexp ^ECHO\d\d\d\d$ 8 + route { + echo + } + } + } +} + ` // Caddy configuration to test + + testUDPHostname = "localhost" // Hostname that clients connect to + testUDPPortEcho = 2000 // Port serving echo only + testUDPDelay = 0 * time.Millisecond // How much time to wait between sending messages +) + +func TestUDP(t *testing.T) { + // Load Caddy + loadCaddyWithCaddyfile(t, testUDPCaddyfile[1:][:len(testUDPCaddyfile)-2]) // Workaround to avoid caddyfile warnings + + // Use a wait group to sync goroutines + var wg sync.WaitGroup + + address := fmt.Sprintf("%s:%d", testUDPHostname, testUDPPortEcho) + + messages := make([][]byte, 1) + for i := range messages { + messages[i] = []byte(fmt.Sprintf("ECHO%04d", i)) + } + + wg.Add(1) + go func() { + defer wg.Done() + + t.Logf("[UDP-%d] connecting to %s, sending %s, expecting %s", testUDPPortEcho, address, messages[0], messages[0]) + + config := udpExchangeConfig{ + ExchangeId: fmt.Sprintf("UDP-%d", testUDPPortEcho), + Timeout: waitForUDPExchange, + Delay: testUDPDelay, + + Address: address, + Messages: messages, + } + config.Check = func(t *testing.T, m []byte, r []byte) { + if !bytes.Equal(m, r) { + t.Errorf("[%s] expected %s, got %s", config.ExchangeId, string(m), string(r)) + } + } + + exchangeUDP(t, &config) + }() + + // Delay for all goroutines to finish + wg.Wait() + + // Stop Caddy + stopCaddy(t) +} diff --git a/layer4/packetconn.go b/layer4/packetconn.go new file mode 100644 index 00000000..f6c268c9 --- /dev/null +++ b/layer4/packetconn.go @@ -0,0 +1,276 @@ +package layer4 + +import ( + "bytes" + "errors" + "net" + "os" + "sync/atomic" + "time" + + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" +) + +func init() { + caddy.RegisterModule(&PacketConnWrapper{}) +} + +// PacketConnWrapper is a Caddy module that wraps App as a packet conn wrapper, it doesn't support tcp. +type PacketConnWrapper struct { + // Routes express composable logic for handling byte streams. + Routes RouteList `json:"routes,omitempty"` + + // Maximum time connections have to complete the matching phase (the first terminal handler is matched). Default: 3s. + MatchingTimeout caddy.Duration `json:"matching_timeout,omitempty"` + + // probably should extract packet conn handling logic, but this will do + server *Server + + ctx caddy.Context +} + +// CaddyModule returns the Caddy module information. +func (*PacketConnWrapper) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "caddy.packetconns.layer4", + New: func() caddy.Module { return new(PacketConnWrapper) }, + } +} + +// Provision sets up the PacketConnWrapper. +func (pcw *PacketConnWrapper) Provision(ctx caddy.Context) error { + pcw.ctx = ctx + + if pcw.MatchingTimeout <= 0 { + pcw.MatchingTimeout = caddy.Duration(MatchingTimeoutDefault) + } + + err := pcw.Routes.Provision(ctx) + if err != nil { + return err + } + + logger := ctx.Logger() + + pcw.server = &Server{ + logger: logger, + compiledRoute: pcw.Routes.Compile(logger, time.Duration(pcw.MatchingTimeout), packetConnHandler{}), + } + + return nil +} + +// WrapPacketConn wraps up a packet conn. +func (pcw *PacketConnWrapper) WrapPacketConn(pc net.PacketConn) net.PacketConn { + pipe := make(chan *packet, 10) + go func() { + err := pcw.server.servePacket(&packetConnWithPipe{ + PacketConn: pc, + packetPipe: pipe, + }) + pipe <- &packet{ + err: err, + } + // server.servePacket will wait for all handling to finish before returning, + // so it's safe to close the pipe here as no new value will be sent + close(pipe) + }() + wpc := &wrappedPacketConn{ + pc: pc, + packetPipe: pipe, + } + // set the deadline to zero time to initialize the timer + _ = wpc.SetReadDeadline(time.Time{}) + return wpc +} + +// UnmarshalCaddyfile sets up the PacketConnWrapper from Caddyfile tokens. Syntax: +// +// layer4 { +// matching_timeout +// @a [] +// @b { +// [] +// [] +// } +// route @a @b { +// [] +// } +// @c { +// [] +// } +// route @c { +// [] +// { +// [] +// } +// } +// } +func (pcw *PacketConnWrapper) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + d.Next() // consume wrapper name + + // No same-line options are supported + if d.CountRemainingArgs() > 0 { + return d.ArgErr() + } + + if err := ParseCaddyfileNestedRoutes(d, &pcw.Routes, &pcw.MatchingTimeout); err != nil { + return err + } + + return nil +} + +// packetConnHandler is a connection handler that unwraps the incoming connection to channel as a packet conn wrapper. +type packetConnHandler struct{} + +func (packetConnHandler) Handle(conn *Connection) error { + // perhaps an interface is better + pc, ok := conn.Context.Value(connCtxKey).(*packetConn) + if !ok { + return errNotPacketConn + } + // impossible to be false, check nonetheless + pcwp, ok := pc.PacketConn.(*packetConnWithPipe) + if !ok { + return errNotPacketConn + } + // get the first buffer to read, Read shouldn't be called on packetConn from now on + var firstBuf []byte + if len(conn.buf) > 0 && conn.offset < len(conn.buf) { + switch { + // data is fully consumed + case pc.lastBuf == nil: + firstBuf = conn.buf[conn.offset:] + // data is partially consumed + case pc.lastBuf != nil && pc.lastBuf.Len() > 0: + // reuse matching buffer + n := copy(conn.buf, conn.buf[conn.offset:]) + buf := bytes.NewBuffer(conn.buf[:n]) + _, _ = buf.ReadFrom(pc.lastBuf) + + // release last packet buffer + udpBufPool.Put(pc.lastPacket.pooledBuf) + pc.lastPacket = nil + pc.lastBuf = nil + + firstBuf = buf.Bytes() + } + } + + // first use the buffer if any + if len(firstBuf) > 0 { + pcwp.packetPipe <- &packet{ + pooledBuf: firstBuf, + n: len(firstBuf), + err: nil, + addr: pc.addr, + } + } + + // pass the packet to the pipe + // reuse the idle timer for idle timeout since Read isn't called anymore + if pc.idleTimer == nil { + pc.idleTimer = time.NewTimer(udpAssociationIdleTimeout) + } else { + pc.idleTimer.Reset(udpAssociationIdleTimeout) + } + for { + select { + case pkt := <-pc.readCh: + pcwp.packetPipe <- pkt + pc.idleTimer.Reset(udpAssociationIdleTimeout) + case <-pc.idleTimer.C: + return errHijacked + } + } +} + +// packetConnWithPipe will send all the data it read to the channel from which the wrapper can receive +// typical udp data. +type packetConnWithPipe struct { + net.PacketConn + packetPipe chan *packet +} + +type wrappedPacketConn struct { + pc net.PacketConn + packetPipe chan *packet + // stores time.Time as Unix as ReadFrom maybe called concurrently with SetReadDeadline + deadline atomic.Int64 + deadlineTimer *time.Timer +} + +func (w *wrappedPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + // check deadline + if isDeadlineExceeded(time.Unix(w.deadline.Load(), 0)) { + return 0, nil, os.ErrDeadlineExceeded + } + for { + select { + case pkt := <-w.packetPipe: + if pkt == nil { + // Channel is closed. Return net.ErrClosed below. + return 0, nil, net.ErrClosed + } + if pkt.err != nil { + return 0, nil, pkt.err + } + n = copy(p, pkt.pooledBuf[:pkt.n]) + // discard the remaining data + udpBufPool.Put(pkt.pooledBuf) + return n, pkt.addr, nil + case <-w.deadlineTimer.C: + // deadline may change during the wait, recheck + if isDeadlineExceeded(time.Unix(w.deadline.Load(), 0)) { + return 0, nil, os.ErrDeadlineExceeded + } + } + } +} + +func (w *wrappedPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return w.pc.WriteTo(p, addr) +} + +func (w *wrappedPacketConn) Close() error { + return w.pc.Close() +} + +func (w *wrappedPacketConn) LocalAddr() net.Addr { + return w.pc.LocalAddr() +} + +func (w *wrappedPacketConn) SetDeadline(t time.Time) error { + _ = w.SetReadDeadline(t) + return w.pc.SetWriteDeadline(t) +} + +// SetReadDeadline sets the read deadline, it will reset the internal timer if already set. +// error will always be nil. +func (w *wrappedPacketConn) SetReadDeadline(t time.Time) error { + w.deadline.Store(t.Unix()) + if w.deadlineTimer != nil { + w.deadlineTimer.Reset(time.Until(t)) + } else { + w.deadlineTimer = time.NewTimer(time.Until(t)) + } + return nil +} + +func (w *wrappedPacketConn) SetWriteDeadline(t time.Time) error { + return w.pc.SetWriteDeadline(t) +} + +var ( + errNotPacketConn = errors.New("no packetConn found in connection context") + connCtxKey = caddy.CtxKey("underlying_conn") +) + +// Interface guards +var ( + _ caddy.Module = (*PacketConnWrapper)(nil) + _ caddy.PacketConnWrapper = (*PacketConnWrapper)(nil) + _ caddyfile.Unmarshaler = (*PacketConnWrapper)(nil) +) diff --git a/layer4/server.go b/layer4/server.go index 306a8831..153f43b7 100644 --- a/layer4/server.go +++ b/layer4/server.go @@ -16,6 +16,7 @@ package layer4 import ( "bytes" + "context" "errors" "fmt" "io" @@ -93,6 +94,8 @@ func (s *Server) serve(ln net.Listener) error { } func (s *Server) servePacket(pc net.PacketConn) error { + // wait until all goroutines are done before returning + var cnt atomic.Uint64 // Spawn a goroutine whose only job is to consume packets from the socket // and send to the packets channel. packets := make(chan packet, 10) @@ -128,12 +131,7 @@ func (s *Server) servePacket(pc net.PacketConn) error { case addr := <-closeCh: conn, ok := udpConns[addr] if ok { - // This will abort any active Read() from another goroutine and return EOF - close(conn.readCh) - // Drain pending packets to ensure we release buffers back to the pool - for pkt := range conn.readCh { - udpBufPool.Put(pkt.pooledBuf) - } + conn.drainBuffer() } // UDP connection is closed (either implicitly through timeout or by // explicit call to Close()). @@ -141,6 +139,20 @@ func (s *Server) servePacket(pc net.PacketConn) error { case pkt := <-packets: if pkt.err != nil { + // wait for all connections to finish + for _, conn := range udpConns { + conn.drainBuffer() + } + // drain the channel so wg.Done() will be called properly + // reason: in this switch case, closeCh won't be read like in the other case, and + // connections send to closeCh during Close. If sending is blocked, their Close is stuck + for addr := range closeCh { + delete(udpConns, addr) + // no new connection will be created, safe to close now + if cnt.Load() == 0 && len(closeCh) == 0 { + close(closeCh) + } + } return pkt.err } conn, ok := udpConns[pkt.addr.String()] @@ -154,6 +166,7 @@ func (s *Server) servePacket(pc net.PacketConn) error { closeCh: closeCh, } udpConns[pkt.addr.String()] = conn + cnt.Add(1) go func(conn *packetConn) { s.handle(conn) // It might seem cleaner to send to closeCh here rather than @@ -162,6 +175,7 @@ func (s *Server) servePacket(pc net.PacketConn) error { // packets coming in from the same downstream. Should that // happen, we'll just spin up a new handler concurrent to // the old one shutting down. + cnt.Add(^uint64(0)) }(conn) } conn.readCh <- &pkt @@ -177,11 +191,13 @@ func (s *Server) handle(conn net.Conn) { defer bufPool.Put(buf) cx := WrapConnection(conn, buf, s.logger) + // used to retrieve the original connection inside handlers + cx.Context = context.WithValue(cx.Context, connCtxKey, conn) start := time.Now() err := s.compiledRoute.Handle(cx) duration := time.Since(start) - if err != nil { + if err != nil && !errors.Is(err, errHijacked) { s.logger.Error("handling connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err)) } @@ -275,6 +291,17 @@ func isDeadlineExceeded(t time.Time) bool { return !t.IsZero() && t.Before(time.Now()) } +// drainBuffer drains any remaining data in the read buffer, and Close will be called automatically. +// should only be called from the server loop goroutine. +func (pc *packetConn) drainBuffer() { + // This will abort any active Read() from another goroutine and return EOF + close(pc.readCh) + // Drain pending packets to ensure we release buffers back to the pool + for pkt := range pc.readCh { + udpBufPool.Put(pkt.pooledBuf) + } +} + func (pc *packetConn) Read(b []byte) (n int, err error) { if pc.lastPacket != nil { // There is a partial buffer to continue reading from the previous