Skip to content

Commit

Permalink
[ISSUE-0089]: safe mul / div for bigdecimal
Browse files Browse the repository at this point in the history
  • Loading branch information
SzigetiJ committed Jan 31, 2024
1 parent 3930f5c commit d7b458a
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 91 deletions.
62 changes: 62 additions & 0 deletions src/bigdecimal128.c
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,22 @@ BigDecimal128 bigdecimal128_mul(const BigDecimal128 *a, const BigDecimal128 *b)
return (BigDecimal128){biguint128_mul(&a->val, &b->val), a->prec + b->prec};
}

buint_bool bigdecimal128_mul_safe(BigDecimal128 *dest, const BigDecimal128 *a, const BigDecimal128 *b) {
buint_bool inv[2];
BigUInt128 par[] = {bigint128_abs(&a->val, &inv[0]),bigint128_abs(&b->val, &inv[1])};
BigUIntPair128 dmul = biguint128_dmul(&par[0], &par[1]);
buint_bool valid = biguint128_eqz(&dmul.second);
if (!!inv[0] != !!inv[1]) { // result should be not positive
bigint128_negate_assign(&dmul.first);
valid &= bigint128_ltz(&dmul.first) || biguint128_eqz(&dmul.first);
} else {
valid &= !bigint128_ltz(&dmul.first);
}
dest->val = dmul.first;
dest->prec = a->prec + b->prec;
return valid;
}

BigDecimal128 bigdecimal128_div_fast(const BigDecimal128 *a, const BigDecimal128 *b, UInt prec) {
UInt ac_prec = a->prec + b->prec + prec;
BigDecimal128 ac = bigdecimal128_ctor_prec(a, ac_prec);
Expand Down Expand Up @@ -342,6 +358,52 @@ BigDecimal128 bigdecimal128_div(const BigDecimal128 *a, const BigDecimal128 *b,
return retv;
}

buint_bool bigdecimal128_div_safe(BigDecimal128 *dest, const BigDecimal128 *a, const BigDecimal128 *b, UInt prec) {
static buint_bool first = 1;
static BigUInt128 max;
if (first) {
BigUInt128 minusone = bigint128_value_of_uint(-1);
BigUInt128 maxx = biguint128_div5(&minusone).first;
max = biguint128_shr(&maxx, 2);
first = 0;
}
buint_bool ainv, binv;
BigUInt128 av = bigint128_abs(&a->val, &ainv);
BigUInt128 bv = bigint128_abs(&b->val, &binv);

// div.prec = a.prec - b.prec
// mul10 iterations: prec - div.prec = prec + b.prec - a.prec
// 10 / 0.3 = 33 -- a.prec=0, b.prec=1
// 100.00 / 3 = 33(.33) -- a.prec=2, b.prec=0
if (a->prec < prec + b->prec) {
UInt i = a->prec;
dest->val = biguint128_ctor_default();
while (1) {
BigUIntPair128 qq = biguint128_div(&av, &bv);
biguint128_add_assign(&dest->val, &qq.first);
av = qq.second;
if (i == prec + b->prec) break;
++i;
if (biguint128_lt(&max, &dest->val)) return 0;
dest->val = biguint128_mul10(&dest->val);
av = biguint128_mul10(&av);
}
for (UInt i = a->prec; i < prec + b->prec; ++i) {

}
} else {
for (UInt i = prec + b->prec; i < a->prec; ++i) {
av = biguint128_div10(&av).first;
}
dest->val = biguint128_div(&av, &bv).first;
}
dest->prec = prec;
if (ainv != binv) {
bigint128_negate_assign(&dest->val);
}
return 1;
}

// I/O

BigDecimal128 bigdecimal128_ctor_cstream(const char *dec_digits, buint_size_t len) {
Expand Down
2 changes: 2 additions & 0 deletions src/bigdecimal128.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ buint_bool bigdecimal128_sub_safe(BigDecimal128 *dest, const BigDecimal128 *a, c
@brief Multiplication.
*/
BigDecimal128 bigdecimal128_mul(const BigDecimal128 *a, const BigDecimal128 *b);
buint_bool bigdecimal128_mul_safe(BigDecimal128 *dest, const BigDecimal128 *a, const BigDecimal128 *b);

/**
@brief Fast division algorithm with narrow operating range.
Expand All @@ -98,6 +99,7 @@ BigDecimal128 bigdecimal128_div_fast(const BigDecimal128 *a, const BigDecimal128
TODO: insert if (remainder==0) break.
*/
BigDecimal128 bigdecimal128_div(const BigDecimal128 *a, const BigDecimal128 *b, UInt prec);
buint_bool bigdecimal128_div_safe(BigDecimal128 *dest, const BigDecimal128 *a, const BigDecimal128 *b, UInt prec);

/**
@brief 'Less than' relation between signed decimal numbers.
Expand Down
225 changes: 134 additions & 91 deletions tests/bigdecimal128_mul_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <stdbool.h>
#include <string.h>
#include <assert.h>
#include <stdlib.h>

#include "bigdecimal128.h"
#include "test_common.h"
Expand All @@ -11,131 +12,173 @@
#define BUFLEN (BIGDECCAP + 3)

// INTERNAL TYPES
typedef BigDecimal128(*BigDecimalDivFun)(const BigDecimal128 *a, const BigDecimal128 *b, UInt prec);
typedef BigDecimal128(*BigDecimalBinaryFun)(const BigDecimal128 *a, const BigDecimal128 *b);

typedef struct {
CStr num1;
CStr num2;
UInt resprec;
} TestInputType;

typedef struct {
CStr num;
} TestOutputType;

typedef BigDecimal128 (*BigDecimalFun)(const BigDecimal128 *a, const BigDecimal128 *b, UInt prec);
typedef union {
BigDecimalDivFun d;
BigDecimalBinaryFun m;
} BigDecimalXFun;

// TEST DATA
const TestInputType any_in[] = {
{STR("20"),STR("0.5"),0},
{STR("0.2"),STR("0.6"),3},
{STR("10"),STR("30"),5},
{STR("-1.5"),STR("1.5"),3},
{STR("2.0"),STR("-0.5"),1},
{STR("-10"),STR("0.001"),2},
{STR("0.000"),STR("10.00"),4}
const CStr any_in[][3] = {
{STR("20"),STR("0.5"),STR("0")},
{STR("0.2"),STR("0.6"),STR("3")},
{STR("10"),STR("30"),STR("5")},
{STR("-1.5"),STR("1.5"),STR("3")},
{STR("2.0"),STR("-0.5"),STR("1")},
{STR("-10"),STR("0.001"),STR("2")},
{STR("100.00"),STR("3"),STR("0")},
{STR("-200.00"),STR("7"),STR("1")},
{STR("200.0"),STR("-9.00"),STR("0")},
{STR("-300"),STR("-7.00"),STR("3")},
{STR("0.000"),STR("10.00"),STR("4")}
};

const TestOutputType mul_out[] = {
{STR("10.0")},
{STR("0.12")},
{STR("300")},
{STR("-2.25")},
{STR("-1.00")},
{STR("-0.010")},
{STR("0.00000")}
const CStr mul_out[] = {
STR("10.0"),
STR("0.12"),
STR("300"),
STR("-2.25"),
STR("-1.00"),
STR("-0.010"),
STR("300.00"),
STR("-1400.00"),
STR("-1800.000"),
STR("2100.00"),
STR("0.00000")
};

const TestOutputType div_out[] = {
{STR("40")},
{STR("0.333")},
{STR("0.33333")},
{STR("-1.000")},
{STR("-4.0")},
{STR("-10000.00")},
{STR("0.0000")}
const CStr div_out[] = {
STR("40"),
STR("0.333"),
STR("0.33333"),
STR("-1.000"),
STR("-4.0"),
STR("-10000.00"),
STR("33"),
STR("-28.5"),
STR("-22"),
STR("42.857"),
STR("0.0000")
};

int input_len = sizeof(any_in) / sizeof(any_in[0]);

// INTERNAL FUNCTIONS
/**
* @return true: test failed, false: test passed.
*/
static inline bool eval_divtestcase_(const TestInputType *tin, const TestOutputType *expected, BigDecimalFun fun, char *funstr) {

static inline bool eval_xtestcase_(const CStr *tin, const CStr *expected, bool is_mul, BigDecimalXFun fun, const char *funstr) {
bool pass = true;
char buffer[BUFLEN + 1];
if (BIGDECCAP < tin->num1.len || BIGDECCAP < tin->num2.len)
if (BIGDECCAP < tin[0].len || BIGDECCAP < tin[1].len)
return false;

BigDecimal128 a = bigdecimal128_ctor_cstream(tin->num1.str, tin->num1.len);
BigDecimal128 b = bigdecimal128_ctor_cstream(tin->num2.str, tin->num2.len);
BigDecimal128 a = bigdecimal128_ctor_cstream(tin[0].str, tin[0].len);
BigDecimal128 b = bigdecimal128_ctor_cstream(tin[1].str, tin[1].len);
UInt p = (UInt) atoi(tin[2].str);

BigDecimal128 quotient = fun(&a, &b, tin->resprec);
buint_size_t len = bigdecimal128_print(&quotient, buffer, sizeof(buffer) / sizeof(char) - 1);
BigDecimal128 result = is_mul ? fun.m(&a, &b) : fun.d(&a, &b, p);
buint_size_t len = bigdecimal128_print(&result, buffer, sizeof (buffer) / sizeof (char) - 1);
buffer[len] = 0;

int result = strcmp(expected->num.str, buffer);
int res_cmp = strcmp(expected->str, buffer);

if (result != 0) {
fprintf(stderr, "input: %s(%s, %s , %"PRIuint"); expected output: [%s], actual [%s]\n",
funstr, tin->num1.str, tin->num2.str, tin->resprec, expected->num.str, buffer);
return true;
if (res_cmp != 0) {
if (is_mul) {
fprintf(stderr, "input: %s(%s, %s); expected output: [%s], actual [%s]\n",
funstr, tin[0].str, tin[1].str, expected->str, buffer);
} else {
fprintf(stderr, "input: %s(%s, %s , %s); expected output: [%s], actual [%s]\n",
funstr, tin[0].str, tin[1].str, tin[2].str, expected->str, buffer);
}
pass = false;
}
return false;
return pass;
}

// TEST FUNCTIONS
bool test_mul0() {
bool fail = false;
char buffer[BUFLEN + 1];
for (int i = 0; i < input_len; ++i) {
const TestInputType *ti = &any_in[i];
const TestOutputType *expected = &mul_out[i];
if (BIGDECCAP < ti->num1.len || BIGDECCAP < ti->num2.len)
continue;

BigDecimal128 a = bigdecimal128_ctor_cstream(ti->num1.str, ti->num1.len);
BigDecimal128 b = bigdecimal128_ctor_cstream(ti->num2.str, ti->num2.len);

BigDecimal128 prod = bigdecimal128_mul(&a, &b);

buint_size_t len = bigdecimal128_print(&prod, buffer, sizeof(buffer) / sizeof(char) - 1);
buffer[len] = 0;
// WRAPPER FUNCTIONS
BigDecimal128 div_safe_(const BigDecimal128 *a, const BigDecimal128 *b, UInt prec) {
BigDecimal128 dest;
bigdecimal128_div_safe(&dest, a, b, prec);
return dest;
}

int result = strcmp(expected->num.str, buffer);
BigDecimal128 mul_safe_(const BigDecimal128 *a, const BigDecimal128 *b) {
BigDecimal128 dest;
bigdecimal128_mul_safe(&dest, a, b);
return dest;
}

if (result != 0) {
fprintf(stderr, "input: (%s * %s); expected output: [%s], actual [%s]\n",
ti->num1.str, ti->num2.str, expected->num.str, buffer);
fail = true;
}
bool test_x_regular(BigDecimalXFun fun, bool is_mul, const char *funname) {
bool pass = true;
for (unsigned int i = 0; i < ARRAYSIZE(any_in); ++i) {
pass &=eval_xtestcase_(&any_in[i][0], &(is_mul?mul_out :div_out)[i], is_mul, fun, funname);
}
return !fail;
return pass;
}

bool test_div0() {
bool fail = false;
for (int i = 0; i < input_len; ++i) {
bool testfailed = eval_divtestcase_(&any_in[i], &div_out[i], bigdecimal128_div, "div");
fail|=testfailed;
bool test_div_oor_prec(UInt prec_a, UInt prec_b) {
bool pass = true;

BigDecimal128 a = {biguint128_value_of_uint(1), prec_a};
BigDecimal128 b = {biguint128_value_of_uint(3), prec_b};
BigDecimal128 q;
for (UInt pi = 0; pi < 128; ++pi) {
buint_bool exp_res;
if (pi + prec_b < prec_a + (128 * 3) / 10 + 128/100) exp_res = 1;
else if (prec_a + (128 * 3) / 10 + 128/96 + 1 < pi + prec_b) exp_res = 0;
else continue;
buint_bool res = bigdecimal128_div_safe(&q, &a, &b, pi);
if (!!exp_res != !!res) {
fprintf(stderr, "div_safe failed at prec %u\n", (unsigned int) pi);
pass = false;
}
}
return !fail;
return pass;
}

bool test_div_fast0() {
bool fail = false;
for (int i = 0; i < input_len; ++i) {
bool testfailed = eval_divtestcase_(&any_in[i], &div_out[i], bigdecimal128_div_fast, "div");
fail|=testfailed;
bool test_mul_oor_prec(unsigned int zeroes) {
bool pass = true;

BigDecimal128 a = {biguint128_ctor_default(), 0};
biguint128_sbit(&a.val, 128 - 1 - zeroes);
BigDecimal128 p_ab; // product a*b
BigDecimal128 p_ac; // product a*c

for (UInt bi = 0; bi < 128 - 1; ++bi) {
BigDecimal128 b = {biguint128_ctor_default(), 0};
biguint128_sbit(&b.val, bi);
BigDecimal128 c = {bigint128_negate(&b.val), 0};
buint_bool exp_res_ab = (bi < zeroes);
buint_bool exp_res_ac = (bi < zeroes + 1);
buint_bool res_ab = bigdecimal128_mul_safe(&p_ab, &a, &b);
buint_bool res_ac = bigdecimal128_mul_safe(&p_ac, &a, &c);

if (!!exp_res_ab != !!res_ab) {
fprintf(stderr, "positive mul_safe failed at bit %u\n", (unsigned int) bi);
pass = false;
}
if (!!exp_res_ac != !!res_ac) {
fprintf(stderr, "negative mul_safe failed at bit %u\n", (unsigned int) bi);
pass = false;
}
}
return !fail;
return pass;
}

int main(int argc, char **argv) {
if (1 < argc) {
fprintf(stderr,"%s does not require any arguments\n",argv[0]);
}

assert(test_mul0());
assert(test_div0());
assert(test_div_fast0());
assert(test_x_regular((BigDecimalXFun){.m=bigdecimal128_mul}, true, "mul"));
assert(test_x_regular((BigDecimalXFun){.m=mul_safe_}, true, "mul_safe"));
assert(test_x_regular((BigDecimalXFun){.d=bigdecimal128_div}, false, "div"));
assert(test_x_regular((BigDecimalXFun){.d=bigdecimal128_div_fast}, false, "div_fast"));
assert(test_x_regular((BigDecimalXFun){.d=div_safe_}, false, "div_safe"));

assert(test_div_oor_prec(0, 0));
assert(test_div_oor_prec(2, 0));
assert(test_div_oor_prec(0, 2));
assert(test_mul_oor_prec(4));
return 0;
}

0 comments on commit d7b458a

Please sign in to comment.