Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -379,22 +379,6 @@ public AccMgrAuthTokenProvider(ClientManager clientManager, String instanceUrl,
@Override
public String getNewAuthToken() {
SalesforceSDKLogger.i(TAG, "Need new access token");
UserAccountManager userAccountManager = SalesforceSDKManager.getInstance().getUserAccountManager();
Account[] accounts = clientManager.getAccounts();
Account matchingAccount = null;

// Find the account for this client.
for (Account account : accounts) {
UserAccount user = userAccountManager.buildUserAccount(account);
if (user != null && lastNewAuthToken.equals(user.getAuthToken())) {
matchingAccount = account;
}
}

// Fail early to ensure we don't logout the current user below by sending null.
if (matchingAccount == null) {
return null;
}

// Wait if another thread is already fetching an access token
synchronized (lock) {
Expand All @@ -408,10 +392,31 @@ public String getNewAuthToken() {
}
gettingAuthToken = true;
}

// Only check for matching account inside synchronized thread that
// is actually getting the new auth token.
UserAccountManager userAccountManager = SalesforceSDKManager.getInstance().getUserAccountManager();
Account[] accounts = clientManager.getAccounts();
Account matchingAccount = null;
String newAuthToken = null;
String newInstanceUrl = null;
try {

if (refreshToken != null) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the refreshToken is somehow null it seems like that should be a logout scenario, but we would need to cache the account/user in the client manager. This was mentioned on the previous PR as well, but feels like too large of a change for 13.1. I will create a 14.0 trust item to better handle this.

for (Account account : accounts) {
UserAccount user = userAccountManager.buildUserAccount(account);
if (user != null && refreshToken.equals(user.getRefreshToken())) {
matchingAccount = account;
break;
}
}
}

// Fail early to ensure we don't logout the current user below by sending null.
if (matchingAccount == null) {
return null;
}

try {
// Invalidate current auth token.
clientManager.invalidateToken(lastNewAuthToken);
final UserAccount userAccount = refreshStaleToken(matchingAccount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ import org.junit.Assert
import org.junit.Before
import org.junit.Test

private const val OLD_TOKEN = "old-token"
private const val REFRESHED_TOKEN = "refreshed-auth-token"
private const val OLD_ACCESS_TOKEN = "old-token"
private const val REFRESHED_ACCESS_TOKEN = "refreshed-auth-token"
private const val REFRESH_TOKEN = "refresh-token"

@SmallTest
class ClientManagerMockTest {
Expand Down Expand Up @@ -75,7 +76,7 @@ class ClientManagerMockTest {

val responseBody = """
{
"access_token": $REFRESHED_TOKEN,
"access_token": $REFRESHED_ACCESS_TOKEN,
"instance_url": "https://login.salesforce.com",
"id": "https://login.salesforce.com/id/orgId/userId",
"token_type": "Bearer",
Expand Down Expand Up @@ -108,7 +109,8 @@ class ClientManagerMockTest {
val broadcastIntentSlot = slot<Intent>()
val mockAccount = mockk<Account>(relaxed = true)
val mockUser = mockk<UserAccount>(relaxed = true) {
every { authToken } returns OLD_TOKEN
every { authToken } returns OLD_ACCESS_TOKEN
every { refreshToken } returns REFRESH_TOKEN
every { loginServer } returns "https://login.salesforce.com"
}
val mockClientManager = mockk<ClientManager>(relaxed = true) {
Expand All @@ -120,22 +122,22 @@ class ClientManagerMockTest {
val authTokenProvider = ClientManager.AccMgrAuthTokenProvider(
mockClientManager,
"https://login.salesforce.com",
OLD_TOKEN,
"",
OLD_ACCESS_TOKEN,
REFRESH_TOKEN,
)

val result = authTokenProvider.getNewAuthToken()
Assert.assertEquals(REFRESHED_TOKEN, result)
Assert.assertEquals(REFRESHED_ACCESS_TOKEN, result)

verify(exactly = 0) {
mockSDKManager.logout(any(), any(), any(), any())
}
verify(exactly = 1) {
mockClientManager.invalidateToken(OLD_TOKEN)
mockClientManager.invalidateToken(OLD_ACCESS_TOKEN)
mockUserAccountManager.updateAccount(mockAccount, capture(userSlot))
mockAppContext.sendBroadcast(capture(broadcastIntentSlot))
}
Assert.assertEquals(REFRESHED_TOKEN, userSlot.captured.authToken)
Assert.assertEquals(REFRESHED_ACCESS_TOKEN, userSlot.captured.authToken)
Assert.assertEquals(ClientManager.ACCESS_TOKEN_REFRESH_INTENT, broadcastIntentSlot.captured.action)
}

Expand All @@ -145,7 +147,8 @@ class ClientManagerMockTest {
val broadcastIntentSlot = slot<Intent>()
val mockAccount = mockk<Account>(relaxed = true)
val mockUser = mockk<UserAccount>(relaxed = true) {
every { authToken } returns OLD_TOKEN
every { authToken } returns OLD_ACCESS_TOKEN
every { refreshToken } returns REFRESH_TOKEN
every { loginServer } returns "https://login.salesforce.com"
}
val mockClientManager = mockk<ClientManager>(relaxed = true) {
Expand All @@ -157,22 +160,22 @@ class ClientManagerMockTest {
val authTokenProvider = ClientManager.AccMgrAuthTokenProvider(
mockClientManager,
"https://not.login.salesforce.com",
OLD_TOKEN,
"",
OLD_ACCESS_TOKEN,
REFRESH_TOKEN,
)

val result = authTokenProvider.getNewAuthToken()
Assert.assertEquals(REFRESHED_TOKEN, result)
Assert.assertEquals(REFRESHED_ACCESS_TOKEN, result)

verify(exactly = 0) {
mockSDKManager.logout(any(), any(), any(), any())
}
verify(exactly = 1) {
mockClientManager.invalidateToken(OLD_TOKEN)
mockClientManager.invalidateToken(OLD_ACCESS_TOKEN)
mockUserAccountManager.updateAccount(mockAccount, capture(userSlot))
mockAppContext.sendBroadcast(capture(broadcastIntentSlot))
}
Assert.assertEquals(REFRESHED_TOKEN, userSlot.captured.authToken)
Assert.assertEquals(REFRESHED_ACCESS_TOKEN, userSlot.captured.authToken)
Assert.assertEquals(ClientManager.INSTANCE_URL_UPDATE_INTENT, broadcastIntentSlot.captured.action)
}

Expand All @@ -184,8 +187,8 @@ class ClientManagerMockTest {
val authTokenProvider = ClientManager.AccMgrAuthTokenProvider(
mockClientManager,
"",
OLD_TOKEN,
"",
OLD_ACCESS_TOKEN,
REFRESH_TOKEN,
)

Assert.assertNull(authTokenProvider.getNewAuthToken())
Expand All @@ -201,17 +204,45 @@ class ClientManagerMockTest {
val mockAccount = mockk<Account>(relaxed = true)
val mockUser = mockk<UserAccount>(relaxed = true) {
every { authToken } returns "not-matching"
every { refreshToken } returns "not-matching"
}
val mockClientManager = mockk<ClientManager>(relaxed = true) {
every { accounts } returns emptyArray<Account>()
every { accounts } returns arrayOf(mockAccount)
}
every { mockUserAccountManager.currentUser } returns mockUser
every { mockUserAccountManager.buildUserAccount(mockAccount) } returns mockUser
val authTokenProvider = ClientManager.AccMgrAuthTokenProvider(
mockClientManager,
"",
OLD_TOKEN,
OLD_ACCESS_TOKEN,
REFRESH_TOKEN,
)

Assert.assertNull(authTokenProvider.getNewAuthToken())
verify(exactly = 0) {
mockSDKManager.logout(any(), any(), any(), any())
mockClientManager.invalidateToken(any())
mockAppContext.sendBroadcast(any())
}
}

@Test
fun testGetNewAuthToken_NullAuthToken() {
val mockAccount = mockk<Account>(relaxed = true)
val mockUser = mockk<UserAccount>(relaxed = true) {
every { authToken } returns "not-matching"
every { refreshToken } returns "not-matching"
}
val mockClientManager = mockk<ClientManager>(relaxed = true) {
every { accounts } returns arrayOf(mockAccount)
}
every { mockUserAccountManager.currentUser } returns mockUser
every { mockUserAccountManager.buildUserAccount(mockAccount) } returns mockUser
val authTokenProvider = ClientManager.AccMgrAuthTokenProvider(
mockClientManager,
"",
null,
REFRESH_TOKEN,
)

Assert.assertNull(authTokenProvider.getNewAuthToken())
Expand All @@ -229,11 +260,13 @@ class ClientManagerMockTest {
val mockAccount = mockk<Account>(relaxed = true)
val mockAccount2 = mockk<Account>(relaxed = true)
val mockUser = mockk<UserAccount>(relaxed = true) {
every { authToken } returns OLD_TOKEN
every { authToken } returns OLD_ACCESS_TOKEN
every { refreshToken } returns REFRESH_TOKEN
every { loginServer } returns "https://login.salesforce.com"
}
val mockUser2 = mockk<UserAccount>(relaxed = true) {
every { authToken } returns user2Token
every { refreshToken } returns "user2Refresh"
every { loginServer } returns "https://login.salesforce.com"
}
val mockClientManager = mockk<ClientManager>(relaxed = true) {
Expand All @@ -247,21 +280,21 @@ class ClientManagerMockTest {
val authTokenProvider = ClientManager.AccMgrAuthTokenProvider(
mockClientManager,
"https://login.salesforce.com",
OLD_TOKEN,
"",
OLD_ACCESS_TOKEN,
REFRESH_TOKEN,
)

Assert.assertEquals(REFRESHED_TOKEN, authTokenProvider.getNewAuthToken())
Assert.assertEquals(REFRESHED_ACCESS_TOKEN, authTokenProvider.getNewAuthToken())
verify(exactly = 0) {
mockClientManager.invalidateToken(user2Token)
mockSDKManager.logout(any(), any(), any(), any())
mockUserAccountManager.updateAccount(mockAccount2, any())
}
verify(exactly = 1) {
mockClientManager.invalidateToken(OLD_TOKEN)
mockClientManager.invalidateToken(OLD_ACCESS_TOKEN)
mockUserAccountManager.updateAccount(mockAccount, capture(userSlot))
}
Assert.assertEquals(REFRESHED_TOKEN, userSlot.captured.authToken)
Assert.assertEquals(REFRESHED_ACCESS_TOKEN, userSlot.captured.authToken)
}

@Test
Expand All @@ -278,7 +311,8 @@ class ClientManagerMockTest {
val broadcastIntentSlot = slot<Intent>()
val mockAccount = mockk<Account>(relaxed = true)
val mockUser = mockk<UserAccount>(relaxed = true) {
every { authToken } returns OLD_TOKEN
every { authToken } returns OLD_ACCESS_TOKEN
every { refreshToken } returns REFRESH_TOKEN
every { loginServer } returns "https://login.salesforce.com"
}

Expand All @@ -291,16 +325,16 @@ class ClientManagerMockTest {
val authTokenProvider = ClientManager.AccMgrAuthTokenProvider(
clientManagerSpy,
"https://login.salesforce.com",
OLD_TOKEN,
"",
OLD_ACCESS_TOKEN,
REFRESH_TOKEN,
)

Assert.assertNull(authTokenProvider.getNewAuthToken())
verify(exactly = 0) {
mockUserAccountManager.updateAccount(any(), any())
}
verify(exactly = 1) {
clientManagerSpy.invalidateToken(OLD_TOKEN)
clientManagerSpy.invalidateToken(OLD_ACCESS_TOKEN)
mockSDKManager.logout(capture(accountSlot), any(), any(), capture(reasonSlot))
mockAppContext.sendBroadcast(capture(broadcastIntentSlot))
}
Expand All @@ -321,11 +355,13 @@ class ClientManagerMockTest {
val mockAccount = mockk<Account>(relaxed = true)
val mockAccount2 = mockk<Account>(relaxed = true)
val mockUser = mockk<UserAccount>(relaxed = true) {
every { authToken } returns OLD_TOKEN
every { authToken } returns OLD_ACCESS_TOKEN
every { refreshToken } returns REFRESH_TOKEN
every { loginServer } returns "https://login.salesforce.com"
}
val mockUser2 = mockk<UserAccount>(relaxed = true) {
every { authToken } returns user2Token
every { refreshToken } returns "user2Refresh"
every { loginServer } returns "https://login.salesforce.com"
}
val mockClientManager = mockk<ClientManager>(relaxed = true) {
Expand All @@ -341,21 +377,21 @@ class ClientManagerMockTest {
val authTokenProvider = ClientManager.AccMgrAuthTokenProvider(
mockClientManager,
"https://login.salesforce.com",
OLD_TOKEN,
"",
OLD_ACCESS_TOKEN,
REFRESH_TOKEN,
)

Assert.assertEquals(REFRESHED_TOKEN, authTokenProvider.getNewAuthToken())
Assert.assertEquals(REFRESHED_ACCESS_TOKEN, authTokenProvider.getNewAuthToken())
verify(exactly = 0) {
mockClientManager.invalidateToken(user2Token)
mockSDKManager.logout(any(), any(), any(), any())
mockUserAccountManager.updateAccount(mockAccount2, any())
}
verify(exactly = 1) {
mockClientManager.invalidateToken(OLD_TOKEN)
mockClientManager.invalidateToken(OLD_ACCESS_TOKEN)
mockUserAccountManager.updateAccount(mockAccount, capture(userSlot))
}
Assert.assertEquals(REFRESHED_TOKEN, userSlot.captured.authToken)
Assert.assertEquals(REFRESHED_ACCESS_TOKEN, userSlot.captured.authToken)
}

/*
Expand All @@ -379,11 +415,13 @@ class ClientManagerMockTest {
val mockAccount = mockk<Account>(relaxed = true)
val mockAccount2 = mockk<Account>(relaxed = true)
val mockUser = mockk<UserAccount>(relaxed = true) {
every { authToken } returns OLD_TOKEN
every { authToken } returns OLD_ACCESS_TOKEN
every { refreshToken } returns REFRESH_TOKEN
every { loginServer } returns "https://login.salesforce.com"
}
val mockUser2 = mockk<UserAccount>(relaxed = true) {
every { authToken } returns user2Token
every { refreshToken } returns "user2Refresh"
every { loginServer } returns "https://login.salesforce.com"
}
val mockClientManager = mockk<ClientManager>(relaxed = true) {
Expand All @@ -402,8 +440,8 @@ class ClientManagerMockTest {
val authTokenProvider = ClientManager.AccMgrAuthTokenProvider(
clientManagerSpy,
"https://login.salesforce.com",
OLD_TOKEN,
"",
OLD_ACCESS_TOKEN,
REFRESH_TOKEN,
)

Assert.assertNull(authTokenProvider.getNewAuthToken())
Expand All @@ -416,7 +454,7 @@ class ClientManagerMockTest {
}

verify(exactly = 1) {
clientManagerSpy.invalidateToken(OLD_TOKEN)
clientManagerSpy.invalidateToken(OLD_ACCESS_TOKEN)
mockSDKManager.logout(capture(accountSlot), any(), any(), capture(reasonSlot))
mockAppContext.sendBroadcast(capture(broadcastIntentSlot))
}
Expand Down
Loading