From 535030de2b209ab271ddfa0bc10a1effcfd37c8f Mon Sep 17 00:00:00 2001 From: TKF Date: Sun, 4 Aug 2024 13:40:10 +0800 Subject: [PATCH] Protect memory mappings --- README.md | 19 ++++++++- src/area.rs | 19 +++++++++ src/set.rs | 71 ++++++++++++++++++++++++++++++- src/tests.rs | 116 ++++++++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 222 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 3d5deec..33edae8 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,10 @@ Data structures and operations for managing memory mappings. -It is useful to implement [`mmap`][1] and [`munmap`][1]. +It is useful to implement [`mmap`][1], [`munmap`][1] and [`mprotect`][2]. [1]: https://man7.org/linux/man-pages/man2/mmap.2.html +[2]: https://man7.org/linux/man-pages/man2/mprotect.2.html ## Examples @@ -65,5 +66,21 @@ impl MappingBackend for MockBackend { } true } + + fn protect( + &self, + start: VirtAddr, + size: usize, + new_flags: MockFlags, + pt: &mut MockPageTable, + ) -> bool { + for entry in pt.iter_mut().skip(start.as_usize()).take(size) { + if *entry == 0 { + return false; + } + *entry = new_flags; + } + true + } } ``` diff --git a/src/area.rs b/src/area.rs index 18cd40b..5d89931 100644 --- a/src/area.rs +++ b/src/area.rs @@ -17,6 +17,8 @@ pub trait MappingBackend: Clone { fn map(&self, start: VirtAddr, size: usize, flags: F, page_table: &mut P) -> bool; /// What to do when unmaping a memory region within the area. fn unmap(&self, start: VirtAddr, size: usize, page_table: &mut P) -> bool; + /// What to do when changing access flags. + fn protect(&self, start: VirtAddr, size: usize, new_flags: F, page_table: &mut P) -> bool; } /// A memory area represents a continuous range of virtual memory with the same @@ -74,6 +76,16 @@ impl> MemoryArea { } impl> MemoryArea { + /// Changes the flags. + pub(crate) fn set_flags(&mut self, new_flags: F) { + self.flags = new_flags; + } + + /// Changes the end address of the memory area. + pub(crate) fn set_end(&mut self, new_end: VirtAddr) { + self.va_range.end = new_end; + } + /// Maps the whole memory area in the page table. pub(crate) fn map_area(&self, page_table: &mut P) -> MappingResult { self.backend @@ -90,6 +102,13 @@ impl> MemoryArea { .ok_or(MappingError::BadState) } + /// Changes the flags in the page table. + pub(crate) fn protect_area(&mut self, new_flags: F, page_table: &mut P) -> MappingResult { + self.backend + .protect(self.start(), self.size(), new_flags, page_table); + Ok(()) + } + /// Shrinks the memory area at the left side. /// /// The start address of the memory area is increased by `new_size`. The diff --git a/src/set.rs b/src/set.rs index 1da2e81..5f4979c 100644 --- a/src/set.rs +++ b/src/set.rs @@ -1,4 +1,4 @@ -use alloc::collections::BTreeMap; +use alloc::{collections::BTreeMap, vec::Vec}; use core::fmt; use memory_addr::{VirtAddr, VirtAddrRange}; @@ -176,6 +176,75 @@ impl> MemorySet { self.areas.clear(); Ok(()) } + + /// Change the flags of memory mappings within the given address range. + /// + /// `update_flags` is a function that receives old flags and processes + /// new flags (e.g., some flags can not be changed through this interface). + /// It returns [`None`] if there is no bit to change. + /// + /// Memory areas will be skipped according to `update_flags`. Memory areas + /// that are fully contained in the range or contains the range or intersects + /// with the boundary will be handled similarly to `munmap`. + pub fn protect( + &mut self, + start: VirtAddr, + size: usize, + update_flags: impl Fn(F) -> Option, + page_table: &mut P, + ) -> MappingResult { + let end = start + size; + let mut to_insert = Vec::new(); + for (_, area) in self.areas.iter_mut() { + if let Some(new_flags) = update_flags(area.flags()) { + if area.start() >= end { + // [ prot ] + // [ area ] + break; + } else if area.end() <= start { + // [ prot ] + // [ area ] + // Do nothing + } else if area.start() >= start && area.end() <= end { + // [ prot ] + // [ area ] + area.protect_area(new_flags, page_table)?; + area.set_flags(new_flags); + } else if area.start() < start && area.end() > end { + // [ prot ] + // [ left | area | right ] + let right_part = area.split(end).unwrap(); + area.set_end(start); + + let mut middle_part = + MemoryArea::new(start, size, area.flags(), area.backend().clone()); + middle_part.protect_area(new_flags, page_table)?; + middle_part.set_flags(new_flags); + + to_insert.push((right_part.start(), right_part)); + to_insert.push((middle_part.start(), middle_part)); + } else if area.end() > end { + // [ prot ] + // [ area | right ] + let right_part = area.split(end).unwrap(); + area.protect_area(new_flags, page_table)?; + area.set_flags(new_flags); + + to_insert.push((right_part.start(), right_part)); + } else { + // [ prot ] + // [ left | area ] + let mut right_part = area.split(start).unwrap(); + right_part.protect_area(new_flags, page_table)?; + right_part.set_flags(new_flags); + + to_insert.push((right_part.start(), right_part)); + } + } + } + self.areas.extend(to_insert.into_iter()); + Ok(()) + } } impl> fmt::Debug for MemorySet { diff --git a/src/tests.rs b/src/tests.rs index f8fc586..e2b3e9b 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -32,6 +32,22 @@ impl MappingBackend for MockBackend { } true } + + fn protect( + &self, + start: VirtAddr, + size: usize, + new_flags: MockFlags, + pt: &mut MockPageTable, + ) -> bool { + for entry in pt.iter_mut().skip(start.as_usize()).take(size) { + if *entry == 0 { + return false; + } + *entry = new_flags; + } + true + } } macro_rules! assert_ok { @@ -168,7 +184,7 @@ fn test_unmap_split() { } } - // Unmap [0x800, 0x900), [0x2800, 0x4400), [0x4800, 0x4900), ... + // Unmap [0x800, 0x900), [0x2800, 0x2900), [0x4800, 0x4900), ... // The areas are split into two areas. for start in (0..MAX_ADDR).step_by(0x2000) { assert_ok!(set.unmap((start + 0x800).into(), 0x100, &mut pt)); @@ -208,3 +224,101 @@ fn test_unmap_split() { assert_eq!(pt[addr], 0); } } + +#[test] +fn test_protect() { + let mut set = MockMemorySet::new(); + let mut pt = [0; MAX_ADDR]; + let update_flags = |new_flags: MockFlags| { + move |old_flags: MockFlags| -> Option { + if (old_flags & 0x7) == (new_flags & 0x7) { + return None; + } + let flags = (new_flags & 0x7) | (old_flags & !0x7); + Some(flags) + } + }; + + // Map [0, 0x1000), [0x2000, 0x3000), [0x4000, 0x5000), ... + for start in (0..MAX_ADDR).step_by(0x2000) { + assert_ok!(set.map( + MemoryArea::new(start.into(), 0x1000, 0x7, MockBackend), + &mut pt, + false, + )); + } + assert_eq!(set.len(), 8); + + // Protect [0xc00, 0x2400), [0x2c00, 0x4400), [0x4c00, 0x6400), ... + // The areas are split into two areas. + for start in (0..MAX_ADDR).step_by(0x2000) { + assert_ok!(set.protect((start + 0xc00).into(), 0x1800, update_flags(0x1), &mut pt)); + } + dump_memory_set(&set); + assert_eq!(set.len(), 23); + + for area in set.iter() { + let off = area.start().align_offset_4k(); + if area.start().as_usize() == 0 { + assert_eq!(area.size(), 0xc00); + assert_eq!(area.flags(), 0x7); + } else { + if off == 0 { + assert_eq!(area.size(), 0x400); + assert_eq!(area.flags(), 0x1); + } else if off == 0x400 { + assert_eq!(area.size(), 0x800); + assert_eq!(area.flags(), 0x7); + } else if off == 0xc00 { + assert_eq!(area.size(), 0x400); + assert_eq!(area.flags(), 0x1); + } + } + } + + // Protect [0x800, 0x900), [0x2800, 0x2900), [0x4800, 0x4900), ... + // The areas are split into three areas. + for start in (0..MAX_ADDR).step_by(0x2000) { + assert_ok!(set.protect((start + 0x800).into(), 0x100, update_flags(0x13), &mut pt)); + } + dump_memory_set(&set); + assert_eq!(set.len(), 39); + + for area in set.iter() { + let off = area.start().align_offset_4k(); + if area.start().as_usize() == 0 { + assert_eq!(area.size(), 0x800); + assert_eq!(area.flags(), 0x7); + } else { + if off == 0 { + assert_eq!(area.size(), 0x400); + assert_eq!(area.flags(), 0x1); + } else if off == 0x400 { + assert_eq!(area.size(), 0x400); + assert_eq!(area.flags(), 0x7); + } else if off == 0x800 { + assert_eq!(area.size(), 0x100); + assert_eq!(area.flags(), 0x3); + } else if off == 0x900 { + assert_eq!(area.size(), 0x300); + assert_eq!(area.flags(), 0x7); + } else if off == 0xc00 { + assert_eq!(area.size(), 0x400); + assert_eq!(area.flags(), 0x1); + } + } + } + + // Test skip [0x880, 0x900), [0x2880, 0x2900), [0x4880, 0x4900), ... + for start in (0..MAX_ADDR).step_by(0x2000) { + assert_ok!(set.protect((start + 0x880).into(), 0x80, update_flags(0x3), &mut pt)); + } + assert_eq!(set.len(), 39); + + // Unmap all areas. + assert_ok!(set.unmap(0.into(), MAX_ADDR, &mut pt)); + assert_eq!(set.len(), 0); + for addr in 0..MAX_ADDR { + assert_eq!(pt[addr], 0); + } +}