Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use minimal perfect hashing for lookups #37

Merged
merged 8 commits into from
Apr 16, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 126 additions & 55 deletions scripts/unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Since this should not require frequent updates, we just store this
# out-of-line and check the unicode.rs file into git.
import collections
import requests
import urllib.request

UNICODE_VERSION = "9.0.0"
UCD_URL = "https://www.unicode.org/Public/%s/ucd/" % UNICODE_VERSION
Expand Down Expand Up @@ -68,9 +68,9 @@ def __init__(self):

def stats(name, table):
count = sum(len(v) for v in table.values())
print "%s: %d chars => %d decomposed chars" % (name, len(table), count)
print("%s: %d chars => %d decomposed chars" % (name, len(table), count))

print "Decomposition table stats:"
print("Decomposition table stats:")
stats("Canonical decomp", self.canon_decomp)
stats("Compatible decomp", self.compat_decomp)
stats("Canonical fully decomp", self.canon_fully_decomp)
Expand All @@ -79,8 +79,8 @@ def stats(name, table):
self.ss_leading, self.ss_trailing = self._compute_stream_safe_tables()

def _fetch(self, filename):
resp = requests.get(UCD_URL + filename)
return resp.text
resp = urllib.request.urlopen(UCD_URL + filename)
return resp.read().decode('utf-8')

def _load_unicode_data(self):
self.combining_classes = {}
Expand Down Expand Up @@ -234,7 +234,7 @@ def _decompose(char_int, compatible):
# need to store their overlap when they agree. When they don't agree,
# store the decomposition in the compatibility table since we'll check
# that first when normalizing to NFKD.
assert canon_fully_decomp <= compat_fully_decomp
assert set(canon_fully_decomp) <= set(compat_fully_decomp)

for ch in set(canon_fully_decomp) & set(compat_fully_decomp):
if canon_fully_decomp[ch] == compat_fully_decomp[ch]:
Expand Down Expand Up @@ -284,52 +284,69 @@ def _compute_stream_safe_tables(self):

return leading_nonstarters, trailing_nonstarters

hexify = lambda c: hex(c)[2:].upper().rjust(4, '0')
hexify = lambda c: '{:04X}'.format(c)

def gen_combining_class(combining_classes, out):
out.write("#[inline]\n")
(salt, keys) = minimal_perfect_hash(combining_classes)
out.write("pub fn canonical_combining_class(c: char) -> u8 {\n")
out.write(" match c {\n")

for char, combining_class in sorted(combining_classes.items()):
out.write(" '\u{%s}' => %s,\n" % (hexify(char), combining_class))

out.write(" _ => 0,\n")
out.write(" }\n")
out.write(" mph_lookup(c.into(), &[\n")
for s in salt:
out.write(" 0x{:x},\n".format(s))
out.write(" ],\n")
out.write(" &[\n")
for k in keys:
kv = int(combining_classes[k]) | (k << 8)
out.write(" 0x{:x},\n".format(kv))
out.write(" ],\n")
out.write(" u8_lookup_fk, u8_lookup_fv, 0)\n")
out.write("}\n")

def gen_composition_table(canon_comp, out):
out.write("#[inline]\n")
table = {}
for (c1, c2), c3 in canon_comp.items():
if c1 < 0x10000 and c2 < 0x10000:
table[(c1 << 16) | c2] = c3
(salt, keys) = minimal_perfect_hash(table)
out.write("pub fn composition_table(c1: char, c2: char) -> Option<char> {\n")
out.write(" match (c1, c2) {\n")
out.write(" if c1 < '\\u{10000}' && c2 < '\\u{10000}' {\n")
out.write(" mph_lookup((c1 as u32) << 16 | (c2 as u32), &[\n")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the code outputting mph_lookup calls be factored out into a function?

for s in salt:
out.write(" 0x{:x},\n".format(s))
out.write(" ],\n")
out.write(" &[\n")
for k in keys:
out.write(" (0x%s, '\\u{%s}'),\n" % (hexify(k), hexify(table[k])))
out.write(" ],\n")
out.write(" pair_lookup_fk, pair_lookup_fv_opt, None)\n")
out.write(" } else {\n")
out.write(" match (c1, c2) {\n")

for (c1, c2), c3 in sorted(canon_comp.items()):
out.write(" ('\u{%s}', '\u{%s}') => Some('\u{%s}'),\n" % (hexify(c1), hexify(c2), hexify(c3)))
if c1 >= 0x10000 and c2 >= 0x10000:
out.write(" ('\\u{%s}', '\\u{%s}') => Some('\\u{%s}'),\n" % (hexify(c1), hexify(c2), hexify(c3)))

out.write(" _ => None,\n")
out.write(" _ => None,\n")
out.write(" }\n")
out.write(" }\n")
out.write("}\n")

def gen_decomposition_tables(canon_decomp, compat_decomp, out):
tables = [(canon_decomp, 'canonical'), (compat_decomp, 'compatibility')]
for table, name in tables:
out.write("#[inline]\n")
(salt, keys) = minimal_perfect_hash(table)
out.write("const {}_DECOMPOSED_KV: &[(u32, &'static [char])] = &[\n".format(name.upper()))
for char in keys:
d = ", ".join("'\\u{%s}'" % hexify(c) for c in table[char])
out.write(" (0x{:x}, &[{}]),\n".format(char, d))
out.write("];\n")
out.write("pub fn %s_fully_decomposed(c: char) -> Option<&'static [char]> {\n" % name)
# The "Some" constructor is around the match statement here, because
# putting it into the individual arms would make the item_bodies
# checking of rustc takes almost twice as long, and it's already pretty
# slow because of the huge number of match arms and the fact that there
# is a borrow inside each arm
out.write(" Some(match c {\n")

for char, chars in sorted(table.items()):
d = ", ".join("'\u{%s}'" % hexify(c) for c in chars)
out.write(" '\u{%s}' => &[%s],\n" % (hexify(char), d))

out.write(" _ => return None,\n")
out.write(" })\n")
out.write(" mph_lookup(c.into(), &[\n")
for s in salt:
out.write(" 0x{:x},\n".format(s))
out.write(" ],\n")
out.write(" {}_DECOMPOSED_KV,\n".format(name.upper()))
out.write(" pair_lookup_fk, pair_lookup_fv_opt, None)\n")
out.write("}\n")
out.write("\n")

def gen_qc_match(prop_table, out):
out.write(" match c {\n")
Expand Down Expand Up @@ -371,39 +388,45 @@ def gen_nfkd_qc(prop_tables, out):
out.write("}\n")

def gen_combining_mark(general_category_mark, out):
out.write("#[inline]\n")
(salt, keys) = minimal_perfect_hash(general_category_mark)
out.write("pub fn is_combining_mark(c: char) -> bool {\n")
out.write(" match c {\n")

for char in general_category_mark:
out.write(" '\u{%s}' => true,\n" % hexify(char))

out.write(" _ => false,\n")
out.write(" }\n")
out.write(" mph_lookup(c.into(), &[\n")
for s in salt:
out.write(" 0x{:x},\n".format(s))
out.write(" ],\n")
out.write(" &[\n")
for k in keys:
out.write(" 0x{:x},\n".format(k))
out.write(" ],\n")
out.write(" bool_lookup_fk, bool_lookup_fv, false)\n")
out.write("}\n")

def gen_stream_safe(leading, trailing, out):
# This could be done as a hash but the table is very small.
out.write("#[inline]\n")
out.write("pub fn stream_safe_leading_nonstarters(c: char) -> usize {\n")
out.write(" match c {\n")

for char, num_leading in leading.items():
out.write(" '\u{%s}' => %d,\n" % (hexify(char), num_leading))
for char, num_leading in sorted(leading.items()):
out.write(" '\\u{%s}' => %d,\n" % (hexify(char), num_leading))

out.write(" _ => 0,\n")
out.write(" }\n")
out.write("}\n")
out.write("\n")

out.write("#[inline]\n")
(salt, keys) = minimal_perfect_hash(trailing)
out.write("pub fn stream_safe_trailing_nonstarters(c: char) -> usize {\n")
out.write(" match c {\n")

for char, num_trailing in trailing.items():
out.write(" '\u{%s}' => %d,\n" % (hexify(char), num_trailing))

out.write(" _ => 0,\n")
out.write(" }\n")
out.write(" mph_lookup(c.into(), &[\n")
for s in salt:
out.write(" 0x{:x},\n".format(s))
out.write(" ],\n")
out.write(" &[\n")
for k in keys:
kv = int(trailing[k]) | (k << 8)
out.write(" 0x{:x},\n".format(kv))
out.write(" ],\n")
out.write(" u8_lookup_fk, u8_lookup_fv, 0) as usize\n")
out.write("}\n")

def gen_tests(tests, out):
Expand All @@ -419,7 +442,7 @@ def gen_tests(tests, out):
""")

out.write("pub const NORMALIZATION_TESTS: &[NormalizationTest] = &[\n")
str_literal = lambda s: '"%s"' % "".join("\u{%s}" % c for c in s)
str_literal = lambda s: '"%s"' % "".join("\\u{%s}" % c for c in s)

for test in tests:
out.write(" NormalizationTest {\n")
Expand All @@ -432,13 +455,61 @@ def gen_tests(tests, out):

out.write("];\n")

def my_hash(x, salt, n):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably should have a comment saying "guaranteed to be less than n"

# This is hash based on the theory that multiplication is efficient
mask_32 = 0xffffffff
y = ((x + salt) * 2654435769) & mask_32
y ^= (x * 0x31415926) & mask_32
return (y * n) >> 32

# Compute minimal perfect hash function, d can be either a dict or list of keys.
def minimal_perfect_hash(d, singleton_buckets = False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer if this function had more comments

n = len(d)
buckets = dict((h, []) for h in range(n))
for key in d:
h = my_hash(key, 0, n)
buckets[h].append(key)
bsorted = [(len(buckets[h]), h) for h in range(n)]
bsorted.sort(reverse = True)
claimed = [False] * n
salts = [0] * n
keys = [0] * n
for (bucket_size, h) in bsorted:
if bucket_size == 0:
break
elif singleton_buckets and bucket_size == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we use the singleton_buckets case at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I can remove it, especially as it seems to perform worse in benchmarks. The main reason I left it in is that it's more robust; without it there's a much greater probability the hashing will fail.

for i in range(n):
if not claimed[i]:
salts[h] = -(i + 1)
claimed[i] = True
keys[i] = buckets[h][0]
break
else:
for salt in range(1, 32768):
rehashes = [my_hash(key, salt, n) for key in buckets[h]]
if all(not claimed[hash] for hash in rehashes):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a guarantee that we won't have a collision amongst the rehashes? Is it just really unlikely? (I suspect it's the latter but want to confirm)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if it finds a suitable salt that comes with a guarantee the rehash won't have a collision (this is what the claimed bool-array keeps track of). On the other hand, it's possible that no salt can be found that satisfies that, but I believe it to be quite a low probability. There's things that can be done to make it more robust. I'll try to add a comment outlining that in case somebody does run into it with a data update.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, wait, the set check deals with this, I'd forgotten it was there 😄 . To be clear, I was specifically worried about cases where a single run of rehashes has collisions, which claimed won't catch since we update it later.

(worth leaving a comment saying that)

if len(set(rehashes)) < bucket_size:
continue
salts[h] = salt
for key in buckets[h]:
rehash = my_hash(key, salt, n)
claimed[rehash] = True
keys[rehash] = key
break
if salts[h] == 0:
print("minimal perfect hashing failed")
exit(1)
return (salts, keys)

if __name__ == '__main__':
data = UnicodeData()
with open("tables.rs", "w") as out:
with open("tables.rs", "w", newline = "\n") as out:
out.write(PREAMBLE)
out.write("use quick_check::IsNormalized;\n")
out.write("use quick_check::IsNormalized::*;\n")
out.write("\n")
out.write("use perfect_hash::*;\n")
out.write("\n")

version = "(%s, %s, %s)" % tuple(UNICODE_VERSION.split("."))
out.write("#[allow(unused)]\n")
Expand Down Expand Up @@ -470,6 +541,6 @@ def gen_tests(tests, out):
gen_stream_safe(data.ss_leading, data.ss_trailing, out)
out.write("\n")

with open("normalization_tests.rs", "w") as out:
with open("normalization_tests.rs", "w", newline = "\n") as out:
out.write(PREAMBLE)
gen_tests(data.norm_tests, out)
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ use std::str::Chars;

mod decompose;
mod normalize;
mod perfect_hash;
mod recompose;
mod quick_check;
mod stream_safe;
Expand Down
78 changes: 78 additions & 0 deletions src/perfect_hash.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright 2019 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! Support for lookups based on minimal perfect hashing.

// This function is based on multiplication being fast and is "good enough". Also
// it can share some work between the unsalted and salted versions.
#[inline]
fn my_hash(key: u32, salt: u32, n: usize) -> usize {
let y = key.wrapping_add(salt).wrapping_mul(2654435769);
let y = y ^ key.wrapping_mul(0x31415926);
(((y as u64) * (n as u64)) >> 32) as usize
}

/// Do a lookup using minimal perfect hashing.
///
/// The table is stored as a sequence of "salt" values, then a sequence of
/// values that contain packed key/value pairs. The strategy is to hash twice.
/// The first hash retrieves a salt value that makes the second hash unique.
/// The hash function doesn't have to be very good, just good enough that the
/// resulting map is unique.
#[inline]
pub(crate) fn mph_lookup<KV, V, FK, FV>(x: u32, salt: &[u16], kv: &[KV], fk: FK, fv: FV,
default: V) -> V
where KV: Copy, FK: Fn(KV) -> u32, FV: Fn(KV) -> V
{
let s = salt[my_hash(x, 0, salt.len())] as u32;
let key_val = kv[my_hash(x, s, salt.len())];
if x == fk(key_val) {
fv(key_val)
} else {
default
}
}

/// Extract the key in a 24 bit key and 8 bit value packed in a u32.
#[inline]
pub(crate) fn u8_lookup_fk(kv: u32) -> u32 {
kv >> 8
}

/// Extract the value in a 24 bit key and 8 bit value packed in a u32.
#[inline]
pub(crate) fn u8_lookup_fv(kv: u32) -> u8 {
(kv & 0xff) as u8
}

/// Extract the key for a boolean lookup.
#[inline]
pub(crate) fn bool_lookup_fk(kv: u32) -> u32 {
kv
}

/// Extract the value for a boolean lookup.
#[inline]
pub(crate) fn bool_lookup_fv(_kv: u32) -> bool {
true
}

/// Extract the key in a pair.
#[inline]
pub(crate) fn pair_lookup_fk<T>(kv: (u32, T)) -> u32 {
kv.0
}

/// Extract the value in a pair, returning an option.
#[inline]
pub(crate) fn pair_lookup_fv_opt<T>(kv: (u32, T)) -> Option<T> {
Some(kv.1)
}

Loading