Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add unix sockets support for URLs #874

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
86 changes: 86 additions & 0 deletions helper/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package helper

import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
)

type transport struct {
base *http.Transport
dialer *net.Dialer
tlsDialer *tls.Dialer
}

func (t *transport) handleUnixAddr(addr string) (string, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return "", err
}
path, err := url.PathUnescape(host)
if err != nil {
return "", err
}
return path, nil
}

func (t *transport) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if path, err := t.handleUnixAddr(addr); err != nil {
return nil, err
} else {
return t.dialer.DialContext(ctx, "unix", path)
}
}

func (t *transport) dialTlsContext(ctx context.Context, network, addr string) (net.Conn, error) {
if path, err := t.handleUnixAddr(addr); err != nil {
return nil, err
} else {
return t.tlsDialer.DialContext(ctx, "unix", path)
}
}

func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) {
if r.URL != nil {
switch r.URL.Scheme {
case "http", "https":
return http.DefaultTransport.RoundTrip(r)
case "unix":
urlValues := r.URL.Query()
req := r.Clone(r.Context())
if urlValues.Get("tls") != "" {
req.URL.Scheme = "https"
} else {
req.URL.Scheme = "http"
}
req.URL.Host = url.QueryEscape(r.URL.Path)
req.URL.Path = urlValues.Get("path")
v := req.URL.Query()
v.Del("tls")
v.Del("path")
req.URL.RawQuery = v.Encode()
return t.base.RoundTrip(req)
default:
}
}
return nil, fmt.Errorf("invalid request")
}

func NewRoundTripper() http.RoundTripper {
base := http.DefaultTransport.(*http.Transport).Clone()
dialer := &net.Dialer{}
t := &transport{
base: base,
dialer: dialer,
tlsDialer: &tls.Dialer{
NetDialer: dialer,
Config: base.TLSClientConfig,
},
}
t.base.DialContext = t.dialContext
t.base.DialTLSContext = t.dialTlsContext
return t
}
9 changes: 8 additions & 1 deletion pipeline/authn/authenticator_bearer_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,18 @@ type AuthenticatorBearerTokenConfiguration struct {

type AuthenticatorBearerToken struct {
c configuration.Provider
h *http.Client
}

func NewAuthenticatorBearerToken(c configuration.Provider) *AuthenticatorBearerToken {
return &AuthenticatorBearerToken{
c: c,
h: &http.Client{
Transport: helper.NewRoundTripper(),
CheckRedirect: http.DefaultClient.CheckRedirect,
Jar: http.DefaultClient.Jar,
Timeout: http.DefaultClient.Timeout,
},
}
}

Expand Down Expand Up @@ -85,7 +92,7 @@ func (a *AuthenticatorBearerToken) Authenticate(r *http.Request, session *Authen
return errors.WithStack(ErrAuthenticatorNotResponsible)
}

body, err := forwardRequestToSessionStore(r, cf.CheckSessionURL, cf.PreserveQuery, cf.PreservePath, cf.PreserveHost, cf.SetHeaders)
body, err := forwardRequestToSessionStore(r, a.h, cf.CheckSessionURL, cf.PreserveQuery, cf.PreservePath, cf.PreserveHost, cf.SetHeaders)
if err != nil {
return err
}
Expand Down
46 changes: 34 additions & 12 deletions pipeline/authn/authenticator_cookie_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,18 @@ type AuthenticatorCookieSessionConfiguration struct {

type AuthenticatorCookieSession struct {
c configuration.Provider
h *http.Client
}

func NewAuthenticatorCookieSession(c configuration.Provider) *AuthenticatorCookieSession {
return &AuthenticatorCookieSession{
c: c,
h: &http.Client{
Transport: helper.NewRoundTripper(),
CheckRedirect: http.DefaultClient.CheckRedirect,
Jar: http.DefaultClient.Jar,
Timeout: http.DefaultClient.Timeout,
},
}
}

Expand Down Expand Up @@ -88,7 +95,7 @@ func (a *AuthenticatorCookieSession) Authenticate(r *http.Request, session *Auth
return errors.WithStack(ErrAuthenticatorNotResponsible)
}

body, err := forwardRequestToSessionStore(r, cf.CheckSessionURL, cf.PreserveQuery, cf.PreservePath, cf.PreserveHost, cf.SetHeaders)
body, err := forwardRequestToSessionStore(r, a.h, cf.CheckSessionURL, cf.PreserveQuery, cf.PreservePath, cf.PreserveHost, cf.SetHeaders)
if err != nil {
return err
}
Expand Down Expand Up @@ -128,23 +135,14 @@ func cookieSessionResponsible(r *http.Request, only []string) bool {
return false
}

func forwardRequestToSessionStore(r *http.Request, checkSessionURL string, preserveQuery bool, preservePath bool, preserveHost bool, setHeaders map[string]string) (json.RawMessage, error) {
func forwardRequestToSessionStore(r *http.Request, httpClient *http.Client, checkSessionURL string, preserveQuery bool, preservePath bool, preserveHost bool, setHeaders map[string]string) (json.RawMessage, error) {
reqUrl, err := url.Parse(checkSessionURL)
if err != nil {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to parse session check URL: %s", err))
}

if !preservePath {
reqUrl.Path = r.URL.Path
}

if !preserveQuery {
reqUrl.RawQuery = r.URL.RawQuery
}

req := http.Request{
Method: r.Method,
URL: reqUrl,
Header: http.Header{},
}

Expand All @@ -161,7 +159,31 @@ func forwardRequestToSessionStore(r *http.Request, checkSessionURL string, prese
req.Header.Set("X-Forwarded-Host", r.Host)
}

res, err := http.DefaultClient.Do(req.WithContext(r.Context()))
if reqUrl.Scheme == "unix" {
urlValues := reqUrl.Query()
if !preservePath {
urlValues.Set("path", r.URL.Path)
}

if !preserveQuery {
v := r.URL.Query()
v.Set("path", urlValues.Get("path"))
v.Set("tls", urlValues.Get("tls"))
urlValues = v
}
reqUrl.RawQuery = urlValues.Encode()
req.URL = reqUrl
} else {
if !preservePath {
reqUrl.Path = r.URL.Path
}

if !preserveQuery {
reqUrl.RawQuery = r.URL.RawQuery
}
req.URL = reqUrl
}
res, err := httpClient.Do(req.WithContext(r.Context()))
if err != nil {
return nil, helper.ErrForbidden.WithReason(err.Error()).WithTrace(err)
}
Expand Down
15 changes: 12 additions & 3 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"net/url"
"strings"

"github.com/ory/oathkeeper/helper"
"github.com/ory/oathkeeper/pipeline/authn"
"github.com/ory/oathkeeper/x"

Expand All @@ -44,11 +45,12 @@ type proxyRegistry interface {
}

func NewProxy(r proxyRegistry) *Proxy {
return &Proxy{r: r}
return &Proxy{r: r, t: helper.NewRoundTripper()}
}

type Proxy struct {
r proxyRegistry
t http.RoundTripper
}

type key int
Expand Down Expand Up @@ -88,7 +90,7 @@ func (d *Proxy) RoundTrip(r *http.Request) (*http.Response, error) {
Header: rw.header,
}, nil
} else if err == nil {
res, err := http.DefaultTransport.RoundTrip(r)
res, err := d.t.RoundTrip(r)
if err != nil {
d.r.Logger().
WithError(errors.WithStack(err)).
Expand Down Expand Up @@ -177,11 +179,18 @@ func ConfigureBackendURL(r *http.Request, rl *rule.Rule) error {
backendHost := p.Host
backendPath := p.Path
backendScheme := p.Scheme
backendQuery := p.Query()

forwardURL := r.URL
forwardURL.Scheme = backendScheme
forwardURL.Host = backendHost
forwardURL.Path = "/" + strings.TrimLeft("/"+strings.Trim(backendPath, "/")+"/"+strings.TrimLeft(proxyPath, "/"), "/")
if r.URL.Scheme == "unix" {
forwardURL.Path = "/" + strings.TrimLeft(backendPath, "/")
backendQuery.Set("path", "/"+strings.TrimLeft("/"+strings.Trim(backendQuery.Get("path"), "/")+"/"+strings.TrimLeft(proxyPath, "/"), "/"))
forwardURL.RawQuery = backendQuery.Encode()
} else {
forwardURL.Path = "/" + strings.TrimLeft("/"+strings.Trim(backendPath, "/")+"/"+strings.TrimLeft(proxyPath, "/"), "/")
}

if rl.Upstream.StripPath != "" {
forwardURL.Path = strings.Replace(forwardURL.Path, "/"+strings.Trim(rl.Upstream.StripPath, "/"), "", 1)
Expand Down
2 changes: 1 addition & 1 deletion rule/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func (v *ValidatorDefault) Validate(r *Rule) error {

if r.Upstream.URL == "" {
// Having no upstream URL is fine here because the judge does not need an upstream!
} else if !govalidator.IsURL(r.Upstream.URL) {
} else if !govalidator.IsRequestURI(r.Upstream.URL) {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf(`Value "%s" of "upstream.url" is not a valid url.`, r.Upstream.URL))
}

Expand Down