From e8b550fbb53f388fe4274a3c2a726d35adbf7ab7 Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Tue, 7 Jan 2025 00:13:06 +0100 Subject: [PATCH] ARM64: add standalone bigint multiplication --- .../assembly/limbs_asm_mul_arm64.nim | 136 ++++++++++++++++++ .../assembly/limbs_asm_mul_mont_arm64.nim | 1 + .../assembly/limbs_asm_mul_x86_adx_bmi2.nim | 4 +- constantine/math/arithmetic/limbs_extmul.nim | 4 + .../isa_arm64/macro_assembler_arm64.nim | 16 +++ 5 files changed, 159 insertions(+), 2 deletions(-) create mode 100644 constantine/math/arithmetic/assembly/limbs_asm_mul_arm64.nim diff --git a/constantine/math/arithmetic/assembly/limbs_asm_mul_arm64.nim b/constantine/math/arithmetic/assembly/limbs_asm_mul_arm64.nim new file mode 100644 index 000000000..0fe794d1a --- /dev/null +++ b/constantine/math/arithmetic/assembly/limbs_asm_mul_arm64.nim @@ -0,0 +1,136 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + # Standard library + std/macros, + # Internal + constantine/platforms/abstractions + +# ############################################################ +# +# Assembly implementation of bigint multiplication +# +# ############################################################ + +static: doAssert UseASM_ARM64 + +macro mul_gen[rLen, aLen, bLen: static int]( + r_PIR: var Limbs[rLen], + a_PIR: Limbs[aLen], + b_PIR: Limbs[bLen]) = + ## `a`, `b`, `r` can have a different number of limbs + ## if `r`.limbs.len < a.limbs.len + b.limbs.len + ## The result will be truncated, i.e. it will be + ## a * b (mod (2^WordBitWidth)^r.limbs.len) + ## + ## Assumes r doesn't alias a or b + + result = newStmtList() + + var ctx = init(Assembler_arm64, BaseType) + let + r = asmArray(r_PIR, rLen, PointerInReg, asmInput, memIndirect = memWrite) + a = asmArray(a_PIR, aLen, PointerInReg, asmInput, memIndirect = memRead) + b = asmArray(b_PIR, bLen, PointerInReg, asmInput, memIndirect = memRead) + + tSym = ident"t" + tSlots = aLen+1 # Extra for high words + + biSym = ident"bi" + bi = asmValue(biSym, Reg, asmOutputEarlyClobber) + + aaSym = ident"aa" + aa = asmArray(aaSym, aLen, ElemsInReg, asmInputOutput) + + uSym = ident"u" + vSym = ident"v" + + var t = asmArray(tSym, tSlots, ElemsInReg, asmOutputEarlyClobber) + + var # Break dependencies chain + u = asmValue(uSym, Reg, asmOutputEarlyClobber) + v = asmValue(vSym, Reg, asmOutputEarlyClobber) + + # Prologue + result.add quote do: + var `tSym`{.noInit, used.}: array[`tSlots`, BaseType] + var `uSym`{.noinit.}, `vSym`{.noInit.}: BaseType + var `biSym`{.noInit.}: BaseType + var `aaSym`{.noInit, used.}: typeof(`a_PIR`) + `aaSym` = `a_PIR` + + template mulloadd_co(ctx, dst, lhs, rhs, addend) {.dirty.} = + ctx.mul u, lhs, rhs + ctx.adds dst, addend, u + swap(u, v) + template mulloadd_cio(ctx, dst, lhs, rhs, addend) {.dirty.} = + ctx.mul u, lhs, rhs + ctx.adcs dst, addend, u + swap(u, v) + + template mulhiadd_co(ctx, dst, lhs, rhs, addend) {.dirty.} = + ctx.umulh u, lhs, rhs + ctx.adds dst, addend, u + swap(u, v) + template mulhiadd_cio(ctx, dst, lhs, rhs, addend) {.dirty.} = + ctx.umulh u, lhs, rhs + ctx.adcs dst, addend, u + swap(u, v) + template mulhiadd_ci(ctx, dst, lhs, rhs, addend) {.dirty.} = + ctx.umulh u, lhs, rhs + ctx.adc dst, addend, u + swap(u, v) + + doAssert aLen >= 2 + + for i in 0 ..< min(rLen, bLen): + ctx.ldr bi, b[i] + if i == 0: + ctx.mul u, aa[0], bi + ctx.str u, r[i] + ctx.umulh t[0], aa[0], bi + swap(u, v) + for j in 1 ..< aLen: + ctx.mul u, aa[j], bi + ctx.umulh t[j], aa[j], bi + if j == 1: + ctx.adds t[j-1], t[j-1], u + else: + ctx.adcs t[j-1], t[j-1], u + ctx.adc t[aLen-1], t[aLen-1], xzr + swap(u, v) + else: + ctx.mulloadd_co(t[0], aa[0], bi, t[0]) + ctx.str t[0], r[i] + for j in 1 ..< aLen: + ctx.mulloadd_cio(t[j], aa[j], bi, t[j]) + ctx.adc t[aLen], xzr, xzr # assumes N > 1 + + ctx.mulhiadd_co(t[1], aa[0], bi, t[1]) + for j in 2 ..< aLen: + ctx.mulhiadd_cio(t[j], aa[j-1], bi, t[j]) + ctx.mulhiadd_ci(t[aLen], aa[aLen-1], bi, t[aLen]) + + t.rotateLeft() + + # Copy upper-limbs to result + for i in b.len ..< min(a.len+b.len, rLen): + ctx.str t[i-b.len], r[i] + + # Zero the extra + for i in aLen+bLen ..< rLen: + ctx.str xzr, r[i] + + result.add ctx.generate() + +func mul_asm*[rLen, aLen, bLen: static int]( + r: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) = + ## Multi-precision Multiplication + ## Assumes r doesn't alias a or b + mul_gen(r, a, b) \ No newline at end of file diff --git a/constantine/math/arithmetic/assembly/limbs_asm_mul_mont_arm64.nim b/constantine/math/arithmetic/assembly/limbs_asm_mul_mont_arm64.nim index 047537291..e2ec3ab90 100644 --- a/constantine/math/arithmetic/assembly/limbs_asm_mul_mont_arm64.nim +++ b/constantine/math/arithmetic/assembly/limbs_asm_mul_mont_arm64.nim @@ -173,6 +173,7 @@ macro mulMont_CIOS_sparebit_gen[N: static int]( for j in 1 ..< N: ctx.mulloadd_cio(t[j], aa[j], bi, t[j]) ctx.adc A, xzr, xzr # assumes N > 1 + ctx.mulhiadd_co(t[1], aa[0], bi, t[1]) # assumes N > 1 for j in 2 ..< N: ctx.mulhiadd_cio(t[j], aa[j-1], bi, t[j]) diff --git a/constantine/math/arithmetic/assembly/limbs_asm_mul_x86_adx_bmi2.nim b/constantine/math/arithmetic/assembly/limbs_asm_mul_x86_adx_bmi2.nim index d7d6aabd0..447a88c0f 100644 --- a/constantine/math/arithmetic/assembly/limbs_asm_mul_x86_adx_bmi2.nim +++ b/constantine/math/arithmetic/assembly/limbs_asm_mul_x86_adx_bmi2.nim @@ -14,7 +14,7 @@ import # ############################################################ # -# Assembly implementation of finite fields +# Assembly implementation of bigint multiplication # # ############################################################ @@ -111,7 +111,7 @@ macro mulx_gen[rLen, aLen, bLen: static int](r_PIR: var Limbs[rLen], a_MEM: Limb ## The result will be truncated, i.e. it will be ## a * b (mod (2^WordBitWidth)^r.limbs.len) ## - ## Assumes r doesn't aliases a or b + ## Assumes r doesn't alias a or b result = newStmtList() diff --git a/constantine/math/arithmetic/limbs_extmul.nim b/constantine/math/arithmetic/limbs_extmul.nim index 0063d09b7..acaeac040 100644 --- a/constantine/math/arithmetic/limbs_extmul.nim +++ b/constantine/math/arithmetic/limbs_extmul.nim @@ -13,6 +13,8 @@ import when UseASM_X86_64: import ./assembly/limbs_asm_mul_x86 import ./assembly/limbs_asm_mul_x86_adx_bmi2 +when UseASM_ARM64: + import ./assembly/limbs_asm_mul_arm64 # ############################################################ # @@ -78,6 +80,8 @@ func prod*[rLen, aLen, bLen: static int](r{.noalias.}: var Limbs[rLen], a: Limbs mul_asm(r, a, b) elif UseASM_X86_64: mul_asm(r, a, b) + elif UseASM_ARM64 and aLen in {2..8}: + mul_asm(r, a, b) else: prod_comba(r, a, b) diff --git a/constantine/platforms/isa_arm64/macro_assembler_arm64.nim b/constantine/platforms/isa_arm64/macro_assembler_arm64.nim index 36ca01981..a46eefe67 100644 --- a/constantine/platforms/isa_arm64/macro_assembler_arm64.nim +++ b/constantine/platforms/isa_arm64/macro_assembler_arm64.nim @@ -520,6 +520,17 @@ func codeFragment(a: var Assembler_arm64, instr: string, op: Operand, reg: Regis if reg != xzr: a.regClobbers.incl reg +func codeFragment(a: var Assembler_arm64, instr: string, reg: Register, op: Operand) = + # Generate a code fragment + let off = a.getStrOffset(op) + + a.code &= instr & " " & $reg & ", " & off & '\n' + + if op.desc.constraint != asmClobberedRegister: + a.operands.incl op.desc + if reg != xzr: + a.regClobbers.incl reg + func codeFragment(a: var Assembler_arm64, instr: string, op0, op1: Operand) = # Generate a code fragment let off0 = a.getStrOffset(op0) @@ -687,6 +698,11 @@ func str*(a: var Assembler_arm64, src, dst: Operand) = doAssert dst.isOutput(), $dst.repr a.codeFragment("str", src, dst) +func str*(a: var Assembler_arm64, src: Register, dst: Operand) = + ## Store register: src -> dst + doAssert dst.isOutput(), $dst.repr + a.codeFragment("str", src, dst) + func stp*(a: var Assembler_arm64, src0, src1, dst: Operand) = ## Store pair: (src0, src1) -> dst doAssert dst.isOutput(), $dst.repr