Skip to content
This repository has been archived by the owner on Sep 14, 2024. It is now read-only.

Commit

Permalink
Protect memory mappings
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf2019 committed Aug 4, 2024
1 parent c5e582a commit 362e250
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 3 deletions.
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

Data structures and operations for managing memory mappings.

It is useful to implement [`mmap`][1] and [`munmap`][1].
It is useful to implement [`mmap`][1], [`munmap`][1] and [`mprotect`][2].

[1]: https://man7.org/linux/man-pages/man2/mmap.2.html
[2]: https://man7.org/linux/man-pages/man2/mprotect.2.html

## Examples

Expand Down Expand Up @@ -65,5 +66,21 @@ impl MappingBackend<MockFlags, MockPageTable> for MockBackend {
}
true
}

fn protect(
&self,
start: VirtAddr,
size: usize,
new_flags: MockFlags,
pt: &mut MockPageTable,
) -> bool {
for entry in pt.iter_mut().skip(start.as_usize()).take(size) {
if *entry == 0 {
return false;
}
*entry = new_flags;
}
true
}
}
```
20 changes: 20 additions & 0 deletions src/area.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@ pub trait MappingBackend<F: Copy, P>: Clone {
fn map(&self, start: VirtAddr, size: usize, flags: F, page_table: &mut P) -> bool;
/// What to do when unmaping a memory region within the area.
fn unmap(&self, start: VirtAddr, size: usize, page_table: &mut P) -> bool;
/// What to do when changing access flags.
fn protect(&self, start: VirtAddr, size: usize, new_flags: F, page_table: &mut P) -> bool;
}

/// A memory area represents a continuous range of virtual memory with the same
/// flags.
///
/// The target physical memory frames are determined by [`MappingBackend`] and
/// may not be contiguous.
#[derive(Clone)]
pub struct MemoryArea<F: Copy, P, B: MappingBackend<F, P>> {
va_range: VirtAddrRange,
flags: F,
Expand Down Expand Up @@ -74,6 +77,16 @@ impl<F: Copy, P, B: MappingBackend<F, P>> MemoryArea<F, P, B> {
}

impl<F: Copy, P, B: MappingBackend<F, P>> MemoryArea<F, P, B> {
/// Changes the flags.
pub(crate) fn set_flags(&mut self, new_flags: F) {
self.flags = new_flags;
}

/// Changes the end address of the memory area.
pub(crate) fn set_end(&mut self, new_end: VirtAddr) {
self.va_range.end = new_end;
}

/// Maps the whole memory area in the page table.
pub(crate) fn map_area(&self, page_table: &mut P) -> MappingResult {
self.backend
Expand All @@ -90,6 +103,13 @@ impl<F: Copy, P, B: MappingBackend<F, P>> MemoryArea<F, P, B> {
.ok_or(MappingError::BadState)
}

/// Changes the flags in the page table.
pub(crate) fn protect_area(&mut self, new_flags: F, page_table: &mut P) -> MappingResult {
self.backend
.protect(self.start(), self.size(), new_flags, page_table);
Ok(())
}

/// Shrinks the memory area at the left side.
///
/// The start address of the memory area is increased by `new_size`. The
Expand Down
71 changes: 70 additions & 1 deletion src/set.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use alloc::collections::BTreeMap;
use alloc::{collections::BTreeMap, vec::Vec};
use core::fmt;

use memory_addr::{VirtAddr, VirtAddrRange};
Expand Down Expand Up @@ -176,6 +176,75 @@ impl<F: Copy, P, B: MappingBackend<F, P>> MemorySet<F, P, B> {
self.areas.clear();
Ok(())
}

/// Change the flags of memory mappings within the given address range.
///
/// `update_flags` is a function that receives old flags and processes
/// new flags (e.g., some flags can not be changed through this interface).
/// It returns [`None`] if there is no bit to change.
///
/// Memory areas will be skipped according to `update_flags`. Memory areas
/// that are fully contained in the range or contains the range or intersects
/// with the boundary will be handled similarly to `munmap`.
pub fn protect(
&mut self,
start: VirtAddr,
size: usize,
update_flags: impl Fn(F) -> Option<F>,
page_table: &mut P,
) -> MappingResult {
let end = start + size;
let mut to_insert = Vec::new();
for (_, area) in self.areas.iter_mut() {
if let Some(new_flags) = update_flags(area.flags()) {
if area.start() >= end {
// [ prot ]
// [ area ]
break;
} else if area.end() <= start {
// [ prot ]
// [ area ]
// Do nothing
} else if area.start() >= start && area.end() <= end {
// [ prot ]
// [ area ]
area.protect_area(new_flags, page_table)?;
area.set_flags(new_flags);
} else if area.start() < start && area.end() > end {
// [ prot ]
// [ left | area | right ]
let right_part = area.split(end).unwrap();
area.set_end(start);

let mut middle_part =
MemoryArea::new(start, size, area.flags(), area.backend().clone());
middle_part.protect_area(new_flags, page_table)?;
middle_part.set_flags(new_flags);

to_insert.push((right_part.start(), right_part));
to_insert.push((middle_part.start(), middle_part));
} else if area.end() > end {
// [ prot ]
// [ area | right ]
let right_part = area.split(end).unwrap();
area.protect_area(new_flags, page_table)?;
area.set_flags(new_flags);

to_insert.push((right_part.start(), right_part));
} else {
// [ prot ]
// [ left | area ]
let mut right_part = area.split(start).unwrap();
right_part.protect_area(new_flags, page_table)?;
right_part.set_flags(new_flags);

to_insert.push((right_part.start(), right_part));
}
}
}
self.areas.extend(to_insert.into_iter());
Ok(())
}
}

impl<F: Copy + fmt::Debug, P, B: MappingBackend<F, P>> fmt::Debug for MemorySet<F, P, B> {
Expand Down
116 changes: 115 additions & 1 deletion src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ impl MappingBackend<MockFlags, MockPageTable> for MockBackend {
}
true
}

fn protect(
&self,
start: VirtAddr,
size: usize,
new_flags: MockFlags,
pt: &mut MockPageTable,
) -> bool {
for entry in pt.iter_mut().skip(start.as_usize()).take(size) {
if *entry == 0 {
return false;
}
*entry = new_flags;
}
true
}
}

macro_rules! assert_ok {
Expand Down Expand Up @@ -168,7 +184,7 @@ fn test_unmap_split() {
}
}

// Unmap [0x800, 0x900), [0x2800, 0x4400), [0x4800, 0x4900), ...
// Unmap [0x800, 0x900), [0x2800, 0x2900), [0x4800, 0x4900), ...
// The areas are split into two areas.
for start in (0..MAX_ADDR).step_by(0x2000) {
assert_ok!(set.unmap((start + 0x800).into(), 0x100, &mut pt));
Expand Down Expand Up @@ -208,3 +224,101 @@ fn test_unmap_split() {
assert_eq!(pt[addr], 0);
}
}

#[test]
fn test_protect() {
let mut set = MockMemorySet::new();
let mut pt = [0; MAX_ADDR];
let update_flags = |new_flags: MockFlags| {
move |old_flags: MockFlags| -> Option<MockFlags> {
if (old_flags & 0x7) == (new_flags & 0x7) {
return None;
}
let flags = (new_flags & 0x7) | (old_flags & !0x7);
Some(flags)
}
};

// Map [0, 0x1000), [0x2000, 0x3000), [0x4000, 0x5000), ...
for start in (0..MAX_ADDR).step_by(0x2000) {
assert_ok!(set.map(
MemoryArea::new(start.into(), 0x1000, 0x7, MockBackend),
&mut pt,
false,
));
}
assert_eq!(set.len(), 8);

// Protect [0xc00, 0x2400), [0x2c00, 0x4400), [0x4c00, 0x6400), ...
// The areas are split into two areas.
for start in (0..MAX_ADDR).step_by(0x2000) {
assert_ok!(set.protect((start + 0xc00).into(), 0x1800, update_flags(0x1), &mut pt));
}
dump_memory_set(&set);
assert_eq!(set.len(), 23);

for area in set.iter() {
let off = area.start().align_offset_4k();
if area.start().as_usize() == 0 {
assert_eq!(area.size(), 0xc00);
assert_eq!(area.flags(), 0x7);
} else {
if off == 0 {
assert_eq!(area.size(), 0x400);
assert_eq!(area.flags(), 0x1);
} else if off == 0x400 {
assert_eq!(area.size(), 0x800);
assert_eq!(area.flags(), 0x7);
} else if off == 0xc00 {
assert_eq!(area.size(), 0x400);
assert_eq!(area.flags(), 0x1);
}
}
}

// Protect [0x800, 0x900), [0x2800, 0x2900), [0x4800, 0x4900), ...
// The areas are split into three areas.
for start in (0..MAX_ADDR).step_by(0x2000) {
assert_ok!(set.protect((start + 0x800).into(), 0x100, update_flags(0x13), &mut pt));
}
dump_memory_set(&set);
assert_eq!(set.len(), 39);

for area in set.iter() {
let off = area.start().align_offset_4k();
if area.start().as_usize() == 0 {
assert_eq!(area.size(), 0x800);
assert_eq!(area.flags(), 0x7);
} else {
if off == 0 {
assert_eq!(area.size(), 0x400);
assert_eq!(area.flags(), 0x1);
} else if off == 0x400 {
assert_eq!(area.size(), 0x400);
assert_eq!(area.flags(), 0x7);
} else if off == 0x800 {
assert_eq!(area.size(), 0x100);
assert_eq!(area.flags(), 0x3);
} else if off == 0x900 {
assert_eq!(area.size(), 0x300);
assert_eq!(area.flags(), 0x7);
} else if off == 0xc00 {
assert_eq!(area.size(), 0x400);
assert_eq!(area.flags(), 0x1);
}
}
}

// Test skip [0x880, 0x900), [0x2880, 0x2900), [0x4880, 0x4900), ...
for start in (0..MAX_ADDR).step_by(0x2000) {
assert_ok!(set.protect((start + 0x880).into(), 0x80, update_flags(0x3), &mut pt));
}
assert_eq!(set.len(), 39);

// Unmap all areas.
assert_ok!(set.unmap(0.into(), MAX_ADDR, &mut pt));
assert_eq!(set.len(), 0);
for addr in 0..MAX_ADDR {
assert_eq!(pt[addr], 0);
}
}

0 comments on commit 362e250

Please sign in to comment.