Skip to content

Commit

Permalink
ARM64: add standalone bigint multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jan 6, 2025
1 parent e9c179d commit e8b550f
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 2 deletions.
136 changes: 136 additions & 0 deletions constantine/math/arithmetic/assembly/limbs_asm_mul_arm64.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
# Standard library
std/macros,
# Internal
constantine/platforms/abstractions

# ############################################################
#
# Assembly implementation of bigint multiplication
#
# ############################################################

static: doAssert UseASM_ARM64

macro mul_gen[rLen, aLen, bLen: static int](
r_PIR: var Limbs[rLen],
a_PIR: Limbs[aLen],
b_PIR: Limbs[bLen]) =
## `a`, `b`, `r` can have a different number of limbs
## if `r`.limbs.len < a.limbs.len + b.limbs.len
## The result will be truncated, i.e. it will be
## a * b (mod (2^WordBitWidth)^r.limbs.len)
##
## Assumes r doesn't alias a or b

result = newStmtList()

var ctx = init(Assembler_arm64, BaseType)
let
r = asmArray(r_PIR, rLen, PointerInReg, asmInput, memIndirect = memWrite)
a = asmArray(a_PIR, aLen, PointerInReg, asmInput, memIndirect = memRead)
b = asmArray(b_PIR, bLen, PointerInReg, asmInput, memIndirect = memRead)

tSym = ident"t"
tSlots = aLen+1 # Extra for high words

biSym = ident"bi"
bi = asmValue(biSym, Reg, asmOutputEarlyClobber)

aaSym = ident"aa"
aa = asmArray(aaSym, aLen, ElemsInReg, asmInputOutput)

uSym = ident"u"
vSym = ident"v"

var t = asmArray(tSym, tSlots, ElemsInReg, asmOutputEarlyClobber)

var # Break dependencies chain
u = asmValue(uSym, Reg, asmOutputEarlyClobber)
v = asmValue(vSym, Reg, asmOutputEarlyClobber)

# Prologue
result.add quote do:
var `tSym`{.noInit, used.}: array[`tSlots`, BaseType]
var `uSym`{.noinit.}, `vSym`{.noInit.}: BaseType
var `biSym`{.noInit.}: BaseType
var `aaSym`{.noInit, used.}: typeof(`a_PIR`)
`aaSym` = `a_PIR`

template mulloadd_co(ctx, dst, lhs, rhs, addend) {.dirty.} =
ctx.mul u, lhs, rhs
ctx.adds dst, addend, u
swap(u, v)
template mulloadd_cio(ctx, dst, lhs, rhs, addend) {.dirty.} =
ctx.mul u, lhs, rhs
ctx.adcs dst, addend, u
swap(u, v)

template mulhiadd_co(ctx, dst, lhs, rhs, addend) {.dirty.} =
ctx.umulh u, lhs, rhs
ctx.adds dst, addend, u
swap(u, v)
template mulhiadd_cio(ctx, dst, lhs, rhs, addend) {.dirty.} =
ctx.umulh u, lhs, rhs
ctx.adcs dst, addend, u
swap(u, v)
template mulhiadd_ci(ctx, dst, lhs, rhs, addend) {.dirty.} =
ctx.umulh u, lhs, rhs
ctx.adc dst, addend, u
swap(u, v)

doAssert aLen >= 2

for i in 0 ..< min(rLen, bLen):
ctx.ldr bi, b[i]
if i == 0:
ctx.mul u, aa[0], bi
ctx.str u, r[i]
ctx.umulh t[0], aa[0], bi
swap(u, v)
for j in 1 ..< aLen:
ctx.mul u, aa[j], bi
ctx.umulh t[j], aa[j], bi
if j == 1:
ctx.adds t[j-1], t[j-1], u
else:
ctx.adcs t[j-1], t[j-1], u
ctx.adc t[aLen-1], t[aLen-1], xzr
swap(u, v)
else:
ctx.mulloadd_co(t[0], aa[0], bi, t[0])
ctx.str t[0], r[i]
for j in 1 ..< aLen:
ctx.mulloadd_cio(t[j], aa[j], bi, t[j])
ctx.adc t[aLen], xzr, xzr # assumes N > 1

ctx.mulhiadd_co(t[1], aa[0], bi, t[1])
for j in 2 ..< aLen:
ctx.mulhiadd_cio(t[j], aa[j-1], bi, t[j])
ctx.mulhiadd_ci(t[aLen], aa[aLen-1], bi, t[aLen])

t.rotateLeft()

# Copy upper-limbs to result
for i in b.len ..< min(a.len+b.len, rLen):
ctx.str t[i-b.len], r[i]

# Zero the extra
for i in aLen+bLen ..< rLen:
ctx.str xzr, r[i]

result.add ctx.generate()

func mul_asm*[rLen, aLen, bLen: static int](
r: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) =
## Multi-precision Multiplication
## Assumes r doesn't alias a or b
mul_gen(r, a, b)
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ macro mulMont_CIOS_sparebit_gen[N: static int](
for j in 1 ..< N:
ctx.mulloadd_cio(t[j], aa[j], bi, t[j])
ctx.adc A, xzr, xzr # assumes N > 1

ctx.mulhiadd_co(t[1], aa[0], bi, t[1]) # assumes N > 1
for j in 2 ..< N:
ctx.mulhiadd_cio(t[j], aa[j-1], bi, t[j])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import

# ############################################################
#
# Assembly implementation of finite fields
# Assembly implementation of bigint multiplication
#
# ############################################################

Expand Down Expand Up @@ -111,7 +111,7 @@ macro mulx_gen[rLen, aLen, bLen: static int](r_PIR: var Limbs[rLen], a_MEM: Limb
## The result will be truncated, i.e. it will be
## a * b (mod (2^WordBitWidth)^r.limbs.len)
##
## Assumes r doesn't aliases a or b
## Assumes r doesn't alias a or b

result = newStmtList()

Expand Down
4 changes: 4 additions & 0 deletions constantine/math/arithmetic/limbs_extmul.nim
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import
when UseASM_X86_64:
import ./assembly/limbs_asm_mul_x86
import ./assembly/limbs_asm_mul_x86_adx_bmi2
when UseASM_ARM64:
import ./assembly/limbs_asm_mul_arm64

# ############################################################
#
Expand Down Expand Up @@ -78,6 +80,8 @@ func prod*[rLen, aLen, bLen: static int](r{.noalias.}: var Limbs[rLen], a: Limbs
mul_asm(r, a, b)
elif UseASM_X86_64:
mul_asm(r, a, b)
elif UseASM_ARM64 and aLen in {2..8}:
mul_asm(r, a, b)
else:
prod_comba(r, a, b)

Expand Down
16 changes: 16 additions & 0 deletions constantine/platforms/isa_arm64/macro_assembler_arm64.nim
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,17 @@ func codeFragment(a: var Assembler_arm64, instr: string, op: Operand, reg: Regis
if reg != xzr:
a.regClobbers.incl reg

func codeFragment(a: var Assembler_arm64, instr: string, reg: Register, op: Operand) =
# Generate a code fragment
let off = a.getStrOffset(op)

a.code &= instr & " " & $reg & ", " & off & '\n'

if op.desc.constraint != asmClobberedRegister:
a.operands.incl op.desc
if reg != xzr:
a.regClobbers.incl reg

func codeFragment(a: var Assembler_arm64, instr: string, op0, op1: Operand) =
# Generate a code fragment
let off0 = a.getStrOffset(op0)
Expand Down Expand Up @@ -687,6 +698,11 @@ func str*(a: var Assembler_arm64, src, dst: Operand) =
doAssert dst.isOutput(), $dst.repr
a.codeFragment("str", src, dst)

func str*(a: var Assembler_arm64, src: Register, dst: Operand) =
## Store register: src -> dst
doAssert dst.isOutput(), $dst.repr
a.codeFragment("str", src, dst)

func stp*(a: var Assembler_arm64, src0, src1, dst: Operand) =
## Store pair: (src0, src1) -> dst
doAssert dst.isOutput(), $dst.repr
Expand Down

0 comments on commit e8b550f

Please sign in to comment.