From 7d81f15ab46cf775c014501d33ff84536bc4afa0 Mon Sep 17 00:00:00 2001 From: Owen Rodley Date: Mon, 18 Nov 2024 15:29:48 +1100 Subject: [PATCH] Add APX register support --- gematria/datasets/block_wrapper.S | 55 +++++++++++-------- gematria/datasets/find_accessed_addrs.cc | 6 +- gematria/datasets/find_accessed_addrs.h | 52 +++++++++++------- gematria/datasets/find_accessed_addrs_test.cc | 22 ++++++++ 4 files changed, 92 insertions(+), 43 deletions(-) diff --git a/gematria/datasets/block_wrapper.S b/gematria/datasets/block_wrapper.S index 60fdee78..176b736f 100644 --- a/gematria/datasets/block_wrapper.S +++ b/gematria/datasets/block_wrapper.S @@ -26,10 +26,10 @@ // See the "WrappedFunc" typedef for the function signature this code has. Since // it doesn't return we make no guarantees about preserving registers / stack // frame, but we do use the normal calling convention for input parameters. -// TODO(orodley): Update to support r16-r31. gematria_prologue: mov rax, [rdi] // rax = vector_reg_width mov rbx, [rdi + 8] // rbx = uses_upper_vector_regs + mov rcx, [rdi + 16] // rcx = uses_apx_regs cmp rax, 0 je set_int_registers cmp rax, 1 @@ -40,55 +40,64 @@ gematria_prologue: set_xmm_registers: .irp n, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 - vpbroadcastq xmm\n, [rdi + 0x90 + (8 * \n)] + vpbroadcastq xmm\n, [rdi + 0x118 + (8 * \n)] .endr cmp rbx, 0 je set_int_registers set_upper_xmm_registers: .irp n, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 - vpbroadcastq xmm\n, [rdi + 0x90 + (8 * \n)] + vpbroadcastq xmm\n, [rdi + 0x118 + (8 * \n)] .endr jmp set_int_registers set_ymm_registers: .irp n, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 - vpbroadcastq ymm\n, [rdi + 0x90 + (8 * \n)] + vpbroadcastq ymm\n, [rdi + 0x118 + (8 * \n)] .endr cmp rbx, 0 je set_int_registers .irp n, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 - vpbroadcastq ymm\n, [rdi + 0x90 + (8 * \n)] + vpbroadcastq ymm\n, [rdi + 0x118 + (8 * \n)] .endr jmp set_int_registers set_zmm_registers: .irp n, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 - vpbroadcastq zmm\n, [rdi + 0x90 + (8 * \n)] + vpbroadcastq zmm\n, [rdi + 0x118 + (8 * \n)] .endr cmp rbx, 0 je set_int_registers .irp n, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 - vpbroadcastq zmm\n, [rdi + 0x90 + (8 * \n)] + vpbroadcastq zmm\n, [rdi + 0x118 + (8 * \n)] .endr set_int_registers: mov r15, rdi - mov rax, [r15 + 0x10] - mov rbx, [r15 + 0x18] - mov rcx, [r15 + 0x20] - mov rdx, [r15 + 0x28] - mov rsi, [r15 + 0x30] - mov rdi, [r15 + 0x38] - mov rsp, [r15 + 0x40] - mov rbp, [r15 + 0x48] - mov r8, [r15 + 0x50] - mov r9, [r15 + 0x58] - mov r10, [r15 + 0x60] - mov r11, [r15 + 0x68] - mov r12, [r15 + 0x70] - mov r13, [r15 + 0x78] - mov r14, [r15 + 0x80] - mov r15, [r15 + 0x88] + // Set APX registers first if necessary, as we need to overwrite RDI as part + // of setting the base registers. + cmp rcx, 0 + je set_base_registers +.irp n, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 + mov r\n, [rdi + 0x98 + (8 * \n)] +.endr + +set_base_registers: + mov rax, [r15 + 0x18] + mov rbx, [r15 + 0x20] + mov rcx, [r15 + 0x28] + mov rdx, [r15 + 0x30] + mov rsi, [r15 + 0x38] + mov rdi, [r15 + 0x40] + mov rsp, [r15 + 0x48] + mov rbp, [r15 + 0x50] + mov r8, [r15 + 0x58] + mov r9, [r15 + 0x60] + mov r10, [r15 + 0x68] + mov r11, [r15 + 0x70] + mov r12, [r15 + 0x78] + mov r13, [r15 + 0x80] + mov r14, [r15 + 0x88] + mov r15, [r15 + 0x90] _gematria_prologue_size = . - gematria_prologue .size gematria_prologue, _gematria_prologue_size diff --git a/gematria/datasets/find_accessed_addrs.cc b/gematria/datasets/find_accessed_addrs.cc index a70c5f02..dc164f3e 100644 --- a/gematria/datasets/find_accessed_addrs.cc +++ b/gematria/datasets/find_accessed_addrs.cc @@ -180,6 +180,7 @@ RawX64Regs ToRawRegs( RawX64Regs raw_regs; raw_regs.max_vector_reg_width = VectorRegWidth::NONE; raw_regs.uses_upper_vector_regs = 0; + raw_regs.uses_apx_regs = 0; for (const RegisterAndValue& reg_and_value : regs) { if (reg_and_value.register_name() == "RAX") { @@ -229,7 +230,10 @@ RawX64Regs ToRawRegs( } VectorRegWidth vector_width = VectorRegWidth::NONE; - if (absl::StartsWith(reg_and_value.register_name(), "XMM")) { + if (reg_and_value.register_name()[0] == 'R') { + raw_regs.apx_regs[number_suffix - 16] = reg_and_value.register_value(); + raw_regs.uses_apx_regs = 1; + } else if (absl::StartsWith(reg_and_value.register_name(), "XMM")) { vector_width = VectorRegWidth::XMM; raw_regs.vector_regs[number_suffix] = reg_and_value.register_value(); } else if (absl::StartsWith(reg_and_value.register_name(), "YMM")) { diff --git a/gematria/datasets/find_accessed_addrs.h b/gematria/datasets/find_accessed_addrs.h index 457725f0..29c55b69 100644 --- a/gematria/datasets/find_accessed_addrs.h +++ b/gematria/datasets/find_accessed_addrs.h @@ -39,26 +39,40 @@ enum class VectorRegWidth : uint64_t { // bytes large, so that there will be no padding and calculating offsets by hand // is easy (as is required in our assembly prologue code). struct RawX64Regs { - VectorRegWidth max_vector_reg_width; - uint64_t uses_upper_vector_regs; - int64_t rax; - int64_t rbx; - int64_t rcx; - int64_t rdx; - int64_t rsi; - int64_t rdi; - int64_t rsp; - int64_t rbp; - int64_t r8; - int64_t r9; - int64_t r10; - int64_t r11; - int64_t r12; - int64_t r13; - int64_t r14; - int64_t r15; + VectorRegWidth max_vector_reg_width; // offset 0x0 + // If true, the code uses at least one of the 16 extra vector registers + // defined in AVX-512. This is interpreted in combination with the max width. + // For example, if max_vector_reg_width is XMM and uses_upper_vector_regs is + // true, then the code uses XMM0-XMM31 but no YMM or ZMM registers. + // + // If this is false, then the latter 16 elements of vector_regs are unset and + // should be ignored. + uint64_t uses_upper_vector_regs; // offset 0x8 + // If true, the code uses at least one of the 16 extra general purpose + // registers defined in APX. + // + // If this is false, then the elements of apx_regs are unset and should be + // ignored. + uint64_t uses_apx_regs; // offset 0x10 + int64_t rax; // offset 0x18 + int64_t rbx; // offset 0x20 + int64_t rcx; // offset 0x28 + int64_t rdx; // offset 0x30 + int64_t rsi; // offset 0x38 + int64_t rdi; // offset 0x40 + int64_t rsp; // offset 0x48 + int64_t rbp; // offset 0x50 + int64_t r8; // offset 0x58 + int64_t r9; // offset 0x60 + int64_t r10; // offset 0x68 + int64_t r11; // offset 0x70 + int64_t r12; // offset 0x78 + int64_t r13; // offset 0x80 + int64_t r14; // offset 0x88 + int64_t r15; // offset 0x90 - int64_t vector_regs[32]; + int64_t apx_regs[16]; // offset 0x98 + int64_t vector_regs[32]; // offset 0x118 }; // Given a basic block of code, attempt to determine what addresses that code diff --git a/gematria/datasets/find_accessed_addrs_test.cc b/gematria/datasets/find_accessed_addrs_test.cc index 56c871a6..afa0751d 100644 --- a/gematria/datasets/find_accessed_addrs_test.cc +++ b/gematria/datasets/find_accessed_addrs_test.cc @@ -289,5 +289,27 @@ TEST_F(FindAccessedAddrsAvx512Test, UpperZmmRegister) { )pb")))); } +class FindAccessedAddrsApxTest : public FindAccessedAddrsTest { + protected: + void SetUp() override { + if (!__builtin_cpu_supports("apxf")) { + GTEST_SKIP() << "Host doesn't support APX"; + } + + FindAccessedAddrsTest::SetUp(); + } +}; + +TEST_F(FindAccessedAddrsApxTest, UpperGpr) { + EXPECT_THAT( + FindAccessedAddrsAsm(R"asm( + mov r23, [r30] + )asm"), + IsOkAndHolds(Partially(EqualsProto(R"pb( + accessed_blocks: 0x15000 + initial_registers: { register_name: "R30" register_value: 0x15000 } + )pb")))); +} + } // namespace } // namespace gematria