diff --git a/src/lib.rs b/src/lib.rs index 4953b1e..cec533e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ mod error; mod util; +use std::collections::hash_map::Entry; use std::ffi::OsString; use std::io::{Read, Write}; use std::mem::size_of; @@ -18,10 +19,13 @@ use std::os::windows::ffi::OsStringExt; use std::path::PathBuf; use std::ptr::{addr_of, addr_of_mut}; use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; +use std::sync::{Arc, Mutex, Weak}; use object::Object; use pdb_addr2line::pdb::PDB; use pdb_addr2line::ContextPdbData; +use rustc_hash::FxHashMap; use windows::core::{GUID, PCSTR, PSTR}; use windows::Win32::Foundation::{ CloseHandle, ERROR_SUCCESS, ERROR_WMI_INSTANCE_NOT_FOUND, HANDLE, INVALID_HANDLE_VALUE, @@ -48,17 +52,14 @@ pub use crate::error::{Error, Result}; // const MAX_STACK_DEPTH: usize = 192; const MAX_STACK_DEPTH: usize = 200; -/// map[array_of_stacktrace_addrs] = sample_count -type StackMap = rustc_hash::FxHashMap<[u64; MAX_STACK_DEPTH], u64>; - -/// Stateful context provided to `event_record_callback`. +/// Context for a specific process's trace. struct TraceContext { target_process_handle: HANDLE, - stack_counts_hashmap: StackMap, target_proc_pid: u32, - trace_running: AtomicBool, show_kernel_samples: bool, + /// map[array_of_stacktrace_addrs] = sample_count + stack_counts_hashmap: FxHashMap<[u64; MAX_STACK_DEPTH], u64>, /// (image_path, image_base, image_size) image_paths: Vec<(OsString, u64, u64)>, } @@ -78,7 +79,6 @@ impl TraceContext { target_process_handle, stack_counts_hashmap: Default::default(), target_proc_pid, - trace_running: AtomicBool::new(false), show_kernel_samples: std::env::var("BLONDIE_KERNEL") .map(|value| { let upper = value.to_uppercase(); @@ -101,19 +101,112 @@ impl Drop for TraceContext { } } -/// The main tracing logic. Traces the process with the given `target_process_id`. -/// -/// # Safety +/// Stateful context provided to `event_record_callback`, containing multiple [`TraceContext`]s. +struct Context { + /// Keys are process IDs. `Weak` deallocates when tracing should stop. + traces: FxHashMap>, + /// Receive new processes to subscribe to tracing. + subscribe_recv: Receiver>, + /// Set to true once the trace starts running. + trace_running: Arc, +} +impl Context { + /// SAFETY: May only be called by `event_record_callback`, while tracing. + unsafe fn get_trace_context(&mut self, pid: u32) -> Option> { + // TODO: handle PID reuse??? + let Entry::Occupied(entry) = self.traces.entry(pid) else { + // Ignore dlls for other processes + return None; + }; + if let Some(trace_context) = Weak::upgrade(entry.get()) { + Some(trace_context) + } else { + // Tracing just stopped, remove deallocated `Weak`. + entry.remove(); + return None; + } + } +} + +/// Global tracing session. /// -/// `is_suspended` may only be true if `target_process` is suspended -unsafe fn trace_from_process_id( - target_process_id: u32, - is_suspended: bool, - kernel_stacks: bool, -) -> Result { +/// When this is dropped, the tracing session will be stopped. +struct Session { + /// Send new processes to subscribe to tracing. + subscribe_send: SyncSender>, + /// Box allocation for [`UserContext`]. + context: *mut Context, + /// Box allocation for event trace props, need deallocation after. + event_trace_props: *mut EVENT_TRACE_PROPERTIES_WITH_STRING, + /// Box allocation for Logfile. + log: *mut EVENT_TRACE_LOGFILEA, +} +unsafe impl Send for Session {} +unsafe impl Sync for Session {} + +impl Session { + fn start(self: &Arc, trace_context: TraceContext) -> TraceGuard { + let trace_context = Arc::new(trace_context); + self.subscribe_send + .send(Arc::downgrade(&trace_context)) + .unwrap(); + TraceGuard { + trace_context, + _session: Arc::clone(&self), + } + } +} + +impl Drop for Session { + fn drop(&mut self) { + let ret = unsafe { + // This unblocks ProcessTrace + ControlTraceA( + ::default(), + KERNEL_LOGGER_NAMEA, + self.event_trace_props.cast(), + EVENT_TRACE_CONTROL_STOP, + ) + }; + unsafe { + drop(Box::from_raw(self.context)); + drop(Box::from_raw(self.event_trace_props)); + drop(Box::from_raw(self.log)); + } + if ret != ERROR_SUCCESS { + eprintln!( + "Error dropping GlobalContext: {:?}", + get_last_error("ControlTraceA STOP ProcessTrace") + ); + } + } +} + +struct TraceGuard { + trace_context: Arc, + /// Ensure session stays alive while `TraceGuard` is alive. + _session: Arc, +} +impl TraceGuard { + fn stop(self) -> TraceContext { + Arc::try_unwrap(self.trace_context) + .map_err(drop) + .expect("TraceContext Arc count should never have been incremented.") + } +} + +/// Gets the global context. Begins tracing if not already running. +fn get_global_context() -> Result> { + static GLOBAL_CONTEXT: Mutex> = Mutex::new(Weak::new()); + + let mut unlocked = GLOBAL_CONTEXT.lock().unwrap(); + if let Some(global_context) = unlocked.upgrade() { + return Ok(global_context); + } + let mut winver_info = OSVERSIONINFOA::default(); winver_info.dwOSVersionInfoSize = size_of::() as u32; - let ret = GetVersionExA(&mut winver_info); + let ret = unsafe { GetVersionExA(addr_of_mut!(winver_info)) }; if ret.0 == 0 { return Err(get_last_error("GetVersionExA")); } @@ -134,19 +227,21 @@ unsafe fn trace_from_process_id( let mut interval = TRACE_PROFILE_INTERVAL::default(); // TODO: Parameter? interval.Interval = (1000000000 / 8000) / 100; - let ret = TraceSetInformation( - None, - // The value is supported on Windows 8, Windows Server 2012, and later. - TraceSampledProfileIntervalInfo, - addr_of!(interval).cast(), - size_of::() as u32, - ); + let ret = unsafe { + TraceSetInformation( + None, + // The value is supported on Windows 8, Windows Server 2012, and later. + TraceSampledProfileIntervalInfo, + addr_of!(interval).cast(), + size_of::() as u32, + ) + }; if ret != ERROR_SUCCESS { return Err(get_last_error("TraceSetInformation interval")); } } - let mut kernel_logger_name_with_nul = KERNEL_LOGGER_NAMEA.as_bytes().to_vec(); + let mut kernel_logger_name_with_nul = unsafe { KERNEL_LOGGER_NAMEA.as_bytes() }.to_vec(); kernel_logger_name_with_nul.push(b'\0'); // Build the trace properties, we want EVENT_TRACE_FLAG_PROFILE for the "SampledProfile" event // https://docs.microsoft.com/en-us/windows/win32/etw/sampledprofile @@ -158,27 +253,11 @@ unsafe fn trace_from_process_id( // Events are delivered when the buffers are flushed (https://docs.microsoft.com/en-us/windows/win32/etw/logging-mode-constants) // We also use Image_Load events to know which dlls to load debug information from for symbol resolution // Which is enabled by the EVENT_TRACE_FLAG_IMAGE_LOAD flag - const KERNEL_LOGGER_NAMEA_LEN: usize = unsafe { - let mut ptr = KERNEL_LOGGER_NAMEA.0; - let mut len = 0; - while *ptr != 0 { - len += 1; - ptr = ptr.add(1); - } - len - }; const PROPS_SIZE: usize = size_of::() + KERNEL_LOGGER_NAMEA_LEN + 1; - #[derive(Clone)] - #[repr(C)] - #[allow(non_camel_case_types)] - struct EVENT_TRACE_PROPERTIES_WITH_STRING { - data: EVENT_TRACE_PROPERTIES, - s: [u8; KERNEL_LOGGER_NAMEA_LEN + 1], - } - let mut event_trace_props = EVENT_TRACE_PROPERTIES_WITH_STRING { + let mut event_trace_props = Box::new(EVENT_TRACE_PROPERTIES_WITH_STRING { data: EVENT_TRACE_PROPERTIES::default(), s: [0u8; KERNEL_LOGGER_NAMEA_LEN + 1], - }; + }); event_trace_props.data.EnableFlags = EVENT_TRACE_FLAG_PROFILE | EVENT_TRACE_FLAG_IMAGE_LOAD; event_trace_props.data.LogFileMode = EVENT_TRACE_REAL_TIME_MODE; event_trace_props.data.Wnode.BufferSize = PROPS_SIZE as u32; @@ -195,17 +274,19 @@ unsafe fn trace_from_process_id( .s .copy_from_slice(&kernel_logger_name_with_nul[..]); - let kernel_logger_name_with_nul_pcstr = PCSTR(kernel_logger_name_with_nul.as_ptr()); + // let kernel_logger_name_with_nul_pcstr = PCSTR(kernel_logger_name_with_nul.as_ptr()); // Stop an existing session with the kernel logger, if it exists // We use a copy of `event_trace_props` since ControlTrace overwrites it { - let mut event_trace_props_copy = event_trace_props.clone(); - let control_stop_retcode = ControlTraceA( - None, - kernel_logger_name_with_nul_pcstr, - addr_of_mut!(event_trace_props_copy) as *mut _, - EVENT_TRACE_CONTROL_STOP, - ); + let mut event_trace_props_copy = (*event_trace_props).clone(); + let control_stop_retcode = unsafe { + ControlTraceA( + None, + KERNEL_LOGGER_NAMEA, + addr_of_mut!(event_trace_props_copy).cast(), + EVENT_TRACE_CONTROL_STOP, + ) + }; if control_stop_retcode != ERROR_SUCCESS && control_stop_retcode != ERROR_WMI_INSTANCE_NOT_FOUND { @@ -216,11 +297,13 @@ unsafe fn trace_from_process_id( // Start kernel trace session let mut trace_session_handle: CONTROLTRACE_HANDLE = Default::default(); { - let start_retcode = StartTraceA( - addr_of_mut!(trace_session_handle), - kernel_logger_name_with_nul_pcstr, - addr_of_mut!(event_trace_props) as *mut _, - ); + let start_retcode = unsafe { + StartTraceA( + addr_of_mut!(trace_session_handle), + KERNEL_LOGGER_NAMEA, + addr_of_mut!(*event_trace_props).cast(), + ) + }; if start_retcode != ERROR_SUCCESS { return Err(get_last_error("StartTraceA")); } @@ -238,41 +321,51 @@ unsafe fn trace_from_process_id( }; stack_event_id.EventGuid = perfinfo_guid; stack_event_id.Type = 46; // Sampled profile event - let enable_stacks_retcode = TraceSetInformation( - trace_session_handle, - TraceStackTracingInfo, - addr_of!(stack_event_id).cast(), - size_of::() as u32, - ); + let enable_stacks_retcode = unsafe { + TraceSetInformation( + trace_session_handle, + TraceStackTracingInfo, + addr_of!(stack_event_id).cast(), + size_of::() as u32, + ) + }; if enable_stacks_retcode != ERROR_SUCCESS { return Err(get_last_error("TraceSetInformation stackwalk")); } } - let target_proc_handle = util::handle_from_process_id(target_process_id)?; - let mut context = TraceContext::new(target_proc_handle, target_process_id, kernel_stacks)?; - // TODO: Do we need to Box the context? - - let mut log = EVENT_TRACE_LOGFILEA::default(); + let mut log = Box::new(EVENT_TRACE_LOGFILEA::default()); log.LoggerName = PSTR(kernel_logger_name_with_nul.as_mut_ptr()); log.Anonymous1.ProcessTraceMode = PROCESS_TRACE_MODE_REAL_TIME | PROCESS_TRACE_MODE_EVENT_RECORD | PROCESS_TRACE_MODE_RAW_TIMESTAMP; - log.Context = addr_of_mut!(context).cast(); unsafe extern "system" fn event_record_callback(record: *mut EVENT_RECORD) { let provider_guid_data1 = (*record).EventHeader.ProviderId.data1; let event_opcode = (*record).EventHeader.EventDescriptor.Opcode; - let context = &mut *(*record).UserContext.cast::(); + + let context = &mut *(*record).UserContext.cast::(); context.trace_running.store(true, Ordering::Relaxed); + // Subscribe any new processes. + context + .traces + .extend(context.subscribe_recv.try_iter().filter_map(|weak| { + let pid = Weak::upgrade(&weak)?.target_proc_pid; + Some((pid, weak)) + })); const EVENT_TRACE_TYPE_LOAD: u8 = 10; if event_opcode == EVENT_TRACE_TYPE_LOAD { let event = (*record).UserData.cast::().read_unaligned(); - if event.ProcessId != context.target_proc_pid { + + let Some(trace_context) = context.get_trace_context(event.ProcessId) else { // Ignore dlls for other processes return; - } + }; + // TODO: use `Arc::get_mut_unchecked` once stable. + // SAFETY: Only the callback may modify the `TraceContext` while running. + let trace_context = Arc::into_raw(trace_context).cast_mut(); + let filename_p = (*record) .UserData .cast::() @@ -282,12 +375,15 @@ unsafe fn trace_from_process_id( filename_p, ((*record).UserDataLength as usize - size_of::()) / 2, )); - context.image_paths.push(( + (*trace_context).image_paths.push(( filename_os_string, event.ImageBase as u64, event.ImageSize as u64, )); + // SAFETY: De-increments Arc from above. + drop(Arc::from_raw(trace_context)); + return; } @@ -302,10 +398,15 @@ unsafe fn trace_from_process_id( let _timestamp = ud_p.cast::().read_unaligned(); let proc = ud_p.cast::().offset(2).read_unaligned(); let _thread = ud_p.cast::().offset(3).read_unaligned(); - if proc != context.target_proc_pid { + + // TODO: handle PID reuse??? + let Some(trace_context) = context.get_trace_context(proc) else { // Ignore stackwalks for other processes return; - } + }; + // TODO: use `Arc::get_mut_unchecked` once stable. + // SAFETY: Only the callback may modify the `TraceContext` while running. + let trace_context = Arc::into_raw(trace_context).cast_mut(); let stack_depth_32 = ((*record).UserDataLength - 16) / 4; let stack_depth_64 = stack_depth_32 / 2; @@ -336,9 +437,12 @@ unsafe fn trace_from_process_id( let mut stack = [0u64; MAX_STACK_DEPTH]; stack[..(stack_depth as usize).min(MAX_STACK_DEPTH)].copy_from_slice(stack_addrs); - let entry = context.stack_counts_hashmap.entry(stack); + let entry = (*trace_context).stack_counts_hashmap.entry(stack); *entry.or_insert(0) += 1; + // SAFETY: De-increments Arc from above. + drop(Arc::from_raw(trace_context)); + const DEBUG_OUTPUT_EVENTS: bool = false; if DEBUG_OUTPUT_EVENTS { #[repr(C)] @@ -386,17 +490,28 @@ unsafe fn trace_from_process_id( } log.Anonymous2.EventRecordCallback = Some(event_record_callback); - let trace_processing_handle = OpenTraceA(&mut log); + let (subscribe_send, subscribe_recv) = sync_channel(16); + let trace_running = Arc::new(AtomicBool::new(false)); + let context = Box::into_raw(Box::new(Context { + traces: Default::default(), + subscribe_recv, + trace_running: Arc::clone(&trace_running), + })); + log.Context = context.cast(); + + let trace_processing_handle = unsafe { OpenTraceA(addr_of_mut!(*log)) }; if trace_processing_handle.0 == INVALID_HANDLE_VALUE.0 as u64 { return Err(get_last_error("OpenTraceA processing")); } - let processing_thread = std::thread::spawn(move || { - SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_TIME_CRITICAL); - // This blocks - ProcessTrace(&[trace_processing_handle], None, None); + let _ = std::thread::spawn(move || { + unsafe { + SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_TIME_CRITICAL); + // This blocks until `EVENT_TRACE_CONTROL_STOP` on `GlobalContext::drop`. + ProcessTrace(&[trace_processing_handle], None, None) + }; - let ret = CloseTrace(trace_processing_handle); + let ret = unsafe { CloseTrace(trace_processing_handle) }; if ret != ERROR_SUCCESS { return Err(get_last_error("Error closing trace")); } @@ -404,9 +519,55 @@ unsafe fn trace_from_process_id( }); // Wait until we know for sure the trace is running - while !context.trace_running.load(Ordering::Relaxed) { + while !trace_running.load(Ordering::Relaxed) { std::hint::spin_loop(); } + + // Store the session. + let session = Arc::new(Session { + subscribe_send, + context, + event_trace_props: Box::into_raw(event_trace_props), + // TODO: does log need to survive past the `OpenTraceA` call? Maybe not + log: Box::into_raw(log), + }); + *unlocked = Arc::downgrade(&session); + Ok(session) +} + +const KERNEL_LOGGER_NAMEA_LEN: usize = unsafe { + let mut ptr = KERNEL_LOGGER_NAMEA.0; + let mut len = 0; + while *ptr != 0 { + len += 1; + ptr = ptr.add(1); + } + len +}; + +#[derive(Clone)] +#[repr(C)] +#[allow(non_camel_case_types)] +struct EVENT_TRACE_PROPERTIES_WITH_STRING { + data: EVENT_TRACE_PROPERTIES, + s: [u8; KERNEL_LOGGER_NAMEA_LEN + 1], +} + +/// The main tracing logic. Traces the process with the given `target_process_id`. +/// +/// # Safety +/// +/// `is_suspended` may only be true if `target_process` is suspended +unsafe fn trace_from_process_id( + target_process_id: u32, + is_suspended: bool, + kernel_stacks: bool, +) -> Result { + let target_proc_handle = util::handle_from_process_id(target_process_id)?; + let trace_context = + unsafe { TraceContext::new(target_proc_handle, target_process_id, kernel_stacks)? }; + let trace_guard = get_global_context()?.start(trace_context); + // Resume the suspended process if is_suspended { // TODO: Do something less gross here @@ -425,34 +586,18 @@ unsafe fn trace_from_process_id( #[allow(non_snake_case)] let NtResumeProcess: extern "system" fn(isize) -> i32 = std::mem::transmute(NtResumeProcess); - NtResumeProcess(context.target_process_handle.0); + NtResumeProcess(target_proc_handle.0); } // Wait for it to end util::wait_for_process_by_handle(target_proc_handle)?; - // This unblocks ProcessTrace - let ret = ControlTraceA( - ::default(), - PCSTR(kernel_logger_name_with_nul.as_ptr()), - addr_of_mut!(event_trace_props) as *mut _, - EVENT_TRACE_CONTROL_STOP, - ); - if ret != ERROR_SUCCESS { - return Err(get_last_error("ControlTraceA STOP ProcessTrace")); - } - - // Block until processing thread is done - // (Safeguard to make sure we don't deallocate the context before the other thread finishes using it) - processing_thread - .join() - .map_err(|_err_any| Error::UnknownError)??; - if context.show_kernel_samples { + let mut trace_context = trace_guard.stop(); + if trace_context.show_kernel_samples { let kernel_module_paths = util::list_kernel_modules(); - context.image_paths.extend(kernel_module_paths); + trace_context.image_paths.extend(kernel_module_paths); } - - Ok(context) + Ok(trace_context) } /// The sampled results from a process execution @@ -521,7 +666,7 @@ type PdbDb<'a, 'b> = /// Returns Vec<(image_base, image_size, image_name, addr2line pdb context)> fn find_pdbs(images: &[(OsString, u64, u64)]) -> Vec<(u64, u64, OsString, OwnedPdb)> { - let mut pdb_db = Vec::with_capacity(images.len()); + let mut pdb_db = Vec::new(); fn owned_pdb(pdb_file_bytes: Vec) -> Option { let pdb = PDB::open(std::io::Cursor::new(pdb_file_bytes)).ok()?; @@ -681,6 +826,7 @@ impl<'a> CallStack<'a> { Ok(()) } } + impl CollectionResults { /// Iterate the distinct callstacks sampled in this execution pub fn iter_callstacks(&self) -> impl std::iter::Iterator> { diff --git a/tests/multi.rs b/tests/multi.rs new file mode 100644 index 0000000..d6cc168 --- /dev/null +++ b/tests/multi.rs @@ -0,0 +1,16 @@ +use std::process::Command; + +#[test] +fn test_multi() { + let handle = std::thread::spawn(|| { + let mut cmd = Command::new("ping"); + cmd.arg("localhost"); + let _ctx = blondie::trace_command(cmd, false).unwrap(); + }); + + let mut cmd = Command::new("ping"); + cmd.arg("localhost"); + let _ctx = blondie::trace_command(cmd, false).unwrap(); + + handle.join().unwrap(); +}