Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SHARD-460 - feat: Make payload size limit a shardus config #15

Merged
merged 10 commits into from
Dec 9, 2024
38 changes: 23 additions & 15 deletions shardus_net/src/header/header_v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ extern crate serde_json;
use crate::compression::Compression;
use serde::Deserialize;

use crate::{check_variable_size, HEADER_SIZE_LIMIT_IN_BYTES};
use crate::check_variable_size;
use crate::NetConfig;

#[derive(Deserialize)]
pub struct HeaderV1 {
Expand All @@ -23,6 +24,7 @@ pub struct HeaderV1 {
pub compression: Compression,
}

const SENDER_ID_SIZE: usize = 64;
impl HeaderV1 {
// Serialize the struct into a Vec<u8>
pub fn serialize(&self) -> Vec<u8> {
Expand Down Expand Up @@ -59,7 +61,7 @@ impl HeaderV1 {
}

// Deserialize a Vec<u8> cursor into a HeaderV1 struct
pub fn deserialize(cursor: &mut Cursor<Vec<u8>>) -> Option<Self> {
pub fn deserialize(cursor: &mut Cursor<Vec<u8>>, net_config: NetConfig) -> Option<Self> {
// Deserialize uuid
let mut uuid_bytes = [0u8; 16];
cursor.read_exact(&mut uuid_bytes).ok()?;
Expand All @@ -74,7 +76,7 @@ impl HeaderV1 {
let mut sender_id_len_bytes = [0u8; 4];
cursor.read_exact(&mut sender_id_len_bytes).ok()?;
let sender_id_len = u32::from_le_bytes(sender_id_len_bytes);
check_variable_size(sender_id_len, 64);
check_variable_size(sender_id_len, SENDER_ID_SIZE);
let mut sender_id_bytes = vec![0u8; sender_id_len as usize];
cursor.read_exact(&mut sender_id_bytes).ok()?;
let sender_id = String::from_utf8(sender_id_bytes).ok()?;
Expand All @@ -83,7 +85,7 @@ impl HeaderV1 {
let mut tracker_id_len_bytes = [0u8; 4];
cursor.read_exact(&mut tracker_id_len_bytes).ok()?;
let tracker_id_len = u32::from_le_bytes(tracker_id_len_bytes);
check_variable_size(tracker_id_len, HEADER_SIZE_LIMIT_IN_BYTES);
check_variable_size(tracker_id_len, net_config.header_size_limit);
let mut tracker_id_bytes = vec![0u8; tracker_id_len as usize];
cursor.read_exact(&mut tracker_id_bytes).ok()?;
let tracker_id = String::from_utf8(tracker_id_bytes).ok()?;
Expand All @@ -92,7 +94,7 @@ impl HeaderV1 {
let mut verification_data_len_bytes = [0u8; 4];
cursor.read_exact(&mut verification_data_len_bytes).ok()?;
let verification_data_len = u32::from_le_bytes(verification_data_len_bytes);
check_variable_size(verification_data_len, HEADER_SIZE_LIMIT_IN_BYTES);
check_variable_size(verification_data_len, net_config.header_size_limit);
let mut verification_data_bytes = vec![0u8; verification_data_len as usize];
cursor.read_exact(&mut verification_data_bytes).ok()?;
let verification_data = String::from_utf8(verification_data_bytes).ok()?;
Expand Down Expand Up @@ -146,10 +148,13 @@ mod tests {
verification_data: "verification_data_1".to_string(),
compression: Compression::None,
};

let net_config = NetConfig {
header_size_limit: 2 * 1024,
payload_size_limit: 2 * 1024 * 1024,
};
let serialized = header.serialize();
let mut cursor = Cursor::new(serialized);
let deserialized = HeaderV1::deserialize(&mut cursor).unwrap();
let deserialized = HeaderV1::deserialize(&mut cursor, net_config).unwrap();

assert_eq!(header.uuid, deserialized.uuid);
assert_eq!(header.message_length, deserialized.message_length);
Expand Down Expand Up @@ -194,23 +199,26 @@ mod tests {
#[test]
#[should_panic(expected = "variable_len exceeds the limit")]
fn test_check_variable_size_panic() {
use crate::HEADER_SIZE_LIMIT_IN_BYTES;

let net_config = NetConfig {
header_size_limit: 2 * 1024,
payload_size_limit: 2 * 1024 * 1024,
};
// Define a variable length that exceeds the limit
let oversized_length = HEADER_SIZE_LIMIT_IN_BYTES as u32 + 1;
let oversized_length = net_config.header_size_limit as u32 + 1;

// Call the function, expecting it to panic
check_variable_size(oversized_length, HEADER_SIZE_LIMIT_IN_BYTES);
check_variable_size(oversized_length, net_config.header_size_limit);
}

#[test]
fn test_check_variable_size_no_panic() {
use crate::HEADER_SIZE_LIMIT_IN_BYTES;

let net_config = NetConfig {
header_size_limit: 2 * 1024,
payload_size_limit: 2 * 1024 * 1024,
};
// Define a variable length within the limit : 2048 (0x800)
let valid_length = 0x799;

// Call the function, ensuring it does not panic
check_variable_size(valid_length, HEADER_SIZE_LIMIT_IN_BYTES);
check_variable_size(valid_length, net_config.header_size_limit);
}
}
5 changes: 3 additions & 2 deletions shardus_net/src/header_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::io::Cursor;

use crate::header::header_types::Header;
use crate::header::header_v1::HeaderV1;
use crate::NetConfig;

pub fn wrap_serialized_message(mut serialized_message: Vec<u8>) -> Vec<u8> {
let mut buffer = Vec::new();
Expand All @@ -10,10 +11,10 @@ pub fn wrap_serialized_message(mut serialized_message: Vec<u8>) -> Vec<u8> {
buffer
}

pub fn header_deserialize_factory(version: u8, serialized_header_cursor: &mut Cursor<Vec<u8>>) -> Option<Header> {
pub fn header_deserialize_factory(version: u8, serialized_header_cursor: &mut Cursor<Vec<u8>>, net_config: NetConfig) -> Option<Header> {
match version {
1 => {
let deserialized = HeaderV1::deserialize(serialized_header_cursor)?;
let deserialized = HeaderV1::deserialize(serialized_header_cursor, net_config)?;
Some(Header::V1(deserialized))
}
_ => None,
Expand Down
24 changes: 17 additions & 7 deletions shardus_net/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ use tokio::sync::Mutex;
use crate::shardus_net_sender::Connection;

const ENABLE_COMPRESSION: bool = false;
const HEADER_SIZE_LIMIT_IN_BYTES: usize = 2 * 1024; // 2KB
const PAYLOAD_SIZE_LIMIT_IN_BYTES: usize = 2 * 1024 * 1024; // 2MB

#[derive(Copy, Clone)]
pub struct NetConfig {
pub header_size_limit: usize,
pub payload_size_limit: usize,
}
const SIGNATURE_SIZE_LIMIT_IN_BYTES: usize = 96;
const OWNER_SIZE_LIMIT_IN_BYTES: usize = 32;

Expand All @@ -52,13 +56,19 @@ fn create_shardus_net(mut cx: FunctionContext) -> JsResult<JsObject> {
let use_lru = cx.argument::<JsBoolean>(2)?.value(cx);
let lru_size = cx.argument::<JsNumber>(3)?.value(cx);
let hash_key = cx.argument::<JsString>(4)?.value(cx);
let hex_signing_sk = cx.argument::<JsString>(5)?.value(cx);
let payload_size_limit = cx.argument::<JsNumber>(6)?.value(cx) as usize;
let header_size_limit = cx.argument::<JsNumber>(7)?.value(cx) as usize;

let net_config = NetConfig {
jintukumardas marked this conversation as resolved.
Show resolved Hide resolved
header_size_limit,
payload_size_limit,
};

shardus_crypto::initialize_shardus_crypto_instance(&hash_key);

let hex_signing_sk = cx.argument::<JsString>(5)?.value(cx);
let key_pair = shardus_crypto::get_shardus_crypto_instance().get_key_pair_using_sk(&crypto::HexStringOrBuffer::Hex(hex_signing_sk));

let shardus_net_listener = create_shardus_net_listener(cx, port, host)?;
let shardus_net_listener = create_shardus_net_listener(cx, port, host, net_config)?;
let shardus_net_sender = create_shardus_net_sender(use_lru, NonZeroUsize::new(lru_size as usize).unwrap(), key_pair);
let (stats, stats_incrementers) = Stats::new();
let shardus_net_listener = cx.boxed(shardus_net_listener);
Expand Down Expand Up @@ -419,11 +429,11 @@ fn evict_socket(mut cx: FunctionContext) -> JsResult<JsUndefined> {
}
}

fn create_shardus_net_listener(cx: &mut FunctionContext, port: f64, host: String) -> Result<Arc<ShardusNetListener>, Throw> {
fn create_shardus_net_listener(cx: &mut FunctionContext, port: f64, host: String, net_config: NetConfig) -> Result<Arc<ShardusNetListener>, Throw> {
// @TODO: Verify that a javascript number properly converts here without loss.
let address = (host, port as u16);

let shardus_net = ShardusNetListener::new(address);
let shardus_net = ShardusNetListener::new(address, net_config);

match shardus_net {
Ok(net) => Ok(Arc::new(net)),
Expand Down
21 changes: 15 additions & 6 deletions shardus_net/src/message.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io::{Cursor, Read, Write};

use crate::{check_variable_size, HEADER_SIZE_LIMIT_IN_BYTES, OWNER_SIZE_LIMIT_IN_BYTES, PAYLOAD_SIZE_LIMIT_IN_BYTES, SIGNATURE_SIZE_LIMIT_IN_BYTES};
use crate::NetConfig;
use crate::{check_variable_size, OWNER_SIZE_LIMIT_IN_BYTES, SIGNATURE_SIZE_LIMIT_IN_BYTES};
use crypto::Format::Buffer;
use crypto::{KeyPair, ShardusCrypto};

Expand Down Expand Up @@ -84,7 +85,7 @@ impl Message {
buffer
}

pub fn deserialize(cursor: &mut Cursor<Vec<u8>>) -> Option<Message> {
pub fn deserialize(cursor: &mut Cursor<Vec<u8>>, net_config: NetConfig) -> Option<Message> {
// Deserialize header_version
let mut header_version_bytes = [0u8; 1];
cursor.read_exact(&mut header_version_bytes).ok()?;
Expand All @@ -94,7 +95,7 @@ impl Message {
let mut header_len_bytes = [0u8; 4];
cursor.read_exact(&mut header_len_bytes).ok()?;
let header_len = u32::from_le_bytes(header_len_bytes);
check_variable_size(header_len, HEADER_SIZE_LIMIT_IN_BYTES);
check_variable_size(header_len, net_config.header_size_limit);
let mut header_bytes = vec![0u8; header_len as usize];
cursor.read_exact(&mut header_bytes).ok()?;
let header = header_bytes;
Expand All @@ -103,7 +104,8 @@ impl Message {
let mut data_len_bytes = [0u8; 4];
cursor.read_exact(&mut data_len_bytes).ok()?;
let data_len = u32::from_le_bytes(data_len_bytes);
check_variable_size(data_len, PAYLOAD_SIZE_LIMIT_IN_BYTES);
let data_len_limit = net_config.payload_size_limit - header_len as usize; // Since Payload size = header + data
check_variable_size(data_len, data_len_limit);
let mut data_bytes = vec![0u8; data_len as usize];
cursor.read_exact(&mut data_bytes).ok()?;
let data = data_bytes;
Expand Down Expand Up @@ -190,7 +192,10 @@ mod tests {
owner: vec![0x12, 0x34, 0x56, 0x78],
sig: vec![0x9a, 0xbc, 0xde, 0xf0],
};

let net_config = NetConfig {
header_size_limit: 2 * 1024,
payload_size_limit: 2 * 1024 * 1024,
};
let serialized = sign.serialize();
let mut cursor = Cursor::new(serialized);
let deserialized = Sign::deserialize(&mut cursor).unwrap();
Expand All @@ -213,9 +218,13 @@ mod tests {
sign,
};

let net_config = NetConfig {
header_size_limit: 2 * 1024,
payload_size_limit: 2 * 1024 * 1024,
};
let serialized = message.serialize();
let mut cursor = Cursor::new(serialized);
let deserialized = Message::deserialize(&mut cursor).unwrap();
let deserialized = Message::deserialize(&mut cursor, net_config).unwrap();

assert_eq!(message.header_version, deserialized.header_version);
assert_eq!(message.header, deserialized.header);
Expand Down
43 changes: 19 additions & 24 deletions shardus_net/src/shardus_net_listener.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use super::runtime::RUNTIME;
use crate::header::header_types::RequestMetadata;
use crate::header_factory::header_deserialize_factory;
use crate::message::Message;
use crate::{shardus_crypto, HEADER_SIZE_LIMIT_IN_BYTES, PAYLOAD_SIZE_LIMIT_IN_BYTES};

use super::runtime::RUNTIME;

use crate::shardus_crypto;
use crate::NetConfig;
use log::{error, info};
use std::io::Cursor;
use std::net::{SocketAddr, ToSocketAddrs};
Expand All @@ -18,6 +17,7 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};

pub struct ShardusNetListener {
address: SocketAddr,
net_config: NetConfig,
}

#[derive(Error, Debug)]
Expand All @@ -34,31 +34,31 @@ pub enum ListenerError {
type ListenerResult<T> = Result<T, ListenerError>;

impl ShardusNetListener {
pub fn new<A: ToSocketAddrs>(address: A) -> Result<Self, ()> {
pub fn new<A: ToSocketAddrs>(address: A, net_config: NetConfig) -> Result<Self, ()> {
let mut addresses = address.to_socket_addrs().map_err(|_| ())?;
let address = addresses.next().ok_or(())?;

Ok(Self { address })
Ok(Self { address, net_config })
}

pub fn listen(&self) -> UnboundedReceiver<(String, SocketAddr, Option<RequestMetadata>)> {
Self::spawn_listener(self.address)
Self::spawn_listener(self.address, self.net_config)
}

fn spawn_listener(address: SocketAddr) -> UnboundedReceiver<(String, SocketAddr, Option<RequestMetadata>)> {
fn spawn_listener(address: SocketAddr, net_config: NetConfig) -> UnboundedReceiver<(String, SocketAddr, Option<RequestMetadata>)> {
let (tx, rx) = unbounded_channel();
RUNTIME.spawn(Self::bind_to_socket(address, tx));
RUNTIME.spawn(Self::bind_to_socket(address, tx, net_config));
rx
}

async fn bind_to_socket(address: SocketAddr, tx: UnboundedSender<(String, SocketAddr, Option<RequestMetadata>)>) {
async fn bind_to_socket(address: SocketAddr, tx: UnboundedSender<(String, SocketAddr, Option<RequestMetadata>)>, net_config: NetConfig) {
loop {
let listener = TcpListener::bind(address).await;

match listener {
Ok(listener) => {
let tx = tx.clone();
match Self::accept_connections(listener, tx).await {
match Self::accept_connections(listener, tx, net_config.clone()).await {
Ok(_) => unreachable!(),
Err(err) => {
error!("Failed to accept connection to {} due to {}", address, err)
Expand All @@ -72,13 +72,14 @@ impl ShardusNetListener {
}
}

async fn accept_connections(listener: TcpListener, received_msg_tx: UnboundedSender<(String, SocketAddr, Option<RequestMetadata>)>) -> std::io::Result<()> {
async fn accept_connections(listener: TcpListener, received_msg_tx: UnboundedSender<(String, SocketAddr, Option<RequestMetadata>)>, net_config: NetConfig) -> std::io::Result<()> {
loop {
let (socket, remote_addr) = listener.accept().await?;
let received_msg_tx = received_msg_tx.clone();
let net_config = net_config.clone();

RUNTIME.spawn(async move {
let result = Self::receive(socket, remote_addr, received_msg_tx).await;
let result = Self::receive(socket, remote_addr, received_msg_tx, net_config).await;
match result {
Ok(_) => info!("Connection safely completed and shutdown with {}", remote_addr),
Err(err) => {
Expand All @@ -89,11 +90,11 @@ impl ShardusNetListener {
}
}

async fn receive(socket_stream: TcpStream, remote_addr: SocketAddr, received_msg_tx: UnboundedSender<(String, SocketAddr, Option<RequestMetadata>)>) -> ListenerResult<()> {
async fn receive(socket_stream: TcpStream, remote_addr: SocketAddr, received_msg_tx: UnboundedSender<(String, SocketAddr, Option<RequestMetadata>)>, net_config: NetConfig) -> ListenerResult<()> {
let mut socket_stream: TcpStream = socket_stream;
while let Ok(msg_len) = socket_stream.read_u32().await {
if (msg_len as usize) > PAYLOAD_SIZE_LIMIT_IN_BYTES {
error!("Message length exceeds the limit of 2MB");
if (msg_len as usize) > net_config.payload_size_limit {
error!("Message length exceeds the limit of {} bytes", net_config.payload_size_limit);
continue;
}

Expand All @@ -115,21 +116,15 @@ impl ShardusNetListener {
let msg_bytes = &buffer[1..];

let mut cursor = Cursor::new(msg_bytes.to_vec());
let message = Message::deserialize(&mut cursor).expect("Failed to deserialize message");

if message.header.len() > HEADER_SIZE_LIMIT_IN_BYTES {
error!("Header exceeds the limit of {} bytes", HEADER_SIZE_LIMIT_IN_BYTES);
continue;
}

let message = Message::deserialize(&mut cursor, net_config).expect("Failed to deserialize message");
if !message.verify(shardus_crypto::get_shardus_crypto_instance()) {
error!("Failed to verify message signature");
continue;
}
info!("Message verified!");

let header_cursor = &mut Cursor::new(message.header);
let header = header_deserialize_factory(message.header_version, header_cursor).expect("Failed to deserialize header");
let header = header_deserialize_factory(message.header_version, header_cursor, net_config).expect("Failed to deserialize header");

let data = message.data;

Expand Down
13 changes: 12 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,23 @@ export const Sn = (opts: SnOpts) => {
const LRU_SIZE = (opts.senderOpts && opts.senderOpts.lruSize) || 1028
const HASH_KEY = opts.crypto.hashKey
const SIGNING_SECRET_KEY_HEX = opts.crypto.signingSecretKeyHex
const PAYLOAD_SIZE_LIMIT = opts.payloadOpts?.payloadSizeLimitInBytes || 2 * 1024 * 1024 // 2MB
const HEADER_SIZE_LIMIT = opts.payloadOpts?.headerSizeLimitInBytes || 2 * 1024 // 2KB

const HEADER_OPTS = opts.headerOpts || {
sendHeaderVersion: 0,
}

const _net = net.Sn(PORT, ADDRESS, USE_LRU_CACHE, LRU_SIZE, HASH_KEY, SIGNING_SECRET_KEY_HEX)
const _net = net.Sn(
PORT,
ADDRESS,
USE_LRU_CACHE,
LRU_SIZE,
HASH_KEY,
SIGNING_SECRET_KEY_HEX,
PAYLOAD_SIZE_LIMIT,
HEADER_SIZE_LIMIT
)

net.setLoggingEnabled(false)

Expand Down
Loading
Loading