Skip to content

Commit

Permalink
Fixed compilation to get the tests working again.
Browse files Browse the repository at this point in the history
Distances are now specified via std::function objects in the Options.
  • Loading branch information
LTLA committed Jun 13, 2024
1 parent 8eb6ad9 commit 9bf7fd3
Show file tree
Hide file tree
Showing 6 changed files with 458 additions and 66 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ else()
find_package(hnswlib CONFIG REQUIRED)
endif()

target_link_libraries(knncolle_hnsw INTERFACE knncolle::knncolle Annoy::Annoy)
target_link_libraries(knncolle_hnsw INTERFACE knncolle::knncolle hnswlib::hnswlib)

# Tests
if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
Expand Down
7 changes: 7 additions & 0 deletions extern/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
include(FetchContent)

FetchContent_Declare(
knncolle
GIT_REPOSITORY https://github.com/knncolle/knncolle
GIT_TAG master # ^2.0.0
)

FetchContent_Declare(
hnswlib
GIT_REPOSITORY https://github.com/nmslib/hnswlib
GIT_TAG master # ^0.8.0
)

FetchContent_MakeAvailable(knncolle)
FetchContent_MakeAvailable(hnswlib)
73 changes: 64 additions & 9 deletions include/knncolle_hnsw/distances.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define KNNCOLLE_HNSW_DISTANCES_HPP

#include <cmath>
#include <functional>

/**
* @file distances.hpp
Expand All @@ -10,6 +11,29 @@

namespace knncolle_hnsw {

/**
* @brief Distance options for the HNSW index.
*
* @tparam Dim_ Integer type for the number of dimensions.
* @tparam InternalData_ Floating point type for the HNSW index.
*/
template<typename Dim_, typename InternalData_>
struct DistanceOptions {
/**
* Create a `hnswlib::SpaceInterface` object, given the number of dimensions.
* If not provided, this defaults to `hnswlib::L2Space` if `InternalData_ = float`,
* otherwise it defaults to `SquaredEuclideanDistance`.
*/
std::function<hnswlib::SpaceInterface<InternalData_>*(Dim_)> create;

/**
* Normalization function to convert distance measures from `hnswlib::SpaceInterface::get_dist_func()` into actual distances.
* If not provided and `create` is also provided, this defaults to a no-op.
* If not provided and `create` is not provided, this defaults to the square root function (i.e., to convert the L2 norm to a Euclidean distance).
*/
std::function<InternalData_(InternalData_)> normalize;
};

/**
* @brief Manhattan distance.
*
Expand All @@ -25,7 +49,7 @@ class ManhattanDistance : public hnswlib::SpaceInterface<InternalData_> {
/**
* @param dim Number of dimensions over which to compute the distance.
*/
ManhattanDistance(size_t dim) : my_data_size(num_dim * sizeof(InternalData_)), my_dim(ndim) {}
ManhattanDistance(size_t dim) : my_data_size(dim * sizeof(InternalData_)), my_dim(dim) {}

/**
* @cond
Expand Down Expand Up @@ -60,22 +84,53 @@ class ManhattanDistance : public hnswlib::SpaceInterface<InternalData_> {
};

/**
* @brief Euclidean distance with single-precision floats.
* @brief Squared Euclidean distance.
*
* @tparam InternalData_ Type of data in the HNSW index, usually floating-point.
*/
class EuclideanFloatDistnace : public hnswlib::L2Space {
template<typename InternalData_>
class SquaredEuclideanDistance : public hnswlib::SpaceInterface<InternalData_> {
private:
size_t my_data_size;
size_t my_dim;

public:
/**
* @param dim Number of dimensions.
* @param dim Number of dimensions over which to compute the distance.
*/
EuclideanFloatDistance(size_t dim) : hnswlib::L2Space(ndim) {}
SquaredEuclideanDistance(size_t dim) : my_data_size(dim * sizeof(InternalData_)), my_dim(dim) {}

/**
* @param raw Squared distance.
* @return Euclidean distance.
* @cond
*/
static float normalize(float raw) {
return std::sqrt(raw);
public:
size_t get_data_size() {
return my_data_size;
}

hnswlib::DISTFUNC<InternalData_> get_dist_func() {
return L2;
}

void * get_dist_func_param() {
return &my_dim;
}

private:
static InternalData_ L2(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
const InternalData_* pVect1 = static_cast<const InternalData_*>(pVect1v);
const InternalData_* pVect2 = static_cast<const InternalData_*>(pVect2v);
size_t qty = *((size_t *) qty_ptr);
InternalData_ res = 0;
for (; qty > 0; --qty, ++pVect1, ++pVect2) {
auto delta = *pVect1 - *pVect2;
res += delta * delta;
}
return res;
}
/**
* @endcond
*/
};

}
Expand Down
Loading

0 comments on commit 9bf7fd3

Please sign in to comment.