diff --git a/.gitignore b/.gitignore index 63d76060..70c16c23 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,8 @@ /target /*.tar.gz /riscof -/build -/proofs +/build* +/proofs* *.pilout /tmp -*.log \ No newline at end of file +*.log diff --git a/Cargo.lock b/Cargo.lock index e59ff2b5..cd9341db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2238,10 +2238,14 @@ name = "sm-mem" version = "0.1.0" dependencies = [ "log", + "num-bigint", + "num-traits", "p3-field", + "pil-std-lib", "proofman", "proofman-common", "proofman-macros", + "proofman-util", "rayon", "sm-common", "zisk-core", diff --git a/core/src/elf2rom.rs b/core/src/elf2rom.rs index 668c8aee..0209aa27 100644 --- a/core/src/elf2rom.rs +++ b/core/src/elf2rom.rs @@ -28,7 +28,15 @@ pub fn elf2rom(elf_file: String) -> Result> { for section_header in section_headers { // Consider only the section headers that contain program data if section_header.sh_type == SHT_PROGBITS { - // Get the program section data as a vector of bytes + // Get the section header address + let addr = section_header.sh_addr; + + // Ignore sections with address = 0, as per ELF spec + if addr == 0 { + continue; + } + + // Get the section data let (data_u8, _) = elf_bytes.section_data(§ion_header)?; let mut data = data_u8.to_vec(); @@ -37,31 +45,58 @@ pub fn elf2rom(elf_file: String) -> Result> { data.pop(); } - // Get the section data address - let addr = section_header.sh_addr; - - // If the data contains instructions, parse them as RISC-V instructions and add them - // to the ROM instructions, at the specified program address + // If this is a code section, add it to program if (section_header.sh_flags & SHF_EXECINSTR as u64) != 0 { add_zisk_code(&mut rom, addr, &data); } + // Add init data as a read/write memory section, initialized by code // If the data is a writable memory section, add it to the ROM memory using Zisk // copy instructions if (section_header.sh_flags & SHF_WRITE as u64) != 0 && addr >= RAM_ADDR && addr + data.len() as u64 <= RAM_ADDR + RAM_SIZE { - add_zisk_init_data(&mut rom, addr, &data); - // Otherwise, add it to the ROM as RO data - } else { - rom.ro_data.push(RoData::new(addr, data.len(), data)); + //println! {"elf2rom() new RW from={:x} length={:x}={}", addr, data.len(), + //data.len()}; + add_zisk_init_data(&mut rom, addr, &data, true); + } + // Add read-only data memory section + else { + // Search for an existing RO section previous to this one + let mut found = false; + for rd in rom.ro_data.iter_mut() { + // Section data should be previous to this one + if (rd.from + rd.length as u64) == addr { + rd.length += data.len(); + rd.data.extend(data.clone()); + found = true; + //println! {"elf2rom() adding RO from={:x} length={:x}={}", rd.from, + // rd.length, rd.length}; + break; + } + } + + // If not found, create a new RO section + if !found { + //println! {"elf2rom() new RO from={:x} length={:x}={}", addr, data.len(), + // data.len()}; + rom.ro_data.push(RoData::new(addr, data.len(), data)); + } } } } } - // Add the program setup, system call and program wrapup instructions + // Add RO data initialization code insctructions + let ro_data_len = rom.ro_data.len(); + for i in 0..ro_data_len { + let addr = rom.ro_data[i].from; + let mut data = Vec::new(); + data.extend(rom.ro_data[i].data.as_slice()); + add_zisk_init_data(&mut rom, addr, &data, true); + } + add_entry_exit_jmp(&mut rom, elf_bytes.ehdr.e_entry); // Preprocess the ROM (experimental) @@ -128,6 +163,8 @@ pub fn elf2rom(elf_file: String) -> Result> { } } + //println! {"elf2rom() got rom.insts.len={}", rom.insts.len()}; + Ok(rom) } diff --git a/core/src/mem.rs b/core/src/mem.rs index f5febf93..10c1d719 100644 --- a/core/src/mem.rs +++ b/core/src/mem.rs @@ -5,49 +5,48 @@ //! * The Zisk processor memory stores data in little-endian format. //! * The addressable memory space is divided into several regions described in the following map: //! -//! `|--------------- ROM_ENTRY: first BIOS instruction ( 0x1000)` -//! `|` -//! `| Performs memory initialization, calls program at ROM_ADDR,` -//! `| and after returning it performs memory finalization.` -//! `| Contains ecall/system call management code.` -//! `|` -//! `|--------------- ROM_EXIT: last BIOS instruction (0x10000000)` -//! ` ...` -//! `|--------------- ROM_ADDR: first program instruction (0x80000000)` -//! `|` -//! `| Contains program instructions.` -//! `| Calls ecalls/system calls when required.` -//! `|` -//! `|--------------- INPUT_ADDR (0x90000000)` -//! `|` -//! `| Contains program input data.` -//! `|` -//! `|--------------- SYS_ADDR (= RAM_ADDR = REG_FIRST) (0xa0000000)` -//! `|` -//! `| Contains system address.` -//! `| The first 256 bytes contain 32 8-byte registers` -//! `| The address UART_ADDR is used as a standard output` -//! `|` -//! `|--------------- OUTPUT_ADDR (0xa0010000)` -//! `|` -//! `| Contains output data, which is written during` -//! `| program execution and read during memory finalization` -//! `|` -//! `|--------------- AVAILABLE_MEM_ADDR (0xa0020000)` -//! `|` -//! `| Contains program memory, available for normal R/W` -//! `| use during program execution.` -//! `|` -//! `|--------------- (0xb0000000)` -//! ` ...` +//! `|--------------- ROM_ENTRY: first BIOS instruction ( 0x1000)` +//! `|` +//! `| Performs memory initialization, calls program at ROM_ADDR,` +//! `| and after returning it performs memory finalization.` +//! `| Contains ecall/system call management code.` +//! `|` +//! `|--------------- ROM_EXIT: last BIOS instruction (0x10000000)` +//! ` ...` +//! `|--------------- ROM_ADDR: first program instruction (0x80000000)` +//! `|` +//! `| Contains program instructions.` +//! `| Calls ecalls/system calls when required.` +//! `|` +//! `|--------------- INPUT_ADDR (0x90000000)` +//! `|` +//! `| Contains program input data.` +//! `|` +//! `|--------------- SYS_ADDR (= RAM_ADDR = REG_FIRST) (0xa0000000)` +//! `|` +//! `| Contains system address.` +//! `| The first 256 bytes contain 32 8-byte registers` +//! `| The address UART_ADDR is used as a standard output` +//! `|` +//! `|--------------- OUTPUT_ADDR (0xa0010000)` +//! `|` +//! `| Contains output data, which is written during` +//! `| program execution and read during memory finalization` +//! `|` +//! `|--------------- AVAILABLE_MEM_ADDR (0xa0020000)` +//! `|` +//! `| Contains program memory, available for normal R/W` +//! `| use during program execution.` +//! `|` +//! `|--------------- (0xb0000000)` +//! ` ...` //! //! ## ROM_ENTRY / ROM_ADDR / ROM_EXIT //! * The program will start executing at the first BIOS address `ROM_ENTRY`. //! * The first instructions do the basic program setup, including writing the input data into //! memory, configuring the ecall (system call) program address, and configuring the program //! completion return address. -//! * After the program setup, the program counter jumps to `ROM_ADDR`, executing the actual -//! program. +//! * After the program set1, the program counter jumps to `ROM_ADDR`, executing the actual program. //! * During the execution, the program can make system calls that will jump to the configured ecall //! program address, and return once the task has completed. The precompiled are implemented via //! ecall. @@ -79,14 +78,16 @@ //! * The third RW memory region going from `AVAILABLE_MEM_ADDR` onwards can be used during the //! program execution a general purpose memory. +use std::fmt; + /// Fist input data memory address pub const INPUT_ADDR: u64 = 0x90000000; /// Maximum size of the input data -pub const MAX_INPUT_SIZE: u64 = 0x10000000; // 256M, +pub const MAX_INPUT_SIZE: u64 = 0x08000000; // 128M, /// First globa RW memory address pub const RAM_ADDR: u64 = 0xa0000000; /// Size of the global RW memory -pub const RAM_SIZE: u64 = 0x10000000; // 256M +pub const RAM_SIZE: u64 = 0x08000000; // 128M /// First system RW memory address pub const SYS_ADDR: u64 = RAM_ADDR; /// Size of the system RW memory @@ -106,7 +107,7 @@ pub const ROM_EXIT: u64 = 0x10000000; /// First program ROM instruction address, i.e. first RISC-V transpiled instruction pub const ROM_ADDR: u64 = 0x80000000; /// Maximum program ROM instruction address -pub const ROM_ADDR_MAX: u64 = INPUT_ADDR - 1; +pub const ROM_ADDR_MAX: u64 = (ROM_ADDR + 0x08000000) - 1; // 128M /// Zisk architecture ID pub const ARCH_ID_ZISK: u64 = 0xFFFEEEE; /// UART memory address; single bytes written here will be copied to the standard output @@ -114,13 +115,46 @@ pub const UART_ADDR: u64 = SYS_ADDR + 512; /// Memory section data, including a buffer (a vector of bytes) and start and end program /// memory addresses. -#[derive(Default)] pub struct MemSection { pub start: u64, pub end: u64, + pub real_end: u64, pub buffer: Vec, } +/// Default constructor for MemSection structure +impl Default for MemSection { + fn default() -> Self { + Self::new() + } +} + +impl fmt::Debug for MemSection { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(&self.to_text()) + } +} + +/// Memory section structure implementation +impl MemSection { + /// Memory section constructor + pub fn new() -> MemSection { + MemSection { start: 0, end: 0, real_end: 0, buffer: Vec::new() } + } + pub fn to_text(&self) -> String { + format!( + "start={:x} real_end={:x} end={:x} diff={:x}={} buffer.len={:x}={}", + self.start, + self.real_end, + self.end, + self.end - self.start, + self.end - self.start, + self.buffer.len(), + self.buffer.len() + ) + } +} + /// Memory structure, containing several read sections and one single write section #[derive(Default)] pub struct Mem { @@ -129,6 +163,12 @@ pub struct Mem { } impl Mem { + /// Memory structue constructor + pub fn new() -> Mem { + //println!("Mem::new()"); + Mem { read_sections: Vec::new(), write_section: MemSection::new() } + } + /// Adds a read section to the memory structure pub fn add_read_section(&mut self, start: u64, buffer: &[u8]) { // Check that the start address is alligned to 8 bytes @@ -142,31 +182,60 @@ impl Mem { // Calculate the end address let end = start + buffer.len() as u64; - // Create a mem section with this data - let mut mem_section = MemSection { start, end, buffer: buffer.to_owned() }; + // If there exists a read section next to this one, reuse it + for existing_section in self.read_sections.iter_mut() { + if existing_section.real_end == start { + // Sanity check + assert!(existing_section.real_end <= existing_section.end); + assert!((existing_section.end - existing_section.real_end) < 8); + + // Pop tail zeros until end matches real_end + while existing_section.real_end > existing_section.end { + existing_section.buffer.pop(); + existing_section.end -= 1; + } + + // Append buffer + existing_section.buffer.extend(buffer); + existing_section.real_end += buffer.len() as u64; + existing_section.end = existing_section.real_end; + + // Append zeros until end is multiple of 8, so that we can read non-alligned reads + while (existing_section.end & 0x07) != 0 { + existing_section.buffer.push(0); + existing_section.end += 1; + } + + /*println!( + "Mem::add_read_section() start={:x} len={} existing section={}", + start, + buffer.len(), + existing_section.to_text() + );*/ + + return; + } + } + + // Create a new memory section + let mut new_section = MemSection { start, end, real_end: end, buffer: buffer.to_owned() }; - // Add zero-value bytes until the end address is alligned to 8 bytes - while (mem_section.end) % 8 != 0 { - mem_section.buffer.push(0); - mem_section.end += 1; + // Append zeros until end is multiple of 8, so that we can read non-alligned reads + while (new_section.end & 0x07) != 0 { + new_section.buffer.push(0); + new_section.end += 1; } - // Push the new read section to the read sections list - self.read_sections.push(mem_section); - - /*println!( - "Mem::add_read_section() start={:x}={} len={} end={:x}={}", - start, - start, - buffer.len(), - end, - end - );*/ + //println!("Mem::add_read_section() new section={}", new_section.to_text()); + + // Add the new section to the read sections + self.read_sections.push(new_section); } /// Adds a write section to the memory structure, which cannot be written twice pub fn add_write_section(&mut self, start: u64, size: u64) { - //println!("Mem::add_write_section() start={} size={}", start, size); + //println!("Mem::add_write_section() start={:x}={} size={:x}={}", start, start, size, + // size); // Check that the start address is alligned to 8 bytes if (start & 0x07) != 0 { @@ -262,6 +331,177 @@ impl Mem { } } + /* + Possible alignment situations: + - Full aligned = address is aligned to 8 bytes (last 3 bits are zero) and width is 8 + - Single not aligned = not full aligned, and the data fits into one aligned slice of 8 bytes + - Double not aligned = not full aligned, and the data needs 2 aligned slices of 8 bytes + + Data required for each situation: + - full_aligned + RD = value + - full_aligned + WR = value, full_value + - single_not_aligned + RD = value, full_value TODO: We can save the value space, optimization + - single_not_aligned + WR = value, previous_full_value + - double_not_aligned + RD = value, full_values_0, full_values_1 + - double_not_aligned + WR = value, previous_full_values_0, previous_full_values_1 + + read_required() returns read value, and a vector of additional data required to prove it + */ + + /// Read a u64 value from the memory read sections, based on the provided address and width + #[inline(always)] + pub fn read_required(&self, addr: u64, width: u64) -> (u64, Vec) { + // Calculate how aligned this operation is + let addr_req_1 = addr & 0xFFFF_FFFF_FFFF_FFF8; // Aligned address of the first 8-bytes chunk + let addr_req_2 = (addr + width - 1) & 0xFFFF_FFFF_FFFF_FFF8; // Aligned address of the second 8-bytes chunk, if needed + let is_full_aligned = ((addr & 0x03) == 0) && (width == 8); + let is_single_not_aligned = !is_full_aligned && (addr_req_1 == addr_req_2); + let is_double_not_aligned = !is_full_aligned && !is_single_not_aligned; + + // First try to read in the write section + if (addr >= self.write_section.start) && (addr <= (self.write_section.end - width)) { + // Calculate the read position + let read_position: usize = (addr - self.write_section.start) as usize; + + // Read the requested data based on the provided width + let value: u64 = match width { + 1 => self.write_section.buffer[read_position] as u64, + 2 => u16::from_le_bytes( + self.write_section.buffer[read_position..read_position + 2].try_into().unwrap(), + ) as u64, + 4 => u32::from_le_bytes( + self.write_section.buffer[read_position..read_position + 4].try_into().unwrap(), + ) as u64, + 8 => u64::from_le_bytes( + self.write_section.buffer[read_position..read_position + 8].try_into().unwrap(), + ), + _ => panic!("Mem::read() invalid width={}", width), + }; + + // If is a single not aligned operation, return the aligned address value + if is_single_not_aligned { + let mut additional_data: Vec = Vec::new(); + + assert!(addr_req_1 >= self.write_section.start); + let read_position_req: usize = (addr_req_1 - self.write_section.start) as usize; + let value_req = u64::from_le_bytes( + self.write_section.buffer[read_position_req..read_position_req + 8] + .try_into() + .unwrap(), + ); + additional_data.push(value_req); + + return (value, additional_data); + } + + // If is a double not aligned operation, return the aligned address value and the next + // one + if is_double_not_aligned { + let mut additional_data: Vec = Vec::new(); + + assert!(addr_req_1 >= self.write_section.start); + let read_position_req_1: usize = (addr_req_1 - self.write_section.start) as usize; + let value_req_1 = u64::from_le_bytes( + self.write_section.buffer[read_position_req_1..read_position_req_1 + 8] + .try_into() + .unwrap(), + ); + additional_data.push(value_req_1); + + assert!(addr_req_2 >= self.write_section.start); + let read_position_req_2: usize = (addr_req_2 - self.write_section.start) as usize; + let value_req_2 = u64::from_le_bytes( + self.write_section.buffer[read_position_req_2..read_position_req_2 + 8] + .try_into() + .unwrap(), + ); + additional_data.push(value_req_2); + + return (value, additional_data); + } + + //println!("Mem::read() addr={:x} width={} value={:x}={}", addr, width, value, value); + return (value, Vec::new()); + } + + // Search for the section that contains the address using binary search (dicothomic search) + let section = if let Ok(section) = self.read_sections.binary_search_by(|section| { + if addr < section.start { + std::cmp::Ordering::Greater + } else if (addr + width) > section.end { + std::cmp::Ordering::Less + } else { + std::cmp::Ordering::Equal + } + }) { + &self.read_sections[section] + } else { + println!("sections: {:?}", self.read_sections); + panic!("Mem::read() section not found for addr: {} with width: {}", addr, width); + }; + + // Calculate the read position + let read_position: usize = (addr - section.start) as usize; + + // Read the requested data based on the provided width + let value: u64 = match width { + 1 => section.buffer[read_position] as u64, + 2 => u16::from_le_bytes( + section.buffer[read_position..read_position + 2].try_into().unwrap(), + ) as u64, + 4 => u32::from_le_bytes( + section.buffer[read_position..read_position + 4].try_into().unwrap(), + ) as u64, + 8 => u64::from_le_bytes( + section.buffer[read_position..read_position + 8].try_into().unwrap(), + ), + _ => panic!( + "Mem::read() invalid addr:0x{:X} read_position:{} width:{}", + addr, read_position, width + ), + }; + + // If is a single not aligned operation, return the aligned address value + if is_single_not_aligned { + let mut additional_data: Vec = Vec::new(); + + assert!(addr_req_1 >= section.start); + let read_position_req: usize = (addr_req_1 - section.start) as usize; + let value_req = u64::from_le_bytes( + section.buffer[read_position_req..read_position_req + 8].try_into().unwrap(), + ); + additional_data.push(value_req); + + return (value, additional_data); + } + + // If is a double not aligned operation, return the aligned address value and the next + // one + if is_double_not_aligned { + let mut additional_data: Vec = Vec::new(); + + assert!(addr_req_1 >= section.start); + let read_position_req_1: usize = (addr_req_1 - section.start) as usize; + let value_req_1 = u64::from_le_bytes( + section.buffer[read_position_req_1..read_position_req_1 + 8].try_into().unwrap(), + ); + additional_data.push(value_req_1); + + assert!(addr_req_2 >= section.start); + let read_position_req_2: usize = (addr_req_2 - section.start) as usize; + let value_req_2 = u64::from_le_bytes( + section.buffer[read_position_req_2..read_position_req_2 + 8].try_into().unwrap(), + ); + additional_data.push(value_req_2); + + return (value, additional_data); + } + + //println!("Mem::read() addr={:x} width={} value={:x}={}", addr, width, value, value); + + (value, Vec::new()) + } + /// Write a u64 value to the memory write section, based on the provided address and width #[inline(always)] pub fn write(&mut self, addr: u64, val: u64, width: u64) { @@ -280,8 +520,24 @@ impl Mem { //println!("Mem::write() addr={:x}={} width={} value={:x}={}", addr, addr, width, val, // val); - // Get a reference to the write section - let section = &mut self.write_section; + // Search for the section that contains the address using binary search (dicothomic search) + let section = if let Ok(section) = self.read_sections.binary_search_by(|section| { + if addr < section.start { + std::cmp::Ordering::Greater + } else if addr > (section.end - width) { + std::cmp::Ordering::Less + } else { + std::cmp::Ordering::Equal + } + }) { + &mut self.read_sections[section] + } else { + /*panic!( + "Mem::write_silent() section not found for addr={:x}={} with width: {}", + addr, addr, width + );*/ + &mut self.write_section + }; // Check that the address and width fall into this section address range if (addr < section.start) || ((addr + width) > section.end) { @@ -304,6 +560,110 @@ impl Mem { 8 => section.buffer[write_position..write_position + 8] .copy_from_slice(&val.to_le_bytes()), _ => panic!("Mem::write_silent() invalid width={}", width), + }; + } + + /// Write a u64 value to the memory write section, based on the provided address and width + #[inline(always)] + pub fn write_silent_required(&mut self, addr: u64, val: u64, width: u64) -> Vec { + //println!("Mem::write() addr={:x}={} width={} value={:x}={}", addr, addr, width, val, + // val); + + // Search for the section that contains the address using binary search (dicothomic search) + let section = if let Ok(section) = self.read_sections.binary_search_by(|section| { + if addr < section.start { + std::cmp::Ordering::Greater + } else if addr > (section.end - width) { + std::cmp::Ordering::Less + } else { + std::cmp::Ordering::Equal + } + }) { + &mut self.read_sections[section] + } else { + /*panic!( + "Mem::write_silent() section not found for addr={:x}={} with width: {}", + addr, addr, width + );*/ + &mut self.write_section + }; + + // Check that the address and width fall into this section address range + if (addr < section.start) || ((addr + width) > section.end) { + panic!( + "Mem::write_silent() invalid addr={}={:x} write section start={:x} end={:x}", + addr, addr, section.start, section.end + ); + } + + // Calculate how aligned this operation is + let addr_req_1 = addr & 0xFFFF_FFFF_FFFF_FFF8; // Aligned address of the first 8-bytes chunk + let addr_req_2 = (addr + width - 1) & 0xFFFF_FFFF_FFFF_FFF8; // Aligned address of the second 8-bytes chunk, if needed + let is_full_aligned = ((addr & 0x03) == 0) && (width == 8); + let is_single_not_aligned = !is_full_aligned && (addr_req_1 == addr_req_2); + let is_double_not_aligned = !is_full_aligned && !is_single_not_aligned; + + // Declare an empty vector + let mut additional_data: Vec = Vec::new(); + + // If is a single not aligned operation, return the aligned address value + if is_single_not_aligned { + assert!( + addr_req_1 >= section.start, + "addr_req_1: 0x{:X} 0x{:X}]", + addr_req_1, + section.start + ); + let read_position_req: usize = (addr_req_1 - section.start) as usize; + let value_req = u64::from_le_bytes( + section.buffer[read_position_req..read_position_req + 8].try_into().unwrap(), + ); + additional_data.push(value_req); } + + // If is a double not aligned operation, return the aligned address value and the next + // one + if is_double_not_aligned { + assert!( + addr_req_1 >= section.start, + "addr_req_1(d): 0x{:X} 0x{:X}]", + addr_req_1, + section.start + ); + let read_position_req_1: usize = (addr_req_1 - section.start) as usize; + let value_req_1 = u64::from_le_bytes( + section.buffer[read_position_req_1..read_position_req_1 + 8].try_into().unwrap(), + ); + additional_data.push(value_req_1); + + assert!( + addr_req_2 >= section.start, + "addr_req_2(d): 0x{:X} 0x{:X}]", + addr_req_2, + section.start + ); + let read_position_req_2: usize = (addr_req_2 - section.start) as usize; + let value_req_2 = u64::from_le_bytes( + section.buffer[read_position_req_2..read_position_req_2 + 8].try_into().unwrap(), + ); + additional_data.push(value_req_2); + } + + // Calculate the write position + let write_position: usize = (addr - section.start) as usize; + + // Write the value based on the provided width + match width { + 1 => section.buffer[write_position] = val as u8, + 2 => section.buffer[write_position..write_position + 2] + .copy_from_slice(&(val as u16).to_le_bytes()), + 4 => section.buffer[write_position..write_position + 4] + .copy_from_slice(&(val as u32).to_le_bytes()), + 8 => section.buffer[write_position..write_position + 8] + .copy_from_slice(&val.to_le_bytes()), + _ => panic!("Mem::write_silent() invalid width={}", width), + } + + additional_data } } diff --git a/core/src/riscv2zisk_context.rs b/core/src/riscv2zisk_context.rs index fd0ab47a..d9a12b4d 100644 --- a/core/src/riscv2zisk_context.rs +++ b/core/src/riscv2zisk_context.rs @@ -1308,8 +1308,13 @@ pub fn add_zisk_code(rom: &mut ZiskRom, addr: u64, data: &[u8]) { /// /// The initial data is copied in chunks of 8 bytes for efficiency, until less than 8 bytes are left /// to copy. The remaining bytes are copied in additional chunks of 4, 2 and 1 byte, if required. -pub fn add_zisk_init_data(rom: &mut ZiskRom, addr: u64, data: &[u8]) { - //print!("add_zisk_init_data() addr={}\n", addr); +pub fn add_zisk_init_data(rom: &mut ZiskRom, addr: u64, data: &[u8], force_aligned: bool) { + /*let mut s = String::new(); + for i in 0..min(50, data.len()) { + s += &format!("{:02x}", data[i]); + } + print!("add_zisk_init_data() addr={:x} len={} data={}...\n", addr, data.len(), s);*/ + let mut o = addr; // Read 64-bit input data chunks and store them in rom @@ -1330,6 +1335,29 @@ pub fn add_zisk_init_data(rom: &mut ZiskRom, addr: u64, data: &[u8]) { o += 8; } + // TODO: review if necessary + let bytes = addr + data.len() as u64 - o; + // If force_aligned is active always store aligned + if force_aligned && bytes > 0 { + let mut v: u64 = 0; + let from = (o - addr + bytes - 1) as usize; + for i in 0..bytes { + v = v * 256 + data[from - i as usize] as u64; + } + let mut zib = ZiskInstBuilder::new(rom.next_init_inst_addr); + zib.src_a("imm", o, false); + zib.src_b("imm", v, false); + zib.op("copyb").unwrap(); + zib.ind_width(8); + zib.store("ind", 0, false, false); + zib.j(4, 4); + zib.verbose(&format!("Init Data {:08x}: {:04x}", o, v)); + zib.build(); + rom.insts.insert(rom.next_init_inst_addr, zib); + rom.next_init_inst_addr += 4; + o += bytes; + } + // Read remaining 32-bit input data chunk, if any, and store them in rom if addr + data.len() as u64 - o >= 4 { let v = u32::from_le_bytes(data[o as usize..o as usize + 4].try_into().unwrap()); @@ -1366,7 +1394,7 @@ pub fn add_zisk_init_data(rom: &mut ZiskRom, addr: u64, data: &[u8]) { // Read remaining 8-bit input data chunk, if any, and store them in rom if addr + data.len() as u64 - o >= 1 { - let v = data[o as usize]; + let v = data[(o - addr) as usize]; let mut zib = ZiskInstBuilder::new(rom.next_init_inst_addr); zib.src_a("imm", o, false); zib.src_b("imm", v as u64, false); @@ -1380,7 +1408,21 @@ pub fn add_zisk_init_data(rom: &mut ZiskRom, addr: u64, data: &[u8]) { rom.next_init_inst_addr += 4; o += 1; } - + /* + if force_aligned { + let mut zib = ZiskInstBuilder::new(rom.next_init_inst_addr); + zib.src_a("imm", o, false); + zib.src_b("imm", 0, false); + zib.op("copyb").unwrap(); + zib.ind_width(8); + zib.store("ind", 0, false, false); + zib.j(4, 4); + zib.verbose(&format!("Init Data {:08x}: {:04x}", o, 0)); + zib.build(); + rom.insts.insert(rom.next_init_inst_addr, zib); + rom.next_init_inst_addr += 4; + } + */ // Check resulting length if o != addr + data.len() as u64 { panic!("add_zisk_init_data() invalid length o={} addr={} data.len={}", o, addr, data.len()); diff --git a/core/src/zisk_ops.rs b/core/src/zisk_ops.rs index 71efe57f..c9e23c10 100644 --- a/core/src/zisk_ops.rs +++ b/core/src/zisk_ops.rs @@ -284,7 +284,7 @@ define_ops! { (MaxuW, "maxu_w", Binary, 77, 0x24, opc_maxu_w, op_maxu_w), (MaxW, "max_w", Binary, 77, 0x25, opc_max_w, op_max_w), (Keccak, "keccak", Keccak, 77, 0xf1, opc_keccak, op_keccak), - (PubOut, "pubout", PubOut, 77, 0x30, opc_pubout, op_pubout), // TODO: New type + (PubOut, "pubout", PubOut, 77, 0x30, opc_pubout, op_pubout), } /* INTERNAL operations */ diff --git a/core/src/zisk_required_operation.rs b/core/src/zisk_required_operation.rs index d82644da..81078f5b 100644 --- a/core/src/zisk_required_operation.rs +++ b/core/src/zisk_required_operation.rs @@ -1,14 +1,14 @@ //! Data required to prove the different Zisk operations -use std::collections::HashMap; +use std::{collections::HashMap, fmt}; -/// Required data to make an operation. +/// Required data to make an operation. /// /// Stores the minimum information to reproduce an operation execution: /// * The opcode and the a and b registers values (regardless of their sources) /// * The step is also stored to keep track of the program execution point /// -/// This data is generated during the first emulation execution. +/// This data is generated during the first emulation execution. /// This data is required by the main state machine executor to generate the witness computation. #[derive(Clone)] pub struct ZiskRequiredOperation { @@ -20,12 +20,52 @@ pub struct ZiskRequiredOperation { /// Stores the minimum information to generate the memory state machine witness computation. #[derive(Clone)] -pub struct ZiskRequiredMemory { - pub step: u64, - pub is_write: bool, - pub address: u64, - pub width: u64, - pub value: u64, +pub enum ZiskRequiredMemory { + Basic { step: u64, value: u64, address: u32, is_write: bool, width: u8, step_offset: u8 }, + Extended { values: [u64; 2], address: u32 }, +} + +impl fmt::Debug for ZiskRequiredMemory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ZiskRequiredMemory::Basic { step, value, address, is_write, width, step_offset: _ } => { + let label = if *is_write { "WR" } else { "RD" }; + write!( + f, + "{0} addr:{1:#08X}({1}) offset:{5} width:{2} value:{3:#016X}({3}) step:{4}", + label, + address, + width, + value, + step, + address & 0x07 + ) + } + ZiskRequiredMemory::Extended { values, address } => { + write!( + f, + "addr:{1:#08X}({0}) value[1]:{1} value[2]:{2}", + address, values[0], values[1], + ) + } + } + } +} + +impl ZiskRequiredMemory { + pub fn get_address(&self) -> u32 { + match self { + ZiskRequiredMemory::Basic { + step: _, + value: _, + address, + is_write: _, + width: _, + step_offset: _, + } => *address, + ZiskRequiredMemory::Extended { values: _, address } => *address, + } + } } /// Data required to get some operations proven by the secondary state machine @@ -37,9 +77,9 @@ pub struct ZiskRequired { pub memory: Vec, } -/// Histogram of the program counter values used during the program execution. +/// Histogram of the program counter values used during the program execution. /// -/// Each pc value has a u64 counter, associated to it via a hash map. +/// Each pc value has a u64 counter, associated to it via a hash map. /// The counter is increased every time the corresponding instruction is executed. #[derive(Clone, Default)] pub struct ZiskPcHistogram { diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index a5f78a6a..6ec7fa0f 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -9,9 +9,9 @@ use riscv::RiscVRegisters; // #[cfg(feature = "sp")] // use zisk_core::SRC_SP; use zisk_core::{ - InstContext, ZiskInst, ZiskOperationType, ZiskPcHistogram, ZiskRequiredOperation, ZiskRom, - OUTPUT_ADDR, ROM_ENTRY, SRC_C, SRC_IMM, SRC_IND, SRC_MEM, SRC_STEP, STORE_IND, STORE_MEM, - STORE_NONE, SYS_ADDR, ZISK_OPERATION_TYPE_VARIANTS, + InstContext, ZiskInst, ZiskOperationType, ZiskPcHistogram, ZiskRequiredMemory, + ZiskRequiredOperation, ZiskRom, OUTPUT_ADDR, ROM_ENTRY, SRC_C, SRC_IMM, SRC_IND, SRC_MEM, + SRC_STEP, STORE_IND, STORE_MEM, STORE_NONE, SYS_ADDR, ZISK_OPERATION_TYPE_VARIANTS, }; /// ZisK emulator structure, containing the ZisK rom, the list of ZisK operations, and the @@ -29,8 +29,8 @@ pub struct Emu<'a> { /// - run -> step -> source_a, source_b, store_c (full functionality, called by main state machine, /// calls callback with trace) /// - run -> run_fast -> step_fast -> source_a, source_b, store_c (maximum speed, for benchmarking) -/// - run_slice -> step_slice -> source_a_slice, source_b_slice, store_c_slice (generates full trace -/// and required input data for secondary state machines) +/// - run_slice -> step_slice -> source_a_slice, source_b_slice (generates full trace and required +/// input data for secondary state machines) impl<'a> Emu<'a> { pub fn new(rom: &ZiskRom) -> Emu { Emu { rom, ctx: EmuContext::default() } @@ -92,6 +92,62 @@ impl<'a> Emu<'a> { } } + /// Calculate the 'a' register value based on the source specified by the current instruction + #[inline(always)] + pub fn source_a_memory( + &mut self, + instruction: &ZiskInst, + emu_mem: &mut Vec, + ) { + match instruction.a_src { + SRC_C => self.ctx.inst_ctx.a = self.ctx.inst_ctx.c, + SRC_MEM => { + // Build the memory address + let mut addr = instruction.a_offset_imm0; + if instruction.a_use_sp_imm1 != 0 { + addr += self.ctx.inst_ctx.sp; + } + + // Call read_required to get both the read value and the additional data (aligned + // read values required to construct the requested read value, if not aligned) + let additional_data: Vec; + (self.ctx.inst_ctx.a, additional_data) = + self.ctx.inst_ctx.mem.read_required(addr, 8); + + // Store the read value into the vector as a basic record + let required_memory = ZiskRequiredMemory::Basic { + step: self.ctx.inst_ctx.step, + step_offset: 0, + is_write: false, + address: addr as u32, + width: 8, + value: self.ctx.inst_ctx.a, + }; + emu_mem.push(required_memory); + + // Store the additional data, if any, as extended records + if !additional_data.is_empty() { + assert!(additional_data.len() <= 2); + let mut values: [u64; 2] = [0; 2]; + values[..additional_data.len()].copy_from_slice(&additional_data[..]); + let required_memory = + ZiskRequiredMemory::Extended { values, address: addr as u32 }; + emu_mem.push(required_memory); + } + } + SRC_IMM => { + self.ctx.inst_ctx.a = instruction.a_offset_imm0 | (instruction.a_use_sp_imm1 << 32) + } + SRC_STEP => self.ctx.inst_ctx.a = self.ctx.inst_ctx.step, + // #[cfg(feature = "sp")] + // SRC_SP => self.ctx.inst_ctx.a = self.ctx.inst_ctx.sp, + _ => panic!( + "Emu::source_a() Invalid a_src={} pc={}", + instruction.a_src, self.ctx.inst_ctx.pc + ), + } + } + /// Calculate the 'b' register value based on the source specified by the current instruction #[inline(always)] pub fn source_b(&mut self, instruction: &ZiskInst) { @@ -128,6 +184,94 @@ impl<'a> Emu<'a> { } } + /// Calculate the 'b' register value based on the source specified by the current instruction + #[inline(always)] + pub fn source_b_memory( + &mut self, + instruction: &ZiskInst, + emu_mem: &mut Vec, + ) { + match instruction.b_src { + SRC_C => self.ctx.inst_ctx.b = self.ctx.inst_ctx.c, + SRC_MEM => { + // Build the memory address + let mut addr = instruction.b_offset_imm0; + if instruction.b_use_sp_imm1 != 0 { + addr += self.ctx.inst_ctx.sp; + } + + // Call read_required to get both the read value and the additional data (aligned + // read values required to construct the requested read value, if not aligned) + let additional_data: Vec; + (self.ctx.inst_ctx.b, additional_data) = + self.ctx.inst_ctx.mem.read_required(addr, 8); + + // Store the read value into the vector as a basic record + let required_memory = ZiskRequiredMemory::Basic { + step: self.ctx.inst_ctx.step, + step_offset: 1, + is_write: false, + address: addr as u32, + width: 8, + value: self.ctx.inst_ctx.b, + }; + emu_mem.push(required_memory); + + // Store the additional data, if any, as extended records + if !additional_data.is_empty() { + assert!(additional_data.len() <= 2); + let mut values: [u64; 2] = [0; 2]; + values[..additional_data.len()].copy_from_slice(&additional_data[..]); + let required_memory = + ZiskRequiredMemory::Extended { values, address: addr as u32 }; + emu_mem.push(required_memory); + } + } + SRC_IMM => { + self.ctx.inst_ctx.b = instruction.b_offset_imm0 | (instruction.b_use_sp_imm1 << 32) + } + SRC_IND => { + // Build the memory address + let mut addr = + (self.ctx.inst_ctx.a as i64 + instruction.b_offset_imm0 as i64) as u64; + if instruction.b_use_sp_imm1 != 0 { + addr += self.ctx.inst_ctx.sp; + } + + // Call read_required to get both the read value and the additional data (aligned + // read values required to construct the requested read value, if not aligned) + let additional_data: Vec; + (self.ctx.inst_ctx.b, additional_data) = + self.ctx.inst_ctx.mem.read_required(addr, instruction.ind_width); + + // Store the read value into the vector as a basic record + let required_memory = ZiskRequiredMemory::Basic { + step: self.ctx.inst_ctx.step, + step_offset: 1, + is_write: false, + address: addr as u32, + width: instruction.ind_width as u8, + value: self.ctx.inst_ctx.b, + }; + emu_mem.push(required_memory); + + // Store the additional data, if any, as extended records + if !additional_data.is_empty() { + assert!(additional_data.len() <= 2); + let mut values: [u64; 2] = [0; 2]; + values[..additional_data.len()].copy_from_slice(&additional_data[..]); + let required_memory = + ZiskRequiredMemory::Extended { values, address: addr as u32 }; + emu_mem.push(required_memory); + } + } + _ => panic!( + "Emu::source_b() Invalid b_src={} pc={}", + instruction.b_src, self.ctx.inst_ctx.pc + ), + } + } + /// Store the 'c' register value based on the storage specified by the current instruction #[inline(always)] pub fn store_c(&mut self, instruction: &ZiskInst) { @@ -171,45 +315,107 @@ impl<'a> Emu<'a> { } } - /// Store the 'c' register value based on the storage specified by the current instruction and - /// log memory access if required + /// Store the 'c' register value based on the storage specified by the current instruction #[inline(always)] - pub fn store_c_slice(&mut self, instruction: &ZiskInst) { + pub fn store_c_memory( + &mut self, + instruction: &ZiskInst, + emu_mem: &mut Vec, + ) { match instruction.store { STORE_NONE => {} STORE_MEM => { + // Calculate the value to write let val: i64 = if instruction.store_ra { self.ctx.inst_ctx.pc as i64 + instruction.jmp_offset2 } else { self.ctx.inst_ctx.c as i64 }; + + // Build the memory address let mut addr: i64 = instruction.store_offset; if instruction.store_use_sp { addr += self.ctx.inst_ctx.sp as i64; } - self.ctx.inst_ctx.mem.write_silent(addr as u64, val as u64, 8); + + // Call write_silent_required to get the additional data (aligned read values + // required to construct the new written data, if not aligned) + let additional_data: Vec = + self.ctx.inst_ctx.mem.write_silent_required(addr as u64, val as u64, 8); + + // Store the written value into the vector as a basic record + let required_memory = ZiskRequiredMemory::Basic { + step: self.ctx.inst_ctx.step, + step_offset: 2, + is_write: true, + address: addr as u32, + width: 8, + value: val as u64, + }; + emu_mem.push(required_memory); + + // Store the additional data, if any, as extended records + if !additional_data.is_empty() { + assert!(additional_data.len() <= 2); + let mut values: [u64; 2] = [0; 2]; + values[..additional_data.len()].copy_from_slice(&additional_data[..]); + let required_memory = + ZiskRequiredMemory::Extended { values, address: addr as u32 }; + emu_mem.push(required_memory); + } } STORE_IND => { + // Calculate the value to write let val: i64 = if instruction.store_ra { self.ctx.inst_ctx.pc as i64 + instruction.jmp_offset2 } else { self.ctx.inst_ctx.c as i64 }; + + // Build the memory address let mut addr = instruction.store_offset; if instruction.store_use_sp { addr += self.ctx.inst_ctx.sp as i64; } addr += self.ctx.inst_ctx.a as i64; - self.ctx.inst_ctx.mem.write_silent(addr as u64, val as u64, instruction.ind_width); + + // Call write_silent_required to get the additional data (aligned read values + // required to construct the new written data, if not aligned) + let additional_data: Vec = self.ctx.inst_ctx.mem.write_silent_required( + addr as u64, + val as u64, + instruction.ind_width, + ); + + // Store the written value into the vector as a basic record + let required_memory = ZiskRequiredMemory::Basic { + step: self.ctx.inst_ctx.step, + step_offset: 2, + is_write: true, + address: addr as u32, + width: instruction.ind_width as u8, + value: val as u64, + }; + emu_mem.push(required_memory); + + // Store the additional data, if any, as extended records + if !additional_data.is_empty() { + assert!(additional_data.len() <= 2); + let mut values: [u64; 2] = [0; 2]; + values[..additional_data.len()].copy_from_slice(&additional_data[..]); + let required_memory = + ZiskRequiredMemory::Extended { values, address: addr as u32 }; + emu_mem.push(required_memory); + } } _ => panic!( - "Emu::store_c_slice() Invalid store={} pc={}", + "Emu::store_c() Invalid store={} pc={}", instruction.store, self.ctx.inst_ctx.pc ), } } - // Set SP, if specified by the current instruction + /// Set SP, if specified by the current instruction // #[cfg(feature = "sp")] // #[inline(always)] // pub fn set_sp(&mut self, instruction: &ZiskInst) { @@ -449,6 +655,22 @@ impl<'a> Emu<'a> { (emu_traces, emu_segments) } + pub fn par_run_memory(&mut self, inputs: Vec) -> Vec { + // Context, where the state of the execution is stored and modified at every execution step + self.ctx = self.create_emu_context(inputs); + + // Init pc to the rom entry address + self.ctx.trace.start_state.pc = ROM_ENTRY; + + let mut emu_mem = Vec::new(); + + while !self.ctx.inst_ctx.end { + self.par_step_memory::(&mut emu_mem); + } + + emu_mem + } + /// Performs one single step of the emulation #[inline(always)] #[allow(unused_variables)] @@ -456,8 +678,13 @@ impl<'a> Emu<'a> { let pc = self.ctx.inst_ctx.pc; let instruction = self.rom.get_instruction(self.ctx.inst_ctx.pc); - //println!("Emu::step() executing step={} pc={:x} inst={}", ctx.step, ctx.pc, - // inst.i.to_string()); println!("Emu::step() step={} pc={}", ctx.step, ctx.pc); + /*println!( + "Emu::step() executing step={} pc={:x} inst={}", + self.ctx.inst_ctx.step, + self.ctx.inst_ctx.pc, + instruction.to_text() + );*/ + //println!("Emu::step() step={} pc={}", ctx.step, ctx.pc); // Build the 'a' register value based on the source specified by the current instruction self.source_a(instruction); @@ -622,6 +849,56 @@ impl<'a> Emu<'a> { self.ctx.inst_ctx.step += 1; } + /// Performs one single step of the emulation + #[inline(always)] + #[allow(unused_variables)] + pub fn par_step_memory(&mut self, emu_mem: &mut Vec) { + //let last_pc = self.ctx.inst_ctx.pc; + //let last_c = self.ctx.inst_ctx.c; + + let instruction = self.rom.get_instruction(self.ctx.inst_ctx.pc); + + // println!( + // "#### step={} pc={} op={}={} a={} b={} c={} flag={} inst={}", + // self.ctx.inst_ctx.step, + // self.ctx.inst_ctx.pc, + // instruction.op, + // instruction.op_str, + // self.ctx.inst_ctx.a, + // self.ctx.inst_ctx.b, + // self.ctx.inst_ctx.c, + // self.ctx.inst_ctx.flag, + // instruction.to_text() + // ); + // self.print_regs(); + // println!(); + + // Build the 'a' register value based on the source specified by the current instruction + self.source_a_memory(instruction, emu_mem); + + // Build the 'b' register value based on the source specified by the current instruction + self.source_b_memory(instruction, emu_mem); + + // Call the operation + (instruction.func)(&mut self.ctx.inst_ctx); + + // Store the 'c' register value based on the storage specified by the current instruction + self.store_c_memory(instruction, emu_mem); + + // Set SP, if specified by the current instruction + // #[cfg(feature = "sp")] + // self.set_sp(instruction); + + // Set PC, based on current PC, current flag and current instruction + self.set_pc(instruction); + + // If this is the last instruction, stop executing + self.ctx.inst_ctx.end = instruction.end; + + // Increment step counter + self.ctx.inst_ctx.step += 1; + } + /// Performs one single step of the emulation #[inline(always)] #[allow(unused_variables)] @@ -742,7 +1019,7 @@ impl<'a> Emu<'a> { self.ctx.inst_ctx.a = trace_step.a; self.ctx.inst_ctx.b = trace_step.b; (instruction.func)(&mut self.ctx.inst_ctx); - self.store_c_slice(instruction); + // No need to store c // #[cfg(feature = "sp")] // self.set_sp(instruction); self.set_pc(instruction); @@ -788,7 +1065,7 @@ impl<'a> Emu<'a> { self.ctx.inst_ctx.a = trace_step.a; self.ctx.inst_ctx.b = trace_step.b; (instruction.func)(&mut self.ctx.inst_ctx); - self.store_c_slice(instruction); + // No need to store c // #[cfg(feature = "sp")] // self.set_sp(instruction); self.set_pc(instruction); diff --git a/emulator/src/emu_context.rs b/emulator/src/emu_context.rs index 811baa6c..6ed4d581 100644 --- a/emulator/src/emu_context.rs +++ b/emulator/src/emu_context.rs @@ -64,6 +64,28 @@ impl EmuContext { impl Default for EmuContext { fn default() -> Self { - Self::new(Vec::new()) + EmuContext { + inst_ctx: InstContext { + mem: Mem::new(), + a: 0, + b: 0, + c: 0, + flag: false, + sp: 0, + pc: ROM_ENTRY, + step: 0, + end: false, + }, + tracerv: Vec::new(), + tracerv_step: 0, + tracerv_current_regs: [0; 32], + trace_pc: 0, + trace: EmuTrace::default(), + do_callback: false, + callback_steps: 0, + last_callback_step: 0, + do_stats: false, + stats: Stats::default(), + } } } diff --git a/emulator/src/emu_trace.rs b/emulator/src/emu_trace.rs index 7baff978..75b4b138 100644 --- a/emulator/src/emu_trace.rs +++ b/emulator/src/emu_trace.rs @@ -13,9 +13,10 @@ pub struct EmuTraceStart { pub step: u64, } -/// Trace data at every step. -/// Only the values of registers a and b are required. -/// The current value of pc evolves starting at the start pc value, as we execute the ROM. +/// Trace data at every step. +/// +/// Only the values of registers a and b are required. +/// The current value of pc evolves starting at the start pc value, as we execute the ROM. /// The value of c and flag can be obtained by executing the ROM instruction corresponding to the /// current value of pc and taking a and b as the input. #[derive(Default, Debug, Clone)] @@ -26,12 +27,13 @@ pub struct EmuTraceStep { pub b: u64, } -/// Trace data at the end of the program execution, including only the `end` flag. -/// If the `end` flag is true, the program executed completely. +/// Trace data at the end of the program execution, including only the `end` flag. +/// +/// If the `end` flag is true, the program executed completely. /// This does not mean that the program ended successfully; it could have found an error condition -/// due to, for example, invalid input data, and then jump directly to the end of the ROM. +/// due to, for example, invalid input data, and then jump directly to the end of the ROM. /// In this error situation, the output data should reveal the success or fail of the completed -/// execution. +/// execution. /// These are the possible combinations: /// * end = false --> program did not complete, e.g. the emulator run out of steps (you can /// configure more steps) diff --git a/emulator/src/emulator.rs b/emulator/src/emulator.rs index 36d3bceb..49d7b8d2 100644 --- a/emulator/src/emulator.rs +++ b/emulator/src/emulator.rs @@ -28,8 +28,8 @@ use std::{ }; use sysinfo::System; use zisk_core::{ - Riscv2zisk, ZiskOperationType, ZiskPcHistogram, ZiskRequiredOperation, ZiskRom, - ZISK_OPERATION_TYPE_VARIANTS, + Riscv2zisk, ZiskOperationType, ZiskPcHistogram, ZiskRequiredMemory, ZiskRequiredOperation, + ZiskRom, ZISK_OPERATION_TYPE_VARIANTS, }; pub trait Emulator { @@ -261,6 +261,22 @@ impl ZiskEmulator { Ok((vec_traces, emu_slices)) } + pub fn par_process_rom_memory( + rom: &ZiskRom, + inputs: &[u8], + ) -> Result, ZiskEmulatorErr> { + let mut emu = Emu::new(rom); + let result = emu.par_run_memory::(inputs.to_owned()); + + if !emu.terminated() { + panic!("Emulation did not complete"); + // TODO! + // return Err(ZiskEmulatorErr::EmulationNoCompleted); + } + + Ok(result) + } + /// Process a Zisk rom with the provided input data, according to the configured options, in /// order to generate a set of required operation data. #[inline] diff --git a/pil/src/lib.rs b/pil/src/lib.rs index aee8bab5..27705cb0 100644 --- a/pil/src/lib.rs +++ b/pil/src/lib.rs @@ -6,8 +6,5 @@ pub use pil_helpers::*; pub const ARITH32_AIR_IDS: &[usize] = &[4, 5]; pub const ARITH64_AIR_IDS: &[usize] = &[6]; pub const ARITH3264_AIR_IDS: &[usize] = &[7]; -pub const MEM_AIRGROUP_ID: usize = 105; -pub const MEM_ALIGN_AIR_IDS: &[usize] = &[1]; -pub const MEM_UNALIGNED_AIR_IDS: &[usize] = &[2, 3]; pub const QUICKOPS_AIRGROUP_ID: usize = 102; pub const QUICKOPS_AIR_IDS: &[usize] = &[10]; diff --git a/pil/src/pil_helpers/pilout.rs b/pil/src/pil_helpers/pilout.rs index 9098a62b..919399c6 100644 --- a/pil/src/pil_helpers/pilout.rs +++ b/pil/src/pil_helpers/pilout.rs @@ -14,21 +14,35 @@ pub const MAIN_AIR_IDS: &[usize] = &[0]; pub const ROM_AIR_IDS: &[usize] = &[1]; -pub const ARITH_AIR_IDS: &[usize] = &[2]; +pub const MEM_AIR_IDS: &[usize] = &[2]; -pub const ARITH_TABLE_AIR_IDS: &[usize] = &[3]; +pub const ROM_DATA_AIR_IDS: &[usize] = &[3]; -pub const ARITH_RANGE_TABLE_AIR_IDS: &[usize] = &[4]; +pub const INPUT_DATA_AIR_IDS: &[usize] = &[4]; -pub const BINARY_AIR_IDS: &[usize] = &[5]; +pub const MEM_ALIGN_AIR_IDS: &[usize] = &[5]; -pub const BINARY_TABLE_AIR_IDS: &[usize] = &[6]; +pub const MEM_ALIGN_ROM_AIR_IDS: &[usize] = &[6]; -pub const BINARY_EXTENSION_AIR_IDS: &[usize] = &[7]; +pub const ARITH_AIR_IDS: &[usize] = &[7]; -pub const BINARY_EXTENSION_TABLE_AIR_IDS: &[usize] = &[8]; +pub const ARITH_TABLE_AIR_IDS: &[usize] = &[8]; -pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[9]; +pub const ARITH_RANGE_TABLE_AIR_IDS: &[usize] = &[9]; + +pub const BINARY_AIR_IDS: &[usize] = &[10]; + +pub const BINARY_TABLE_AIR_IDS: &[usize] = &[11]; + +pub const BINARY_EXTENSION_AIR_IDS: &[usize] = &[12]; + +pub const BINARY_EXTENSION_TABLE_AIR_IDS: &[usize] = &[13]; + +pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[14]; + +pub const U_8_AIR_AIR_IDS: &[usize] = &[15]; + +pub const U_16_AIR_AIR_IDS: &[usize] = &[16]; pub struct Pilout; @@ -39,7 +53,12 @@ impl Pilout { let air_group = pilout.add_air_group(Some("Zisk")); air_group.add_air(Some("Main"), 2097152); - air_group.add_air(Some("Rom"), 1048576); + air_group.add_air(Some("Rom"), 4194304); + air_group.add_air(Some("Mem"), 2097152); + air_group.add_air(Some("RomData"), 2097152); + air_group.add_air(Some("InputData"), 2097152); + air_group.add_air(Some("MemAlign"), 2097152); + air_group.add_air(Some("MemAlignRom"), 256); air_group.add_air(Some("Arith"), 2097152); air_group.add_air(Some("ArithTable"), 128); air_group.add_air(Some("ArithRangeTable"), 4194304); @@ -48,6 +67,8 @@ impl Pilout { air_group.add_air(Some("BinaryExtension"), 2097152); air_group.add_air(Some("BinaryExtensionTable"), 4194304); air_group.add_air(Some("SpecifiedRanges"), 16777216); + air_group.add_air(Some("U8Air"), 256); + air_group.add_air(Some("U16Air"), 65536); pilout } diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 47541a74..6137366d 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -11,6 +11,26 @@ trace!(RomRow, RomTrace { multiplicity: F, }); +trace!(MemRow, MemTrace { + addr: F, step: F, sel: F, addr_changes: F, value: [F; 2], wr: F, increment: F, +}); + +trace!(RomDataRow, RomDataTrace { + addr: F, step: F, sel: F, addr_changes: F, value: [F; 2], +}); + +trace!(InputDataRow, InputDataTrace { + addr: F, step: F, sel: F, addr_changes: F, value_word: [F; 4], +}); + +trace!(MemAlignRow, MemAlignTrace { + addr: F, offset: F, width: F, wr: F, pc: F, reset: F, sel_up_to_down: F, sel_down_to_up: F, reg: [F; 8], sel: [F; 8], step: F, delta_addr: F, sel_prove: F, value: [F; 2], +}); + +trace!(MemAlignRomRow, MemAlignRomTrace { + multiplicity: F, +}); + trace!(ArithRow, ArithTrace { carry: [F; 7], a: [F; 4], b: [F; 4], c: [F; 4], d: [F; 4], na: F, nb: F, nr: F, np: F, sext: F, m32: F, div: F, fab: F, na_fb: F, nb_fa: F, debug_main_step: F, main_div: F, main_mul: F, signed: F, div_by_zero: F, div_overflow: F, inv_sum_all_bs: F, op: F, bus_res1: F, multiplicity: F, range_ab: F, range_cd: F, }); @@ -40,7 +60,15 @@ trace!(BinaryExtensionTableRow, BinaryExtensionTableTrace { }); trace!(SpecifiedRangesRow, SpecifiedRangesTrace { - mul: [F; 1], + mul: [F; 2], +}); + +trace!(U8AirRow, U8AirTrace { + mul: F, +}); + +trace!(U16AirRow, U16AirTrace { + mul: F, }); trace!(RomRomRow, RomRomTrace { diff --git a/pil/zisk.pil b/pil/zisk.pil index 6dc5052c..bce36ffa 100644 --- a/pil/zisk.pil +++ b/pil/zisk.pil @@ -1,17 +1,27 @@ require "rom/pil/rom.pil" require "main/pil/main.pil" +require "mem/pil/mem.pil" +require "mem/pil/mem_align.pil" +require "mem/pil/mem_align_rom.pil" require "binary/pil/binary.pil" require "binary/pil/binary_table.pil" require "binary/pil/binary_extension.pil" require "binary/pil/binary_extension_table.pil" require "arith/pil/arith.pil" -// require "mem/pil/mem.pil" const int OPERATION_BUS_ID = 5000; + airgroup Zisk { Main(N: 2**21, RC: 2, operation_bus_id: OPERATION_BUS_ID); - Rom(N: 2**20); - // Mem(N: 2**21, RC: 2); + Rom(N: 2**22); + + Mem(N: 2**21, RC: 2, base_address: 0xA000_0000); + Mem(N: 2**21, RC: 2, base_address: 0x8000_0000, immutable: 1) alias RomData; + Mem(N: 2**21, RC: 2, base_address: 0x9000_0000, free_input_mem: 1) alias InputData; + MemAlign(N: 2**21); + MemAlignRom(disable_fixed: 0); + // InputData(N: 2**21, RC: 2); + Arith(N: 2**21, operation_bus_id: OPERATION_BUS_ID); ArithTable(); ArithRangeTable(); diff --git a/state-machines/arith/pil/arith_table.pil b/state-machines/arith/pil/arith_table.pil index 6788f7de..e8bd35d7 100644 --- a/state-machines/arith/pil/arith_table.pil +++ b/state-machines/arith/pil/arith_table.pil @@ -225,9 +225,9 @@ airtemplate ArithTable(int N = 2**7, int generate_table = 1) { RANGE_CD[index] = range_cd; if (generate_table) { - println(`OP:${opcode} na:${na} nb:${nb} np:${np} nr:${nr} sext:${sext} m32:${m32} div:${div}`, - `div_by_zero:${div_by_zero} div_overflow:${div_overflow} sa:${sa} sb:${sb} main_mul:${main_mul}`, - `main_div:${main_div} signed:${signed} range_ab:${range_ab} range_cd:${range_cd} index:${(opcode - 0xb0) * 128 + icase} icase:${icase}`); + // println(`OP:${opcode} na:${na} nb:${nb} np:${np} nr:${nr} sext:${sext} m32:${m32} div:${div}`, + // `div_by_zero:${div_by_zero} div_overflow:${div_overflow} sa:${sa} sb:${sb} main_mul:${main_mul}`, + // `main_div:${main_div} signed:${signed} range_ab:${range_ab} range_cd:${range_cd} index:${(opcode - 0xb0) * 128 + icase} icase:${icase}`); op2row[(opcode - 0xb0) * 128 + icase] = index; code = code + `[${opcode}, ${flags}, ${range_ab}, ${range_cd}],`; diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs index 99d71a65..afa3bd3c 100644 --- a/state-machines/arith/src/arith_full.rs +++ b/state-machines/arith/src/arith_full.rs @@ -78,7 +78,7 @@ impl ArithFullSM { let num_rows = air.num_rows(); timer_start_trace!(ARITH_TRACE); info!( - "{}: ยทยทยท Creating Arith instance KKKKK [{} / {} rows filled {:.2}%]", + "{}: ยทยทยท Creating Arith instance [{} / {} rows filled {:.2}%]", Self::MY_NAME, input.len(), num_rows, @@ -259,7 +259,6 @@ impl ArithFullSM { } timer_stop_and_log_trace!(ARITH_PADDING); timer_start_trace!(ARITH_TABLE); - info!("{}: ยทยทยท calling arit_table_sm", Self::MY_NAME); self.arith_table_sm.process_slice(&table_inputs); timer_stop_and_log_trace!(ARITH_TABLE); timer_start_trace!(ARITH_RANGE_TABLE); diff --git a/state-machines/arith/src/arith_table.rs b/state-machines/arith/src/arith_table.rs index dc535754..79fdbb2a 100644 --- a/state-machines/arith/src/arith_table.rs +++ b/state-machines/arith/src/arith_table.rs @@ -58,9 +58,7 @@ impl ArithTableSM { // Create the trace vector let mut _multiplicity = self.multiplicity.lock().unwrap(); - info!("{}: ยทยทยท process multiplicity", Self::MY_NAME); for (row, value) in inputs { - info!("{}: ยทยทยท Processing row {} with value {}", Self::MY_NAME, row, value); _multiplicity[row] += value; } self.used.store(true, Ordering::Relaxed); diff --git a/state-machines/arith/src/arith_table_helpers.rs b/state-machines/arith/src/arith_table_helpers.rs index ba557e07..67f2a730 100644 --- a/state-machines/arith/src/arith_table_helpers.rs +++ b/state-machines/arith/src/arith_table_helpers.rs @@ -25,9 +25,9 @@ impl ArithTableHelpers { sext as u64 * 16 + div_by_zero as u64 * 32 + div_overflow as u64 * 64; - assert!(index < ARITH_TABLE_ROWS.len() as u64); + debug_assert!(index < ARITH_TABLE_ROWS.len() as u64); let row = ARITH_TABLE_ROWS[index as usize]; - assert!( + debug_assert!( row < 255, "INVALID ROW row:{} op:0x{:x} na:{} nb:{} np:{} nr:{} sext:{} div_by_zero:{} div_overflow:{} index:{}", row, diff --git a/state-machines/binary/pil/binary_extension_table.pil b/state-machines/binary/pil/binary_extension_table.pil index 35e0ad35..f521b550 100644 --- a/state-machines/binary/pil/binary_extension_table.pil +++ b/state-machines/binary/pil/binary_extension_table.pil @@ -4,9 +4,9 @@ require "std_lookup.pil" // Operations Table: // Running Total // SLL (OP:0x31) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^19 -// SRL (OP:0x32) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^20 +// SRL (OP:0x32) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^20 // SRA (OP:0x33) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^20 + 2^19 -// SLL_W (OP:0x34) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 +// SLL_W (OP:0x34) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 // SRL_W (OP:0x35) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 + 2^19 // SRA_W (OP:0x36) 2^8 (A) * 2^3 (OFFSET) * 2^8 (B) = 2^19 | 2^21 + 2^20 // SE_B (OP:0x37) 2^8 (A) * 2^3 (OFFSET) = 2^11 | 2^21 + 2^20 + 2^11 @@ -16,7 +16,7 @@ require "std_lookup.pil" const int BINARY_EXTENSION_TABLE_ID = 124; airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = 0) { - + #pragma memory m1 start const int SE_MASK_32 = 0xFFFFFFFF00000000; @@ -28,15 +28,15 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = const int LS_5_BITS = 0x1F; const int LS_6_BITS = 0x3F; - + col witness multiplicity; if (disable_fixed) { col fixed _K = [0...]; // FORCE ONE TRACE multiplicity * _K === 0; - - println("*** DISABLE_FIXED ***"); + + println("*** DISABLE_FIXED ***"); return; } @@ -58,7 +58,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = // Input B (8 bits) col fixed B = [[0:P2_11..255:P2_11]:6, // SLL, SRL, SRA, SLL_W, SRL_W, SRA_W - 0:(P2_11*3)]...; // SE_B, SE_H, SE_W + 0:(P2_11*3)]...; // SE_B, SE_H, SE_W // Operation is shift (fixed values) col fixed OP_IS_SHIFT = [1:(P2_19*6), // SLL, SRL, SRA, SLL_W, SRL_W, SRA_W @@ -84,12 +84,12 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = const int _a = a << (8*offset); switch (op) { case 0x31: // SLL - _out = _a << (b & LS_6_BITS); + _out = _a << (b & LS_6_BITS); case 0x32: // SRL _out = _a >> (b & LS_6_BITS); - case 0x33: { // SRA + case 0x33: { // SRA const int _b = b & LS_6_BITS; _out = _a >> _b; if (offset == 7) { @@ -110,7 +110,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = _out = _out | SE_MASK_32; } } - + case 0x35: // SRL_W if (offset >= 4) { // last most significant bytes are ignored because it's 32-bit operation @@ -148,7 +148,7 @@ airtemplate BinaryExtensionTable(const int N = 2**22, const int disable_fixed = } case 0x38: // SE_H - if (offset == 0) { + if (offset == 0) { // fist byte not define the sign extend, but participate of result _out = a; } else if (offset == 1) { diff --git a/state-machines/binary/src/binary_basic.rs b/state-machines/binary/src/binary_basic.rs index aefe09f6..b37c5a22 100644 --- a/state-machines/binary/src/binary_basic.rs +++ b/state-machines/binary/src/binary_basic.rs @@ -10,7 +10,7 @@ use proofman_common::AirInstance; use proofman_util::{timer_start_trace, timer_stop_and_log_trace}; use std::cmp::Ordering as CmpOrdering; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; -use zisk_pil::*; +use zisk_pil::{BinaryRow, BinaryTrace, BINARY_AIR_IDS, BINARY_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; use crate::{BinaryBasicTableOp, BinaryBasicTableSM}; diff --git a/state-machines/binary/src/binary_extension.rs b/state-machines/binary/src/binary_extension.rs index ded8972f..4e22f8fe 100644 --- a/state-machines/binary/src/binary_extension.rs +++ b/state-machines/binary/src/binary_extension.rs @@ -15,7 +15,10 @@ use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; -use zisk_pil::*; +use zisk_pil::{ + BinaryExtensionRow, BinaryExtensionTrace, BINARY_EXTENSION_AIR_IDS, + BINARY_EXTENSION_TABLE_AIR_IDS, ZISK_AIRGROUP_ID, +}; const MASK_32: u64 = 0xFFFFFFFF; const MASK_64: u64 = 0xFFFFFFFFFFFFFFFF; diff --git a/state-machines/main/pil/main.pil b/state-machines/main/pil/main.pil index 740a6e63..6d8d7506 100644 --- a/state-machines/main/pil/main.pil +++ b/state-machines/main/pil/main.pil @@ -80,7 +80,7 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope col witness air.b_imm1; } col witness b_src_ind; - col witness ind_width; // 8 , 4, 2, 1 + col witness ind_width; // 8, 4, 2, 1 // Operations related @@ -113,8 +113,6 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope col witness jmp_offset1, jmp_offset2; // if flag, goto2, else goto 1 col witness m32; - const expr addr_step = STEP * 3; - const expr sel_mem_b; sel_mem_b = b_src_mem + b_src_ind; @@ -136,17 +134,18 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope } // Mem.load - //mem_load(sel: a_src_mem, - // step: addr_step, - // addr: addr0, - // value: a); + mem_load(sel: a_src_mem, + step: STEP, + addr: addr0, + value: a); // Mem.load - //mem_load(sel: sel_mem_b, - // step: addr_step + 1, - // bytes: ind_width, - // addr: addr1, - // value: b); + mem_load(sel: sel_mem_b, + step: STEP, + step_offset: 1, + bytes: b_src_ind * (ind_width - 8) + 8, + addr: addr1, + value: b); const expr store_value[2]; @@ -154,11 +153,12 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope store_value[1] = (1 - store_ra) * c[1]; // Mem.store - //mem_store(sel: store_mem + store_ind, - // step: addr_step + 2, - // bytes: ind_width, - // addr: addr2, - // value: store_value); + mem_store(sel: store_mem + store_ind, + step: STEP, + step_offset: 2, + bytes: store_ind * (ind_width - 8) + 8, + addr: addr2, + value: store_value); // Operation.assume => how organize software col witness __debug_operation_bus_enabled; @@ -241,12 +241,8 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope // const expr bus_main_segment = main_segment - SEGMENT_LAST * (main_segment * main_last_segment - 1 + main_last_segment); - // permutation_proves(MAIN_CONTINUATION_ID, cols: [bus_main_segment, is_last_continuation, ...specific_registers, c[0] * (1 - main_last_segment), c[1] * (1 - main_last_segment)], - // sel: SEGMENT_LAST - SEGMENT_L1, name: PIOP_NAME_ISOLATED, bus_type: PIOP_BUS_SUM); - permutation_proves(MAIN_CONTINUATION_ID, cols: [bus_main_segment, is_last_continuation, pc, c[0], c[1], set_pc, jmp_offset1, flag * SEGMENT_LAST * (jmp_offset1 - jmp_offset2) + jmp_offset2], - sel: SEGMENT_LAST, name: PIOP_NAME_ISOLATED, bus_type: PIOP_BUS_SUM); - permutation_assumes(MAIN_CONTINUATION_ID, cols: [bus_main_segment, is_last_continuation, pc, c[0], c[1], set_pc, jmp_offset1, flag * SEGMENT_LAST * (jmp_offset1 - jmp_offset2) + jmp_offset2], - sel: SEGMENT_L1, name: PIOP_NAME_ISOLATED, bus_type: PIOP_BUS_SUM); + permutation (MAIN_CONTINUATION_ID, cols: [bus_main_segment, is_last_continuation, pc, c[0], c[1], set_pc, jmp_offset1, flag * SEGMENT_LAST * (jmp_offset1 - jmp_offset2) + jmp_offset2], + sel: SEGMENT_LAST - SEGMENT_L1, name: PIOP_NAME_ISOLATED, bus_type: PIOP_BUS_SUM); flag * (1 - flag) === 0; diff --git a/state-machines/main/src/main_sm.rs b/state-machines/main/src/main_sm.rs index a7bcef56..52045db4 100644 --- a/state-machines/main/src/main_sm.rs +++ b/state-machines/main/src/main_sm.rs @@ -1,5 +1,6 @@ use log::info; use p3_field::PrimeField; +use sm_mem::MemProxy; use crate::InstanceExtensionCtx; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; @@ -12,7 +13,6 @@ use proofman_common::{AirInstance, ProofCtx}; use proofman::WitnessComponent; use sm_arith::ArithSM; -use sm_mem::MemSM; use zisk_pil::{ ArithTrace, BinaryExtensionTrace, BinaryTrace, MainRow, MainTrace, ARITH_AIR_IDS, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ZISK_AIRGROUP_ID, @@ -28,14 +28,14 @@ pub struct MainSM { /// Witness computation manager wcm: Arc>, + /// Memory state machine + mem_proxy_sm: Arc>, + /// Arithmetic state machine arith_sm: Arc>, /// Binary state machine binary_sm: Arc>, - - /// Memory state machine - mem_sm: Arc, } impl MainSM { @@ -54,16 +54,16 @@ impl MainSM { /// * Arc to the MainSM state machine pub fn new( wcm: Arc>, + mem_proxy_sm: Arc>, arith_sm: Arc>, binary_sm: Arc>, - mem_sm: Arc, ) -> Arc { - let main_sm = Arc::new(Self { wcm: wcm.clone(), arith_sm, binary_sm, mem_sm }); + let main_sm = Arc::new(Self { wcm: wcm.clone(), mem_proxy_sm, arith_sm, binary_sm }); wcm.register_component(main_sm.clone(), Some(ZISK_AIRGROUP_ID), Some(MAIN_AIR_IDS)); // For all the secondary state machines, register the main state machine as a predecessor - main_sm.mem_sm.register_predecessor(); + main_sm.mem_proxy_sm.register_predecessor(); main_sm.binary_sm.register_predecessor(); main_sm.arith_sm.register_predecessor(); @@ -153,6 +153,39 @@ impl MainSM { segment_trace.steps[slice_start..slice_end].iter().enumerate() { partial_trace[i] = emu.step_slice_full_trace(emu_trace_step); + // if partial_trace[i].a_src_mem == F::one() { + // println!( + // "A=MEM_OP_RD({}) [{},{}] PC:{}", + // partial_trace[i].a_offset_imm0, + // partial_trace[i].a[0], + // partial_trace[i].a[1], + // partial_trace[i].pc + // ); + // } + // if partial_trace[i].b_src_mem == F::one() || partial_trace[i].b_src_ind == + // F::one() { + // println!( + // "B=MEM_OP_RD({0}) [{1},{2}] PC:{3}", + // partial_trace[i].addr1, + // partial_trace[i].b[0], + // partial_trace[i].b[1], + // partial_trace[i].pc + // ); + // } + // if partial_trace[i].b_src_mem == F::one() || partial_trace[i].b_src_ind == + // F::one() { + // println!( + // "MEM_OP_WR({}) [{}, {}] PC:{}", + // partial_trace[i].store_offset + // + partial_trace[i].store_ind * partial_trace[i].a[0], + // partial_trace[i].store_ra + // * (partial_trace[i].pc + partial_trace[i].jmp_offset2 + // - partial_trace[i].c[0]) + // + partial_trace[i].c[0], + // (F::one() - partial_trace[i].store_ra) * partial_trace[i].c[1], + // partial_trace[i].pc + // ); + // } } // if there are steps in the chunk update last row if slice_end - slice_start > 0 { diff --git a/state-machines/mem/Cargo.toml b/state-machines/mem/Cargo.toml index 3f8ee914..7cdb344d 100644 --- a/state-machines/mem/Cargo.toml +++ b/state-machines/mem/Cargo.toml @@ -7,14 +7,21 @@ edition = "2021" sm-common = { path = "../common" } zisk-core = { path = "../../core" } zisk-pil = { path = "../../pil" } +num-traits = "0.2" -p3-field = { workspace=true } proofman-common = { workspace = true } proofman-macros = { workspace = true } +proofman-util = { workspace = true } proofman = { workspace = true } +pil-std-lib = { workspace = true } + +p3-field = { workspace=true } log = { workspace = true } rayon = { workspace = true } +num-bigint = { workspace = true } [features] default = [] -no_lib_link = ["proofman-common/no_lib_link", "proofman/no_lib_link"] \ No newline at end of file +no_lib_link = ["proofman-common/no_lib_link", "proofman/no_lib_link"] +debug_mem_proxy_engine = [] +debug_mem_align = [] \ No newline at end of file diff --git a/state-machines/mem/pil/mem.pil b/state-machines/mem/pil/mem.pil index 50bd652e..f574f264 100644 --- a/state-machines/mem/pil/mem.pil +++ b/state-machines/mem/pil/mem.pil @@ -1,3 +1,42 @@ +/* + Memory Component (Mem) + ====================== + + - Allows to define a memory on a region with size <= MEMORY_MAX_DIFF * mem_bytes (2^24). + - Inside this component the address are mem-byte address, to translate internal adress to + external need to multiply by mem_bytes. + - For executors optimization, external addresses use 32-bits. + - The memory regions must be exclusive, to avoid collisions between different memories. + - The constraints over instances guarantees that the memory access are inside definited region. + - The constraints guarantees that only one cyle for memory region is allowed. + - For non-aligned or for non mem-bytes access, the MemAlign machine was used. + + Parameters: + + - N = number of rows + - id = bus_id used of memory operations + - RC = number of value chunks (2 by default) + - mem_bytes = number of bytes of memory word (8 bytes by default) + - base_address = base byte address when start the memory + - mem_size = size of memory in bytes (0x800_0000 by default) + - immutable = if memory is immutable, first access is a write (by default is mutable) + - free_input_mem = if memory is a free input memory, memory without write, all access are reads + with same value, this value it's stablished by executor. + + Continuations: + + - The memory continuation is used to proves the last row significant values of the current segment, + and the next segment assume these significant values. + - The first assume of memory is generated by global constraint to guarantees only one cycle by + memory region. + - In the last segment, the proves are not generated to avoid generate more than one memory cycle. + - The constraints that refer to the values of the previous row, in the first row, take the value + from the airvalue previous_segment_xxx, which contains the value at the end of the previous segment. + - These previous airvalues are validated throw bus, because assume these values at end of previous + segment. + +*/ + require "std_permutation.pil" require "std_range_check.pil" @@ -6,101 +45,191 @@ const int MEMORY_CONT_ID = 11; const int MEMORY_LOAD_OP = 1; const int MEMORY_STORE_OP = 2; -const int MEMORY_MAX_DIFF = 2**22; +const int MEMORY_MAX_DIFF = 2**24; -const int MAX_MEM_STEP_OFFSET = 3; +const int MAX_MEM_STEP_OFFSET = 2; +const int MAX_MEM_OPS_PER_MAIN_STEP = (MAX_MEM_STEP_OFFSET + 1) * 2; -airtemplate Mem (int N = 2**21, int RC = 2, int id = MEMORY_ID, int MAX_STEP = 2 ** 23, int MEM_BYTES = 8 ) { +airtemplate Mem(const int N = 2**21, const int id = MEMORY_ID, const int RC = 2, const int mem_bytes = 8, const int base_address = 0, const int mem_size = 0x800_0000, int immutable = 0, const int free_input_mem = 0) { col fixed SEGMENT_L1 = [1,0...]; const expr SEGMENT_LAST = SEGMENT_L1'; - airval mem_segment; - airval mem_last_segment; + // in this air the address in a mem-bytes address (internal), when this address is pushed in BUS must be multiplied + // by mem_bytes to get the real address. + + const expr internal_base_address = base_address / mem_bytes; + const expr internal_end_address = (base_address + mem_size - 1) / mem_bytes; + airval segment_id; + airval is_first_segment; + airval is_last_segment; - col witness addr; // n-byte address, real address = addr * MEM_BYTES + is_first_segment * (1 - is_first_segment) === 0; + is_last_segment * (1 - is_last_segment) === 0; + is_first_segment * segment_id === 0; + + col witness addr; // n-byte address, real address = addr * mem_bytes col witness step; - col witness sel, wr; - col witness value[RC]; + col witness sel; col witness addr_changes; - const expr rd = (1 - wr); - sel * (1 - sel) === 0; - wr * (1 - wr) === 0; + if (!free_input_mem) { + col witness air.value[RC]; + } else { + immutable = 1; + col witness air.value_word[RC*2]; + const expr air.value[RC]; + for (int index = 0; index < RC; ++index) { + value[index] = value_word[index*2] + 2**16 * value_word[index*2 + 1]; + + // how value is a free-input, must be checked that it's 32-bit well formed value + range_check(value_word[index*2], 0, 2**16 - 1); + range_check(value_word[index*2+1], 0, 2**16 - 1); + } + } + if (!immutable) { + col witness air.wr; + const expr air.rd = 1 - wr; + wr * (1 - wr) === 0; + } else { + // a free input memory must be read-only, an immutable memory must be write + // on first row of new address (addr_changes = 1) + const expr air.wr = free_input_mem ? 0 : addr_changes; + } // if wr is 1, sel must be 1 (not allowed writes) wr * (1 - sel) === 0; - // all time first line is lost, used for continuations - sel * SEGMENT_L1 === 0; + sel * (1 - sel) === 0; addr_changes * (1 - addr_changes) === 0; + airval previous_segment_value[RC]; + airval previous_segment_step, previous_segment_addr; + + // continuation for next segment, these values used on direct update to air bus, and after + // with constraints force that these values are the same as last row of current segment. + + airval segment_last_value[RC]; + airval segment_last_step, segment_last_addr; + + for (int i = 0; i < RC; i++) { + SEGMENT_LAST * (value[i] - segment_last_value[i]) === 0; + } + + SEGMENT_LAST * (addr - segment_last_addr) === 0; + SEGMENT_LAST * (step - segment_last_step) === 0; + + // add base_address to the columns to avoid collisions between different memories + // for security send is_last_segment to avoid reuse end of last segment as start of new cycle of segments + direct_update_assumes(MEMORY_CONT_ID, + [ + base_address, // identify area of memory + segment_id, // current segment_id + // proves of last segment + previous_segment_addr, + previous_segment_step, + ...previous_segment_value + ]); + + direct_update_proves(MEMORY_CONT_ID, + [ + base_address, // identify area of memory + segment_id + 1, // next segment_id, for last segment + // this value is forced to 0 to match global constraint + segment_last_addr, // last addr of segment + segment_last_step, // last step of segment + ...segment_last_value + ], + sel: (1 - is_last_segment)); + + const int zeros[air.RC]; + for (int i = 0; i < length(zeros); ++i) { + zeros[i] = 0; + } + direct_global_update_proves(MEMORY_CONT_ID, [ base_address, 0, internal_base_address, 0, ...zeros]); + + // for security check that first address has correct value, to avoid add huge quantity of instances to "overflow" prime field. + range_check(colu: previous_segment_addr - internal_base_address + 1, min: 1, max: MEMORY_MAX_DIFF); + + // control final of memory + range_check(colu: internal_end_address - segment_last_addr + 1, min: 1, max: MEMORY_MAX_DIFF); + + // check increment of memory - range_check(sel: (1 - SEGMENT_L1), colu: addr_changes * (addr - 'addr - step + 'step) + step - 'step, min: 1, max: MEMORY_MAX_DIFF); + if (immutable) { + // addresses are incremental, to save range check, increment, etc, address must be consecutive. + const expr air.previous_addr = SEGMENT_L1 * (previous_segment_addr - is_first_segment - 'addr) + 'addr; + const expr delta_addr = addr - previous_addr; + addr_changes * (delta_addr - 1) === 0; + (1 - addr_changes) * (addr - previous_addr) === 0; + } else { + const expr air.previous_addr = SEGMENT_L1 * (previous_segment_addr - 'addr) + 'addr; + const expr delta_addr = addr - previous_addr; - // PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd = 1, wr = 0 - // setting mem_last_segment = 1 + // on first row of first segment could be the same and address_change = 1 because it's as a new addr + // SEGMENT_L1 * (x + is_first_segment * SEGMENT_L1) === SEGMENT_L1 * (x + is_first_segment) - // if addr_changes == 0 means that addr and previous address are the same - (1 - addr_changes) * ('addr - addr) === 0; + const expr previous_step = SEGMENT_L1 * (previous_segment_step - 'step) + 'step; + const expr delta_step = step - previous_step; - col witness same_value; - (1 - same_value) * (1 - wr) * (1 - addr_changes) === 0; + col witness increment; + increment === addr_changes * (delta_addr - delta_step) + delta_step; - col witness first_addr_access_is_read; - (1 - first_addr_access_is_read) * rd * (1 - addr_changes) === 0; + is_first_segment * SEGMENT_L1 * (1 - addr_changes) === 0; - for (int index = 0; index < length(value); index = index + 1) { - same_value * (value[index] - 'value[index]) === 0; - first_addr_access_is_read * value[index] === 0; + range_check(colu: increment, min: 1, max: MEMORY_MAX_DIFF); } - // CONTINUATIONS - // - // segments: S, S+1 - // - // CASE: last row of segment is read - // - // S[n-1] wr = 0, sel = 1, addr, step, value => BUS.proves(MEM_CONT_ID, S+1, addr, step-1, value) - // S+1[0] wr = 0, sel = 0, addr, step, value => BUS.assumes(MEM_CONT_ID, S, addr, step, value) - // - // CASE: last row of segment is write - // - // S[n-1] wr = 1, sel = 1, addr, step, value => BUS.proves(MEM_CONT_ID, S+1, addr, step-1, value) - // S+1[0] wr = 0, sel = 0, addr, step, value => BUS.assumes(MEM_CONT_ID, S, addr, step, value) - // - // NOTES: from row = 1 all constraints could be reference previous row, without problems - // on row = 0 forced by constraint that sel = 0 => wr = 0. - // on S+1[0].step = S[n-1].step - 1; - // - // FIRST SEGMENT: - // the BUS.proves needed by BUS.assumes of the first segment it's generated by global constraint to avoid - // generate more than one cycle of memory. In this constraint we could force the initial address (to split - // in two memories, one register-memory and other standard-memory). - // - // LAST SEGMENT: - // the last not used rows are filled with last addr and value and sel = 0 and wr = 0 incrementing steps. - // last BUS.proves not it's generated to avoid generate more than one memory cycle. - - // permutation_proves(MEMORY_CONT_ID, [(mem_segment + 1), addr, step, ...value], sel: mem_last_segment * 'SEGMENT_L1); // last row - // permutation_assumes(MEMORY_CONT_ID, [mem_segment, 0, addr, step, ...value], sel: SEGMENT_L1); // first row - - permutation_proves(MEMORY_ID, cols: [wr, addr * MEM_BYTES, step, MEM_BYTES, ...value], sel: sel); -} + (1 - addr_changes) * (addr - previous_addr) === 0; -// TODO: detect non default value but not called, mandatory parameter. -function mem_load(int id = MEMORY_ID, expr sel = 1, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[]) { - if (step_offset > MAX_MEM_STEP_OFFSET) { - error("max step_offset ${step_offset} is greater than max value ${MAX_MEM_STEP_OFFSET}"); + // PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd = 1, wr = 0 + // setting is_last_segment = 1 + + // if addr_changes == 0 means that addr and previous address are the same + // TODO: + + for (int index = 0; index < length(value); index++) { + const expr previous_value = SEGMENT_L1 * (previous_segment_value[index] - 'value[index]) + 'value[index]; + if (immutable) { + // if address not change value must be equal to previous value + (1 - addr_changes) * (value[index] - previous_value) === 0; + + if (!free_input_mem) { + // if address changes => write, and it must be inserted on bus + addr_changes * (1 - sel) === 0; + } + } else { + // if address not change and it isn't write, value must be equal to previous value + // TODO: boundary constraints + (1 - addr_changes) * (1 - wr) * (value[index] - previous_value) === 0; + + // if address changes, and it isn't a write, value must be 0. + addr_changes * (1 - wr) * value[index] === 0; + } } - // adding one for first continuation - permutation_assumes(id, [MEMORY_LOAD_OP, addr, 1 + ((MAX_MEM_STEP_OFFSET + 1) * step) + step_offset, bytes, ...value], sel:sel); + + // The Memory component is only able to prove aligned memory access, since we force the bus address to be a multiple of mem_bytes + // and the width to be exactly mem_bytes + // Notice, however, that the main can also use widths of 4, 2, 1 and addresses that are not multiples of mem_bytes. + // These are handled with the Memory Align component + + const expr mem_op = wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP; + permutation_proves(MEMORY_ID, cols: [mem_op, addr * mem_bytes, step, mem_bytes, ...value], sel: sel); +} + +function mem_load(int id = MEMORY_ID, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[], expr sel = 1) { + mem_assumes(id, MEMORY_LOAD_OP, addr, step, step_offset, bytes, value, sel); } -function mem_store(int id = MEMORY_ID, expr sel = 1, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[]) { +function mem_store(int id = MEMORY_ID, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[], expr sel = 1) { + mem_assumes(id, MEMORY_STORE_OP, addr, step, step_offset, bytes, value, sel); +} + +private function mem_assumes(int id, int mem_op, expr addr, expr step, expr step_offset, expr bytes, expr value[], expr sel) { if (step_offset > MAX_MEM_STEP_OFFSET) { - error("max step_offset ${step_offset} is greater than max value ${MAX_MEM_STEP_OFFSET}"); + error("step_offset ${step_offset} is greater than max value allowed ${MAX_MEM_STEP_OFFSET}"); } - // adding one for first continuation - permutation_assumes(id, [MEMORY_STORE_OP, addr, 1 + ((MAX_MEM_STEP_OFFSET + 1) * step), bytes, ...value], sel:sel); -} \ No newline at end of file + + // adding 1 at step for first continuation + permutation_assumes(id, [mem_op, addr, 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset, bytes, ...value], sel: sel); +} diff --git a/state-machines/mem/pil/mem_align.pil b/state-machines/mem/pil/mem_align.pil index e69de29b..8a23ab2a 100644 --- a/state-machines/mem/pil/mem_align.pil +++ b/state-machines/mem/pil/mem_align.pil @@ -0,0 +1,188 @@ +require "std_permutation.pil" +require "std_lookup.pil" +require "std_range_check.pil" + +// Problem to solve: +// ================= +// We are given an op (rd,wr), an addr, a step and a bytes-width (8,4,2,1) and we should prove that the memory access is correct. +// Note: Either the original addr is not a multiple of 8 or width < 8 to ensure it is a non-aligned access that should be +// handled by this component. + +/* + We will model it as a very specified processor with 8 registers and a very limited instruction set. + + This processor is limited to 4 possible subprograms: + + 1] Read operation that spans one memory word w = [w_0, w_1]: + w_0 w_1 + +---+===+===+===+ +===+---+---+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+===+===+===+ +===+---+---+---+ + |<------ v ------>| + + [R] In the first clock cycle, we perform an aligned read to w + [V] In the second clock cycle, we return the demanded value v from w + + 2] Write operation that spans one memory word w = [w_0, w_1]: + w_0 w_1 + +---+---+---+---+ +---+===+===+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+---+---+---+ +---+===+===+---+ + |<- v ->| + + [R] In the first clock cycle, we perform an aligned read to w + [W] In the second clock cycle, we compute an aligned write of v to w + [V] In the third clock cycle, we restore the demanded value from w + + 3] Read operation that spans two memory words w1 = [w1_0, w1_1] and w2 = [w2_0, w2_1]: + w1_0 w1_1 w2_0 w2_1 + +---+---+---+---+ +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+---+---+---+ +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ + |<---------------- v ---------------->| + + [R] In the first clock cycle, we perform an aligned read to w1 + [V] In the second clock cycle, we return the demanded value v from w1 and w2 + [R] In the third clock cycle, we perform an aligned read to w2 + + 4] Write operation that spans two memory words w1 = [w1_0, w1_1] and w2 = [w2_0, w2_1]: + w1_0 w1_1 w2_0 w2_1 + +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ +---+---+---+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ +---+---+---+---+ + |<---------------- v ---------------->| + + [R] In the first clock cycle, we perform an aligned read to w1 + [W] In the second clock cycle, we compute an aligned write of v to w1 + [V] In the third clock cycle, we restore the demanded value from w1 and w2 + [R] In the fourth clock cycle, we perform an aligned read to w2 + [W] In the fiveth clock cycle, we compute an aligned write of v to w2 + + Example: + ========================================================== + (offset = 6, width = 4) + +----+----+----+----+----+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | [R1] (assume, up_to_down) sel = [1,1,1,1,1,1,0,0] + +----+----+----+----+----+----+----+----+ + โ‡“ + +----+----+----+----+----+----+====+====+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | [W1] (assume, up_to_down) sel = [0,0,0,0,0,0,1,1] + +----+----+----+----+----+----+====+====+ + โ‡“ + +====+====+----+----+----+----+====+====+ + | V6 | V7 | V0 | V1 | V2 | V3 | V4 | V5 | [V] (prove) (shift (offset + width) % 8) sel = [0,0,0,0,0,0,1,0] (*) + +====+====+----+----+----+----+====+====+ + โ‡“ + +====+====+----+----+----+----+----+----+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | [W2] (assume, down_to_up) sel = [1,1,0,0,0,0,0,0] + +====+====+----+----+----+----+----+----+ + โ‡“ + +----+----+----+----+----+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | [R2] (assume, down_to_up) sel = [0,0,1,1,1,1,1,1] + +----+----+----+----+----+----+----+----+ + + (*) In this step, we use the selectors to indicate the "scanning" needed to form the bus value: + v_0 = sel[0] * [V1,V0,V7,V6] + sel[1] * [V0,V7,V6,V5] + sel[2] * [V7,V6,V5,V4] + sel[3] * [V6,V5,V4,V3] + v_1 = sel[4] * [V5,V4,V3,V2] + sel[5] * [V4,V3,V2,V1] + sel[6] * [V3,V2,V1,V0] + sel[7] * [V2,V1,V0,V7] + Notice that it is enough with 8 combinations. +*/ + +airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM = 8, const int CHUNK_BITS = 8) { + const int CHUNKS_BY_RC = CHUNK_NUM / RC; + + col witness addr; // CHUNK_NUM-byte address, real address = addr * CHUNK_NUM + col witness offset; // 0..7, position at which the operation starts + col witness width; // 1,2,4,8, width of the operation + col witness wr; // 1 if the operation is a write, 0 otherwise + col witness pc; // line of the program to execute + col witness reset; // 1 at the beginning of the operation (indicating an address reset), 0 otherwise + col witness sel_up_to_down; // 1 if the next value is the current value (e.g. R -> W) + col witness sel_down_to_up; // 1 if the next value is the previous value (e.g. W -> R) + col witness reg[CHUNK_NUM]; // Register values, 1 byte each + col witness sel[CHUNK_NUM]; // Selectors, 1 if the value is used, 0 otherwise + col witness step; // Memory step + + // 1] Ensure the MemAlign follows the program + + // Registers should be bytes and be shuch that: + // - reg' == reg in transitions R -> V, R -> W, W -> V, + // - 'reg == reg in transitions V <- W, W <- R, + // in any case, sel_up_to_down,sel_down_to_up are 0 in [V] steps. + for (int i = 0; i < CHUNK_NUM; i++) { + range_check(reg[i], 0, 2**CHUNK_BITS-1); + + (reg[i]' - reg[i]) * sel[i] * sel_up_to_down === 0; + ('reg[i] - reg[i]) * sel[i] * sel_down_to_up === 0; + } + + col fixed L1 = [1,0...]; + L1 * pc === 0; // The program should start at the first line + + // We compress selectors, so we should ensure they are binary + for (int i = 0; i < CHUNK_NUM; i++) { + sel[i] * (1 - sel[i]) === 0; + } + wr * (1 - wr) === 0; + reset * (1 - reset) === 0; + sel_up_to_down * (1 - sel_up_to_down) === 0; + sel_down_to_up * (1 - sel_down_to_up) === 0; + + expr flags = 0; + for (int i = 0; i < CHUNK_NUM; i++) { + flags += sel[i] * 2**i; + } + flags += wr * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3); + + // Perform the lookup against the program + expr delta_pc; + col witness delta_addr; // Auxiliary column + delta_pc = pc' - pc; + delta_addr === (addr - 'addr) * (1 - reset); + lookup_assumes(MEM_ALIGN_ROM_ID, [pc, delta_pc, delta_addr, offset, width, flags]); + + // 2] Assume aligned memory accesses against the Memory component + const expr sel_assume = sel_up_to_down + sel_down_to_up; + + // Offset should be 0 in aligned memory accesses, but this is ensured by the rom + // Width should be 8 in aligned memory accesses, but this is ensured by the rom + + // On assume steps, we reconstruct the value from the registers directly + expr assume_val[RC]; + for (int rc_index = 0; rc_index < RC; rc_index++) { + assume_val[rc_index] = 0; + int base = 1; + for (int _offset = 0; _offset < CHUNKS_BY_RC; _offset++) { + assume_val[rc_index] += reg[_offset + rc_index * CHUNKS_BY_RC] * base; + base *= 256; + } + } + + // 3] Prove unaligned memory accesses against the Main component + col witness sel_prove; + + sel_prove * sel_assume === 0; // Disjoint selectors + + // On prove steps, we reconstruct the value in the correct manner chosen by the selectors + expr prove_val[RC]; + for (int rc_index = 0; rc_index < RC; rc_index++) { + prove_val[rc_index] = 0; + } + for (int _offset = 0; _offset < CHUNK_NUM; _offset++) { + for (int rc_index = 0; rc_index < RC; rc_index++) { + expr _tmp = 0; + int base = 1; + for (int ichunk = 0; ichunk < CHUNKS_BY_RC; ichunk++) { + _tmp += reg[(_offset + rc_index * CHUNKS_BY_RC + ichunk) % CHUNK_NUM] * base; + base *= 256; + } + prove_val[rc_index] += sel[_offset] * _tmp; + } + } + + // We prove and assume with the same permutation check but with disjoint and different sign selectors + col witness value[RC]; // Auxiliary columns + for (int i = 0; i < RC; i++) { + value[i] === sel_prove * prove_val[i] + sel_assume * assume_val[i]; + } + permutation(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...value], sel: sel_prove - sel_assume); +} \ No newline at end of file diff --git a/state-machines/mem/pil/mem_align_rom.pil b/state-machines/mem/pil/mem_align_rom.pil new file mode 100644 index 00000000..3d7735bf --- /dev/null +++ b/state-machines/mem/pil/mem_align_rom.pil @@ -0,0 +1,323 @@ +require "std_lookup.pil" + +const int MEM_ALIGN_ROM_ID = 133; +const int MEM_ALIGN_ROM_SIZE = P2_8; + +airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = 8, const int DEFAULT_OFFSET = 0, const int DEFAULT_WIDTH = 8, const int disable_fixed = 0) { + if (N < MEM_ALIGN_ROM_SIZE) { + error(`N must be at least ${MEM_ALIGN_ROM_SIZE}, but N=${N} was provided`); + } + + col witness multiplicity; + + if (disable_fixed) { + col fixed _K = [0...]; + multiplicity * _K === 0; + + println("*** DISABLE_FIXED ***"); + return; + } + + // Define the size of each sub-program: RV, RWV, RVR, RWVWR + const int spsize[4] = [2, 3, 3, 5]; + + // Not all combinations of offset and width are valid for each program: + const int one_word_combinations = 20; // (0..4,[1,2,4]), (5,6,[1,2]), (7,[1]) -> 5*3 + 2*2 + 1*1 = 20 + const int two_word_combinations = 11; // (1..4,[8]), (5,6,[4,8]), (7,[2,4,8]) -> 4*1 + 2*2 + 1*3 = 11 + + // table_size = combinations * program_size + const int tsize[4] = [one_word_combinations*spsize[0], one_word_combinations*spsize[1], two_word_combinations*spsize[2], two_word_combinations*spsize[3]]; + const int psize = tsize[0] + tsize[1] + tsize[2] + tsize[3]; + + // Offset is set to DEFAULT_OFFSET and width to DEFAULT_WIDTH in aligned memory accesses. + // Offset and width are set to 0 in padding lines. + // size + col fixed OFFSET = [0, // Padding 1 = 1 | 1 + [[0,0]:3, [0,1]:3, [0,2]:3, [0,3]:3, [0,4]:3, [0,5]:2, [0,6]:2, [0,7]], // RV 6+6*4+4+4+2 = 40 | 41 + [[0,0,0]:3, [0,0,1]:3, [0,0,2]:3, [0,0,3]:3, [0,0,4]:3, [0,0,5]:2, [0,0,6]:2, [0,0,7]], // RWV 9+9*4+6+6+3 = 60 | 101 + [[0,1,0], [0,2,0], [0,3,0], [0,4,0], [0,5,0]:2, [0,6,0]:2, [0,7,0]:3], // RVR 3*4+6+6+9 = 33 | 134 + [[0,0,1,0,0], [0,0,2,0,0], [0,0,3,0,0], [0,0,4,0,0], [0,0,5,0,0]:2, [0,0,6,0,0]:2, [0,0,7,0,0]:3], // RWVWR 5*4+10+10+15 = 55 | 189 => N = 2^8 + 0...]; // Padding + + col fixed WIDTH = [0, // Padding + [[8,1,8,2,8,4]:5, [8,1,8,2]:2, [8,1]], // RV + [[8,8,1,8,8,2,8,8,4]:5, [8,8,1,8,8,2]:2, [8,8,1]], // RWV + [[8,8,8]:4, [8,4,8,8,8,8]:2, [8,2,8,8,4,8,8,8,8]], // RVR + [[8,8,8,8,8]:4, [8,8,4,8,8,8,8,8,8,8]:2, [8,8,2,8,8,8,8,4,8,8,8,8,8,8,8]], // RWVWR + 0...]; // Padding + + // line | pc | pc'-pc | reset | addr | (addr-'addr)*(1-reset) | + // 0 | 0 | 0 | 1 | 0 | 0 | // for padding + // 1 | 0 | 1 | 1 | X1 | 0 | // (RV) + // 2 | 1 | -1 | 0 | X1 | 0 | + // 3 | 0 | 3 | 1 | X2 | 0 | // (RV) + // 4 | 3 | -3 | 0 | X2 | 0 | + // 5 | 0 | 5 | 1 | X3 | 0 | // (RV) + // 6 | 5 | -5 | 0 | X3 | 0 | + // 7 | 0 | 7 | 1 | โ‹ฎ | โ‹ฎ | // (RV) + // โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | + // 41 | 0 | 41 | 1 | X4 | 0 | // (RWV) + // 42 | 41 | 1 | 0 | X4 | 0 | + // 43 | 42 | -42 | 0 | X4 | 0 | + // 44 | 0 | 44 | 1 | X5 | 0 | // (RWV) + // 45 | 44 | 1 | 0 | X5 | 0 | + // 46 | 45 | -45 | 0 | X5 | 0 | + // 47 | 0 | 47 | 1 | X6 | 0 | // (RWV) + // โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | + // 101 | 0 | 101 | 1 | X7 | 0 | // (RVR) + // 102 |101 | 1 | 0 | X7 | 0 | + // 103 |102 | -102 | 0 | X7+1 | 1 | + // 104 | 0 | 104 | 1 | X8 | 0 | // (RVR) + // 105 |104 | 1 | 0 | X8 | 0 | + // 106 |105 | -105 | 0 | X8+1 | 1 | + // 107 | 0 | 107 | 1 | X9 | 0 | // (RVR) + // โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | + // 134 | 0 | 134 | 1 | X10 | 0 | // (RWVWR) + // 135 |134 | 1 | 0 | X10 | 0 | + // 136 |135 | 1 | 0 | X10 | 0 | + // 137 |136 | 1 | 0 | X10+1 | 1 | + // 138 |137 | -137 | 0 | X10+1 | 0 | + // 139 | 0 | 139 | 1 | X11 | 0 | // (RWVWR) + // 140 |139 | 1 | 0 | X11 | 0 | + // 141 |140 | 1 | 0 | X11 | 0 | + // 142 |141 | 1 | 0 | X11+1 | 1 | + // 143 |142 | -142 | 0 | X11+1 | 0 | + // 144 | 0 | 144 | 1 | X12 | 0 | // (RWVWR) + // โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | + // 188 |187 | -187 | 0 | X13+1 | 0 | + // 189 | 0 | 0 | 1 | 0 | 0 | // for padding + // โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | โ‹ฎ | + + // Note: The overall program contains "holes", meaning that pc can vary + // from program to program by any constant, as long as it is unique for each program. + // For example, the first program has pc=0,1, while the second has pc=0,3. + + col fixed PC; + col fixed DELTA_PC; + col fixed DELTA_ADDR; + col fixed FLAGS; + for (int i = 0; i < N; i++) { + int pc = 0; + int delta_pc = 0; + int delta_addr = 0; + int is_write = 0; + int reset = 0; + int sel[CHUNK_NUM]; + for (int j = 0; j < CHUNK_NUM; j++) { + sel[j] = 0; + } + int sel_up_to_down = 0; + int sel_down_to_up = 0; + + const int prev_line = i == 0 ? 0 : i-1; + const int line = i; + if (line == 0 || line > psize) + { + // pc = 0; + // delta_pc = 0; + // delta_addr = 0; + // is_write = 0; + reset = 1; + // sel = [0:CHUNK_NUM] + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } + else if (line < 1+tsize[0]) // RV + { + if (line % 2 == 1) { + // pc = 0; + delta_pc = line; + // delta_addr = 0; + // is_write = 0; + reset = 1; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1] && j < OFFSET[i+1] + WIDTH[i+1]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else { + pc = prev_line; + delta_pc = -pc; + // delta_addr = 0; + // is_write = 0; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } + } + else if (line < 1+tsize[0]+tsize[1]) // RWV + { + if (line % 3 == 2) { + // pc = 0; + delta_pc = line; + // delta_addr = 0; + // is_write = 0; + reset = 1; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < OFFSET[i+2] || j >= OFFSET[i+2] + WIDTH[i+2]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (line % 3 == 0) { + pc = prev_line; + delta_pc = 1; + // delta_addr = 0; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1] && j < OFFSET[i+1] + WIDTH[i+1]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else { + pc = prev_line; + delta_pc = -pc; + // delta_addr = 0; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } + } + else if (line < 1+tsize[0]+tsize[1]+tsize[2]) // RVR + { + if (line % 3 == 2) { + // pc = 0; + delta_pc = line; + // delta_addr = 0; + // is_write = 0; + reset = 1; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (line % 3 == 0) { + pc = prev_line; + delta_pc = 1; + // delta_addr = 0; + // is_write = 0; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } else { + pc = prev_line; + delta_pc = -pc; + delta_addr = 1; + // is_write = 0; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < (OFFSET[i-1] + WIDTH[i-1]) % CHUNK_NUM) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + sel_down_to_up = 1; + } + } + else if (line < 1+tsize[0]+tsize[1]+tsize[2]+tsize[3]) // RWVWR + { + if (line % 5 == 4) { + // pc = 0; + delta_pc = line; + // delta_addr = 0; + // is_write = 0; + reset = 1; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < OFFSET[i+2]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (line % 5 == 0) { + pc = prev_line; + delta_pc = 1; + // delta_addr = 0; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (line % 5 == 1) { + pc = prev_line; + delta_pc = 1; + // delta_addr = 0; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } else if (line % 5 == 2) { + pc = prev_line; + delta_pc = 1; + delta_addr = 1; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < (OFFSET[i-1] + WIDTH[i-1]) % CHUNK_NUM) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + sel_down_to_up = 1; + } else { + pc = prev_line; + delta_pc = -pc; + // delta_addr = 0; + // is_write = 0; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= (OFFSET[i-2] + WIDTH[i-2]) % CHUNK_NUM) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + sel_down_to_up = 1; + } + } + PC[i] = pc; + DELTA_PC[i] = delta_pc; + DELTA_ADDR[i] = delta_addr; + int flags = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + flags += sel[j] * 2**j; + } + flags += is_write * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3); + FLAGS[i] = flags; + } + + // Ensure the program is being followed by the MemAlign + lookup_proves(MEM_ALIGN_ROM_ID, [PC, DELTA_PC, DELTA_ADDR, OFFSET, WIDTH, FLAGS], multiplicity); +} \ No newline at end of file diff --git a/state-machines/mem/src/input_data_sm.rs b/state-machines/mem/src/input_data_sm.rs new file mode 100644 index 00000000..220fc33f --- /dev/null +++ b/state-machines/mem/src/input_data_sm.rs @@ -0,0 +1,377 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use crate::{ + MemAirValues, MemInput, MemModule, MemPreviousSegment, MEMORY_MAX_DIFF, MEM_BYTES_BITS, +}; +use num_bigint::BigInt; +use p3_field::PrimeField; +use pil_std_lib::Std; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::AirInstance; +use zisk_core::{INPUT_ADDR, MAX_INPUT_SIZE}; +use zisk_pil::{InputDataTrace, INPUT_DATA_AIR_IDS, ZISK_AIRGROUP_ID}; + +const INPUT_W_ADDR_INIT: u32 = INPUT_ADDR as u32 >> MEM_BYTES_BITS; +const INPUT_W_ADDR_END: u32 = (INPUT_ADDR + MAX_INPUT_SIZE - 1) as u32 >> MEM_BYTES_BITS; + +#[allow(clippy::assertions_on_constants)] +const _: () = { + assert!( + (MAX_INPUT_SIZE - 1) >> MEM_BYTES_BITS as u64 <= MEMORY_MAX_DIFF, + "INPUT_DATA is too large" + ); + assert!( + INPUT_ADDR + MAX_INPUT_SIZE - 1 <= 0xFFFF_FFFF, + "INPUT_DATA memory exceeds the 32-bit addressable range" + ); +}; + +pub struct InputDataSM { + // Witness computation manager + wcm: Arc>, + + // STD + std: Arc>, + + num_rows: usize, + // Count of registered predecessors + registered_predecessors: AtomicU32, +} + +#[allow(unused, unused_variables)] +impl InputDataSM { + pub fn new(wcm: Arc>, std: Arc>) -> Arc { + let pctx = wcm.get_pctx(); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, INPUT_DATA_AIR_IDS[0]); + let input_data_sm = Self { + wcm: wcm.clone(), + std: std.clone(), + num_rows: air.num_rows(), + registered_predecessors: AtomicU32::new(0), + }; + let input_data_sm = Arc::new(input_data_sm); + + wcm.register_component( + input_data_sm.clone(), + Some(ZISK_AIRGROUP_ID), + Some(INPUT_DATA_AIR_IDS), + ); + std.register_predecessor(); + + input_data_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + let pctx = self.wcm.get_pctx(); + self.std.unregister_predecessor(pctx, None); + } + } + + pub fn prove(&self, inputs: &[MemInput]) { + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let ectx = wcm.get_ectx(); + let sctx = wcm.get_sctx(); + + // PRE: proxy calculate if exists jmp on step out-of-range, adding internal inputs + // memory only need to process these special inputs, but inputs no change. At end of + // inputs proxy add an extra internal input to jump to last address + + let air_id = INPUT_DATA_AIR_IDS[0]; + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, air_id); + let air_rows = air.num_rows(); + + // at least one row to go + let count = inputs.len(); + let count_rem = count % air_rows; + let num_segments = (count / air_rows) + if count_rem > 0 { 1 } else { 0 }; + + let mut prover_buffers = Mutex::new(vec![Vec::new(); num_segments]); + let mut global_idxs = vec![0; num_segments]; + + #[allow(clippy::needless_range_loop)] + for i in 0..num_segments { + // TODO: Review + if let (true, global_idx) = + ectx.dctx.write().unwrap().add_instance(ZISK_AIRGROUP_ID, air_id, 1) + { + let trace: InputDataTrace<'_, _> = InputDataTrace::new(air_rows); + let mut buffer = trace.buffer.unwrap(); + prover_buffers.lock().unwrap()[i] = buffer; + global_idxs[i] = global_idx; + } + } + + #[allow(clippy::needless_range_loop)] + for segment_id in 0..num_segments { + let is_last_segment = segment_id == num_segments - 1; + let input_offset = segment_id * air_rows; + let previous_segment = if (segment_id == 0) { + MemPreviousSegment { addr: INPUT_W_ADDR_INIT, step: 0, value: 0 } + } else { + MemPreviousSegment { + addr: inputs[input_offset - 1].addr, + step: inputs[input_offset - 1].step, + value: inputs[input_offset - 1].value, + } + }; + let input_end = + if (input_offset + air_rows) > count { count } else { input_offset + air_rows }; + let mem_ops = &inputs[input_offset..input_end]; + let prover_buffer = std::mem::take(&mut prover_buffers.lock().unwrap()[segment_id]); + + self.prove_instance( + mem_ops, + segment_id, + is_last_segment, + &previous_segment, + prover_buffer, + air_rows, + global_idxs[segment_id], + ); + } + } + + /// Finalizes the witness accumulation process and triggers the proof generation. + /// + /// This method is invoked by the executor when no further witness data remains to be added. + /// + /// # Parameters + /// + /// - `mem_inputs`: A slice of all `ZiskRequiredMemory` inputs + #[allow(clippy::too_many_arguments)] + pub fn prove_instance( + &self, + mem_ops: &[MemInput], + segment_id: usize, + is_last_segment: bool, + previous_segment: &MemPreviousSegment, + mut prover_buffer: Vec, + air_mem_rows: usize, + global_idx: usize, + ) -> Result<(), Box> { + assert!( + !mem_ops.is_empty() && mem_ops.len() <= air_mem_rows, + "InputDataSM: mem_ops.len()={} out of range {}", + mem_ops.len(), + air_mem_rows + ); + + // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR + // segments In a Memory AIR instance, the first row is reserved as a dummy row. + // This dummy row is used to facilitate the continuation state between different AIR + // segments. It ensures seamless transitions when multiple AIR segments are + // processed consecutively. This design avoids discontinuities in memory access + // patterns and ensures that the memory trace is continuous, For this reason we use + // AIR num_rows - 1 as the number of rows in each memory AIR instance + + // Create a vector of Mem0Row instances, one for each memory operation + // Recall that first row is a dummy row used for the continuations between AIR segments + // The length of the vector is the number of input memory operations plus one because + // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows + + //println! {"InputDataSM::prove_instance() mem_ops.len={} prover_buffer.len={} + // air.num_rows={}", mem_ops.len(), prover_buffer.len(), air.num_rows()}; + let mut trace = + InputDataTrace::::map_buffer(&mut prover_buffer, air_mem_rows, 0).unwrap(); + + let mut range_check_data: Vec = vec![0; 1 << 16]; + + let mut air_values = MemAirValues { + segment_id: segment_id as u32, + is_first_segment: segment_id == 0, + is_last_segment, + previous_segment_addr: previous_segment.addr, + previous_segment_step: previous_segment.step, + previous_segment_value: [ + previous_segment.value as u32, + (previous_segment.value >> 32) as u32, + ], + ..MemAirValues::default() + }; + + // range of instance + let range_id = self.std.get_range(BigInt::from(1), BigInt::from(MEMORY_MAX_DIFF), None); + self.std.range_check( + F::from_canonical_u32(previous_segment.addr - INPUT_W_ADDR_INIT + 1), + F::one(), + range_id, + ); + + // Fill the remaining rows + let mut last_addr: u32 = previous_segment.addr; + let mut last_step: u64 = previous_segment.step; + let mut last_value: u64 = previous_segment.value; + + for (i, mem_op) in mem_ops.iter().enumerate() { + trace[i].addr = F::from_canonical_u32(mem_op.addr); + trace[i].step = F::from_canonical_u64(mem_op.step); + trace[i].sel = F::from_bool(!mem_op.is_internal); + + let value = mem_op.value; + let value_words = self.get_u16_values(value); + for j in 0..4 { + range_check_data[value_words[j] as usize] += 1; + trace[i].value_word[j] = F::from_canonical_u16(value_words[j]); + } + + let addr_changes = last_addr != mem_op.addr; + trace[i].addr_changes = + if addr_changes || (i == 0 && segment_id == 0) { F::one() } else { F::zero() }; + + last_addr = mem_op.addr; + last_step = mem_op.step; + last_value = mem_op.value; + } + + // STEP3. Add dummy rows to the output vector to fill the remaining rows + //PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0 + let last_row_idx = mem_ops.len() - 1; + let addr = trace[last_row_idx].addr; + let value = trace[last_row_idx].value_word; + + let padding_size = air_mem_rows - mem_ops.len(); + for i in mem_ops.len()..air_mem_rows { + last_step += 1; + + // TODO CHECK + // trace[i].mem_segment = segment_id_field; + // trace[i].mem_last_segment = is_last_segment_field; + + trace[i].addr = addr; + trace[i].step = F::from_canonical_u64(last_step); + trace[i].sel = F::zero(); + + trace[i].value_word = value; + + trace[i].addr_changes = F::zero(); + } + + air_values.segment_last_addr = last_addr; + air_values.segment_last_step = last_step; + air_values.segment_last_value[0] = last_value as u32; + air_values.segment_last_value[1] = (last_value >> 32) as u32; + + self.std.range_check( + F::from_canonical_u32(INPUT_W_ADDR_END - last_addr + 1), + F::one(), + range_id, + ); + + // range of chunks + let range_id = self.std.get_range(BigInt::from(0), BigInt::from((1 << 16) - 1), None); + for (value, &multiplicity) in range_check_data.iter().enumerate() { + if (multiplicity == 0) { + continue; + } + + self.std.range_check( + F::from_canonical_usize(value), + F::from_canonical_u64(multiplicity), + range_id, + ); + } + for value_chunk in &value { + self.std.range_check(*value_chunk, F::from_canonical_usize(padding_size), range_id); + } + + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let sctx = wcm.get_sctx(); + + let mut air_instance = AirInstance::new( + self.wcm.get_sctx(), + ZISK_AIRGROUP_ID, + INPUT_DATA_AIR_IDS[0], + Some(segment_id), + prover_buffer, + ); + + self.set_airvalues("InputData", &mut air_instance, &air_values); + + pctx.air_instance_repo.add_air_instance(air_instance, Some(global_idx)); + + Ok(()) + } + + fn get_u16_values(&self, value: u64) -> [u16; 4] { + [value as u16, (value >> 16) as u16, (value >> 32) as u16, (value >> 48) as u16] + } + fn set_airvalues( + &self, + prefix: &str, + air_instance: &mut AirInstance, + air_values: &MemAirValues, + ) { + air_instance.set_airvalue( + format!("{}.segment_id", prefix).as_str(), + None, + F::from_canonical_u32(air_values.segment_id), + ); + air_instance.set_airvalue( + format!("{}.is_first_segment", prefix).as_str(), + None, + F::from_bool(air_values.is_first_segment), + ); + air_instance.set_airvalue( + format!("{}.is_last_segment", prefix).as_str(), + None, + F::from_bool(air_values.is_last_segment), + ); + air_instance.set_airvalue( + format!("{}.previous_segment_addr", prefix).as_str(), + None, + F::from_canonical_u32(air_values.previous_segment_addr), + ); + air_instance.set_airvalue( + format!("{}.previous_segment_step", prefix).as_str(), + None, + F::from_canonical_u64(air_values.previous_segment_step), + ); + air_instance.set_airvalue( + format!("{}.segment_last_addr", prefix).as_str(), + None, + F::from_canonical_u32(air_values.segment_last_addr), + ); + air_instance.set_airvalue( + format!("{}.segment_last_step", prefix).as_str(), + None, + F::from_canonical_u64(air_values.segment_last_step), + ); + let count = air_values.previous_segment_value.len(); + for i in 0..count { + air_instance.set_airvalue( + format!("{}.previous_segment_value", prefix).as_str(), + Some(vec![i as u64]), + F::from_canonical_u32(air_values.previous_segment_value[i]), + ); + air_instance.set_airvalue( + format!("{}.segment_last_value", prefix).as_str(), + Some(vec![i as u64]), + F::from_canonical_u32(air_values.segment_last_value[i]), + ); + } + } +} + +impl MemModule for InputDataSM { + fn send_inputs(&self, mem_op: &[MemInput]) { + self.prove(mem_op); + } + fn get_addr_ranges(&self) -> Vec<(u32, u32)> { + vec![(INPUT_ADDR as u32, (INPUT_ADDR + MAX_INPUT_SIZE - 1) as u32)] + } + fn get_flush_input_size(&self) -> u32 { + self.num_rows as u32 + } +} + +impl WitnessComponent for InputDataSM {} diff --git a/state-machines/mem/src/lib.rs b/state-machines/mem/src/lib.rs index 67bf225c..3c42869b 100644 --- a/state-machines/mem/src/lib.rs +++ b/state-machines/mem/src/lib.rs @@ -1,9 +1,23 @@ -mod mem; -mod mem_aligned; -mod mem_traces; -mod mem_unaligned; +mod input_data_sm; +mod mem_align_rom_sm; +mod mem_align_sm; +mod mem_constants; +mod mem_helpers; +mod mem_module; +mod mem_proxy; +mod mem_proxy_engine; +mod mem_sm; +mod mem_unmapped; +mod rom_data; -pub use mem::*; -pub use mem_aligned::*; -pub use mem_traces::*; -pub use mem_unaligned::*; +pub use input_data_sm::*; +pub use mem_align_rom_sm::*; +pub use mem_align_sm::*; +pub use mem_constants::*; +pub use mem_helpers::*; +pub use mem_module::*; +pub use mem_proxy::*; +pub use mem_proxy_engine::*; +pub use mem_sm::*; +pub use mem_unmapped::*; +pub use rom_data::*; diff --git a/state-machines/mem/src/mem.rs b/state-machines/mem/src/mem.rs deleted file mode 100644 index 065b1841..00000000 --- a/state-machines/mem/src/mem.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use crate::{MemAlignedSM, MemUnalignedSM}; -use p3_field::Field; -use rayon::Scope; -use sm_common::{MemOp, MemUnalignedOp, OpResult, Provable}; -use zisk_core::ZiskRequiredMemory; - -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; - -#[allow(dead_code)] -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -#[allow(dead_code)] -pub struct MemSM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs_aligned: Mutex>, - inputs_unaligned: Mutex>, - - // Secondary State machines - mem_aligned_sm: Arc, - mem_unaligned_sm: Arc, -} - -impl MemSM { - pub fn new(wcm: Arc>) -> Arc { - let mem_aligned_sm = MemAlignedSM::new(wcm.clone()); - let mem_unaligned_sm = MemUnalignedSM::new(wcm.clone()); - - let mem_sm = Self { - registered_predecessors: AtomicU32::new(0), - inputs_aligned: Mutex::new(Vec::new()), - inputs_unaligned: Mutex::new(Vec::new()), - mem_aligned_sm: mem_aligned_sm.clone(), - mem_unaligned_sm: mem_unaligned_sm.clone(), - }; - let mem_sm = Arc::new(mem_sm); - - wcm.register_component(mem_sm.clone(), None, None); - - // For all the secondary state machines, register the main state machine as a predecessor - mem_sm.mem_aligned_sm.register_predecessor(); - mem_sm.mem_unaligned_sm.register_predecessor(); - - mem_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); - - self.mem_aligned_sm.unregister_predecessor::(scope); - self.mem_unaligned_sm.unregister_predecessor::(scope); - } - } -} - -impl WitnessComponent for MemSM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, - ) { - } -} - -impl Provable for MemSM { - fn calculate( - &self, - _operation: ZiskRequiredMemory, - ) -> Result> { - unimplemented!() - } - - fn prove(&self, _operations: &[ZiskRequiredMemory], _drain: bool, _scope: &Scope) { - // TODO! - } - - fn calculate_prove( - &self, - _operation: ZiskRequiredMemory, - _drain: bool, - _scope: &Scope, - ) -> Result> { - unimplemented!() - } -} diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs new file mode 100644 index 00000000..486c05dd --- /dev/null +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -0,0 +1,214 @@ +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, + }, +}; + +use log::info; +use p3_field::PrimeField; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::AirInstance; + +use zisk_pil::{MemAlignRomRow, MemAlignRomTrace, MEM_ALIGN_ROM_AIR_IDS, ZISK_AIRGROUP_ID}; + +#[derive(Debug, Clone, Copy)] +pub enum MemOp { + OneRead, + OneWrite, + TwoReads, + TwoWrites, +} + +const OP_SIZES: [u64; 4] = [2, 3, 3, 5]; +const ONE_WORD_COMBINATIONS: u64 = 20; // (0..4,[1,2,4]), (5,6,[1,2]), (7,[1]) -> 5*3 + 2*2 + 1*1 = 20 +const TWO_WORD_COMBINATIONS: u64 = 11; // (1..4,[8]), (5,6,[4,8]), (7,[2,4,8]) -> 4*1 + 2*2 + 1*3 = 11 + +pub struct MemAlignRomSM { + // Witness computation manager + wcm: Arc>, + + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Rom data + num_rows: usize, + multiplicity: Mutex>, // row_num -> multiplicity +} + +#[derive(Debug)] +pub enum ExtensionTableSMErr { + InvalidOpcode, +} + +impl MemAlignRomSM { + const MY_NAME: &'static str = "MemAlignRom"; + + pub fn new(wcm: Arc>) -> Arc { + let pctx = wcm.get_pctx(); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); + let num_rows = air.num_rows(); + + let mem_align_rom = Self { + wcm: wcm.clone(), + registered_predecessors: AtomicU32::new(0), + num_rows, + multiplicity: Mutex::new(HashMap::with_capacity(num_rows)), + }; + let mem_align_rom = Arc::new(mem_align_rom); + wcm.register_component( + mem_align_rom.clone(), + Some(ZISK_AIRGROUP_ID), + Some(MEM_ALIGN_ROM_AIR_IDS), + ); + + mem_align_rom + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + self.create_air_instance(); + } + } + + pub fn calculate_next_pc(&self, opcode: MemOp, offset: usize, width: usize) -> u64 { + // Get the table offset + let (table_offset, one_word) = match opcode { + MemOp::OneRead => (1, true), + + MemOp::OneWrite => (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0], true), + + MemOp::TwoReads => ( + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1], + false, + ), + + MemOp::TwoWrites => ( + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + + ONE_WORD_COMBINATIONS * OP_SIZES[1] + + TWO_WORD_COMBINATIONS * OP_SIZES[2], + false, + ), + }; + + // Get the first row index + let first_row_idx = Self::get_first_row_idx(opcode, offset, width, table_offset, one_word); + + // Based on the program size, return the row indices + let opcode_idx = opcode as usize; + let op_size = OP_SIZES[opcode_idx]; + for i in 0..op_size { + let row_idx = first_row_idx + i; + // Check whether the row index is within the bounds + debug_assert!(row_idx < self.num_rows as u64); + // Update the multiplicity + self.update_multiplicity_by_row_idx(row_idx, 1); + } + + first_row_idx + } + + fn get_first_row_idx( + opcode: MemOp, + offset: usize, + width: usize, + table_offset: u64, + one_word: bool, + ) -> u64 { + let opcode_idx = opcode as usize; + let op_size = OP_SIZES[opcode_idx]; + + // Go to the actual operation + let mut first_row_idx = table_offset; + + // Go to the actual offset + let first_valid_offset = if one_word { 0 } else { 1 }; + for i in first_valid_offset..offset { + let possible_widths = Self::calculate_possible_widths(one_word, i); + first_row_idx += op_size * possible_widths.len() as u64; + } + + // Go to the right width + let width_idx = Self::calculate_possible_widths(one_word, offset) + .iter() + .position(|&w| w == width) + .unwrap_or_else(|| panic!("Invalid width offset:{} width:{}", offset, width)); + first_row_idx += op_size * width_idx as u64; + + first_row_idx + } + + fn calculate_possible_widths(one_word: bool, offset: usize) -> Vec { + // Calculate the ROM rows based on the requested opcode, offset, and width + match one_word { + true => match offset { + x if x <= 4 => vec![1, 2, 4], + x if x <= 6 => vec![1, 2], + 7 => vec![1], + _ => panic!("Invalid offset={}", offset), + }, + false => match offset { + 0 => panic!("Invalid offset={}", offset), + x if x <= 4 => vec![8], + x if x <= 6 => vec![4, 8], + 7 => vec![2, 4, 8], + _ => panic!("Invalid offset={}", offset), + }, + } + } + + pub fn update_padding_row(&self, padding_len: u64) { + // Update entry at the padding row (pos = 0) with the given padding length + self.update_multiplicity_by_row_idx(0, padding_len); + } + + pub fn update_multiplicity_by_row_idx(&self, row_idx: u64, mul: u64) { + let mut multiplicity = self.multiplicity.lock().unwrap(); + *multiplicity.entry(row_idx).or_insert(0) += mul; + } + + pub fn create_air_instance(&self) { + // Get the contexts + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let sctx = wcm.get_sctx(); + + // Get the Mem Align ROM AIR + let air_mem_align_rom = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); + let air_mem_align_rom_rows = air_mem_align_rom.num_rows(); + + let mut trace_buffer: MemAlignRomTrace<'_, _> = + MemAlignRomTrace::new(air_mem_align_rom_rows); + + // Initialize the trace buffer to zero + for i in 0..air_mem_align_rom_rows { + trace_buffer[i] = MemAlignRomRow { multiplicity: F::zero() }; + } + + // Fill the trace buffer with the multiplicity values + if let Ok(multiplicity) = self.multiplicity.lock() { + for (row_idx, multiplicity) in multiplicity.iter() { + trace_buffer[*row_idx as usize] = + MemAlignRomRow { multiplicity: F::from_canonical_u64(*multiplicity) }; + } + } + + info!("{}: ยทยทยท Creating Mem Align Rom instance", Self::MY_NAME,); + + let air_instance = AirInstance::new( + sctx, + ZISK_AIRGROUP_ID, + MEM_ALIGN_ROM_AIR_IDS[0], + None, + trace_buffer.buffer.unwrap(), + ); + pctx.air_instance_repo.add_air_instance(air_instance, None); + } +} + +impl WitnessComponent for MemAlignRomSM {} diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs new file mode 100644 index 00000000..7433c007 --- /dev/null +++ b/state-machines/mem/src/mem_align_sm.rs @@ -0,0 +1,1015 @@ +use core::panic; +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use log::info; +use num_bigint::BigInt; +use num_traits::cast::ToPrimitive; +use p3_field::PrimeField; +use pil_std_lib::Std; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::AirInstance; + +use zisk_pil::{MemAlignRow, MemAlignTrace, MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID}; + +use crate::{MemAlignInput, MemAlignRomSM, MemOp}; + +const RC: usize = 2; +const CHUNK_NUM: usize = 8; +const CHUNKS_BY_RC: usize = CHUNK_NUM / RC; +const CHUNK_BITS: usize = 8; +const RC_BITS: u64 = (CHUNKS_BY_RC * CHUNK_BITS) as u64; +const RC_MASK: u64 = (1 << RC_BITS) - 1; +const OFFSET_MASK: u32 = 0x07; +const OFFSET_BITS: u32 = 3; +const CHUNK_BITS_MASK: u64 = (1 << CHUNK_BITS) - 1; + +const fn generate_allowed_offsets() -> [u8; CHUNK_NUM] { + let mut offsets = [0; CHUNK_NUM]; + let mut i = 0; + while i < CHUNK_NUM { + offsets[i] = i as u8; + i += 1; + } + offsets +} + +const ALLOWED_OFFSETS: [u8; CHUNK_NUM] = generate_allowed_offsets(); +const ALLOWED_WIDTHS: [u8; 4] = [1, 2, 4, 8]; +const DEFAULT_OFFSET: u64 = 0; +const DEFAULT_WIDTH: u64 = 8; + +pub struct MemAlignResponse { + pub more_addr: bool, + pub step: u64, + pub value: Option, +} +pub struct MemAlignSM { + // Witness computation manager + wcm: Arc>, + + // STD + std: Arc>, + + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Computed row information + rows: Mutex>>, + #[cfg(feature = "debug_mem_align")] + num_computed_rows: Mutex, + + // Secondary State machines + mem_align_rom_sm: Arc>, +} + +macro_rules! debug_info { + ($prefix:expr, $($arg:tt)*) => { + #[cfg(feature = "debug_mem_align")] + { + info!(concat!("MemAlign: ",$prefix), $($arg)*); + } + }; +} + +impl MemAlignSM { + const MY_NAME: &'static str = "MemAlign"; + + pub fn new( + wcm: Arc>, + std: Arc>, + mem_align_rom_sm: Arc>, + ) -> Arc { + let mem_align_sm = Self { + wcm: wcm.clone(), + std: std.clone(), + registered_predecessors: AtomicU32::new(0), + rows: Mutex::new(Vec::new()), + #[cfg(feature = "debug_mem_align")] + num_computed_rows: Mutex::new(0), + mem_align_rom_sm, + }; + let mem_align_sm = Arc::new(mem_align_sm); + + wcm.register_component( + mem_align_sm.clone(), + Some(ZISK_AIRGROUP_ID), + Some(MEM_ALIGN_AIR_IDS), + ); + + // Register the predecessors + std.register_predecessor(); + mem_align_sm.mem_align_rom_sm.register_predecessor(); + + mem_align_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + let pctx = self.wcm.get_pctx(); + + // If there are remaining rows, generate the last instance + if let Ok(mut rows) = self.rows.lock() { + // Get the Mem Align AIR + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + + let rows_len = rows.len(); + debug_assert!(rows_len <= air_mem_align.num_rows()); + + let drained_rows = rows.drain(..rows_len).collect::>(); + + self.fill_new_air_instance(&drained_rows); + } + + self.mem_align_rom_sm.unregister_predecessor(); + self.std.unregister_predecessor(pctx, None); + } + } + + #[inline(always)] + pub fn get_mem_op(&self, input: &MemAlignInput, phase: usize) -> MemAlignResponse { + let addr = input.addr; + let width = input.width; + + // Compute the width + debug_assert!( + ALLOWED_WIDTHS.contains(&width), + "Width={} is not allowed. Allowed widths are {:?}", + width, + ALLOWED_WIDTHS + ); + let width = width as usize; + + // Compute the offset + let offset = (addr & OFFSET_MASK) as u8; + debug_assert!( + ALLOWED_OFFSETS.contains(&offset), + "Offset={} is not allowed. Allowed offsets are {:?}", + offset, + ALLOWED_OFFSETS + ); + let offset = offset as usize; + + #[cfg(feature = "debug_mem_align")] + let num_rows = self.num_computed_rows.lock().unwrap(); + match (input.is_write, offset + width > CHUNK_NUM) { + (false, false) => { + /* RV with offset=2, width=4 + +----+----+====+====+====+====+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+====+====+====+====+----+----+ + โ‡“ + +----+----+====+====+====+====+----+----+ + | V6 | V7 | V0 | V1 | V2 | V3 | V4 | V5 | + +----+----+====+====+====+====+----+----+ + */ + debug_assert!(phase == 0); + + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Get the aligned address + let addr_read = addr >> OFFSET_BITS; + + // Get the aligned value + let value_read = input.mem_values[phase]; + + // Get the next pc + let next_pc = + self.mem_align_rom_sm.calculate_next_pc(MemOp::OneRead, offset, width); + + let mut read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + read_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_read, i, 0)); + if i >= offset && i < offset + width { + read_row.sel[i] = F::from_bool(true); + } + + value_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)); + if i == offset { + value_row.sel[i] = F::from_bool(true); + } + } + + let mut _value_read = value_read; + let mut _value = value; + for i in 0..RC { + read_row.value[i] = F::from_canonical_u64(_value_read & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + _value_read >>= RC_BITS; + _value >>= RC_BITS; + } + + #[rustfmt::skip] + debug_info!( + "\nOne Word Read\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Phase: {:?}\n\ + Value Read: {:?}\n\ + Value: {:?}\n\ + Flags Read: {:?}\n\ + Flags Value: {:?}", + [*num_rows, *num_rows + 1], + input, + phase, + value_read.to_le_bytes(), + value.to_le_bytes(), + [ + read_row.sel[0], read_row.sel[1], read_row.sel[2], read_row.sel[3], + read_row.sel[4], read_row.sel[5], read_row.sel[6], read_row.sel[7], + read_row.wr, read_row.reset, read_row.sel_up_to_down, read_row.sel_down_to_up + ], + [ + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up + ] + ); + + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + + // Prove the generated rows + self.prove(&[read_row, value_row]); + + MemAlignResponse { more_addr: false, step, value: None } + } + (true, false) => { + /* RWV with offset=3, width=4 + +----+----+----+====+====+====+====+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+----+====+====+====+====+----+ + โ‡“ + +----+----+----+====+====+====+====+----+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | + +----+----+----+====+====+====+====+----+ + โ‡“ + +----+----+----+====+====+====+====+----+ + | V5 | V6 | V7 | V0 | V1 | V2 | V3 | V4 | + +----+----+----+====+====+====+====+----+ + */ + debug_assert!(phase == 0); + + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Get the aligned address + let addr_read = addr >> OFFSET_BITS; + + // Get the aligned value + let value_read = input.mem_values[phase]; + + // Get the next pc + let next_pc = + self.mem_align_rom_sm.calculate_next_pc(MemOp::OneWrite, offset, width); + + // Compute the write value + let value_write = { + // with:1 offset:4 + let width_bytes: u64 = (1 << (width * CHUNK_BITS)) - 1; + + let mask: u64 = width_bytes << (offset * CHUNK_BITS); + + // Get the first width bytes of the unaligned value + let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + + // Write zeroes to value_read from offset to offset + width + // and add the value to write to the value read + (value_read & !mask) | value_to_write + }; + + let mut read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut write_row = MemAlignRow:: { + step: F::from_canonical_u64(step + 1), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + read_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_read, i, 0)); + if i < offset || i >= offset + width { + read_row.sel[i] = F::from_bool(true); + } + + write_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_write, i, 0)); + if i >= offset && i < offset + width { + write_row.sel[i] = F::from_bool(true); + } + + value_row.reg[i] = { + if i >= offset && i < offset + width { + write_row.reg[i] + } else { + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)) + } + }; + if i == offset { + value_row.sel[i] = F::from_bool(true); + } + } + + let mut _value_read = value_read; + let mut _value_write = value_write; + let mut _value = value; + for i in 0..RC { + read_row.value[i] = F::from_canonical_u64(_value_read & RC_MASK); + write_row.value[i] = F::from_canonical_u64(_value_write & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + _value_read >>= RC_BITS; + _value_write >>= RC_BITS; + _value >>= RC_BITS; + } + + #[rustfmt::skip] + debug_info!( + "\nOne Word Write\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Phase: {:?}\n\ + Value Read: {:?}\n\ + Value Write: {:?}\n\ + Value: {:?}\n\ + Flags Read: {:?}\n\ + Flags Write: {:?}\n\ + Flags Value: {:?}", + [*num_rows, *num_rows + 2], + input, + phase, + value_read.to_le_bytes(), + value_write.to_le_bytes(), + value.to_le_bytes(), + [ + read_row.sel[0], read_row.sel[1], read_row.sel[2], read_row.sel[3], + read_row.sel[4], read_row.sel[5], read_row.sel[6], read_row.sel[7], + read_row.wr, read_row.reset, read_row.sel_up_to_down, read_row.sel_down_to_up + ], + [ + write_row.sel[0], write_row.sel[1], write_row.sel[2], write_row.sel[3], + write_row.sel[4], write_row.sel[5], write_row.sel[6], write_row.sel[7], + write_row.wr, write_row.reset, write_row.sel_up_to_down, write_row.sel_down_to_up + ], + [ + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up + ] + ); + + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + + // Prove the generated rows + self.prove(&[read_row, write_row, value_row]); + + MemAlignResponse { more_addr: false, step, value: Some(value_write) } + } + (false, true) => { + /* RVR with offset=5, width=8 + +----+----+----+----+----+====+====+====+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+----+----+----+====+====+====+ + โ‡“ + +====+====+====+====+====+====+====+====+ + | V3 | V4 | V5 | V6 | V7 | V0 | V1 | V2 | + +====+====+====+====+====+====+====+====+ + โ‡“ + +====+====+====+====+====+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +====+====+====+====+====+----+----+----+ + */ + debug_assert!(phase == 0 || phase == 1); + + match phase { + // If phase == 0, do nothing, just ask for more + 0 => MemAlignResponse { more_addr: true, step: input.step, value: None }, + + // Otherwise, do the RVR + 1 => { + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Compute the remaining bytes + let rem_bytes = (offset + width) % CHUNK_NUM; + + // Get the aligned address + let addr_first_read = addr >> OFFSET_BITS; + let addr_second_read = addr_first_read + 1; + + // Get the aligned value + let value_first_read = input.mem_values[0]; + let value_second_read = input.mem_values[1]; + + // Get the next pc + let next_pc = + self.mem_align_rom_sm.calculate_next_pc(MemOp::TwoReads, offset, width); + + let mut first_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read), + // delta_addr: F::zero(), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + let mut second_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_second_read), + delta_addr: F::one(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + first_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); + if i >= offset { + first_read_row.sel[i] = F::from_bool(true); + } + + value_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)); + + if i == offset { + value_row.sel[i] = F::from_bool(true); + } + + second_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); + if i < rem_bytes { + second_read_row.sel[i] = F::from_bool(true); + } + } + + let mut _value_first_read = value_first_read; + let mut _value = value; + let mut _value_second_read = value_second_read; + for i in 0..RC { + first_read_row.value[i] = + F::from_canonical_u64(_value_first_read & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + second_read_row.value[i] = + F::from_canonical_u64(_value_second_read & RC_MASK); + _value_first_read >>= RC_BITS; + _value >>= RC_BITS; + _value_second_read >>= RC_BITS; + } + + #[rustfmt::skip] + debug_info!( + "\nTwo Words Read\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Phase: {:?}\n\ + Value First Read: {:?}\n\ + Value: {:?}\n\ + Value Second Read: {:?}\n\ + Flags First Read: {:?}\n\ + Flags Value: {:?}\n\ + Flags Second Read: {:?}", + [*num_rows, *num_rows + 2], + input, + phase, + value_first_read.to_le_bytes(), + value.to_le_bytes(), + value_second_read.to_le_bytes(), + [ + first_read_row.sel[0], first_read_row.sel[1], first_read_row.sel[2], first_read_row.sel[3], + first_read_row.sel[4], first_read_row.sel[5], first_read_row.sel[6], first_read_row.sel[7], + first_read_row.wr, first_read_row.reset, first_read_row.sel_up_to_down, first_read_row.sel_down_to_up + ], + [ + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up + ], + [ + second_read_row.sel[0], second_read_row.sel[1], second_read_row.sel[2], second_read_row.sel[3], + second_read_row.sel[4], second_read_row.sel[5], second_read_row.sel[6], second_read_row.sel[7], + second_read_row.wr, second_read_row.reset, second_read_row.sel_up_to_down, second_read_row.sel_down_to_up + ] + ); + + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + + // Prove the generated rows + self.prove(&[first_read_row, value_row, second_read_row]); + + MemAlignResponse { more_addr: false, step, value: None } + } + _ => panic!("Invalid phase={}", phase), + } + } + (true, true) => { + /* RWVWR with offset=6, width=4 + +----+----+----+----+----+----+====+====+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+----+----+----+----+====+====+ + โ‡“ + +----+----+----+----+----+----+====+====+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | + +----+----+----+----+----+----+====+====+ + โ‡“ + +====+====+----+----+----+----+====+====+ + | V2 | V3 | V4 | V5 | V6 | V7 | V0 | V1 | + +====+====+----+----+----+----+====+====+ + โ‡“ + +====+====+----+----+----+----+----+----+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | + +====+====+----+----+----+----+----+----+ + โ‡“ + +====+====+----+----+----+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +====+====+----+----+----+----+----+----+ + */ + debug_assert!(phase == 0 || phase == 1); + + match phase { + // If phase == 0, compute the resulting write value and ask for more + 0 => { + // Unaligned memory op information thrown into the bus + let value = input.value; + let step = input.step; + + // Get the aligned value + let value_first_read = input.mem_values[0]; + + // Compute the write value + let value_first_write = { + // Normalize the width + let width_norm = CHUNK_NUM - offset; + + let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; + + let mask: u64 = width_bytes << (offset * CHUNK_BITS); + + // Get the first width bytes of the unaligned value + let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + + // Write zeroes to value_read from offset to offset + width + // and add the value to write to the value read + (value_first_read & !mask) | value_to_write + }; + + MemAlignResponse { more_addr: true, step, value: Some(value_first_write) } + } + // Otherwise, do the RWVRW + 1 => { + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Compute the shift + let rem_bytes = (offset + width) % CHUNK_NUM; + + // Get the aligned address + let addr_first_read_write = addr >> OFFSET_BITS; + let addr_second_read_write = addr_first_read_write + 1; + + // Get the first aligned value + let value_first_read = input.mem_values[0]; + + // Recompute the first write value + let value_first_write = { + // Normalize the width + let width_norm = CHUNK_NUM - offset; + + let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; + + let mask: u64 = width_bytes << (offset * CHUNK_BITS); + + // Get the first width bytes of the unaligned value + let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + + // Write zeroes to value_read from offset to offset + width + // and add the value to write to the value read + (value_first_read & !mask) | value_to_write + }; + + // Get the second aligned value + let value_second_read = input.mem_values[1]; + + // Compute the second write value + let value_second_write = { + // Normalize the width + let width_norm = CHUNK_NUM - offset; + + let mask: u64 = (1 << (rem_bytes * CHUNK_BITS)) - 1; + + // Get the first width bytes of the unaligned value + let value_to_write = (value >> (width_norm * CHUNK_BITS)) & mask; + + // Write zeroes to value_read from 0 to offset + width + // and add the value to write to the value read + (value_second_read & !mask) | value_to_write + }; + + // Get the next pc + let next_pc = self.mem_align_rom_sm.calculate_next_pc( + MemOp::TwoWrites, + offset, + width, + ); + + // RWVWR + let mut first_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut first_write_row = MemAlignRow:: { + step: F::from_canonical_u64(step + 1), + addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + let mut second_write_row = MemAlignRow:: { + step: F::from_canonical_u64(step + 1), + addr: F::from_canonical_u32(addr_second_read_write), + delta_addr: F::one(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc + 2), + // reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + let mut second_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_second_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 3), + reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + first_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); + if i < offset { + first_read_row.sel[i] = F::from_bool(true); + } + + first_write_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_write, i, 0)); + if i >= offset { + first_write_row.sel[i] = F::from_bool(true); + } + + value_row.reg[i] = { + if i < rem_bytes { + second_write_row.reg[i] + } else if i >= offset { + first_write_row.reg[i] + } else { + F::from_canonical_u64(Self::get_byte( + value, + i, + CHUNK_NUM - offset, + )) + } + }; + if i == offset { + value_row.sel[i] = F::from_bool(true); + } + + second_write_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_write, i, 0)); + if i < rem_bytes { + second_write_row.sel[i] = F::from_bool(true); + } + + second_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); + if i >= rem_bytes { + second_read_row.sel[i] = F::from_bool(true); + } + } + + let mut _value_first_read = value_first_read; + let mut _value_first_write = value_first_write; + let mut _value = value; + let mut _value_second_write = value_second_write; + let mut _value_second_read = value_second_read; + for i in 0..RC { + first_read_row.value[i] = + F::from_canonical_u64(_value_first_read & RC_MASK); + first_write_row.value[i] = + F::from_canonical_u64(_value_first_write & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + second_write_row.value[i] = + F::from_canonical_u64(_value_second_write & RC_MASK); + second_read_row.value[i] = + F::from_canonical_u64(_value_second_read & RC_MASK); + _value_first_read >>= RC_BITS; + _value_first_write >>= RC_BITS; + _value >>= RC_BITS; + _value_second_write >>= RC_BITS; + _value_second_read >>= RC_BITS; + } + + #[rustfmt::skip] + debug_info!( + "\nTwo Words Write\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Phase: {:?}\n\ + Value First Read: {:?}\n\ + Value First Write: {:?}\n\ + Value: {:?}\n\ + Value Second Read: {:?}\n\ + Value Second Write: {:?}\n\ + Flags First Read: {:?}\n\ + Flags First Write: {:?}\n\ + Flags Value: {:?}\n\ + Flags Second Write: {:?}\n\ + Flags Second Read: {:?}", + [*num_rows, *num_rows + 4], + input, + phase, + value_first_read.to_le_bytes(), + value_first_write.to_le_bytes(), + value.to_le_bytes(), + value_second_write.to_le_bytes(), + value_second_read.to_le_bytes(), + [ + first_read_row.sel[0], first_read_row.sel[1], first_read_row.sel[2], first_read_row.sel[3], + first_read_row.sel[4], first_read_row.sel[5], first_read_row.sel[6], first_read_row.sel[7], + first_read_row.wr, first_read_row.reset, first_read_row.sel_up_to_down, first_read_row.sel_down_to_up + ], + [ + first_write_row.sel[0], first_write_row.sel[1], first_write_row.sel[2], first_write_row.sel[3], + first_write_row.sel[4], first_write_row.sel[5], first_write_row.sel[6], first_write_row.sel[7], + first_write_row.wr, first_write_row.reset, first_write_row.sel_up_to_down, first_write_row.sel_down_to_up + ], + [ + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up + ], + [ + second_write_row.sel[0], second_write_row.sel[1], second_write_row.sel[2], second_write_row.sel[3], + second_write_row.sel[4], second_write_row.sel[5], second_write_row.sel[6], second_write_row.sel[7], + second_write_row.wr, second_write_row.reset, second_write_row.sel_up_to_down, second_write_row.sel_down_to_up + ], + [ + second_read_row.sel[0], second_read_row.sel[1], second_read_row.sel[2], second_read_row.sel[3], + second_read_row.sel[4], second_read_row.sel[5], second_read_row.sel[6], second_read_row.sel[7], + second_read_row.wr, second_read_row.reset, second_read_row.sel_up_to_down, second_read_row.sel_down_to_up + ] + ); + + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + + // Prove the generated rows + self.prove(&[ + first_read_row, + first_write_row, + value_row, + second_write_row, + second_read_row, + ]); + + MemAlignResponse { more_addr: false, step, value: Some(value_second_write) } + } + _ => panic!("Invalid phase={}", phase), + } + } + } + } + + fn get_byte(value: u64, index: usize, offset: usize) -> u64 { + let chunk = (offset + index) % CHUNK_NUM; + (value >> (chunk * CHUNK_BITS)) & CHUNK_BITS_MASK + } + + pub fn prove(&self, computed_rows: &[MemAlignRow]) { + if let Ok(mut rows) = self.rows.lock() { + rows.extend_from_slice(computed_rows); + + #[cfg(feature = "debug_mem_align")] + { + let mut num_rows = self.num_computed_rows.lock().unwrap(); + *num_rows += computed_rows.len(); + drop(num_rows); + } + + let pctx = self.wcm.get_pctx(); + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + + while rows.len() >= air_mem_align.num_rows() { + let num_drained = std::cmp::min(air_mem_align.num_rows(), rows.len()); + let drained_rows = rows.drain(..num_drained).collect::>(); + + self.fill_new_air_instance(&drained_rows); + } + } + } + + fn fill_new_air_instance(&self, rows: &[MemAlignRow]) { + // Get the proof context + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + + // Get the Mem Align AIR + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + let air_mem_align_rows = air_mem_align.num_rows(); + let rows_len = rows.len(); + + // You cannot feed to the AIR more rows than it has + debug_assert!(rows_len <= air_mem_align_rows); + + // Get the execution and setup context + let sctx = wcm.get_sctx(); + + let mut trace_buffer: MemAlignTrace<'_, _> = MemAlignTrace::new(air_mem_align_rows); + + let mut reg_range_check: Vec = vec![0; 1 << CHUNK_BITS]; + // Add the input rows to the trace + for (i, &row) in rows.iter().enumerate() { + // Store the entire row + trace_buffer[i] = row; + // Store the value of all reg columns so that they can be range checked + for j in 0..CHUNK_NUM { + let element = + row.reg[j].as_canonical_biguint().to_usize().expect("Cannot convert to usize"); + reg_range_check[element] += 1; + } + } + + // Pad the remaining rows with trivially satisfying rows + let padding_row = MemAlignRow:: { reset: F::from_bool(true), ..Default::default() }; + let padding_size = air_mem_align_rows - rows_len; + + // Store the padding rows + for i in rows_len..air_mem_align_rows { + trace_buffer[i] = padding_row; + } + + // Store the value of all padding reg columns so that they can be range checked + for _ in 0..CHUNK_NUM { + reg_range_check[0] += padding_size as u64; + } + + // Perform the range checks + let std = self.std.clone(); + let range_id = std.get_range(BigInt::from(0), BigInt::from(CHUNK_BITS_MASK), None); + for (value, &multiplicity) in reg_range_check.iter().enumerate() { + std.range_check( + F::from_canonical_usize(value), + F::from_canonical_u64(multiplicity), + range_id, + ); + } + + // Compute the program multiplicity + let mem_align_rom_sm = self.mem_align_rom_sm.clone(); + mem_align_rom_sm.update_padding_row(padding_size as u64); + + info!( + "{}: ยทยทยท Creating Mem Align instance [{} / {} rows filled {:.2}%]", + Self::MY_NAME, + rows_len, + air_mem_align_rows, + rows_len as f64 / air_mem_align_rows as f64 * 100.0 + ); + + // Add a new Mem Align instance + let air_instance = AirInstance::new( + sctx, + ZISK_AIRGROUP_ID, + MEM_ALIGN_AIR_IDS[0], + None, + trace_buffer.buffer.unwrap(), + ); + pctx.air_instance_repo.add_air_instance(air_instance, None); + } +} + +impl WitnessComponent for MemAlignSM {} diff --git a/state-machines/mem/src/mem_aligned.rs b/state-machines/mem/src/mem_aligned.rs deleted file mode 100644 index 1a126e3c..00000000 --- a/state-machines/mem/src/mem_aligned.rs +++ /dev/null @@ -1,112 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use p3_field::Field; -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{MemOp, OpResult, Provable}; -use zisk_pil::{MEM_AIRGROUP_ID, MEM_ALIGN_AIR_IDS}; - -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -pub struct MemAlignedSM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs: Mutex>, -} - -#[allow(unused, unused_variables)] -impl MemAlignedSM { - pub fn new(wcm: Arc>) -> Arc { - let mem_aligned_sm = - Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let mem_aligned_sm = Arc::new(mem_aligned_sm); - - wcm.register_component( - mem_aligned_sm.clone(), - Some(MEM_AIRGROUP_ID), - Some(MEM_ALIGN_AIR_IDS), - ); - - mem_aligned_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); - } - } - - fn read( - &self, - _addr: u64, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } - - fn write( - &self, - _addr: u64, - _val: u64, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } -} - -impl WitnessComponent for MemAlignedSM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, - ) { - } -} - -impl Provable for MemAlignedSM { - fn calculate(&self, operation: MemOp) -> Result> { - match operation { - MemOp::Read(addr) => self.read(addr), - MemOp::Write(addr, val) => self.write(addr, val), - } - } - - fn prove(&self, operations: &[MemOp], drain: bool, scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); - - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); - } - } - } - - fn calculate_prove( - &self, - operation: MemOp, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} diff --git a/state-machines/mem/src/mem_constants.rs b/state-machines/mem/src/mem_constants.rs index 4e177ee3..9165edd1 100644 --- a/state-machines/mem/src/mem_constants.rs +++ b/state-machines/mem/src/mem_constants.rs @@ -1,12 +1,17 @@ -pub const MEM_ADDR_MASK: u64 = 0xFFFF_FFFF_FFFF_FFF8; -pub const MEM_BYTES: u64 = 8; +pub const MEM_ADDR_MASK: u32 = 0xFFFF_FFF8; +pub const MEM_BYTES_BITS: u32 = 3; +pub const MEM_BYTES: u32 = 1 << MEM_BYTES_BITS; +pub const MEM_STEP_BASE: u64 = 1; pub const MAX_MEM_STEP_OFFSET: u64 = 2; -pub const MAX_MEM_OPS_PER_MAIN_STEP: u64 = (MAX_MEM_STEP_OFFSET + 1) * 2; +pub const MAX_MEM_OPS_BY_STEP_OFFSET: u64 = 2; +pub const MAX_MEM_OPS_BY_MAIN_STEP: u64 = (MAX_MEM_STEP_OFFSET + 1) * MAX_MEM_OPS_BY_STEP_OFFSET; -pub const MEM_STEP_BITS: u64 = 34; // with step_slot = 8 => 2GB steps ( -pub const MEM_STEP_MASK: u64 = (1 << MEM_STEP_BITS) - 1; // 256 MB -pub const MEM_ADDR_BITS: u64 = 64 - MEM_STEP_BITS; +pub const MAX_MAIN_STEP: u64 = 0x1FFF_FFFF_FFFF_FFFF; +pub const MAX_MEM_STEP: u64 = MEM_STEP_BASE + + MAX_MEM_OPS_BY_MAIN_STEP * MAX_MAIN_STEP + + MAX_MEM_OPS_BY_STEP_OFFSET * MAX_MEM_STEP_OFFSET; -pub const MAX_MEM_STEP: u64 = (1 << MEM_STEP_BITS) - 1; -pub const MAX_MEM_ADDR: u64 = (1 << MEM_ADDR_BITS) - 1; +pub const MAX_MEM_ADDR: u64 = 0xFFFF_FFFF; + +pub const MEMORY_MAX_DIFF: u64 = 1 << 24; diff --git a/state-machines/mem/src/mem_helpers.rs b/state-machines/mem/src/mem_helpers.rs index ac4ca198..8e70b537 100644 --- a/state-machines/mem/src/mem_helpers.rs +++ b/state-machines/mem/src/mem_helpers.rs @@ -1,7 +1,10 @@ -use crate::MemAlignResponse; +use crate::{ + MemAlignResponse, MAX_MEM_OPS_BY_MAIN_STEP, MAX_MEM_OPS_BY_STEP_OFFSET, MEM_STEP_BASE, +}; use std::fmt; use zisk_core::ZiskRequiredMemory; +#[allow(dead_code)] fn format_u64_hex(value: u64) -> String { let hex_str = format!("{:016x}", value); hex_str @@ -12,54 +15,73 @@ fn format_u64_hex(value: u64) -> String { .join("_") } +#[derive(Debug, Clone)] +pub struct MemAlignInput { + pub addr: u32, + pub is_write: bool, + pub width: u8, + pub step: u64, + pub value: u64, + pub mem_values: [u64; 2], +} + +#[derive(Debug, Clone)] +pub struct MemInput { + pub addr: u32, // address in word native format means byte_address / MEM_BYTES + pub is_write: bool, // it's a write operation + pub is_internal: bool, // internal operation, don't send this operation to bus + pub step: u64, // mem_step = f(main_step, main_step_offset) + pub value: u64, // value to read or write +} + +impl MemAlignInput { + pub fn new( + addr: u32, + is_write: bool, + width: u8, + step: u64, + value: u64, + mem_values: [u64; 2], + ) -> Self { + MemAlignInput { addr, is_write, width, step, value, mem_values } + } + pub fn from(mem_external_op: &ZiskRequiredMemory, mem_values: &[u64; 2]) -> Self { + match mem_external_op { + ZiskRequiredMemory::Basic { step, value, address, is_write, width, step_offset } => { + MemAlignInput { + addr: *address, + is_write: *is_write, + step: MemHelpers::main_step_to_address_step(*step, *step_offset), + width: *width, + value: *value, + mem_values: [mem_values[0], mem_values[1]], + } + } + ZiskRequiredMemory::Extended { values: _, address: _ } => { + panic!("MemAlignInput::from() called with extended instance") + } + } + } +} + +pub struct MemHelpers {} + +impl MemHelpers { + pub fn main_step_to_address_step(step: u64, step_offset: u8) -> u64 { + MEM_STEP_BASE + + MAX_MEM_OPS_BY_MAIN_STEP * step + + MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64 + } +} + impl fmt::Debug for MemAlignResponse { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "more:{0} step:{1} value:{2:016X}({2:})", - self.more_address, + self.more_addr, self.step, self.value.unwrap_or(0) ) } } - -pub fn mem_align_call( - mem_op: &ZiskRequiredMemory, - mem_values: [u64; 2], - phase: u8, -) -> MemAlignResponse { - // DEBUG: only for testing - let offset = (mem_op.address & 0x7) * 8; - let width = (mem_op.width as u64) * 8; - let double_address = (offset + width as u32) > 64; - let mem_value = mem_values[phase as usize]; - let mask = 0xFFFF_FFFF_FFFF_FFFFu64 >> (64 - width); - if mem_op.is_write { - if phase == 0 { - MemAlignResponse { - more_address: double_address, - step: mem_op.step + 1, - value: Some( - (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))) - | ((mem_op.value & mask) << offset), - ), - } - } else { - MemAlignResponse { - more_address: false, - step: mem_op.step + 1, - value: Some( - (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width as u32 - 64))) - | ((mem_op.value & mask) >> (128 - (offset + width as u32))), - ), - } - } - } else { - MemAlignResponse { - more_address: double_address && phase == 0, - step: mem_op.step + 1, - value: None, - } - } -} diff --git a/state-machines/mem/src/mem_module.rs b/state-machines/mem/src/mem_module.rs new file mode 100644 index 00000000..59308fd3 --- /dev/null +++ b/state-machines/mem/src/mem_module.rs @@ -0,0 +1,31 @@ +use crate::{MemHelpers, MemInput, MEM_BYTES}; +use zisk_core::ZiskRequiredMemory; + +impl MemInput { + pub fn new(addr: u32, is_write: bool, step: u64, value: u64, is_internal: bool) -> Self { + MemInput { addr, is_write, step, value, is_internal } + } + pub fn from(mem_op: &ZiskRequiredMemory) -> Self { + match mem_op { + ZiskRequiredMemory::Basic { step, value, address, is_write, width, step_offset } => { + debug_assert_eq!(*width, MEM_BYTES as u8); + MemInput { + addr: address >> 3, + is_write: *is_write, + is_internal: false, + step: MemHelpers::main_step_to_address_step(*step, *step_offset), + value: *value, + } + } + ZiskRequiredMemory::Extended { values: _, address: _ } => { + panic!("MemInput::from() called with an extended instance"); + } + } + } +} + +pub trait MemModule: Send + Sync { + fn send_inputs(&self, mem_op: &[MemInput]); + fn get_addr_ranges(&self) -> Vec<(u32, u32)>; + fn get_flush_input_size(&self) -> u32; +} diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs new file mode 100644 index 00000000..a5bcf320 --- /dev/null +++ b/state-machines/mem/src/mem_proxy.rs @@ -0,0 +1,79 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, +}; + +use crate::{InputDataSM, MemAlignRomSM, MemAlignSM, MemProxyEngine, MemSM, RomDataSM}; +use p3_field::PrimeField; +use pil_std_lib::Std; +use zisk_core::ZiskRequiredMemory; + +use proofman::{WitnessComponent, WitnessManager}; + +pub struct MemProxy { + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Secondary State machines + mem_sm: Arc>, + mem_align_sm: Arc>, + mem_align_rom_sm: Arc>, + input_data_sm: Arc>, + rom_data_sm: Arc>, +} + +impl MemProxy { + pub fn new(wcm: Arc>, std: Arc>) -> Arc { + let mem_align_rom_sm = MemAlignRomSM::new(wcm.clone()); + let mem_align_sm = MemAlignSM::new(wcm.clone(), std.clone(), mem_align_rom_sm.clone()); + let mem_sm = MemSM::new(wcm.clone(), std.clone()); + let input_data_sm = InputDataSM::new(wcm.clone(), std.clone()); + let rom_data_sm = RomDataSM::new(wcm.clone(), std.clone()); + + let mem_proxy = Self { + registered_predecessors: AtomicU32::new(0), + mem_align_sm, + mem_align_rom_sm, + mem_sm, + input_data_sm, + rom_data_sm, + }; + let mem_proxy = Arc::new(mem_proxy); + + wcm.register_component(mem_proxy.clone(), None, None); + + // For all the secondary state machines, register the main state machine as a predecessor + mem_proxy.mem_align_rom_sm.register_predecessor(); + mem_proxy.mem_align_sm.register_predecessor(); + mem_proxy.mem_sm.register_predecessor(); + mem_proxy.input_data_sm.register_predecessor(); + mem_proxy.rom_data_sm.register_predecessor(); + mem_proxy + } + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + self.mem_align_rom_sm.unregister_predecessor(); + self.mem_align_sm.unregister_predecessor(); + self.mem_sm.unregister_predecessor(); + self.input_data_sm.unregister_predecessor(); + self.rom_data_sm.unregister_predecessor(); + } + } + + pub fn prove( + &self, + mem_operations: &mut Vec, + ) -> Result<(), Box> { + let mut engine = MemProxyEngine::::new(self.mem_align_sm.clone()); + engine.add_module("mem", self.mem_sm.clone()); + engine.add_module("input_data", self.input_data_sm.clone()); + engine.add_module("row_data", self.rom_data_sm.clone()); + engine.prove(mem_operations) + } +} + +impl WitnessComponent for MemProxy {} diff --git a/state-machines/mem/src/mem_proxy_engine.rs b/state-machines/mem/src/mem_proxy_engine.rs new file mode 100644 index 00000000..3ca6bf5d --- /dev/null +++ b/state-machines/mem/src/mem_proxy_engine.rs @@ -0,0 +1,628 @@ +//! The `MemProxyEngine` module is designed to facilitate dividing the proxy logic into smaller, +//! more manageable pieces of code. +//! +//! The engine is created through MemProxy on a static call, which creates the `MemProxyEngine`. +//! `MemProxyEngine` has state, and this state allows the implementation of smaller, focused +//! methods, making the codebase easier to maintain and extend. +//! +//! +//! ## Creation and Setup of the `MemProxyEngine` +//! +//! When creating the `MemProxyEngine`, a state machine is provided to handle alignment of memory +//! accesses. This state machine is responsible for demostrate unaligned accesses based on aligned +//! ones. +//! +//! Once the `MemProxyEngine` is created, all memory modules are registered. These modules must +//! implement the `MemModule` trait, which serves three purposes: +//! +//! 1. To define the range of addresses (regions) they are responsible for handling. +//! 2. To specify the frequency (number of inputs) at which they expect to receive inputs. +//! 3. To define the "callback" used to send inputs to the module +//! +//! +//! ## Inputs from `MemProxyEngine` +//! +//! The inputs to the `MemProxyEngine` are represented as an enumeration to optimize memory usage +//! and performance. This design ensures efficient handling of both common and rare cases, +//! balancing memory allocation and computational efficiency. +//! +//! The enumeration has two variants: +//! 1. `Basic`: The primary input type, used for the majority of memory accesses. This variant is +//! highly optimized to minimize overhead and ensure efficient processing in typical scenarios. +//! 2. `Extended`: A specialized input type used exclusively for handling unaligned memory +//! accesses. This variant is appended to the vector immediately after the corresponding `Basic` +//! instance that generates it. The `Extended` input contains the aligned memory values required +//! to process the unaligned access (in word case two values) +//! +//! By adopting this design, the `MemProxyEngine` avoids penalizing the commonly used `Basic` type +//! due to the less frequent unaligned cases that requires addicional `Extended` type. This +//! separation ensures that unaligned access handling introduces minimal overhead to the overall +//! system, while still providing the flexibility to unaligned access. +//! +//! +//! ## Logic of the `MemProxyEngine` +//! +//! Step 1. Sort the aligned memory accesses +//! original vector is sorted by step, sort_by_key is stable, no reordering of elements with +//! the same key. +//! +//! Step 2. Add a final mark mem_op to force flush of open_mem_align_ops, because always the +//! last operation is mem_op. +//! +//! Step 3. Composing information for memory operation (access). In this step, all necessary +//! information is gathered and composed to perform a memory operation. The process involves +//! reading the next input from the input vector, which defines the nature of the operation. +//! +//! - For standard (aligned) operations, only the `Basic` input is required, and the operation +//! proceeds directly. +//! - For unaligned operations, the `Extended` input is also read. This additional input provides +//! the extra values required to handle the unaligned operation. +//! +//! Step 4. Process each memory operation ordered by address and step. When a non-aligned +//! memory access there are two possible situations: +//! +//! 1. The operation applies only applies to one memory address (read or read+write). In this case +//! mem_align helper return the aligned operation for this address, and loop continues. +//! +//! 2. The operation applies to two consecutive memory addresses, mem_align helper returns the +//! aligned operation involved for the current address, and the second part of the operation is +//! enqueued to open_mem_align_ops, it will processed when processing next address. +//! +//! First, we verify if there are any "previous" open memory alignment operations +//! (`open_mem_align_ops`) that need to be processed before handling the current `mem_op`. If such +//! operations exist, they are processed first, and then the current `mem_op` is executed. +//! +//! At the end of Step 2, a final marker is used to ensure a forced flush of any remaining +//! `open_mem_align_ops`. This guarantees that all pending alignment operations are completed, +//! as the last operation in this step is always a `mem_op`. +//! +//! +//! ## Handling Large Gaps Between Steps +//! +//! One challenge in the design is addressing cases where the distance between steps becomes +//! more large than max range check MEMORY_MAX_DIFF (current 2^24). This solve this situation +//! the proxy add extra intermediate internal reads (internal because don't send to bus), each +//! increase step in MEMORY_MAX_DIFF to arrive to the final step. + +use std::{collections::VecDeque, sync::Arc}; + +use crate::{ + MemAlignInput, MemAlignResponse, MemAlignSM, MemHelpers, MemInput, MemModule, MemUnmapped, + MAX_MAIN_STEP, MAX_MEM_ADDR, MAX_MEM_OPS_BY_MAIN_STEP, MAX_MEM_STEP, MAX_MEM_STEP_OFFSET, + MEMORY_MAX_DIFF, MEM_ADDR_MASK, MEM_BYTES, MEM_BYTES_BITS, +}; +use log::info; + +use p3_field::PrimeField; +use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; +use zisk_core::ZiskRequiredMemory; + +#[cfg(feature = "debug_mem_proxy_engine")] +const DEBUG_ADDR: u32 = 0x90000008; + +macro_rules! debug_info { + ($prefix:expr, $($arg:tt)*) => { + #[cfg(feature = "debug_mem_proxy_engine")] + { + info!(concat!("MemProxy: ",$prefix), $($arg)*); + } + }; +} + +struct MemModuleData { + pub name: String, + pub inputs: Vec, + pub flush_input_size: usize, +} + +#[derive(Debug)] +pub struct AddressRegion { + from_addr: u32, + to_addr: u32, + module_id: u8, +} +pub struct MemProxyEngine { + modules: Vec>>, + modules_data: Vec, + open_mem_align_ops: VecDeque, + addr_map: Vec, + addr_map_fetched: bool, + current_module_id: usize, + current_module: String, + module_end_addr: u32, + mem_align_sm: Arc>, + next_open_addr: u32, + next_open_step: u64, + last_addr: u32, + last_step: u64, + intermediate_cases: u32, + intermediate_steps: u32, +} + +const NO_OPEN_ADDR: u32 = 0xFFFF_FFFF; +const NO_OPEN_STEP: u64 = 0xFFFF_FFFF_FFFF_FFFF; + +impl MemProxyEngine { + pub fn new(mem_align_sm: Arc>) -> Self { + Self { + modules: Vec::new(), + modules_data: Vec::new(), + current_module_id: 0, + current_module: String::new(), + module_end_addr: 0, + open_mem_align_ops: VecDeque::new(), + addr_map: Vec::new(), + addr_map_fetched: false, + mem_align_sm, + next_open_addr: NO_OPEN_ADDR, + next_open_step: NO_OPEN_STEP, + last_addr: 0xFFFF_FFFF, + last_step: 0, + intermediate_cases: 0, + intermediate_steps: 0, + } + } + + pub fn add_module(&mut self, name: &str, module: Arc>) { + if self.modules.is_empty() { + self.current_module = String::from(name); + } + let module_id = self.modules.len() as u8; + self.modules.push(module.clone()); + + let ranges = module.get_addr_ranges(); + let flush_input_size = module.get_flush_input_size(); + + for range in ranges.iter() { + debug_info!("adding range 0x{:X} 0x{:X} to {}", range.0, range.1, name); + self.insert_address_range(range.0, range.1, module_id); + } + self.modules_data.push(MemModuleData { + name: String::from(name), + inputs: Vec::new(), + flush_input_size: if flush_input_size == 0 { + 0xFFFF_FFFF_FFFF_FFFF + } else { + flush_input_size as usize + }, + }); + } + /* insert in sort way the address map and verify that */ + fn insert_address_range(&mut self, from_addr: u32, to_addr: u32, module_id: u8) { + let region = AddressRegion { from_addr, to_addr, module_id }; + if let Some(index) = self.addr_map.iter().position(|x| x.from_addr >= from_addr) { + self.addr_map.insert(index, region); + } else { + self.addr_map.push(region); + } + } + + pub fn prove( + &mut self, + mem_operations: &mut Vec, + ) -> Result<(), Box> { + self.init_prove(); + + // Sort the aligned memory accesses + // original vector is sorted by step, sort_by_key is stable, no reordering of elements with + // the same key. + + timer_start_debug!(MEM_SORT); + mem_operations.sort_by_key(|mem| (mem.get_address() & 0xFFFF_FFF8)); + timer_stop_and_log_debug!(MEM_SORT); + + // Add a final mark mem_op to force flush of open_mem_align_ops, because always the + // last operation is mem_op. + + self.push_end_of_memory_mark(mem_operations); + + let mut index = 0; + let count = mem_operations.len(); + while index < count { + if let ZiskRequiredMemory::Basic { + step, + value, + address, + is_write, + width, + step_offset, + } = mem_operations[index] + { + let extend_values = if !Self::is_aligned(address, width) { + debug_assert!(index + 1 < count, "expected one element more extended !!"); + if let ZiskRequiredMemory::Extended { address: _, values } = + mem_operations[index + 1] + { + index += 1; + values + } else { + panic!("MemProxy::prove() unexpected Basic variant"); + } + } else { + [0, 0] + }; + index += 1; + if !self.prove_one( + address, + MemHelpers::main_step_to_address_step(step, step_offset), + value, + is_write, + width, + extend_values, + ) { + break; + } + } else { + panic!("MemProxy::prove() unexpected Extended variant"); + } + } + self.finish_prove(); + Ok(()) + } + + fn prove_one( + &mut self, + addr: u32, + mem_step: u64, + value: u64, + is_write: bool, + width: u8, + extend_values: [u64; 2], + ) -> bool { + let is_aligned: bool = Self::is_aligned(addr, width); + let aligned_mem_addr = Self::to_aligned_addr(addr); + + // Check if there are open mem align operations to be processed in this moment, + // with address (or step) less than the aligned of current + // mem_op. + self.process_all_previous_open_mem_align_ops(aligned_mem_addr, mem_step); + + // check if we are at end of loop + if self.check_if_end_of_memory_mark(addr, mem_step) { + return false; + } + + // all open mem align operations are processed, check if new mem operation is + // aligned + if !is_aligned { + // In this point found non-aligned memory access, phase-0 + let mem_align_input = MemAlignInput { + addr, + value, + width, + mem_values: extend_values, + is_write, + step: mem_step, + }; + let mem_align_response = self.mem_align_sm.get_mem_op(&mem_align_input, 0); + + #[cfg(feature = "debug_mem_proxy_engine")] + Self::debug_mem_align_api(&mem_align_input, &mem_align_response, 0); + + // if operation applies to two consecutive memory addresses, add the second + // part is enqueued to be processed in future when + // processing next address on phase-1 + self.push_mem_align_response_ops( + aligned_mem_addr, + extend_values[0], + &mem_align_input, + &mem_align_response, + ); + if mem_align_response.more_addr { + self.open_mem_align_ops.push_back(mem_align_input); + self.update_next_open_mem_align(); + } + } else { + self.push_aligned_op(is_write, addr, value, mem_step); + } + true + } + + fn update_next_open_mem_align(&mut self) { + if self.open_mem_align_ops.is_empty() { + self.next_open_addr = NO_OPEN_ADDR; + self.next_open_step = NO_OPEN_STEP; + } else if self.open_mem_align_ops.len() == 1 { + let mem_align_input = self.open_mem_align_ops.front().unwrap(); + self.next_open_addr = Self::next_aligned_addr(mem_align_input.addr); + self.next_open_step = mem_align_input.step; + } + } + + fn process_all_previous_open_mem_align_ops(&mut self, mem_addr: u32, mem_step: u64) { + // Two possible situations to process open mem align operations: + // + // 1) the address of open operation is less than the aligned address. + // 2) the address of open operation is equal to the aligned address, but the step of the + // open operation is less than the step of the current operation. + + while let Some(open_op) = self.get_next_open_mem_align_input(mem_addr, mem_step) { + // call to mem_align to get information of the aligned memory access needed + // to prove the unaligned open operation. + let mem_align_resp = self.mem_align_sm.get_mem_op(&open_op, 1); + + #[cfg(feature = "debug_mem_proxy_engine")] + Self::debug_mem_align_api(&open_op, &mem_align_resp, 1); + + // push the aligned memory operations for current address (read or read+write) and + // update last_address and last_value. + self.push_mem_align_response_ops( + Self::next_aligned_addr(open_op.addr), + open_op.mem_values[1], + &open_op, + &mem_align_resp, + ); + } + } + + pub fn main_step_to_mem_step(step: u64, step_offset: u8) -> u64 { + 1 + MAX_MEM_OPS_BY_MAIN_STEP * step + 2 * step_offset as u64 + } + + #[inline(always)] + fn is_aligned(address: u32, width: u8) -> bool { + ((address & 0x07) == 0) && (width == 8) + } + + fn push_aligned_op(&mut self, is_write: bool, addr: u32, value: u64, step: u64) { + self.update_mem_module(addr); + let w_addr = Self::to_aligned_word_addr(addr); + + // check if step difference is too large + if self.last_addr == w_addr && (step - self.last_step) > MEMORY_MAX_DIFF { + self.push_intermediate_internal_reads(w_addr, value, self.last_step, step); + } + + self.last_step = step; + self.last_addr = w_addr; + + let mem_op = MemInput { step, is_write, is_internal: false, addr: w_addr, value }; + debug_info!( + "route ==> {}[{:X}] {} {} #{}", + self.current_module, + mem_op.addr << MEM_BYTES_BITS, + if is_write { "W" } else { "R" }, + value, + step, + ); + self.internal_push_mem_op(mem_op); + } + + fn push_intermediate_internal_reads( + &mut self, + addr: u32, + value: u64, + last_step: u64, + final_step: u64, + ) { + let mut step = last_step; + self.intermediate_cases += 1; + while (final_step - step) > MEMORY_MAX_DIFF { + self.intermediate_steps += 1; + step += MEMORY_MAX_DIFF; + let mem_op = MemInput { step, is_write: false, is_internal: true, addr, value }; + self.internal_push_mem_op(mem_op); + } + } + + fn internal_push_mem_op(&mut self, mem_op: MemInput) { + self.modules_data[self.current_module_id].inputs.push(mem_op); + self.check_flush_inputs(); + } + // method to add aligned read operation + #[inline(always)] + fn push_aligned_read(&mut self, addr: u32, value: u64, step: u64) { + self.push_aligned_op(false, addr, value, step); + } + // method to add aligned write operation + #[inline(always)] + fn push_aligned_write(&mut self, addr: u32, value: u64, step: u64) { + self.push_aligned_op(true, addr, value, step); + } + /// Process information of mem_op and mem_align_op to push mem_op operation. Only two possible + /// situations: + /// 1) read, only on single mem_op is pushed + /// 2) read+write, two mem_op are pushed, one read and one write. + /// + /// This process is used for each aligned memory address, means that the "second part" of non + /// aligned memory operation is processed on addr + MEM_BYTES. + fn push_mem_align_response_ops( + &mut self, + mem_addr: u32, + mem_value: u64, + mem_align_input: &MemAlignInput, + mem_align_resp: &MemAlignResponse, + ) { + self.push_aligned_read(mem_addr, mem_value, mem_align_resp.step); + if mem_align_input.is_write { + self.push_aligned_write( + mem_addr, + mem_align_resp.value.unwrap(), + mem_align_resp.step + 1, + ); + } + } + fn set_active_region(&mut self, region_id: usize) { + self.current_module_id = self.addr_map[region_id].module_id as usize; + self.current_module = self.modules_data[self.current_module_id].name.clone(); + self.module_end_addr = self.addr_map[region_id].to_addr; + } + fn update_mem_module_id(&mut self, addr: u32) { + debug_info!("search module for address 0x{:X}", addr); + if let Some(index) = + self.addr_map.iter().position(|x| x.from_addr <= addr && x.to_addr >= addr) + { + self.set_active_region(index); + } else { + panic!("out-of-memory 0x{:X}", addr); + } + } + fn update_mem_module(&mut self, addr: u32) { + // check if need to reevaluate the module id + if addr > self.module_end_addr { + self.update_mem_module_id(addr); + } + } + fn check_flush_inputs(&mut self) { + // check if need to flush the inputs of the module + let mid = self.current_module_id; + let inputs = self.modules_data[mid].inputs.len(); + if inputs >= self.modules_data[mid].flush_input_size { + // TODO: optimize passing ownership of inputs to module, and creating a new input + // object + debug_info!("flush {} inputs => {}", inputs, self.current_module); + self.modules[mid].send_inputs(&self.modules_data[mid].inputs); + self.modules_data[mid].inputs.clear(); + } + } + + fn get_next_open_mem_align_input(&mut self, addr: u32, step: u64) -> Option { + if self.next_open_addr < addr || (self.next_open_addr == addr && self.next_open_step < step) + { + let open_op = self.open_mem_align_ops.pop_front().unwrap(); + self.update_next_open_mem_align(); + Some(open_op) + } else { + None + } + } + // method to process open mem align operations, second part of non aligned memory operations + // applies to two consecutive memory addresses. + + fn push_end_of_memory_mark(&mut self, mem_operations: &mut Vec) { + mem_operations.push(ZiskRequiredMemory::Basic { + step: MAX_MAIN_STEP, + step_offset: MAX_MEM_STEP_OFFSET as u8, + is_write: false, + address: MAX_MEM_ADDR as u32, + width: MEM_BYTES as u8, + value: 0, + }); + mem_operations + .push(ZiskRequiredMemory::Extended { address: MAX_MEM_ADDR as u32, values: [0, 0] }); + } + + /// Check if the address is the "special" address inserted at the end of the memory operations + #[inline(always)] + fn check_if_end_of_memory_mark(&self, addr: u32, mem_step: u64) -> bool { + if addr == MAX_MEM_ADDR as u32 && mem_step == MAX_MEM_STEP { + debug_assert!( + self.open_mem_align_ops.is_empty(), + "open_mem_align_ops not empty, has {} elements", + self.open_mem_align_ops.len() + ); + true + } else { + false + } + } + /// Encapsulates all tasks to be performed at the beginning of the witness computation (stage + /// 1). + /// + /// This method fetches the address map and sets the initial values to prepare for the + /// computation. + fn init_prove(&mut self) { + if !self.addr_map_fetched { + self.fetch_address_map(); + } + self.current_module_id = self.addr_map[0].module_id as usize; + self.current_module = self.modules_data[self.current_module_id].name.clone(); + self.module_end_addr = self.addr_map[0].to_addr; + } + /// Encapsulates all tasks to be performed at the end of the witness computation (stage 1). + /// + /// This method flushes all module inputs to ensure they are finalized and ready for further + /// processing. + fn finish_prove(&self) { + for (module_id, module) in self.modules.iter().enumerate() { + debug_info!( + "{}: flush all({}) inputs", + self.modules_data[module_id].name, + self.modules_data[module_id].inputs.len() + ); + module.send_inputs(&self.modules_data[module_id].inputs); + } + info!( + "MemProxy: ยทยทยท Intermediate reads [cases:{} steps:{}]", + self.intermediate_cases, self.intermediate_steps + ); + } + /// Fetches the address map, defining and calculating all necessary structures to manage the + /// memory map. + /// + /// For undefined regions (such as memory between defined regions, or memory at the beginning or + /// end of the memory map), this method assigns an unmapped module. If any access occurs + /// within these unmapped memory regions, the method will trigger a panic. + /// + /// The unmapped module ensures that every address has an associated module to handle memory + /// access, providing a safety mechanism to prevent undefined behavior. + fn fetch_address_map(&mut self) { + let unmapped_regions: Vec<(u32, u32)> = self.get_unmapped_regions(); + if !unmapped_regions.is_empty() { + self.define_unmapped_module(&unmapped_regions); + } + self.addr_map_fetched = true; + } + + /// Get list of regions (from_addr, to_addr) that are not defined in the memory map + fn get_unmapped_regions(&self) -> Vec<(u32, u32)> { + let mut next_addr = 0; + let mut unmapped_regions: Vec<(u32, u32)> = Vec::new(); + for addr_region in self.addr_map.iter() { + if next_addr < addr_region.from_addr { + unmapped_regions.push((next_addr, addr_region.from_addr - 1)); + } + next_addr = addr_region.to_addr + 1; + } + unmapped_regions + } + + /// Define an unmapped module with all unmapped regions. + fn define_unmapped_module(&mut self, unmapped_regions: &[(u32, u32)]) { + let mut unmapped_module = MemUnmapped::::new(); + for unmapped_region in unmapped_regions.iter() { + unmapped_module.add_range(unmapped_region.0, unmapped_region.1); + } + self.add_module("unmapped", Arc::new(unmapped_module)); + } + + /// Calculate aligned address from regular address (aligned or not) + #[inline(always)] + fn to_aligned_addr(addr: u32) -> u32 { + addr & MEM_ADDR_MASK + } + + /// Calculate the next aligned address from regular address (aligned or not) + #[inline(always)] + fn next_aligned_addr(addr: u32) -> u32 { + (addr & MEM_ADDR_MASK) + MEM_BYTES + } + + /// Calculate the word address where word is MEM_BYTES + #[inline(always)] + fn to_aligned_word_addr(addr: u32) -> u32 { + addr >> MEM_BYTES_BITS + } + + #[cfg(feature = "debug_mem_proxy_engine")] + fn debug_mem_align_api( + mem_align_input: &MemAlignInput, + mem_align_response: &MemAlignResponse, + phase: u8, + ) { + if mem_align_input.addr >= DEBUG_ADDR - 8 && mem_align_input.addr <= DEBUG_ADDR + 8 { + debug_info!( + "mem_align_input_{:X}: phase:{} {:?}", + mem_align_input.addr, + phase, + mem_align_input + ); + debug_info!( + "mem_align_response_{:X}: phase:{} {:?}", + mem_align_input.addr, + phase, + mem_align_response + ); + } + } +} diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs new file mode 100644 index 00000000..051277e6 --- /dev/null +++ b/state-machines/mem/src/mem_sm.rs @@ -0,0 +1,383 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use crate::{MemInput, MemModule, MEMORY_MAX_DIFF, MEM_BYTES_BITS}; +use num_bigint::BigInt; +use p3_field::PrimeField; +use pil_std_lib::Std; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::AirInstance; + +use zisk_core::{RAM_ADDR, RAM_SIZE}; +use zisk_pil::{MemTrace, MEM_AIR_IDS, ZISK_AIRGROUP_ID}; + +const RAM_W_ADDR_INIT: u32 = RAM_ADDR as u32 >> MEM_BYTES_BITS; +const RAM_W_ADDR_END: u32 = (RAM_ADDR + RAM_SIZE - 1) as u32 >> MEM_BYTES_BITS; + +const _: () = { + assert!((RAM_SIZE - 1) >> MEM_BYTES_BITS <= MEMORY_MAX_DIFF, "RAM is too large"); + assert!( + (RAM_ADDR + RAM_SIZE - 1) <= 0xFFFF_FFFF, + "RAM memory exceeds the 32-bit addressable range" + ); +}; + +pub struct MemSM { + // Witness computation manager + wcm: Arc>, + + // STD + std: Arc>, + + // Count of registered predecessors + registered_predecessors: AtomicU32, +} + +#[derive(Default)] +pub struct MemAirValues { + pub segment_id: u32, + pub is_first_segment: bool, + pub is_last_segment: bool, + pub previous_segment_addr: u32, + pub previous_segment_step: u64, + pub previous_segment_value: [u32; 2], + pub segment_last_addr: u32, + pub segment_last_step: u64, + pub segment_last_value: [u32; 2], +} +#[derive(Debug)] +pub struct MemPreviousSegment { + pub addr: u32, + pub step: u64, + pub value: u64, +} + +#[allow(unused, unused_variables)] +impl MemSM { + pub fn new(wcm: Arc>, std: Arc>) -> Arc { + let pctx = wcm.get_pctx(); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); + let mem_sm = + Self { wcm: wcm.clone(), std: std.clone(), registered_predecessors: AtomicU32::new(0) }; + let mem_sm = Arc::new(mem_sm); + + wcm.register_component(mem_sm.clone(), Some(ZISK_AIRGROUP_ID), Some(MEM_AIR_IDS)); + std.register_predecessor(); + + mem_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + let pctx = self.wcm.get_pctx(); + self.std.unregister_predecessor(pctx, None); + } + } + + pub fn prove(&self, inputs: &[MemInput]) { + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let ectx = wcm.get_ectx(); + let sctx = wcm.get_sctx(); + + // PRE: proxy calculate if exists jmp on step out-of-range, adding internal inputs + // memory only need to process these special inputs, but inputs no change. At end of + // inputs proxy add an extra internal input to jump to last address + + let air_id = MEM_AIR_IDS[0]; + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, air_id); + let air_rows = air.num_rows(); + + // at least one row to go + let count = inputs.len(); + let count_rem = count % air_rows; + let num_segments = (count / air_rows) + if count_rem > 0 { 1 } else { 0 }; + + let mut prover_buffers = Mutex::new(vec![Vec::new(); num_segments]); + let mut global_idxs = vec![0; num_segments]; + + #[allow(clippy::needless_range_loop)] + for i in 0..num_segments { + // TODO: Review + if let (true, global_idx) = + ectx.dctx.write().unwrap().add_instance(ZISK_AIRGROUP_ID, air_id, 1) + { + let trace: MemTrace<'_, _> = MemTrace::new(air_rows); + let mut buffer = trace.buffer.unwrap(); + + prover_buffers.lock().unwrap()[i] = buffer; + global_idxs[i] = global_idx; + } + } + + #[allow(clippy::needless_range_loop)] + for segment_id in 0..num_segments { + let is_last_segment = segment_id == num_segments - 1; + let input_offset = segment_id * air_rows; + let previous_segment = if (segment_id == 0) { + MemPreviousSegment { addr: RAM_W_ADDR_INIT, step: 0, value: 0 } + } else { + MemPreviousSegment { + addr: inputs[input_offset - 1].addr, + step: inputs[input_offset - 1].step, + value: inputs[input_offset - 1].value, + } + }; + let input_end = + if (input_offset + air_rows) > count { count } else { input_offset + air_rows }; + let mem_ops = &inputs[input_offset..input_end]; + let prover_buffer = std::mem::take(&mut prover_buffers.lock().unwrap()[segment_id]); + + self.prove_instance( + mem_ops, + segment_id, + is_last_segment, + &previous_segment, + prover_buffer, + air_rows, + global_idxs[segment_id], + ); + } + } + + /// Finalizes the witness accumulation process and triggers the proof generation. + /// + /// This method is invoked by the executor when no further witness data remains to be added. + /// + /// # Parameters + /// + /// - `mem_inputs`: A slice of all `MemoryInput` inputs + #[allow(clippy::too_many_arguments)] + pub fn prove_instance( + &self, + mem_ops: &[MemInput], + segment_id: usize, + is_last_segment: bool, + previous_segment: &MemPreviousSegment, + mut prover_buffer: Vec, + air_mem_rows: usize, + global_idx: usize, + ) -> Result<(), Box> { + assert!( + !mem_ops.is_empty() && mem_ops.len() <= air_mem_rows, + "MemSM: mem_ops.len()={} out of range {}", + mem_ops.len(), + air_mem_rows + ); + + // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR + // segments In a Memory AIR instance, the first row is reserved as a dummy row. + // This dummy row is used to facilitate the continuation state between different AIR + // segments. It ensures seamless transitions when multiple AIR segments are + // processed consecutively. This design avoids discontinuities in memory access + // patterns and ensures that the memory trace is continuous, For this reason we use + // AIR num_rows - 1 as the number of rows in each memory AIR instance + + // Create a vector of Mem0Row instances, one for each memory operation + // Recall that first row is a dummy row used for the continuations between AIR segments + // The length of the vector is the number of input memory operations plus one because + // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows + + let mut trace = MemTrace::::map_buffer(&mut prover_buffer, air_mem_rows, 0).unwrap(); + + let mut range_check_data: Vec = vec![0; MEMORY_MAX_DIFF as usize]; + + let mut air_values = MemAirValues { + segment_id: segment_id as u32, + is_first_segment: segment_id == 0, + is_last_segment, + previous_segment_addr: previous_segment.addr, + previous_segment_step: previous_segment.step, + previous_segment_value: [ + previous_segment.value as u32, + (previous_segment.value >> 32) as u32, + ], + ..MemAirValues::default() + }; + + // index it's value - 1, for this reason no add +1 + range_check_data[(previous_segment.addr - RAM_W_ADDR_INIT) as usize] += 1; // TODO + + // Fill the remaining rows + let mut last_addr: u32 = previous_segment.addr; + let mut last_step: u64 = previous_segment.step; + let mut last_value: u64 = previous_segment.value; + + for (i, mem_op) in mem_ops.iter().enumerate() { + trace[i].addr = F::from_canonical_u32(mem_op.addr); + trace[i].step = F::from_canonical_u64(mem_op.step); + trace[i].sel = F::from_bool(!mem_op.is_internal); + trace[i].wr = F::from_bool(mem_op.is_write); + + let (low_val, high_val) = self.get_u32_values(mem_op.value); + trace[i].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; + + let addr_changes = last_addr != mem_op.addr; + trace[i].addr_changes = if addr_changes { F::one() } else { F::zero() }; + + let increment = if addr_changes { + // (mem_op.addr - last_addr + if i == 0 && segment_id == 0 { 1 } else { 0 }) as u64 + (mem_op.addr - last_addr) as u64 + } else { + mem_op.step - last_step + }; + trace[i].increment = F::from_canonical_u64(increment); + + // Store the value of incremenet so it can be range checked + if increment <= MEMORY_MAX_DIFF || increment == 0 { + range_check_data[(increment - 1) as usize] += 1; + } else { + panic!("MemSM: increment's out of range: {} i:{} addr_changes:{} mem_op.addr:0x{:X} last_addr:0x{:X} mem_op.step:{} last_step:{}", + increment, i, addr_changes as u8, mem_op.addr, last_addr, mem_op.step, last_step); + } + + last_addr = mem_op.addr; + last_step = mem_op.step; + last_value = mem_op.value; + } + + // STEP3. Add dummy rows to the output vector to fill the remaining rows + // PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd + // = 1, wr = 0 + let last_row_idx = mem_ops.len() - 1; + let addr = trace[last_row_idx].addr; + let value = trace[last_row_idx].value; + + let padding_size = air_mem_rows - mem_ops.len(); + for i in mem_ops.len()..air_mem_rows { + last_step += 1; + trace[i].addr = addr; + trace[i].step = F::from_canonical_u64(last_step); + trace[i].sel = F::zero(); + trace[i].wr = F::zero(); + + trace[i].value = value; + + trace[i].addr_changes = F::zero(); + trace[i].increment = F::one(); + } + + air_values.segment_last_addr = last_addr; + air_values.segment_last_step = last_step; + air_values.segment_last_value[0] = last_value as u32; + air_values.segment_last_value[1] = (last_value >> 32) as u32; + + // Store the value of trivial increment so that they can be range checked + // value = 1 => index = 0 + range_check_data[0] += padding_size as u64; + + // no add extra +1 because index = value - 1 + // RAM_W_ADDR_END - last_addr + 1 - 1 = RAM_W_ADDR_END - last_addr + range_check_data[(RAM_W_ADDR_END - last_addr) as usize] += 1; // TODO + + // TODO: Perform the range checks + let range_id = self.std.get_range(BigInt::from(1), BigInt::from(MEMORY_MAX_DIFF), None); + for (value, &multiplicity) in range_check_data.iter().enumerate() { + if (multiplicity == 0) { + continue; + } + self.std.range_check( + F::from_canonical_usize(value + 1), + F::from_canonical_u64(multiplicity), + range_id, + ); + } + + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let sctx = wcm.get_sctx(); + + let mut air_instance = AirInstance::new( + sctx.clone(), + ZISK_AIRGROUP_ID, + MEM_AIR_IDS[0], + Some(segment_id), + prover_buffer, + ); + + self.set_airvalues("Mem", &mut air_instance, &air_values); + + pctx.air_instance_repo.add_air_instance(air_instance, Some(global_idx)); + + Ok(()) + } + + fn get_u32_values(&self, value: u64) -> (u32, u32) { + (value as u32, (value >> 32) as u32) + } + fn set_airvalues( + &self, + prefix: &str, + air_instance: &mut AirInstance, + air_values: &MemAirValues, + ) { + air_instance.set_airvalue( + format!("{}.segment_id", prefix).as_str(), + None, + F::from_canonical_u32(air_values.segment_id), + ); + air_instance.set_airvalue( + format!("{}.is_first_segment", prefix).as_str(), + None, + F::from_bool(air_values.is_first_segment), + ); + air_instance.set_airvalue( + format!("{}.is_last_segment", prefix).as_str(), + None, + F::from_bool(air_values.is_last_segment), + ); + air_instance.set_airvalue( + format!("{}.previous_segment_addr", prefix).as_str(), + None, + F::from_canonical_u32(air_values.previous_segment_addr), + ); + air_instance.set_airvalue( + format!("{}.previous_segment_step", prefix).as_str(), + None, + F::from_canonical_u64(air_values.previous_segment_step), + ); + air_instance.set_airvalue( + format!("{}.segment_last_addr", prefix).as_str(), + None, + F::from_canonical_u32(air_values.segment_last_addr), + ); + air_instance.set_airvalue( + format!("{}.segment_last_step", prefix).as_str(), + None, + F::from_canonical_u64(air_values.segment_last_step), + ); + let count = air_values.previous_segment_value.len(); + for i in 0..count { + air_instance.set_airvalue( + format!("{}.previous_segment_value", prefix).as_str(), + Some(vec![i as u64]), + F::from_canonical_u32(air_values.previous_segment_value[i]), + ); + air_instance.set_airvalue( + format!("{}.segment_last_value", prefix).as_str(), + Some(vec![i as u64]), + F::from_canonical_u32(air_values.segment_last_value[i]), + ); + } + } +} + +impl MemModule for MemSM { + fn send_inputs(&self, mem_op: &[MemInput]) { + self.prove(mem_op); + } + fn get_addr_ranges(&self) -> Vec<(u32, u32)> { + vec![(RAM_ADDR as u32, (RAM_ADDR + RAM_SIZE - 1) as u32)] + } + fn get_flush_input_size(&self) -> u32 { + 0 + } +} + +impl WitnessComponent for MemSM {} diff --git a/state-machines/mem/src/mem_traces.rs b/state-machines/mem/src/mem_traces.rs deleted file mode 100644 index c80a8c74..00000000 --- a/state-machines/mem/src/mem_traces.rs +++ /dev/null @@ -1,5 +0,0 @@ -use proofman_common as common; -pub use proofman_macros::trace; - -trace!(MemALignedRow, MemALignedTrace { fake: F }); -trace!(MemUnaLignedRow, MemUnaLignedTrace { fake: F}); diff --git a/state-machines/mem/src/mem_unaligned.rs b/state-machines/mem/src/mem_unaligned.rs deleted file mode 100644 index fde238e3..00000000 --- a/state-machines/mem/src/mem_unaligned.rs +++ /dev/null @@ -1,114 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use p3_field::Field; -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{MemUnalignedOp, OpResult, Provable}; -use zisk_pil::{MEM_AIRGROUP_ID, MEM_UNALIGNED_AIR_IDS}; - -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -pub struct MemUnalignedSM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs: Mutex>, -} - -#[allow(unused, unused_variables)] -impl MemUnalignedSM { - pub fn new(wcm: Arc>) -> Arc { - let mem_aligned_sm = - Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let mem_aligned_sm = Arc::new(mem_aligned_sm); - - wcm.register_component( - mem_aligned_sm.clone(), - Some(MEM_AIRGROUP_ID), - Some(MEM_UNALIGNED_AIR_IDS), - ); - - mem_aligned_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); - } - } - - fn read( - &self, - _addr: u64, - _width: usize, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } - - fn write( - &self, - _addr: u64, - _width: usize, - _val: u64, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } -} - -impl WitnessComponent for MemUnalignedSM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, - ) { - } -} - -impl Provable for MemUnalignedSM { - fn calculate(&self, operation: MemUnalignedOp) -> Result> { - match operation { - MemUnalignedOp::Read(addr, width) => self.read(addr, width), - MemUnalignedOp::Write(addr, width, val) => self.write(addr, width, val), - } - } - - fn prove(&self, operations: &[MemUnalignedOp], drain: bool, scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); - - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); - } - } - } - - fn calculate_prove( - &self, - operation: MemUnalignedOp, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} diff --git a/state-machines/mem/src/mem_unmapped.rs b/state-machines/mem/src/mem_unmapped.rs new file mode 100644 index 00000000..2ec61685 --- /dev/null +++ b/state-machines/mem/src/mem_unmapped.rs @@ -0,0 +1,35 @@ +use std::marker::PhantomData; + +use crate::{MemInput, MemModule}; +use p3_field::PrimeField; + +pub struct MemUnmapped { + ranges: Vec<(u32, u32)>, + __data: PhantomData, +} + +impl Default for MemUnmapped { + fn default() -> Self { + Self::new() + } +} + +impl MemUnmapped { + pub fn new() -> Self { + Self { ranges: Vec::new(), __data: PhantomData } + } + pub fn add_range(&mut self, _start: u32, _end: u32) { + self.ranges.push((_start, _end)); + } +} +impl MemModule for MemUnmapped { + fn send_inputs(&self, _mem_op: &[MemInput]) { + // panic!("[MemUnmapped] invalid access to addr {:x}", _mem_op[0].addr); + } + fn get_addr_ranges(&self) -> Vec<(u32, u32)> { + self.ranges.to_vec() + } + fn get_flush_input_size(&self) -> u32 { + 1 + } +} diff --git a/state-machines/mem/src/rom_data.rs b/state-machines/mem/src/rom_data.rs new file mode 100644 index 00000000..57243430 --- /dev/null +++ b/state-machines/mem/src/rom_data.rs @@ -0,0 +1,339 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use crate::{ + MemAirValues, MemInput, MemModule, MemPreviousSegment, MEMORY_MAX_DIFF, MEM_BYTES_BITS, +}; +use num_bigint::BigInt; +use p3_field::PrimeField; +use pil_std_lib::Std; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::AirInstance; +use zisk_core::{ROM_ADDR, ROM_ADDR_MAX}; +use zisk_pil::{RomDataTrace, ROM_DATA_AIR_IDS, ZISK_AIRGROUP_ID}; + +const ROM_W_ADDR: u32 = ROM_ADDR as u32 >> MEM_BYTES_BITS; +const ROM_W_ADDR_END: u32 = ROM_ADDR_MAX as u32 >> MEM_BYTES_BITS; + +const _: () = { + assert!( + (ROM_ADDR_MAX - ROM_ADDR) >> MEM_BYTES_BITS as u64 <= MEMORY_MAX_DIFF, + "ROM_DATA is too large" + ); + assert!(ROM_ADDR_MAX <= 0xFFFF_FFFF, "ROM_DATA memory exceeds the 32-bit addressable range"); +}; + +pub struct RomDataSM { + // Witness computation manager + wcm: Arc>, + + // STD + std: Arc>, + + num_rows: usize, + // Count of registered predecessors + registered_predecessors: AtomicU32, +} + +#[allow(unused, unused_variables)] +impl RomDataSM { + pub fn new(wcm: Arc>, std: Arc>) -> Arc { + let pctx = wcm.get_pctx(); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, ROM_DATA_AIR_IDS[0]); + let rom_data_sm = Self { + wcm: wcm.clone(), + std: std.clone(), + num_rows: air.num_rows(), + registered_predecessors: AtomicU32::new(0), + }; + let rom_data_sm = Arc::new(rom_data_sm); + + wcm.register_component(rom_data_sm.clone(), Some(ZISK_AIRGROUP_ID), Some(ROM_DATA_AIR_IDS)); + std.register_predecessor(); + + rom_data_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + let pctx = self.wcm.get_pctx(); + self.std.unregister_predecessor(pctx, None); + } + } + + pub fn prove(&self, inputs: &[MemInput]) { + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let ectx = wcm.get_ectx(); + let sctx = wcm.get_sctx(); + + // PRE: proxy calculate if exists jmp on step out-of-range, adding internal inputs + // memory only need to process these special inputs, but inputs no change. At end of + // inputs proxy add an extra internal input to jump to last address + + let air_id = ROM_DATA_AIR_IDS[0]; + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, air_id); + let air_rows = air.num_rows(); + + // at least one row to go + let count = inputs.len(); + let count_rem = count % air_rows; + let num_segments = (count / air_rows) + if count_rem > 0 { 1 } else { 0 }; + + let mut prover_buffers = Mutex::new(vec![Vec::new(); num_segments]); + let mut global_idxs = vec![0; num_segments]; + + #[allow(clippy::needless_range_loop)] + for i in 0..num_segments { + // TODO: Review + if let (true, global_idx) = + ectx.dctx.write().unwrap().add_instance(ZISK_AIRGROUP_ID, air_id, 1) + { + let trace: RomDataTrace<'_, _> = RomDataTrace::new(air_rows); + let mut buffer = trace.buffer.unwrap(); + prover_buffers.lock().unwrap()[i] = buffer; + global_idxs[i] = global_idx; + } + } + + #[allow(clippy::needless_range_loop)] + for segment_id in 0..num_segments { + let is_last_segment = segment_id == num_segments - 1; + let input_offset = segment_id * air_rows; + let previous_segment = if (segment_id == 0) { + MemPreviousSegment { addr: ROM_W_ADDR, step: 0, value: 0 } + } else { + MemPreviousSegment { + addr: inputs[input_offset - 1].addr, + step: inputs[input_offset - 1].step, + value: inputs[input_offset - 1].value, + } + }; + let input_end = + if (input_offset + air_rows) > count { count } else { input_offset + air_rows }; + let mem_ops = &inputs[input_offset..input_end]; + let prover_buffer = std::mem::take(&mut prover_buffers.lock().unwrap()[segment_id]); + + self.prove_instance( + mem_ops, + segment_id, + is_last_segment, + &previous_segment, + prover_buffer, + air_rows, + global_idxs[segment_id], + ); + } + } + + /// Finalizes the witness accumulation process and triggers the proof generation. + /// + /// This method is invoked by the executor when no further witness data remains to be added. + /// + /// # Parameters + /// + /// - `mem_inputs`: A slice of all `MemoryInput` inputs + #[allow(clippy::too_many_arguments)] + pub fn prove_instance( + &self, + mem_ops: &[MemInput], + segment_id: usize, + is_last_segment: bool, + previous_segment: &MemPreviousSegment, + mut prover_buffer: Vec, + air_mem_rows: usize, + global_idx: usize, + ) -> Result<(), Box> { + assert!( + !mem_ops.is_empty() && mem_ops.len() <= air_mem_rows, + "RomDataSM: mem_ops.len()={} out of range {}", + mem_ops.len(), + air_mem_rows + ); + + // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR + // segments In a Memory AIR instance, the first row is reserved as a dummy row. + // This dummy row is used to facilitate the continuation state between different AIR + // segments. It ensures seamless transitions when multiple AIR segments are + // processed consecutively. This design avoids discontinuities in memory access + // patterns and ensures that the memory trace is continuous, For this reason we use + // AIR num_rows - 1 as the number of rows in each memory AIR instance + + // Create a vector of Mem0Row instances, one for each memory operation + // Recall that first row is a dummy row used for the continuations between AIR segments + // The length of the vector is the number of input memory operations plus one because + // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows + + let mut trace = RomDataTrace::::map_buffer(&mut prover_buffer, air_mem_rows, 0).unwrap(); + + let mut air_values = MemAirValues { + segment_id: segment_id as u32, + is_first_segment: segment_id == 0, + is_last_segment, + previous_segment_addr: previous_segment.addr, + previous_segment_step: previous_segment.step, + previous_segment_value: [ + previous_segment.value as u32, + (previous_segment.value >> 32) as u32, + ], + ..MemAirValues::default() + }; + + // range of instance + let range_id = self.std.get_range(BigInt::from(1), BigInt::from(MEMORY_MAX_DIFF), None); + self.std.range_check( + F::from_canonical_u32(previous_segment.addr - ROM_W_ADDR + 1), + F::one(), + range_id, + ); + + // Fill the remaining rows + let mut last_addr: u32 = previous_segment.addr; + let mut last_step: u64 = previous_segment.step; + let mut last_value: u64 = previous_segment.value; + + for (i, mem_op) in mem_ops.iter().enumerate() { + trace[i].addr = F::from_canonical_u32(mem_op.addr); + trace[i].step = F::from_canonical_u64(mem_op.step); + trace[i].sel = F::from_bool(!mem_op.is_internal); + + let (low_val, high_val) = self.get_u32_values(mem_op.value); + trace[i].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; + + let addr_changes = last_addr != mem_op.addr; + trace[i].addr_changes = + if addr_changes || (i == 0 && segment_id == 0) { F::one() } else { F::zero() }; + + last_addr = mem_op.addr; + last_step = mem_op.step; + last_value = mem_op.value; + } + + // STEP3. Add dummy rows to the output vector to fill the remaining rows + // PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd + // = 1, wr = 0 + let last_row_idx = mem_ops.len() - 1; + let addr = trace[last_row_idx].addr; + let value = trace[last_row_idx].value; + + let padding_size = air_mem_rows - mem_ops.len(); + for i in mem_ops.len()..air_mem_rows { + last_step += 1; + trace[i].addr = addr; + trace[i].step = F::from_canonical_u64(last_step); + trace[i].sel = F::zero(); + + trace[i].value = value; + + trace[i].addr_changes = F::zero(); + } + + air_values.segment_last_addr = last_addr; + air_values.segment_last_step = last_step; + air_values.segment_last_value[0] = last_value as u32; + air_values.segment_last_value[1] = (last_value >> 32) as u32; + + self.std.range_check( + F::from_canonical_u32(ROM_W_ADDR_END - last_addr + 1), + F::one(), + range_id, + ); + + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let sctx = wcm.get_sctx(); + + let mut air_instance = AirInstance::new( + sctx.clone(), + ZISK_AIRGROUP_ID, + ROM_DATA_AIR_IDS[0], + Some(segment_id), + prover_buffer, + ); + + self.set_airvalues("RomData", &mut air_instance, &air_values); + + pctx.air_instance_repo.add_air_instance(air_instance, Some(global_idx)); + + Ok(()) + } + + fn get_u32_values(&self, value: u64) -> (u32, u32) { + (value as u32, (value >> 32) as u32) + } + fn set_airvalues( + &self, + prefix: &str, + air_instance: &mut AirInstance, + air_values: &MemAirValues, + ) { + air_instance.set_airvalue( + format!("{}.segment_id", prefix).as_str(), + None, + F::from_canonical_u32(air_values.segment_id), + ); + air_instance.set_airvalue( + format!("{}.is_first_segment", prefix).as_str(), + None, + F::from_bool(air_values.is_first_segment), + ); + air_instance.set_airvalue( + format!("{}.is_last_segment", prefix).as_str(), + None, + F::from_bool(air_values.is_last_segment), + ); + air_instance.set_airvalue( + format!("{}.previous_segment_addr", prefix).as_str(), + None, + F::from_canonical_u32(air_values.previous_segment_addr), + ); + air_instance.set_airvalue( + format!("{}.previous_segment_step", prefix).as_str(), + None, + F::from_canonical_u64(air_values.previous_segment_step), + ); + air_instance.set_airvalue( + format!("{}.segment_last_addr", prefix).as_str(), + None, + F::from_canonical_u32(air_values.segment_last_addr), + ); + air_instance.set_airvalue( + format!("{}.segment_last_step", prefix).as_str(), + None, + F::from_canonical_u64(air_values.segment_last_step), + ); + let count = air_values.previous_segment_value.len(); + for i in 0..count { + air_instance.set_airvalue( + format!("{}.previous_segment_value", prefix).as_str(), + Some(vec![i as u64]), + F::from_canonical_u32(air_values.previous_segment_value[i]), + ); + air_instance.set_airvalue( + format!("{}.segment_last_value", prefix).as_str(), + Some(vec![i as u64]), + F::from_canonical_u32(air_values.segment_last_value[i]), + ); + } + } +} + +impl MemModule for RomDataSM { + fn send_inputs(&self, mem_op: &[MemInput]) { + self.prove(mem_op); + } + fn get_addr_ranges(&self) -> Vec<(u32, u32)> { + vec![(ROM_ADDR as u32, ROM_ADDR_MAX as u32)] + } + fn get_flush_input_size(&self) -> u32 { + self.num_rows as u32 + } +} + +impl WitnessComponent for RomDataSM {} diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index 8d89c2dc..52a2b591 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -9,16 +9,18 @@ use rayon::prelude::*; use sm_arith::ArithSM; use sm_binary::BinarySM; use sm_main::{InstanceExtensionCtx, MainSM}; -use sm_mem::MemSM; +use sm_mem::MemProxy; use sm_rom::RomSM; use std::{ fs, path::{Path, PathBuf}, sync::Arc, + thread, }; use zisk_core::{Riscv2zisk, ZiskOperationType, ZiskRom, ZISK_OPERATION_TYPE_VARIANTS}; use zisk_pil::{ - ARITH_AIR_IDS, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ZISK_AIRGROUP_ID, + ARITH_AIR_IDS, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ROM_AIR_IDS, + ZISK_AIRGROUP_ID, }; use ziskemu::{EmuOptions, ZiskEmulator}; @@ -33,7 +35,7 @@ pub struct ZiskExecutor { pub rom_sm: Arc>, /// Memory State Machine - pub mem_sm: Arc, + pub mem_proxy_sm: Arc>, /// Binary State Machine pub binary_sm: Arc>, @@ -49,7 +51,7 @@ impl ZiskExecutor { let std = Std::new(wcm.clone()); let rom_sm = RomSM::new(wcm.clone()); - let mem_sm = MemSM::new(wcm.clone()); + let mem_proxy_sm = MemProxy::new(wcm.clone(), std.clone()); let binary_sm = BinarySM::new(wcm.clone(), std.clone()); let arith_sm = ArithSM::new(wcm.clone(), binary_sm.clone()); @@ -81,9 +83,10 @@ impl ZiskExecutor { // TODO - If there is more than one Main AIR available, the MAX_ACCUMULATED will be the one // with the highest num_rows. It has to be a power of 2. - let main_sm = MainSM::new(wcm.clone(), arith_sm.clone(), binary_sm.clone(), mem_sm.clone()); + let main_sm = + MainSM::new(wcm.clone(), mem_proxy_sm.clone(), arith_sm.clone(), binary_sm.clone()); - Self { zisk_rom, main_sm, rom_sm, mem_sm, binary_sm, arith_sm } + Self { zisk_rom, main_sm, rom_sm, mem_proxy_sm, binary_sm, arith_sm } } /// Executes the MainSM state machine and processes the inputs in batches when the maximum @@ -118,6 +121,7 @@ impl ZiskExecutor { let path = PathBuf::from(public_inputs_path.display().to_string()); fs::read(path).expect("Could not read inputs file") }; + let public_inputs = Arc::new(public_inputs); // During ROM processing, we gather execution data necessary for creating the AIR instances. // This data is collected by the emulator and includes the minimal execution trace, @@ -137,17 +141,36 @@ impl ZiskExecutor { op_sizes[ZiskOperationType::Binary as usize] = air_binary.num_rows() as u64; op_sizes[ZiskOperationType::BinaryE as usize] = air_binary_e.num_rows() as u64; + // STEP 1. Generate all inputs + // ============================================== + + // Memory State Machine + // ---------------------------------------------- + let mem_thread = thread::spawn({ + let zisk_rom = self.zisk_rom.clone(); + let public_inputs = public_inputs.clone(); + move || { + ZiskEmulator::par_process_rom_memory::(&zisk_rom, &public_inputs) + .expect("Failed in ZiskEmulator::par_process_rom_memory") + } + }); + // ROM State Machine // ---------------------------------------------- // Run the ROM to compute the ROM witness - let rom_sm = self.rom_sm.clone(); - let zisk_rom = self.zisk_rom.clone(); - let pc_histogram = - ZiskEmulator::process_rom_pc_histogram(&self.zisk_rom, &public_inputs, &emu_options) - .expect( - "MainSM::execute() failed calling ZiskEmulator::process_rom_pc_histogram()", - ); - let handle_rom = std::thread::spawn(move || rom_sm.prove(&zisk_rom, pc_histogram)); + let rom_thread = thread::spawn({ + let zisk_rom = self.zisk_rom.clone(); + let public_inputs = public_inputs.clone(); + let emu_options_cloned = emu_options.clone(); + move || { + ZiskEmulator::process_rom_pc_histogram( + &zisk_rom, + &public_inputs, + &emu_options_cloned, + ) + .expect("MainSM::execute() failed calling ZiskEmulator::process_rom_pc_histogram()") + } + }); // Main, Binary and Arith State Machines // ---------------------------------------------- @@ -164,10 +187,43 @@ impl ZiskExecutor { .expect("Error during emulator execution"); timer_stop_and_log_debug!(PAR_PROCESS_ROM); - emu_slices.points.sort_by(|a, b| a.op_type.partial_cmp(&b.op_type).unwrap()); + // STEP 2. Wait until all inputs are generated + // ============================================== + // Join all the threads to synchronize the execution + let mut mem_required = mem_thread.join().expect("Error during Memory witness computation"); + let rom_required = rom_thread.join().expect("Error during ROM witness computation"); + + // STEP 3. Generate AIRs and Prove + // ============================================== - // Join threads to synchronize the execution - handle_rom.join().unwrap().expect("Error during ROM witness computation"); + // Memory State Machine + // ---------------------------------------------- + let mem_thread = thread::spawn({ + let mem_proxy_sm = self.mem_proxy_sm.clone(); + move || { + mem_proxy_sm + .prove(&mut mem_required) + .expect("Error during Memory witness computation") + } + }); + + // ROM State Machine + // ---------------------------------------------- + let (rom_is_mine, _rom_instance_gid) = + ectx.dctx.write().unwrap().add_instance(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0], 1); + + let rom_thread = if rom_is_mine { + let rom_sm = self.rom_sm.clone(); + let zisk_rom = self.zisk_rom.clone(); + + Some(thread::spawn(move || rom_sm.prove(&zisk_rom, rom_required))) + } else { + None + }; + + // Main, Binary and Arith State Machines + // ---------------------------------------------- + emu_slices.points.sort_by(|a, b| a.op_type.partial_cmp(&b.op_type).unwrap()); // FIXME: Move InstanceExtensionCtx form main SM to another place let mut instances_extension_ctx: Vec> = @@ -232,7 +288,28 @@ impl ZiskExecutor { } timer_stop_and_log_debug!(ADD_INSTANCES_TO_THE_REPO); - // self.mem_sm.unregister_predecessor(scope); + mem_thread.join().expect("Error during Memory witness computation"); + + // match mem_thread.join() { + // Ok(_) => println!("El thread ha finalitzat correctament."), + // Err(e) => { + // println!("El thread ha fet panic!"); + // + // // Converteix l'error en una cadena llegible (opcional) + // if let Some(missatge) = e.downcast_ref::<&str>() { + // println!("Missatge d'error: {}", missatge); + // } else if let Some(missatge) = e.downcast_ref::() { + // println!("Missatge d'error: {}", missatge); + // } else { + // println!("No es pot determinar el tipus d'error."); + // } + // } + // } + if let Some(thread) = rom_thread { + let _ = thread.join().expect("Error during ROM witness computation"); + } + + self.mem_proxy_sm.unregister_predecessor(); self.binary_sm.unregister_predecessor(); self.arith_sm.unregister_predecessor(); }