diff --git a/ably/auth_integration_test.go b/ably/auth_integration_test.go index b19ad264..821b1f4a 100644 --- a/ably/auth_integration_test.go +++ b/ably/auth_integration_test.go @@ -349,7 +349,7 @@ func TestAuth_JWT_Token_RSA8c(t *testing.T) { t.Run("Get JWT from echo server", func(t *testing.T) { app := ablytest.MustSandbox(nil) defer safeclose(t, app) - jwt, err := app.CreateJwt(3 * time.Second) + jwt, err := app.CreateJwt(3*time.Second, false) assert.NoError(t, err) assert.True(t, strings.HasPrefix(jwt, "ey")) }) @@ -361,7 +361,7 @@ func TestAuth_JWT_Token_RSA8c(t *testing.T) { rec, optn := ablytest.NewHttpRecorder() rest, err := ably.NewREST( ably.WithAuthURL(ablytest.CREATE_JWT_URL), - ably.WithAuthParams(app.GetJwtAuthParams(30*time.Second)), + ably.WithAuthParams(app.GetJwtAuthParams(30*time.Second, false)), ably.WithEnvironment(app.Environment), ably.WithKey(""), optn[0], @@ -396,7 +396,7 @@ func TestAuth_JWT_Token_RSA8c(t *testing.T) { jwtToken := "" authCallback := ably.WithAuthCallback(func(ctx context.Context, tp ably.TokenParams) (ably.Tokener, error) { - jwtTokenString, err := app.CreateJwt(time.Second * 30) + jwtTokenString, err := app.CreateJwt(time.Second*30, false) jwtToken = jwtTokenString if err != nil { return nil, err @@ -428,8 +428,30 @@ func TestAuth_JWT_Token_RSA8c(t *testing.T) { assert.Equal(t, "Bearer "+encodedToken, statsRequest.Header.Get("Authorization")) }) - t.Run("Should return error when JWT is invalid", func(t *testing.T) { + t.Run("RSA4e, RSA4b: Should return error when JWT is invalid", func(t *testing.T) { + app := ablytest.MustSandbox(nil) + defer safeclose(t, app) + rec, optn := ablytest.NewHttpRecorder() + rest, err := ably.NewREST( + ably.WithAuthURL(ablytest.CREATE_JWT_URL), + ably.WithAuthParams(app.GetJwtAuthParams(30*time.Second, true)), + ably.WithEnvironment(app.Environment), + ably.WithKey(""), + optn[0], + ) + + assert.NoError(t, err, "rest()=%v", err) + _, err = rest.Stats().Pages(context.Background()) + var errorInfo *ably.ErrorInfo + assert.Error(t, err, "Stats()=%v", err) + assert.ErrorAs(t, err, &errorInfo) + assert.Equal(t, 40144, int(errorInfo.Code)) + assert.Equal(t, 401, errorInfo.StatusCode) + assert.Contains(t, err.Error(), "invalid JWT format") + + assert.Len(t, rec.Requests(), 2) + assert.Len(t, rec.Responses(), 2) }) } diff --git a/ably/realtime_conn_spec_integration_test.go b/ably/realtime_conn_spec_integration_test.go index 838de2d9..eb4be2fa 100644 --- a/ably/realtime_conn_spec_integration_test.go +++ b/ably/realtime_conn_spec_integration_test.go @@ -3031,7 +3031,7 @@ func TestRealtimeConn_RTC8a_ExplicitAuthorizeWhileConnected(t *testing.T) { tokenExpiry := 3 * time.Second // Returns token that expires after 3 seconds causing disconnect every 3 seconds authCallback := func(ctx context.Context, tp ably.TokenParams) (ably.Tokener, error) { - jwtTokenString, err := app.CreateJwt(tokenExpiry) + jwtTokenString, err := app.CreateJwt(tokenExpiry, false) if err != nil { return nil, err } diff --git a/ablytest/sandbox.go b/ablytest/sandbox.go index 36596351..b96e3100 100644 --- a/ablytest/sandbox.go +++ b/ablytest/sandbox.go @@ -260,24 +260,28 @@ func (app *Sandbox) URL(paths ...string) string { var CREATE_JWT_URL string = "https://echo.ably.io/createJWT" // Returns authParams, required for authUrl as a mode of auth -func (app *Sandbox) GetJwtAuthParams(expiresIn time.Duration) url.Values { +func (app *Sandbox) GetJwtAuthParams(expiresIn time.Duration, invalid bool) url.Values { key, secret := app.KeyParts() authParams := url.Values{} authParams.Add("environment", app.Environment) authParams.Add("returnType", "jwt") authParams.Add("keyName", key) - authParams.Add("keySecret", secret) + if invalid { + authParams.Add("keySecret", "invalid") + } else { + authParams.Add("keySecret", secret) + } authParams.Add("expiresIn", fmt.Sprint(expiresIn.Seconds())) return authParams } // Returns JWT with given expiry -func (app *Sandbox) CreateJwt(expiresIn time.Duration) (string, error) { +func (app *Sandbox) CreateJwt(expiresIn time.Duration, invalid bool) (string, error) { u, err := url.Parse(CREATE_JWT_URL) if err != nil { return "", err } - u.RawQuery = app.GetJwtAuthParams(expiresIn).Encode() + u.RawQuery = app.GetJwtAuthParams(expiresIn, invalid).Encode() req, err := http.NewRequest(http.MethodGet, u.String(), nil) if err != nil { return "", fmt.Errorf("client: could not create request: %s", err)