Skip to content

Commit

Permalink
Merge pull request #14 from ThoughtRiver/bug-fix-support-raw-msgpack-…
Browse files Browse the repository at this point in the history
…deserialization

Bug-fix: Supporting `raw` deserialization
  • Loading branch information
DomHudson authored Feb 24, 2020
2 parents 152fd37 + 43af596 commit f45ecb9
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 9 deletions.
22 changes: 19 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ pip install lmdb-embeddings
```

## Reading vectors

```python
from lmdb_embeddings.reader import LmdbEmbeddingsReader
from lmdb_embeddings.exceptions import MissingWordError
Expand Down Expand Up @@ -57,6 +56,23 @@ writer = LmdbEmbeddingsWriter(iter_embeddings()).write(OUTPUT_DATABASE_FOLDER)
# These vectors can now be loaded with the LmdbEmbeddingsReader.
```

## LRU Cache
A reader with an LRU (Least Recently Used) cache is included. This will save the embeddings for the 50,000 most recently queried words and return the same object instead of querying the database each time. Its interface is the same as the standard reader.
See [functools.lru_cache](https://docs.python.org/3/library/functools.html#functools.lru_cache) in the standard library.

```python
from lmdb_embeddings.reader import LruCachedLmdbEmbeddingsReader
from lmdb_embeddings.exceptions import MissingWordError

embeddings = LruCachedLmdbEmbeddingsReader('/path/to/word/vectors/eg/GoogleNews-vectors-negative300')

try:
vector = embeddings.get_word_vector('google')
except MissingWordError:
# 'google' is not in the database.
pass
```

## Customisation
By default, LMDB Embeddings uses pickle to serialize the vectors to bytes (optimized and pickled with the highest available protocol). However, it is very easy to use an alternative approach - simply inject the serializer and unserializer as callables into the `LmdbEmbeddingsWriter` and `LmdbEmbeddingsReader`.

Expand All @@ -68,7 +84,7 @@ from lmdb_embeddings.serializers import MsgpackSerializer

writer = LmdbEmbeddingsWriter(
iter_embeddings(),
serializer=MsgpackSerializer.serialize
serializer=MsgpackSerializer().serialize
).write(OUTPUT_DATABASE_FOLDER)
```

Expand All @@ -78,7 +94,7 @@ from lmdb_embeddings.serializers import MsgpackSerializer

reader = LmdbEmbeddingsReader(
OUTPUT_DATABASE_FOLDER,
unserializer=MsgpackSerializer.unserialize
unserializer=MsgpackSerializer().unserialize
)
```

Expand Down
22 changes: 19 additions & 3 deletions lmdb_embeddings/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ def unserialize(serialized_vector):

class MsgpackSerializer:

def __init__(self, raw = False):
""" Constructor.
:param bool raw: If True, unpack msgpack raw to Python bytes. Otherwise, unpack to Python
str by decoding with UTF-8 encoding (default). This is a highly confusing aspect of
msgpack-python. They have gone through several iterations on approaches to handle both
strings and bytes. If you are unsure what you need, leave this as False. If you
serialized your data on an older version of msgpack than what you are currently using,
you may need to set this to True.
:return void:
"""
self._raw = raw

@staticmethod
def serialize(vector):
""" Serializer a vector using msgpack.
Expand All @@ -57,11 +70,14 @@ def serialize(vector):
"""
return msgpack.packb(vector, default = msgpack_numpy.encode)

@staticmethod
def unserialize(serialized_vector):
def unserialize(self, serialized_vector):
""" Unserialize a vector using msgpack.
:param bytes serialized_vector:
:return np.array:
"""
return msgpack.unpackb(serialized_vector, object_hook = msgpack_numpy.decode)
return msgpack.unpackb(
serialized_vector,
object_hook = msgpack_numpy.decode,
raw = self._raw
)
4 changes: 2 additions & 2 deletions lmdb_embeddings/tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def test_msgpack_serialization(self, tmp_path, reader_class):

LmdbEmbeddingsWriter(
[('the', the_vector), ('is', np.random.rand(10))],
serializer = MsgpackSerializer.serialize
serializer = MsgpackSerializer().serialize
).write(directory_path)

reader = reader_class(directory_path, unserializer = MsgpackSerializer.unserialize)
reader = reader_class(directory_path, unserializer = MsgpackSerializer().unserialize)
assert reader.get_word_vector('the').tolist() == the_vector.tolist()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_readme():

setup(
name = 'lmdb_embeddings',
version = '0.3.0',
version = '0.4.0',
description = 'Fast querying of word embeddings using the LMDB "Lightning" Database.',
license = 'GNU General Public License v3.0',
long_description = get_readme(),
Expand Down

0 comments on commit f45ecb9

Please sign in to comment.