Skip to content

Commit

Permalink
Login response handler validation rework (#54)
Browse files Browse the repository at this point in the history
* refactor: perform client type validations during decoding, not encoding

* refactor: remove channel activity check in writeSuccessfulResponse

* refactor: separate validateNewConnection from writeSuccessfulResponse
  • Loading branch information
Z-Kris authored Jan 25, 2025
1 parent 691a839 commit e947e32
Show file tree
Hide file tree
Showing 64 changed files with 648 additions and 663 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import net.rsprot.protocol.api.game.GameMessageHandler
import net.rsprot.protocol.api.logging.networkLog
import net.rsprot.protocol.common.client.OldSchoolClientType
import net.rsprot.protocol.loginprot.incoming.util.LoginBlock
import net.rsprot.protocol.loginprot.incoming.util.LoginClientType
import net.rsprot.protocol.loginprot.outgoing.LoginResponse
import java.util.concurrent.TimeUnit

Expand All @@ -34,6 +33,26 @@ public class GameLoginResponseHandler<R>(
public val networkService: NetworkService<R>,
public val ctx: ChannelHandlerContext,
) {
/**
* Validates the new connection by ensuring the connected user hasn't reached an IP limitation due
* to too many connections from the same IP.
* This function does not write any response to the client. It simply returns whether a new connection
* is allowed to take place. The server is responsible for writing the [LoginResponse.TooManyAttempts]
* response back to the client via [writeFailedResponse], should they wish to do so.
*/
public fun validateNewConnection(): Boolean {
val address = ctx.inetAddress()
val count =
networkService
.iNetAddressHandlers
.gameInetAddressTracker
.getCount(address)
return networkService
.iNetAddressHandlers
.inetAddressValidator
.acceptGameConnection(address, count)
}

/**
* Writes a successful login response to the client.
* @param response the login response to write
Expand All @@ -43,46 +62,13 @@ public class GameLoginResponseHandler<R>(
public fun writeSuccessfulResponse(
response: LoginResponse.Ok,
loginBlock: LoginBlock<*>,
): Session<R>? {
): Session<R> {
// Ensure it isn't null - our decoder pre-validates it long before hitting this function,
// so this exception should never be hit.
val oldSchoolClientType =
getOldSchoolClientType(loginBlock)
if (oldSchoolClientType == null || !networkService.isSupported(oldSchoolClientType)) {
networkLog(logger) {
"Unsupported client type received from channel " +
"'${ctx.channel()}': $oldSchoolClientType, login block: $loginBlock"
}
ctx
.writeAndFlush(LoginResponse.InvalidLoginPacket)
.addListener(ChannelFutureListener.CLOSE)
return null
}
if (!ctx.channel().isActive) {
networkLog(logger) {
"Channel '${ctx.channel()}' has gone inactive; login block: $loginBlock"
checkNotNull(loginBlock.clientType.toOldSchoolClientType()) {
"Login client type cannot be null"
}
return null
}
val address = ctx.inetAddress()
val count =
networkService
.iNetAddressHandlers
.gameInetAddressTracker
.getCount(address)
val accepted =
networkService
.iNetAddressHandlers
.inetAddressValidator
.acceptGameConnection(address, count)
// Secondary validation just before we allow the server to log the user in
if (!accepted) {
networkLog(logger) {
"INetAddressValidator rejected game login for channel ${ctx.channel()}"
}
ctx
.writeAndFlush(LoginResponse.TooManyAttempts)
.addListener(ChannelFutureListener.CLOSE)
return null
}
val cipher = createStreamCipherPair(loginBlock)

if (networkService.betaWorld) {
Expand Down Expand Up @@ -112,23 +98,13 @@ public class GameLoginResponseHandler<R>(
public fun writeSuccessfulResponse(
response: LoginResponse.ReconnectOk,
loginBlock: LoginBlock<*>,
): Session<R>? {
): Session<R> {
// Ensure it isn't null - our decoder pre-validates it long before hitting this function,
// so this exception should never be hit.
val oldSchoolClientType =
getOldSchoolClientType(loginBlock)
if (oldSchoolClientType == null || !networkService.isSupported(oldSchoolClientType)) {
networkLog(logger) {
"Unsupported client type received from channel " +
"'${ctx.channel()}': $oldSchoolClientType, login block: $loginBlock"
checkNotNull(loginBlock.clientType.toOldSchoolClientType()) {
"Login client type cannot be null"
}
ctx.writeAndFlush(LoginResponse.InvalidLoginPacket)
return null
}
if (!ctx.channel().isActive) {
networkLog(logger) {
"Channel '${ctx.channel()}' has gone inactive; login block: $loginBlock"
}
return null
}
val (encodingCipher, decodingCipher) = createStreamCipherPair(loginBlock)

// Unlike in the above case, we kind of have to assume it was successful
Expand Down Expand Up @@ -157,20 +133,6 @@ public class GameLoginResponseHandler<R>(
return StreamCipherPair(encodingCipher, decodingCipher)
}

private fun getOldSchoolClientType(loginBlock: LoginBlock<*>): OldSchoolClientType? {
val oldSchoolClientType =
when (loginBlock.clientType) {
LoginClientType.DESKTOP -> OldSchoolClientType.DESKTOP
LoginClientType.ENHANCED_WINDOWS -> OldSchoolClientType.DESKTOP
LoginClientType.ENHANCED_LINUX -> OldSchoolClientType.DESKTOP
LoginClientType.ENHANCED_MAC -> OldSchoolClientType.DESKTOP
LoginClientType.ENHANCED_ANDROID -> OldSchoolClientType.ANDROID
LoginClientType.ENHANCED_IOS -> OldSchoolClientType.IOS
else -> null
}
return oldSchoolClientType
}

private fun createSession(
loginBlock: LoginBlock<*>,
pipeline: ChannelPipeline,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ public class MessageDecoderRepositories private constructor(
public val gameMessageDecoderRepositories: ClientTypeMap<MessageDecoderRepository<ClientProt>>,
) {
public constructor(
clientTypes: List<OldSchoolClientType>,
exp: BigInteger,
mod: BigInteger,
gameMessageDecoderRepositories: ClientTypeMap<MessageDecoderRepository<ClientProt>>,
) : this(
LoginMessageDecoderRepository.build(exp, mod),
LoginMessageDecoderRepository.build(clientTypes, exp, mod),
Js5MessageDecoderRepository.build(),
gameMessageDecoderRepositories,
)
Expand All @@ -49,6 +50,7 @@ public class MessageDecoderRepositories private constructor(
repositories,
)
return MessageDecoderRepositories(
clientTypes,
rsaKeyPair.exponent,
rsaKeyPair.modulus,
clientTypeMap,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package net.rsprot.protocol.loginprot.incoming.util

import net.rsprot.protocol.common.client.OldSchoolClientType

public enum class LoginClientType(
public val id: Int,
) {
Expand All @@ -13,6 +15,18 @@ public enum class LoginClientType(
ENHANCED_LINUX(10),
;

public fun toOldSchoolClientType(): OldSchoolClientType? {
return when (this) {
DESKTOP -> OldSchoolClientType.DESKTOP
ENHANCED_WINDOWS -> OldSchoolClientType.DESKTOP
ENHANCED_LINUX -> OldSchoolClientType.DESKTOP
ENHANCED_MAC -> OldSchoolClientType.DESKTOP
ENHANCED_ANDROID -> OldSchoolClientType.ANDROID
ENHANCED_IOS -> OldSchoolClientType.IOS
else -> null
}
}

public companion object {
public operator fun get(id: Int): LoginClientType =
entries.firstOrNull { it.id == id }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package net.rsprot.protocol.common.loginprot.incoming.codec
import net.rsprot.buffer.JagByteBuf
import net.rsprot.buffer.extensions.toJagByteBuf
import net.rsprot.protocol.ClientProt
import net.rsprot.protocol.common.client.OldSchoolClientType
import net.rsprot.protocol.common.loginprot.incoming.codec.shared.LoginBlockDecoder
import net.rsprot.protocol.common.loginprot.incoming.prot.LoginClientProt
import net.rsprot.protocol.loginprot.incoming.GameLogin
Expand All @@ -14,6 +15,7 @@ import net.rsprot.protocol.message.codec.MessageDecoder
import java.math.BigInteger

public class GameLoginDecoder(
private val supportedClientTypes: List<OldSchoolClientType>,
exp: BigInteger,
mod: BigInteger,
) : LoginBlockDecoder<AuthenticationType<*>>(exp, mod),
Expand All @@ -25,7 +27,7 @@ public class GameLoginDecoder(
// Mark the buffer as "read" as copy function doesn't do it automatically.
buffer.buffer.readerIndex(buffer.buffer.writerIndex())
return GameLogin(copy.toJagByteBuf()) { jagByteBuf, betaWorld ->
decodeLoginBlock(jagByteBuf, betaWorld)
decodeLoginBlock(jagByteBuf, betaWorld, supportedClientTypes)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import net.rsprot.buffer.JagByteBuf
import net.rsprot.buffer.extensions.toJagByteBuf
import net.rsprot.crypto.xtea.XteaKey
import net.rsprot.protocol.ClientProt
import net.rsprot.protocol.common.client.OldSchoolClientType
import net.rsprot.protocol.common.loginprot.incoming.codec.shared.LoginBlockDecoder
import net.rsprot.protocol.common.loginprot.incoming.prot.LoginClientProt
import net.rsprot.protocol.loginprot.incoming.GameReconnect
import net.rsprot.protocol.message.codec.MessageDecoder
import java.math.BigInteger

public class GameReconnectDecoder(
private val supportedClientTypes: List<OldSchoolClientType>,
exp: BigInteger,
mod: BigInteger,
) : LoginBlockDecoder<XteaKey>(exp, mod),
Expand All @@ -22,7 +24,7 @@ public class GameReconnectDecoder(
// Mark the buffer as "read" as copy function doesn't do it automatically.
buffer.buffer.readerIndex(buffer.buffer.writerIndex())
return GameReconnect(copy.toJagByteBuf()) { jagByteBuf, betaWorld ->
decodeLoginBlock(jagByteBuf, betaWorld)
decodeLoginBlock(jagByteBuf, betaWorld, supportedClientTypes)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ import net.rsprot.buffer.extensions.toJagByteBuf
import net.rsprot.crypto.rsa.decipherRsa
import net.rsprot.crypto.xtea.xteaDecrypt
import net.rsprot.protocol.common.RSProtConstants
import net.rsprot.protocol.common.client.OldSchoolClientType
import net.rsprot.protocol.common.loginprot.incoming.codec.shared.exceptions.InvalidVersionException
import net.rsprot.protocol.common.loginprot.incoming.codec.shared.exceptions.UnsupportedClientException
import net.rsprot.protocol.loginprot.incoming.util.CyclicRedundancyCheckBlock
import net.rsprot.protocol.loginprot.incoming.util.HostPlatformStats
import net.rsprot.protocol.loginprot.incoming.util.LoginBlock
import net.rsprot.protocol.loginprot.incoming.util.LoginClientType
import java.math.BigInteger

@Suppress("DuplicatedCode")
Expand All @@ -21,6 +24,7 @@ public abstract class LoginBlockDecoder<T>(
protected fun decodeLoginBlock(
buffer: JagByteBuf,
betaWorld: Boolean,
supportedClientTypes: List<OldSchoolClientType>,
): LoginBlock<T> {
try {
val version = buffer.g4()
Expand All @@ -29,6 +33,11 @@ public abstract class LoginBlockDecoder<T>(
}
val subVersion = buffer.g4()
val firstClientType = buffer.g1()
val loginClientType = LoginClientType[firstClientType]
val oldSchoolClientType = loginClientType.toOldSchoolClientType()
if (oldSchoolClientType !in supportedClientTypes) {
throw UnsupportedClientException
}
val platformType = buffer.g1()
val constZero1 = buffer.g1()
val rsaSize = buffer.g2()
Expand Down Expand Up @@ -71,6 +80,9 @@ public abstract class LoginBlockDecoder<T>(
val constZero2 = xteaBuffer.g1()
val hostPlatformStats = decodeHostPlatformStats(xteaBuffer)
val secondClientType = xteaBuffer.g1()
if (secondClientType != firstClientType) {
throw UnsupportedClientException
}
val crcBlockHeader = xteaBuffer.g4()
val crc =
if (betaWorld) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package net.rsprot.protocol.common.loginprot.incoming.codec.shared.exceptions

/**
* A singleton exception for whenever login decoding fails due to an unsupported client connecting to it.
* It is not ideal to be using exceptions for flow control, but it is by far the easiest option
* here. From a performance standpoint, only building stack traces is slow, which we aren't
* going to be doing for this type of exception.
*/
public data object UnsupportedClientException : RuntimeException() {
@Suppress("unused")
private fun readResolve(): Any = UnsupportedClientException
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package net.rsprot.protocol.common.loginprot.incoming.prot

import net.rsprot.protocol.ProtRepository
import net.rsprot.protocol.common.client.OldSchoolClientType
import net.rsprot.protocol.common.loginprot.incoming.codec.GameLoginDecoder
import net.rsprot.protocol.common.loginprot.incoming.codec.GameReconnectDecoder
import net.rsprot.protocol.common.loginprot.incoming.codec.InitGameConnectionDecoder
Expand All @@ -14,6 +15,7 @@ import java.math.BigInteger
public object LoginMessageDecoderRepository {
@ExperimentalStdlibApi
public fun build(
supportedClientTypes: List<OldSchoolClientType>,
exp: BigInteger,
mod: BigInteger,
): MessageDecoderRepository<LoginClientProt> {
Expand All @@ -24,8 +26,8 @@ public object LoginMessageDecoderRepository {
).apply {
bind(InitGameConnectionDecoder())
bind(InitJs5RemoteConnectionDecoder())
bind(GameLoginDecoder(exp, mod))
bind(GameReconnectDecoder(exp, mod))
bind(GameLoginDecoder(supportedClientTypes, exp, mod))
bind(GameReconnectDecoder(supportedClientTypes, exp, mod))
bind(ProofOfWorkReplyDecoder())
bind(RemainingBetaArchivesDecoder())
}
Expand Down
Loading

0 comments on commit e947e32

Please sign in to comment.