Skip to content

Commit

Permalink
Merge pull request #1327 from ampli/count-ovfl
Browse files Browse the repository at this point in the history
A fix for undetected count overflows
  • Loading branch information
linas authored Jul 16, 2022
2 parents 06f660d + 44e987e commit b4c303b
Showing 1 changed file with 76 additions and 12 deletions.
88 changes: 76 additions & 12 deletions link-grammar/parse/count.c
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,52 @@ static Count_bin do_count(int lineno, count_context_t *ctxt,
return r;
}

/*
* See do_parse() for the purpose of this function.
*
* The returned number of parses (called here "count") is a 32-bit
* integer. However, this count may sometimes be very big - much more than
* can be represented in 32-bits. In such a case it is just enough to know
* that such an "overflow" occurred. Internally, big counts are clamped to
* INT_MAX (2^31-1) - see parse_count_clamp() (we refer below to such
* values as "clamped"). If the top-level do_count() (the one that is
* called from do_parse()) returns this value, it means such an overflow
* has occurred.
*
* The function uses a 64-bit signed integer as a count accumulator - named
* "total". The maximum value it can hold is 2^63-1. If it becomes greater
* than INT_MAX, it is considered as a count overflow. A care should be
* taken that this total itself would not overflow, else this detection
* mechanism would be rendered useless. To that end, each value from which
* this total is computed should be small enough so it would not overflow.
*
* The function has 4 code sections to calculate the count. Each of them,
* when entered, returns a value which is clamped (or doesn't need to be
* clamped). The are marked in the code with "Path 1a", "Path 1b",
* "Path 2", and "Path 3".
*
* Path 1a, Path 1b: If there is a possible linkage between the given
* words, return 1, else return 0. Here a count overflow cannot occur.
*
* Path 2: The total accumulate the result of the do_count() invocations
* that are done in a loop. The upper bound on the number of iterations is
* twice (out loop) the maximum number of word disjuncts )inner loop).
* Assuming no more than 2^31 disjuncts per word, and considering that
* each value is a result of do_count() which is clamped, the total is
* less than (2*2^31)*(2^31`-1), which is less than 2^63-1, and hence just
* needs to be clamped before returning.
*
* Path 3: The total is calculated as a sum of series of multiplications.
* To prevent its overflow, we ensure that each term (including the total
* itself) would not be greater than INT_MAX (2^31-1), so the result will
* not be more than (2^31-1)+((2^31-1)*(2^31-1)) which is less than
* 2^63-1. In this path, each multiplication term that may be greater then
* INT_MAX (leftcount and rightcount) is clamped before the
* multiplication, and the total is clamped after the multiplication.
* Multiplication terms that result from caching (or directly from
* do_count()) are already clamped.
*/

#define do_count do_count1
#else
#define TRACE_LABEL(l, do_count) (do_count)
Expand Down Expand Up @@ -968,6 +1014,8 @@ static Count_bin do_count(

unsigned int unparseable_len = rw-lw-1;

/* Path 1a. */

#if 1
/* This check is not necessary for correctness, as it is handled in
* the general case below. It looks like it should be slightly faster. */
Expand All @@ -982,12 +1030,15 @@ static Count_bin do_count(
}
#endif


/* The left and right connectors are null, but the two words are
* NOT next to each-other. */
if ((le == NULL) && (re == NULL))
{
int nopt_words = num_optional_words(ctxt, lw, rw);

/* Path 1b. */

if ((null_count == 0) ||
(!ctxt->islands_ok && (lw != -1) && (ctxt->sent->word[lw].d != NULL)))
{
Expand All @@ -1004,6 +1055,8 @@ static Count_bin do_count(
return table_store(ctxt, lw, rw, le, re, null_count, h, hist_zero());
}

/* Path 2. */

/* Here null_count != 0 and we allow islands (a set of words
* linked together but separate from the rest of the sentence).
* Because we don't know here if an optional word is just
Expand All @@ -1012,6 +1065,12 @@ static Count_bin do_count(
* rest of the sentence must contain one less null-word. Else
* the rest of the sentence still contains the required number
* of null words. */

/* total (w_Count_bin which is int64_t) cannot overflow in this
* loop since the number of disjuncts in the inner loop is
* surely < 2^31, the outer loop can be iterated at most twice,
* and do_count() may return at most 2^31-1. However, it may
* become > 2^31-1 and hence needs to be clamped after the loop. */
w = lw + 1;
for (int opt = 0; opt <= (int)ctxt->sent->word[w].optional; opt++)
{
Expand All @@ -1024,26 +1083,23 @@ static Count_bin do_count(
hist_accumv(&total, d->cost,
do_count(ctxt, w, rw, d->right, NULL, try_null_count-1));
}
if (parse_count_clamp(&total))
{
#if 0
printf("OVERFLOW 1\n");
#endif
}
}

hist_accumv(&total, 0.0,
do_count(ctxt, w, rw, NULL, NULL, try_null_count-1));
if (parse_count_clamp(&total))
{
}

if (parse_count_clamp(&total))
{
#if 0
printf("OVERFLOW 2\n");
printf("OVERFLOW 1\n");
#endif
}
}
return table_store(ctxt, lw, rw, le, re, null_count, h, total);
}

/* Path 3. */

/* The word range (lw, rw) gets split in all tentatively possible ways
* to LHS term and RHS term.
* There can be a total count > 0 only if one of the following
Expand Down Expand Up @@ -1130,7 +1186,6 @@ static Count_bin do_count(
Count_bin *l_cache = NULL;
Count_bin *r_cache = NULL;
unsigned int lcount_index = 0; /* Cached left count index */
#define S(c) (!c?"(nil)":connector_string(c))

if (ctxt->is_short)
{
Expand Down Expand Up @@ -1355,14 +1410,21 @@ static Count_bin do_count(

#define CACHE_COUNT(c, how_to_count, do_count) \
{ \
w_Count_bin count = (hist_total(&c) == NO_COUNT) ? \
Count_bin count = (hist_total(&c) == NO_COUNT) ? \
TRACE_LABEL(c, do_count) : c; \
how_to_count; \
}
/* If the pseudocounting above indicates one of the terms
* in the count multiplication is zero,
* we know that the true total is zero. So we don't
* bother counting the other term at all, in that case. */

/* To enable 31-bit overflow detection, total, leftcount and
* rightcount are signed 64-bit, and are , a clamped cached
* value, or are clamped below before they are used. total is
* initially 0 and is clamped at the end of each iteration.
* So the result will not be more than (2^31-1)+((2^31-1)*(2^31-1))
* which is less than 2^63-1. */
if (leftpcount &&
(!lcnt_optimize || rightpcount || (0 != hist_total(&l_bnr))))
{
Expand All @@ -1383,6 +1445,7 @@ static Count_bin do_count(

if (0 < hist_total(&leftcount))
{
parse_count_clamp(&leftcount); /* May be up to 4*2^31. */
lrcnt_found = true;
d->match_left = true;

Expand Down Expand Up @@ -1412,6 +1475,7 @@ static Count_bin do_count(

if (0 < hist_total(&rightcount))
{
parse_count_clamp(&rightcount); /* May be up to 4*INT_MAX. */
if (le == NULL)
{
lrcnt_found = true;
Expand Down

0 comments on commit b4c303b

Please sign in to comment.