Skip to content

Commit

Permalink
better buffer management
Browse files Browse the repository at this point in the history
  • Loading branch information
Aidan63 committed Nov 23, 2024
1 parent db00bf1 commit c916a19
Showing 1 changed file with 111 additions and 57 deletions.
168 changes: 111 additions & 57 deletions src/hx/libs/ssl/windows/SSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,26 @@ namespace
::Dynamic socket;

/**
* Buffer for handshake or encrypted data.
* This buffer is a fixed sized.
* Buffer for encrypted message data.
*/
::Array<uint8_t> input;
::Array<uint8_t> encryptedBuffer;
/**
* Number of bytes in the input array.
* Number of bytes in the encrypted array.
*/
int received;
int encryptedCursor;

/**
* buffered message data.
* Array is expanded and shrunk as data is added and read.
* Buffer for decrypted message data.
*/
::Array<uint8_t> decrypted;
::Array<uint8_t> decryptedBuffer;
/**
* Position in the decrypted buffer the data starts.
*/
int decryptedLower;
/**
* Position in the decrypted buffer the data ends.
*/
int decryptedUpper;

CredHandle credHandle;
TimeStamp credTimestamp;
Expand All @@ -65,9 +71,11 @@ namespace
SChannelContext(::String inHost)
: host(inHost)
, socket(null())
, input(10000, 10000)
, decrypted(0, 0)
, received(0)
, encryptedBuffer(0, 0)
, decryptedBuffer(0, 0)
, encryptedCursor(0)
, decryptedLower(0)
, decryptedUpper(0)
, credHandle()
, credTimestamp()
, ctxtHandle()
Expand All @@ -79,24 +87,75 @@ namespace
HX_OBJ_WB_NEW_MARKED_OBJECT(this);
}

int DecryptedBytes() const
{
return decryptedUpper - decryptedLower;
}

int DecryptedEndSpace() const
{
return decryptedBuffer->length - decryptedUpper;
}

int DecryptedBeginningSpace() const
{
return decryptedLower;
}

void AppendDecrypted(const uint8_t* ptr, const int length)
{
// We need to make more space for the incoming bytes.
if (length > DecryptedEndSpace())
{
if (DecryptedBeginningSpace() + DecryptedEndSpace() >= length)
{
// If we have a hole at the beginning of the decrypted buffer
// and that space plus the end space is big enough,
// shuffle the existing data down.
DefragmentDecrypted();
}
else
{
if (DecryptedBeginningSpace() > 0)
{
// Defragment in prep for growing the buffer.
DefragmentDecrypted();
}
}
}

decryptedBuffer->memcpy(decryptedUpper, ptr, length);

decryptedUpper += length;
}

void DefragmentDecrypted()
{
auto size = DecryptedBytes();

std::memmove(decryptedBuffer->getBase(), decryptedBuffer->getBase() + decryptedLower, DecryptedBytes());

decryptedLower = 0;
decryptedUpper = size;
}

void __Mark(HX_MARK_PARAMS) override
{
HX_MARK_MEMBER(host);
HX_MARK_MEMBER(socket);
HX_MARK_MEMBER(input);
HX_MARK_MEMBER(decrypted);
HX_MARK_MEMBER(encryptedBuffer);
HX_MARK_MEMBER(decryptedBuffer);
}

#ifdef HXCPP_VISIT_ALLOCS
void __Visit(HX_VISIT_PARAMS) override
{
HX_VISIT_MEMBER(host);
HX_VISIT_MEMBER(socket);
HX_VISIT_MEMBER(input);
HX_VISIT_MEMBER(decrypted);
HX_VISIT_MEMBER(encryptedBuffer);
HX_VISIT_MEMBER(decryptedBuffer);
}
#endif

};

void DestroyCert(Dynamic obj)
Expand Down Expand Up @@ -208,7 +267,9 @@ void _hx_ssl_handshake(Dynamic handle)
auto outputBuffer1 = std::vector<char>(1024);
auto outputBuffer2 = std::vector<char>(1024);

auto initial = true;
auto initial = true;
auto staging = std::array<char, std::numeric_limits<uint16_t>::max()>();
auto cursor = 0;

auto outputBuffers = std::array<SecBuffer, 3>();
auto outputBufferDescription = SecBufferDesc();
Expand All @@ -219,7 +280,7 @@ void _hx_ssl_handshake(Dynamic handle)
while (true)
{
// Input Buffers
init_sec_buffer(&inputBuffers[0], SECBUFFER_TOKEN, ctx->input->getBase(), ctx->received);
init_sec_buffer(&inputBuffers[0], SECBUFFER_TOKEN, staging.data(), cursor);
init_sec_buffer(&inputBuffers[1], SECBUFFER_EMPTY, nullptr, 0);
init_sec_buffer_desc(&inputBufferDescription, inputBuffers.data(), inputBuffers.size());

Expand Down Expand Up @@ -252,22 +313,18 @@ void _hx_ssl_handshake(Dynamic handle)
{
printf("handshake complete\n");

QueryContextAttributes(&ctx->ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &ctx->sizes);

ctx->encryptedBuffer->EnsureSize(ctx->sizes.cbMaximumMessage);
ctx->decryptedBuffer->EnsureSize(ctx->sizes.cbMaximumMessage);

if (SECBUFFER_EXTRA == inputBuffers[1].BufferType)
{
std::memmove(ctx->input->getBase(), ctx->input->getBase() + ctx->received - inputBuffers[1].cbBuffer, inputBuffers[1].cbBuffer);

ctx->received = inputBuffers[1].cbBuffer;
ctx->encryptedBuffer->memcpy(0, static_cast<uint8_t*>(inputBuffers[1].pvBuffer), inputBuffers[1].cbBuffer);
ctx->encryptedCursor = inputBuffers[1].cbBuffer;

printf("%i bytes of extra data found\n", inputBuffers[1].cbBuffer);
}
else
{
ctx->received = 0;
}

QueryContextAttributes(&ctx->ctxtHandle, SECPKG_ATTR_STREAM_SIZES, &ctx->sizes);

ctx->input->EnsureSize(ctx->sizes.cbMaximumMessage);

return;
}
Expand All @@ -279,18 +336,18 @@ void _hx_ssl_handshake(Dynamic handle)
// https://learn.microsoft.com/en-us/windows/win32/secauthn/acceptsecuritycontext--schannel
if (SECBUFFER_MISSING != inputBuffers[1].BufferType)
{
auto targetReceived = ctx->received + inputBuffers[1].cbBuffer;
auto targetReceived = cursor + inputBuffers[1].cbBuffer;

// Loop until we've read at least the required amount to avoid excessive calls to InitializeSecurityContextA
while (ctx->received < targetReceived)
while (cursor < targetReceived)
{
auto read = recv(wrapper->socket, ctx->input->getBase() + ctx->received, ctx->input->length - ctx->received, 0);
auto read = recv(wrapper->socket, staging.data() + cursor, staging.size() - cursor, 0);
if (read <= 0)
{
hx::Throw(HX_CSTRING("Failed to read handshake message"));
}

ctx->received += read;
cursor += read;
}
}
break;
Expand All @@ -303,13 +360,13 @@ void _hx_ssl_handshake(Dynamic handle)
// Otherwise we can just set received to zero as we've consumed all data in the buffer.
if (SECBUFFER_EXTRA == inputBuffers[1].BufferType)
{
std::memmove(ctx->input->getBase(), ctx->input->getBase() + ctx->received - inputBuffers[1].cbBuffer, inputBuffers[1].cbBuffer);
std::memmove(staging.data(), staging.data() + cursor - inputBuffers[1].cbBuffer, inputBuffers[1].cbBuffer);

ctx->received = inputBuffers[1].cbBuffer;
cursor = inputBuffers[1].cbBuffer;
}
else
{
ctx->received = 0;
cursor = 0;
}

// Send all data in the output token buffer to the remote end.
Expand All @@ -330,13 +387,13 @@ void _hx_ssl_handshake(Dynamic handle)
}

// Read more data from the remote end and loop.
auto read = recv(wrapper->socket, ctx->input->getBase() + ctx->received, ctx->input->length - ctx->received, 0);
auto read = recv(wrapper->socket, staging.data() + cursor, staging.size() - cursor, 0);
if (read <= 0)
{
hx::Throw(HX_CSTRING("Failed to read handshake message"));
}

ctx->received += read;
cursor += read;

break;
}
Expand Down Expand Up @@ -452,24 +509,24 @@ int _hx_ssl_recv(Dynamic hssl, Array<unsigned char> buf, int p, int l)

while (true)
{
if (ctx->decrypted->length > 0)
if (ctx->DecryptedBytes() > 0)
{
auto taking = std::min(l, ctx->decrypted->length);
auto taking = std::min(l, ctx->DecryptedBytes());

printf("taking %i cached bytes\n", taking);

buf->memcpy(p, &ctx->decrypted[0], taking);
buf->memcpy(p, &ctx->decryptedBuffer[ctx->decryptedLower], taking);

ctx->decrypted->removeRange(0, taking);
ctx->decryptedLower += taking;

return taking;
}

if (ctx->received > 0)
if (ctx->encryptedCursor > 0)
{
auto result = SECURITY_STATUS{ SEC_E_OK };

init_sec_buffer(&buffers[0], SECBUFFER_DATA, ctx->input->getBase(), ctx->received);
init_sec_buffer(&buffers[0], SECBUFFER_DATA, ctx->encryptedBuffer->getBase(), ctx->encryptedCursor);
init_sec_buffer(&buffers[1], SECBUFFER_EMPTY, nullptr, 0);
init_sec_buffer(&buffers[2], SECBUFFER_EMPTY, nullptr, 0);
init_sec_buffer(&buffers[3], SECBUFFER_EMPTY, nullptr, 0);
Expand All @@ -493,24 +550,21 @@ int _hx_ssl_recv(Dynamic hssl, Array<unsigned char> buf, int p, int l)
printf("extra ( ptr : %p, len : %i)\n", buffers[3].pvBuffer, buffers[3].cbBuffer);
}

for (auto i = 0; i < buffers[1].cbBuffer; i++)
{
ctx->decrypted->push(static_cast<uint8_t*>(buffers[1].pvBuffer)[i]);
}
ctx->AppendDecrypted(static_cast<uint8_t*>(buffers[1].pvBuffer), buffers[1].cbBuffer);

if (SECBUFFER_EXTRA == buffers[3].BufferType)
{
printf("moving %i to extra buffer\n", buffers[3].cbBuffer);

std::memmove(ctx->input->getBase(), buffers[3].pvBuffer, buffers[3].cbBuffer);
std::memmove(ctx->encryptedBuffer->getBase(), buffers[3].pvBuffer, buffers[3].cbBuffer);

ctx->received = buffers[3].cbBuffer;
ctx->encryptedCursor = buffers[3].cbBuffer;
}
else
{
printf("no extra buffer, resetting recieved\n");

ctx->received = 0;
ctx->encryptedCursor = 0;
}

break;
Expand All @@ -524,22 +578,22 @@ int _hx_ssl_recv(Dynamic hssl, Array<unsigned char> buf, int p, int l)
printf("\tSECBUFFER_MISSING indicates it wants %i more bytes\n", buffers[0].cbBuffer);
printf("\tcurrent receive position is %i, so %i free space\n", ctx->received, ctx->input->length - ctx->received);

if (ctx->received + buffers[0].cbBuffer > ctx->input->length)
if (ctx->encryptedCursor + buffers[0].cbBuffer > ctx->encryptedBuffer->length)
{
printf("\t\tgrowing input buffer\n");

ctx->input->EnsureSize(ctx->received + buffers[0].cbBuffer);
ctx->encryptedBuffer->EnsureSize(ctx->encryptedCursor + buffers[0].cbBuffer);
}

auto count = recv(wrapper->socket, ctx->input->getBase() + ctx->received, buffers[0].cbBuffer, 0);
auto count = recv(wrapper->socket, ctx->encryptedBuffer->getBase() + ctx->encryptedCursor, buffers[0].cbBuffer, 0);
if (count <= 0)
{
printf("about to throw leaving behind %i encrypted and %i decrypted bytes\n", ctx->received, ctx->decrypted->length);

block_error();
}

ctx->received += count;
ctx->encryptedCursor += count;

printf("socket read, added %i\n", count);

Expand All @@ -554,15 +608,15 @@ int _hx_ssl_recv(Dynamic hssl, Array<unsigned char> buf, int p, int l)
{
printf("no buffered input, reading block from socket (%i)\n", ctx->sizes.cbBlockSize);

auto count = recv(wrapper->socket, ctx->input->getBase(), ctx->input->length, 0);
auto count = recv(wrapper->socket, ctx->encryptedBuffer->getBase(), ctx->encryptedBuffer->length, 0);
if (count <= 0)
{
block_error();
}

printf("added to received buffer (total %i)\n", count);

ctx->received = count;
ctx->encryptedCursor = count;
}
}

Expand Down

0 comments on commit c916a19

Please sign in to comment.