Skip to content

Commit

Permalink
fix: sub conversation epoch verification [WPB-15647] (#3244)
Browse files Browse the repository at this point in the history
* fix: sub conversation epoch verification

* test fix

* detekt fix

* test fix
  • Loading branch information
Garzas authored and github-actions[bot] committed Jan 24, 2025
1 parent 544ccea commit cd2f9e0
Show file tree
Hide file tree
Showing 15 changed files with 441 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Wire
* Copyright (C) 2025 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.data.conversation

import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.SubconversationId

data class SubConversation(
val id: SubconversationId,
val parentId: ConversationId,
val groupId: GroupID,
val epoch: ULong,
val epochTimestamp: String?,
val mlsCipherSuiteTag: Int?,

val members: List<SubconversationMember>,
)

data class SubconversationMember(
val clientId: String,
val userId: String,
val domain: String
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.wrapApiRequest
import com.wire.kalium.logic.wrapMLSRequest
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi
import com.wire.kalium.network.api.authenticated.conversation.SubconversationMember
import com.wire.kalium.network.api.authenticated.conversation.SubconversationMemberDTO

/**
* Leave a sub-conversation you've previously joined
Expand Down Expand Up @@ -71,7 +71,7 @@ internal class LeaveSubconversationUseCaseImpl(
subconversationRepository.getSubconversationInfo(conversationId, subconversationId)?.let {
Either.Right(it)
} ?: wrapApiRequest { conversationApi.fetchSubconversationDetails(conversationId.toApi(), subconversationId.toApi()) }.flatMap {
if (it.members.contains(SubconversationMember(selfClientId.value, selfUserId.value, selfUserId.domain))) {
if (it.members.contains(SubconversationMemberDTO(selfClientId.value, selfUserId.value, selfUserId.domain))) {
Either.Right(GroupID(it.groupId))
} else {
Either.Right(null)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Wire
* Copyright (C) 2025 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.data.conversation

import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.network.api.authenticated.conversation.SubconversationMemberDTO
import com.wire.kalium.network.api.authenticated.conversation.SubconversationResponse

fun SubconversationResponse.toModel(): SubConversation {
return SubConversation(
id = id.toModel(),
parentId = parentId.toModel(),
groupId = GroupID(groupId),
epoch = epoch,
epochTimestamp = epochTimestamp,
mlsCipherSuiteTag = mlsCipherSuiteTag,
members = members.map { it.toModel() }
)
}

fun SubconversationMemberDTO.toModel(): SubconversationMember {
return SubconversationMember(
clientId = clientId,
userId = userId,
domain = domain
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.SubconversationId
import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.functional.onFailure
import com.wire.kalium.logic.functional.onSuccess
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.wrapApiRequest
import com.wire.kalium.network.api.authenticated.conversation.SubconversationDeleteRequest
import com.wire.kalium.network.api.authenticated.conversation.SubconversationResponse
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi
import io.ktor.util.collections.ConcurrentMap
import kotlinx.coroutines.sync.Mutex
Expand Down Expand Up @@ -68,7 +68,7 @@ interface SubconversationRepository {
suspend fun fetchRemoteSubConversationDetails(
conversationId: ConversationId,
subConversationId: SubconversationId
): Either<NetworkFailure, SubconversationResponse>
): Either<NetworkFailure, SubConversation>
}

class SubconversationRepositoryImpl(
Expand Down Expand Up @@ -139,14 +139,15 @@ class SubconversationRepositoryImpl(
)
}

// TODO: Replace SubconversationResponse with a domain model
override suspend fun fetchRemoteSubConversationDetails(
conversationId: ConversationId,
subConversationId: SubconversationId
): Either<NetworkFailure, SubconversationResponse> = wrapApiRequest {
): Either<NetworkFailure, SubConversation> = wrapApiRequest {
conversationApi.fetchSubconversationDetails(
conversationId.toApi(),
subConversationId.toApi()
)
}.map {
it.toModel()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.wire.kalium.cryptography.MLSGroupId
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.network.api.model.UserAssetDTO
import com.wire.kalium.persistence.dao.QualifiedIDEntity
import com.wire.kalium.network.api.model.SubconversationId as NetworkSubConversationId

internal typealias NetworkQualifiedId = com.wire.kalium.network.api.model.QualifiedID
internal typealias PersistenceQualifiedId = QualifiedIDEntity
Expand Down Expand Up @@ -53,3 +54,5 @@ internal fun SubconversationId.toApi(): String = value
internal fun GroupID.toCrypto(): MLSGroupId = value

internal fun CryptoQualifiedClientId.toModel() = QualifiedClientID(ClientId(value), userId.toModel())

internal fun NetworkSubConversationId.toModel() = SubconversationId(this)
Original file line number Diff line number Diff line change
Expand Up @@ -1399,7 +1399,8 @@ class UserSessionScope internal constructor(
systemMessageInserter = systemMessageInserter,
conversationRepository = conversationRepository,
mlsConversationRepository = mlsConversationRepository,
joinExistingMLSConversation = joinExistingMLSConversationUseCase
joinExistingMLSConversation = joinExistingMLSConversationUseCase,
subconversationRepository = subconversationRepository
)

private val newMessageHandler: NewMessageEventHandler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.conversation.JoinExistingMLSConversationUseCase
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.conversation.SubconversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.SubconversationId
import com.wire.kalium.logic.data.message.SystemMessageInserter
import com.wire.kalium.logic.data.conversation.JoinExistingMLSConversationUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.map
Expand All @@ -34,19 +36,36 @@ import kotlinx.datetime.Clock
import kotlinx.datetime.Instant

interface StaleEpochVerifier {
suspend fun verifyEpoch(conversationId: ConversationId, timestamp: Instant? = null): Either<CoreFailure, Unit>
suspend fun verifyEpoch(
conversationId: ConversationId,
subConversationId: SubconversationId? = null,
timestamp: Instant? = null
): Either<CoreFailure, Unit>
}

internal class StaleEpochVerifierImpl(
private val systemMessageInserter: SystemMessageInserter,
private val conversationRepository: ConversationRepository,
private val subconversationRepository: SubconversationRepository,
private val mlsConversationRepository: MLSConversationRepository,
private val joinExistingMLSConversation: JoinExistingMLSConversationUseCase
) : StaleEpochVerifier {

private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.MESSAGES) }
override suspend fun verifyEpoch(conversationId: ConversationId, timestamp: Instant?): Either<CoreFailure, Unit> {
logger.i("Verifying stale epoch")
override suspend fun verifyEpoch(
conversationId: ConversationId,
subConversationId: SubconversationId?,
timestamp: Instant?
): Either<CoreFailure, Unit> {
return if (subConversationId != null) {
verifySubConversationEpoch(conversationId, subConversationId)
} else {
verifyConversationEpoch(conversationId)
}
}

private suspend fun verifyConversationEpoch(conversationId: ConversationId): Either<CoreFailure, Unit> {
logger.i("Verifying stale epoch for conversation ${conversationId.toLogString()}")
return getUpdatedConversationProtocolInfo(conversationId).flatMap { protocol ->
if (protocol is Conversation.ProtocolInfo.MLS) {
Either.Right(protocol)
Expand Down Expand Up @@ -74,6 +93,32 @@ internal class StaleEpochVerifierImpl(
}
}

private suspend fun verifySubConversationEpoch(
conversationId: ConversationId,
subConversationId: SubconversationId
): Either<CoreFailure, Unit> {
logger.i("Verifying stale epoch for subconversation ${subConversationId.toLogString()}")
return subconversationRepository.fetchRemoteSubConversationDetails(conversationId, subConversationId)
.flatMap { subConversationDetails ->
mlsConversationRepository.isGroupOutOfSync(subConversationDetails.groupId, subConversationDetails.epoch)
.map { epochIsStale ->
epochIsStale
}
.flatMap { hasMissedCommits ->
if (hasMissedCommits) {
logger.w("Epoch stale due to missing commits, joining by external commit")
subconversationRepository.fetchRemoteSubConversationGroupInfo(conversationId, subConversationId)
.flatMap { groupInfo ->
mlsConversationRepository.joinGroupByExternalCommit(subConversationDetails.groupId, groupInfo)
}
} else {
logger.i("Epoch stale due to unprocessed events")
Either.Right(Unit)
}
}
}
}

private suspend fun getUpdatedConversationProtocolInfo(conversationId: ConversationId): Either<CoreFailure, Conversation.ProtocolInfo> {
return conversationRepository.fetchConversation(conversationId).flatMap {
conversationRepository.getConversationProtocolInfo(conversationId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ internal class NewMessageEventHandlerImpl(
eventLogger.logFailure(it, "protocol" to "MLS", "mlsOutcome" to "OUT_OF_SYNC")
staleEpochVerifier.verifyEpoch(
event.conversationId,
event.subconversationId,
event.messageInstant
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ import com.wire.kalium.logic.util.thenReturnSequentially
import com.wire.kalium.network.api.base.authenticated.message.MLSMessageApi
import com.wire.kalium.network.api.model.ErrorResponse
import com.wire.kalium.network.exceptions.KaliumException
import com.wire.kalium.util.DateTimeUtil
import com.wire.kalium.util.time.UNIX_FIRST_DATE
import io.ktor.utils.io.core.toByteArray
import io.mockative.Mock
Expand Down Expand Up @@ -323,7 +322,7 @@ class MessageSenderTest {
// then
result.shouldSucceed()
coVerify {
arrangement.staleEpochVerifier.verifyEpoch(eq(Arrangement.TEST_CONVERSATION_ID), any())
arrangement.staleEpochVerifier.verifyEpoch(eq(Arrangement.TEST_CONVERSATION_ID), any(), any())
}.wasInvoked(once)
}
}
Expand Down
Loading

0 comments on commit cd2f9e0

Please sign in to comment.