Skip to content

Commit

Permalink
Use shlex to parse command arguments (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
joecorall authored Apr 18, 2024
1 parent c6e314b commit 3a81181
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 29 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.22.0

require (
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/lestrrat-go/jwx v1.2.29
github.com/stretchr/testify v1.9.0
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/lestrrat-go/backoff/v2 v2.0.8 h1:oNb5E5isby2kiro9AgdHLv5N5tint1AnDVVf2E2un5A=
github.com/lestrrat-go/backoff/v2 v2.0.8/go.mod h1:rHP/q/r9aT27n24JQLa7JhSQZCKBBOiM/uP402WwN8Y=
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
Expand Down
26 changes: 20 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@ import (
"net/http"
"os"
"os/exec"
"strings"

"github.com/golang-jwt/jwt/v5"
"github.com/google/shlex"
"github.com/lestrrat-go/jwx/jwk"
)

func init() {
// call getArgs early to fail on a bad config
getArgs()
}

func main() {
if os.Getenv("JWKS_URI") == "" {
log.Fatal("JWKS_URI is required. e.g. JWKS_URI=https://gitlab.com/oauth/discovery/keys")
Expand Down Expand Up @@ -84,11 +89,7 @@ func Rollout(w http.ResponseWriter, r *http.Request) {
if name == "" {
name = "/bin/bash"
}
args := os.Getenv("ROLLOUT_ARGS")
if args == "" {
args = "/rollout.sh"
}
cmd := exec.Command(name, strings.Split(args, " ")...)
cmd := exec.Command(name, getArgs()...)

var stdOut, stdErr bytes.Buffer
cmd.Stdout = &stdOut
Expand Down Expand Up @@ -162,3 +163,16 @@ func strInSlice(e string, s []string) bool {
}
return false
}

func getArgs() []string {
args := os.Getenv("ROLLOUT_ARGS")
if args == "" {
args = "/rollout.sh"
}
rolloutArgs, err := shlex.Split(args)
if err != nil {
log.Fatalf("Error parsing ROLLOUT_ARGS %s: %v", args, err)
}

return rolloutArgs
}
77 changes: 54 additions & 23 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ import (
"github.com/stretchr/testify/assert"
)

var (
kid, claim, aud string
privateKey *rsa.PrivateKey
)

// createJWKS creates a JWKS JSON representation with a single RSA key.
func mockJWKS(pub *rsa.PublicKey, kid string) (string, error) {
jwks := struct {
Expand Down Expand Up @@ -80,6 +85,23 @@ func setupMockJwksServer(pub *rsa.PublicKey, kid string) *httptest.Server {
return httptest.NewServer(handler)
}

func createMockJwksServer() *httptest.Server {
var publicKey *rsa.PublicKey
var err error

os.Setenv("JWT_AUD", "test-success")
kid = "no-kidding"
aud = os.Getenv("JWT_AUD")
claim = "bar"
privateKey, publicKey, err = GenerateRSAKeys()
if err != nil {
log.Fatalf("Unable to generate RSA keys: %v", err)
}
testServer := setupMockJwksServer(publicKey, kid)
os.Setenv("JWKS_URI", fmt.Sprintf("%s/oauth/discovery/keys", testServer.URL))
return testServer
}

func CreateSignedJWT(kid, aud, claim string, exp int64, privateKey *rsa.PrivateKey) (string, error) {
// Define the claims of the token. You can add more claims based on your needs.
claims := jwt.MapClaims{
Expand Down Expand Up @@ -114,6 +136,8 @@ func createRequest(authHeader string) *http.Request {
// TestRollout tests the Rollout function with various scenarios
func TestRollout(t *testing.T) {
testFile := "/tmp/rollout-test.txt"

// have our test rollout cmd just touch a file
os.Setenv("ROLLOUT_CMD", "touch")
os.Setenv("ROLLOUT_ARGS", testFile)

Expand All @@ -123,19 +147,8 @@ func TestRollout(t *testing.T) {
log.Fatalf("Unable to cleanup test file: %v", err)
}

// mock the JWKS server response
os.Setenv("JWT_AUD", "test-success")
kid := "no-kidding"
aud := os.Getenv("JWT_AUD")
claim := "bar"
privateKey, publicKey, err := GenerateRSAKeys()
if err != nil {
log.Fatalf("Unable to generate RSA keys: %v", err)
}
server := setupMockJwksServer(publicKey, kid)
defer server.Close()
jwkURL := fmt.Sprintf("%s/oauth/discovery/keys", server.URL)
os.Setenv("JWKS_URI", jwkURL)
s := createMockJwksServer()
defer s.Close()

// get a valid token
exp := time.Now().Add(time.Hour * 1).Unix()
Expand Down Expand Up @@ -179,13 +192,13 @@ func TestRollout(t *testing.T) {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}

// Define test cases
tests := []struct {
name string
authHeader string
expectedStatus int
expectedBody string
claim map[string]string
cmdArgs string
}{
{
name: "No Authorization Header",
Expand Down Expand Up @@ -235,6 +248,13 @@ func TestRollout(t *testing.T) {
expectedStatus: http.StatusOK,
expectedBody: "Rollout complete\n",
},
{
name: "Rollout cmd with quotes parsed correctly",
authHeader: "Bearer " + jwtToken,
expectedStatus: http.StatusOK,
cmdArgs: `/tmp/rollout-shlex-test /tmp/"rollout test filename wrapped in quotes"`,
expectedBody: "Rollout complete\n",
},
{
name: "Valid Token and Successful Command",
authHeader: "Bearer " + jwtToken,
Expand All @@ -250,25 +270,36 @@ func TestRollout(t *testing.T) {
os.Setenv("CUSTOM_CLAIMS", "")
} else {
os.Setenv("CUSTOM_CLAIMS", `{"foo": "bar"}`)

}
if tt.cmdArgs != "" {
os.Setenv("ROLLOUT_ARGS", tt.cmdArgs)
}

Rollout(recorder, request)

assert.Equal(t, tt.expectedStatus, recorder.Code)
assert.Equal(t, tt.expectedBody, recorder.Body.String())
})
}

// make sure the rollout command actually ran the command
_, err = os.Stat(testFile)
if err != nil && os.IsNotExist(err) {
t.Errorf("The successful test did not create the expected file")
testFiles := []string{
testFile,
"/tmp/rollout-shlex-test",
`/tmp/rollout test filename wrapped in quotes`,
}
for _, f := range testFiles {
// make sure the rollout command actually ran the command
// which creates the file
_, err = os.Stat(f)
if err != nil && os.IsNotExist(err) {
t.Errorf("The successful test did not create the expected file %s", f)
}

// cleanup
err = RemoveFileIfExists(testFile)
if err != nil {
log.Fatalf("Unable to cleanup test file: %v", err)
// cleanup
err = RemoveFileIfExists(f)
if err != nil {
log.Fatalf("Unable to cleanup test file: %v", err)
}
}
}

Expand Down

0 comments on commit 3a81181

Please sign in to comment.