Skip to content

Commit

Permalink
[minor] Allow setting custom claims in JWT auth (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
joecorall authored Apr 15, 2024
1 parent 832deb5 commit c6e314b
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 12 deletions.
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# rollout

Trigger a deployment of your application from a CI/CD pipeline to an instance of your application running on a VM.
Deploy your application from a CI/CD pipeline via `cURL` + JWT auth.

```
$ curl -s -H "Authorization: bearer abc..." https://example.com/your/rollout/path
Expand All @@ -9,9 +9,9 @@ Rollout complete

## Purpose

Instead of managing SSH keys in your CI/CD that has access to your production environment to run deployment scripts, this serivce can be running in your production environment to handle deploying code changes.
Instead of managing SSH keys in your CI/CD for accounts that have privileged access to perform deployments in your production environment, this service can handle deploying code changes.

Requires creating a JWT from your CI provider, and sending that token to this service running in your deployment environment to trigger the deployment script.
Requires creating a JWT from your CI provider, and sending that token to this service running in your deployment environment to trigger a deployment script.

Also requires a `rollout.sh` script that can handle all the command needing ran to rollout your software.

Expand All @@ -33,6 +33,10 @@ This service requires two envionrment variables.

- `JWKS_URI` - the URL of the OIDC Provider's [JSON Web Key (JWK) set document](https://www.rfc-editor.org/info/rfc7517). This is used to ensure the JWT was signed by the provider.
- `JWT_AUD` - the audience set in the JWT token.
- `CUSTOM_CLAIMS` - (optional) JSON of key/value pairs to validate in the JWT e.g.
```
{"foo": "bar", "foo2": "bar2"}
```
- `ROLLOUT_CMD` (default: `/bin/bash`) - the command to execute a rollout
- `ROLLOUT_ARGS` (default: `/rollout.sh` ) - the args to pass to `ROLLOUT_CMD`

Expand All @@ -56,4 +60,3 @@ JWT_AUD=aud-string-you-set-in-your-job
- [ ] Add a full example for GitHub
- [ ] Install instructions using binary
- [ ] Tag/push versions to dockerhub
- [ ] Allow more custom auth handling
26 changes: 25 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"net/http"
Expand Down Expand Up @@ -51,7 +52,30 @@ func Rollout(w http.ResponseWriter, r *http.Request) {
}

if !token.Valid {
log.Println("Invalid token for", realIp, ",", lastIP, err.Error())
log.Println("Invalid token for", realIp, ",", lastIP)
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}

ccStr := os.Getenv("CUSTOM_CLAIMS")
claims, ok := token.Claims.(jwt.MapClaims)
if ok && ccStr != "" {
var cc map[string]string
err = json.Unmarshal([]byte(ccStr), &cc)
if err != nil {
log.Println("Unable to read token claims for", realIp, ",", lastIP)
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
for k, v := range cc {
if claims[k] != v {
log.Println("Claim for", k, "doesn't match", realIp, ",", lastIP)
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
}
} else if !ok {
log.Println("Unable to read token claims for", realIp, ",", lastIP)
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
Expand Down
39 changes: 32 additions & 7 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,14 @@ func setupMockJwksServer(pub *rsa.PublicKey, kid string) *httptest.Server {
return httptest.NewServer(handler)
}

func CreateSignedJWT(kid, aud string, exp int64, privateKey *rsa.PrivateKey) (string, error) {
func CreateSignedJWT(kid, aud, claim string, exp int64, privateKey *rsa.PrivateKey) (string, error) {
// Define the claims of the token. You can add more claims based on your needs.
claims := jwt.MapClaims{
"sub": "1234567890",
"aud": aud,
"iat": time.Now().Unix(),
"exp": exp,
"foo": claim,
}

// Create a new token object with the claims and the signing method
Expand Down Expand Up @@ -126,6 +127,7 @@ func TestRollout(t *testing.T) {
os.Setenv("JWT_AUD", "test-success")
kid := "no-kidding"
aud := os.Getenv("JWT_AUD")
claim := "bar"
privateKey, publicKey, err := GenerateRSAKeys()
if err != nil {
log.Fatalf("Unable to generate RSA keys: %v", err)
Expand All @@ -137,13 +139,13 @@ func TestRollout(t *testing.T) {

// get a valid token
exp := time.Now().Add(time.Hour * 1).Unix()
jwtToken, err := CreateSignedJWT(kid, aud, exp, privateKey)
jwtToken, err := CreateSignedJWT(kid, aud, claim, exp, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}

// make sure invalid kids fail
badKidJwtToken, err := CreateSignedJWT("just-kidding", aud, exp, privateKey)
badKidJwtToken, err := CreateSignedJWT("just-kidding", aud, claim, exp, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}
Expand All @@ -153,20 +155,26 @@ func TestRollout(t *testing.T) {
if err != nil {
t.Fatalf("Unable to generate a new private key")
}
badPrivKeyjwtToken, err := CreateSignedJWT(kid, aud, exp, badPrivateKey)
badPrivKeyjwtToken, err := CreateSignedJWT(kid, aud, claim, exp, badPrivateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our new test key: %v", err)
}

// make sure expired JWTs fail
expired := time.Now().Add(time.Hour * -1).Unix()
expiredJwtToken, err := CreateSignedJWT(kid, aud, expired, privateKey)
expiredJwtToken, err := CreateSignedJWT(kid, aud, claim, expired, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}

// make sure bad audience JWTs fail
badAudJwtToken, err := CreateSignedJWT(kid, "different-audience", exp, privateKey)
badAudJwtToken, err := CreateSignedJWT(kid, "different-audience", claim, exp, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}

// make sure JWTs with a bad custom claim fail
badClaimJwtToken, err := CreateSignedJWT(kid, aud, "bad-claim", exp, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}
Expand All @@ -177,6 +185,7 @@ func TestRollout(t *testing.T) {
authHeader string
expectedStatus int
expectedBody string
claim map[string]string
}{
{
name: "No Authorization Header",
Expand Down Expand Up @@ -214,19 +223,35 @@ func TestRollout(t *testing.T) {
expectedStatus: http.StatusUnauthorized,
expectedBody: "Failed to verify token.\n",
},
{
name: "Bad custom claim",
authHeader: "Bearer " + badClaimJwtToken,
expectedStatus: http.StatusUnauthorized,
expectedBody: "Invalid token\n",
},
{
name: "No custom claim",
authHeader: "Bearer " + jwtToken,
expectedStatus: http.StatusOK,
expectedBody: "Rollout complete\n",
},
{
name: "Valid Token and Successful Command",
authHeader: "Bearer " + jwtToken,
expectedStatus: http.StatusOK,
expectedBody: "Rollout complete\n",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
request := createRequest(tt.authHeader)
if tt.name == "No custom claim" {
os.Setenv("CUSTOM_CLAIMS", "")
} else {
os.Setenv("CUSTOM_CLAIMS", `{"foo": "bar"}`)

}
Rollout(recorder, request)

assert.Equal(t, tt.expectedStatus, recorder.Code)
Expand Down

0 comments on commit c6e314b

Please sign in to comment.