Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: sub conversation epoch verification [WPB-15647] #3244

Merged
merged 5 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Wire
* Copyright (C) 2025 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.data.conversation

import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.SubconversationId

data class SubConversation(
val id: SubconversationId,
val parentId: ConversationId,
val groupId: GroupID,
val epoch: ULong,
val epochTimestamp: String?,
val mlsCipherSuiteTag: Int?,

val members: List<SubconversationMember>,
)

data class SubconversationMember(
val clientId: String,
val userId: String,
val domain: String
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.wrapApiRequest
import com.wire.kalium.logic.wrapMLSRequest
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi
import com.wire.kalium.network.api.authenticated.conversation.SubconversationMember
import com.wire.kalium.network.api.authenticated.conversation.SubconversationMemberDTO

/**
* Leave a sub-conversation you've previously joined
Expand Down Expand Up @@ -71,7 +71,7 @@ internal class LeaveSubconversationUseCaseImpl(
subconversationRepository.getSubconversationInfo(conversationId, subconversationId)?.let {
Either.Right(it)
} ?: wrapApiRequest { conversationApi.fetchSubconversationDetails(conversationId.toApi(), subconversationId.toApi()) }.flatMap {
if (it.members.contains(SubconversationMember(selfClientId.value, selfUserId.value, selfUserId.domain))) {
if (it.members.contains(SubconversationMemberDTO(selfClientId.value, selfUserId.value, selfUserId.domain))) {
Either.Right(GroupID(it.groupId))
} else {
Either.Right(null)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Wire
* Copyright (C) 2025 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.data.conversation

import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.network.api.authenticated.conversation.SubconversationMemberDTO
import com.wire.kalium.network.api.authenticated.conversation.SubconversationResponse

fun SubconversationResponse.toModel(): SubConversation {
return SubConversation(
id = id.toModel(),
parentId = parentId.toModel(),
groupId = GroupID(groupId),
epoch = epoch,
epochTimestamp = epochTimestamp,
mlsCipherSuiteTag = mlsCipherSuiteTag,
members = members.map { it.toModel() }
)
}

fun SubconversationMemberDTO.toModel(): SubconversationMember {
return SubconversationMember(
clientId = clientId,
userId = userId,
domain = domain
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.SubconversationId
import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.functional.onFailure
import com.wire.kalium.logic.functional.onSuccess
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.wrapApiRequest
import com.wire.kalium.network.api.authenticated.conversation.SubconversationDeleteRequest
import com.wire.kalium.network.api.authenticated.conversation.SubconversationResponse
import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi
import io.ktor.util.collections.ConcurrentMap
import kotlinx.coroutines.sync.Mutex
Expand Down Expand Up @@ -68,7 +68,7 @@ interface SubconversationRepository {
suspend fun fetchRemoteSubConversationDetails(
conversationId: ConversationId,
subConversationId: SubconversationId
): Either<NetworkFailure, SubconversationResponse>
): Either<NetworkFailure, SubConversation>
}

class SubconversationRepositoryImpl(
Expand Down Expand Up @@ -139,14 +139,15 @@ class SubconversationRepositoryImpl(
)
}

// TODO: Replace SubconversationResponse with a domain model
override suspend fun fetchRemoteSubConversationDetails(
conversationId: ConversationId,
subConversationId: SubconversationId
): Either<NetworkFailure, SubconversationResponse> = wrapApiRequest {
): Either<NetworkFailure, SubConversation> = wrapApiRequest {
conversationApi.fetchSubconversationDetails(
conversationId.toApi(),
subConversationId.toApi()
)
}.map {
it.toModel()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.wire.kalium.cryptography.MLSGroupId
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.network.api.model.UserAssetDTO
import com.wire.kalium.persistence.dao.QualifiedIDEntity
import com.wire.kalium.network.api.model.SubconversationId as NetworkSubConversationId

internal typealias NetworkQualifiedId = com.wire.kalium.network.api.model.QualifiedID
internal typealias PersistenceQualifiedId = QualifiedIDEntity
Expand Down Expand Up @@ -53,3 +54,5 @@ internal fun SubconversationId.toApi(): String = value
internal fun GroupID.toCrypto(): MLSGroupId = value

internal fun CryptoQualifiedClientId.toModel() = QualifiedClientID(ClientId(value), userId.toModel())

internal fun NetworkSubConversationId.toModel() = SubconversationId(this)
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,8 @@ class UserSessionScope internal constructor(
systemMessageInserter = systemMessageInserter,
conversationRepository = conversationRepository,
mlsConversationRepository = mlsConversationRepository,
joinExistingMLSConversation = joinExistingMLSConversationUseCase
joinExistingMLSConversation = joinExistingMLSConversationUseCase,
subconversationRepository = subconversationRepository
)

private val newMessageHandler: NewMessageEventHandler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.conversation.JoinExistingMLSConversationUseCase
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.conversation.SubconversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.SubconversationId
import com.wire.kalium.logic.data.message.SystemMessageInserter
import com.wire.kalium.logic.data.conversation.JoinExistingMLSConversationUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.map
Expand All @@ -34,19 +36,36 @@ import kotlinx.datetime.Clock
import kotlinx.datetime.Instant

interface StaleEpochVerifier {
suspend fun verifyEpoch(conversationId: ConversationId, timestamp: Instant? = null): Either<CoreFailure, Unit>
suspend fun verifyEpoch(
conversationId: ConversationId,
subConversationId: SubconversationId? = null,
timestamp: Instant? = null
): Either<CoreFailure, Unit>
}

internal class StaleEpochVerifierImpl(
private val systemMessageInserter: SystemMessageInserter,
private val conversationRepository: ConversationRepository,
private val subconversationRepository: SubconversationRepository,
private val mlsConversationRepository: MLSConversationRepository,
private val joinExistingMLSConversation: JoinExistingMLSConversationUseCase
) : StaleEpochVerifier {

private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.MESSAGES) }
override suspend fun verifyEpoch(conversationId: ConversationId, timestamp: Instant?): Either<CoreFailure, Unit> {
logger.i("Verifying stale epoch")
override suspend fun verifyEpoch(
conversationId: ConversationId,
subConversationId: SubconversationId?,
timestamp: Instant?
): Either<CoreFailure, Unit> {
return if (subConversationId != null) {
verifySubConversationEpoch(conversationId, subConversationId)
} else {
verifyConversationEpoch(conversationId)
}
}

private suspend fun verifyConversationEpoch(conversationId: ConversationId): Either<CoreFailure, Unit> {
logger.i("Verifying stale epoch for conversation ${conversationId.toLogString()}")
return getUpdatedConversationProtocolInfo(conversationId).flatMap { protocol ->
if (protocol is Conversation.ProtocolInfo.MLS) {
Either.Right(protocol)
Expand Down Expand Up @@ -74,6 +93,32 @@ internal class StaleEpochVerifierImpl(
}
}

private suspend fun verifySubConversationEpoch(
conversationId: ConversationId,
subConversationId: SubconversationId
): Either<CoreFailure, Unit> {
logger.i("Verifying stale epoch for subconversation ${subConversationId.toLogString()}")
return subconversationRepository.fetchRemoteSubConversationDetails(conversationId, subConversationId)
.flatMap { subConversationDetails ->
mlsConversationRepository.isGroupOutOfSync(subConversationDetails.groupId, subConversationDetails.epoch)
.map { epochIsStale ->
epochIsStale
}
.flatMap { hasMissedCommits ->
if (hasMissedCommits) {
logger.w("Epoch stale due to missing commits, joining by external commit")
subconversationRepository.fetchRemoteSubConversationGroupInfo(conversationId, subConversationId)
.flatMap { groupInfo ->
mlsConversationRepository.joinGroupByExternalCommit(subConversationDetails.groupId, groupInfo)
}
} else {
logger.i("Epoch stale due to unprocessed events")
Either.Right(Unit)
}
}
}
}

private suspend fun getUpdatedConversationProtocolInfo(conversationId: ConversationId): Either<CoreFailure, Conversation.ProtocolInfo> {
return conversationRepository.fetchConversation(conversationId).flatMap {
conversationRepository.getConversationProtocolInfo(conversationId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ internal class NewMessageEventHandlerImpl(
eventLogger.logFailure(it, "protocol" to "MLS", "mlsOutcome" to "OUT_OF_SYNC")
staleEpochVerifier.verifyEpoch(
event.conversationId,
event.subconversationId,
event.messageInstant
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ import com.wire.kalium.logic.util.thenReturnSequentially
import com.wire.kalium.network.api.base.authenticated.message.MLSMessageApi
import com.wire.kalium.network.api.model.ErrorResponse
import com.wire.kalium.network.exceptions.KaliumException
import com.wire.kalium.util.DateTimeUtil
import com.wire.kalium.util.time.UNIX_FIRST_DATE
import io.ktor.utils.io.core.toByteArray
import io.mockative.Mock
Expand Down Expand Up @@ -323,7 +322,7 @@ class MessageSenderTest {
// then
result.shouldSucceed()
coVerify {
arrangement.staleEpochVerifier.verifyEpoch(eq(Arrangement.TEST_CONVERSATION_ID), any())
arrangement.staleEpochVerifier.verifyEpoch(eq(Arrangement.TEST_CONVERSATION_ID), any(), any())
}.wasInvoked(once)
}
}
Expand Down
Loading
Loading