Skip to content

Commit

Permalink
Add additional test coverage for helper functions in Util
Browse files Browse the repository at this point in the history
Currently helper functions in Util lacks test coverage

This commit fixes the following functions by addressing the review feedback.
- TestExpandFileArgs: Removed os.RemoveAll(testDir) line
- Changed name to setEnvVar and fixed all references
- TestGetGrpcConnection: use setEnvVar to set env variable
- TestGetJsonFromProto: Normalize json to handle empty space
- remved helper function createTestServer

Signed-off-by: Kugamoorthy Gajananan <[email protected]>
  • Loading branch information
gajananan committed Jan 27, 2025
1 parent 7a03106 commit af6a40c
Showing 1 changed file with 35 additions and 58 deletions.
93 changes: 35 additions & 58 deletions internal/util/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"slices"
Expand All @@ -35,16 +34,17 @@ var (
envLock = &sync.Mutex{}
)

func SetEnvVar(t *testing.T, env string, value string) {
func setEnvVar(t *testing.T, env string, value string) {
t.Helper() // Keep golangci-lint happy
envLock.Lock()
t.Cleanup(envLock.Unlock)

originalEnvToken := os.Getenv(env)
err := os.Setenv(env, value)
if err != nil {
t.Errorf("error setting %v: %v", env, err)
}

defer os.Setenv(env, originalEnvToken)
}

// TestGetConfigDirPath tests the GetConfigDirPath function
Expand Down Expand Up @@ -73,7 +73,7 @@ func TestGetConfigDirPath(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
SetEnvVar(t, "XDG_CONFIG_HOME", tt.envVar)
setEnvVar(t, "XDG_CONFIG_HOME", tt.envVar)
path, err := util.GetConfigDirPath()
if (err != nil) != tt.expectingError {
t.Errorf("expected error: %v, got: %v", tt.expectingError, err)
Expand Down Expand Up @@ -163,13 +163,8 @@ func TestGetGrpcConnection(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
originalEnvToken := os.Getenv(util.MinderAuthTokenEnvVar)
err := os.Setenv(util.MinderAuthTokenEnvVar, tt.envToken)
if err != nil {
t.Errorf("error setting %v: %v", util.MinderAuthTokenEnvVar, err)
}
// reset this the environment variable when complete.
defer os.Setenv(util.MinderAuthTokenEnvVar, originalEnvToken)

setEnvVar(t, util.MinderAuthTokenEnvVar, tt.envToken)

conn, err := util.GetGrpcConnection(tt.grpcHost, tt.grpcPort, tt.allowInsecure, tt.issuerUrl, tt.clientId)
if (err != nil) != tt.expectedError {
Expand Down Expand Up @@ -199,42 +194,31 @@ func TestSaveCredentials(t *testing.T) {
// Create a temporary directory
testDir := t.TempDir()

SetEnvVar(t, "XDG_CONFIG_HOME", testDir)
setEnvVar(t, "XDG_CONFIG_HOME", testDir)

cfgPath := filepath.Join(testDir, "minder")

expectedFilePath := filepath.Join(cfgPath, "credentials.json")

filePath, err := util.SaveCredentials(tokens)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
require.NoError(t, err)

if filePath != expectedFilePath {
t.Errorf("expected file path %v, got %v", expectedFilePath, filePath)
}

// Verify the file content
credsJSON, err := json.Marshal(tokens)
if err != nil {
t.Fatalf("error marshaling credentials: %v", err)
}
require.NoError(t, err)

fpath := filepath.Clean(filePath)
content, err := os.ReadFile(fpath)
if err != nil {
t.Fatalf("error reading file: %v", err)
}
cleanPath := filepath.Clean(filePath)
content, err := os.ReadFile(cleanPath)
require.NoError(t, err)

if string(content) != string(credsJSON) {
t.Errorf("expected file content %v, got %v", string(credsJSON), string(content))
}

// Clean up
err = os.Remove(filePath)
if err != nil {
t.Fatalf("error removing file: %v", err)
}
}

// TestRemoveCredentials tests the RemoveCredentials function
Expand All @@ -244,7 +228,7 @@ func TestRemoveCredentials(t *testing.T) {
// Create a temporary directory
testDir := t.TempDir()

SetEnvVar(t, "XDG_CONFIG_HOME", testDir)
setEnvVar(t, "XDG_CONFIG_HOME", testDir)
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")

filePath := filepath.Join(xdgConfigHome, "minder", "credentials.json")
Expand Down Expand Up @@ -278,7 +262,7 @@ func TestRefreshCredentials(t *testing.T) {
// Create a temporary directory
testDir := t.TempDir()

SetEnvVar(t, "XDG_CONFIG_HOME", testDir)
setEnvVar(t, "XDG_CONFIG_HOME", testDir)
tests := []struct {
name string
refreshToken string
Expand All @@ -291,7 +275,6 @@ func TestRefreshCredentials(t *testing.T) {
{
name: "Successful refresh with local server",
refreshToken: "valid_refresh_token",
issuerUrl: "http://localhost:8081",
clientId: "minder-cli",
responseBody: `{"access_token":"new_access_token","refresh_token":"new_refresh_token","expires_in":3600}`,
expectedResult: util.OpenIdCredentials{
Expand All @@ -303,22 +286,19 @@ func TestRefreshCredentials(t *testing.T) {
{
name: "Error fetching new credentials (responseBody is missing) rwith local server",
refreshToken: "valid_refresh_token",
issuerUrl: "http://localhost:8081",
clientId: "minder-cli",
expectedError: "error unmarshaling credentials: EOF",
},
{
name: "Error unmarshaling credentials with local server",
refreshToken: "valid_refresh_token",
issuerUrl: "http://localhost:8081",
clientId: "minder-cli",
responseBody: `invalid_json`,
expectedError: "error unmarshaling credentials: invalid character 'i' looking for beginning of value",
},
{
name: "Error refreshing credentials with local server",
refreshToken: "valid_refresh_token",
issuerUrl: "http://localhost:8081",
clientId: "minder-cli",
responseBody: `{"error":"invalid_grant","error_description":"Invalid refresh token"}`,
expectedError: "error refreshing credentials: invalid_grant: Invalid refresh token",
Expand All @@ -335,8 +315,7 @@ func TestRefreshCredentials(t *testing.T) {
}))
defer server.Close()

parsedURL, _ := url.Parse(server.URL)
tt.issuerUrl = parsedURL.String()
tt.issuerUrl = server.URL

result, err := util.RefreshCredentials(tt.refreshToken, tt.issuerUrl, tt.clientId)
if tt.expectedError != "" {
Expand Down Expand Up @@ -391,7 +370,7 @@ func TestLoadCredentials(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
testDir := t.TempDir()
SetEnvVar(t, "XDG_CONFIG_HOME", testDir)
setEnvVar(t, "XDG_CONFIG_HOME", testDir)
// Create the minder directory inside the temp directory
minderDir := filepath.Join(testDir, "minder")
err := os.MkdirAll(minderDir, 0750)
Expand All @@ -403,10 +382,7 @@ func TestLoadCredentials(t *testing.T) {

if tt.fileContent != "" {
// Create a temporary file with the specified content
err := os.WriteFile(filePath, []byte(tt.fileContent), 0600)
if err != nil {
t.Fatalf("failed to write test file: %v", err)
}
require.NoError(t, os.WriteFile(filePath, []byte(tt.fileContent), 0600))
// Print the file path and content for debugging
t.Logf("Test %s: written file path %s with content: %s", tt.name, filePath, tt.fileContent)
} else {
Expand All @@ -433,6 +409,7 @@ func TestLoadCredentials(t *testing.T) {
}
}

// TestCase struct for holding test case data
type TestCase struct {
name string
token string
Expand All @@ -444,18 +421,6 @@ type TestCase struct {
createServer func(t *testing.T, tt TestCase) *httptest.Server
}

func createTestServer(t *testing.T, tt TestCase) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
require.NoError(t, err, "error parsing form")

require.Equal(t, tt.clientId, r.Form.Get("client_id"))
require.Equal(t, tt.token, r.Form.Get("token"))
require.Equal(t, tt.tokenHint, r.Form.Get("token_type_hint"))
}))
}

// TestRevokeToken tests the RevokeToken function
func TestRevokeToken(t *testing.T) {
t.Parallel()
Expand All @@ -468,7 +433,16 @@ func TestRevokeToken(t *testing.T) {
tokenHint: "refresh_token",
expectedPath: "/realms/stacklok/protocol/openid-connect/revoke",
expectError: false,
createServer: createTestServer,
createServer: func(t *testing.T, tt TestCase) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
require.NoError(t, err, "error parsing form")
require.Equal(t, tt.clientId, r.Form.Get("client_id"))
require.Equal(t, tt.token, r.Form.Get("token"))
require.Equal(t, tt.tokenHint, r.Form.Get("token_type_hint"))
}))
},
},
{
name: "Invalid issuer URL",
Expand Down Expand Up @@ -499,7 +473,6 @@ func TestRevokeToken(t *testing.T) {
}
})
}

}

// TestGetJsonFromProto tests the GetJsonFromProto function
Expand Down Expand Up @@ -540,8 +513,13 @@ func TestGetJsonFromProto(t *testing.T) {
t.Errorf("GetJsonFromProto() error = %v, expectedError %v", err, tt.expectedError)
return
}
if strings.TrimSpace(jsonStr) != strings.TrimSpace(tt.expectedJson) {
t.Errorf("GetJsonFromProto() = %v, expected %v", jsonStr, tt.expectedJson)

// Normalize JSON strings by removing all whitespaces and new lines
normalizedResult := strings.Join(strings.Fields(jsonStr), "")
normalizedExpected := strings.Join(strings.Fields(tt.expectedJson), "")

if normalizedResult != normalizedExpected {
t.Errorf("GetJsonFromProto() = %v, expected %v", normalizedResult, normalizedExpected)
}
})
}
Expand Down Expand Up @@ -724,7 +702,6 @@ func TestExpandFileArgs(t *testing.T) {
}) {
t.Errorf("ExpandFileArgs() = %v, want %v", got, tt.expected)
}
defer os.RemoveAll(testDir)
})
}
}
Expand Down

0 comments on commit af6a40c

Please sign in to comment.