From 1e26d718d7d4f08b7b3175358a8e5c0f23667061 Mon Sep 17 00:00:00 2001 From: Gary Tierney Date: Sat, 17 Aug 2024 06:26:46 +0100 Subject: [PATCH 1/3] Read event data in parallel to backtest Remove the `read_data()` calls within the backtest implementations and replace them with `recv()` calls on a lock-free queue. This avoids the pause that happened previously when a backtest reaches the end of the current periods data and begins loading the next file. With this model, the data for the next period should be available by the time the previous one finishes. --- hftbacktest/Cargo.toml | 1 + hftbacktest/src/backtest/data/bus.rs | 106 +++++++++++ hftbacktest/src/backtest/data/mod.rs | 4 + hftbacktest/src/backtest/data/reader.rs | 3 +- hftbacktest/src/backtest/mod.rs | 49 +++-- hftbacktest/src/backtest/proc/l3_local.rs | 13 +- .../backtest/proc/l3_nopartialfillexchange.rs | 154 ++++++--------- hftbacktest/src/backtest/proc/local.rs | 86 ++++----- hftbacktest/src/backtest/proc/mod.rs | 6 +- .../backtest/proc/nopartialfillexchange.rs | 175 +++++++++--------- .../src/backtest/proc/partialfillexchange.rs | 155 +++++++--------- 11 files changed, 411 insertions(+), 341 deletions(-) create mode 100644 hftbacktest/src/backtest/data/bus.rs diff --git a/hftbacktest/Cargo.toml b/hftbacktest/Cargo.toml index 82a9655e..47b62abd 100644 --- a/hftbacktest/Cargo.toml +++ b/hftbacktest/Cargo.toml @@ -49,6 +49,7 @@ hmac = { version = "0.13.0-pre.3", optional = true } rand = { version = "0.8.5", optional = true } uuid = { version = "1.8.0", features = ["v4"], optional = true } nom = { version = "7.1.3", optional = true } +bus = { version = "2.4" } hftbacktest-derive = { path = "../hftbacktest-derive", optional = true, version = "0.1.0" } [dev-dependencies] diff --git a/hftbacktest/src/backtest/data/bus.rs b/hftbacktest/src/backtest/data/bus.rs new file mode 100644 index 00000000..2e60d26e --- /dev/null +++ b/hftbacktest/src/backtest/data/bus.rs @@ -0,0 +1,106 @@ +use std::{io, io::ErrorKind}; +use std::iter::Peekable; +use bus::{Bus, BusIntoIter, BusReader}; +use tracing::{error, info, info_span}; + +use crate::backtest::{ + data::{read_npy_file, read_npz_file, Data, NpyDTyped}, + BacktestError, +}; + +#[derive(Copy, Clone)] +pub enum EventBusMessage { + Item(EventT), + EndOfData, +} + +pub struct EventBusReader { + reader: Peekable>>, +} + +impl EventBusReader { + pub fn new(reader: BusReader>) -> Self { + Self { + reader: reader.into_iter().peekable() + } + } + + pub fn peek(&mut self) -> Option<&EventT> { + self.reader.peek().and_then(|ev| match ev { + EventBusMessage::Item(item) => Some(item), + EventBusMessage::EndOfData => None, + }) + } + + pub fn next(&mut self) -> Option { + self.reader.next().and_then(|ev| match ev { + EventBusMessage::Item(item) => Some(item), + EventBusMessage::EndOfData => None, + }) + } +} + +pub trait TimestampedEventQueue { + fn next_event(&mut self) -> Option; + + fn peek_event(&mut self) -> Option<&EventT>; + + fn event_time(value: &EventT) -> i64; +} + +pub trait EventConsumer { + fn is_event_relevant(event: &EventT) -> bool; + + fn process_event(&mut self, event: EventT) -> Result<(), BacktestError>; +} + +fn load_data( + filepath: String, +) -> Result, BacktestError> { + let data = if filepath.ends_with(".npy") { + read_npy_file(&filepath)? + } else if filepath.ends_with(".npz") { + read_npz_file(&filepath, "data")? + } else { + return Err(BacktestError::DataError(io::Error::new( + ErrorKind::InvalidData, + "unsupported data type", + ))); + }; + + Ok(data) +} + +#[tracing::instrument(skip_all)] +pub fn replay_events_to_bus( + mut bus: Bus>, + mut sources: Vec, +) { + for source in sources.drain(..) { + let source_load_span = info_span!("load_data", source = &source); + let _source_load_span = source_load_span.entered(); + + let data = load_data::(source); + + match data { + Ok(data) => { + info!( + records = data.len(), + "found {} events in data source", + data.len() + ); + + for row in 0..data.len() { + bus.broadcast(EventBusMessage::Item(data[row].clone())); + } + } + Err(e) => { + error!("encountered error loading data source: {}", e); + // TODO: handle as an error. + break; + } + } + } + + bus.broadcast(EventBusMessage::EndOfData); +} diff --git a/hftbacktest/src/backtest/data/mod.rs b/hftbacktest/src/backtest/data/mod.rs index a7b38685..d6426e04 100644 --- a/hftbacktest/src/backtest/data/mod.rs +++ b/hftbacktest/src/backtest/data/mod.rs @@ -1,3 +1,4 @@ +mod bus; mod npy; mod reader; @@ -10,6 +11,7 @@ use std::{ slice::SliceIndex, }; +pub use bus::{replay_events_to_bus, EventBusMessage, EventBusReader, EventConsumer, TimestampedEventQueue}; pub use npy::{read_npy_file, read_npz_file, write_npy, Field, NpyDTyped, NpyHeader}; pub use reader::{Cache, DataSource, Reader}; @@ -107,6 +109,8 @@ where } } +unsafe impl Send for DataPtr {} + #[derive(Debug)] pub struct DataPtr { ptr: *mut [u8], diff --git a/hftbacktest/src/backtest/data/reader.rs b/hftbacktest/src/backtest/data/reader.rs index 7c396bab..3c5b221d 100644 --- a/hftbacktest/src/backtest/data/reader.rs +++ b/hftbacktest/src/backtest/data/reader.rs @@ -3,6 +3,7 @@ use std::{ collections::HashMap, io::{Error as IoError, ErrorKind}, rc::Rc, + sync::Arc, }; use uuid::Uuid; @@ -60,7 +61,7 @@ where /// Provides a data cache that allows both the local processor and exchange processor to access the /// same or different data based on their timestamps without the need for reloading. #[derive(Clone, Debug)] -pub struct Cache(Rc>>>) +pub struct Cache(Arc>>>) where D: POD + Clone; diff --git a/hftbacktest/src/backtest/mod.rs b/hftbacktest/src/backtest/mod.rs index 2bc5d9bf..f0b61d01 100644 --- a/hftbacktest/src/backtest/mod.rs +++ b/hftbacktest/src/backtest/mod.rs @@ -1,5 +1,6 @@ -use std::{collections::HashMap, io::Error as IoError, marker::PhantomData}; +use std::{collections::HashMap, io::Error as IoError, marker::PhantomData, sync::mpsc::Receiver}; +use bus::Bus; pub use data::DataSource; use data::{Cache, Reader}; use models::FeeModel; @@ -13,10 +14,18 @@ pub use crate::backtest::{ use crate::{ backtest::{ assettype::AssetType, + data::replay_events_to_bus, evs::{EventIntentKind, EventSet}, models::{LatencyModel, QueueModel}, order::OrderBus, - proc::{Local, LocalProcessor, NoPartialFillExchange, PartialFillExchange, Processor}, + proc::{ + Local, + LocalProcessor, + NoPartialFillExchange, + OrderConsumer, + PartialFillExchange, + Processor, + }, state::State, }, depth::{HashMapMarketDepth, L2MarketDepth, L3MarketDepth, MarketDepth}, @@ -34,6 +43,7 @@ use crate::{ }, types::{BuildError, Event}, }; +use crate::backtest::data::EventBusReader; /// Provides asset types. pub mod assettype; @@ -113,11 +123,11 @@ pub enum ExchangeKind { /// A builder for `Asset`. pub struct AssetBuilder { + data_sources: Vec, latency_model: Option, asset_type: Option, queue_model: Option, depth_builder: Option MD>>, - reader: Reader, fee_model: Option, exch_kind: ExchangeKind, last_trades_cap: usize, @@ -133,18 +143,15 @@ where { /// Constructs an instance of `AssetBuilder`. pub fn new() -> Self { - let cache = Cache::new(); - let reader = Reader::new(cache); - Self { latency_model: None, asset_type: None, queue_model: None, depth_builder: None, - reader, fee_model: None, exch_kind: ExchangeKind::NoPartialFillExchange, last_trades_cap: 0, + data_sources: vec![], } } @@ -153,10 +160,10 @@ where for item in data { match item { DataSource::File(filename) => { - self.reader.add_file(filename); + self.data_sources.push(filename); } - DataSource::Data(data) => { - self.reader.add_data(data); + DataSource::Data(_) => { + todo!("involves a copy"); } } } @@ -242,8 +249,16 @@ where .clone() .ok_or(BuildError::BuilderIncomplete("fee_model"))?; + let mut bus = Bus::new(10_000); + let exch_bus = bus.add_rx(); + let local_bus = bus.add_rx(); + + std::thread::spawn(move || { + replay_events_to_bus(bus, self.data_sources); + }); + let local = Local::new( - self.reader.clone(), + EventBusReader::new(local_bus), create_depth(), State::new(asset_type, fee_model), order_latency, @@ -271,7 +286,7 @@ where match self.exch_kind { ExchangeKind::NoPartialFillExchange => { let exch = NoPartialFillExchange::new( - self.reader.clone(), + EventBusReader::new(exch_bus), create_depth(), State::new(asset_type, fee_model), order_latency, @@ -287,7 +302,7 @@ where } ExchangeKind::PartialFillExchange => { let exch = PartialFillExchange::new( - self.reader.clone(), + EventBusReader::new(exch_bus), create_depth(), State::new(asset_type, fee_model), order_latency, @@ -330,8 +345,12 @@ where .clone() .ok_or(BuildError::BuilderIncomplete("fee_model"))?; + let mut bus = Bus::new(1000); + let local_reader = EventBusReader::new(bus.add_rx()); + let exch_reader = EventBusReader::new(bus.add_rx()); + let local = Local::new( - self.reader.clone(), + local_reader, create_depth(), State::new(asset_type, fee_model), order_latency, @@ -356,7 +375,7 @@ where .clone() .ok_or(BuildError::BuilderIncomplete("fee_model"))?; let exch = NoPartialFillExchange::new( - self.reader.clone(), + exch_reader, create_depth(), State::new(asset_type, fee_model), order_latency, diff --git a/hftbacktest/src/backtest/proc/l3_local.rs b/hftbacktest/src/backtest/proc/l3_local.rs index db962539..8218c3d0 100644 --- a/hftbacktest/src/backtest/proc/l3_local.rs +++ b/hftbacktest/src/backtest/proc/l3_local.rs @@ -9,7 +9,7 @@ use crate::{ data::{Data, Reader}, models::{FeeModel, LatencyModel}, order::OrderBus, - proc::{LocalProcessor, Processor}, + proc::{LocalProcessor, OrderConsumer, Processor}, state::State, BacktestError, }, @@ -312,7 +312,16 @@ where Ok((next_ts, i64::MAX)) } +} +impl OrderConsumer for L3Local +where + AT: AssetType, + LM: LatencyModel, + MD: L3MarketDepth, + FM: FeeModel, + BacktestError: From<::Error>, +{ fn process_recv_order( &mut self, timestamp: i64, @@ -347,11 +356,9 @@ where } Ok(wait_resp_order_received) } - fn earliest_recv_order_timestamp(&self) -> i64 { self.orders_from.earliest_timestamp().unwrap_or(i64::MAX) } - fn earliest_send_order_timestamp(&self) -> i64 { self.orders_to.earliest_timestamp().unwrap_or(i64::MAX) } diff --git a/hftbacktest/src/backtest/proc/l3_nopartialfillexchange.rs b/hftbacktest/src/backtest/proc/l3_nopartialfillexchange.rs index 705bf6d8..4cd48289 100644 --- a/hftbacktest/src/backtest/proc/l3_nopartialfillexchange.rs +++ b/hftbacktest/src/backtest/proc/l3_nopartialfillexchange.rs @@ -1,12 +1,14 @@ use std::mem; +use bus::BusReader; + use crate::{ backtest::{ assettype::AssetType, - data::{Data, Reader}, + data::{Data, EventConsumer, Reader}, models::{FeeModel, L3QueueModel, LatencyModel}, order::OrderBus, - proc::Processor, + proc::{OrderConsumer, Processor}, state::State, BacktestError, }, @@ -63,8 +65,7 @@ where MD: L3MarketDepth, FM: FeeModel, { - reader: Reader, - data: Data, + reader: BusReader, row_num: usize, orders_to: OrderBus, orders_from: OrderBus, @@ -86,7 +87,7 @@ where { /// Constructs an instance of `NoPartialFillExchange`. pub fn new( - reader: Reader, + reader: BusReader, depth: MD, state: State, queue_model: QM, @@ -96,7 +97,6 @@ where ) -> Self { Self { reader, - data: Data::empty(), row_num: 0, orders_to, orders_from, @@ -300,7 +300,7 @@ where } } -impl Processor for L3NoPartialFillExchange +impl EventConsumer for L3NoPartialFillExchange where AT: AssetType, LM: LatencyModel, @@ -309,134 +309,88 @@ where FM: FeeModel, BacktestError: From<::Error>, { - fn initialize_data(&mut self) -> Result { - self.data = self.reader.next_data()?; - for rn in 0..self.data.len() { - if self.data[rn].is(EXCH_EVENT) { - self.row_num = rn; - return Ok(self.data[rn].exch_ts); - } - } - Err(BacktestError::EndOfData) + fn is_event_relevant(&self, event: &EventT) -> bool { + event.is(EXCH_EVENT) } - fn process_data(&mut self) -> Result<(i64, i64), BacktestError> { - let row_num = self.row_num; - if self.data[row_num].is(EXCH_BID_DEPTH_CLEAR_EVENT) { + fn process_event(&mut self, event: Event) -> Result<(), BacktestError> { + if event.is(EXCH_BID_DEPTH_CLEAR_EVENT) { self.depth.clear_orders(Side::Buy); let expired = self.queue_model.clear_orders(Side::Buy); for order in expired { - self.expired(order, self.data[row_num].exch_ts)?; + self.expired(order, event.exch_ts)?; } - } else if self.data[row_num].is(EXCH_ASK_DEPTH_CLEAR_EVENT) { + } else if event.is(EXCH_ASK_DEPTH_CLEAR_EVENT) { self.depth.clear_orders(Side::Sell); let expired = self.queue_model.clear_orders(Side::Sell); for order in expired { - self.expired(order, self.data[row_num].exch_ts)?; + self.expired(order, event.exch_ts)?; } - } else if self.data[row_num].is(EXCH_DEPTH_CLEAR_EVENT) { + } else if event.is(EXCH_DEPTH_CLEAR_EVENT) { self.depth.clear_orders(Side::None); let expired = self.queue_model.clear_orders(Side::None); for order in expired { - self.expired(order, self.data[row_num].exch_ts)?; + self.expired(order, event.exch_ts)?; } - } else if self.data[row_num].is(EXCH_BID_ADD_ORDER_EVENT) { - let (prev_best_bid_tick, best_bid_tick) = self.depth.add_buy_order( - self.data[row_num].order_id, - self.data[row_num].px, - self.data[row_num].qty, - self.data[row_num].exch_ts, - )?; + } else if event.is(EXCH_BID_ADD_ORDER_EVENT) { + let (prev_best_bid_tick, best_bid_tick) = + self.depth + .add_buy_order(event.order_id, event.px, event.qty, event.exch_ts)?; self.queue_model - .add_market_feed_order(&self.data[row_num], &self.depth)?; + .add_market_feed_order(&event, &self.depth)?; if best_bid_tick > prev_best_bid_tick { - self.fill_ask_orders_by_crossing( - prev_best_bid_tick, - best_bid_tick, - self.data[row_num].exch_ts, - )?; + self.fill_ask_orders_by_crossing(prev_best_bid_tick, best_bid_tick, event.exch_ts)?; } - } else if self.data[row_num].is(EXCH_ASK_ADD_ORDER_EVENT) { - let (prev_best_ask_tick, best_ask_tick) = self.depth.add_sell_order( - self.data[row_num].order_id, - self.data[row_num].px, - self.data[row_num].qty, - self.data[row_num].exch_ts, - )?; + } else if event.is(EXCH_ASK_ADD_ORDER_EVENT) { + let (prev_best_ask_tick, best_ask_tick) = + self.depth + .add_sell_order(event.order_id, event.px, event.qty, event.exch_ts)?; self.queue_model - .add_market_feed_order(&self.data[row_num], &self.depth)?; + .add_market_feed_order(&event, &self.depth)?; if best_ask_tick < prev_best_ask_tick { - self.fill_bid_orders_by_crossing( - prev_best_ask_tick, - best_ask_tick, - self.data[row_num].exch_ts, - )?; + self.fill_bid_orders_by_crossing(prev_best_ask_tick, best_ask_tick, event.exch_ts)?; } - } else if self.data[row_num].is(EXCH_MODIFY_ORDER_EVENT) { - let (side, prev_best_tick, best_tick) = self.depth.modify_order( - self.data[row_num].order_id, - self.data[row_num].px, - self.data[row_num].qty, - self.data[row_num].exch_ts, - )?; - self.queue_model.modify_market_feed_order( - self.data[row_num].order_id, - &self.data[row_num], - &self.depth, - )?; + } else if event.is(EXCH_MODIFY_ORDER_EVENT) { + let (side, prev_best_tick, best_tick) = + self.depth + .modify_order(event.order_id, event.px, event.qty, event.exch_ts)?; + self.queue_model + .modify_market_feed_order(event.order_id, &event, &self.depth)?; if side == Side::Buy { if best_tick > prev_best_tick { - self.fill_ask_orders_by_crossing( - prev_best_tick, - best_tick, - self.data[row_num].exch_ts, - )?; + self.fill_ask_orders_by_crossing(prev_best_tick, best_tick, event.exch_ts)?; } } else if best_tick < prev_best_tick { - self.fill_bid_orders_by_crossing( - prev_best_tick, - best_tick, - self.data[row_num].exch_ts, - )?; + self.fill_bid_orders_by_crossing(prev_best_tick, best_tick, event.exch_ts)?; } - } else if self.data[row_num].is(EXCH_CANCEL_ORDER_EVENT) { - let _ = self - .depth - .delete_order(self.data[row_num].order_id, self.data[row_num].exch_ts)?; + } else if event.is(EXCH_CANCEL_ORDER_EVENT) { + let _ = self.depth.delete_order(event.order_id, event.exch_ts)?; self.queue_model - .cancel_market_feed_order(self.data[row_num].order_id, &self.depth)?; - } else if self.data[row_num].is(EXCH_FILL_EVENT) { + .cancel_market_feed_order(event.order_id, &self.depth)?; + } else if event.is(EXCH_FILL_EVENT) { let filled = self .queue_model - .fill_market_feed_order::(self.data[row_num].order_id, &self.depth)?; - let timestamp = self.data[row_num].exch_ts; + .fill_market_feed_order::(event.order_id, &self.depth)?; + let timestamp = event.exch_ts; for mut order in filled { let price_tick = order.price_tick; self.fill(&mut order, timestamp, true, price_tick)?; } } - // Checks - let mut next_ts = 0; - for rn in (self.row_num + 1)..self.data.len() { - if self.data[rn].is(EXCH_EVENT) { - self.row_num = rn; - next_ts = self.data[rn].exch_ts; - break; - } - } - - if next_ts <= 0 { - let next_data = self.reader.next_data()?; - let next_row = &next_data[0]; - next_ts = next_row.exch_ts; - let data = mem::replace(&mut self.data, next_data); - self.reader.release(data); - self.row_num = 0; - } - Ok((next_ts, i64::MAX)) + Ok(()) } +} +impl OrderConsumer for L3NoPartialFillExchange +where + AT: AssetType, + LM: LatencyModel, + QM: L3QueueModel, + MD: L3MarketDepth, + FM: FeeModel, + BacktestError: From<::Error>, +{ fn process_recv_order( &mut self, timestamp: i64, @@ -455,11 +409,9 @@ where } Ok(false) } - fn earliest_recv_order_timestamp(&self) -> i64 { self.orders_from.earliest_timestamp().unwrap_or(i64::MAX) } - fn earliest_send_order_timestamp(&self) -> i64 { self.orders_to.earliest_timestamp().unwrap_or(i64::MAX) } diff --git a/hftbacktest/src/backtest/proc/local.rs b/hftbacktest/src/backtest/proc/local.rs index 980c30c0..db4b08af 100644 --- a/hftbacktest/src/backtest/proc/local.rs +++ b/hftbacktest/src/backtest/proc/local.rs @@ -1,15 +1,18 @@ use std::{ collections::{hash_map::Entry, HashMap}, mem, + sync::mpsc::Receiver, }; +use bus::BusReader; + use crate::{ backtest::{ assettype::AssetType, - data::{Data, Reader}, + data::{Data, EventBusMessage, EventConsumer, Reader, TimestampedEventQueue}, models::{FeeModel, LatencyModel}, order::OrderBus, - proc::{LocalProcessor, Processor}, + proc::{LocalProcessor, OrderConsumer, Processor}, state::State, BacktestError, }, @@ -34,6 +37,7 @@ use crate::{ LOCAL_TRADE_EVENT, }, }; +use crate::backtest::data::EventBusReader; /// The local model. pub struct Local @@ -43,8 +47,8 @@ where MD: MarketDepth, FM: FeeModel, { - reader: Reader, - data: Data, + reader: EventBusReader, + next: Option, row_num: usize, orders: HashMap, orders_to: OrderBus, @@ -66,7 +70,7 @@ where { /// Constructs an instance of `Local`. pub fn new( - reader: Reader, + receiver: EventBusReader, depth: MD, state: State, order_latency: LM, @@ -75,8 +79,8 @@ where orders_from: OrderBus, ) -> Self { Self { - reader, - data: Data::empty(), + reader: receiver, + next: None, row_num: 0, orders: Default::default(), orders_to, @@ -241,27 +245,38 @@ where } } -impl Processor for Local +impl TimestampedEventQueue for Local where AT: AssetType, LM: LatencyModel, MD: MarketDepth + L2MarketDepth, FM: FeeModel, { - fn initialize_data(&mut self) -> Result { - self.data = self.reader.next_data()?; - for rn in 0..self.data.len() { - if self.data[rn].is(LOCAL_EVENT) { - self.row_num = rn; - let tmp = self.data[rn].local_ts; - return Ok(tmp); - } - } - Err(BacktestError::EndOfData) + fn next_event(&mut self) -> Option { + self.reader.next() + } + + fn peek_event(&mut self) -> Option<&Event> { + self.reader.peek() + } + + fn event_time(value: &Event) -> i64 { + value.local_ts + } +} + +impl EventConsumer for Local +where + AT: AssetType, + LM: LatencyModel, + MD: MarketDepth + L2MarketDepth, + FM: FeeModel, +{ + fn is_event_relevant(event: &Event) -> bool { + event.is(LOCAL_EVENT) } - fn process_data(&mut self) -> Result<(i64, i64), BacktestError> { - let ev = &self.data[self.row_num]; + fn process_event(&mut self, ev: Event) -> Result<(), BacktestError> { // Processes a depth event if ev.is(LOCAL_BID_DEPTH_CLEAR_EVENT) { self.depth.clear_depth(Side::Buy, ev.px); @@ -282,28 +297,17 @@ where // Stores the current feed latency self.last_feed_latency = Some((ev.exch_ts, ev.local_ts)); - // Checks - let mut next_ts = 0; - for rn in (self.row_num + 1)..self.data.len() { - if self.data[rn].is(LOCAL_EVENT) { - self.row_num = rn; - next_ts = self.data[rn].local_ts; - break; - } - } - - if next_ts <= 0 { - let next_data = self.reader.next_data()?; - let next_row = &next_data[0]; - next_ts = next_row.local_ts; - let data = mem::replace(&mut self.data, next_data); - self.reader.release(data); - self.row_num = 0; - } - - Ok((next_ts, i64::MAX)) + Ok(()) } +} +impl OrderConsumer for Local +where + AT: AssetType, + LM: LatencyModel, + MD: MarketDepth + L2MarketDepth, + FM: FeeModel, +{ fn process_recv_order( &mut self, timestamp: i64, @@ -338,11 +342,9 @@ where } Ok(wait_resp_order_received) } - fn earliest_recv_order_timestamp(&self) -> i64 { self.orders_from.earliest_timestamp().unwrap_or(i64::MAX) } - fn earliest_send_order_timestamp(&self) -> i64 { self.orders_to.earliest_timestamp().unwrap_or(i64::MAX) } diff --git a/hftbacktest/src/backtest/proc/mod.rs b/hftbacktest/src/backtest/proc/mod.rs index cfcc0125..66b6ee44 100644 --- a/hftbacktest/src/backtest/proc/mod.rs +++ b/hftbacktest/src/backtest/proc/mod.rs @@ -91,7 +91,7 @@ where } /// Processes the historical feed data and the order interaction. -pub trait Processor { +pub trait Processor: OrderConsumer { /// Prepares to process the data. This is invoked when the backtesting is initiated. /// If successful, returns the timestamp of the first event. fn initialize_data(&mut self) -> Result; @@ -100,7 +100,9 @@ pub trait Processor { /// event to be processed in the data. /// If successful, returns the timestamp of the next event. fn process_data(&mut self) -> Result<(i64, i64), BacktestError>; +} +pub trait OrderConsumer { /// Processes an order upon receipt. This is invoked when the backtesting time reaches the order /// receipt timestamp. /// Returns Ok(true) if the order with `wait_resp_order_id` is received and processed. @@ -109,10 +111,8 @@ pub trait Processor { timestamp: i64, wait_resp_order_id: Option, ) -> Result; - /// Returns the foremost timestamp at which an order is to be received by this processor. fn earliest_recv_order_timestamp(&self) -> i64; - /// Returns the foremost timestamp at which an order sent by this processor is to be received by /// the corresponding processor. fn earliest_send_order_timestamp(&self) -> i64; diff --git a/hftbacktest/src/backtest/proc/nopartialfillexchange.rs b/hftbacktest/src/backtest/proc/nopartialfillexchange.rs index e54455b7..0945a1ea 100644 --- a/hftbacktest/src/backtest/proc/nopartialfillexchange.rs +++ b/hftbacktest/src/backtest/proc/nopartialfillexchange.rs @@ -6,13 +6,15 @@ use std::{ rc::Rc, }; +use bus::BusReader; + use crate::{ backtest::{ assettype::AssetType, - data::{Data, Reader}, + data::{Data, EventBusMessage, EventConsumer, Reader, TimestampedEventQueue}, models::{FeeModel, LatencyModel, QueueModel}, order::OrderBus, - proc::Processor, + proc::{OrderConsumer, Processor}, state::State, BacktestError, }, @@ -36,6 +38,7 @@ use crate::{ EXCH_SELL_TRADE_EVENT, }, }; +use crate::backtest::data::EventBusReader; /// The exchange model without partial fills. /// @@ -70,8 +73,8 @@ where MD: MarketDepth, FM: FeeModel, { - reader: Reader, - data: Data, + reader: EventBusReader, + next: Option, row_num: usize, // key: order_id, value: Order @@ -101,7 +104,7 @@ where { /// Constructs an instance of `NoPartialFillExchange`. pub fn new( - reader: Reader, + reader: EventBusReader, depth: MD, state: State, order_latency: LM, @@ -111,7 +114,7 @@ where ) -> Self { Self { reader, - data: Data::empty(), + next: None, row_num: 0, orders: Default::default(), buy_orders: Default::default(), @@ -595,7 +598,7 @@ where } } -impl Processor for NoPartialFillExchange +impl EventConsumer for NoPartialFillExchange where AT: AssetType, LM: LatencyModel, @@ -603,54 +606,36 @@ where MD: MarketDepth + L2MarketDepth, FM: FeeModel, { - fn initialize_data(&mut self) -> Result { - self.data = self.reader.next_data()?; - for rn in 0..self.data.len() { - if self.data[rn].is(EXCH_EVENT) { - self.row_num = rn; - return Ok(self.data[rn].exch_ts); - } - } - Err(BacktestError::EndOfData) + fn is_event_relevant(event: &Event) -> bool { + event.is(EXCH_EVENT) } - fn process_data(&mut self) -> Result<(i64, i64), BacktestError> { - let row_num = self.row_num; - if self.data[row_num].is(EXCH_BID_DEPTH_CLEAR_EVENT) { - self.depth.clear_depth(Side::Buy, self.data[row_num].px); - } else if self.data[row_num].is(EXCH_ASK_DEPTH_CLEAR_EVENT) { - self.depth.clear_depth(Side::Sell, self.data[row_num].px); - } else if self.data[row_num].is(EXCH_DEPTH_CLEAR_EVENT) { + fn process_event(&mut self, event: Event) -> Result<(), BacktestError> { + if event.is(EXCH_BID_DEPTH_CLEAR_EVENT) { + self.depth.clear_depth(Side::Buy, event.px); + } else if event.is(EXCH_ASK_DEPTH_CLEAR_EVENT) { + self.depth.clear_depth(Side::Sell, event.px); + } else if event.is(EXCH_DEPTH_CLEAR_EVENT) { self.depth.clear_depth(Side::None, 0.0); - } else if self.data[row_num].is(EXCH_BID_DEPTH_EVENT) - || self.data[row_num].is(EXCH_BID_DEPTH_SNAPSHOT_EVENT) - { + } else if event.is(EXCH_BID_DEPTH_EVENT) || event.is(EXCH_BID_DEPTH_SNAPSHOT_EVENT) { let (price_tick, prev_best_bid_tick, best_bid_tick, prev_qty, new_qty, timestamp) = - self.depth.update_bid_depth( - self.data[row_num].px, - self.data[row_num].qty, - self.data[row_num].exch_ts, - ); + self.depth + .update_bid_depth(event.px, event.qty, event.exch_ts); self.on_bid_qty_chg(price_tick, prev_qty, new_qty); if best_bid_tick > prev_best_bid_tick { self.on_best_bid_update(prev_best_bid_tick, best_bid_tick, timestamp)?; } - } else if self.data[row_num].is(EXCH_ASK_DEPTH_EVENT) - || self.data[row_num].is(EXCH_ASK_DEPTH_SNAPSHOT_EVENT) - { + } else if event.is(EXCH_ASK_DEPTH_EVENT) || event.is(EXCH_ASK_DEPTH_SNAPSHOT_EVENT) { let (price_tick, prev_best_ask_tick, best_ask_tick, prev_qty, new_qty, timestamp) = - self.depth.update_ask_depth( - self.data[row_num].px, - self.data[row_num].qty, - self.data[row_num].exch_ts, - ); + self.depth + .update_ask_depth(event.px, event.qty, event.exch_ts); self.on_ask_qty_chg(price_tick, prev_qty, new_qty); if best_ask_tick < prev_best_ask_tick { self.on_best_ask_update(prev_best_ask_tick, best_ask_tick, timestamp)?; } - } else if self.data[row_num].is(EXCH_BUY_TRADE_EVENT) { - let price_tick = (self.data[row_num].px / self.depth.tick_size()).round() as i64; - let qty = self.data[row_num].qty; + } else if event.is(EXCH_BUY_TRADE_EVENT) { + let price_tick = (event.px / self.depth.tick_size()).round() as i64; + let qty = event.qty; { let orders = self.orders.clone(); let mut orders_borrowed = orders.borrow_mut(); @@ -659,12 +644,7 @@ where { for (_, order) in orders_borrowed.iter_mut() { if order.side == Side::Sell { - self.check_if_sell_filled( - order, - price_tick, - qty, - self.data[row_num].exch_ts, - )?; + self.check_if_sell_filled(order, price_tick, qty, event.exch_ts)?; } } } else { @@ -672,21 +652,16 @@ where if let Some(order_ids) = self.sell_orders.get(&t) { for order_id in order_ids.clone().iter() { let order = orders_borrowed.get_mut(order_id).unwrap(); - self.check_if_sell_filled( - order, - price_tick, - qty, - self.data[row_num].exch_ts, - )?; + self.check_if_sell_filled(order, price_tick, qty, event.exch_ts)?; } } } } } self.remove_filled_orders(); - } else if self.data[row_num].is(EXCH_SELL_TRADE_EVENT) { - let price_tick = (self.data[row_num].px / self.depth.tick_size()).round() as i64; - let qty = self.data[row_num].qty; + } else if event.is(EXCH_SELL_TRADE_EVENT) { + let price_tick = (event.px / self.depth.tick_size()).round() as i64; + let qty = event.qty; { let orders = self.orders.clone(); let mut orders_borrowed = orders.borrow_mut(); @@ -695,12 +670,7 @@ where { for (_, order) in orders_borrowed.iter_mut() { if order.side == Side::Buy { - self.check_if_buy_filled( - order, - price_tick, - qty, - self.data[row_num].exch_ts, - )?; + self.check_if_buy_filled(order, price_tick, qty, event.exch_ts)?; } } } else { @@ -708,12 +678,7 @@ where if let Some(order_ids) = self.buy_orders.get(&t) { for order_id in order_ids.clone().iter() { let order = orders_borrowed.get_mut(order_id).unwrap(); - self.check_if_buy_filled( - order, - price_tick, - qty, - self.data[row_num].exch_ts, - )?; + self.check_if_buy_filled(order, price_tick, qty, event.exch_ts)?; } } } @@ -722,27 +687,67 @@ where self.remove_filled_orders(); } - // Checks - let mut next_ts = 0; - for rn in (self.row_num + 1)..self.data.len() { - if self.data[rn].is(EXCH_EVENT) { - self.row_num = rn; - next_ts = self.data[rn].exch_ts; - break; + Ok(()) + } +} + +impl Processor for ExchT +where + ExchT: OrderConsumer + EventConsumer + TimestampedEventQueue, +{ + fn initialize_data(&mut self) -> Result { + while let Some(event) = self.peek_event() { + if Self::is_event_relevant(event) { + let ts = Self::event_time(event); + return Ok(ts); } - } - if next_ts <= 0 { - let next_data = self.reader.next_data()?; - let next_row = &next_data[0]; - next_ts = next_row.exch_ts; - let data = mem::replace(&mut self.data, next_data); - self.reader.release(data); - self.row_num = 0; + // Consume the peeked event. + let _ = self.next_event(); } + + Err(BacktestError::EndOfData) + } + + fn process_data(&mut self) -> Result<(i64, i64), BacktestError> { + let current = self.next_event().ok_or(BacktestError::EndOfData)?; + self.process_event(current)?; + let next = self.peek_event().ok_or(BacktestError::EndOfData)?; + let next_ts = Self::event_time(&next); + Ok((next_ts, i64::MAX)) } +} +impl TimestampedEventQueue for NoPartialFillExchange +where + AT: AssetType, + LM: LatencyModel, + QM: QueueModel, + MD: MarketDepth + L2MarketDepth, + FM: FeeModel, +{ + fn next_event(&mut self) -> Option { + self.reader.next() + } + + fn peek_event(&mut self) -> Option<&Event> { + self.reader.peek() + } + + fn event_time(value: &Event) -> i64 { + value.exch_ts + } +} + +impl OrderConsumer for NoPartialFillExchange +where + AT: AssetType, + LM: LatencyModel, + QM: QueueModel, + MD: MarketDepth + L2MarketDepth, + FM: FeeModel, +{ fn process_recv_order( &mut self, timestamp: i64, @@ -761,11 +766,9 @@ where } Ok(false) } - fn earliest_recv_order_timestamp(&self) -> i64 { self.orders_from.earliest_timestamp().unwrap_or(i64::MAX) } - fn earliest_send_order_timestamp(&self) -> i64 { self.orders_to.earliest_timestamp().unwrap_or(i64::MAX) } diff --git a/hftbacktest/src/backtest/proc/partialfillexchange.rs b/hftbacktest/src/backtest/proc/partialfillexchange.rs index 6b834c56..f0ca00b1 100644 --- a/hftbacktest/src/backtest/proc/partialfillexchange.rs +++ b/hftbacktest/src/backtest/proc/partialfillexchange.rs @@ -6,13 +6,15 @@ use std::{ rc::Rc, }; +use bus::BusReader; + use crate::{ backtest::{ assettype::AssetType, - data::{Data, Reader}, + data::{Data, EventBusMessage, EventConsumer, Reader, TimestampedEventQueue}, models::{FeeModel, LatencyModel, QueueModel}, order::OrderBus, - proc::Processor, + proc::{OrderConsumer, Processor}, state::State, BacktestError, }, @@ -36,6 +38,7 @@ use crate::{ EXCH_SELL_TRADE_EVENT, }, }; +use crate::backtest::data::EventBusReader; /// The exchange model with partial fills. /// @@ -84,8 +87,8 @@ where MD: MarketDepth, FM: FeeModel, { - reader: Reader, - data: Data, + reader: EventBusReader, + next: Option, row_num: usize, // key: order_id, value: Order @@ -115,7 +118,7 @@ where { /// Constructs an instance of `PartialFillExchange`. pub fn new( - reader: Reader, + reader: EventBusReader, depth: MD, state: State, order_latency: LM, @@ -125,7 +128,7 @@ where ) -> Self { Self { reader, - data: Data::empty(), + next: None, row_num: 0, orders: Default::default(), buy_orders: Default::default(), @@ -779,7 +782,7 @@ where } } -impl Processor for PartialFillExchange +impl TimestampedEventQueue for PartialFillExchange where AT: AssetType, LM: LatencyModel, @@ -787,54 +790,57 @@ where MD: MarketDepth + L2MarketDepth, FM: FeeModel, { - fn initialize_data(&mut self) -> Result { - self.data = self.reader.next_data()?; - for rn in 0..self.data.len() { - if self.data[rn].is(EXCH_EVENT) { - self.row_num = rn; - return Ok(self.data[rn].exch_ts); - } - } - Err(BacktestError::EndOfData) + fn next_event(&mut self) -> Option { + self.reader.next() + } + + fn peek_event(&mut self) -> Option<&Event> { + self.reader.peek() + } + + fn event_time(value: &Event) -> i64 { + value.exch_ts + } +} + +impl EventConsumer for PartialFillExchange +where + AT: AssetType, + LM: LatencyModel, + QM: QueueModel, + MD: MarketDepth + L2MarketDepth, + FM: FeeModel, +{ + fn is_event_relevant(event: &Event) -> bool { + event.is(EXCH_EVENT) } - fn process_data(&mut self) -> Result<(i64, i64), BacktestError> { - let row_num = self.row_num; - if self.data[row_num].is(EXCH_BID_DEPTH_CLEAR_EVENT) { - self.depth.clear_depth(Side::Buy, self.data[row_num].px); - } else if self.data[row_num].is(EXCH_ASK_DEPTH_CLEAR_EVENT) { - self.depth.clear_depth(Side::Sell, self.data[row_num].px); - } else if self.data[row_num].is(EXCH_DEPTH_CLEAR_EVENT) { + fn process_event(&mut self, event: Event) -> Result<(), BacktestError> { + if event.is(EXCH_BID_DEPTH_CLEAR_EVENT) { + self.depth.clear_depth(Side::Buy, event.px); + } else if event.is(EXCH_ASK_DEPTH_CLEAR_EVENT) { + self.depth.clear_depth(Side::Sell, event.px); + } else if event.is(EXCH_DEPTH_CLEAR_EVENT) { self.depth.clear_depth(Side::None, 0.0); - } else if self.data[row_num].is(EXCH_BID_DEPTH_EVENT) - || self.data[row_num].is(EXCH_BID_DEPTH_SNAPSHOT_EVENT) - { + } else if event.is(EXCH_BID_DEPTH_EVENT) || event.is(EXCH_BID_DEPTH_SNAPSHOT_EVENT) { let (price_tick, prev_best_bid_tick, best_bid_tick, prev_qty, new_qty, timestamp) = - self.depth.update_bid_depth( - self.data[row_num].px, - self.data[row_num].qty, - self.data[row_num].exch_ts, - ); + self.depth + .update_bid_depth(event.px, event.qty, event.exch_ts); self.on_bid_qty_chg(price_tick, prev_qty, new_qty); if best_bid_tick > prev_best_bid_tick { self.on_best_bid_update(prev_best_bid_tick, best_bid_tick, timestamp)?; } - } else if self.data[row_num].is(EXCH_ASK_DEPTH_EVENT) - || self.data[row_num].is(EXCH_ASK_DEPTH_SNAPSHOT_EVENT) - { + } else if event.is(EXCH_ASK_DEPTH_EVENT) || event.is(EXCH_ASK_DEPTH_SNAPSHOT_EVENT) { let (price_tick, prev_best_ask_tick, best_ask_tick, prev_qty, new_qty, timestamp) = - self.depth.update_ask_depth( - self.data[row_num].px, - self.data[row_num].qty, - self.data[row_num].exch_ts, - ); + self.depth + .update_ask_depth(event.px, event.qty, event.exch_ts); self.on_ask_qty_chg(price_tick, prev_qty, new_qty); if best_ask_tick < prev_best_ask_tick { self.on_best_ask_update(prev_best_ask_tick, best_ask_tick, timestamp)?; } - } else if self.data[row_num].is(EXCH_BUY_TRADE_EVENT) { - let price_tick = (self.data[row_num].px / self.depth.tick_size()).round() as i64; - let qty = self.data[row_num].qty; + } else if event.is(EXCH_BUY_TRADE_EVENT) { + let price_tick = (event.px / self.depth.tick_size()).round() as i64; + let qty = event.qty; { let orders = self.orders.clone(); let mut orders_borrowed = orders.borrow_mut(); @@ -843,12 +849,7 @@ where { for (_, order) in orders_borrowed.iter_mut() { if order.side == Side::Sell { - self.check_if_sell_filled( - order, - price_tick, - qty, - self.data[row_num].exch_ts, - )?; + self.check_if_sell_filled(order, price_tick, qty, event.exch_ts)?; } } } else { @@ -856,21 +857,16 @@ where if let Some(order_ids) = self.sell_orders.get(&t) { for order_id in order_ids.clone().iter() { let order = orders_borrowed.get_mut(order_id).unwrap(); - self.check_if_sell_filled( - order, - price_tick, - qty, - self.data[row_num].exch_ts, - )?; + self.check_if_sell_filled(order, price_tick, qty, event.exch_ts)?; } } } } } self.remove_filled_orders(); - } else if self.data[row_num].is(EXCH_SELL_TRADE_EVENT) { - let price_tick = (self.data[row_num].px / self.depth.tick_size()).round() as i64; - let qty = self.data[row_num].qty; + } else if event.is(EXCH_SELL_TRADE_EVENT) { + let price_tick = (event.px / self.depth.tick_size()).round() as i64; + let qty = event.qty; { let orders = self.orders.clone(); let mut orders_borrowed = orders.borrow_mut(); @@ -879,12 +875,7 @@ where { for (_, order) in orders_borrowed.iter_mut() { if order.side == Side::Buy { - self.check_if_buy_filled( - order, - price_tick, - qty, - self.data[row_num].exch_ts, - )?; + self.check_if_buy_filled(order, price_tick, qty, event.exch_ts)?; } } } else { @@ -892,12 +883,7 @@ where if let Some(order_ids) = self.buy_orders.get(&t) { for order_id in order_ids.clone().iter() { let order = orders_borrowed.get_mut(order_id).unwrap(); - self.check_if_buy_filled( - order, - price_tick, - qty, - self.data[row_num].exch_ts, - )?; + self.check_if_buy_filled(order, price_tick, qty, event.exch_ts)?; } } } @@ -906,27 +892,18 @@ where self.remove_filled_orders(); } - // Checks - let mut next_ts = 0; - for rn in (self.row_num + 1)..self.data.len() { - if self.data[rn].is(EXCH_EVENT) { - self.row_num = rn; - next_ts = self.data[rn].exch_ts; - break; - } - } - - if next_ts <= 0 { - let next_data = self.reader.next_data()?; - let next_row = &next_data[0]; - next_ts = next_row.exch_ts; - let data = mem::replace(&mut self.data, next_data); - self.reader.release(data); - self.row_num = 0; - } - Ok((next_ts, i64::MAX)) + Ok(()) } +} +impl OrderConsumer for PartialFillExchange +where + AT: AssetType, + LM: LatencyModel, + QM: QueueModel, + MD: MarketDepth + L2MarketDepth, + FM: FeeModel, +{ fn process_recv_order( &mut self, timestamp: i64, @@ -945,11 +922,9 @@ where } Ok(false) } - fn earliest_recv_order_timestamp(&self) -> i64 { self.orders_from.earliest_timestamp().unwrap_or(i64::MAX) } - fn earliest_send_order_timestamp(&self) -> i64 { self.orders_to.earliest_timestamp().unwrap_or(i64::MAX) } From 018d0fe80a990955b8f178ab5c2600035b5ea817 Mon Sep 17 00:00:00 2001 From: Gary Tierney Date: Fri, 23 Aug 2024 12:51:20 +0100 Subject: [PATCH 2/3] Support incremental reads of .npz data --- hftbacktest/src/backtest/data/bus.rs | 68 ++++++----- hftbacktest/src/backtest/data/npy/mod.rs | 140 +++++++++++++++++++++++ hftbacktest/src/backtest/data/queue.rs | 0 3 files changed, 181 insertions(+), 27 deletions(-) create mode 100644 hftbacktest/src/backtest/data/queue.rs diff --git a/hftbacktest/src/backtest/data/bus.rs b/hftbacktest/src/backtest/data/bus.rs index 2e60d26e..ad3657b0 100644 --- a/hftbacktest/src/backtest/data/bus.rs +++ b/hftbacktest/src/backtest/data/bus.rs @@ -1,11 +1,15 @@ -use std::{io, io::ErrorKind}; -use std::iter::Peekable; +use std::{fs::File, io, io::ErrorKind, iter::Peekable, num::NonZeroUsize}; + use bus::{Bus, BusIntoIter, BusReader}; use tracing::{error, info, info_span}; - -use crate::backtest::{ - data::{read_npy_file, read_npz_file, Data, NpyDTyped}, - BacktestError, +use zip::ZipArchive; + +use crate::{ + backtest::{ + data::{npy::NpyReader, read_npy_file, read_npz_file, Data, NpyDTyped}, + BacktestError, + }, + types::Event, }; #[derive(Copy, Clone)] @@ -21,7 +25,7 @@ pub struct EventBusReader { impl EventBusReader { pub fn new(reader: BusReader>) -> Self { Self { - reader: reader.into_iter().peekable() + reader: reader.into_iter().peekable(), } } @@ -71,6 +75,35 @@ fn load_data( Ok(data) } +#[tracing::instrument(skip(bus))] +pub fn replay_event_file( + path: String, + bus: &mut Bus>, +) -> std::io::Result<()> { + if !path.ends_with(".npz") { + todo!("Only .npz is supported in this branch") + } + + let mut archive = ZipArchive::new(File::open(path)?)?; + let mut reader = NpyReader::<_, EventT>::new( + archive.by_name("data.npy")?, + NonZeroUsize::new(512).unwrap(), + )?; + + loop { + let read = reader.read(|event| { + bus.broadcast(EventBusMessage::Item(event.clone())); + })?; + + // EOF + if read == 0 { + break; + } + } + + Ok(()) +} + #[tracing::instrument(skip_all)] pub fn replay_events_to_bus( mut bus: Bus>, @@ -80,26 +113,7 @@ pub fn replay_events_to_bus( let source_load_span = info_span!("load_data", source = &source); let _source_load_span = source_load_span.entered(); - let data = load_data::(source); - - match data { - Ok(data) => { - info!( - records = data.len(), - "found {} events in data source", - data.len() - ); - - for row in 0..data.len() { - bus.broadcast(EventBusMessage::Item(data[row].clone())); - } - } - Err(e) => { - error!("encountered error loading data source: {}", e); - // TODO: handle as an error. - break; - } - } + replay_event_file(source, &mut bus).unwrap(); } bus.broadcast(EventBusMessage::EndOfData); diff --git a/hftbacktest/src/backtest/data/npy/mod.rs b/hftbacktest/src/backtest/data/npy/mod.rs index 36c17768..76fa309d 100644 --- a/hftbacktest/src/backtest/data/npy/mod.rs +++ b/hftbacktest/src/backtest/data/npy/mod.rs @@ -1,6 +1,9 @@ use std::{ + alloc::{alloc, dealloc, Layout}, fs::File, io::{Error, ErrorKind, Read, Write}, + marker::PhantomData, + num::NonZeroUsize, }; use crate::backtest::data::{npy::parser::Value, Data, DataPtr, POD}; @@ -162,6 +165,143 @@ fn check_field_consistency( Ok(discrepancies) } +pub struct NpyReader { + reader: R, + + /// Input buffer aligned to [T]. + buffer: *mut u8, + + /// Current buffer position in bytes. + buffer_pos: usize, + + /// Number of bytes available in the buffer for reading. + buffer_filled: usize, + + /// Maximum number of bytes the buffer of this reader can hold. + buffer_capacity: usize, + + phantom_data: PhantomData, +} + +impl Drop for NpyReader { + fn drop(&mut self) { + unsafe { + dealloc( + self.buffer, + Layout::from_size_align_unchecked(self.buffer_capacity, align_of::()), + ) + } + } +} + +impl NpyReader { + pub fn new(mut reader: R, buffer_size: NonZeroUsize) -> std::io::Result { + let header = read_npy_header::(&mut reader)?; + + if T::descr() != header.descr { + match check_field_consistency(&T::descr(), &header.descr) { + Ok(diff) => { + println!("Warning: Field name mismatch - {:?}", diff); + } + Err(err) => { + return Err(Error::new(ErrorKind::InvalidData, err)); + } + } + } + + let buffer_capacity = buffer_size.get() * size_of::(); + let buffer = unsafe { + alloc(Layout::from_size_align_unchecked( + buffer_capacity, + align_of::(), + )) + }; + + Ok(Self { + buffer, + buffer_pos: 0, + buffer_filled: 0, + buffer_capacity, + reader, + phantom_data: Default::default(), + }) + } + + pub fn read(&mut self, mut collector: impl FnMut(&T)) -> std::io::Result { + if self.buffer_pos == self.buffer_capacity { + self.buffer_pos = 0; + self.buffer_filled = 0; + } + + let io_buf = unsafe { std::slice::from_raw_parts_mut(self.buffer, self.buffer_capacity) }; + let io_buf_cursor = &mut io_buf[self.buffer_pos..]; + + let io_buf_unconsumed = self.buffer_filled - self.buffer_pos; + let bytes_read = self.reader.read(&mut io_buf_cursor[io_buf_unconsumed..])?; + let items_read = (io_buf_unconsumed + bytes_read) / size_of::(); + let io_buf_consumed = items_read * size_of::(); + + let item_buf: &[T] = unsafe { + std::slice::from_raw_parts( + self.buffer.offset(self.buffer_pos as isize).cast(), + items_read, + ) + }; + + for item in item_buf { + collector(item); + } + + self.buffer_filled += bytes_read; + self.buffer_pos += io_buf_consumed; + + Ok(items_read) + } +} + +pub fn read_npy_header( + reader: &mut R, +) -> std::io::Result { + let mut buf = Vec::with_capacity(10); + let mut magic = reader.take(10); + magic.read_to_end(&mut buf)?; + + if buf[0..6].to_vec() != b"\x93NUMPY" { + return Err(Error::new( + ErrorKind::InvalidData, + "must start with \\x93NUMPY", + )); + } + if buf[6..8].to_vec() != b"\x01\x00" { + return Err(Error::new( + ErrorKind::InvalidData, + "support only version 1.0", + )); + } + + let header_len = u16::from_le_bytes(buf[8..10].try_into().unwrap()) as usize; + let reader = magic.into_inner(); + + reader.take(header_len as u64).read_to_end(&mut buf)?; + + let header = String::from_utf8(buf[10..(10 + header_len)].to_vec()) + .map_err(|err| Error::new(ErrorKind::InvalidData, err.to_string()))?; + let header = NpyHeader::from_header(&header)?; + + if header.fortran_order { + return Err(Error::new( + ErrorKind::InvalidData, + "fortran order is unsupported", + )); + } + + if header.shape.len() != 1 { + return Err(Error::new(ErrorKind::InvalidData, "only 1-d is supported")); + } + + Ok(header) +} + pub fn read_npy( reader: &mut R, size: usize, diff --git a/hftbacktest/src/backtest/data/queue.rs b/hftbacktest/src/backtest/data/queue.rs new file mode 100644 index 00000000..e69de29b From caf8d5e9bf53907caee93f39397ba7bc5a7a67a7 Mon Sep 17 00:00:00 2001 From: Gary Tierney Date: Fri, 23 Aug 2024 21:19:35 +0100 Subject: [PATCH 3/3] Fix unsoundness issues with NpyReader buffer --- hftbacktest/src/backtest/data/npy/mod.rs | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/hftbacktest/src/backtest/data/npy/mod.rs b/hftbacktest/src/backtest/data/npy/mod.rs index 76fa309d..74374d5c 100644 --- a/hftbacktest/src/backtest/data/npy/mod.rs +++ b/hftbacktest/src/backtest/data/npy/mod.rs @@ -181,6 +181,8 @@ pub struct NpyReader { buffer_capacity: usize, phantom_data: PhantomData, + + layout: Layout, } impl Drop for NpyReader { @@ -188,7 +190,7 @@ impl Drop for NpyReader { unsafe { dealloc( self.buffer, - Layout::from_size_align_unchecked(self.buffer_capacity, align_of::()), + self.layout, ) } } @@ -209,19 +211,20 @@ impl NpyReader { } } + let layout = Layout::array::(buffer_size.get()).map_err(|_| Error::other("Buffer size is too large"))?; let buffer_capacity = buffer_size.get() * size_of::(); - let buffer = unsafe { - alloc(Layout::from_size_align_unchecked( - buffer_capacity, - align_of::(), - )) - }; + let buffer = unsafe { alloc(layout) }; + if buffer.is_null() { + return Err(std::io::Error::new(ErrorKind::OutOfMemory, "unable to allocate buffer")) + } + Ok(Self { buffer, buffer_pos: 0, buffer_filled: 0, buffer_capacity, + layout, reader, phantom_data: Default::default(), }) @@ -259,9 +262,7 @@ impl NpyReader { } } -pub fn read_npy_header( - reader: &mut R, -) -> std::io::Result { +pub fn read_npy_header(reader: &mut R) -> std::io::Result { let mut buf = Vec::with_capacity(10); let mut magic = reader.take(10); magic.read_to_end(&mut buf)?;