Skip to content

Commit

Permalink
arm64: add assembly for sumprod
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jan 8, 2025
1 parent 3340914 commit ec80d57
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 5 deletions.
199 changes: 198 additions & 1 deletion constantine/math/arithmetic/assembly/limbs_asm_mul_mont_arm64.nim
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,201 @@ func mulMont_CIOS_sparebit_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType,
## otherwise the result is in the range [0, M)
##
## This procedure can only be called if the modulus doesn't use the full bitwidth of its underlying representation
r.mulMont_CIOS_sparebit_gen(a, b, M, m0ninv, lazyReduce)
r.mulMont_CIOS_sparebit_gen(a, b, M, m0ninv, lazyReduce)

# Montgomery Sum of Products
# ------------------------------------------------------------

macro sumprodMont_CIOS_spare2bits_gen[N, K: static int](
r_PIR: var Limbs[N], a_PIR, b_PIR: array[K, Limbs[N]],
M_REG: Limbs[N], m0ninv_REG: BaseType,
lazyReduce: static bool): untyped =
## Generate an optimized Montgomery merged sum of products ⅀aᵢ.bᵢ kernel
## using the CIOS method
##
## This requires 2 spare bits in the most significant word
## so that we can skip the intermediate reductions

# No register spilling handling
doAssert N <= 6, "The Assembly-optimized montgomery multiplication requires at most 6 limbs."

doAssert K <= 8, "we cannot sum more than 8 products"
# Bounds:
# 1. To ensure mapping in [0, 2p), we need ⅀aᵢ.bᵢ <=pR
# for all intent and purposes this is true since aᵢ.bᵢ is:
# if reduced inputs: (p-1).(p-1) = p²-2p+1 which would allow more than p sums
# if unreduced inputs: (2p-1).(2p-1) = 4p²-4p+1,
# with 4p < R due to the 2 unused bits constraint so more than p sums are allowed
# 2. We have a high-word tN to accumulate overflows.
# with 2 unused bits in the last word,
# the multiplication of two last words will leave 4 unused bits
# enough for accumulating 8 additions and overflow.

result = newStmtList()

var ctx = init(Assembler_arm64, BaseType)
let
scratchSlots = 8

r = asmArray(r_PIR, N, PointerInReg, asmInput, memIndirect = memWrite)
M = asmArray(M_REG, N, ElemsInReg, asmInput)

akSym = ident "ak"
ak = asmArray(akSym, N, ElemsInReg, asmOutputEarlyClobber) # buffer for a[k]

tSym = ident"t"
t = asmArray(tSym, N, ElemsInReg, asmOutputEarlyClobber)
m0ninv = asmValue(m0ninv_REG, Reg, asmInput)

# MultiPurpose Register slots
scratchSym = ident"scratch"
scratch = asmArray(scratchSym, scratchSlots, ElemsInReg, asmInputOutputEarlyClobber)

a = scratch[0].as2dArrayAddr(a_PIR, rows = K, cols = N, memIndirect = memRead) # Store the `a` operand
b = scratch[1].as2dArrayAddr(b_PIR, rows = K, cols = N, memIndirect = memRead) # Store the `b` operand
tN = scratch[2] # High part of extended precision multiplication
A = scratch[3] # Carry during mul step (A)
bi = scratch[4] # Stores b[i] during mul and u during reduction
m = scratch[5] # Red step: (t[0] * m0ninv) mod 2ʷ

var # break dependency chains
u = scratch[6]
v = scratch[7]

template mulloadd_co(ctx, dst, lhs, rhs, addend) {.dirty.} =
ctx.mul u, lhs, rhs
ctx.adds dst, addend, u
swap(u, v)
template mulloadd_cio(ctx, dst, lhs, rhs, addend) {.dirty.} =
ctx.mul u, lhs, rhs
ctx.adcs dst, addend, u
swap(u, v)

template mulhiadd_co(ctx, dst, lhs, rhs, addend) {.dirty.} =
ctx.umulh u, lhs, rhs
ctx.adds dst, addend, u
swap(u, v)
template mulhiadd_cio(ctx, dst, lhs, rhs, addend) {.dirty.} =
ctx.umulh u, lhs, rhs
ctx.adcs dst, addend, u
swap(u, v)
template mulhiadd_ci(ctx, dst, lhs, rhs, addend) {.dirty.} =
ctx.umulh u, lhs, rhs
ctx.adc dst, addend, u
swap(u, v)

result.add quote do:
static: doAssert: sizeof(SecretWord) == sizeof(ByteAddress)

var `tsym`{.noInit, used.}: typeof(`r_PIR`)
# Assumes 64-bit limbs on 64-bit arch (or you can't store an address)
var `scratchSym` {.noInit.}: Limbs[`scratchSlots`]
`scratchSym`[0] = cast[SecretWord](`a_PIR`[0][0].unsafeAddr)
`scratchSym`[1] = cast[SecretWord](`b_PIR`[0][0].unsafeAddr)

var `akSym` {.noInit.}: typeof(`a_PIR`[0])

# Algorithm
# -----------------------------------------
# for i=0 to N-1
# tN := 0
# for k=0 to K-1
# A := 0
# for j=0 to N-1
# (A,t[j]) := t[j] + a[k][j]*b[k][i] + A
# tN += A
# m := t[0]*m0ninv mod W
# C,_ := t[0] + m*M[0]
# for j=1 to N-1
# (C,t[j-1]) := t[j] + m*M[j] + C
# t[N-1] = tN + C

for i in 0 ..< N:
# Multiplication step
# -------------------------------
ctx.comment " Multiplication step"
ctx.comment " tN = 0"
ctx.mov tN, xzr
for k in 0 ..< K:
ctx.comment " A = 0"
ctx.mov A, xzr

ctx.comment " bi <- b[k, i]"
ctx.ldr bi, b[k, i]

ctx.comment " load a[k] in registers"
let lastEven = N.round_step_down(2)
for i in countup(0, lastEven-1, 2):
ctx.ldp ak[i], ak[i+1], a[k, i]
if lastEven != N:
ctx.ldr ak[N-1], a[k, N-1]

ctx.comment " (A,t[0]) := t[0] + a[k][0]*b[k][i] + A"
if k == 0 and i == 0: # First accumulation, overwrite t[0]
for j in 0 ..< N:
ctx.mul t[j], ak[j], bi
else:
ctx.mulloadd_co(t[0], ak[0], bi, t[0])
for j in 1 ..< N:
ctx.mulloadd_cio(t[j], ak[j], bi, t[j])
ctx.adc A, xzr, xzr # assumes N > 1

ctx.mulhiadd_co(t[1], ak[0], bi, t[1]) # assumes N > 1
for j in 2 ..< N:
ctx.mulhiadd_cio(t[j], ak[j-1], bi, t[j])
ctx.mulhiadd_ci(A, ak[N-1], bi, A)

ctx.add tN, tN, A

# Reduction step
# -------------------------------
ctx.comment " Reduction step"

ctx.mul m, t[0], m0ninv
ctx.mul u, m, M[0]
ctx.cmn t[0], u # TODO: bad latency chain, hopefully done parallel to prev loop
swap(u, v)

for j in 1 ..< N:
ctx.mulloadd_cio(t[j-1], m, M[j], t[j])
ctx.adc t[N-1], tN, xzr

# assumes N > 1
ctx.mulhiadd_co(t[0], m, M[0], t[0])
for j in 1 ..< N-1:
ctx.mulhiadd_cio(t[j], m, M[j], t[j])
ctx.mulhiadd_ci(t[N-1], m, M[N-1], t[N-1])


if lazyReduce:
for i in 0 ..< N:
ctx.str t[i], r[i]
else:
# Final substraction
# we reuse the aa buffer
template s: untyped = ak

for i in 0 ..< N:
if i == 0:
ctx.subs s[i], t[i], M[i]
else:
ctx.sbcs s[i], t[i], M[i]

# if carry clear t < M, so pick t
for i in 0 ..< N:
ctx.csel t[i], t[i], s[i], cc
ctx.str t[i], r[i]

result.add ctx.generate()

func sumprodMont_CIOS_spare2bits_asm*[N, K: static int](
r: var Limbs[N], a, b: array[K, Limbs[N]],
M: Limbs[N], m0ninv: BaseType,
lazyReduce: static bool) =
## Sum of products ⅀aᵢ.bᵢ in the Montgomery domain
## If "lazyReduce" is set
## the result is in the range [0, 2M)
## otherwise the result is in the range [0, M)
##
## This procedure can only be called if the modulus doesn't use the full bitwidth of its underlying representation
r.sumprodMont_CIOS_spare2bits_gen(a, b, M, m0ninv, lazyReduce)
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,7 @@ macro sumprodMont_CIOS_spare2bits_adx_gen[N, K: static int](
tN = scratch[2] # High part of extended precision multiplication
C = scratch[3] # Carry during reduction step
r = scratch[4] # Stores the `r` operand
S = scratch[5] # Mul step: Stores the carry A
# Red step: Stores (t[0] * m0ninv) mod 2ʷ
A = scratch[5] # Stores the carry A

# Registers used:
# - 1 for `M`
Expand Down Expand Up @@ -394,8 +393,6 @@ macro sumprodMont_CIOS_spare2bits_adx_gen[N, K: static int](
ctx.comment " tN = 0"
ctx.`xor` tN, tN
for k in 0 ..< K:
template A: untyped = S

ctx.comment " A = 0"
ctx.`xor` A, A
ctx.comment " (A,t[0]) := t[0] + a[k][0]*b[k][i] + A"
Expand Down
2 changes: 2 additions & 0 deletions constantine/math/arithmetic/limbs_montgomery.nim
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,8 @@ func sumprodMont*[N: static int](
r.sumprodMont_CIOS_spare2bits_asm_adx(a, b, M, m0ninv, lazyReduce)
else:
r.sumprodMont_CIOS_spare2bits_asm(a, b, M, m0ninv, lazyReduce)
elif UseASM_ARM64 and r.len in {2 .. 6}:
r.sumprodMont_CIOS_spare2bits_asm(a, b, M, m0ninv, lazyReduce)
else:
r.sumprodMont_CIOS_spare2bits(a, b, M, m0ninv, lazyReduce)
else:
Expand Down

0 comments on commit ec80d57

Please sign in to comment.