diff --git a/src/area.rs b/src/area.rs index c6611bf..6a8ba9e 100644 --- a/src/area.rs +++ b/src/area.rs @@ -17,18 +17,8 @@ pub trait MappingBackend: Clone { fn map(&self, start: VirtAddr, size: usize, flags: F, page_table: &mut P) -> bool; /// What to do when unmaping a memory region within the area. fn unmap(&self, start: VirtAddr, size: usize, page_table: &mut P) -> bool; - /// What to do when changing access flags. We pass both old and new flags - /// for backend's processing (e.g., some flags can not be changed through this - /// interface). Flags for replacement will be returned. And this interface - /// returns [None] if we need to skip this area for no flag will be changed. - fn protect( - &self, - start: VirtAddr, - size: usize, - old_flags: F, - new_flags: F, - page_table: &mut P, - ) -> Option; + /// What to do when changing access flags. + fn protect(&self, start: VirtAddr, size: usize, new_flags: F, page_table: &mut P) -> bool; } /// A memory area represents a continuous range of virtual memory with the same @@ -114,9 +104,10 @@ impl> MemoryArea { } /// Changes the flags in the page table. - pub(crate) fn protect_area(&mut self, new_flags: F, page_table: &mut P) -> Option { + pub(crate) fn protect_area(&mut self, new_flags: F, page_table: &mut P) -> MappingResult { self.backend - .protect(self.start(), self.size(), self.flags, new_flags, page_table) + .protect(self.start(), self.size(), new_flags, page_table); + Ok(()) } /// Shrinks the memory area at the left side. diff --git a/src/set.rs b/src/set.rs index fbf642e..53b5caf 100644 --- a/src/set.rs +++ b/src/set.rs @@ -178,69 +178,67 @@ impl> MemorySet { } /// Change the flags of memory mappings within the given address range. - /// - /// Memory areas with the same flags will be skipped. Memory areas that - /// are fully contained in the range or contains the range or intersects + /// + /// `update_flags` is a function that receives old flags and processes + /// new flags (e.g., some flags can not be changed through this interface). + /// It returns [None] if there is no bit to change. + /// + /// Memory areas will be skipped according to `update_flags`. Memory areas + /// that are fully contained in the range or contains the range or intersects /// with the boundary will be handled similarly to `munmap`. pub fn protect( &mut self, start: VirtAddr, size: usize, - new_flags: F, + update_flags: impl Fn(F) -> Option, page_table: &mut P, ) -> MappingResult { let end = start + size; let mut to_insert = Vec::new(); for (_, area) in self.areas.iter_mut() { - if area.start() >= end { - // [ prot ] - // [ area ] - break; - } else if area.end() <= start { - // [ prot ] - // [ area ] - // Do nothing - } else if area.start() >= start && area.end() <= end { - // [ prot ] - // [ area ] - if let Some(new_flags) = area.protect_area(new_flags, page_table) { + if let Some(new_flags) = update_flags(area.flags()) { + if area.start() >= end { + // [ prot ] + // [ area ] + break; + } else if area.end() <= start { + // [ prot ] + // [ area ] + // Do nothing + } else if area.start() >= start && area.end() <= end { + // [ prot ] + // [ area ] + area.protect_area(new_flags, page_table)?; area.set_flags(new_flags); - } - } else if area.start() < start && area.end() > end { - // [ prot ] - // [ left | area | right ] - let mut middle_part = - MemoryArea::new(start, size, area.flags(), area.backend().clone()); - if let Some(new_flags) = middle_part.protect_area(new_flags, page_table) { - middle_part.set_flags(new_flags); + } else if area.start() < start && area.end() > end { + // [ prot ] + // [ left | area | right ] let right_part = area.split(end).unwrap(); area.set_end(start); + + let mut middle_part = + MemoryArea::new(start, size, area.flags(), area.backend().clone()); + middle_part.protect_area(new_flags, page_table)?; + middle_part.set_flags(new_flags); + to_insert.push((right_part.start(), right_part)); to_insert.push((middle_part.start(), middle_part)); - } - } else if area.end() > end { - // [ prot ] - // [ area | right ] - let mut left_part = MemoryArea::new( - area.start(), - end.as_usize() - area.start().as_usize(), - area.flags(), - area.backend().clone(), - ); - if let Some(new_flags) = left_part.protect_area(new_flags, page_table) { + } else if area.end() > end { + // [ prot ] + // [ area | right ] let right_part = area.split(end).unwrap(); + area.protect_area(new_flags, page_table)?; area.set_flags(new_flags); + to_insert.push((right_part.start(), right_part)); - } - } else { - // [ prot ] - // [ left | area ] - let mut right_part = area.split(start).unwrap(); - if let Some(new_flags) = right_part.protect_area(new_flags, page_table) { + } else { + // [ prot ] + // [ left | area ] + let mut right_part = area.split(start).unwrap(); + right_part.protect_area(new_flags, page_table)?; right_part.set_flags(new_flags); + to_insert.push((right_part.start(), right_part)); - } else { - area.set_end(right_part.end()); // rollback the end } } } diff --git a/src/tests.rs b/src/tests.rs index a479b2f..6cbb9cd 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -37,18 +37,16 @@ impl MappingBackend for MockBackend { &self, start: VirtAddr, size: usize, - old_flags: MockFlags, new_flags: MockFlags, pt: &mut MockPageTable, - ) -> Option { - if (old_flags & 0x7) == (new_flags & 0x7) { - return None; - } - let flags = (new_flags & 0x7) | (old_flags & !0x7); + ) -> bool { for entry in pt.iter_mut().skip(start.as_usize()).take(size) { - *entry = flags; + if *entry == 0 { + return false; + } + *entry = new_flags; } - Some(flags) + true } } @@ -231,6 +229,15 @@ fn test_unmap_split() { fn test_protect() { let mut set = MockMemorySet::new(); let mut pt = [0; MAX_ADDR]; + let update_flags = |new_flags: MockFlags| { + move |old_flags: MockFlags| -> Option { + if (old_flags & 0x7) == (new_flags & 0x7) { + return None; + } + let flags = (new_flags & 0x7) | (old_flags & !0x7); + Some(flags) + } + }; // Map [0, 0x1000), [0x2000, 0x3000), [0x4000, 0x5000), ... for start in (0..MAX_ADDR).step_by(0x2000) { @@ -245,7 +252,7 @@ fn test_protect() { // Protect [0xc00, 0x2400), [0x2c00, 0x4400), [0x4c00, 0x6400), ... // The areas are split into two areas. for start in (0..MAX_ADDR).step_by(0x2000) { - assert_ok!(set.protect((start + 0xc00).into(), 0x1800, 0x1, &mut pt)); + assert_ok!(set.protect((start + 0xc00).into(), 0x1800, update_flags(0x1), &mut pt)); } dump_memory_set(&set); assert_eq!(set.len(), 23); @@ -272,7 +279,7 @@ fn test_protect() { // Protect [0x800, 0x900), [0x2800, 0x2900), [0x4800, 0x4900), ... // The areas are split into three areas. for start in (0..MAX_ADDR).step_by(0x2000) { - assert_ok!(set.protect((start + 0x800).into(), 0x100, 0x13, &mut pt)); + assert_ok!(set.protect((start + 0x800).into(), 0x100, update_flags(0x13), &mut pt)); } dump_memory_set(&set); assert_eq!(set.len(), 39); @@ -304,7 +311,7 @@ fn test_protect() { // Test skip [0x850, 0x900), [0x2850, 0x2900), [0x4850, 0x4900), ... for start in (0..MAX_ADDR).step_by(0x2000) { - assert_ok!(set.protect((start + 0x880).into(), 0x80, 0x3, &mut pt)); + assert_ok!(set.protect((start + 0x880).into(), 0x80, update_flags(0x3), &mut pt)); } dump_memory_set(&set); assert_eq!(set.len(), 39);