From 2eef3cb81c0212f39f049a523a8e56647ba60325 Mon Sep 17 00:00:00 2001 From: Kip Date: Wed, 30 Oct 2024 17:53:15 -0700 Subject: [PATCH] hnswlib/{bruteforce,hnswalg}.h: Added new generic serial / deserialization interfaces for HierarchicalNSW and BruteforceSearch... README.md: Noted new serial / deserialization interfaces... setup.py: Bumped patch version because new interfaces introduced... --- README.md | 4 ++++ hnswlib/bruteforce.h | 25 +++++++++++++++++-------- hnswlib/hnswalg.h | 26 ++++++++++++++++++-------- setup.py | 2 +- 4 files changed, 40 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index b0f015a9..76ed01d0 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,10 @@ Header-only C++ HNSW implementation with python bindings, insertions and updates **NEWS:** +**version 0.8.1** + +* Added generic serialization / deserialization interfaces for HierarchicalNSW and BruteforceSearch that take `std::ostream` / `std::istream` arguments + **version 0.8.0** * Multi-vector document search and epsilon search (for now, only in C++) diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 8727cc8a..f7e1d1cf 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -4,6 +4,7 @@ #include #include #include +#include namespace hnswlib { template @@ -135,24 +136,25 @@ class BruteforceSearch : public AlgorithmInterface { } - void saveIndex(const std::string &location) { - std::ofstream output(location, std::ios::binary); - std::streampos position; - + void saveIndex(std::ostream &output) { writeBinaryPOD(output, maxelements_); writeBinaryPOD(output, size_per_element_); writeBinaryPOD(output, cur_element_count); output.write(data_, maxelements_ * size_per_element_); + } + + + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + + saveIndex(output); output.close(); } - void loadIndex(const std::string &location, SpaceInterface *s) { - std::ifstream input(location, std::ios::binary); - std::streampos position; - + void loadIndex(std::istream &input, SpaceInterface *s) { readBinaryPOD(input, maxelements_); readBinaryPOD(input, size_per_element_); readBinaryPOD(input, cur_element_count); @@ -166,6 +168,13 @@ class BruteforceSearch : public AlgorithmInterface { throw std::runtime_error("Not enough memory: loadIndex failed to allocate data"); input.read(data_, maxelements_ * size_per_element_); + } + + + void loadIndex(const std::string &location, SpaceInterface *s) { + std::ifstream input(location, std::ios::binary); + + loadIndex(input, s); input.close(); } diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index e269ae69..dcd0766d 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -9,6 +9,7 @@ #include #include #include +#include namespace hnswlib { typedef unsigned int tableint; @@ -682,10 +683,8 @@ class HierarchicalNSW : public AlgorithmInterface { return size; } - void saveIndex(const std::string &location) { - std::ofstream output(location, std::ios::binary); - std::streampos position; + void saveIndex(std::ostream &output) { writeBinaryPOD(output, offsetLevel0_); writeBinaryPOD(output, max_elements_); writeBinaryPOD(output, cur_element_count); @@ -709,16 +708,17 @@ class HierarchicalNSW : public AlgorithmInterface { if (linkListSize) output.write(linkLists_[i], linkListSize); } - output.close(); } - void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i = 0) { - std::ifstream input(location, std::ios::binary); + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + saveIndex(output); + output.close(); + } - if (!input.is_open()) - throw std::runtime_error("Cannot open file"); + void loadIndex(std::istream &input, SpaceInterface *s, size_t max_elements_i = 0) { clear(); // get file size: input.seekg(0, input.end); @@ -815,6 +815,16 @@ class HierarchicalNSW : public AlgorithmInterface { if (allow_replace_deleted_) deleted_elements.insert(i); } } + } + + + void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i = 0) { + std::ifstream input(location, std::ios::binary); + + if (!input.is_open()) + throw std::runtime_error("Cannot open file"); + + loadIndex(input, s, max_elements_i); input.close(); diff --git a/setup.py b/setup.py index d96aea49..33616a0a 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext -__version__ = '0.8.0' +__version__ = '0.8.1' include_dirs = [