From d7b458a774e2e5552502a39a08201fed646c9542 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?SZIGETI=20J=C3=A1nos?= Date: Wed, 31 Jan 2024 12:51:55 +0100 Subject: [PATCH] [ISSUE-0089]: safe mul / div for bigdecimal --- src/bigdecimal128.c | 62 +++++++++ src/bigdecimal128.h | 2 + tests/bigdecimal128_mul_test.c | 225 ++++++++++++++++++++------------- 3 files changed, 198 insertions(+), 91 deletions(-) diff --git a/src/bigdecimal128.c b/src/bigdecimal128.c index cc287f5..f26f8f8 100644 --- a/src/bigdecimal128.c +++ b/src/bigdecimal128.c @@ -315,6 +315,22 @@ BigDecimal128 bigdecimal128_mul(const BigDecimal128 *a, const BigDecimal128 *b) return (BigDecimal128){biguint128_mul(&a->val, &b->val), a->prec + b->prec}; } +buint_bool bigdecimal128_mul_safe(BigDecimal128 *dest, const BigDecimal128 *a, const BigDecimal128 *b) { + buint_bool inv[2]; + BigUInt128 par[] = {bigint128_abs(&a->val, &inv[0]),bigint128_abs(&b->val, &inv[1])}; + BigUIntPair128 dmul = biguint128_dmul(&par[0], &par[1]); + buint_bool valid = biguint128_eqz(&dmul.second); + if (!!inv[0] != !!inv[1]) { // result should be not positive + bigint128_negate_assign(&dmul.first); + valid &= bigint128_ltz(&dmul.first) || biguint128_eqz(&dmul.first); + } else { + valid &= !bigint128_ltz(&dmul.first); + } + dest->val = dmul.first; + dest->prec = a->prec + b->prec; + return valid; +} + BigDecimal128 bigdecimal128_div_fast(const BigDecimal128 *a, const BigDecimal128 *b, UInt prec) { UInt ac_prec = a->prec + b->prec + prec; BigDecimal128 ac = bigdecimal128_ctor_prec(a, ac_prec); @@ -342,6 +358,52 @@ BigDecimal128 bigdecimal128_div(const BigDecimal128 *a, const BigDecimal128 *b, return retv; } +buint_bool bigdecimal128_div_safe(BigDecimal128 *dest, const BigDecimal128 *a, const BigDecimal128 *b, UInt prec) { + static buint_bool first = 1; + static BigUInt128 max; + if (first) { + BigUInt128 minusone = bigint128_value_of_uint(-1); + BigUInt128 maxx = biguint128_div5(&minusone).first; + max = biguint128_shr(&maxx, 2); + first = 0; + } + buint_bool ainv, binv; + BigUInt128 av = bigint128_abs(&a->val, &ainv); + BigUInt128 bv = bigint128_abs(&b->val, &binv); + + // div.prec = a.prec - b.prec + // mul10 iterations: prec - div.prec = prec + b.prec - a.prec + // 10 / 0.3 = 33 -- a.prec=0, b.prec=1 + // 100.00 / 3 = 33(.33) -- a.prec=2, b.prec=0 + if (a->prec < prec + b->prec) { + UInt i = a->prec; + dest->val = biguint128_ctor_default(); + while (1) { + BigUIntPair128 qq = biguint128_div(&av, &bv); + biguint128_add_assign(&dest->val, &qq.first); + av = qq.second; + if (i == prec + b->prec) break; + ++i; + if (biguint128_lt(&max, &dest->val)) return 0; + dest->val = biguint128_mul10(&dest->val); + av = biguint128_mul10(&av); + } + for (UInt i = a->prec; i < prec + b->prec; ++i) { + + } + } else { + for (UInt i = prec + b->prec; i < a->prec; ++i) { + av = biguint128_div10(&av).first; + } + dest->val = biguint128_div(&av, &bv).first; + } + dest->prec = prec; + if (ainv != binv) { + bigint128_negate_assign(&dest->val); + } + return 1; +} + // I/O BigDecimal128 bigdecimal128_ctor_cstream(const char *dec_digits, buint_size_t len) { diff --git a/src/bigdecimal128.h b/src/bigdecimal128.h index 08618c2..872aa10 100644 --- a/src/bigdecimal128.h +++ b/src/bigdecimal128.h @@ -81,6 +81,7 @@ buint_bool bigdecimal128_sub_safe(BigDecimal128 *dest, const BigDecimal128 *a, c @brief Multiplication. */ BigDecimal128 bigdecimal128_mul(const BigDecimal128 *a, const BigDecimal128 *b); +buint_bool bigdecimal128_mul_safe(BigDecimal128 *dest, const BigDecimal128 *a, const BigDecimal128 *b); /** @brief Fast division algorithm with narrow operating range. @@ -98,6 +99,7 @@ BigDecimal128 bigdecimal128_div_fast(const BigDecimal128 *a, const BigDecimal128 TODO: insert if (remainder==0) break. */ BigDecimal128 bigdecimal128_div(const BigDecimal128 *a, const BigDecimal128 *b, UInt prec); +buint_bool bigdecimal128_div_safe(BigDecimal128 *dest, const BigDecimal128 *a, const BigDecimal128 *b, UInt prec); /** @brief 'Less than' relation between signed decimal numbers. diff --git a/tests/bigdecimal128_mul_test.c b/tests/bigdecimal128_mul_test.c index 2cad669..b2186ee 100644 --- a/tests/bigdecimal128_mul_test.c +++ b/tests/bigdecimal128_mul_test.c @@ -2,6 +2,7 @@ #include #include #include +#include #include "bigdecimal128.h" #include "test_common.h" @@ -11,131 +12,173 @@ #define BUFLEN (BIGDECCAP + 3) // INTERNAL TYPES +typedef BigDecimal128(*BigDecimalDivFun)(const BigDecimal128 *a, const BigDecimal128 *b, UInt prec); +typedef BigDecimal128(*BigDecimalBinaryFun)(const BigDecimal128 *a, const BigDecimal128 *b); -typedef struct { - CStr num1; - CStr num2; - UInt resprec; -} TestInputType; - -typedef struct { - CStr num; -} TestOutputType; - -typedef BigDecimal128 (*BigDecimalFun)(const BigDecimal128 *a, const BigDecimal128 *b, UInt prec); +typedef union { + BigDecimalDivFun d; + BigDecimalBinaryFun m; +} BigDecimalXFun; // TEST DATA -const TestInputType any_in[] = { - {STR("20"),STR("0.5"),0}, - {STR("0.2"),STR("0.6"),3}, - {STR("10"),STR("30"),5}, - {STR("-1.5"),STR("1.5"),3}, - {STR("2.0"),STR("-0.5"),1}, - {STR("-10"),STR("0.001"),2}, - {STR("0.000"),STR("10.00"),4} +const CStr any_in[][3] = { + {STR("20"),STR("0.5"),STR("0")}, + {STR("0.2"),STR("0.6"),STR("3")}, + {STR("10"),STR("30"),STR("5")}, + {STR("-1.5"),STR("1.5"),STR("3")}, + {STR("2.0"),STR("-0.5"),STR("1")}, + {STR("-10"),STR("0.001"),STR("2")}, + {STR("100.00"),STR("3"),STR("0")}, + {STR("-200.00"),STR("7"),STR("1")}, + {STR("200.0"),STR("-9.00"),STR("0")}, + {STR("-300"),STR("-7.00"),STR("3")}, + {STR("0.000"),STR("10.00"),STR("4")} }; -const TestOutputType mul_out[] = { - {STR("10.0")}, - {STR("0.12")}, - {STR("300")}, - {STR("-2.25")}, - {STR("-1.00")}, - {STR("-0.010")}, - {STR("0.00000")} +const CStr mul_out[] = { + STR("10.0"), + STR("0.12"), + STR("300"), + STR("-2.25"), + STR("-1.00"), + STR("-0.010"), + STR("300.00"), + STR("-1400.00"), + STR("-1800.000"), + STR("2100.00"), + STR("0.00000") }; -const TestOutputType div_out[] = { - {STR("40")}, - {STR("0.333")}, - {STR("0.33333")}, - {STR("-1.000")}, - {STR("-4.0")}, - {STR("-10000.00")}, - {STR("0.0000")} +const CStr div_out[] = { + STR("40"), + STR("0.333"), + STR("0.33333"), + STR("-1.000"), + STR("-4.0"), + STR("-10000.00"), + STR("33"), + STR("-28.5"), + STR("-22"), + STR("42.857"), + STR("0.0000") }; -int input_len = sizeof(any_in) / sizeof(any_in[0]); - // INTERNAL FUNCTIONS -/** - * @return true: test failed, false: test passed. - */ -static inline bool eval_divtestcase_(const TestInputType *tin, const TestOutputType *expected, BigDecimalFun fun, char *funstr) { + +static inline bool eval_xtestcase_(const CStr *tin, const CStr *expected, bool is_mul, BigDecimalXFun fun, const char *funstr) { + bool pass = true; char buffer[BUFLEN + 1]; - if (BIGDECCAP < tin->num1.len || BIGDECCAP < tin->num2.len) + if (BIGDECCAP < tin[0].len || BIGDECCAP < tin[1].len) return false; - BigDecimal128 a = bigdecimal128_ctor_cstream(tin->num1.str, tin->num1.len); - BigDecimal128 b = bigdecimal128_ctor_cstream(tin->num2.str, tin->num2.len); + BigDecimal128 a = bigdecimal128_ctor_cstream(tin[0].str, tin[0].len); + BigDecimal128 b = bigdecimal128_ctor_cstream(tin[1].str, tin[1].len); + UInt p = (UInt) atoi(tin[2].str); - BigDecimal128 quotient = fun(&a, &b, tin->resprec); - buint_size_t len = bigdecimal128_print("ient, buffer, sizeof(buffer) / sizeof(char) - 1); + BigDecimal128 result = is_mul ? fun.m(&a, &b) : fun.d(&a, &b, p); + buint_size_t len = bigdecimal128_print(&result, buffer, sizeof (buffer) / sizeof (char) - 1); buffer[len] = 0; - int result = strcmp(expected->num.str, buffer); + int res_cmp = strcmp(expected->str, buffer); - if (result != 0) { - fprintf(stderr, "input: %s(%s, %s , %"PRIuint"); expected output: [%s], actual [%s]\n", - funstr, tin->num1.str, tin->num2.str, tin->resprec, expected->num.str, buffer); - return true; + if (res_cmp != 0) { + if (is_mul) { + fprintf(stderr, "input: %s(%s, %s); expected output: [%s], actual [%s]\n", + funstr, tin[0].str, tin[1].str, expected->str, buffer); + } else { + fprintf(stderr, "input: %s(%s, %s , %s); expected output: [%s], actual [%s]\n", + funstr, tin[0].str, tin[1].str, tin[2].str, expected->str, buffer); + } + pass = false; } - return false; + return pass; } -// TEST FUNCTIONS -bool test_mul0() { - bool fail = false; - char buffer[BUFLEN + 1]; - for (int i = 0; i < input_len; ++i) { - const TestInputType *ti = &any_in[i]; - const TestOutputType *expected = &mul_out[i]; - if (BIGDECCAP < ti->num1.len || BIGDECCAP < ti->num2.len) - continue; - - BigDecimal128 a = bigdecimal128_ctor_cstream(ti->num1.str, ti->num1.len); - BigDecimal128 b = bigdecimal128_ctor_cstream(ti->num2.str, ti->num2.len); - - BigDecimal128 prod = bigdecimal128_mul(&a, &b); - - buint_size_t len = bigdecimal128_print(&prod, buffer, sizeof(buffer) / sizeof(char) - 1); - buffer[len] = 0; +// WRAPPER FUNCTIONS +BigDecimal128 div_safe_(const BigDecimal128 *a, const BigDecimal128 *b, UInt prec) { + BigDecimal128 dest; + bigdecimal128_div_safe(&dest, a, b, prec); + return dest; +} - int result = strcmp(expected->num.str, buffer); +BigDecimal128 mul_safe_(const BigDecimal128 *a, const BigDecimal128 *b) { + BigDecimal128 dest; + bigdecimal128_mul_safe(&dest, a, b); + return dest; +} - if (result != 0) { - fprintf(stderr, "input: (%s * %s); expected output: [%s], actual [%s]\n", - ti->num1.str, ti->num2.str, expected->num.str, buffer); - fail = true; - } +bool test_x_regular(BigDecimalXFun fun, bool is_mul, const char *funname) { + bool pass = true; + for (unsigned int i = 0; i < ARRAYSIZE(any_in); ++i) { + pass &=eval_xtestcase_(&any_in[i][0], &(is_mul?mul_out :div_out)[i], is_mul, fun, funname); } - return !fail; + return pass; } -bool test_div0() { - bool fail = false; - for (int i = 0; i < input_len; ++i) { - bool testfailed = eval_divtestcase_(&any_in[i], &div_out[i], bigdecimal128_div, "div"); - fail|=testfailed; +bool test_div_oor_prec(UInt prec_a, UInt prec_b) { + bool pass = true; + + BigDecimal128 a = {biguint128_value_of_uint(1), prec_a}; + BigDecimal128 b = {biguint128_value_of_uint(3), prec_b}; + BigDecimal128 q; + for (UInt pi = 0; pi < 128; ++pi) { + buint_bool exp_res; + if (pi + prec_b < prec_a + (128 * 3) / 10 + 128/100) exp_res = 1; + else if (prec_a + (128 * 3) / 10 + 128/96 + 1 < pi + prec_b) exp_res = 0; + else continue; + buint_bool res = bigdecimal128_div_safe(&q, &a, &b, pi); + if (!!exp_res != !!res) { + fprintf(stderr, "div_safe failed at prec %u\n", (unsigned int) pi); + pass = false; + } } - return !fail; + return pass; } -bool test_div_fast0() { - bool fail = false; - for (int i = 0; i < input_len; ++i) { - bool testfailed = eval_divtestcase_(&any_in[i], &div_out[i], bigdecimal128_div_fast, "div"); - fail|=testfailed; +bool test_mul_oor_prec(unsigned int zeroes) { + bool pass = true; + + BigDecimal128 a = {biguint128_ctor_default(), 0}; + biguint128_sbit(&a.val, 128 - 1 - zeroes); + BigDecimal128 p_ab; // product a*b + BigDecimal128 p_ac; // product a*c + + for (UInt bi = 0; bi < 128 - 1; ++bi) { + BigDecimal128 b = {biguint128_ctor_default(), 0}; + biguint128_sbit(&b.val, bi); + BigDecimal128 c = {bigint128_negate(&b.val), 0}; + buint_bool exp_res_ab = (bi < zeroes); + buint_bool exp_res_ac = (bi < zeroes + 1); + buint_bool res_ab = bigdecimal128_mul_safe(&p_ab, &a, &b); + buint_bool res_ac = bigdecimal128_mul_safe(&p_ac, &a, &c); + + if (!!exp_res_ab != !!res_ab) { + fprintf(stderr, "positive mul_safe failed at bit %u\n", (unsigned int) bi); + pass = false; + } + if (!!exp_res_ac != !!res_ac) { + fprintf(stderr, "negative mul_safe failed at bit %u\n", (unsigned int) bi); + pass = false; + } } - return !fail; + return pass; } int main(int argc, char **argv) { + if (1 < argc) { + fprintf(stderr,"%s does not require any arguments\n",argv[0]); + } - assert(test_mul0()); - assert(test_div0()); - assert(test_div_fast0()); + assert(test_x_regular((BigDecimalXFun){.m=bigdecimal128_mul}, true, "mul")); + assert(test_x_regular((BigDecimalXFun){.m=mul_safe_}, true, "mul_safe")); + assert(test_x_regular((BigDecimalXFun){.d=bigdecimal128_div}, false, "div")); + assert(test_x_regular((BigDecimalXFun){.d=bigdecimal128_div_fast}, false, "div_fast")); + assert(test_x_regular((BigDecimalXFun){.d=div_safe_}, false, "div_safe")); + assert(test_div_oor_prec(0, 0)); + assert(test_div_oor_prec(2, 0)); + assert(test_div_oor_prec(0, 2)); + assert(test_mul_oor_prec(4)); return 0; }