From 8de2e360b428990013351fdd001e7bb6790fd4a4 Mon Sep 17 00:00:00 2001 From: MITSUNARI Shigeo Date: Thu, 20 Jun 2024 11:35:35 +0900 Subject: [PATCH] add mod and modT --- src/low_func.hpp | 33 +++++++++++++++++++++++++++------ test/low_func_test.cpp | 42 +++++++++++++++++++++++++++++------------- 2 files changed, 56 insertions(+), 19 deletions(-) diff --git a/src/low_func.hpp b/src/low_func.hpp index d7e29dfe..a20a3bfb 100644 --- a/src/low_func.hpp +++ b/src/low_func.hpp @@ -288,19 +288,22 @@ static void sqrModT(Unit *y, const Unit *x, const Unit *p) fpDblModT(y, xx, p); } -template struct SmallModP { + const size_t d = 16; // d = 26 if use double in approx const size_t maxE_ = d - 2; const Unit *p_; + const size_t n_; const size_t l_; uint32_t p0_; // p must not be temporary. - explicit SmallModP(const Unit *p) + explicit SmallModP(const Unit *p, size_t n) : p_(p) - , l_(getBitSize(p, N)) + , n_(n) + , l_(getBitSize(p, n)) { - Unit t[N+1] = {}; + Unit *t = (Unit*)CYBOZU_ALLOCA((n_+1)*sizeof(Unit)); + mcl::bint::clearN(t, n_+1); size_t pos = d + l_ - 1; { size_t q = pos / MCL_UNIT_BIT_SIZE; @@ -309,7 +312,7 @@ struct SmallModP { } // p0 = 2**(d+l-1)/p Unit q[2]; - mcl::bint::div(q, 2, t, N+1, p, N); + mcl::bint::div(q, 2, t, n_+1, p, n_); assert(q[1] == 0); p0_ = uint32_t(q[0]); } @@ -320,7 +323,7 @@ struct SmallModP { return Unit(t >> (2 * d + l_ - 1 - a)); } // x[xn] %= p - // the effective range of return value is [0, N) + // the effective range of return value is [0, n_) bool quot(Unit *pQ, const Unit *x, size_t xn) const { size_t a = getBitSize(x, xn); @@ -336,6 +339,24 @@ struct SmallModP { } // return false if x[0, xn) is large bool mod(Unit *x, size_t xn) const + { + assert(xn <= n_ + 1); + Unit Q; + if (!quot(&Q, x, xn)) return false; + if (Q == 0) return true; + Unit *t = (Unit*)CYBOZU_ALLOCA((n_+1)*sizeof(Unit)); + t[n_] = mcl::bint::mulUnitN(t, p_, Q, n_); + mcl::bint::subN(t, x, t, n_+1); + if (mcl::bint::cmpGeN(t, p_, n_)) { + mcl::bint::subN(x, t, p_, n_); + } else { + mcl::bint::copyN(x, t, n_); + } + return true; + } + template + // return false if x[0, xn) is large + bool modT(Unit *x, size_t xn) const { assert(xn <= N + 1); Unit Q; diff --git a/test/low_func_test.cpp b/test/low_func_test.cpp index ec837376..05a7e2df 100644 --- a/test/low_func_test.cpp +++ b/test/low_func_test.cpp @@ -73,21 +73,31 @@ void testEdge(const mpz_class& p) } template -void setAndMod(const mcl::fp::SmallModP& smp, Unit *x) +void setAndModT(const mcl::fp::SmallModP& smp, Unit *x, size_t xn) { - x[N] = x[0] & 0x3f; - if (!smp.mod(x, N+1)) { - puts("ERR"); + x[smp.n_] = x[0] & 0x3f; + if (!smp.modT(x, xn)) { + puts("ERR2"); + exit(1); + } +} + +void setAndMod(const mcl::fp::SmallModP& smp, Unit *x, size_t xn) +{ + x[smp.n_] = x[0] & 0x3f; + if (!smp.mod(x, xn)) { + puts("ERR1"); + exit(1); } } template void testSmallModP(const mpz_class& p) { - mcl::fp::SmallModP smp(p.getUnit()); + mcl::fp::SmallModP smp(p.getUnit(), N); cybozu::XorShift rg; mpz_class x; - for (size_t i = 0; i < 1000; i++) { + for (size_t i = 0; i < 10; i++) { x.setRand(p, rg); x += p; x *= int(rg.get32() % 128) + 1; @@ -99,19 +109,25 @@ void testSmallModP(const mpz_class& p) if (b) { CYBOZU_TEST_ASSERT(Q1 == Q2 || Q1 == Q2 + 1); } - Unit x2[N+1] = {}; - mcl::bint::copyN(x2, x.getUnit(), x.getUnitSize()); - b = smp.mod(x2, x.getUnitSize()); - mpz_class x3 = x % p; - CYBOZU_TEST_ASSERT(b); - CYBOZU_TEST_EQUAL_ARRAY(x2, x3.getUnit(), x3.getUnitSize()); + for (int mode = 0; mode < 2; mode++) { + Unit x2[N+1] = {}; + mcl::bint::copyN(x2, x.getUnit(), x.getUnitSize()); + switch (mode) { + case 0: b = smp.mod(x2, x.getUnitSize()); break; + case 1: b = smp.modT(x2, x.getUnitSize()); break; + } + mpz_class x3 = x % p; + CYBOZU_TEST_ASSERT(b); + CYBOZU_TEST_EQUAL_ARRAY(x2, x3.getUnit(), x3.getUnitSize()); + } } #ifdef NDEBUG { if ((smp.p_[N-1] >> (MCL_UNIT_BIT_SIZE - 8)) == 0) return; // top 8-bit must be not zero Unit x[N+1]; mcl::gmp::getArray(x, N+1, p); - CYBOZU_BENCH_C("mod", 1000, setAndMod, smp, x); + CYBOZU_BENCH_C("mod ", 1000, setAndMod, smp, x, N+1); + CYBOZU_BENCH_C("modT", 1000, setAndModT, smp, x, N+1); } #endif }