Skip to content

Commit

Permalink
loop ok in mcl_c5_vmul
Browse files Browse the repository at this point in the history
  • Loading branch information
herumi committed Aug 29, 2024
1 parent 437ec49 commit f4fbeec
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 79 deletions.
160 changes: 93 additions & 67 deletions src/asm/bint-x64-amd64.S
Original file line number Diff line number Diff line change
Expand Up @@ -710,45 +710,71 @@ SIZE(mcl_c5_vsubA)
PRE(mcl_c5_vmul):
TYPE(mcl_c5_vmul)
mov $4503599627370495, %rax
vpbroadcastq %rax, %zmm25
vpbroadcastq PRE(rp)(%rip), %zmm26
vmovdqa64 (%rdx), %zmm28
vpbroadcastq %rax, %zmm9
vpbroadcastq PRE(rp)(%rip), %zmm10
vmovdqa64 (%rdx), %zmm12
add $64, %rdx
vpxorq %zmm0, %zmm0, %zmm0
vpmadd52luq (%rsi), %zmm28, %zmm0
vpxorq %zmm29, %zmm29, %zmm29
vpmadd52huq (%rsi), %zmm28, %zmm29
vmovdqa64 %zmm29, %zmm1
vpmadd52luq 64(%rsi), %zmm28, %zmm1
vpxorq %zmm29, %zmm29, %zmm29
vpmadd52huq 64(%rsi), %zmm28, %zmm29
vmovdqa64 %zmm29, %zmm2
vpmadd52luq 128(%rsi), %zmm28, %zmm2
vpxorq %zmm29, %zmm29, %zmm29
vpmadd52huq 128(%rsi), %zmm28, %zmm29
vmovdqa64 %zmm29, %zmm3
vpmadd52luq 192(%rsi), %zmm28, %zmm3
vpxorq %zmm29, %zmm29, %zmm29
vpmadd52huq 192(%rsi), %zmm28, %zmm29
vmovdqa64 %zmm29, %zmm4
vpmadd52luq 256(%rsi), %zmm28, %zmm4
vpxorq %zmm29, %zmm29, %zmm29
vpmadd52huq 256(%rsi), %zmm28, %zmm29
vmovdqa64 %zmm29, %zmm5
vpmadd52luq 320(%rsi), %zmm28, %zmm5
vpxorq %zmm29, %zmm29, %zmm29
vpmadd52huq 320(%rsi), %zmm28, %zmm29
vmovdqa64 %zmm29, %zmm6
vpmadd52luq 384(%rsi), %zmm28, %zmm6
vpxorq %zmm29, %zmm29, %zmm29
vpmadd52huq 384(%rsi), %zmm28, %zmm29
vmovdqa64 %zmm29, %zmm7
vpmadd52luq 448(%rsi), %zmm28, %zmm7
vpmadd52luq (%rsi), %zmm12, %zmm0
vpxorq %zmm13, %zmm13, %zmm13
vpmadd52huq (%rsi), %zmm12, %zmm13
vmovdqa64 %zmm13, %zmm1
vpmadd52luq 64(%rsi), %zmm12, %zmm1
vpxorq %zmm13, %zmm13, %zmm13
vpmadd52huq 64(%rsi), %zmm12, %zmm13
vmovdqa64 %zmm13, %zmm2
vpmadd52luq 128(%rsi), %zmm12, %zmm2
vpxorq %zmm13, %zmm13, %zmm13
vpmadd52huq 128(%rsi), %zmm12, %zmm13
vmovdqa64 %zmm13, %zmm3
vpmadd52luq 192(%rsi), %zmm12, %zmm3
vpxorq %zmm13, %zmm13, %zmm13
vpmadd52huq 192(%rsi), %zmm12, %zmm13
vmovdqa64 %zmm13, %zmm4
vpmadd52luq 256(%rsi), %zmm12, %zmm4
vpxorq %zmm13, %zmm13, %zmm13
vpmadd52huq 256(%rsi), %zmm12, %zmm13
vmovdqa64 %zmm13, %zmm5
vpmadd52luq 320(%rsi), %zmm12, %zmm5
vpxorq %zmm13, %zmm13, %zmm13
vpmadd52huq 320(%rsi), %zmm12, %zmm13
vmovdqa64 %zmm13, %zmm6
vpmadd52luq 384(%rsi), %zmm12, %zmm6
vpxorq %zmm13, %zmm13, %zmm13
vpmadd52huq 384(%rsi), %zmm12, %zmm13
vmovdqa64 %zmm13, %zmm7
vpmadd52luq 448(%rsi), %zmm12, %zmm7
vpxorq %zmm8, %zmm8, %zmm8
vpmadd52huq 448(%rsi), %zmm28, %zmm8
vpxorq %zmm30, %zmm30, %zmm30
vpmadd52luq %zmm26, %zmm0, %zmm30
vpmadd52huq 448(%rsi), %zmm12, %zmm8
vpxorq %zmm14, %zmm14, %zmm14
vpmadd52luq %zmm10, %zmm0, %zmm14
lea PRE(ap)(%rip), %rax
call .L3
call .L4
mov $7, %ecx
.align 32
.L3:
mov %rsi, %rax
vmovdqa64 (%rdx), %zmm14
add $64, %rdx
vmovdqa64 %zmm0, %zmm15
vmovdqa64 %zmm1, %zmm0
vmovdqa64 %zmm2, %zmm1
vmovdqa64 %zmm3, %zmm2
vmovdqa64 %zmm4, %zmm3
vmovdqa64 %zmm5, %zmm4
vmovdqa64 %zmm6, %zmm5
vmovdqa64 %zmm7, %zmm6
vmovdqa64 %zmm8, %zmm7
vpxorq %zmm8, %zmm8, %zmm8
call .L4
vpsrlq $52, %zmm15, %zmm14
vpaddq %zmm14, %zmm0, %zmm0
vpxorq %zmm14, %zmm14, %zmm14
vpmadd52luq %zmm10, %zmm0, %zmm14
lea PRE(ap)(%rip), %rax
call .L4
dec %ecx
jnz .L3
vmovdqa64 %zmm0, (%rdi)
vmovdqa64 %zmm1, 64(%rdi)
vmovdqa64 %zmm2, 128(%rdi)
Expand All @@ -761,37 +787,37 @@ vmovdqa64 %zmm8, 512(%rdi)
vzeroupper
ret
.align 32
.L3:
vpmadd52luq (%rax), %zmm30, %zmm0
vpxorq %zmm29, %zmm29, %zmm29
vpmadd52huq (%rax), %zmm30, %zmm29
vpmadd52luq 64(%rax), %zmm30, %zmm1
vpaddq %zmm29, %zmm1, %zmm1
vpxorq %zmm29, %zmm29, %zmm29
vpmadd52huq 64(%rax), %zmm30, %zmm29
vpmadd52luq 128(%rax), %zmm30, %zmm2
vpaddq %zmm29, %zmm2, %zmm2
vpxorq %zmm29, %zmm29, %zmm29
vpmadd52huq 128(%rax), %zmm30, %zmm29
vpmadd52luq 192(%rax), %zmm30, %zmm3
vpaddq %zmm29, %zmm3, %zmm3
vpxorq %zmm29, %zmm29, %zmm29
vpmadd52huq 192(%rax), %zmm30, %zmm29
vpmadd52luq 256(%rax), %zmm30, %zmm4
vpaddq %zmm29, %zmm4, %zmm4
vpxorq %zmm29, %zmm29, %zmm29
vpmadd52huq 256(%rax), %zmm30, %zmm29
vpmadd52luq 320(%rax), %zmm30, %zmm5
vpaddq %zmm29, %zmm5, %zmm5
vpxorq %zmm29, %zmm29, %zmm29
vpmadd52huq 320(%rax), %zmm30, %zmm29
vpmadd52luq 384(%rax), %zmm30, %zmm6
vpaddq %zmm29, %zmm6, %zmm6
vpxorq %zmm29, %zmm29, %zmm29
vpmadd52huq 384(%rax), %zmm30, %zmm29
vpmadd52luq 448(%rax), %zmm30, %zmm7
vpaddq %zmm29, %zmm7, %zmm7
vpmadd52huq 448(%rax), %zmm30, %zmm8
.L4:
vpmadd52luq (%rax), %zmm14, %zmm0
vpxorq %zmm13, %zmm13, %zmm13
vpmadd52huq (%rax), %zmm14, %zmm13
vpmadd52luq 64(%rax), %zmm14, %zmm1
vpaddq %zmm13, %zmm1, %zmm1
vpxorq %zmm13, %zmm13, %zmm13
vpmadd52huq 64(%rax), %zmm14, %zmm13
vpmadd52luq 128(%rax), %zmm14, %zmm2
vpaddq %zmm13, %zmm2, %zmm2
vpxorq %zmm13, %zmm13, %zmm13
vpmadd52huq 128(%rax), %zmm14, %zmm13
vpmadd52luq 192(%rax), %zmm14, %zmm3
vpaddq %zmm13, %zmm3, %zmm3
vpxorq %zmm13, %zmm13, %zmm13
vpmadd52huq 192(%rax), %zmm14, %zmm13
vpmadd52luq 256(%rax), %zmm14, %zmm4
vpaddq %zmm13, %zmm4, %zmm4
vpxorq %zmm13, %zmm13, %zmm13
vpmadd52huq 256(%rax), %zmm14, %zmm13
vpmadd52luq 320(%rax), %zmm14, %zmm5
vpaddq %zmm13, %zmm5, %zmm5
vpxorq %zmm13, %zmm13, %zmm13
vpmadd52huq 320(%rax), %zmm14, %zmm13
vpmadd52luq 384(%rax), %zmm14, %zmm6
vpaddq %zmm13, %zmm6, %zmm6
vpxorq %zmm13, %zmm13, %zmm13
vpmadd52huq 384(%rax), %zmm14, %zmm13
vpmadd52luq 448(%rax), %zmm14, %zmm7
vpaddq %zmm13, %zmm7, %zmm7
vpmadd52huq 448(%rax), %zmm14, %zmm8
ret
SIZE(mcl_c5_vmul)
.align 16
Expand Down
49 changes: 39 additions & 10 deletions src/gen_bint_x64.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,24 +245,31 @@ def vmulUnitAddBroadcast(z, px, y, N, H):
else:
vmulH(z[N], y, ptr_b(px+i*8))

def shift(v, s):
vmovdqa64(s, v[0])
for i in range(1, len(v)):
vmovdqa64(v[i-1], v[i])
vpxorq(v[-1], v[-1], v[-1])

def gen_vmul(mont):
with FuncProc(MSM_PRE+'vmul'):
with StackFrame(3, 0, vNum=mont.N*3+7, vType=T_ZMM) as sf:
with StackFrame(3, 0, useRCX=True, vNum=mont.N+8, vType=T_ZMM) as sf:
regs = list(reversed(sf.v))
W = mont.W
N = mont.N
pz = sf.p[0]
px = sf.p[1]
py = sf.p[2]

t = pops(regs, N*2)
t2 = pops(regs, N+1)
t = pops(regs, N+1)
vmask = pops(regs, 1)[0]
rp = pops(regs, 1)[0]
c = pops(regs, 1)[0]
y = pops(regs, 1)[0]
H = pops(regs, 1)[0]
q = pops(regs, 1)[0]
s = pops(regs, 1)[0]
lpL = Label()

vmulUnitAddL = Label()

Expand All @@ -272,23 +279,45 @@ def gen_vmul(mont):

un = genUnrollFunc()

vmovdqa64(y, ptr_b(py))
vmovdqa64(y, ptr(py))
add(py, 64)
vmulUnit(t, px, y, N, H)

vpxorq(q, q, q)
vmulL(q, t[0], rp)

lea(rax, ptr(rip+C_ap))
#vmulUnitAdd(t, rax, q, N, H)
call(vmulUnitAddL)
un(vmovdqa64)(ptr(pz), t[0:N+1])

#un(vmovdqa64)(ptr(pz), t[0:N])
call(vmulUnitAddL) # t += p * q


mov(ecx, N-1)
align(32)
L(lpL)

mov(rax, px)
vmovdqa64(q, ptr(py))
add(py, 64)
shift(t, s)
call(vmulUnitAddL) # t += x * py[i]
vpsrlq(q, s, W)
vpaddq(t[0], t[0], q)

vpxorq(q, q, q)
vmulL(q, t[0], rp)

lea(rax, ptr(rip+C_ap))
call(vmulUnitAddL) # t += p * q

dec(ecx)
jnz(lpL)

un(vmovdqa64)(ptr(pz), t)

sf.close()
# out of vmul
align(32)
L(vmulUnitAddL)
#rax = px
#set rax(= px) and q(= y)
vmulUnitAdd(t, rax, q, N, H)
ret()

Expand Down
7 changes: 5 additions & 2 deletions src/msm_avx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,14 +311,16 @@ inline void vmul2(V *z, const V *x, const U *y)
vmulUnit(t, x, broadcast<V>(y[0]));
q = vmulL(t[0], G::rp());
t[N] = vpaddq(t[N], vmulUnitAdd(t, G::ap(), q));
mcl::bint::copyT<N+1>(z, t);
return;
for (size_t i = 1; i < N; i++) {
t[N+i] = vmulUnitAdd(t+i, x, broadcast<V>(y[i]));
t[i] = vpaddq(t[i], vpsrlq(t[i-1], W));
q = vmulL(t[i], G::rp());
t[N+i] = vpaddq(t[N+i], vmulUnitAdd(t+i, G::ap(), q));
//mcl::bint::copyT<N+1>(z, t+i);
//return;
}
mcl::bint::copyT<N+1>(z, t+N-1);
return;
for (size_t i = N; i < N*2; i++) {
t[i] = vpaddq(t[i], vpsrlq(t[i-1], W));
t[i-1] = vpandq(t[i-1], G::mask());
Expand Down Expand Up @@ -1737,6 +1739,7 @@ for (int i = 0; i < 9; i++) {
printf("i=%d\n", i);
dump(v[i], "v");
dump(w[i], "w");
if (memcmp(&v[i], &w[i], sizeof(v[i])) != 0) printf("ERR");
}
exit(1);
#endif
Expand Down

0 comments on commit f4fbeec

Please sign in to comment.