diff --git a/Makefile b/Makefile index dabe95a3..4cf29baa 100644 --- a/Makefile +++ b/Makefile @@ -16,6 +16,7 @@ TEST_SRC+=ecdsa_test.cpp ecdsa_c_test.cpp TEST_SRC+=mul_test.cpp TEST_SRC+=bint_test.cpp TEST_SRC+=low_func_test.cpp +TEST_SRC+=smallmodp_test.cpp ifneq ($(MCL_USE_GMP),1) TEST_SRC+=static_init_test.cpp endif diff --git a/include/mcl/bint.hpp b/include/mcl/bint.hpp index 339265cb..7a601923 100644 --- a/include/mcl/bint.hpp +++ b/include/mcl/bint.hpp @@ -8,7 +8,7 @@ */ #include -#include +#include #include #ifndef MCL_STANDALONE #include @@ -476,5 +476,133 @@ inline Unit getMontgomeryCoeff(Unit pLow, size_t bitSize = sizeof(Unit) * 8) return pp; } +struct SmallModP { + static const size_t d = 16; // d = 26 if use double in approx + static const size_t MAX_MUL_N = 1; // not used because mulSmallUnit is call at first. + static const size_t maxE_ = d - 2; + const Unit *p_; + Unit tbl_[MAX_MUL_N][MCL_MAX_UNIT_SIZE+1]; + size_t n_; + size_t l_; + uint32_t p0_; + + SmallModP() + : n_(0) + , l_(0) + , p0_(0) + { + } + // p must not be temporary. + void init(const Unit *p, size_t n) + { + p_ = p; + n_ = n; + l_ = mcl::fp::getBitSize(p, n); + Unit *t = (Unit*)CYBOZU_ALLOCA((n_+1)*sizeof(Unit)); + mcl::bint::clearN(t, n_+1); + size_t pos = d + l_ - 1; + { + size_t q = pos / MCL_UNIT_BIT_SIZE; + size_t r = pos % MCL_UNIT_BIT_SIZE; + t[q] = Unit(1) << r; + } + // p0 = 2**(d+l-1)/p + Unit q[2]; + mcl::bint::div(q, 2, t, n_+1, p, n_); + assert(q[1] == 0); + p0_ = uint32_t(q[0]); + for (size_t i = 0; i < MAX_MUL_N; i++) { + tbl_[i][n_] = mcl::bint::mulUnitN(tbl_[i], p_, Unit(i+1), n_); // 1~MAX_MUL_N + } + } + Unit approx(Unit x0, size_t a) const + { +// uint64_t t = uint64_t(double(x0) * double(p0_)); // for d = 26 + uint32_t t = uint32_t(x0 * p0_); + return Unit(t >> (2 * d + l_ - 1 - a)); + } + // x[xn] %= p + // the effective range of return value is [0, n_) + bool quot(Unit *pQ, const Unit *x, size_t xn) const + { + size_t a = mcl::fp::getBitSize(x, xn); + if (a < l_) { + *pQ = 0; + return true; + } + size_t e = a - l_ + 1; + if (e > maxE_) return false; + Unit x0 = mcl::fp::getUnitAt(x, xn, a - d); + *pQ = approx(x0, a); + return true; + } + // return false if x[0, xn) is large + bool mod(Unit *z, const Unit *x, size_t xn) const + { + assert(xn <= n_ + 1); + Unit Q; + if (!quot(&Q, x, xn)) return false; + if (Q == 0) { + mcl::bint::copyN(z, x, n_); + return true; + } + Unit *t = (Unit*)CYBOZU_ALLOCA((n_+1)*sizeof(Unit)); + const Unit *pQ = 0; + if (Q <= MAX_MUL_N) { + assert(Q > 0); + pQ = tbl_[Q-1]; + } else { + t[n_] = mcl::bint::mulUnitN(t, p_, Q, n_); + pQ = t; + } + bool b = mcl::bint::subN(t, x, pQ, xn); + assert(!b); (void)b; + if (mcl::bint::cmpGeN(t, tbl_[0], xn)) { // tbl_[0] == p and tbl_[n_] = 0 + mcl::bint::subN(z, t, p_, n_); + } else { + mcl::bint::copyN(z, t, n_); + } + return true; + } +#if 1 + // return false if x[0, xn) is large + template + bool modT(Unit z[N], const Unit *x, size_t xn) const + { + assert(xn <= N + 1); + Unit Q; + if (!quot(&Q, x, xn)) return false; + if (Q == 0) { + mcl::bint::copyT(z, x); + return true; + } + Unit t[N+1]; + const Unit *pQ = 0; + if (Q <= MAX_MUL_N) { + pQ = tbl_[Q-1]; + } else { + t[N] = mcl::bint::mulUnitT(t, p_, Q); + pQ = t; + } + bool b = mcl::bint::subT(t, x, pQ); + assert(!b); (void)b; + if (mcl::bint::cmpGeT(t, tbl_[0])) { + mcl::bint::subT(z, t, p_); + } else { + mcl::bint::copyT(z, t); + } + return true; + } +#endif + template + static bool mulUnit(const SmallModP& smp, Unit z[N], const Unit x[N], Unit y) + { + Unit xy[N+1]; + xy[N] = mulUnitT(xy, x, y); + return smp.modT(z, xy, N+1); +// return smp.mod(z, xy, N+1); + } +}; + } } // mcl::bint diff --git a/include/mcl/fp.hpp b/include/mcl/fp.hpp index 3d03c1e8..d1f9f9a0 100644 --- a/include/mcl/fp.hpp +++ b/include/mcl/fp.hpp @@ -114,21 +114,6 @@ class FpT : public fp::Serializable, op_.fp_add(y, x, x, op_.p); } #endif - static inline void mul9A(Unit *y, const Unit *x) - { - mulSmall(y, x, 9); -// op_.fp_mul9(y, x, op_.p); - } - static inline void mulSmall(Unit *z, const Unit *x, const uint32_t y) - { - assert(y <= op_.smallModp.maxMulN); - Unit xy[maxSize + 1]; - op_.fp_mulUnitPre(xy, x, y); - int v = op_.smallModp.approxMul(xy); - const Unit *pv = op_.smallModp.getPmul(v); - op_.fp_subPre(z, xy, pv); - op_.fp_sub(z, z, op_.p, op_.p); - } public: typedef FpT BaseFp; // return pointer to array v_[] @@ -187,9 +172,6 @@ class FpT : public fp::Serializable, if (op_.fp_mul2A_ == 0) { op_.fp_mul2A_ = mul2A; } - if (op_.fp_mul9A_ == 0) { - op_.fp_mul9A_ = mul9A; - } #endif *pb = true; } @@ -608,23 +590,18 @@ class FpT : public fp::Serializable, } static void mul9(FpT& y, const FpT& x) { -#ifdef MCL_XBYAK_DIRECT_CALL - op_.fp_mul9A_(y.v_, x.v_); -#else - mul9A(y.v_, x.v_); -#endif + mulUnit(y, x, 9); } static inline void addPre(FpT& z, const FpT& x, const FpT& y) { op_.fp_addPre(z.v_, x.v_, y.v_); } static inline void subPre(FpT& z, const FpT& x, const FpT& y) { op_.fp_subPre(z.v_, x.v_, y.v_); } - static inline void mulSmall(FpT& z, const FpT& x, const uint32_t y) - { - mulSmall(z.v_, x.v_, y); - } static inline void mulUnit(FpT& z, const FpT& x, const Unit y) { - if (mulSmallUnit(z, x, y)) return; + if (mcl::fp::mulSmallUnit(z, x, y)) return; + if (op_.mulSmallUnit(op_.smallModP, z.v_, x.v_, y)) return; op_.fp_mulUnit(z.v_, x.v_, y, op_.p); } + // alias of mulUnit + static inline void mulSmall(FpT& z, const FpT& x, const uint32_t y) { mulUnit(z, x, y); } static inline void inv(FpT& y, const FpT& x) { assert(!x.isZero()); diff --git a/include/mcl/gmp_util.hpp b/include/mcl/gmp_util.hpp index afb50192..b0e731fc 100644 --- a/include/mcl/gmp_util.hpp +++ b/include/mcl/gmp_util.hpp @@ -949,85 +949,6 @@ class SquareRoot { #endif }; -/* - x mod p for a small value x < (pMulTblN * p). -*/ -struct SmallModp { - static const size_t unitBitSize = sizeof(Unit) * 8; - static const size_t maxTblSize = (MCL_MAX_BIT_SIZE + unitBitSize - 1) / unitBitSize + 1; - static const size_t maxMulN = 9; - static const size_t pMulTblN = maxMulN + 1; - uint32_t N_; - uint32_t shiftL_; - uint32_t shiftR_; - uint32_t maxIdx_; - // pMulTbl_[i] = (p * i) >> (pBitSize_ - 1) - Unit pMulTbl_[pMulTblN][maxTblSize]; - // idxTbl_[x] = (x << (pBitSize_ - 1)) / p - uint8_t idxTbl_[pMulTblN * 2]; - // return x >> (pBitSize_ - 1) - SmallModp() - : N_(0) - , shiftL_(0) - , shiftR_(0) - , maxIdx_(0) - , pMulTbl_() - , idxTbl_() - { - } - // return argmax { i : x > i * p } - uint32_t approxMul(const Unit *x) const - { - uint32_t top = getTop(x); - assert(top <= maxIdx_); - return idxTbl_[top]; - } - const Unit *getPmul(size_t v) const - { - assert(v < pMulTblN); - return pMulTbl_[v]; - } - uint32_t getTop(const Unit *x) const - { - if (shiftR_ == 0) return x[N_ - 1]; - return (x[N_ - 1] >> shiftR_) | (x[N_] << shiftL_); - } - uint32_t cvtInt(const mpz_class& x) const - { - assert(mcl::gmp::getUnitSize(x) <= 1); - if (x == 0) { - return 0; - } else { - return uint32_t(mcl::gmp::getUnit(x)[0]); - } - } - void init(const mpz_class& p) - { - size_t pBitSize = mcl::gmp::getBitSize(p); - N_ = uint32_t((pBitSize + unitBitSize - 1) / unitBitSize); - shiftR_ = (pBitSize - 1) % unitBitSize; - shiftL_ = unitBitSize - shiftR_; - mpz_class t = 0; - for (size_t i = 0; i < pMulTblN; i++) { - bool b; - mcl::gmp::getArray(&b, pMulTbl_[i], maxTblSize, t); - assert(b); - (void)b; - if (i == pMulTblN - 1) { - maxIdx_ = getTop(pMulTbl_[i]); - assert(maxIdx_ < CYBOZU_NUM_OF_ARRAY(idxTbl_)); - break; - } - t += p; - } - - for (uint32_t i = 0; i <= maxIdx_; i++) { - idxTbl_[i] = cvtInt((mpz_class(int(i)) << (pBitSize - 1)) / p); - } - } -}; - - /* Barrett Reduction for non GMP version diff --git a/include/mcl/op.hpp b/include/mcl/op.hpp index abc3ffa8..e11151a3 100644 --- a/include/mcl/op.hpp +++ b/include/mcl/op.hpp @@ -29,7 +29,7 @@ namespace mcl { -static const int version = 0x193; /* 0xABC = A.BC */ +static const int version = 0x194; /* 0xABC = A.BC */ /* specifies available string format mode for X::setIoMode() @@ -185,7 +185,8 @@ struct Op { mcl::SquareRoot sq; CYBOZU_ALIGN(8) char im[sizeof(mcl::inv::InvModT)]; mcl::Modp modp; - mcl::SmallModp smallModp; +// mcl::SmallModp smallModp; + mcl::bint::SmallModP smallModP; Unit half[maxUnitSize]; // (p + 1) / 2 Unit oneRep[maxUnitSize]; // 1(=inv R if Montgomery) /* @@ -210,7 +211,6 @@ struct Op { void3u fp_mulA_; void2u fp_sqrA_; void2u fp_mul2A_; - void2u fp_mul9A_; void3u fp2_addA_; void3u fp2_subA_; void2u fp2_negA_; @@ -238,6 +238,7 @@ struct Op { void3u fp_mul2; void2uOp fp_invOp; void2uIu fp_mulUnit; // fp_mulUnitPre + bool (*mulSmallUnit)(const mcl::bint::SmallModP&, Unit *z, const Unit *x, Unit y); void3u fpDbl_mulPre; void2u fpDbl_sqrPre; @@ -300,7 +301,6 @@ struct Op { fp_mulA_ = 0; fp_sqrA_ = 0; fp_mul2A_ = 0; - fp_mul9A_ = 0; fp2_addA_ = 0; fp2_subA_ = 0; fp2_negA_ = 0; @@ -328,6 +328,7 @@ struct Op { fp_mul2 = 0; fp_invOp = 0; fp_mulUnit = 0; + mulSmallUnit = 0; fpDbl_mulPre = 0; fpDbl_sqrPre = 0; diff --git a/include/mcl/util.hpp b/include/mcl/util.hpp index 6ec62b52..36706928 100644 --- a/include/mcl/util.hpp +++ b/include/mcl/util.hpp @@ -61,6 +61,18 @@ T getUnitAt(const T *x, size_t xN, size_t bitPos) return (x[q] >> r) | (x[q + 1] << (TbitSize - r)); } +template +size_t getBitSize(const T *x, size_t n) +{ + while (n > 0 && (x[n - 1] == 0)) { + n--; + } + if (n == 0) { + return 0; + } + return (n - 1) * sizeof(T) * 8 + 1 + cybozu::bsr(x[n - 1]); +} + template class BitIterator { const T *x_; @@ -78,14 +90,7 @@ class BitIterator { { x_ = x; bitPos_ = 0; - while (n > 0 && (x[n - 1] == 0)) { - n--; - } - if (n == 0) { - bitSize_ = 0; - return; - } - bitSize_ = (n - 1) * sizeof(T) * 8 + 1 + cybozu::bsr(x_[n - 1]); + bitSize_ = mcl::fp::getBitSize(x, n); } size_t getBitSize() const { return bitSize_; } bool hasNext() const { return bitPos_ < bitSize_; } @@ -135,6 +140,7 @@ class BitIterator { /* shortcut of multiplication by Unit + remark : support b times where y^2=x^3+a x + b. */ template bool mulSmallUnit(T& z, const T& x, U y) @@ -149,10 +155,14 @@ bool mulSmallUnit(T& z, const T& x, U y) case 6: { T t; T::add(t, x, x); T::add(t, t, x); T::add(z, t, t); break; } case 7: { T t; T::add(t, x, x); T::add(t, t, t); T::add(t, t, t); T::sub(z, t, x); break; } case 8: T::add(z, x, x); T::add(z, z, z); T::add(z, z, z); break; + // require FpDblT::mulPre for xi.a = 9 case 9: { T t; T::add(t, x, x); T::add(t, t, t); T::add(t, t, t); T::add(z, t, x); break; } + // slower than SmallModP +#if 0 case 10: { T t; T::add(t, x, x); T::add(t, t, t); T::add(t, t, x); T::add(z, t, t); break; } case 11: { T t; T::add(t, x, x); T::add(t, t, x); T::add(t, t, t); T::add(t, t, t); T::sub(z, t, x); break; } case 12: { T t; T::add(t, x, x); T::add(t, t, t); T::add(z, t, t); T::add(z, z, t); break; } +#endif default: return false; } diff --git a/include/mcl/vint.hpp b/include/mcl/vint.hpp index de1481e8..e26e0b3a 100644 --- a/include/mcl/vint.hpp +++ b/include/mcl/vint.hpp @@ -429,9 +429,8 @@ class Vint { { if (isZero()) return 1; size_t n = size(); - Unit v = buf_[n - 1]; - assert(v); - return (n - 1) * sizeof(Unit) * 8 + 1 + cybozu::bsr(v); + assert(buf_[n-1]); + return mcl::fp::getBitSize(buf_, n); } // ignore sign bool testBit(size_t i) const diff --git a/misc/mul-approx.py b/misc/mul-approx.py new file mode 100644 index 00000000..c6ea1f75 --- /dev/null +++ b/misc/mul-approx.py @@ -0,0 +1,139 @@ +""" +how to find the quotient of p/x. + +# Notation +d = 26 # half of the double-precision bit length (52) +l = p.bit_length(), then 2**(l-1) <= p < 2**l +a = x.bit_length(), then 2**(a-1) <= x < 2**a + +assume d < l <= a <= l + d - 3 +e = a - l + 1, then e <= d - 2 +assume max(e) = 9 + +# Preparation +(p0, p1) = divmod(2**(d+l-1), p) +# 2**(d+l-1) = p0 * p + p1, p1 < p + +# Quotient +input : x < 2**(l+e-1) +a = x.bit_length() +(x0, x1) = divmod(x, 2**(a-d)) +# x = x0 * 2**(a-d) + x1, x1 < 2**(a-d) + +s = 2 * d - e +S = 2**s + +Q'=(x0 * p0) >> s +# (Q', R') = divmod(x0 * p0, S), x0 * p0 = Q' S + R', R' < S + +(Q, R) = divmod(x, p), x = Q * p + R, R < p + +# Theorem +0 <= Q - Q' <= 1 + +S p (Q - Q') = S (p Q) - p (S Q') = S(x - R) - p(x0 p0 - R') + = S(x0 * 2**(a-d) + x1 - R) - x0 p p0 + p R' + = 2**(2 d - e + a - d) x0 + S x1 - S R - x0 (2**(d+l-1) - p1)) + p R' + = 2**(d + l -1) x0 + S x1 - S R - x0 2**(d+l-1) + x0 p1 + p R' + = S x1 + p1 x0 + p R' - S R + +Q - Q' = (x1/p) + (p1/p)(x0/S) + (R'/S) - (R/p) + +0 <= x1/p < 2**(a-d)/2**(l-1) = 1/2**(d-(a-l+1))=1/2**(d-e) <= 1/4 +0 <= p1/p < 1 +0 <= x0/S < 2**d / 2**(2d-e) = 1/2**(d-e) <= 1/4 +0 <= R'/S < 1 +0 <= R/p < 1 + +-1 < Q - Q' < 1+1/2 +""" +class ApproxMul: + def __init__(self, p, d): + self.p = p + self.d = d + self.l = p.bit_length() + t = 1<<(d+self.l-1) + (q, r) = divmod(t, self.p) + self.p0 = q + self.p1 = r + + def __str__(self): + return f'''p={self.p} +d={self.d} +l={self.l} +p0={self.p0} +p1={self.p1}''' + + def getTop(self, x): + """ + return (x0, x1) such that x = x0 * 2**(a-d) + x1 where a = x.bit_length() + """ + if x < self.p: + return (0, x) + a = x.bit_length() + t = 1<<(a-self.d) + return divmod(x, t) + + def quot(self, x): + (x0, x1) = self.getTop(x) + a = x.bit_length() + s= 2*self.d -(a - self.l + 1) + return (x0 * self.p0) >> s + + def check(self, x): + (x0, x1) = self.getTop(x) + a = x.bit_length() + (Q, R) = divmod(x, self.p) + S = 1<<(2*self.d - (a - self.l + 1)) + (Q2, R2) = divmod(x0 * self.p0, S) + lhs = S * self.p * (Q - Q2) + rhs = S * x1 + self.p1 * x0 - S * R + self.p * R2 + if Q == Q2 or Q == Q2 + 1: + return + if Q != Q2: + print('rare case') + print(f'{x=}') + print(f'{x0=}') + print(f'{x1=}') + print(f'{Q=}') + print(f'{R=}') + print(f'{Q2=}') + print(f'{R2=}') + if lhs != rhs: + print(f'check err {x=}') + print(f'{Q=} {R=}') + print(f'{Q2=} {R2=}') + print(f'{lhs=} {rhs=} {lhs==rhs=}') + ERR + + + +def test(p): + print(f'test {p=}') + import random + app = ApproxMul(p, 26) + print(app) + + MAX = p * 256 + random.seed(a=12345) + + app.check(715409340372908097786544094000490505679080949411292527675476747206857849744375764344129765863746114129605942739419060) + ERR + for i in range(0, 100): + x = p * i + app.check(x) + x = p * i + p-1 + app.check(x) + + for i in range(1000000): + x = random.randint(p, p*2) *random.randint(1, 256) + app.check(x) + +def __main__(): + r = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 + p = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab + test(p) + test(r) + + +__main__() diff --git a/src/fp.cpp b/src/fp.cpp index 650f8954..e430be47 100644 --- a/src/fp.cpp +++ b/src/fp.cpp @@ -300,6 +300,7 @@ void setOp(Op& op) op.fp_shr1 = shr1T; op.fp_neg = negT; op.fp_mulUnitPre = mulUnitPreT; + op.mulSmallUnit = bint::SmallModP::mulUnit; op.fp_addPre = bint::get_add(N); op.fp_subPre = bint::get_sub(N); op.fpDbl_addPre = bint::get_add(N * 2); @@ -580,7 +581,8 @@ bool Op::init(const mpz_class& _p, size_t maxBitSize, int _xi_a, Mode mode, size if (!b) return false; } modp.init(mp); - smallModp.init(mp); +// smallModp.init(mp); + smallModP.init(p, N); return fp::initForMont(*this, p, mode); } diff --git a/test/bint_test.cpp b/test/bint_test.cpp index 9023765d..111333bd 100644 --- a/test/bint_test.cpp +++ b/test/bint_test.cpp @@ -921,6 +921,98 @@ CYBOZU_TEST_AUTO(sqr) testSqr<9>(); } +template +void setAndModT(const mcl::bint::SmallModP& smp, Unit x[N+1]) +{ + x[N-1] = mcl::bint::mulUnit1(&x[N], x[N-1], x[0] & 0x3f); + size_t xn = x[N] == 0 ? N : N+1; + if (!smp.modT(x, x, xn)) { + puts("ERR2"); + exit(1); + } +} + +template +void setAndMod(const mcl::bint::SmallModP& smp, Unit x[N+1]) +{ + x[N-1] = mcl::bint::mulUnit1(&x[N], x[N-1], x[0] & 0x3f); + size_t xn = x[N] == 0 ? N : N+1; + if (!smp.mod(x, x, xn)) { + puts("ERR1"); + exit(1); + } +} + +template +void testSmallModP(const char *pStr) +{ + printf("p=%s\n", pStr); + Unit p[N]; + const size_t FACTOR = 128; + size_t xn = mcl::fp::hexToArray(p, N, pStr, strlen(pStr)); + CYBOZU_TEST_EQUAL(xn, N); + mcl::bint::SmallModP smp; + smp.init(p, N); + cybozu::XorShift rg; + Unit x[N+1]; + mcl::bint::copyT(x, p); + for (size_t i = 0; i < 10; i++) { + uint32_t a = rg.get32() % FACTOR; + x[N-1] = mcl::bint::mulUnit1(&x[N], x[N-1], a); + size_t xn = x[N] == 0 ? N : N+1; + Unit q[2], r[N+1]; + mcl::bint::copyN(r, x, xn); + mcl::bint::div(q, 2, r, xn, p, N); + CYBOZU_TEST_ASSERT(q[0] <= FACTOR && q[1] == 0); + for (int mode = 0; mode < 2; mode++) { + Unit r2[N]; + bool b = false; + switch (mode) { + case 0: b = smp.mod(r2, x, xn); break; + case 1: b = smp.modT(r2, x, xn); break; + } + CYBOZU_TEST_ASSERT(b); + CYBOZU_TEST_EQUAL_ARRAY(r2, r, N); + } + mcl::bint::copyT(x, r); + } +#ifdef NDEBUG + { + if ((smp.p_[N-1] >> (MCL_UNIT_BIT_SIZE - 8)) == 0) return; // top 8-bit must be not zero + CYBOZU_BENCH_C("mod ", 1000, setAndMod, smp, x); + CYBOZU_BENCH_C("modT", 1000, setAndModT, smp, x); + } +#endif +} + +CYBOZU_TEST_AUTO(SmallModP) +{ + const size_t adj = 8 / sizeof(Unit); + const char *tbl4[] = { + "2523648240000001ba344d80000000086121000000000013a700000000000013", + "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001", // BLS12-381 r + "7523648240000001ba344d80000000086121000000000013a700000000000017", + "800000000000000000000000000000000000000000000000000000000000005f", + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f", // secp256k1 + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff43", // max prime + // not primes + "ffffffffffffffffffffffffffffffffffffffffffffffff0000000000000001", + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffff00000001", + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + }; + for (size_t i = 0; i < CYBOZU_NUM_OF_ARRAY(tbl4); i++) { + testSmallModP<4 * adj>(tbl4[i]); + } + const char *tbl6[] = { + "1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab", // BLS12-381 p + "240026400f3d82b2e42de125b00158405b710818ac000007e0042f008e3e00000000001080046200000000000000000d", // BN381 r + "240026400f3d82b2e42de125b00158405b710818ac00000840046200950400000000001380052e000000000000000013", // BN381 p + }; + for (size_t i = 0; i < CYBOZU_NUM_OF_ARRAY(tbl6); i++) { + testSmallModP<6 * adj>(tbl6[i]); + } +} + #if 0 template void testMulLow() diff --git a/test/low_func_test.cpp b/test/low_func_test.cpp index 26b54b65..b9ee62b9 100644 --- a/test/low_func_test.cpp +++ b/test/low_func_test.cpp @@ -2,6 +2,7 @@ #include #include #include "../src/low_func.hpp" +#include #define PUT(x) std::cout << #x "=" << (x) << std::endl; @@ -73,10 +74,12 @@ void testEdge(const mpz_class& p) CYBOZU_TEST_AUTO(limit) { + const size_t adj = 8 / sizeof(Unit); std::cout << std::hex; const char *tbl4[] = { "0000000000000001000000000000000000000000000000000000000000000085", // min prime "2523648240000001ba344d80000000086121000000000013a700000000000013", + "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001", // BLS12-381 r "7523648240000001ba344d80000000086121000000000013a700000000000017", "800000000000000000000000000000000000000000000000000000000000005f", "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f", // secp256k1 @@ -90,7 +93,18 @@ CYBOZU_TEST_AUTO(limit) printf("p=%s\n", tbl4[i]); mpz_class p; mcl::gmp::setStr(p, tbl4[i], 16); - testEdge<4 * (8 / sizeof(Unit))>(p); + testEdge<4 * adj>(p); + } + const char *tbl6[] = { + "1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab", // BLS12-381 p + "240026400f3d82b2e42de125b00158405b710818ac000007e0042f008e3e00000000001080046200000000000000000d", // BN381 r + "240026400f3d82b2e42de125b00158405b710818ac00000840046200950400000000001380052e000000000000000013", // BN381 p + }; + for (size_t i = 0; i < CYBOZU_NUM_OF_ARRAY(tbl6); i++) { + printf("p=%s\n", tbl6[i]); + mpz_class p; + mcl::gmp::setStr(p, tbl6[i], 16); + testEdge<6 * adj>(p); } } diff --git a/test/smallmodp_test.cpp b/test/smallmodp_test.cpp new file mode 100644 index 00000000..666302b5 --- /dev/null +++ b/test/smallmodp_test.cpp @@ -0,0 +1,42 @@ +#include +#include +#include +#include + +using namespace mcl; +using namespace mcl::bn; + +template +void test(const char *s) +{ + puts(s); + cybozu::XorShift rg; + F x, z1, z2; + for (size_t i = 0; i < 1000; i++) { + x.setByCSPRNG(rg); + uint32_t y = uint32_t(i); + z1 = x * y; + F::mulUnit(z2, x, y); + CYBOZU_TEST_EQUAL(z1, z2); + } +#ifdef NDEBUG + const int C = 100000; + for (uint32_t i = 1; i < 10; i++) { + printf("i=% 2d ", i); + CYBOZU_BENCH_C("mulUnit", C, F::mulUnit, x, x, i); + } + CYBOZU_BENCH_C("mulUnit [1-256]", C, F::mulUnit, x, x, (*x.getUnit() % 256) + 1); + CYBOZU_BENCH_C("mulUnit [1000-1255]", C, F::mulUnit, x, x, (*x.getUnit() % 256) + 1000); + CYBOZU_BENCH_C("mulUnit all", C, F::mulUnit, x, x, uint32_t(*x.getUnit())); + CYBOZU_BENCH_C("mul(F, u32)", C, F::mul, x, x, uint32_t(*x.getUnit())); + CYBOZU_BENCH_C("mul(F, F)", C, F::mul, x, x, x); + CYBOZU_TEST_ASSERT(x != 0); +#endif +} + +CYBOZU_TEST_AUTO(main) +{ + initPairing(BLS12_381); + test("Fr"); + test("Fp"); +}