-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use minimal perfect hashing for lookups #37
Changes from 5 commits
aa84f63
3a4a8f6
db57ffc
2d7bfd1
f64d47c
2e432d2
08996fa
40f9ba6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
# Since this should not require frequent updates, we just store this | ||
# out-of-line and check the unicode.rs file into git. | ||
import collections | ||
import requests | ||
import urllib.request | ||
|
||
UNICODE_VERSION = "9.0.0" | ||
UCD_URL = "https://www.unicode.org/Public/%s/ucd/" % UNICODE_VERSION | ||
|
@@ -68,9 +68,9 @@ def __init__(self): | |
|
||
def stats(name, table): | ||
count = sum(len(v) for v in table.values()) | ||
print "%s: %d chars => %d decomposed chars" % (name, len(table), count) | ||
print("%s: %d chars => %d decomposed chars" % (name, len(table), count)) | ||
|
||
print "Decomposition table stats:" | ||
print("Decomposition table stats:") | ||
stats("Canonical decomp", self.canon_decomp) | ||
stats("Compatible decomp", self.compat_decomp) | ||
stats("Canonical fully decomp", self.canon_fully_decomp) | ||
|
@@ -79,8 +79,8 @@ def stats(name, table): | |
self.ss_leading, self.ss_trailing = self._compute_stream_safe_tables() | ||
|
||
def _fetch(self, filename): | ||
resp = requests.get(UCD_URL + filename) | ||
return resp.text | ||
resp = urllib.request.urlopen(UCD_URL + filename) | ||
return resp.read().decode('utf-8') | ||
|
||
def _load_unicode_data(self): | ||
self.combining_classes = {} | ||
|
@@ -234,7 +234,7 @@ def _decompose(char_int, compatible): | |
# need to store their overlap when they agree. When they don't agree, | ||
# store the decomposition in the compatibility table since we'll check | ||
# that first when normalizing to NFKD. | ||
assert canon_fully_decomp <= compat_fully_decomp | ||
assert set(canon_fully_decomp) <= set(compat_fully_decomp) | ||
|
||
for ch in set(canon_fully_decomp) & set(compat_fully_decomp): | ||
if canon_fully_decomp[ch] == compat_fully_decomp[ch]: | ||
|
@@ -284,52 +284,69 @@ def _compute_stream_safe_tables(self): | |
|
||
return leading_nonstarters, trailing_nonstarters | ||
|
||
hexify = lambda c: hex(c)[2:].upper().rjust(4, '0') | ||
hexify = lambda c: '{:04X}'.format(c) | ||
|
||
def gen_combining_class(combining_classes, out): | ||
out.write("#[inline]\n") | ||
(salt, keys) = minimal_perfect_hash(combining_classes) | ||
out.write("pub fn canonical_combining_class(c: char) -> u8 {\n") | ||
out.write(" match c {\n") | ||
|
||
for char, combining_class in sorted(combining_classes.items()): | ||
out.write(" '\u{%s}' => %s,\n" % (hexify(char), combining_class)) | ||
|
||
out.write(" _ => 0,\n") | ||
out.write(" }\n") | ||
out.write(" mph_lookup(c.into(), &[\n") | ||
for s in salt: | ||
out.write(" 0x{:x},\n".format(s)) | ||
out.write(" ],\n") | ||
out.write(" &[\n") | ||
for k in keys: | ||
kv = int(combining_classes[k]) | (k << 8) | ||
out.write(" 0x{:x},\n".format(kv)) | ||
out.write(" ],\n") | ||
out.write(" u8_lookup_fk, u8_lookup_fv, 0)\n") | ||
out.write("}\n") | ||
|
||
def gen_composition_table(canon_comp, out): | ||
out.write("#[inline]\n") | ||
table = {} | ||
for (c1, c2), c3 in canon_comp.items(): | ||
if c1 < 0x10000 and c2 < 0x10000: | ||
table[(c1 << 16) | c2] = c3 | ||
(salt, keys) = minimal_perfect_hash(table) | ||
out.write("pub fn composition_table(c1: char, c2: char) -> Option<char> {\n") | ||
out.write(" match (c1, c2) {\n") | ||
out.write(" if c1 < '\\u{10000}' && c2 < '\\u{10000}' {\n") | ||
out.write(" mph_lookup((c1 as u32) << 16 | (c2 as u32), &[\n") | ||
for s in salt: | ||
out.write(" 0x{:x},\n".format(s)) | ||
out.write(" ],\n") | ||
out.write(" &[\n") | ||
for k in keys: | ||
out.write(" (0x%s, '\\u{%s}'),\n" % (hexify(k), hexify(table[k]))) | ||
out.write(" ],\n") | ||
out.write(" pair_lookup_fk, pair_lookup_fv_opt, None)\n") | ||
out.write(" } else {\n") | ||
out.write(" match (c1, c2) {\n") | ||
|
||
for (c1, c2), c3 in sorted(canon_comp.items()): | ||
out.write(" ('\u{%s}', '\u{%s}') => Some('\u{%s}'),\n" % (hexify(c1), hexify(c2), hexify(c3))) | ||
if c1 >= 0x10000 and c2 >= 0x10000: | ||
out.write(" ('\\u{%s}', '\\u{%s}') => Some('\\u{%s}'),\n" % (hexify(c1), hexify(c2), hexify(c3))) | ||
|
||
out.write(" _ => None,\n") | ||
out.write(" _ => None,\n") | ||
out.write(" }\n") | ||
out.write(" }\n") | ||
out.write("}\n") | ||
|
||
def gen_decomposition_tables(canon_decomp, compat_decomp, out): | ||
tables = [(canon_decomp, 'canonical'), (compat_decomp, 'compatibility')] | ||
for table, name in tables: | ||
out.write("#[inline]\n") | ||
(salt, keys) = minimal_perfect_hash(table) | ||
out.write("const {}_DECOMPOSED_KV: &[(u32, &'static [char])] = &[\n".format(name.upper())) | ||
for char in keys: | ||
d = ", ".join("'\\u{%s}'" % hexify(c) for c in table[char]) | ||
out.write(" (0x{:x}, &[{}]),\n".format(char, d)) | ||
out.write("];\n") | ||
out.write("pub fn %s_fully_decomposed(c: char) -> Option<&'static [char]> {\n" % name) | ||
# The "Some" constructor is around the match statement here, because | ||
# putting it into the individual arms would make the item_bodies | ||
# checking of rustc takes almost twice as long, and it's already pretty | ||
# slow because of the huge number of match arms and the fact that there | ||
# is a borrow inside each arm | ||
out.write(" Some(match c {\n") | ||
|
||
for char, chars in sorted(table.items()): | ||
d = ", ".join("'\u{%s}'" % hexify(c) for c in chars) | ||
out.write(" '\u{%s}' => &[%s],\n" % (hexify(char), d)) | ||
|
||
out.write(" _ => return None,\n") | ||
out.write(" })\n") | ||
out.write(" mph_lookup(c.into(), &[\n") | ||
for s in salt: | ||
out.write(" 0x{:x},\n".format(s)) | ||
out.write(" ],\n") | ||
out.write(" {}_DECOMPOSED_KV,\n".format(name.upper())) | ||
out.write(" pair_lookup_fk, pair_lookup_fv_opt, None)\n") | ||
out.write("}\n") | ||
out.write("\n") | ||
|
||
def gen_qc_match(prop_table, out): | ||
out.write(" match c {\n") | ||
|
@@ -371,39 +388,45 @@ def gen_nfkd_qc(prop_tables, out): | |
out.write("}\n") | ||
|
||
def gen_combining_mark(general_category_mark, out): | ||
out.write("#[inline]\n") | ||
(salt, keys) = minimal_perfect_hash(general_category_mark) | ||
out.write("pub fn is_combining_mark(c: char) -> bool {\n") | ||
out.write(" match c {\n") | ||
|
||
for char in general_category_mark: | ||
out.write(" '\u{%s}' => true,\n" % hexify(char)) | ||
|
||
out.write(" _ => false,\n") | ||
out.write(" }\n") | ||
out.write(" mph_lookup(c.into(), &[\n") | ||
for s in salt: | ||
out.write(" 0x{:x},\n".format(s)) | ||
out.write(" ],\n") | ||
out.write(" &[\n") | ||
for k in keys: | ||
out.write(" 0x{:x},\n".format(k)) | ||
out.write(" ],\n") | ||
out.write(" bool_lookup_fk, bool_lookup_fv, false)\n") | ||
out.write("}\n") | ||
|
||
def gen_stream_safe(leading, trailing, out): | ||
# This could be done as a hash but the table is very small. | ||
out.write("#[inline]\n") | ||
out.write("pub fn stream_safe_leading_nonstarters(c: char) -> usize {\n") | ||
out.write(" match c {\n") | ||
|
||
for char, num_leading in leading.items(): | ||
out.write(" '\u{%s}' => %d,\n" % (hexify(char), num_leading)) | ||
for char, num_leading in sorted(leading.items()): | ||
out.write(" '\\u{%s}' => %d,\n" % (hexify(char), num_leading)) | ||
|
||
out.write(" _ => 0,\n") | ||
out.write(" }\n") | ||
out.write("}\n") | ||
out.write("\n") | ||
|
||
out.write("#[inline]\n") | ||
(salt, keys) = minimal_perfect_hash(trailing) | ||
out.write("pub fn stream_safe_trailing_nonstarters(c: char) -> usize {\n") | ||
out.write(" match c {\n") | ||
|
||
for char, num_trailing in trailing.items(): | ||
out.write(" '\u{%s}' => %d,\n" % (hexify(char), num_trailing)) | ||
|
||
out.write(" _ => 0,\n") | ||
out.write(" }\n") | ||
out.write(" mph_lookup(c.into(), &[\n") | ||
for s in salt: | ||
out.write(" 0x{:x},\n".format(s)) | ||
out.write(" ],\n") | ||
out.write(" &[\n") | ||
for k in keys: | ||
kv = int(trailing[k]) | (k << 8) | ||
out.write(" 0x{:x},\n".format(kv)) | ||
out.write(" ],\n") | ||
out.write(" u8_lookup_fk, u8_lookup_fv, 0) as usize\n") | ||
out.write("}\n") | ||
|
||
def gen_tests(tests, out): | ||
|
@@ -419,7 +442,7 @@ def gen_tests(tests, out): | |
""") | ||
|
||
out.write("pub const NORMALIZATION_TESTS: &[NormalizationTest] = &[\n") | ||
str_literal = lambda s: '"%s"' % "".join("\u{%s}" % c for c in s) | ||
str_literal = lambda s: '"%s"' % "".join("\\u{%s}" % c for c in s) | ||
|
||
for test in tests: | ||
out.write(" NormalizationTest {\n") | ||
|
@@ -432,13 +455,61 @@ def gen_tests(tests, out): | |
|
||
out.write("];\n") | ||
|
||
def my_hash(x, salt, n): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably should have a comment saying "guaranteed to be less than |
||
# This is hash based on the theory that multiplication is efficient | ||
mask_32 = 0xffffffff | ||
y = ((x + salt) * 2654435769) & mask_32 | ||
y ^= (x * 0x31415926) & mask_32 | ||
return (y * n) >> 32 | ||
|
||
# Compute minimal perfect hash function, d can be either a dict or list of keys. | ||
def minimal_perfect_hash(d, singleton_buckets = False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer if this function had more comments |
||
n = len(d) | ||
buckets = dict((h, []) for h in range(n)) | ||
for key in d: | ||
h = my_hash(key, 0, n) | ||
buckets[h].append(key) | ||
bsorted = [(len(buckets[h]), h) for h in range(n)] | ||
bsorted.sort(reverse = True) | ||
claimed = [False] * n | ||
salts = [0] * n | ||
keys = [0] * n | ||
for (bucket_size, h) in bsorted: | ||
if bucket_size == 0: | ||
break | ||
elif singleton_buckets and bucket_size == 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we use the singleton_buckets case at all? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I can remove it, especially as it seems to perform worse in benchmarks. The main reason I left it in is that it's more robust; without it there's a much greater probability the hashing will fail. |
||
for i in range(n): | ||
if not claimed[i]: | ||
salts[h] = -(i + 1) | ||
claimed[i] = True | ||
keys[i] = buckets[h][0] | ||
break | ||
else: | ||
for salt in range(1, 32768): | ||
rehashes = [my_hash(key, salt, n) for key in buckets[h]] | ||
if all(not claimed[hash] for hash in rehashes): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a guarantee that we won't have a collision amongst the rehashes? Is it just really unlikely? (I suspect it's the latter but want to confirm) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, if it finds a suitable salt that comes with a guarantee the rehash won't have a collision (this is what the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, wait, the (worth leaving a comment saying that) |
||
if len(set(rehashes)) < bucket_size: | ||
continue | ||
salts[h] = salt | ||
for key in buckets[h]: | ||
rehash = my_hash(key, salt, n) | ||
claimed[rehash] = True | ||
keys[rehash] = key | ||
break | ||
if salts[h] == 0: | ||
print("minimal perfect hashing failed") | ||
exit(1) | ||
return (salts, keys) | ||
|
||
if __name__ == '__main__': | ||
data = UnicodeData() | ||
with open("tables.rs", "w") as out: | ||
with open("tables.rs", "w", newline = "\n") as out: | ||
out.write(PREAMBLE) | ||
out.write("use quick_check::IsNormalized;\n") | ||
out.write("use quick_check::IsNormalized::*;\n") | ||
out.write("\n") | ||
out.write("use perfect_hash::*;\n") | ||
out.write("\n") | ||
|
||
version = "(%s, %s, %s)" % tuple(UNICODE_VERSION.split(".")) | ||
out.write("#[allow(unused)]\n") | ||
|
@@ -470,6 +541,6 @@ def gen_tests(tests, out): | |
gen_stream_safe(data.ss_leading, data.ss_trailing, out) | ||
out.write("\n") | ||
|
||
with open("normalization_tests.rs", "w") as out: | ||
with open("normalization_tests.rs", "w", newline = "\n") as out: | ||
out.write(PREAMBLE) | ||
gen_tests(data.norm_tests, out) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
// Copyright 2019 The Rust Project Developers. See the COPYRIGHT | ||
// file at the top-level directory of this distribution and at | ||
// http://rust-lang.org/COPYRIGHT. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
//! Support for lookups based on minimal perfect hashing. | ||
|
||
// This function is based on multiplication being fast and is "good enough". Also | ||
// it can share some work between the unsalted and salted versions. | ||
#[inline] | ||
fn my_hash(key: u32, salt: u32, n: usize) -> usize { | ||
let y = key.wrapping_add(salt).wrapping_mul(2654435769); | ||
let y = y ^ key.wrapping_mul(0x31415926); | ||
(((y as u64) * (n as u64)) >> 32) as usize | ||
} | ||
|
||
/// Do a lookup using minimal perfect hashing. | ||
/// | ||
/// The table is stored as a sequence of "salt" values, then a sequence of | ||
/// values that contain packed key/value pairs. The strategy is to hash twice. | ||
/// The first hash retrieves a salt value that makes the second hash unique. | ||
/// The hash function doesn't have to be very good, just good enough that the | ||
/// resulting map is unique. | ||
#[inline] | ||
pub(crate) fn mph_lookup<KV, V, FK, FV>(x: u32, salt: &[u16], kv: &[KV], fk: FK, fv: FV, | ||
default: V) -> V | ||
where KV: Copy, FK: Fn(KV) -> u32, FV: Fn(KV) -> V | ||
{ | ||
let s = salt[my_hash(x, 0, salt.len())] as u32; | ||
let key_val = kv[my_hash(x, s, salt.len())]; | ||
if x == fk(key_val) { | ||
fv(key_val) | ||
} else { | ||
default | ||
} | ||
} | ||
|
||
/// Extract the key in a 24 bit key and 8 bit value packed in a u32. | ||
#[inline] | ||
pub(crate) fn u8_lookup_fk(kv: u32) -> u32 { | ||
kv >> 8 | ||
} | ||
|
||
/// Extract the value in a 24 bit key and 8 bit value packed in a u32. | ||
#[inline] | ||
pub(crate) fn u8_lookup_fv(kv: u32) -> u8 { | ||
(kv & 0xff) as u8 | ||
} | ||
|
||
/// Extract the key for a boolean lookup. | ||
#[inline] | ||
pub(crate) fn bool_lookup_fk(kv: u32) -> u32 { | ||
kv | ||
} | ||
|
||
/// Extract the value for a boolean lookup. | ||
#[inline] | ||
pub(crate) fn bool_lookup_fv(_kv: u32) -> bool { | ||
true | ||
} | ||
|
||
/// Extract the key in a pair. | ||
#[inline] | ||
pub(crate) fn pair_lookup_fk<T>(kv: (u32, T)) -> u32 { | ||
kv.0 | ||
} | ||
|
||
/// Extract the value in a pair, returning an option. | ||
#[inline] | ||
pub(crate) fn pair_lookup_fv_opt<T>(kv: (u32, T)) -> Option<T> { | ||
Some(kv.1) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the code outputting
mph_lookup
calls be factored out into a function?