Skip to content

Commit

Permalink
SHARD-460 - feat: Make payload size limit a shardus config (#15)
Browse files Browse the repository at this point in the history
* feat: add payload size limit to SnOpts type and update Sn function to use the provided limit

* Add additional network param as shardus config

* Make sender id constant

* Make sign params constant

* Update tests

* Update test files

* Fix test_multi_send bug

* Move payload and header sizes under same config

* Update test file

* Remove unnecessary clone for NetConfig
  • Loading branch information
jintukumardas authored Dec 9, 2024
1 parent e7889bd commit 023b6dd
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 70 deletions.
38 changes: 23 additions & 15 deletions shardus_net/src/header/header_v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ extern crate serde_json;
use crate::compression::Compression;
use serde::Deserialize;

use crate::{check_variable_size, HEADER_SIZE_LIMIT_IN_BYTES};
use crate::check_variable_size;
use crate::NetConfig;

#[derive(Deserialize)]
pub struct HeaderV1 {
Expand All @@ -23,6 +24,7 @@ pub struct HeaderV1 {
pub compression: Compression,
}

const SENDER_ID_SIZE: usize = 64;
impl HeaderV1 {
// Serialize the struct into a Vec<u8>
pub fn serialize(&self) -> Vec<u8> {
Expand Down Expand Up @@ -59,7 +61,7 @@ impl HeaderV1 {
}

// Deserialize a Vec<u8> cursor into a HeaderV1 struct
pub fn deserialize(cursor: &mut Cursor<Vec<u8>>) -> Option<Self> {
pub fn deserialize(cursor: &mut Cursor<Vec<u8>>, net_config: NetConfig) -> Option<Self> {
// Deserialize uuid
let mut uuid_bytes = [0u8; 16];
cursor.read_exact(&mut uuid_bytes).ok()?;
Expand All @@ -74,7 +76,7 @@ impl HeaderV1 {
let mut sender_id_len_bytes = [0u8; 4];
cursor.read_exact(&mut sender_id_len_bytes).ok()?;
let sender_id_len = u32::from_le_bytes(sender_id_len_bytes);
check_variable_size(sender_id_len, 64);
check_variable_size(sender_id_len, SENDER_ID_SIZE);
let mut sender_id_bytes = vec![0u8; sender_id_len as usize];
cursor.read_exact(&mut sender_id_bytes).ok()?;
let sender_id = String::from_utf8(sender_id_bytes).ok()?;
Expand All @@ -83,7 +85,7 @@ impl HeaderV1 {
let mut tracker_id_len_bytes = [0u8; 4];
cursor.read_exact(&mut tracker_id_len_bytes).ok()?;
let tracker_id_len = u32::from_le_bytes(tracker_id_len_bytes);
check_variable_size(tracker_id_len, HEADER_SIZE_LIMIT_IN_BYTES);
check_variable_size(tracker_id_len, net_config.header_size_limit);
let mut tracker_id_bytes = vec![0u8; tracker_id_len as usize];
cursor.read_exact(&mut tracker_id_bytes).ok()?;
let tracker_id = String::from_utf8(tracker_id_bytes).ok()?;
Expand All @@ -92,7 +94,7 @@ impl HeaderV1 {
let mut verification_data_len_bytes = [0u8; 4];
cursor.read_exact(&mut verification_data_len_bytes).ok()?;
let verification_data_len = u32::from_le_bytes(verification_data_len_bytes);
check_variable_size(verification_data_len, HEADER_SIZE_LIMIT_IN_BYTES);
check_variable_size(verification_data_len, net_config.header_size_limit);
let mut verification_data_bytes = vec![0u8; verification_data_len as usize];
cursor.read_exact(&mut verification_data_bytes).ok()?;
let verification_data = String::from_utf8(verification_data_bytes).ok()?;
Expand Down Expand Up @@ -146,10 +148,13 @@ mod tests {
verification_data: "verification_data_1".to_string(),
compression: Compression::None,
};

let net_config = NetConfig {
header_size_limit: 2 * 1024,
payload_size_limit: 2 * 1024 * 1024,
};
let serialized = header.serialize();
let mut cursor = Cursor::new(serialized);
let deserialized = HeaderV1::deserialize(&mut cursor).unwrap();
let deserialized = HeaderV1::deserialize(&mut cursor, net_config).unwrap();

assert_eq!(header.uuid, deserialized.uuid);
assert_eq!(header.message_length, deserialized.message_length);
Expand Down Expand Up @@ -194,23 +199,26 @@ mod tests {
#[test]
#[should_panic(expected = "variable_len exceeds the limit")]
fn test_check_variable_size_panic() {
use crate::HEADER_SIZE_LIMIT_IN_BYTES;

let net_config = NetConfig {
header_size_limit: 2 * 1024,
payload_size_limit: 2 * 1024 * 1024,
};
// Define a variable length that exceeds the limit
let oversized_length = HEADER_SIZE_LIMIT_IN_BYTES as u32 + 1;
let oversized_length = net_config.header_size_limit as u32 + 1;

// Call the function, expecting it to panic
check_variable_size(oversized_length, HEADER_SIZE_LIMIT_IN_BYTES);
check_variable_size(oversized_length, net_config.header_size_limit);
}

#[test]
fn test_check_variable_size_no_panic() {
use crate::HEADER_SIZE_LIMIT_IN_BYTES;

let net_config = NetConfig {
header_size_limit: 2 * 1024,
payload_size_limit: 2 * 1024 * 1024,
};
// Define a variable length within the limit : 2048 (0x800)
let valid_length = 0x799;

// Call the function, ensuring it does not panic
check_variable_size(valid_length, HEADER_SIZE_LIMIT_IN_BYTES);
check_variable_size(valid_length, net_config.header_size_limit);
}
}
5 changes: 3 additions & 2 deletions shardus_net/src/header_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::io::Cursor;

use crate::header::header_types::Header;
use crate::header::header_v1::HeaderV1;
use crate::NetConfig;

pub fn wrap_serialized_message(mut serialized_message: Vec<u8>) -> Vec<u8> {
let mut buffer = Vec::new();
Expand All @@ -10,10 +11,10 @@ pub fn wrap_serialized_message(mut serialized_message: Vec<u8>) -> Vec<u8> {
buffer
}

pub fn header_deserialize_factory(version: u8, serialized_header_cursor: &mut Cursor<Vec<u8>>) -> Option<Header> {
pub fn header_deserialize_factory(version: u8, serialized_header_cursor: &mut Cursor<Vec<u8>>, net_config: NetConfig) -> Option<Header> {
match version {
1 => {
let deserialized = HeaderV1::deserialize(serialized_header_cursor)?;
let deserialized = HeaderV1::deserialize(serialized_header_cursor, net_config)?;
Some(Header::V1(deserialized))
}
_ => None,
Expand Down
24 changes: 17 additions & 7 deletions shardus_net/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ use tokio::sync::Mutex;
use crate::shardus_net_sender::Connection;

const ENABLE_COMPRESSION: bool = false;
const HEADER_SIZE_LIMIT_IN_BYTES: usize = 2 * 1024; // 2KB
const PAYLOAD_SIZE_LIMIT_IN_BYTES: usize = 2 * 1024 * 1024; // 2MB

#[derive(Copy, Clone)]
pub struct NetConfig {
pub header_size_limit: usize,
pub payload_size_limit: usize,
}
const SIGNATURE_SIZE_LIMIT_IN_BYTES: usize = 96;
const OWNER_SIZE_LIMIT_IN_BYTES: usize = 32;

Expand All @@ -52,13 +56,19 @@ fn create_shardus_net(mut cx: FunctionContext) -> JsResult<JsObject> {
let use_lru = cx.argument::<JsBoolean>(2)?.value(cx);
let lru_size = cx.argument::<JsNumber>(3)?.value(cx);
let hash_key = cx.argument::<JsString>(4)?.value(cx);
let hex_signing_sk = cx.argument::<JsString>(5)?.value(cx);
let payload_size_limit = cx.argument::<JsNumber>(6)?.value(cx) as usize;
let header_size_limit = cx.argument::<JsNumber>(7)?.value(cx) as usize;

let net_config = NetConfig {
header_size_limit,
payload_size_limit,
};

shardus_crypto::initialize_shardus_crypto_instance(&hash_key);

let hex_signing_sk = cx.argument::<JsString>(5)?.value(cx);
let key_pair = shardus_crypto::get_shardus_crypto_instance().get_key_pair_using_sk(&crypto::HexStringOrBuffer::Hex(hex_signing_sk));

let shardus_net_listener = create_shardus_net_listener(cx, port, host)?;
let shardus_net_listener = create_shardus_net_listener(cx, port, host, net_config)?;
let shardus_net_sender = create_shardus_net_sender(use_lru, NonZeroUsize::new(lru_size as usize).unwrap(), key_pair);
let (stats, stats_incrementers) = Stats::new();
let shardus_net_listener = cx.boxed(shardus_net_listener);
Expand Down Expand Up @@ -419,11 +429,11 @@ fn evict_socket(mut cx: FunctionContext) -> JsResult<JsUndefined> {
}
}

fn create_shardus_net_listener(cx: &mut FunctionContext, port: f64, host: String) -> Result<Arc<ShardusNetListener>, Throw> {
fn create_shardus_net_listener(cx: &mut FunctionContext, port: f64, host: String, net_config: NetConfig) -> Result<Arc<ShardusNetListener>, Throw> {
// @TODO: Verify that a javascript number properly converts here without loss.
let address = (host, port as u16);

let shardus_net = ShardusNetListener::new(address);
let shardus_net = ShardusNetListener::new(address, net_config);

match shardus_net {
Ok(net) => Ok(Arc::new(net)),
Expand Down
21 changes: 15 additions & 6 deletions shardus_net/src/message.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io::{Cursor, Read, Write};

use crate::{check_variable_size, HEADER_SIZE_LIMIT_IN_BYTES, OWNER_SIZE_LIMIT_IN_BYTES, PAYLOAD_SIZE_LIMIT_IN_BYTES, SIGNATURE_SIZE_LIMIT_IN_BYTES};
use crate::NetConfig;
use crate::{check_variable_size, OWNER_SIZE_LIMIT_IN_BYTES, SIGNATURE_SIZE_LIMIT_IN_BYTES};
use crypto::Format::Buffer;
use crypto::{KeyPair, ShardusCrypto};

Expand Down Expand Up @@ -84,7 +85,7 @@ impl Message {
buffer
}

pub fn deserialize(cursor: &mut Cursor<Vec<u8>>) -> Option<Message> {
pub fn deserialize(cursor: &mut Cursor<Vec<u8>>, net_config: NetConfig) -> Option<Message> {
// Deserialize header_version
let mut header_version_bytes = [0u8; 1];
cursor.read_exact(&mut header_version_bytes).ok()?;
Expand All @@ -94,7 +95,7 @@ impl Message {
let mut header_len_bytes = [0u8; 4];
cursor.read_exact(&mut header_len_bytes).ok()?;
let header_len = u32::from_le_bytes(header_len_bytes);
check_variable_size(header_len, HEADER_SIZE_LIMIT_IN_BYTES);
check_variable_size(header_len, net_config.header_size_limit);
let mut header_bytes = vec![0u8; header_len as usize];
cursor.read_exact(&mut header_bytes).ok()?;
let header = header_bytes;
Expand All @@ -103,7 +104,8 @@ impl Message {
let mut data_len_bytes = [0u8; 4];
cursor.read_exact(&mut data_len_bytes).ok()?;
let data_len = u32::from_le_bytes(data_len_bytes);
check_variable_size(data_len, PAYLOAD_SIZE_LIMIT_IN_BYTES);
let data_len_limit = net_config.payload_size_limit - header_len as usize; // Since Payload size = header + data
check_variable_size(data_len, data_len_limit);
let mut data_bytes = vec![0u8; data_len as usize];
cursor.read_exact(&mut data_bytes).ok()?;
let data = data_bytes;
Expand Down Expand Up @@ -190,7 +192,10 @@ mod tests {
owner: vec![0x12, 0x34, 0x56, 0x78],
sig: vec![0x9a, 0xbc, 0xde, 0xf0],
};

let net_config = NetConfig {
header_size_limit: 2 * 1024,
payload_size_limit: 2 * 1024 * 1024,
};
let serialized = sign.serialize();
let mut cursor = Cursor::new(serialized);
let deserialized = Sign::deserialize(&mut cursor).unwrap();
Expand All @@ -213,9 +218,13 @@ mod tests {
sign,
};

let net_config = NetConfig {
header_size_limit: 2 * 1024,
payload_size_limit: 2 * 1024 * 1024,
};
let serialized = message.serialize();
let mut cursor = Cursor::new(serialized);
let deserialized = Message::deserialize(&mut cursor).unwrap();
let deserialized = Message::deserialize(&mut cursor, net_config).unwrap();

assert_eq!(message.header_version, deserialized.header_version);
assert_eq!(message.header, deserialized.header);
Expand Down
43 changes: 19 additions & 24 deletions shardus_net/src/shardus_net_listener.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use super::runtime::RUNTIME;
use crate::header::header_types::RequestMetadata;
use crate::header_factory::header_deserialize_factory;
use crate::message::Message;
use crate::{shardus_crypto, HEADER_SIZE_LIMIT_IN_BYTES, PAYLOAD_SIZE_LIMIT_IN_BYTES};

use super::runtime::RUNTIME;

use crate::shardus_crypto;
use crate::NetConfig;
use log::{error, info};
use std::io::Cursor;
use std::net::{SocketAddr, ToSocketAddrs};
Expand All @@ -18,6 +17,7 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};

pub struct ShardusNetListener {
address: SocketAddr,
net_config: NetConfig,
}

#[derive(Error, Debug)]
Expand All @@ -34,31 +34,31 @@ pub enum ListenerError {
type ListenerResult<T> = Result<T, ListenerError>;

impl ShardusNetListener {
pub fn new<A: ToSocketAddrs>(address: A) -> Result<Self, ()> {
pub fn new<A: ToSocketAddrs>(address: A, net_config: NetConfig) -> Result<Self, ()> {
let mut addresses = address.to_socket_addrs().map_err(|_| ())?;
let address = addresses.next().ok_or(())?;

Ok(Self { address })
Ok(Self { address, net_config })
}

pub fn listen(&self) -> UnboundedReceiver<(String, SocketAddr, Option<RequestMetadata>)> {
Self::spawn_listener(self.address)
Self::spawn_listener(self.address, self.net_config)
}

fn spawn_listener(address: SocketAddr) -> UnboundedReceiver<(String, SocketAddr, Option<RequestMetadata>)> {
fn spawn_listener(address: SocketAddr, net_config: NetConfig) -> UnboundedReceiver<(String, SocketAddr, Option<RequestMetadata>)> {
let (tx, rx) = unbounded_channel();
RUNTIME.spawn(Self::bind_to_socket(address, tx));
RUNTIME.spawn(Self::bind_to_socket(address, tx, net_config));
rx
}

async fn bind_to_socket(address: SocketAddr, tx: UnboundedSender<(String, SocketAddr, Option<RequestMetadata>)>) {
async fn bind_to_socket(address: SocketAddr, tx: UnboundedSender<(String, SocketAddr, Option<RequestMetadata>)>, net_config: NetConfig) {
loop {
let listener = TcpListener::bind(address).await;

match listener {
Ok(listener) => {
let tx = tx.clone();
match Self::accept_connections(listener, tx).await {
match Self::accept_connections(listener, tx, net_config.clone()).await {
Ok(_) => unreachable!(),
Err(err) => {
error!("Failed to accept connection to {} due to {}", address, err)
Expand All @@ -72,13 +72,14 @@ impl ShardusNetListener {
}
}

async fn accept_connections(listener: TcpListener, received_msg_tx: UnboundedSender<(String, SocketAddr, Option<RequestMetadata>)>) -> std::io::Result<()> {
async fn accept_connections(listener: TcpListener, received_msg_tx: UnboundedSender<(String, SocketAddr, Option<RequestMetadata>)>, net_config: NetConfig) -> std::io::Result<()> {
loop {
let (socket, remote_addr) = listener.accept().await?;
let received_msg_tx = received_msg_tx.clone();
let net_config = net_config.clone();

RUNTIME.spawn(async move {
let result = Self::receive(socket, remote_addr, received_msg_tx).await;
let result = Self::receive(socket, remote_addr, received_msg_tx, net_config).await;
match result {
Ok(_) => info!("Connection safely completed and shutdown with {}", remote_addr),
Err(err) => {
Expand All @@ -89,11 +90,11 @@ impl ShardusNetListener {
}
}

async fn receive(socket_stream: TcpStream, remote_addr: SocketAddr, received_msg_tx: UnboundedSender<(String, SocketAddr, Option<RequestMetadata>)>) -> ListenerResult<()> {
async fn receive(socket_stream: TcpStream, remote_addr: SocketAddr, received_msg_tx: UnboundedSender<(String, SocketAddr, Option<RequestMetadata>)>, net_config: NetConfig) -> ListenerResult<()> {
let mut socket_stream: TcpStream = socket_stream;
while let Ok(msg_len) = socket_stream.read_u32().await {
if (msg_len as usize) > PAYLOAD_SIZE_LIMIT_IN_BYTES {
error!("Message length exceeds the limit of 2MB");
if (msg_len as usize) > net_config.payload_size_limit {
error!("Message length exceeds the limit of {} bytes", net_config.payload_size_limit);
continue;
}

Expand All @@ -115,21 +116,15 @@ impl ShardusNetListener {
let msg_bytes = &buffer[1..];

let mut cursor = Cursor::new(msg_bytes.to_vec());
let message = Message::deserialize(&mut cursor).expect("Failed to deserialize message");

if message.header.len() > HEADER_SIZE_LIMIT_IN_BYTES {
error!("Header exceeds the limit of {} bytes", HEADER_SIZE_LIMIT_IN_BYTES);
continue;
}

let message = Message::deserialize(&mut cursor, net_config).expect("Failed to deserialize message");
if !message.verify(shardus_crypto::get_shardus_crypto_instance()) {
error!("Failed to verify message signature");
continue;
}
info!("Message verified!");

let header_cursor = &mut Cursor::new(message.header);
let header = header_deserialize_factory(message.header_version, header_cursor).expect("Failed to deserialize header");
let header = header_deserialize_factory(message.header_version, header_cursor, net_config).expect("Failed to deserialize header");

let data = message.data;

Expand Down
13 changes: 12 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,23 @@ export const Sn = (opts: SnOpts) => {
const LRU_SIZE = (opts.senderOpts && opts.senderOpts.lruSize) || 1028
const HASH_KEY = opts.crypto.hashKey
const SIGNING_SECRET_KEY_HEX = opts.crypto.signingSecretKeyHex
const PAYLOAD_SIZE_LIMIT = opts.payloadOpts?.payloadSizeLimitInBytes || 2 * 1024 * 1024 // 2MB
const HEADER_SIZE_LIMIT = opts.payloadOpts?.headerSizeLimitInBytes || 2 * 1024 // 2KB

const HEADER_OPTS = opts.headerOpts || {
sendHeaderVersion: 0,
}

const _net = net.Sn(PORT, ADDRESS, USE_LRU_CACHE, LRU_SIZE, HASH_KEY, SIGNING_SECRET_KEY_HEX)
const _net = net.Sn(
PORT,
ADDRESS,
USE_LRU_CACHE,
LRU_SIZE,
HASH_KEY,
SIGNING_SECRET_KEY_HEX,
PAYLOAD_SIZE_LIMIT,
HEADER_SIZE_LIMIT
)

net.setLoggingEnabled(false)

Expand Down
Loading

0 comments on commit 023b6dd

Please sign in to comment.