diff --git a/shardus_net/src/header/header_v1.rs b/shardus_net/src/header/header_v1.rs index a29b7f8..00aa0f8 100644 --- a/shardus_net/src/header/header_v1.rs +++ b/shardus_net/src/header/header_v1.rs @@ -6,7 +6,8 @@ extern crate serde_json; use crate::compression::Compression; use serde::Deserialize; -use crate::{check_variable_size, HEADER_SIZE_LIMIT_IN_BYTES}; +use crate::check_variable_size; +use crate::NetConfig; #[derive(Deserialize)] pub struct HeaderV1 { @@ -23,6 +24,7 @@ pub struct HeaderV1 { pub compression: Compression, } +const SENDER_ID_SIZE: usize = 64; impl HeaderV1 { // Serialize the struct into a Vec pub fn serialize(&self) -> Vec { @@ -59,7 +61,7 @@ impl HeaderV1 { } // Deserialize a Vec cursor into a HeaderV1 struct - pub fn deserialize(cursor: &mut Cursor>) -> Option { + pub fn deserialize(cursor: &mut Cursor>, net_config: NetConfig) -> Option { // Deserialize uuid let mut uuid_bytes = [0u8; 16]; cursor.read_exact(&mut uuid_bytes).ok()?; @@ -74,7 +76,7 @@ impl HeaderV1 { let mut sender_id_len_bytes = [0u8; 4]; cursor.read_exact(&mut sender_id_len_bytes).ok()?; let sender_id_len = u32::from_le_bytes(sender_id_len_bytes); - check_variable_size(sender_id_len, 64); + check_variable_size(sender_id_len, SENDER_ID_SIZE); let mut sender_id_bytes = vec![0u8; sender_id_len as usize]; cursor.read_exact(&mut sender_id_bytes).ok()?; let sender_id = String::from_utf8(sender_id_bytes).ok()?; @@ -83,7 +85,7 @@ impl HeaderV1 { let mut tracker_id_len_bytes = [0u8; 4]; cursor.read_exact(&mut tracker_id_len_bytes).ok()?; let tracker_id_len = u32::from_le_bytes(tracker_id_len_bytes); - check_variable_size(tracker_id_len, HEADER_SIZE_LIMIT_IN_BYTES); + check_variable_size(tracker_id_len, net_config.header_size_limit); let mut tracker_id_bytes = vec![0u8; tracker_id_len as usize]; cursor.read_exact(&mut tracker_id_bytes).ok()?; let tracker_id = String::from_utf8(tracker_id_bytes).ok()?; @@ -92,7 +94,7 @@ impl HeaderV1 { let mut verification_data_len_bytes = [0u8; 4]; cursor.read_exact(&mut verification_data_len_bytes).ok()?; let verification_data_len = u32::from_le_bytes(verification_data_len_bytes); - check_variable_size(verification_data_len, HEADER_SIZE_LIMIT_IN_BYTES); + check_variable_size(verification_data_len, net_config.header_size_limit); let mut verification_data_bytes = vec![0u8; verification_data_len as usize]; cursor.read_exact(&mut verification_data_bytes).ok()?; let verification_data = String::from_utf8(verification_data_bytes).ok()?; @@ -146,10 +148,13 @@ mod tests { verification_data: "verification_data_1".to_string(), compression: Compression::None, }; - + let net_config = NetConfig { + header_size_limit: 2 * 1024, + payload_size_limit: 2 * 1024 * 1024, + }; let serialized = header.serialize(); let mut cursor = Cursor::new(serialized); - let deserialized = HeaderV1::deserialize(&mut cursor).unwrap(); + let deserialized = HeaderV1::deserialize(&mut cursor, net_config).unwrap(); assert_eq!(header.uuid, deserialized.uuid); assert_eq!(header.message_length, deserialized.message_length); @@ -194,23 +199,26 @@ mod tests { #[test] #[should_panic(expected = "variable_len exceeds the limit")] fn test_check_variable_size_panic() { - use crate::HEADER_SIZE_LIMIT_IN_BYTES; - + let net_config = NetConfig { + header_size_limit: 2 * 1024, + payload_size_limit: 2 * 1024 * 1024, + }; // Define a variable length that exceeds the limit - let oversized_length = HEADER_SIZE_LIMIT_IN_BYTES as u32 + 1; + let oversized_length = net_config.header_size_limit as u32 + 1; // Call the function, expecting it to panic - check_variable_size(oversized_length, HEADER_SIZE_LIMIT_IN_BYTES); + check_variable_size(oversized_length, net_config.header_size_limit); } #[test] fn test_check_variable_size_no_panic() { - use crate::HEADER_SIZE_LIMIT_IN_BYTES; - + let net_config = NetConfig { + header_size_limit: 2 * 1024, + payload_size_limit: 2 * 1024 * 1024, + }; // Define a variable length within the limit : 2048 (0x800) let valid_length = 0x799; - // Call the function, ensuring it does not panic - check_variable_size(valid_length, HEADER_SIZE_LIMIT_IN_BYTES); + check_variable_size(valid_length, net_config.header_size_limit); } } diff --git a/shardus_net/src/header_factory.rs b/shardus_net/src/header_factory.rs index 5cc091a..0e2bf63 100644 --- a/shardus_net/src/header_factory.rs +++ b/shardus_net/src/header_factory.rs @@ -2,6 +2,7 @@ use std::io::Cursor; use crate::header::header_types::Header; use crate::header::header_v1::HeaderV1; +use crate::NetConfig; pub fn wrap_serialized_message(mut serialized_message: Vec) -> Vec { let mut buffer = Vec::new(); @@ -10,10 +11,10 @@ pub fn wrap_serialized_message(mut serialized_message: Vec) -> Vec { buffer } -pub fn header_deserialize_factory(version: u8, serialized_header_cursor: &mut Cursor>) -> Option
{ +pub fn header_deserialize_factory(version: u8, serialized_header_cursor: &mut Cursor>, net_config: NetConfig) -> Option
{ match version { 1 => { - let deserialized = HeaderV1::deserialize(serialized_header_cursor)?; + let deserialized = HeaderV1::deserialize(serialized_header_cursor, net_config)?; Some(Header::V1(deserialized)) } _ => None, diff --git a/shardus_net/src/lib.rs b/shardus_net/src/lib.rs index 3b01f5f..c8e2cd4 100644 --- a/shardus_net/src/lib.rs +++ b/shardus_net/src/lib.rs @@ -38,8 +38,12 @@ use tokio::sync::Mutex; use crate::shardus_net_sender::Connection; const ENABLE_COMPRESSION: bool = false; -const HEADER_SIZE_LIMIT_IN_BYTES: usize = 2 * 1024; // 2KB -const PAYLOAD_SIZE_LIMIT_IN_BYTES: usize = 2 * 1024 * 1024; // 2MB + +#[derive(Copy, Clone)] +pub struct NetConfig { + pub header_size_limit: usize, + pub payload_size_limit: usize, +} const SIGNATURE_SIZE_LIMIT_IN_BYTES: usize = 96; const OWNER_SIZE_LIMIT_IN_BYTES: usize = 32; @@ -52,13 +56,19 @@ fn create_shardus_net(mut cx: FunctionContext) -> JsResult { let use_lru = cx.argument::(2)?.value(cx); let lru_size = cx.argument::(3)?.value(cx); let hash_key = cx.argument::(4)?.value(cx); + let hex_signing_sk = cx.argument::(5)?.value(cx); + let payload_size_limit = cx.argument::(6)?.value(cx) as usize; + let header_size_limit = cx.argument::(7)?.value(cx) as usize; + + let net_config = NetConfig { + header_size_limit, + payload_size_limit, + }; shardus_crypto::initialize_shardus_crypto_instance(&hash_key); - let hex_signing_sk = cx.argument::(5)?.value(cx); let key_pair = shardus_crypto::get_shardus_crypto_instance().get_key_pair_using_sk(&crypto::HexStringOrBuffer::Hex(hex_signing_sk)); - - let shardus_net_listener = create_shardus_net_listener(cx, port, host)?; + let shardus_net_listener = create_shardus_net_listener(cx, port, host, net_config)?; let shardus_net_sender = create_shardus_net_sender(use_lru, NonZeroUsize::new(lru_size as usize).unwrap(), key_pair); let (stats, stats_incrementers) = Stats::new(); let shardus_net_listener = cx.boxed(shardus_net_listener); @@ -419,11 +429,11 @@ fn evict_socket(mut cx: FunctionContext) -> JsResult { } } -fn create_shardus_net_listener(cx: &mut FunctionContext, port: f64, host: String) -> Result, Throw> { +fn create_shardus_net_listener(cx: &mut FunctionContext, port: f64, host: String, net_config: NetConfig) -> Result, Throw> { // @TODO: Verify that a javascript number properly converts here without loss. let address = (host, port as u16); - let shardus_net = ShardusNetListener::new(address); + let shardus_net = ShardusNetListener::new(address, net_config); match shardus_net { Ok(net) => Ok(Arc::new(net)), diff --git a/shardus_net/src/message.rs b/shardus_net/src/message.rs index b9f6118..e9b9fe9 100644 --- a/shardus_net/src/message.rs +++ b/shardus_net/src/message.rs @@ -1,6 +1,7 @@ use std::io::{Cursor, Read, Write}; -use crate::{check_variable_size, HEADER_SIZE_LIMIT_IN_BYTES, OWNER_SIZE_LIMIT_IN_BYTES, PAYLOAD_SIZE_LIMIT_IN_BYTES, SIGNATURE_SIZE_LIMIT_IN_BYTES}; +use crate::NetConfig; +use crate::{check_variable_size, OWNER_SIZE_LIMIT_IN_BYTES, SIGNATURE_SIZE_LIMIT_IN_BYTES}; use crypto::Format::Buffer; use crypto::{KeyPair, ShardusCrypto}; @@ -84,7 +85,7 @@ impl Message { buffer } - pub fn deserialize(cursor: &mut Cursor>) -> Option { + pub fn deserialize(cursor: &mut Cursor>, net_config: NetConfig) -> Option { // Deserialize header_version let mut header_version_bytes = [0u8; 1]; cursor.read_exact(&mut header_version_bytes).ok()?; @@ -94,7 +95,7 @@ impl Message { let mut header_len_bytes = [0u8; 4]; cursor.read_exact(&mut header_len_bytes).ok()?; let header_len = u32::from_le_bytes(header_len_bytes); - check_variable_size(header_len, HEADER_SIZE_LIMIT_IN_BYTES); + check_variable_size(header_len, net_config.header_size_limit); let mut header_bytes = vec![0u8; header_len as usize]; cursor.read_exact(&mut header_bytes).ok()?; let header = header_bytes; @@ -103,7 +104,8 @@ impl Message { let mut data_len_bytes = [0u8; 4]; cursor.read_exact(&mut data_len_bytes).ok()?; let data_len = u32::from_le_bytes(data_len_bytes); - check_variable_size(data_len, PAYLOAD_SIZE_LIMIT_IN_BYTES); + let data_len_limit = net_config.payload_size_limit - header_len as usize; // Since Payload size = header + data + check_variable_size(data_len, data_len_limit); let mut data_bytes = vec![0u8; data_len as usize]; cursor.read_exact(&mut data_bytes).ok()?; let data = data_bytes; @@ -190,7 +192,10 @@ mod tests { owner: vec![0x12, 0x34, 0x56, 0x78], sig: vec![0x9a, 0xbc, 0xde, 0xf0], }; - + let net_config = NetConfig { + header_size_limit: 2 * 1024, + payload_size_limit: 2 * 1024 * 1024, + }; let serialized = sign.serialize(); let mut cursor = Cursor::new(serialized); let deserialized = Sign::deserialize(&mut cursor).unwrap(); @@ -213,9 +218,13 @@ mod tests { sign, }; + let net_config = NetConfig { + header_size_limit: 2 * 1024, + payload_size_limit: 2 * 1024 * 1024, + }; let serialized = message.serialize(); let mut cursor = Cursor::new(serialized); - let deserialized = Message::deserialize(&mut cursor).unwrap(); + let deserialized = Message::deserialize(&mut cursor, net_config).unwrap(); assert_eq!(message.header_version, deserialized.header_version); assert_eq!(message.header, deserialized.header); diff --git a/shardus_net/src/shardus_net_listener.rs b/shardus_net/src/shardus_net_listener.rs index 527842d..f015413 100644 --- a/shardus_net/src/shardus_net_listener.rs +++ b/shardus_net/src/shardus_net_listener.rs @@ -1,10 +1,9 @@ +use super::runtime::RUNTIME; use crate::header::header_types::RequestMetadata; use crate::header_factory::header_deserialize_factory; use crate::message::Message; -use crate::{shardus_crypto, HEADER_SIZE_LIMIT_IN_BYTES, PAYLOAD_SIZE_LIMIT_IN_BYTES}; - -use super::runtime::RUNTIME; - +use crate::shardus_crypto; +use crate::NetConfig; use log::{error, info}; use std::io::Cursor; use std::net::{SocketAddr, ToSocketAddrs}; @@ -18,6 +17,7 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; pub struct ShardusNetListener { address: SocketAddr, + net_config: NetConfig, } #[derive(Error, Debug)] @@ -34,31 +34,31 @@ pub enum ListenerError { type ListenerResult = Result; impl ShardusNetListener { - pub fn new(address: A) -> Result { + pub fn new(address: A, net_config: NetConfig) -> Result { let mut addresses = address.to_socket_addrs().map_err(|_| ())?; let address = addresses.next().ok_or(())?; - Ok(Self { address }) + Ok(Self { address, net_config }) } pub fn listen(&self) -> UnboundedReceiver<(String, SocketAddr, Option)> { - Self::spawn_listener(self.address) + Self::spawn_listener(self.address, self.net_config) } - fn spawn_listener(address: SocketAddr) -> UnboundedReceiver<(String, SocketAddr, Option)> { + fn spawn_listener(address: SocketAddr, net_config: NetConfig) -> UnboundedReceiver<(String, SocketAddr, Option)> { let (tx, rx) = unbounded_channel(); - RUNTIME.spawn(Self::bind_to_socket(address, tx)); + RUNTIME.spawn(Self::bind_to_socket(address, tx, net_config)); rx } - async fn bind_to_socket(address: SocketAddr, tx: UnboundedSender<(String, SocketAddr, Option)>) { + async fn bind_to_socket(address: SocketAddr, tx: UnboundedSender<(String, SocketAddr, Option)>, net_config: NetConfig) { loop { let listener = TcpListener::bind(address).await; match listener { Ok(listener) => { let tx = tx.clone(); - match Self::accept_connections(listener, tx).await { + match Self::accept_connections(listener, tx, net_config.clone()).await { Ok(_) => unreachable!(), Err(err) => { error!("Failed to accept connection to {} due to {}", address, err) @@ -72,13 +72,14 @@ impl ShardusNetListener { } } - async fn accept_connections(listener: TcpListener, received_msg_tx: UnboundedSender<(String, SocketAddr, Option)>) -> std::io::Result<()> { + async fn accept_connections(listener: TcpListener, received_msg_tx: UnboundedSender<(String, SocketAddr, Option)>, net_config: NetConfig) -> std::io::Result<()> { loop { let (socket, remote_addr) = listener.accept().await?; let received_msg_tx = received_msg_tx.clone(); + let net_config = net_config.clone(); RUNTIME.spawn(async move { - let result = Self::receive(socket, remote_addr, received_msg_tx).await; + let result = Self::receive(socket, remote_addr, received_msg_tx, net_config).await; match result { Ok(_) => info!("Connection safely completed and shutdown with {}", remote_addr), Err(err) => { @@ -89,11 +90,11 @@ impl ShardusNetListener { } } - async fn receive(socket_stream: TcpStream, remote_addr: SocketAddr, received_msg_tx: UnboundedSender<(String, SocketAddr, Option)>) -> ListenerResult<()> { + async fn receive(socket_stream: TcpStream, remote_addr: SocketAddr, received_msg_tx: UnboundedSender<(String, SocketAddr, Option)>, net_config: NetConfig) -> ListenerResult<()> { let mut socket_stream: TcpStream = socket_stream; while let Ok(msg_len) = socket_stream.read_u32().await { - if (msg_len as usize) > PAYLOAD_SIZE_LIMIT_IN_BYTES { - error!("Message length exceeds the limit of 2MB"); + if (msg_len as usize) > net_config.payload_size_limit { + error!("Message length exceeds the limit of {} bytes", net_config.payload_size_limit); continue; } @@ -115,13 +116,7 @@ impl ShardusNetListener { let msg_bytes = &buffer[1..]; let mut cursor = Cursor::new(msg_bytes.to_vec()); - let message = Message::deserialize(&mut cursor).expect("Failed to deserialize message"); - - if message.header.len() > HEADER_SIZE_LIMIT_IN_BYTES { - error!("Header exceeds the limit of {} bytes", HEADER_SIZE_LIMIT_IN_BYTES); - continue; - } - + let message = Message::deserialize(&mut cursor, net_config).expect("Failed to deserialize message"); if !message.verify(shardus_crypto::get_shardus_crypto_instance()) { error!("Failed to verify message signature"); continue; @@ -129,7 +124,7 @@ impl ShardusNetListener { info!("Message verified!"); let header_cursor = &mut Cursor::new(message.header); - let header = header_deserialize_factory(message.header_version, header_cursor).expect("Failed to deserialize header"); + let header = header_deserialize_factory(message.header_version, header_cursor, net_config).expect("Failed to deserialize header"); let data = message.data; diff --git a/src/index.ts b/src/index.ts index f8a1752..fa79a9d 100644 --- a/src/index.ts +++ b/src/index.ts @@ -49,12 +49,23 @@ export const Sn = (opts: SnOpts) => { const LRU_SIZE = (opts.senderOpts && opts.senderOpts.lruSize) || 1028 const HASH_KEY = opts.crypto.hashKey const SIGNING_SECRET_KEY_HEX = opts.crypto.signingSecretKeyHex + const PAYLOAD_SIZE_LIMIT = opts.payloadOpts?.payloadSizeLimitInBytes || 2 * 1024 * 1024 // 2MB + const HEADER_SIZE_LIMIT = opts.payloadOpts?.headerSizeLimitInBytes || 2 * 1024 // 2KB const HEADER_OPTS = opts.headerOpts || { sendHeaderVersion: 0, } - const _net = net.Sn(PORT, ADDRESS, USE_LRU_CACHE, LRU_SIZE, HASH_KEY, SIGNING_SECRET_KEY_HEX) + const _net = net.Sn( + PORT, + ADDRESS, + USE_LRU_CACHE, + LRU_SIZE, + HASH_KEY, + SIGNING_SECRET_KEY_HEX, + PAYLOAD_SIZE_LIMIT, + HEADER_SIZE_LIMIT + ) net.setLoggingEnabled(false) diff --git a/src/types.ts b/src/types.ts index e5ce3fb..78448f5 100644 --- a/src/types.ts +++ b/src/types.ts @@ -60,6 +60,10 @@ export type SnOpts = { hashKey: string signingSecretKeyHex: string } + payloadOpts?: { + payloadSizeLimitInBytes?: number + headerSizeLimitInBytes?: number + } } /** @@ -70,13 +74,33 @@ export type SnOpts = { export const validateSnOpts = (opts: SnOpts): void => { if (!opts) throw new Error('snq: must supply options') - if (!opts.port || typeof opts.port !== 'number') throw new Error('snq: must supply port') - - if (!opts.crypto.hashKey || typeof opts.crypto.hashKey !== 'string') - throw new Error('snq: must supply hashKey') - - if (opts.senderOpts && opts.senderOpts.useLruCache && !opts.senderOpts.lruSize) - throw new Error('snq: must supply lruSize when using lruCache') + const validations = [ + { condition: !opts.port || typeof opts.port !== 'number', message: 'snq: must supply port' }, + { + condition: !opts.crypto.hashKey || typeof opts.crypto.hashKey !== 'string', + message: 'snq: must supply hashKey', + }, + { + condition: opts.senderOpts?.useLruCache && !opts.senderOpts.lruSize, + message: 'snq: must supply lruSize when using lruCache', + }, + { + condition: + opts.payloadOpts?.payloadSizeLimitInBytes && + typeof opts.payloadOpts.payloadSizeLimitInBytes !== 'number', + message: 'snq: payloadSizeLimitInBytes must be a number', + }, + { + condition: + opts.payloadOpts?.headerSizeLimitInBytes && + typeof opts.payloadOpts.headerSizeLimitInBytes !== 'number', + message: 'snq: headerSizeLimitInBytes must be a number', + }, + ] + + for (const { condition, message } of validations) { + if (condition) throw new Error(message) + } } export interface RemoteSender { diff --git a/test/test_multi_send.ts b/test/test_multi_send.ts index cc383f5..f115b96 100644 --- a/test/test_multi_send.ts +++ b/test/test_multi_send.ts @@ -1,8 +1,14 @@ import { Command } from 'commander' import { Sn } from '../.' -import { AppHeader, Sign } from '../build/src/types' -const setupLruSender = (port: number, lruSize: number) => { +const setupLruSender = ( + port: number, + lruSize: number, + limits: { + payloadSize?: number + headerSize?: number + } +) => { return Sn({ port, address: '127.0.0.1', @@ -18,6 +24,10 @@ const setupLruSender = (port: number, lruSize: number) => { headerOpts: { sendHeaderVersion: 1, }, + payloadOpts: { + payloadSizeLimitInBytes: limits.payloadSize || 2 * 1024 * 1024, // Default 2MB + headerSizeLimitInBytes: limits.headerSize || 2 * 1024, // Default 2KB + }, }) } @@ -26,10 +36,13 @@ const main = async () => { create a cli with the following options: -p, --port Port to listen on -c, --cache Size of the LRU cache + --payload-size Payload size limit in bytes + --header-size Header size limit in bytes the cli should create a sender with the following options: - lruSize: - port: + - limits: { payloadSize, headerSize} on running the cli a listener should be started and sending of message with input from terminal should be allowed */ @@ -37,8 +50,8 @@ const main = async () => { /* Commands to use for multi_send_with_header - ts-node test/test_multi_send.ts -p 44000 -c 2 - path/to/test_multi_send.ts -p -c + ts-node test/test_multi_send.ts -p 44000 -c 2 --payload-size 2097152 --header-size 2048 + path/to/test_multi_send.ts -p -c --payload-size --header-size data 3 ping @@ -49,14 +62,20 @@ const main = async () => { const program = new Command() program.requiredOption('-p, --port ', 'Port to listen on') program.option('-c, --cache ', 'Size of the LRU cache', '2') + program.option('--payload-size ', 'Payload size limit in bytes', '2097152') // Default 2MB + program.option('--header-size ', 'Header size limit in bytes', '2048') // Default 2KB program.parse(process.argv) const port = program.port.toString() const cacheSize = program.cache.toString() + const limits = { + payloadSize: +program.payloadSize, + headerSize: +program.headerSize, + } console.log(`Starting listener on port ${port} with cache size ${cacheSize}`) - - const sn = setupLruSender(+port, +cacheSize) + console.log(`Limits: ${JSON.stringify(limits, null, 2)}`) + const sn = setupLruSender(+port, +cacheSize, limits) const input = process.stdin input.addListener('data', async (data: Buffer) => { @@ -86,7 +105,7 @@ const main = async () => { }, 1000 ) - console.log('Message sent', message) + console.log('Message sent: ', message) } else if (inputs.length === 2) { sn.evictSocket(+inputs[1], '127.0.0.1') console.log('Cache cleared') @@ -100,7 +119,6 @@ const main = async () => { if (data && data.message === 'ping') { console.log('Received ping from:', data.fromPort) console.log('Ping header:', JSON.stringify(header, null, 2)) - // await sleep(10000) return respond( { message: 'pong', fromPort: +port }, {