From f7b4c2bdc43b1ea2c6bbbd11146d8760fabc0888 Mon Sep 17 00:00:00 2001 From: Jack O'Connor Date: Mon, 18 Sep 2023 08:09:33 -0700 Subject: [PATCH] riscv universal_hash passing all tests --- rust/guts/src/riscv64gcv.S | 134 +++++++++++++++++++++++++++++-------- 1 file changed, 105 insertions(+), 29 deletions(-) diff --git a/rust/guts/src/riscv64gcv.S b/rust/guts/src/riscv64gcv.S index 2682619b9..424b95a90 100644 --- a/rust/guts/src/riscv64gcv.S +++ b/rust/guts/src/riscv64gcv.S @@ -1618,17 +1618,7 @@ blake3_guts_riscv64gcv_xof_xor_partial_block: .global blake3_guts_riscv64gcv_universal_hash blake3_guts_riscv64gcv_universal_hash: // t0 := full_blocks := input_len / 64 - // TODO: handle the partial block at the end srli t0, a1, 6 - // Load the counter. - vsetvli zero, t0, e64, m2, ta, ma - vmv.v.x v8, a3 - vid.v v10 - vadd.vv v8, v8, v10 - vsetvli zero, t0, e32, m1, ta, ma - vncvt.x.x.w v12, v8 - li t1, 32 - vnsrl.wx v13, v8, t1 // Load and transpose full message blocks. These are "strided segment // loads". Each vlsseg8e32 instruction transposes 8 words from multiple // message blocks into 8 registers, so we need two vlsseg8e32 @@ -1639,30 +1629,44 @@ blake3_guts_riscv64gcv_universal_hash: // RISC-V ABI allows misaligned loads and stores. If we need to support // an environment that doesn't allow them (or where they're // unacceptably slow), we could add a fallback here. + vsetvli zero, t0, e32, m1, ta, ma li t1, 64 addi t2, a0, 32 vlsseg8e32.v v16, (a0), t1 vlsseg8e32.v v24, (t2), t1 - // Broadcast the key to v0-7. - lw t0, 0(a2) - vmv.v.x v0, t0 - lw t0, 4(a2) - vmv.v.x v1, t0 - lw t0, 8(a2) - vmv.v.x v2, t0 - lw t0, 12(a2) - vmv.v.x v3, t0 - lw t0, 16(a2) - vmv.v.x v4, t0 - lw t0, 20(a2) - vmv.v.x v5, t0 - lw t0, 24(a2) - vmv.v.x v6, t0 - lw t0, 28(a2) - vmv.v.x v7, t0 // Broadcast the block length. li t1, 64 vmv.v.x v14, t1 + // If there's a partial block, handle it in an out-of-line branch. + andi t1, a1, 63 + bnez t1, universal_hash_handle_partial_block +universal_hash_partial_block_finished: + // Broadcast the key to v0-7. + lw t1, 0(a2) + vmv.v.x v0, t1 + lw t1, 4(a2) + vmv.v.x v1, t1 + lw t1, 8(a2) + vmv.v.x v2, t1 + lw t1, 12(a2) + vmv.v.x v3, t1 + lw t1, 16(a2) + vmv.v.x v4, t1 + lw t1, 20(a2) + vmv.v.x v5, t1 + lw t1, 24(a2) + vmv.v.x v6, t1 + lw t1, 28(a2) + vmv.v.x v7, t1 + // Load the counter. + vsetvli zero, t0, e64, m2, ta, ma + vmv.v.x v8, a3 + vid.v v10 + vadd.vv v8, v8, v10 + vsetvli zero, t0, e32, m1, ta, ma + vncvt.x.x.w v12, v8 + li t1, 32 + vnsrl.wx v13, v8, t1 // Broadcast the flags. li t1, CHUNK_START | CHUNK_END | ROOT | KEYED_HASH vmv.v.x v15, t1 @@ -1670,7 +1674,7 @@ blake3_guts_riscv64gcv_universal_hash: mv t6, ra call blake3_guts_riscv64gcv_kernel mv ra, t6 - // XOR the first four words. The rest are dropped. + // Finish the first four state vectors. The rest are dropped. vxor.vv v0, v0, v8 vxor.vv v1, v1, v9 vxor.vv v2, v2, v10 @@ -1690,5 +1694,77 @@ blake3_guts_riscv64gcv_universal_hash: sw t0, 8(a4) vmv.x.s t0, v3 sw t0, 12(a4) - ret +universal_hash_handle_partial_block: + // Load the partial block into v8-v11. With LMUL=4, v8 is guaranteed to + // hold at least 64 bytes. Zero all 64 bytes first, for block padding. + // The block length is already in t1. + li t2, 64 + vsetvli zero, t2, e8, m4, ta, ma + vmv.v.i v8, 0 + vsetvli zero, t1, e8, m4, ta, ma + add t2, a0, a1 + sub t2, t2, t1 + vle8.v v8, (t2) + // If VLEN is longer than 128 bits (16 bytes), then half or all of the + // block bytes will be in v8. Make sure they're split evenly across + // v8-v11. + csrr t2, vlenb + li t3, 64 + bltu t2, t3, universal_hash_vlenb_less_than_64 + vsetivli zero, 8, e32, m1, ta, ma + vslidedown.vi v9, v8, 8 +universal_hash_vlenb_less_than_64: + li t3, 32 + bltu t2, t3, universal_hash_vlenb_less_than_32 + vsetivli zero, 4, e32, m1, ta, ma + vmv.v.v v10, v9 + vslidedown.vi v11, v9, 4 + vslidedown.vi v9, v8, 4 +universal_hash_vlenb_less_than_32: + // Shift each of the words of the padded partial block to the end of + // the corresponding message vector. t0 was previously the number of + // full blocks. Now we increment it, so that it's the number of all + // blocks (both full and partial). + mv t2, t0 + addi t0, t0, 1 + // Set vl to at least 4, because v8-v11 each have 4 message words. + // Setting vl shorter will make vslide1down clobber those words. + li t3, 4 + maxu t3, t0, t3 + vsetvli zero, t3, e32, m1, ta, ma + vslideup.vx v16, v8, t2 + vslide1down.vx v8, v8, zero + vslideup.vx v17, v8, t2 + vslide1down.vx v8, v8, zero + vslideup.vx v18, v8, t2 + vslide1down.vx v8, v8, zero + vslideup.vx v19, v8, t2 + vslideup.vx v20, v9, t2 + vslide1down.vx v9, v9, zero + vslideup.vx v21, v9, t2 + vslide1down.vx v9, v9, zero + vslideup.vx v22, v9, t2 + vslide1down.vx v9, v9, zero + vslideup.vx v23, v9, t2 + vslideup.vx v24, v10, t2 + vslide1down.vx v10, v10, zero + vslideup.vx v25, v10, t2 + vslide1down.vx v10, v10, zero + vslideup.vx v26, v10, t2 + vslide1down.vx v10, v10, zero + vslideup.vx v27, v10, t2 + vslideup.vx v28, v11, t2 + vslide1down.vx v11, v11, zero + vslideup.vx v29, v11, t2 + vslide1down.vx v11, v11, zero + vslideup.vx v30, v11, t2 + vslide1down.vx v11, v11, zero + vslideup.vx v31, v11, t2 + // Set the updated VL. + vsetvli zero, t0, e32, m1, ta, ma + // Append the final block length, still in t1. + vmv.v.x v8, t1 + addi t2, t0, -1 + vslideup.vx v14, v8, t2 + j universal_hash_partial_block_finished