From 8ddf5ed9b4af0cd2e6ea140cdfe2d4c563dbd5d4 Mon Sep 17 00:00:00 2001 From: Shanin Roman Date: Wed, 31 Jan 2024 14:48:26 +0300 Subject: [PATCH] fix: use T aligned pointer in TempFdArray Signed-off-by: Shanin Roman --- src/collector.rs | 60 +++++++++++++++++++++++++++++++++++------------- src/report.rs | 6 +++-- 2 files changed, 48 insertions(+), 18 deletions(-) diff --git a/src/collector.rs b/src/collector.rs index c05fb18b..db9a9812 100644 --- a/src/collector.rs +++ b/src/collector.rs @@ -5,6 +5,7 @@ use std::convert::TryInto; use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::io::{Read, Seek, SeekFrom, Write}; +use std::mem::ManuallyDrop; use crate::frames::UnresolvedFrames; @@ -148,6 +149,7 @@ pub struct TempFdArray { file: NamedTempFile, buffer: Box<[T; BUFFER_LENGTH]>, buffer_index: usize, + flush_n: usize, } impl TempFdArray { @@ -162,6 +164,7 @@ impl TempFdArray { file, buffer, buffer_index: 0, + flush_n: 0, }) } } @@ -175,6 +178,7 @@ impl TempFdArray { BUFFER_LENGTH * std::mem::size_of::(), ) }; + self.flush_n += 1; self.file.write_all(buf)?; Ok(()) @@ -191,24 +195,50 @@ impl TempFdArray { Ok(()) } - fn try_iter(&self) -> std::io::Result> { - let mut file_vec = Vec::new(); - let mut file = self.file.reopen()?; - file.seek(SeekFrom::Start(0))?; - file.read_to_end(&mut file_vec)?; - file.seek(SeekFrom::End(0))?; + fn try_iter<'lt>(&'lt self, file_buffer_container: &'lt mut Option]>>) -> std::io::Result> { + let file_buffer = self.file_buffer()?; + let file_buffer = file_buffer_container.insert(file_buffer); Ok(TempFdArrayIterator { buffer: &self.buffer[0..self.buffer_index], - file_vec, + file_buffer, index: 0, }) } + + fn file_buffer(&self) -> std::io::Result]>> { + if self.flush_n == 0 { + return Ok(Vec::new().into_boxed_slice()) + } + + let mut file = self.file.reopen()?; + file.seek(SeekFrom::Start(0))?; + let file_buffer = unsafe { + // Get properly aligned pointer + let len = BUFFER_LENGTH * self.flush_n; + // Expect T to be non-ZST + let layout = std::alloc::Layout::array::>(len).unwrap(); + let ptr = std::alloc::alloc(layout); + if ptr.is_null() { + std::alloc::handle_alloc_error(layout); + } + // Populate with bytes + file.read_exact(std::slice::from_raw_parts_mut( + ptr, + len * std::mem::size_of::(), + ))?; + // Cast to proper type + Box::from_raw(std::ptr::slice_from_raw_parts_mut(ptr.cast::>(), len)) + }; + file.seek(SeekFrom::End(0))?; + + Ok(file_buffer) + } } pub struct TempFdArrayIterator<'a, T> { pub buffer: &'a [T], - pub file_vec: Vec, + pub file_buffer: &'a [ManuallyDrop], pub index: usize, } @@ -220,12 +250,9 @@ impl<'a, T> Iterator for TempFdArrayIterator<'a, T> { self.index += 1; Some(&self.buffer[self.index - 1]) } else { - let length = self.file_vec.len() / std::mem::size_of::(); - let ts = - unsafe { std::slice::from_raw_parts(self.file_vec.as_ptr() as *const T, length) }; - if self.index - self.buffer.len() < ts.len() { + if self.index - self.buffer.len() < self.file_buffer.len() { self.index += 1; - Some(&ts[self.index - self.buffer.len() - 1]) + Some(&self.file_buffer[self.index - self.buffer.len() - 1]) } else { None } @@ -256,8 +283,8 @@ impl Collector { Ok(()) } - pub fn try_iter(&self) -> std::io::Result>> { - Ok(self.map.iter().chain(self.temp_array.try_iter()?)) + pub fn try_iter<'lt>(&'lt self, file_buffer_store: &'lt mut Option>]>>) -> std::io::Result>> { + Ok(self.map.iter().chain(self.temp_array.try_iter(file_buffer_store)?)) } } @@ -343,7 +370,8 @@ mod tests { } } - collector.try_iter().unwrap().for_each(|entry| { + let mut file_buffer_store = None; + collector.try_iter(&mut file_buffer_store).unwrap().for_each(|entry| { test_utils::add_map(&mut real_map, entry); }); diff --git a/src/report.rs b/src/report.rs index 971cb8d6..82b3ec4d 100644 --- a/src/report.rs +++ b/src/report.rs @@ -69,7 +69,8 @@ impl<'a> ReportBuilder<'a> { Err(Error::CreatingError) } Ok(profiler) => { - profiler.data.try_iter()?.for_each(|entry| { + let mut file_buffer_store = None; + profiler.data.try_iter(&mut file_buffer_store)?.for_each(|entry| { let count = entry.count; if count > 0 { let key = &entry.item; @@ -107,7 +108,8 @@ impl<'a> ReportBuilder<'a> { Err(Error::CreatingError) } Ok(profiler) => { - profiler.data.try_iter()?.for_each(|entry| { + let mut file_buffer_store = None; + profiler.data.try_iter(&mut file_buffer_store)?.for_each(|entry| { let count = entry.count; if count > 0 { let mut key = Frames::from(entry.item.clone());