Skip to content

Commit a16524f

Browse files
Garzasgithub-actions[bot]
authored andcommitted
fix: mls client init [WPB-15022] (#3178)
* fix: secure mls client creation with is mls enabled * fix: dont persist mls conversations when mls is disabled * review improvements
1 parent 6fb2177 commit a16524f

File tree

13 files changed

+212
-53
lines changed

13 files changed

+212
-53
lines changed

Diff for: logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt

+1
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ interface MLSFailure : CoreFailure {
211211
data object StaleProposal : MLSFailure
212212
data object StaleCommit : MLSFailure
213213
data object InternalErrors : MLSFailure
214+
data object Disabled : MLSFailure
214215

215216
data class Generic(internal val exception: Exception) : MLSFailure {
216217
val rootCause: Throwable get() = exception

Diff for: logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/MLSClientProvider.kt

+5
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import com.wire.kalium.cryptography.coreCryptoCentral
2828
import com.wire.kalium.logger.KaliumLogLevel
2929
import com.wire.kalium.logic.CoreFailure
3030
import com.wire.kalium.logic.E2EIFailure
31+
import com.wire.kalium.logic.MLSFailure
3132
import com.wire.kalium.logic.configuration.UserConfigRepository
3233
import com.wire.kalium.logic.data.conversation.ClientId
3334
import com.wire.kalium.logic.data.featureConfig.FeatureConfigRepository
@@ -130,6 +131,10 @@ class MLSClientProviderImpl(
130131
}
131132

132133
override suspend fun getOrFetchMLSConfig(): Either<CoreFailure, SupportedCipherSuite> {
134+
if (!userConfigRepository.isMLSEnabled().getOrElse(true)) {
135+
kaliumLogger.w("$TAG: Cannot fetch MLS config, MLS is disabled.")
136+
return MLSFailure.Disabled.left()
137+
}
133138
return userConfigRepository.getSupportedCipherSuite().flatMapLeft<CoreFailure, SupportedCipherSuite> {
134139
featureConfigRepository.getFeatureConfigs().map {
135140
it.mlsModel.supportedCipherSuite

Diff for: logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt

+35-25
Original file line numberDiff line numberDiff line change
@@ -432,16 +432,19 @@ internal class ConversationDataSource internal constructor(
432432
): Either<CoreFailure, Boolean> = wrapStorageRequest {
433433
val isNewConversation = conversationDAO.getConversationById(conversation.id.toDao()) == null
434434
if (isNewConversation) {
435-
conversationDAO.insertConversation(
436-
conversationMapper.fromApiModelToDaoModel(
437-
conversation,
438-
mlsGroupState = conversation.groupId?.let { mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent) },
439-
selfTeamIdProvider().getOrNull(),
435+
val mlsGroupState = conversation.groupId?.let { mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent) }
436+
if (shouldPersistMLSConversation(mlsGroupState)) {
437+
conversationDAO.insertConversation(
438+
conversationMapper.fromApiModelToDaoModel(
439+
conversation,
440+
mlsGroupState = mlsGroupState?.getOrNull(),
441+
selfTeamIdProvider().getOrNull(),
442+
)
440443
)
441-
)
442-
memberDAO.insertMembersWithQualifiedId(
443-
memberMapper.fromApiModelToDaoModel(conversation.members), idMapper.fromApiToDao(conversation.id)
444-
)
444+
memberDAO.insertMembersWithQualifiedId(
445+
memberMapper.fromApiModelToDaoModel(conversation.members), idMapper.fromApiToDao(conversation.id)
446+
)
447+
}
445448
}
446449
isNewConversation
447450
}
@@ -453,17 +456,19 @@ internal class ConversationDataSource internal constructor(
453456
invalidateMembers: Boolean
454457
) = wrapStorageRequest {
455458
val conversationEntities = conversations
456-
.map { conversationResponse ->
457-
conversationMapper.fromApiModelToDaoModel(
458-
conversationResponse,
459-
mlsGroupState = conversationResponse.groupId?.let {
460-
mlsGroupState(
461-
idMapper.fromGroupIDEntity(it),
462-
originatedFromEvent
463-
)
464-
},
465-
selfTeamIdProvider().getOrNull(),
466-
)
459+
.mapNotNull { conversationResponse ->
460+
val mlsGroupState = conversationResponse.groupId?.let {
461+
mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent)
462+
}
463+
if (shouldPersistMLSConversation(mlsGroupState)) {
464+
conversationMapper.fromApiModelToDaoModel(
465+
conversationResponse,
466+
mlsGroupState = mlsGroupState?.getOrNull(),
467+
selfTeamIdProvider().getOrNull(),
468+
)
469+
} else {
470+
null
471+
}
467472
}
468473
conversationDAO.insertConversations(conversationEntities)
469474
conversations.forEach { conversationsResponse ->
@@ -483,10 +488,11 @@ internal class ConversationDataSource internal constructor(
483488
}
484489
}
485490

486-
private suspend fun mlsGroupState(groupId: GroupID, originatedFromEvent: Boolean = false): ConversationEntity.GroupState =
487-
hasEstablishedMLSGroup(groupId).fold({
488-
throw IllegalStateException(it.toString()) // TODO find a more fitting exception?
489-
}, { exists ->
491+
private suspend fun mlsGroupState(
492+
groupId: GroupID,
493+
originatedFromEvent: Boolean = false
494+
): Either<CoreFailure, ConversationEntity.GroupState> = hasEstablishedMLSGroup(groupId)
495+
.map { exists ->
490496
if (exists) {
491497
ConversationEntity.GroupState.ESTABLISHED
492498
} else {
@@ -496,7 +502,7 @@ internal class ConversationDataSource internal constructor(
496502
ConversationEntity.GroupState.PENDING_JOIN
497503
}
498504
}
499-
})
505+
}
500506

501507
private suspend fun hasEstablishedMLSGroup(groupID: GroupID): Either<CoreFailure, Boolean> =
502508
mlsClientProvider.getMLSClient()
@@ -506,6 +512,10 @@ internal class ConversationDataSource internal constructor(
506512
}
507513
}
508514

515+
// if group state is not null and is left, then we don't want to persist the MLS conversation
516+
private fun shouldPersistMLSConversation(groupState: Either<CoreFailure, ConversationEntity.GroupState>?): Boolean =
517+
groupState?.fold({ true }, { false }) != true
518+
509519
@DelicateKaliumApi("This function does not get values from cache")
510520
override suspend fun getProteusSelfConversationId(): Either<StorageFailure, ConversationId> =
511521
wrapStorageRequest { conversationDAO.getSelfConversationId(ConversationEntity.Protocol.PROTEUS) }

Diff for: logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt

+14-8
Original file line numberDiff line numberDiff line change
@@ -670,17 +670,23 @@ internal class MLSConversationDataSource(
670670
})
671671

672672
override suspend fun getClientIdentity(clientId: ClientId) =
673-
wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }.flatMap {
674-
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
675-
wrapMLSRequest {
673+
wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }
674+
.flatMap { conversationClientInfo ->
675+
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
676+
wrapMLSRequest {
676677

677-
mlsClient.getDeviceIdentities(
678-
it.mlsGroupId,
679-
listOf(CryptoQualifiedClientId(it.clientId, it.userId.toModel().toCrypto()))
680-
).firstOrNull()
678+
mlsClient.getDeviceIdentities(
679+
conversationClientInfo.mlsGroupId,
680+
listOf(
681+
CryptoQualifiedClientId(
682+
conversationClientInfo.clientId,
683+
conversationClientInfo.userId.toModel().toCrypto()
684+
)
685+
)
686+
).firstOrNull()
687+
}
681688
}
682689
}
683-
}
684690

685691
override suspend fun getUserIdentity(userId: UserId) =
686692
wrapStorageRequest {

Diff for: logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1771,7 +1771,8 @@ class UserSessionScope internal constructor(
17711771
cachedClientIdClearer,
17721772
updateSupportedProtocolsAndResolveOneOnOnes,
17731773
registerMLSClientUseCase,
1774-
syncFeatureConfigsUseCase
1774+
syncFeatureConfigsUseCase,
1775+
userConfigRepository
17751776
)
17761777
}
17771778
val conversations: ConversationScope by lazy {

Diff for: logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
package com.wire.kalium.logic.feature.client
2020

21+
import com.wire.kalium.logic.configuration.UserConfigRepository
2122
import com.wire.kalium.logic.configuration.notification.NotificationTokenRepository
2223
import com.wire.kalium.logic.data.auth.verification.SecondFactorVerificationRepository
2324
import com.wire.kalium.logic.data.client.ClientRepository
@@ -71,7 +72,8 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor(
7172
private val cachedClientIdClearer: CachedClientIdClearer,
7273
private val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase,
7374
private val registerMLSClientUseCase: RegisterMLSClientUseCase,
74-
private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase
75+
private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase,
76+
private val userConfigRepository: UserConfigRepository
7577
) {
7678
@OptIn(DelicateKaliumApi::class)
7779
val register: RegisterClientUseCase
@@ -102,7 +104,7 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor(
102104
val deregisterNativePushToken: DeregisterTokenUseCase
103105
get() = DeregisterTokenUseCaseImpl(clientRepository, notificationTokenRepository)
104106
val mlsKeyPackageCountUseCase: MLSKeyPackageCountUseCase
105-
get() = MLSKeyPackageCountUseCaseImpl(keyPackageRepository, clientIdProvider, keyPackageLimitsProvider)
107+
get() = MLSKeyPackageCountUseCaseImpl(keyPackageRepository, clientIdProvider, keyPackageLimitsProvider, userConfigRepository)
106108
val restartSlowSyncProcessForRecoveryUseCase: RestartSlowSyncProcessForRecoveryUseCase
107109
get() = RestartSlowSyncProcessForRecoveryUseCaseImpl(slowSyncRepository)
108110
val refillKeyPackages: RefillKeyPackagesUseCase

Diff for: logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/IsAllowedToRegisterMLSClientUseCase.kt

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package com.wire.kalium.logic.feature.client
2121
import com.wire.kalium.logic.configuration.UserConfigRepository
2222
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
2323
import com.wire.kalium.logic.featureFlags.FeatureSupport
24-
import com.wire.kalium.logic.functional.fold
24+
import com.wire.kalium.logic.functional.getOrElse
2525
import com.wire.kalium.logic.functional.isRight
2626
import com.wire.kalium.util.DelicateKaliumApi
2727

@@ -45,8 +45,8 @@ internal class IsAllowedToRegisterMLSClientUseCaseImpl(
4545
) : IsAllowedToRegisterMLSClientUseCase {
4646

4747
override suspend operator fun invoke(): Boolean {
48-
return featureSupport.isMLSSupported &&
49-
mlsPublicKeysRepository.getKeys().isRight() &&
50-
userConfigRepository.isMLSEnabled().fold({ false }, { isEnabled -> isEnabled })
48+
return featureSupport.isMLSSupported
49+
&& userConfigRepository.isMLSEnabled().getOrElse(false)
50+
&& mlsPublicKeysRepository.getKeys().isRight()
5151
}
5252
}

Diff for: logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCase.kt

+13-4
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ package com.wire.kalium.logic.feature.keypackage
2020

2121
import com.wire.kalium.logic.CoreFailure
2222
import com.wire.kalium.logic.NetworkFailure
23+
import com.wire.kalium.logic.configuration.UserConfigRepository
2324
import com.wire.kalium.logic.data.conversation.ClientId
2425
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
2526
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
2627
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
2728
import com.wire.kalium.logic.functional.fold
29+
import com.wire.kalium.logic.functional.getOrElse
2830

2931
/**
3032
* This use case will return the current number of key packages.
@@ -37,6 +39,7 @@ internal class MLSKeyPackageCountUseCaseImpl(
3739
private val keyPackageRepository: KeyPackageRepository,
3840
private val currentClientIdProvider: CurrentClientIdProvider,
3941
private val keyPackageLimitsProvider: KeyPackageLimitsProvider,
42+
private val userConfigRepository: UserConfigRepository
4043
) : MLSKeyPackageCountUseCase {
4144
override suspend operator fun invoke(fromAPI: Boolean): MLSKeyPackageCountResult =
4245
when (fromAPI) {
@@ -47,10 +50,15 @@ internal class MLSKeyPackageCountUseCaseImpl(
4750
private suspend fun validKeyPackagesCountFromAPI() = currentClientIdProvider().fold({
4851
MLSKeyPackageCountResult.Failure.FetchClientIdFailure(it)
4952
}, { selfClient ->
50-
keyPackageRepository.getAvailableKeyPackageCount(selfClient).fold(
51-
{
52-
MLSKeyPackageCountResult.Failure.NetworkCallFailure(it)
53-
}, { MLSKeyPackageCountResult.Success(selfClient, it.count, keyPackageLimitsProvider.needsRefill(it.count)) })
53+
if (userConfigRepository.isMLSEnabled().getOrElse(false)) {
54+
keyPackageRepository.getAvailableKeyPackageCount(selfClient)
55+
.fold(
56+
{ MLSKeyPackageCountResult.Failure.NetworkCallFailure(it) },
57+
{ MLSKeyPackageCountResult.Success(selfClient, it.count, keyPackageLimitsProvider.needsRefill(it.count)) }
58+
)
59+
} else {
60+
MLSKeyPackageCountResult.Failure.NotEnabled
61+
}
5462
})
5563

5664
private suspend fun validKeyPackagesCountFromMLSClient() =
@@ -70,6 +78,7 @@ sealed class MLSKeyPackageCountResult {
7078
sealed class Failure : MLSKeyPackageCountResult() {
7179
class NetworkCallFailure(val networkFailure: NetworkFailure) : Failure()
7280
class FetchClientIdFailure(val genericFailure: CoreFailure) : Failure()
81+
data object NotEnabled : Failure()
7382
data class Generic(val genericFailure: CoreFailure) : Failure()
7483
}
7584
}

Diff for: logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ internal object MLSMessageFailureHandler {
4343
is MLSFailure.StaleCommit -> MLSMessageFailureResolution.Ignore
4444
is MLSFailure.MessageEpochTooOld -> MLSMessageFailureResolution.Ignore
4545
is MLSFailure.InternalErrors -> MLSMessageFailureResolution.Ignore
46+
is MLSFailure.Disabled -> MLSMessageFailureResolution.Ignore
4647
else -> MLSMessageFailureResolution.InformUser
4748
}
4849
}

Diff for: logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/MLSClientProviderTest.kt

+44
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818
package com.wire.kalium.logic.data.client
1919

20+
import com.wire.kalium.logic.CoreFailure
2021
import com.wire.kalium.logic.StorageFailure
2122
import com.wire.kalium.logic.data.featureConfig.FeatureConfigTest
2223
import com.wire.kalium.logic.data.featureConfig.MLSModel
@@ -32,12 +33,15 @@ import com.wire.kalium.logic.util.arrangement.repository.FeatureConfigRepository
3233
import com.wire.kalium.logic.util.arrangement.repository.FeatureConfigRepositoryArrangementImpl
3334
import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangement
3435
import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangementImpl
36+
import com.wire.kalium.logic.util.shouldFail
3537
import com.wire.kalium.logic.util.shouldSucceed
3638
import com.wire.kalium.persistence.dbPassphrase.PassphraseStorage
39+
import io.ktor.util.reflect.instanceOf
3740
import io.mockative.Mock
3841
import io.mockative.coVerify
3942
import io.mockative.mock
4043
import io.mockative.once
44+
import io.mockative.verify
4145
import kotlinx.coroutines.runBlocking
4246
import kotlinx.coroutines.test.runTest
4347
import kotlin.test.Test
@@ -63,12 +67,16 @@ class MLSClientProviderTest {
6367
val (arrangement, mlsClientProvider) = Arrangement().arrange {
6468
withGetSupportedCipherSuitesReturning(StorageFailure.DataNotFound.left())
6569
withGetFeatureConfigsReturning(FeatureConfigTest.newModel(mlsModel = expected).right())
70+
withGetMLSEnabledReturning(true.right())
6671
}
6772

6873
mlsClientProvider.getOrFetchMLSConfig().shouldSucceed {
6974
assertEquals(expected.supportedCipherSuite, it)
7075
}
7176

77+
verify { arrangement.userConfigRepository.isMLSEnabled() }
78+
.wasInvoked(exactly = once)
79+
7280
coVerify { arrangement.userConfigRepository.getSupportedCipherSuite() }
7381
.wasInvoked(exactly = once)
7482

@@ -88,12 +96,17 @@ class MLSClientProviderTest {
8896

8997
val (arrangement, mlsClientProvider) = Arrangement().arrange {
9098
withGetSupportedCipherSuitesReturning(expected.right())
99+
withGetMLSEnabledReturning(true.right())
100+
withGetFeatureConfigsReturning(FeatureConfigTest.newModel().right())
91101
}
92102

93103
mlsClientProvider.getOrFetchMLSConfig().shouldSucceed {
94104
assertEquals(expected, it)
95105
}
96106

107+
verify { arrangement.userConfigRepository.isMLSEnabled() }
108+
.wasInvoked(exactly = once)
109+
97110
coVerify {
98111
arrangement.userConfigRepository.getSupportedCipherSuite()
99112
}.wasInvoked(exactly = once)
@@ -103,6 +116,37 @@ class MLSClientProviderTest {
103116
}.wasNotInvoked()
104117
}
105118

119+
@Test
120+
fun givenMLSDisabledWhenGetOrFetchMLSConfigIsCalledThenDoNotCallGetSupportedCipherSuiteOrGetFeatureConfigs() = runTest {
121+
// given
122+
val (arrangement, mlsClientProvider) = Arrangement().arrange {
123+
withGetMLSEnabledReturning(false.right())
124+
withGetSupportedCipherSuitesReturning(
125+
SupportedCipherSuite(
126+
supported = listOf(
127+
CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256,
128+
CipherSuite.MLS_256_DHKEMP384_AES256GCM_SHA384_P384
129+
),
130+
default = CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256
131+
).right()
132+
)
133+
}
134+
135+
// when
136+
val result = mlsClientProvider.getOrFetchMLSConfig()
137+
138+
// then
139+
result.shouldFail {
140+
it.instanceOf(CoreFailure.Unknown::class)
141+
}
142+
143+
coVerify { arrangement.userConfigRepository.getSupportedCipherSuite() }
144+
.wasNotInvoked()
145+
146+
coVerify { arrangement.featureConfigRepository.getFeatureConfigs() }
147+
.wasNotInvoked()
148+
}
149+
106150
private class Arrangement : UserConfigRepositoryArrangement by UserConfigRepositoryArrangementImpl(),
107151
FeatureConfigRepositoryArrangement by FeatureConfigRepositoryArrangementImpl() {
108152

0 commit comments

Comments
 (0)