Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mls): set removal-keys for 1on1 calls from conversation-response (WPB-10743) 🍒 #3019

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ internal class ConversationGroupRepositoryImpl(
val conversationEntity = conversationMapper.fromApiModelToDaoModel(
conversationResponse, mlsGroupState = ConversationEntity.GroupState.PENDING_CREATION, selfTeamId
)
val mlsPublicKeys = conversationMapper.fromApiModel(conversationResponse.publicKeys)
val protocol = protocolInfoMapper.fromEntity(conversationEntity.protocolInfo)

return wrapStorageRequest {
Expand All @@ -166,7 +167,8 @@ internal class ConversationGroupRepositoryImpl(
is Conversation.ProtocolInfo.MLSCapable -> mlsConversationRepository.establishMLSGroup(
groupID = protocol.groupId,
members = usersList + selfUserId,
allowSkippingUsersWithoutKeyPackages = true
publicKeys = mlsPublicKeys,
allowSkippingUsersWithoutKeyPackages = true,
).map { it.notAddedUsers }
}
}.flatMap { protocolSpecificAdditionFailures ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.message.MessagePreview
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.user.AvailabilityStatusMapper
import com.wire.kalium.logic.data.user.BotService
import com.wire.kalium.logic.data.user.Connection
Expand All @@ -41,6 +42,7 @@ import com.wire.kalium.network.api.authenticated.conversation.ConvTeamInfo
import com.wire.kalium.network.api.authenticated.conversation.ConversationResponse
import com.wire.kalium.network.api.authenticated.conversation.CreateConversationRequest
import com.wire.kalium.network.api.authenticated.conversation.ReceiptMode
import com.wire.kalium.network.api.authenticated.serverpublickey.MLSPublicKeysDTO
import com.wire.kalium.network.api.model.ConversationAccessDTO
import com.wire.kalium.network.api.model.ConversationAccessRoleDTO
import com.wire.kalium.persistence.dao.conversation.ConversationEntity
Expand All @@ -59,6 +61,7 @@ import kotlin.time.toDuration

interface ConversationMapper {
fun fromApiModelToDaoModel(apiModel: ConversationResponse, mlsGroupState: GroupState?, selfUserTeamId: TeamId?): ConversationEntity
fun fromApiModel(mlsPublicKeysDTO: MLSPublicKeysDTO?): MLSPublicKeys?
fun fromDaoModel(daoModel: ConversationViewEntity): Conversation
fun fromDaoModel(daoModel: ConversationEntity): Conversation
fun fromDaoModelToDetails(
Expand Down Expand Up @@ -136,6 +139,12 @@ internal class ConversationMapperImpl(
legalHoldStatus = ConversationEntity.LegalHoldStatus.DISABLED
)

override fun fromApiModel(mlsPublicKeysDTO: MLSPublicKeysDTO?) = mlsPublicKeysDTO?.let {
MLSPublicKeys(
removal = mlsPublicKeysDTO.removal
)
}

override fun fromDaoModel(daoModel: ConversationViewEntity): Conversation = with(daoModel) {
val lastReadDateEntity = if (type == ConversationEntity.Type.CONNECTION_PENDING) Instant.UNIX_FIRST_DATE
else lastReadDate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeys
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.mlspublickeys.getRemovalKey
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.data.e2ei.RevocationListChecker
Expand Down Expand Up @@ -123,6 +125,7 @@ interface MLSConversationRepository {
suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
publicKeys: MLSPublicKeys? = null,
allowSkippingUsersWithoutKeyPackages: Boolean = false
): Either<CoreFailure, MLSAdditionResult>

Expand Down Expand Up @@ -554,16 +557,18 @@ internal class MLSConversationDataSource(
override suspend fun establishMLSGroup(
groupID: GroupID,
members: List<UserId>,
allowSkippingUsersWithoutKeyPackages: Boolean,
publicKeys: MLSPublicKeys?,
allowSkippingUsersWithoutKeyPackages: Boolean
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
mlsClientProvider.getMLSClient().flatMap<MLSAdditionResult, CoreFailure, MLSClient> {
mlsPublicKeysRepository.getKeyForCipherSuite(
CipherSuite.fromTag(it.getDefaultCipherSuite())
).flatMap { key ->
mlsClientProvider.getMLSClient().flatMap<MLSAdditionResult, CoreFailure, MLSClient> { mlsClient ->
val cipherSuite = CipherSuite.fromTag(mlsClient.getDefaultCipherSuite())
val keys = publicKeys?.getRemovalKey(cipherSuite) ?: mlsPublicKeysRepository.getKeyForCipherSuite(cipherSuite)

keys.flatMap { externalSenders ->
establishMLSGroup(
groupID = groupID,
members = members,
externalSenders = key,
externalSenders = externalSenders,
allowPartialMemberList = allowSkippingUsersWithoutKeyPackages
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ data class MLSPublicKeys(
val removal: Map<String, String>?
)

fun MLSPublicKeys.getRemovalKey(cipherSuite: CipherSuite): Either<CoreFailure, ByteArray> {
val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper()
val keySignature = mlsPublicKeysMapper.fromCipherSuite(cipherSuite)
val key = this.removal?.let { removalKeys ->
removalKeys[keySignature.value]
} ?: return Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite $cipherSuite")))
return key.decodeBase64Bytes().right()
}

interface MLSPublicKeysRepository {
suspend fun fetchKeys(): Either<CoreFailure, MLSPublicKeys>
suspend fun getKeys(): Either<CoreFailure, MLSPublicKeys>
Expand All @@ -42,7 +51,6 @@ interface MLSPublicKeysRepository {

class MLSPublicKeysRepositoryImpl(
private val mlsPublicKeyApi: MLSPublicKeyApi,
private val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper()
) : MLSPublicKeysRepository {

// TODO: make it thread safe
Expand All @@ -60,14 +68,8 @@ class MLSPublicKeysRepositoryImpl(
}

override suspend fun getKeyForCipherSuite(cipherSuite: CipherSuite): Either<CoreFailure, ByteArray> {

return getKeys().flatMap { serverPublicKeys ->
val keySignature = mlsPublicKeysMapper.fromCipherSuite(cipherSuite)
val key = serverPublicKeys.removal?.let { removalKeys ->
removalKeys[keySignature.value]
} ?: return Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite $cipherSuite")))
key.decodeBase64Bytes().right()
serverPublicKeys.getRemovalKey(cipherSuite)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ class ConversationGroupRepositoryTest {
}.wasInvoked(once)

coVerify {
mlsConversationRepository.establishMLSGroup(any(), any(), eq(true))
mlsConversationRepository.establishMLSGroup(any(), any(), any(), eq(true))
}.wasInvoked(once)

coVerify {
Expand Down Expand Up @@ -465,7 +465,7 @@ class ConversationGroupRepositoryTest {
}.wasInvoked(once)

coVerify {
mlsConversationRepository.establishMLSGroup(any(), any(), eq(true))
mlsConversationRepository.establishMLSGroup(any(), any(), any(), eq(true))
}.wasInvoked(once)

coVerify {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arr
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.CRYPTO_CLIENT_ID
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.E2EI_CONVERSATION_CLIENT_INFO_ENTITY
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.KEY_PACKAGE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.MLS_PUBLIC_KEY
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.ROTATE_BUNDLE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.TEST_FAILURE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.WIRE_IDENTITY
Expand Down Expand Up @@ -98,6 +99,7 @@ import io.mockative.matches
import io.mockative.mock
import io.mockative.once
import io.mockative.twice
import io.mockative.verify
import kotlinx.coroutines.async
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.first
Expand Down Expand Up @@ -168,7 +170,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

coVerify {
Expand Down Expand Up @@ -280,6 +282,84 @@ class MLSConversationRepositoryTest {
}.wasNotInvoked()
}

@Test
fun givenPublicKeysIsNotNull_whenCallingEstablishMLSGroup_ThenGetPublicKeysRepositoryNotCalled() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetDefaultCipherSuite(CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519)
.withCommitPendingProposalsReturningNothing()
.withClaimKeyPackagesSuccessful()
.withGetMLSClientSuccessful()
.withKeyForCipherSuite()
.withAddMLSMemberSuccessful()
.withSendCommitBundleSuccessful()
.arrange()

val result =
mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = MLS_PUBLIC_KEY)
result.shouldSucceed()

coVerify {
arrangement.mlsClient.createConversation(
groupId = eq(Arrangement.RAW_GROUP_ID),
externalSenders = any())
}.wasInvoked(once)

coVerify {
arrangement.mlsClient.addMember(
groupId = eq(Arrangement.RAW_GROUP_ID),
membersKeyPackages = any())
}.wasInvoked(once)

coVerify {
arrangement.mlsMessageApi.sendCommitBundle(any<MLSMessageApi.CommitBundle>())
}.wasInvoked(once)

coVerify {
arrangement.mlsClient.commitAccepted(eq(Arrangement.RAW_GROUP_ID))
}.wasInvoked(once)

coVerify {
arrangement.mlsPublicKeysRepository.getKeyForCipherSuite(any())
}.wasNotInvoked()
}

@Test
fun givenPublicKeysIsNull_whenCallingEstablishMLSGroup_ThenGetPublicKeysRepositoryIsCalled() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetDefaultCipherSuite(CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519)
.withCommitPendingProposalsReturningNothing()
.withClaimKeyPackagesSuccessful()
.withGetMLSClientSuccessful()
.withKeyForCipherSuite()
.withAddMLSMemberSuccessful()
.withSendCommitBundleSuccessful()
.arrange()

val result =
mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

coVerify {
arrangement.mlsClient.createConversation(eq(Arrangement.RAW_GROUP_ID), any())
}.wasInvoked(once)

coVerify {
arrangement.mlsClient.addMember(eq(Arrangement.RAW_GROUP_ID), any())
}.wasInvoked(once)

coVerify {
arrangement.mlsMessageApi.sendCommitBundle(any<MLSMessageApi.CommitBundle>())
}.wasInvoked(once)

coVerify {
arrangement.mlsClient.commitAccepted(eq(Arrangement.RAW_GROUP_ID))
}.wasInvoked(once)

coVerify {
arrangement.mlsPublicKeysRepository.getKeyForCipherSuite(any())
}.wasInvoked(once)
}

@Test
fun givenNewCrlDistributionPoints_whenEstablishingMLSGroup_thenCheckRevocationList() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement(testKaliumDispatcher)
Expand Down Expand Up @@ -329,7 +409,7 @@ class MLSConversationRepositoryTest {
.withWaitUntilLiveSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

coVerify {
Expand Down Expand Up @@ -357,7 +437,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleFailing(Arrangement.MLS_STALE_MESSAGE_ERROR)
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldFail()

coVerify {
Expand Down Expand Up @@ -385,7 +465,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1))
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1), publicKeys = null)
result.shouldSucceed()

coVerify {
Expand All @@ -410,7 +490,7 @@ class MLSConversationRepositoryTest {
.withSendCommitBundleSuccessful()
.arrange()

val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, emptyList())
val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, emptyList(), publicKeys = null)
result.shouldSucceed()

coVerify {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package com.wire.kalium.network.api.authenticated.conversation

import com.wire.kalium.network.api.authenticated.serverpublickey.MLSPublicKeysDTO
import com.wire.kalium.network.api.model.ConversationAccessDTO
import com.wire.kalium.network.api.model.ConversationAccessRoleDTO
import com.wire.kalium.network.api.model.ConversationId
Expand Down Expand Up @@ -86,7 +87,10 @@ data class ConversationResponse(
val accessRole: Set<ConversationAccessRoleDTO>?,

@SerialName("receipt_mode")
val receiptMode: ReceiptMode
val receiptMode: ReceiptMode,

@SerialName("public_keys")
val publicKeys: MLSPublicKeysDTO? = null
) {

@Suppress("MagicNumber")
Expand Down Expand Up @@ -155,6 +159,14 @@ data class ConversationResponseV3(
val receiptMode: ReceiptMode,
)

@Serializable
data class ConversationResponseV6(
@SerialName("conversation")
val conversation: ConversationResponseV3,
@SerialName("public_keys")
val publicKeys: MLSPublicKeysDTO
)

@Serializable
data class ConversationMembersResponse(
@SerialName("self")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package com.wire.kalium.network.api.model

import com.wire.kalium.network.api.authenticated.conversation.ConversationResponse
import com.wire.kalium.network.api.authenticated.conversation.ConversationResponseV3
import com.wire.kalium.network.api.authenticated.conversation.ConversationResponseV6
import com.wire.kalium.network.api.authenticated.conversation.CreateConversationRequest
import com.wire.kalium.network.api.authenticated.conversation.CreateConversationRequestV3
import com.wire.kalium.network.api.authenticated.conversation.UpdateConversationAccessRequest
Expand All @@ -33,6 +34,7 @@ interface ApiModelMapper {
fun toApiV3(request: CreateConversationRequest): CreateConversationRequestV3
fun toApiV3(request: UpdateConversationAccessRequest): UpdateConversationAccessRequestV3
fun fromApiV3(response: ConversationResponseV3): ConversationResponse
fun fromApiV6(response: ConversationResponseV6): ConversationResponse
}

class ApiModelMapperImpl : ApiModelMapper {
Expand Down Expand Up @@ -76,4 +78,23 @@ class ApiModelMapperImpl : ApiModelMapper {
response.receiptMode
)

override fun fromApiV6(response: ConversationResponseV6): ConversationResponse =
ConversationResponse(
creator = response.conversation.creator,
members = response.conversation.members,
name = response.conversation.name,
id = response.conversation.id,
groupId = response.conversation.groupId,
epoch = response.conversation.epoch,
type = response.conversation.type,
messageTimer = response.conversation.messageTimer,
teamId = response.conversation.teamId,
protocol = response.conversation.protocol,
lastEventTime = response.conversation.lastEventTime,
mlsCipherSuiteTag = response.conversation.mlsCipherSuiteTag,
access = response.conversation.access,
accessRole = response.conversation.accessRole,
receiptMode = response.conversation.receiptMode,
publicKeys = response.publicKeys
)
}
Loading
Loading