Skip to content
This repository has been archived by the owner on Sep 14, 2024. It is now read-only.

Protect memory mappings #1

Merged
merged 1 commit into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

Data structures and operations for managing memory mappings.

It is useful to implement [`mmap`][1] and [`munmap`][1].
It is useful to implement [`mmap`][1], [`munmap`][1] and [`mprotect`][2].

[1]: https://man7.org/linux/man-pages/man2/mmap.2.html
[2]: https://man7.org/linux/man-pages/man2/mprotect.2.html

## Examples

Expand Down Expand Up @@ -65,5 +66,21 @@ impl MappingBackend<MockFlags, MockPageTable> for MockBackend {
}
true
}

fn protect(
&self,
start: VirtAddr,
size: usize,
new_flags: MockFlags,
pt: &mut MockPageTable,
) -> bool {
for entry in pt.iter_mut().skip(start.as_usize()).take(size) {
if *entry == 0 {
return false;
}
*entry = new_flags;
}
true
}
}
```
19 changes: 19 additions & 0 deletions src/area.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub trait MappingBackend<F: Copy, P>: Clone {
fn map(&self, start: VirtAddr, size: usize, flags: F, page_table: &mut P) -> bool;
/// What to do when unmaping a memory region within the area.
fn unmap(&self, start: VirtAddr, size: usize, page_table: &mut P) -> bool;
/// What to do when changing access flags.
tkf2019 marked this conversation as resolved.
Show resolved Hide resolved
fn protect(&self, start: VirtAddr, size: usize, new_flags: F, page_table: &mut P) -> bool;
}

/// A memory area represents a continuous range of virtual memory with the same
Expand Down Expand Up @@ -74,6 +76,16 @@ impl<F: Copy, P, B: MappingBackend<F, P>> MemoryArea<F, P, B> {
}

impl<F: Copy, P, B: MappingBackend<F, P>> MemoryArea<F, P, B> {
/// Changes the flags.
pub(crate) fn set_flags(&mut self, new_flags: F) {
self.flags = new_flags;
}

/// Changes the end address of the memory area.
pub(crate) fn set_end(&mut self, new_end: VirtAddr) {
self.va_range.end = new_end;
}

/// Maps the whole memory area in the page table.
pub(crate) fn map_area(&self, page_table: &mut P) -> MappingResult {
self.backend
Expand All @@ -90,6 +102,13 @@ impl<F: Copy, P, B: MappingBackend<F, P>> MemoryArea<F, P, B> {
.ok_or(MappingError::BadState)
}

/// Changes the flags in the page table.
pub(crate) fn protect_area(&mut self, new_flags: F, page_table: &mut P) -> MappingResult {
self.backend
.protect(self.start(), self.size(), new_flags, page_table);
Ok(())
}

/// Shrinks the memory area at the left side.
///
/// The start address of the memory area is increased by `new_size`. The
Expand Down
71 changes: 70 additions & 1 deletion src/set.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use alloc::collections::BTreeMap;
use alloc::{collections::BTreeMap, vec::Vec};
use core::fmt;

use memory_addr::{VirtAddr, VirtAddrRange};
Expand Down Expand Up @@ -176,6 +176,75 @@ impl<F: Copy, P, B: MappingBackend<F, P>> MemorySet<F, P, B> {
self.areas.clear();
Ok(())
}

/// Change the flags of memory mappings within the given address range.
///
/// `update_flags` is a function that receives old flags and processes
/// new flags (e.g., some flags can not be changed through this interface).
/// It returns [`None`] if there is no bit to change.
///
/// Memory areas will be skipped according to `update_flags`. Memory areas
/// that are fully contained in the range or contains the range or intersects
/// with the boundary will be handled similarly to `munmap`.
pub fn protect(
&mut self,
start: VirtAddr,
size: usize,
update_flags: impl Fn(F) -> Option<F>,
page_table: &mut P,
) -> MappingResult {
let end = start + size;
let mut to_insert = Vec::new();
for (_, area) in self.areas.iter_mut() {
if let Some(new_flags) = update_flags(area.flags()) {
if area.start() >= end {
// [ prot ]
// [ area ]
break;
} else if area.end() <= start {
// [ prot ]
// [ area ]
// Do nothing
} else if area.start() >= start && area.end() <= end {
// [ prot ]
// [ area ]
area.protect_area(new_flags, page_table)?;
area.set_flags(new_flags);
} else if area.start() < start && area.end() > end {
// [ prot ]
// [ left | area | right ]
let right_part = area.split(end).unwrap();
area.set_end(start);

let mut middle_part =
MemoryArea::new(start, size, area.flags(), area.backend().clone());
middle_part.protect_area(new_flags, page_table)?;
middle_part.set_flags(new_flags);

to_insert.push((right_part.start(), right_part));
to_insert.push((middle_part.start(), middle_part));
} else if area.end() > end {
// [ prot ]
// [ area | right ]
let right_part = area.split(end).unwrap();
area.protect_area(new_flags, page_table)?;
area.set_flags(new_flags);

to_insert.push((right_part.start(), right_part));
} else {
// [ prot ]
// [ left | area ]
let mut right_part = area.split(start).unwrap();
right_part.protect_area(new_flags, page_table)?;
right_part.set_flags(new_flags);

to_insert.push((right_part.start(), right_part));
}
}
}
self.areas.extend(to_insert.into_iter());
Ok(())
}
}

impl<F: Copy + fmt::Debug, P, B: MappingBackend<F, P>> fmt::Debug for MemorySet<F, P, B> {
Expand Down
116 changes: 115 additions & 1 deletion src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ impl MappingBackend<MockFlags, MockPageTable> for MockBackend {
}
true
}

fn protect(
&self,
start: VirtAddr,
size: usize,
new_flags: MockFlags,
pt: &mut MockPageTable,
) -> bool {
for entry in pt.iter_mut().skip(start.as_usize()).take(size) {
if *entry == 0 {
return false;
}
*entry = new_flags;
}
true
}
}

macro_rules! assert_ok {
Expand Down Expand Up @@ -168,7 +184,7 @@ fn test_unmap_split() {
}
}

// Unmap [0x800, 0x900), [0x2800, 0x4400), [0x4800, 0x4900), ...
// Unmap [0x800, 0x900), [0x2800, 0x2900), [0x4800, 0x4900), ...
// The areas are split into two areas.
for start in (0..MAX_ADDR).step_by(0x2000) {
assert_ok!(set.unmap((start + 0x800).into(), 0x100, &mut pt));
Expand Down Expand Up @@ -208,3 +224,101 @@ fn test_unmap_split() {
assert_eq!(pt[addr], 0);
}
}

#[test]
fn test_protect() {
let mut set = MockMemorySet::new();
let mut pt = [0; MAX_ADDR];
let update_flags = |new_flags: MockFlags| {
move |old_flags: MockFlags| -> Option<MockFlags> {
if (old_flags & 0x7) == (new_flags & 0x7) {
return None;
}
let flags = (new_flags & 0x7) | (old_flags & !0x7);
Some(flags)
}
};

// Map [0, 0x1000), [0x2000, 0x3000), [0x4000, 0x5000), ...
for start in (0..MAX_ADDR).step_by(0x2000) {
assert_ok!(set.map(
MemoryArea::new(start.into(), 0x1000, 0x7, MockBackend),
&mut pt,
false,
));
}
assert_eq!(set.len(), 8);

// Protect [0xc00, 0x2400), [0x2c00, 0x4400), [0x4c00, 0x6400), ...
tkf2019 marked this conversation as resolved.
Show resolved Hide resolved
// The areas are split into two areas.
for start in (0..MAX_ADDR).step_by(0x2000) {
assert_ok!(set.protect((start + 0xc00).into(), 0x1800, update_flags(0x1), &mut pt));
}
dump_memory_set(&set);
assert_eq!(set.len(), 23);
tkf2019 marked this conversation as resolved.
Show resolved Hide resolved

for area in set.iter() {
let off = area.start().align_offset_4k();
if area.start().as_usize() == 0 {
assert_eq!(area.size(), 0xc00);
assert_eq!(area.flags(), 0x7);
} else {
if off == 0 {
assert_eq!(area.size(), 0x400);
assert_eq!(area.flags(), 0x1);
} else if off == 0x400 {
assert_eq!(area.size(), 0x800);
assert_eq!(area.flags(), 0x7);
} else if off == 0xc00 {
assert_eq!(area.size(), 0x400);
assert_eq!(area.flags(), 0x1);
}
}
}

// Protect [0x800, 0x900), [0x2800, 0x2900), [0x4800, 0x4900), ...
// The areas are split into three areas.
for start in (0..MAX_ADDR).step_by(0x2000) {
assert_ok!(set.protect((start + 0x800).into(), 0x100, update_flags(0x13), &mut pt));
}
dump_memory_set(&set);
assert_eq!(set.len(), 39);

for area in set.iter() {
let off = area.start().align_offset_4k();
if area.start().as_usize() == 0 {
assert_eq!(area.size(), 0x800);
assert_eq!(area.flags(), 0x7);
} else {
if off == 0 {
assert_eq!(area.size(), 0x400);
assert_eq!(area.flags(), 0x1);
} else if off == 0x400 {
assert_eq!(area.size(), 0x400);
assert_eq!(area.flags(), 0x7);
} else if off == 0x800 {
assert_eq!(area.size(), 0x100);
assert_eq!(area.flags(), 0x3);
} else if off == 0x900 {
assert_eq!(area.size(), 0x300);
assert_eq!(area.flags(), 0x7);
} else if off == 0xc00 {
assert_eq!(area.size(), 0x400);
assert_eq!(area.flags(), 0x1);
}
}
}

// Test skip [0x880, 0x900), [0x2880, 0x2900), [0x4880, 0x4900), ...
for start in (0..MAX_ADDR).step_by(0x2000) {
assert_ok!(set.protect((start + 0x880).into(), 0x80, update_flags(0x3), &mut pt));
}
assert_eq!(set.len(), 39);

// Unmap all areas.
assert_ok!(set.unmap(0.into(), MAX_ADDR, &mut pt));
assert_eq!(set.len(), 0);
for addr in 0..MAX_ADDR {
assert_eq!(pt[addr], 0);
}
}