Skip to content

Commit

Permalink
Allow passing other JWT claims
Browse files Browse the repository at this point in the history
  • Loading branch information
joecorall committed Apr 15, 2024
1 parent 832deb5 commit 8fe8c66
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 9 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ This service requires two envionrment variables.

- `JWKS_URI` - the URL of the OIDC Provider's [JSON Web Key (JWK) set document](https://www.rfc-editor.org/info/rfc7517). This is used to ensure the JWT was signed by the provider.
- `JWT_AUD` - the audience set in the JWT token.
- `CUSTOM_CLAIMS` - (optional) JSON of key/value pairs to validate in the JWT e.g.
```
{"foo": "bar", "foo2": "bar2"}
```
- `ROLLOUT_CMD` (default: `/bin/bash`) - the command to execute a rollout
- `ROLLOUT_ARGS` (default: `/rollout.sh` ) - the args to pass to `ROLLOUT_CMD`

Expand All @@ -56,4 +60,3 @@ JWT_AUD=aud-string-you-set-in-your-job
- [ ] Add a full example for GitHub
- [ ] Install instructions using binary
- [ ] Tag/push versions to dockerhub
- [ ] Allow more custom auth handling
21 changes: 20 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"net/http"
Expand Down Expand Up @@ -51,7 +52,25 @@ func Rollout(w http.ResponseWriter, r *http.Request) {
}

if !token.Valid {
log.Println("Invalid token for", realIp, ",", lastIP, err.Error())
log.Println("Invalid token for", realIp, ",", lastIP)
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}

if claims, ok := token.Claims.(jwt.MapClaims); ok {
ccStr := os.Getenv("CUSTOM_CLAIMS")
log.Println(ccStr)
var cc map[string]string
json.Unmarshal([]byte(ccStr), &cc)
for k, v := range cc {
if claims[k] != v {
log.Println("Claim for", k, "doesn't match", realIp, ",", lastIP)
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
}
} else {
log.Println("Unable to read token claims for", realIp, ",", lastIP)
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
Expand Down
39 changes: 32 additions & 7 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,14 @@ func setupMockJwksServer(pub *rsa.PublicKey, kid string) *httptest.Server {
return httptest.NewServer(handler)
}

func CreateSignedJWT(kid, aud string, exp int64, privateKey *rsa.PrivateKey) (string, error) {
func CreateSignedJWT(kid, aud, claim string, exp int64, privateKey *rsa.PrivateKey) (string, error) {
// Define the claims of the token. You can add more claims based on your needs.
claims := jwt.MapClaims{
"sub": "1234567890",
"aud": aud,
"iat": time.Now().Unix(),
"exp": exp,
"foo": claim,
}

// Create a new token object with the claims and the signing method
Expand Down Expand Up @@ -126,6 +127,7 @@ func TestRollout(t *testing.T) {
os.Setenv("JWT_AUD", "test-success")
kid := "no-kidding"
aud := os.Getenv("JWT_AUD")
claim := "bar"
privateKey, publicKey, err := GenerateRSAKeys()
if err != nil {
log.Fatalf("Unable to generate RSA keys: %v", err)
Expand All @@ -137,13 +139,13 @@ func TestRollout(t *testing.T) {

// get a valid token
exp := time.Now().Add(time.Hour * 1).Unix()
jwtToken, err := CreateSignedJWT(kid, aud, exp, privateKey)
jwtToken, err := CreateSignedJWT(kid, aud, claim, exp, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}

// make sure invalid kids fail
badKidJwtToken, err := CreateSignedJWT("just-kidding", aud, exp, privateKey)
badKidJwtToken, err := CreateSignedJWT("just-kidding", aud, claim, exp, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}
Expand All @@ -153,20 +155,26 @@ func TestRollout(t *testing.T) {
if err != nil {
t.Fatalf("Unable to generate a new private key")
}
badPrivKeyjwtToken, err := CreateSignedJWT(kid, aud, exp, badPrivateKey)
badPrivKeyjwtToken, err := CreateSignedJWT(kid, aud, claim, exp, badPrivateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our new test key: %v", err)
}

// make sure expired JWTs fail
expired := time.Now().Add(time.Hour * -1).Unix()
expiredJwtToken, err := CreateSignedJWT(kid, aud, expired, privateKey)
expiredJwtToken, err := CreateSignedJWT(kid, aud, claim, expired, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}

// make sure bad audience JWTs fail
badAudJwtToken, err := CreateSignedJWT(kid, "different-audience", exp, privateKey)
badAudJwtToken, err := CreateSignedJWT(kid, "different-audience", claim, exp, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}

// make sure JWTs with a bad custom claim fail
badClaimJwtToken, err := CreateSignedJWT(kid, aud, "bad-claim", exp, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}
Expand All @@ -177,6 +185,7 @@ func TestRollout(t *testing.T) {
authHeader string
expectedStatus int
expectedBody string
claim map[string]string
}{
{
name: "No Authorization Header",
Expand Down Expand Up @@ -214,19 +223,35 @@ func TestRollout(t *testing.T) {
expectedStatus: http.StatusUnauthorized,
expectedBody: "Failed to verify token.\n",
},
{
name: "Bad custom claim",
authHeader: "Bearer " + badClaimJwtToken,
expectedStatus: http.StatusUnauthorized,
expectedBody: "Invalid token\n",
},
{
name: "No custom claim",
authHeader: "Bearer " + jwtToken,
expectedStatus: http.StatusOK,
expectedBody: "Rollout complete\n",
},
{
name: "Valid Token and Successful Command",
authHeader: "Bearer " + jwtToken,
expectedStatus: http.StatusOK,
expectedBody: "Rollout complete\n",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
request := createRequest(tt.authHeader)
if tt.name == "No custom claim" {
os.Setenv("CUSTOM_CLAIMS", "")
} else {
os.Setenv("CUSTOM_CLAIMS", `{"foo": "bar"}`)

}
Rollout(recorder, request)

assert.Equal(t, tt.expectedStatus, recorder.Code)
Expand Down

0 comments on commit 8fe8c66

Please sign in to comment.