Skip to content

Commit

Permalink
Refactor and add tests around JWT auth
Browse files Browse the repository at this point in the history
  • Loading branch information
joecorall committed Apr 15, 2024
1 parent a285c2c commit 1c82000
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 97 deletions.
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ go 1.22.0
require (
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/lestrrat-go/jwx v1.2.29
github.com/stretchr/testify v1.9.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect
Expand All @@ -16,5 +18,7 @@ require (
github.com/lestrrat-go/iter v1.0.2 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/crypto v0.22.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
191 changes: 94 additions & 97 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,111 +14,23 @@ import (
)

var (
gitLabJwksURL = "https://%s/oauth/discovery/keys"
aud string
gitLabJwksURL, aud string
)

func init() {
domain := os.Getenv("GITLAB_DOMAIN")
if domain == "" {
log.Fatal("GITLAB_DOMAIN is required. You could use GITLAB_DOMAIN=gitlab.com")
}
gitLabJwksURL = fmt.Sprintf(gitLabJwksURL, domain)
func main() {

gitLabJwksURL = os.Getenv("JWKS_URI")
if gitLabJwksURL == "" {
log.Fatal("JWKS_URI is required. e.g. JWKS_URI=https://gitlab.com/oauth/discovery/keys")
}
aud = os.Getenv("JWT_AUD")
if aud == "" {
log.Fatal("JWT_AUD is required. This needs to be the aud in the JWT you except this service to handle.")
}

}

func main() {
ctx := context.Background()

// Fetch the JWKS from GitLab
set, err := jwk.Fetch(ctx, gitLabJwksURL)
if err != nil {
fmt.Printf("Failed to fetch JWKS: %v\n", err)
return
}

http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
realIp, lastIP := readUserIP(r)

a := r.Header.Get("Authorization")
if len(a) < 10 {
log.Println("Not auth header for", realIp, ",", lastIP)
http.Error(w, "need authorizaton: bearer xyz header", http.StatusUnauthorized)
return
}
// Assuming "Bearer " prefix
tokenString := a[7:]

// Parse and verify the token
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

// Check audience claim
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("error retrieving claims from token")
}
if !claims.VerifyAudience(aud, true) {
return nil, fmt.Errorf("invalid audience. Expected: %s", aud)
}

kid, ok := token.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("expecting JWT header to have string 'kid'")
}

// Find the appropriate key in JWKS
key, ok := set.LookupKeyID(kid)
if !ok {
return nil, fmt.Errorf("unable to find key '%s'", kid)
}

var pubkey interface{}
if err := key.Raw(&pubkey); err != nil {
return nil, fmt.Errorf("failed to get raw key: %v", err)
}

return pubkey, nil
})

if err != nil {
log.Println("Failed to verify token for", realIp, ",", lastIP, err.Error())
http.Error(w, "Failed to verify token.", http.StatusUnauthorized)
return
}

if !token.Valid {
log.Println("Invalid token for", realIp, ",", lastIP, err.Error())
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}

// TODO make this more customizable
// but for now this fills the need
cmd := exec.Command("/bin/bash", "/rollout.sh")

var stdErr bytes.Buffer
cmd.Stderr = &stdErr
cmd.Env = os.Environ()
if err := cmd.Run(); err != nil {
log.Printf("Error running %s command: %s", cmd.String(), stdErr.String())
http.Error(w, "Script execution failed", http.StatusInternalServerError)
return
}

log.Println("Rollout complete for", realIp, ",", lastIP)
fmt.Fprintln(w, "Rollout complete")
})

fmt.Println("Server is running on http://localhost:8080/")
err = http.ListenAndServe(":8080", nil)
http.HandleFunc("/", Rollout)
log.Println("Server is running on :8080")
err := http.ListenAndServe(":8080", nil)
if err != nil {
log.Fatal("Unable to start service")
}
Expand All @@ -132,3 +44,88 @@ func readUserIP(r *http.Request) (string, string) {
}
return realIP, lastIP
}

func Rollout(w http.ResponseWriter, r *http.Request) {
realIp, lastIP := readUserIP(r)

a := r.Header.Get("Authorization")
if len(a) < 10 {
log.Println("Not auth header for", realIp, ",", lastIP)
http.Error(w, "need authorizaton: bearer xyz header", http.StatusUnauthorized)
return
}
// Assuming "Bearer " prefix
tokenString := a[7:]

// Parse and verify the token
token, err := jwt.Parse(tokenString, ParseToken)
if err != nil {
log.Println("Failed to verify token for", realIp, ",", lastIP, err.Error())
http.Error(w, "Failed to verify token.", http.StatusUnauthorized)
return
}

if !token.Valid {
log.Println("Invalid token for", realIp, ",", lastIP, err.Error())
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}

// TODO make this more customizable
// but for now this fills the need
cmd := exec.Command("/bin/bash", "/rollout.sh")

var stdOut, stdErr bytes.Buffer
cmd.Stdout = &stdOut
cmd.Stderr = &stdErr
cmd.Env = os.Environ()
if err := cmd.Run(); err != nil {
log.Printf("Error running %s command: %s", cmd.String(), stdOut.String())
log.Printf("stderr: %s", stdErr.String())
http.Error(w, "Script execution failed", http.StatusInternalServerError)
return
}

log.Println("Rollout complete for", realIp, ",", lastIP)
fmt.Fprintln(w, "Rollout complete")
}

func ParseToken(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

// Check audience claim
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("error retrieving claims from token")
}
aud := os.Getenv("JWT_AUD")
if !claims.VerifyAudience(aud, true) {
return nil, fmt.Errorf("invalid audience. Expected: %s", aud)
}

kid, ok := token.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("expecting JWT header to have string 'kid'")
}

ctx := context.Background()
gitLabJwksURL = os.Getenv("JWKS_URI")
jwksSet, err := jwk.Fetch(ctx, gitLabJwksURL)
if err != nil {
log.Fatalf("Unable to fetch JWK set from %s: %v", gitLabJwksURL, err)
}
// Find the appropriate key in JWKS
key, ok := jwksSet.LookupKeyID(kid)
if !ok {
return nil, fmt.Errorf("unable to find key '%s'", kid)
}

var pubkey interface{}
if err := key.Raw(&pubkey); err != nil {
return nil, fmt.Errorf("failed to get raw key: %v", err)
}

return pubkey, nil
}
Loading

0 comments on commit 1c82000

Please sign in to comment.