diff --git a/byteops.go b/byteops.go index fcfacd6..fed5740 100644 --- a/byteops.go +++ b/byteops.go @@ -60,11 +60,7 @@ func byteopsCbd(buf []byte, paramsK int) poly { // byteopsMontgomeryReduce computes a Montgomery reduction; given // a 32-bit integer `a`, returns `a * R^-1 mod Q` where `R=2^16`. func byteopsMontgomeryReduce(a int32) int16 { - u := int16(a * int32(paramsQInv)) - t := int32(u) * int32(paramsQ) - t = a - t - t >>= 16 - return int16(t) + return int16((a - int32(int16(a*int32(paramsQInv)))*int32(paramsQ)) >> 16) } // byteopsBarrettReduce computes a Barrett reduction; given diff --git a/indcpa.go b/indcpa.go index 70a4ec0..062b740 100644 --- a/indcpa.go +++ b/indcpa.go @@ -111,11 +111,9 @@ func indcpaGenMatrix(seed []byte, transposed bool, paramsK int) ([]polyvec, erro xof.Reset() var err error if transposed { - xof_buf := append(append(make([]byte, 0, len(seed)+2), seed...), byte(i), byte(j)) - _, err = xof.Write(xof_buf) + _, err = xof.Write(append(seed, byte(i), byte(j))) } else { - xof_buf := append(append(make([]byte, 0, len(seed)+2), seed...), byte(j), byte(i)) - _, err = xof.Write(xof_buf) + _, err = xof.Write(append(seed, byte(j), byte(i))) } if err != nil { return []polyvec{}, err diff --git a/ntt.go b/ntt.go index d42a739..28a9834 100644 --- a/ntt.go +++ b/ntt.go @@ -85,12 +85,7 @@ func nttBaseMul( a0 int16, a1 int16, b0 int16, b1 int16, zeta int16, -) [2]int16 { - var r [2]int16 - r[0] = nttFqMul(a1, b1) - r[0] = nttFqMul(r[0], zeta) - r[0] = r[0] + nttFqMul(a0, b0) - r[1] = nttFqMul(a0, b1) - r[1] = r[1] + nttFqMul(a1, b0) - return r +) (int16, int16) { + return nttFqMul(nttFqMul(a1, b1), zeta) + nttFqMul(a0, b0), + nttFqMul(a0, b1) + nttFqMul(a1, b0) } diff --git a/poly.go b/poly.go index c3bfe10..57dcf0c 100644 --- a/poly.go +++ b/poly.go @@ -166,20 +166,16 @@ func polyInvNttToMont(r poly) poly { // in the number-theoretic transform (NTT) domain. func polyBaseMulMontgomery(a poly, b poly) poly { for i := 0; i < paramsN/4; i++ { - rx := nttBaseMul( + a[4*i+0], a[4*i+1] = nttBaseMul( a[4*i+0], a[4*i+1], b[4*i+0], b[4*i+1], nttZetas[64+i], ) - ry := nttBaseMul( + a[4*i+2], a[4*i+3] = nttBaseMul( a[4*i+2], a[4*i+3], b[4*i+2], b[4*i+3], -nttZetas[64+i], ) - a[4*i+0] = rx[0] - a[4*i+1] = rx[1] - a[4*i+2] = ry[0] - a[4*i+3] = ry[1] } return a }