diff --git a/README.md b/README.md index cd51fdd..59a55fe 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ These are the environment variables currently supported, keyed by their respecti |----------------|---------------| ------------------------------------- | `docker-image` | `DOCKER_IMAGE`| `{"docker-image": "foo/bar:latest"}` | | `docker-tag` | `DOCKER_TAG` | `{"docker-tag": "latest"}` | -| `git-repo` | `GIT_REPO` | `{"repo": "foo/bar:latest"}` | +| `git-repo` | `GIT_REPO` | `{"git-repo": "foo/bar:latest"}` | | `git-branch` | `GIT_BRANCH` | `{"git-branch": "main"}` | | `rollout-arg1` | `ROLLOUT_ARG1`| `{"rollout-arg1": "FOO"}` | | `rollout-arg2` | `ROLLOUT_ARG2`| `{"rollout-arg2": "BAR"}` | diff --git a/main.go b/main.go index 18e220a..48947b4 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "reflect" + "regexp" "github.com/golang-jwt/jwt/v5" "github.com/google/shlex" @@ -101,8 +102,8 @@ func Rollout(w http.ResponseWriter, r *http.Request) { err = setCustomArgs(r) if err != nil { - slog.Error("Error setting custom logs", "err", err) - http.Error(w, "Script execution failed", http.StatusInternalServerError) + slog.Error("Error setting custom args", "err", err) + http.Error(w, "Bad request", http.StatusBadRequest) return } @@ -219,6 +220,11 @@ func setCustomArgs(r *http.Request) error { } func setEnvFromStruct(data interface{}) error { + regex, err := regexp.Compile(`^[a-zA-Z0-9._\-:\/@]+$`) + if err != nil { + return fmt.Errorf("failed to compile regex: %v", err) + } + v := reflect.ValueOf(data) if v.Kind() == reflect.Ptr { v = v.Elem() @@ -233,6 +239,9 @@ func setEnvFromStruct(data interface{}) error { if value == "" { continue } + if !regex.MatchString(value) { + return fmt.Errorf("invalid input for environment variable %s:%s", envTag, value) + } if err := os.Setenv(envTag, value); err != nil { return fmt.Errorf("could not set environment variable %s: %v", envTag, err) } diff --git a/main_test.go b/main_test.go index f319ea3..e204716 100644 --- a/main_test.go +++ b/main_test.go @@ -325,10 +325,10 @@ func TestRolloutCmdArgs(t *testing.T) { } payloads := map[string]string{ - "docker-image": "rollout-docker-image-test", + "docker-image": "us-docker.pkg.dev-project-interal-image:latest", "docker-tag": "rollout-docker-tag-test", "git-branch": "rollout-git-branch-test", - "git-repo": "rollout-git-repo-test", + "git-repo": "git@github.com:lehigh-university-libraries-rollout.git", "rollout-arg1": "rollout-arg1-test", "rollout-arg2": "rollout-arg2-test", "rollout-arg3": "rollout-arg3-test", @@ -392,6 +392,93 @@ func TestRolloutCmdArgs(t *testing.T) { } } +func TestBadRolloutCmdArgs(t *testing.T) { + os.Setenv("ROLLOUT_CMD", "/bin/bash") + s := createMockJwksServer() + defer s.Close() + + // get a valid token + exp := time.Now().Add(time.Hour * 1).Unix() + jwtToken, err := CreateSignedJWT(kid, aud, claim, exp, privateKey) + if err != nil { + t.Fatalf("Unable to create a JWT with our test key: %v", err) + } + + payloads := map[string]string{ + "rollout-arg1": "bad1;", + "rollout-arg2": "bad2&", + "rollout-arg3": "bad3|", + "bad4": "bad4$", + "bad5": "any`thing", + "bad6": `any"thing`, + "bad7": "any\thing", + "bad8": "any*thing", + "bad9": "any?thing", + "bad10": "any[thing", + "bad11": "any]thing", + "bad12": "any{thing", + "bad13": "any}thing", + "bad14": "any(thing", + "bad15": "any)thing", + "bad16": "anything", + "bad18": "anything!", + } + for k, v := range payloads { + var e string + switch k { + case "rollout-arg1": + e = "ROLLOUT_ARG1" + case "rollout-arg2": + e = "ROLLOUT_ARG2" + case "rollout-arg3": + e = "ROLLOUT_ARG3" + default: + k = "rollout-arg1" + e = "ROLLOUT_ARG1" + } + tt := Test{ + name: fmt.Sprintf("%s custom arg doesn't pass to rollout.sh", k), + authHeader: "Bearer " + jwtToken, + expectedStatus: http.StatusBadRequest, + cmdArgs: fmt.Sprintf(`-c "touch /tmp/$%s"`, e), + method: "POST", + payload: fmt.Sprintf(`{"%s": "%s"}`, k, v), + expectedBody: "Bad request\n", + } + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + method := "POST" + body := strings.NewReader(tt.payload) + request := createRequest(tt.authHeader, method, body) + os.Setenv("ROLLOUT_ARGS", tt.cmdArgs) + + Rollout(recorder, request) + + assert.Equal(t, tt.expectedStatus, recorder.Code) + assert.Equal(t, tt.expectedBody, recorder.Body.String()) + }) + } + + for _, v := range payloads { + f := "/tmp/" + v + // make sure the rollout command didn't run the command + // which creates the file + _, err = os.Stat(f) + if err != nil && os.IsNotExist(err) { + continue + } + t.Errorf("The test created a bad file name. Check sanitizing inputs to catch %s", f) + + // cleanup + err = RemoveFileIfExists(f) + if err != nil { + slog.Error("Unable to cleanup test file", "file", f, "err", err) + os.Exit(1) + } + } +} + func RemoveFileIfExists(filePath string) error { _, err := os.Stat(filePath) if err == nil {