From 0dd8d0236b21ea19ce8addf4b04cb07a13798171 Mon Sep 17 00:00:00 2001 From: Brandon Page Date: Thu, 25 Sep 2025 16:54:48 -0700 Subject: [PATCH] Improve getNewAuthToken. --- .../androidsdk/rest/ClientManager.java | 39 +++--- .../androidsdk/rest/ClientManagerMockTest.kt | 116 ++++++++++++------ 2 files changed, 99 insertions(+), 56 deletions(-) diff --git a/libs/SalesforceSDK/src/com/salesforce/androidsdk/rest/ClientManager.java b/libs/SalesforceSDK/src/com/salesforce/androidsdk/rest/ClientManager.java index e6ff6bfb5..271cc5db0 100644 --- a/libs/SalesforceSDK/src/com/salesforce/androidsdk/rest/ClientManager.java +++ b/libs/SalesforceSDK/src/com/salesforce/androidsdk/rest/ClientManager.java @@ -379,22 +379,6 @@ public AccMgrAuthTokenProvider(ClientManager clientManager, String instanceUrl, @Override public String getNewAuthToken() { SalesforceSDKLogger.i(TAG, "Need new access token"); - UserAccountManager userAccountManager = SalesforceSDKManager.getInstance().getUserAccountManager(); - Account[] accounts = clientManager.getAccounts(); - Account matchingAccount = null; - - // Find the account for this client. - for (Account account : accounts) { - UserAccount user = userAccountManager.buildUserAccount(account); - if (user != null && lastNewAuthToken.equals(user.getAuthToken())) { - matchingAccount = account; - } - } - - // Fail early to ensure we don't logout the current user below by sending null. - if (matchingAccount == null) { - return null; - } // Wait if another thread is already fetching an access token synchronized (lock) { @@ -408,10 +392,31 @@ public String getNewAuthToken() { } gettingAuthToken = true; } + + // Only check for matching account inside synchronized thread that + // is actually getting the new auth token. + UserAccountManager userAccountManager = SalesforceSDKManager.getInstance().getUserAccountManager(); + Account[] accounts = clientManager.getAccounts(); + Account matchingAccount = null; String newAuthToken = null; String newInstanceUrl = null; - try { + if (refreshToken != null) { + for (Account account : accounts) { + UserAccount user = userAccountManager.buildUserAccount(account); + if (user != null && refreshToken.equals(user.getRefreshToken())) { + matchingAccount = account; + break; + } + } + } + + // Fail early to ensure we don't logout the current user below by sending null. + if (matchingAccount == null) { + return null; + } + + try { // Invalidate current auth token. clientManager.invalidateToken(lastNewAuthToken); final UserAccount userAccount = refreshStaleToken(matchingAccount); diff --git a/libs/test/SalesforceSDKTest/src/com/salesforce/androidsdk/rest/ClientManagerMockTest.kt b/libs/test/SalesforceSDKTest/src/com/salesforce/androidsdk/rest/ClientManagerMockTest.kt index 14fce6937..8097e592c 100644 --- a/libs/test/SalesforceSDKTest/src/com/salesforce/androidsdk/rest/ClientManagerMockTest.kt +++ b/libs/test/SalesforceSDKTest/src/com/salesforce/androidsdk/rest/ClientManagerMockTest.kt @@ -32,8 +32,9 @@ import org.junit.Assert import org.junit.Before import org.junit.Test -private const val OLD_TOKEN = "old-token" -private const val REFRESHED_TOKEN = "refreshed-auth-token" +private const val OLD_ACCESS_TOKEN = "old-token" +private const val REFRESHED_ACCESS_TOKEN = "refreshed-auth-token" +private const val REFRESH_TOKEN = "refresh-token" @SmallTest class ClientManagerMockTest { @@ -75,7 +76,7 @@ class ClientManagerMockTest { val responseBody = """ { - "access_token": $REFRESHED_TOKEN, + "access_token": $REFRESHED_ACCESS_TOKEN, "instance_url": "https://login.salesforce.com", "id": "https://login.salesforce.com/id/orgId/userId", "token_type": "Bearer", @@ -108,7 +109,8 @@ class ClientManagerMockTest { val broadcastIntentSlot = slot() val mockAccount = mockk(relaxed = true) val mockUser = mockk(relaxed = true) { - every { authToken } returns OLD_TOKEN + every { authToken } returns OLD_ACCESS_TOKEN + every { refreshToken } returns REFRESH_TOKEN every { loginServer } returns "https://login.salesforce.com" } val mockClientManager = mockk(relaxed = true) { @@ -120,22 +122,22 @@ class ClientManagerMockTest { val authTokenProvider = ClientManager.AccMgrAuthTokenProvider( mockClientManager, "https://login.salesforce.com", - OLD_TOKEN, - "", + OLD_ACCESS_TOKEN, + REFRESH_TOKEN, ) val result = authTokenProvider.getNewAuthToken() - Assert.assertEquals(REFRESHED_TOKEN, result) + Assert.assertEquals(REFRESHED_ACCESS_TOKEN, result) verify(exactly = 0) { mockSDKManager.logout(any(), any(), any(), any()) } verify(exactly = 1) { - mockClientManager.invalidateToken(OLD_TOKEN) + mockClientManager.invalidateToken(OLD_ACCESS_TOKEN) mockUserAccountManager.updateAccount(mockAccount, capture(userSlot)) mockAppContext.sendBroadcast(capture(broadcastIntentSlot)) } - Assert.assertEquals(REFRESHED_TOKEN, userSlot.captured.authToken) + Assert.assertEquals(REFRESHED_ACCESS_TOKEN, userSlot.captured.authToken) Assert.assertEquals(ClientManager.ACCESS_TOKEN_REFRESH_INTENT, broadcastIntentSlot.captured.action) } @@ -145,7 +147,8 @@ class ClientManagerMockTest { val broadcastIntentSlot = slot() val mockAccount = mockk(relaxed = true) val mockUser = mockk(relaxed = true) { - every { authToken } returns OLD_TOKEN + every { authToken } returns OLD_ACCESS_TOKEN + every { refreshToken } returns REFRESH_TOKEN every { loginServer } returns "https://login.salesforce.com" } val mockClientManager = mockk(relaxed = true) { @@ -157,22 +160,22 @@ class ClientManagerMockTest { val authTokenProvider = ClientManager.AccMgrAuthTokenProvider( mockClientManager, "https://not.login.salesforce.com", - OLD_TOKEN, - "", + OLD_ACCESS_TOKEN, + REFRESH_TOKEN, ) val result = authTokenProvider.getNewAuthToken() - Assert.assertEquals(REFRESHED_TOKEN, result) + Assert.assertEquals(REFRESHED_ACCESS_TOKEN, result) verify(exactly = 0) { mockSDKManager.logout(any(), any(), any(), any()) } verify(exactly = 1) { - mockClientManager.invalidateToken(OLD_TOKEN) + mockClientManager.invalidateToken(OLD_ACCESS_TOKEN) mockUserAccountManager.updateAccount(mockAccount, capture(userSlot)) mockAppContext.sendBroadcast(capture(broadcastIntentSlot)) } - Assert.assertEquals(REFRESHED_TOKEN, userSlot.captured.authToken) + Assert.assertEquals(REFRESHED_ACCESS_TOKEN, userSlot.captured.authToken) Assert.assertEquals(ClientManager.INSTANCE_URL_UPDATE_INTENT, broadcastIntentSlot.captured.action) } @@ -184,8 +187,8 @@ class ClientManagerMockTest { val authTokenProvider = ClientManager.AccMgrAuthTokenProvider( mockClientManager, "", - OLD_TOKEN, - "", + OLD_ACCESS_TOKEN, + REFRESH_TOKEN, ) Assert.assertNull(authTokenProvider.getNewAuthToken()) @@ -201,17 +204,45 @@ class ClientManagerMockTest { val mockAccount = mockk(relaxed = true) val mockUser = mockk(relaxed = true) { every { authToken } returns "not-matching" + every { refreshToken } returns "not-matching" } val mockClientManager = mockk(relaxed = true) { - every { accounts } returns emptyArray() + every { accounts } returns arrayOf(mockAccount) } every { mockUserAccountManager.currentUser } returns mockUser every { mockUserAccountManager.buildUserAccount(mockAccount) } returns mockUser val authTokenProvider = ClientManager.AccMgrAuthTokenProvider( mockClientManager, "", - OLD_TOKEN, + OLD_ACCESS_TOKEN, + REFRESH_TOKEN, + ) + + Assert.assertNull(authTokenProvider.getNewAuthToken()) + verify(exactly = 0) { + mockSDKManager.logout(any(), any(), any(), any()) + mockClientManager.invalidateToken(any()) + mockAppContext.sendBroadcast(any()) + } + } + + @Test + fun testGetNewAuthToken_NullAuthToken() { + val mockAccount = mockk(relaxed = true) + val mockUser = mockk(relaxed = true) { + every { authToken } returns "not-matching" + every { refreshToken } returns "not-matching" + } + val mockClientManager = mockk(relaxed = true) { + every { accounts } returns arrayOf(mockAccount) + } + every { mockUserAccountManager.currentUser } returns mockUser + every { mockUserAccountManager.buildUserAccount(mockAccount) } returns mockUser + val authTokenProvider = ClientManager.AccMgrAuthTokenProvider( + mockClientManager, "", + null, + REFRESH_TOKEN, ) Assert.assertNull(authTokenProvider.getNewAuthToken()) @@ -229,11 +260,13 @@ class ClientManagerMockTest { val mockAccount = mockk(relaxed = true) val mockAccount2 = mockk(relaxed = true) val mockUser = mockk(relaxed = true) { - every { authToken } returns OLD_TOKEN + every { authToken } returns OLD_ACCESS_TOKEN + every { refreshToken } returns REFRESH_TOKEN every { loginServer } returns "https://login.salesforce.com" } val mockUser2 = mockk(relaxed = true) { every { authToken } returns user2Token + every { refreshToken } returns "user2Refresh" every { loginServer } returns "https://login.salesforce.com" } val mockClientManager = mockk(relaxed = true) { @@ -247,21 +280,21 @@ class ClientManagerMockTest { val authTokenProvider = ClientManager.AccMgrAuthTokenProvider( mockClientManager, "https://login.salesforce.com", - OLD_TOKEN, - "", + OLD_ACCESS_TOKEN, + REFRESH_TOKEN, ) - Assert.assertEquals(REFRESHED_TOKEN, authTokenProvider.getNewAuthToken()) + Assert.assertEquals(REFRESHED_ACCESS_TOKEN, authTokenProvider.getNewAuthToken()) verify(exactly = 0) { mockClientManager.invalidateToken(user2Token) mockSDKManager.logout(any(), any(), any(), any()) mockUserAccountManager.updateAccount(mockAccount2, any()) } verify(exactly = 1) { - mockClientManager.invalidateToken(OLD_TOKEN) + mockClientManager.invalidateToken(OLD_ACCESS_TOKEN) mockUserAccountManager.updateAccount(mockAccount, capture(userSlot)) } - Assert.assertEquals(REFRESHED_TOKEN, userSlot.captured.authToken) + Assert.assertEquals(REFRESHED_ACCESS_TOKEN, userSlot.captured.authToken) } @Test @@ -278,7 +311,8 @@ class ClientManagerMockTest { val broadcastIntentSlot = slot() val mockAccount = mockk(relaxed = true) val mockUser = mockk(relaxed = true) { - every { authToken } returns OLD_TOKEN + every { authToken } returns OLD_ACCESS_TOKEN + every { refreshToken } returns REFRESH_TOKEN every { loginServer } returns "https://login.salesforce.com" } @@ -291,8 +325,8 @@ class ClientManagerMockTest { val authTokenProvider = ClientManager.AccMgrAuthTokenProvider( clientManagerSpy, "https://login.salesforce.com", - OLD_TOKEN, - "", + OLD_ACCESS_TOKEN, + REFRESH_TOKEN, ) Assert.assertNull(authTokenProvider.getNewAuthToken()) @@ -300,7 +334,7 @@ class ClientManagerMockTest { mockUserAccountManager.updateAccount(any(), any()) } verify(exactly = 1) { - clientManagerSpy.invalidateToken(OLD_TOKEN) + clientManagerSpy.invalidateToken(OLD_ACCESS_TOKEN) mockSDKManager.logout(capture(accountSlot), any(), any(), capture(reasonSlot)) mockAppContext.sendBroadcast(capture(broadcastIntentSlot)) } @@ -321,11 +355,13 @@ class ClientManagerMockTest { val mockAccount = mockk(relaxed = true) val mockAccount2 = mockk(relaxed = true) val mockUser = mockk(relaxed = true) { - every { authToken } returns OLD_TOKEN + every { authToken } returns OLD_ACCESS_TOKEN + every { refreshToken } returns REFRESH_TOKEN every { loginServer } returns "https://login.salesforce.com" } val mockUser2 = mockk(relaxed = true) { every { authToken } returns user2Token + every { refreshToken } returns "user2Refresh" every { loginServer } returns "https://login.salesforce.com" } val mockClientManager = mockk(relaxed = true) { @@ -341,21 +377,21 @@ class ClientManagerMockTest { val authTokenProvider = ClientManager.AccMgrAuthTokenProvider( mockClientManager, "https://login.salesforce.com", - OLD_TOKEN, - "", + OLD_ACCESS_TOKEN, + REFRESH_TOKEN, ) - Assert.assertEquals(REFRESHED_TOKEN, authTokenProvider.getNewAuthToken()) + Assert.assertEquals(REFRESHED_ACCESS_TOKEN, authTokenProvider.getNewAuthToken()) verify(exactly = 0) { mockClientManager.invalidateToken(user2Token) mockSDKManager.logout(any(), any(), any(), any()) mockUserAccountManager.updateAccount(mockAccount2, any()) } verify(exactly = 1) { - mockClientManager.invalidateToken(OLD_TOKEN) + mockClientManager.invalidateToken(OLD_ACCESS_TOKEN) mockUserAccountManager.updateAccount(mockAccount, capture(userSlot)) } - Assert.assertEquals(REFRESHED_TOKEN, userSlot.captured.authToken) + Assert.assertEquals(REFRESHED_ACCESS_TOKEN, userSlot.captured.authToken) } /* @@ -379,11 +415,13 @@ class ClientManagerMockTest { val mockAccount = mockk(relaxed = true) val mockAccount2 = mockk(relaxed = true) val mockUser = mockk(relaxed = true) { - every { authToken } returns OLD_TOKEN + every { authToken } returns OLD_ACCESS_TOKEN + every { refreshToken } returns REFRESH_TOKEN every { loginServer } returns "https://login.salesforce.com" } val mockUser2 = mockk(relaxed = true) { every { authToken } returns user2Token + every { refreshToken } returns "user2Refresh" every { loginServer } returns "https://login.salesforce.com" } val mockClientManager = mockk(relaxed = true) { @@ -402,8 +440,8 @@ class ClientManagerMockTest { val authTokenProvider = ClientManager.AccMgrAuthTokenProvider( clientManagerSpy, "https://login.salesforce.com", - OLD_TOKEN, - "", + OLD_ACCESS_TOKEN, + REFRESH_TOKEN, ) Assert.assertNull(authTokenProvider.getNewAuthToken()) @@ -416,7 +454,7 @@ class ClientManagerMockTest { } verify(exactly = 1) { - clientManagerSpy.invalidateToken(OLD_TOKEN) + clientManagerSpy.invalidateToken(OLD_ACCESS_TOKEN) mockSDKManager.logout(capture(accountSlot), any(), any(), capture(reasonSlot)) mockAppContext.sendBroadcast(capture(broadcastIntentSlot)) }