Skip to content

Commit

Permalink
feat(e2ei): separate e2ei enrollment flow (WPB-5229) (#2223)
Browse files Browse the repository at this point in the history
* chore: separate e2ei enrollment steps

* chore: refactor and added new error types for e2eiFailure

* fix detekt

* chore: get acme url from team settings

* fix detekt

* introduce new inLine function to Either

* fix detekt

* fix tests

* fix tests

* cover comments

* fix detekt
  • Loading branch information
mchenani authored Nov 17, 2023
1 parent 21ae14c commit b80ce0f
Show file tree
Hide file tree
Showing 17 changed files with 275 additions and 344 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ class E2EIClientImpl(
)

private fun toAcmeChallenge(value: com.wire.crypto.AcmeChallenge) = AcmeChallenge(
value.delegate.toUByteArray().asByteArray(), value.url
value.delegate.toUByteArray().asByteArray(),
value.url,
value.target
)

fun toNewAcmeAuthz(value: com.wire.crypto.NewAcmeAuthz) = NewAcmeAuthz(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ data class NewAcmeOrder(

data class AcmeChallenge(
var delegate: JsonRawData,
var url: String
var url: String,
var target: String
)

data class NewAcmeAuthz(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ data class CryptoQualifiedClientId(
}

data class WireIdentity(
var clientId: String,
var handle: String,
var displayName: String,
var domain: String,
var certificate: String
val clientId: String,
val handle: String,
val displayName: String,
val domain: String,
val certificate: String
)

@Suppress("MagicNumber")
Expand Down
13 changes: 10 additions & 3 deletions logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package com.wire.kalium.logic

import com.wire.kalium.cryptography.exceptions.ProteusException
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.e2ei.usecase.E2EIEnrollmentResult
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.network.exceptions.APINotSupported
import com.wire.kalium.network.exceptions.KaliumException
Expand Down Expand Up @@ -189,9 +190,15 @@ interface MLSFailure : CoreFailure {
}
}

class E2EIFailure(internal val exception: Exception) : CoreFailure {
interface E2EIFailure : CoreFailure {
data class FailedInitialization(val step: E2EIEnrollmentResult.E2EIStep) : E2EIFailure
data class FailedOAuth(val reason: String) : E2EIFailure
data class FailedFinalization(val step: E2EIEnrollmentResult.E2EIStep) : E2EIFailure
data object FailedRotationAndMigration : E2EIFailure

val rootCause: Throwable get() = exception
class Generic(internal val exception: Exception) : E2EIFailure {
val rootCause: Throwable get() = exception
}
}

class ProteusFailure(internal val proteusException: ProteusException) : CoreFailure {
Expand Down Expand Up @@ -299,7 +306,7 @@ internal inline fun <T> wrapE2EIRequest(e2eiRequest: () -> T): Either<E2EIFailur
Either.Right(e2eiRequest())
} catch (e: Exception) {
kaliumLogger.e(e.stackTraceToString())
Either.Left(E2EIFailure(e))
Either.Left(E2EIFailure.Generic(e))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ internal class EI2EIClientProviderImpl(
private suspend fun getSelfUserInfo(): Either<CoreFailure, Pair<String, String>> {
val selfUser = userRepository.getSelfUser() ?: return Either.Left(CoreFailure.Unknown(NullPointerException()))
return if (selfUser.name == null || selfUser.handle == null)
Either.Left(E2EIFailure(IllegalArgumentException(ERROR_NAME_AND_HANDLE_MUST_NOT_BE_NULL)))
Either.Left(E2EIFailure.Generic(IllegalArgumentException(ERROR_NAME_AND_HANDLE_MUST_NOT_BE_NULL)))
else Either.Right(Pair(selfUser.name, selfUser.handle))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ data class DecryptedMessageBundle(
val identity: E2EIdentity?
)

data class E2EIdentity(var clientId: String, var handle: String, var displayName: String, var domain: String)
data class E2EIdentity(val clientId: String, val handle: String, val displayName: String, val domain: String)

@Suppress("TooManyFunctions", "LongParameterList")
interface MLSConversationRepository {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.wire.kalium.cryptography.AcmeDirectory
import com.wire.kalium.cryptography.NewAcmeAuthz
import com.wire.kalium.cryptography.NewAcmeOrder
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.client.E2EIClientProvider
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
Expand Down Expand Up @@ -63,21 +64,25 @@ interface E2EIRepository {
suspend fun rotateKeysAndMigrateConversations(certificateChain: String): Either<CoreFailure, Unit>
}

@Suppress("LongParameterList")
class E2EIRepositoryImpl(
private val e2EIApi: E2EIApi,
private val acmeApi: ACMEApi,
private val e2EIClientProvider: E2EIClientProvider,
private val mlsClientProvider: MLSClientProvider,
private val currentClientIdProvider: CurrentClientIdProvider,
private val mlsConversationRepository: MLSConversationRepository
private val mlsConversationRepository: MLSConversationRepository,
private val userConfigRepository: UserConfigRepository
) : E2EIRepository {

override suspend fun loadACMEDirectories(): Either<CoreFailure, AcmeDirectory> = wrapApiRequest {
acmeApi.getACMEDirectories()
}.flatMap { directories ->
e2EIClientProvider.getE2EIClient().flatMap { e2eiClient ->
wrapE2EIRequest {
e2eiClient.directoryResponse(Json.encodeToString(directories).encodeToByteArray())
override suspend fun loadACMEDirectories(): Either<CoreFailure, AcmeDirectory> = userConfigRepository.getE2EISettings().flatMap {
wrapApiRequest {
acmeApi.getACMEDirectories(TEMP_ACME_DISCOVER_URL)
}.flatMap { directories ->
e2EIClientProvider.getE2EIClient().flatMap { e2eiClient ->
wrapE2EIRequest {
e2eiClient.directoryResponse(Json.encodeToString(directories).encodeToByteArray())
}
}
}
}
Expand Down Expand Up @@ -200,4 +205,8 @@ class E2EIRepositoryImpl(
}
}

companion object {
// todo: remove after testing e2ei
const val TEMP_ACME_DISCOVER_URL = "https://acme.elna.wire.link/acme/defaultteams"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,8 @@ class UserSessionScope internal constructor(
e2EIClientProvider,
mlsClientProvider,
clientIdProvider,
mlsConversationRepository
mlsConversationRepository,
userConfigRepository
)

private val e2EIClientProvider: E2EIClientProvider by lazy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,106 +17,135 @@
*/
package com.wire.kalium.logic.feature.e2ei.usecase

import com.wire.kalium.cryptography.NewAcmeAuthz
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.E2EIFailure
import com.wire.kalium.logic.data.e2ei.E2EIRepository
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.getOrFail
import com.wire.kalium.logic.functional.onFailure
import com.wire.kalium.logic.kaliumLogger

/**
* Issue an E2EI certificate and re-initiate the MLSClient
*/
interface EnrollE2EIUseCase {
suspend operator fun invoke(idToken: String): Either<CoreFailure, E2EIEnrollmentResult>
suspend fun initialEnrollment(): Either<CoreFailure, E2EIEnrollmentResult>
suspend fun finalizeEnrollment(
idToken: String,
initializationResult: E2EIEnrollmentResult.Initialized
): Either<CoreFailure, E2EIEnrollmentResult>
}

@Suppress("ReturnCount")
class EnrollE2EIUseCaseImpl internal constructor(
private val e2EIRepository: E2EIRepository,
) : EnrollE2EIUseCase {
/**
* Operation to issue an E2EI certificate and re-initiate MLS Client
* Operation to initial E2EI certificate enrollment
*
* @param idToken id token generated by the IdP
* @return [Either] [CoreFailure] or [E2EIEnrollmentResult]
*/
override suspend fun invoke(idToken: String): Either<CoreFailure, E2EIEnrollmentResult> {
override suspend fun initialEnrollment(): Either<CoreFailure, E2EIEnrollmentResult> {
kaliumLogger.i("start E2EI Enrollment Initialization")

val acmeDirectories = e2EIRepository.loadACMEDirectories().fold({
val acmeDirectories = e2EIRepository.loadACMEDirectories().getOrFail {
return E2EIEnrollmentResult.Failed(E2EIEnrollmentResult.E2EIStep.AcmeDirectories, it).toEitherLeft()
}, { it })
}

var prevNonce = e2EIRepository.getACMENonce(acmeDirectories.newNonce).fold({
var prevNonce = e2EIRepository.getACMENonce(acmeDirectories.newNonce).getOrFail {
return E2EIEnrollmentResult.Failed(E2EIEnrollmentResult.E2EIStep.AcmeNonce, it).toEitherLeft()
}, { it })
}

prevNonce = e2EIRepository.createNewAccount(prevNonce, acmeDirectories.newAccount).fold({
prevNonce = e2EIRepository.createNewAccount(prevNonce, acmeDirectories.newAccount).getOrFail {
return E2EIEnrollmentResult.Failed(E2EIEnrollmentResult.E2EIStep.AcmeNewAccount, it).toEitherLeft()
}, { it })
}

val newOrderResponse = e2EIRepository.createNewOrder(prevNonce, acmeDirectories.newOrder).fold({
val newOrderResponse = e2EIRepository.createNewOrder(prevNonce, acmeDirectories.newOrder).getOrFail {
return E2EIEnrollmentResult.Failed(E2EIEnrollmentResult.E2EIStep.AcmeNewOrder, it).toEitherLeft()
}, { it })
}

prevNonce = newOrderResponse.second

val authzResponse = e2EIRepository.createAuthz(prevNonce, newOrderResponse.first.authorizations[0]).fold({
val authzResponse = e2EIRepository.createAuthz(prevNonce, newOrderResponse.first.authorizations[0]).getOrFail {
return E2EIEnrollmentResult.Failed(E2EIEnrollmentResult.E2EIStep.AcmeNewAuthz, it).toEitherLeft()
}, { it })
}

val initializationResult = E2EIEnrollmentResult.Initialized(
authzResponse.first.wireOidcChallenge!!.target, authzResponse.first, authzResponse.second, newOrderResponse.third
)

kaliumLogger.i("E2EI Enrollment Initialization Result: $initializationResult")

return Either.Right(initializationResult)
}

/**
* Operation to finalize E2EI certificate enrollment
*
* @param idToken id token generated by the IdP
* @param initializationResult e2ei initialization result
*
* @return [Either] [CoreFailure] or [E2EIEnrollmentResult]
*/
override suspend fun finalizeEnrollment(
idToken: String,
initializationResult: E2EIEnrollmentResult.Initialized
): Either<CoreFailure, E2EIEnrollmentResult> {

prevNonce = authzResponse.second
var prevNonce = initializationResult.lastNonce
val authz = initializationResult.authz
val orderLocation = initializationResult.orderLocation

val wireNonce = e2EIRepository.getWireNonce().fold({
val wireNonce = e2EIRepository.getWireNonce().getOrFail {
return E2EIEnrollmentResult.Failed(E2EIEnrollmentResult.E2EIStep.WireNonce, it).toEitherLeft()
}, { it })
}

val dpopToken = e2EIRepository.getDPoPToken(wireNonce).fold({
val dpopToken = e2EIRepository.getDPoPToken(wireNonce).getOrFail {
return E2EIEnrollmentResult.Failed(E2EIEnrollmentResult.E2EIStep.DPoPToken, it).toEitherLeft()
}, { it })
}

val wireAccessToken = e2EIRepository.getWireAccessToken(dpopToken).fold({
val wireAccessToken = e2EIRepository.getWireAccessToken(dpopToken).getOrFail {
return E2EIEnrollmentResult.Failed(E2EIEnrollmentResult.E2EIStep.WireAccessToken, it).toEitherLeft()
}, { it })
}

val dpopChallengeResponse = e2EIRepository.validateDPoPChallenge(
wireAccessToken.token, prevNonce, authzResponse.first.wireDpopChallenge!!
).fold({
wireAccessToken.token, prevNonce, authz.wireDpopChallenge!!
).getOrFail {
return E2EIEnrollmentResult.Failed(E2EIEnrollmentResult.E2EIStep.DPoPChallenge, it).toEitherLeft()
}, { it })
}

prevNonce = dpopChallengeResponse.nonce

val oidcChallengeResponse = e2EIRepository.validateOIDCChallenge(
idToken, prevNonce, authzResponse.first.wireOidcChallenge!!
).fold({
idToken, prevNonce, authz.wireOidcChallenge!!
).getOrFail {
return E2EIEnrollmentResult.Failed(E2EIEnrollmentResult.E2EIStep.OIDCChallenge, it).toEitherLeft()
}, { it })
}

prevNonce = oidcChallengeResponse.nonce

val orderResponse = e2EIRepository.checkOrderRequest(newOrderResponse.third, prevNonce).fold({
val orderResponse = e2EIRepository.checkOrderRequest(orderLocation, prevNonce).getOrFail {
return E2EIEnrollmentResult.Failed(E2EIEnrollmentResult.E2EIStep.CheckOrderRequest, it).toEitherLeft()
}, { it })
}

prevNonce = orderResponse.first.nonce

// TODO(fix): replace with orderResponse.third
val finalizeResponse = e2EIRepository.finalize(orderResponse.second, prevNonce).fold({
val finalizeResponse = e2EIRepository.finalize(orderResponse.second, prevNonce).getOrFail {
return E2EIEnrollmentResult.Failed(E2EIEnrollmentResult.E2EIStep.FinalizeRequest, it).toEitherLeft()
}, { it })
}

prevNonce = finalizeResponse.first.nonce

val certificateRequest = e2EIRepository.certificateRequest(finalizeResponse.second, prevNonce).fold({
val certificateRequest = e2EIRepository.certificateRequest(finalizeResponse.second, prevNonce).getOrFail {
return E2EIEnrollmentResult.Failed(E2EIEnrollmentResult.E2EIStep.Certificate, it).toEitherLeft()
}, { it })
}

e2EIRepository.rotateKeysAndMigrateConversations(certificateRequest.response.decodeToString()).onFailure {
return E2EIEnrollmentResult.Failed(E2EIEnrollmentResult.E2EIStep.ConversationMigration, it).toEitherLeft()
}

return Either.Right(E2EIEnrollmentResult.Success(certificateRequest.response.decodeToString()))
return Either.Right(E2EIEnrollmentResult.Finalized(certificateRequest.response.decodeToString()))
}

}
Expand All @@ -128,6 +157,7 @@ sealed interface E2EIEnrollmentResult {
AcmeNewAccount,
AcmeNewOrder,
AcmeNewAuthz,
OAuth,
WireNonce,
DPoPToken,
WireAccessToken,
Expand All @@ -139,13 +169,15 @@ sealed interface E2EIEnrollmentResult {
Certificate
}

class Success(val certificate: String) : E2EIEnrollmentResult
class Initialized(val target: String, val authz: NewAcmeAuthz, val lastNonce: String, val orderLocation: String) : E2EIEnrollmentResult

class Finalized(val certificate: String) : E2EIEnrollmentResult

data class Failed(val step: E2EIStep, val failure: CoreFailure) : E2EIEnrollmentResult {
override fun toString(): String {
return "E2EI enrollment failed at $step: with $failure"
}

fun toEitherLeft() = Either.Left(E2EIFailure(Exception(this.toString())))
fun toEitherLeft() = Either.Left(E2EIFailure.Generic(Exception(this.toString())))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,16 @@ inline fun <T, L, R> Either<L, R>.flatMap(fn: (R) -> Either<L, T>): Either<L, T>
is Right -> fn(value)
}

/**
* Right-biased getOrFail() FP convention which means that Right is assumed to be the default case
* to operate on and return the result. If it is Left, operations like map, flatMap, ... return the Left value unchanged.
*/
inline fun <L, R> Either<L, R>.getOrFail(fn: (failure: L) -> R): R =
when (this) {
is Left -> fn(value)
is Right -> value
}

/**
* Left-biased flatMap() FP convention which means that Left is assumed to be the default case
* to operate on. If it is Right, operations like map, flatMap, ... return the Right value unchanged.
Expand Down
Loading

0 comments on commit b80ce0f

Please sign in to comment.