Skip to content

Commit

Permalink
Add float16 support in python for cagra/brute_force/ivf_pq and scalar…
Browse files Browse the repository at this point in the history
… quantizer (#637)

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #637
  • Loading branch information
benfred authored Feb 4, 2025
1 parent ddec762 commit 8c683b0
Show file tree
Hide file tree
Showing 15 changed files with 114 additions and 43 deletions.
6 changes: 3 additions & 3 deletions cpp/include/cuvs/neighbors/brute_force.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ cuvsError_t cuvsBruteForceIndexDestroy(cuvsBruteForceIndex_t index);
* `DLDeviceType` equal to `kDLCUDA`, `kDLCUDAHost`, `kDLCUDAManaged`,
* or `kDLCPU`. Also, acceptable underlying types are:
* 1. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* 2. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
* 3. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
* 2. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 16`
*
* @code {.c}
* #include <cuvs/core/c_api.h>
Expand Down Expand Up @@ -120,7 +119,8 @@ cuvsError_t cuvsBruteForceBuild(cuvsResources_t res,
* It is also important to note that the BRUTEFORCE index must have been built
* with the same type of `queries`, such that `index.dtype.code ==
* queries.dl_tensor.dtype.code` Types for input are:
* 1. `queries`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* 1. `queries`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` or
* `kDLDataType.bits = 16`
* 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 32`
* 3. `distances`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
*
Expand Down
10 changes: 6 additions & 4 deletions cpp/include/cuvs/neighbors/cagra.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,9 @@ cuvsError_t cuvsCagraIndexGetDims(cuvsCagraIndex_t index, int* dim);
* `DLDeviceType` equal to `kDLCUDA`, `kDLCUDAHost`, `kDLCUDAManaged`,
* or `kDLCPU`. Also, acceptable underlying types are:
* 1. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* 2. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
* 3. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
* 2. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 16`
* 3. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
* 4. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
*
* @code {.c}
* #include <cuvs/core/c_api.h>
Expand Down Expand Up @@ -421,8 +422,9 @@ cuvsError_t cuvsCagraExtend(cuvsResources_t res,
* queries.dl_tensor.dtype.code` Types for input are:
* 1. `queries`:
* a. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* b. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
* c. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
* b. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 16`
* c. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
* d. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
* 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 32`
* 3. `distances`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
*
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/cuvs/neighbors/ivf_pq.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,9 @@ cuvsError_t cuvsIvfPqIndexDestroy(cuvsIvfPqIndex_t index);
* `DLDeviceType` equal to `kDLCUDA`, `kDLCUDAHost`, `kDLCUDAManaged`,
* or `kDLCPU`. Also, acceptable underlying types are:
* 1. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* 2. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
* 3. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
* 2. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 16`
* 3. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
* 4. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
*
* @code {.c}
* #include <cuvs/core/c_api.h>
Expand Down Expand Up @@ -314,6 +315,7 @@ cuvsError_t cuvsIvfPqBuild(cuvsResources_t res,
* with the same type of `queries`, such that `index.dtype.code ==
* queries.dl_tensor.dtype.code` Types for input are:
* 1. `queries`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* or `kDLDataType.bits = 16`
* 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 32`
* 3. `distances`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
*
Expand Down
57 changes: 40 additions & 17 deletions cpp/src/neighbors/brute_force_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

namespace {

template <typename T, typename LayoutT = raft::row_major>
template <typename T, typename LayoutT = raft::row_major, typename DistT = float>
void* _build(cuvsResources_t res,
DLManagedTensor* dataset_tensor,
cuvsDistanceType metric,
Expand All @@ -49,11 +49,11 @@ void* _build(cuvsResources_t res,
params.metric_arg = metric_arg;

auto index_on_stack = cuvs::neighbors::brute_force::build(*res_ptr, params, mds);
auto index_on_heap = new cuvs::neighbors::brute_force::index<T>(std::move(index_on_stack));
auto index_on_heap = new cuvs::neighbors::brute_force::index<T, DistT>(std::move(index_on_stack));
return index_on_heap;
}

template <typename T, typename QueriesLayoutT = raft::row_major>
template <typename T, typename QueriesLayoutT = raft::row_major, typename DistT = float>
void _search(cuvsResources_t res,
cuvsBruteForceIndex index,
DLManagedTensor* queries_tensor,
Expand All @@ -62,11 +62,11 @@ void _search(cuvsResources_t res,
cuvsFilter prefilter)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::brute_force::index<T>*>(index.addr);
auto index_ptr = reinterpret_cast<cuvs::neighbors::brute_force::index<T, DistT>*>(index.addr);

using queries_mdspan_type = raft::device_matrix_view<T const, int64_t, QueriesLayoutT>;
using neighbors_mdspan_type = raft::device_matrix_view<int64_t, int64_t, raft::row_major>;
using distances_mdspan_type = raft::device_matrix_view<float, int64_t, raft::row_major>;
using distances_mdspan_type = raft::device_matrix_view<DistT, int64_t, raft::row_major>;
using prefilter_mds_type = raft::device_vector_view<uint32_t, int64_t>;
using prefilter_bmp_type = cuvs::core::bitmap_view<uint32_t, int64_t>;

Expand Down Expand Up @@ -98,19 +98,19 @@ void _search(cuvsResources_t res,
}
}

template <typename T>
template <typename T, typename DistT = float>
void _serialize(cuvsResources_t res, const char* filename, cuvsBruteForceIndex index)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::brute_force::index<T>*>(index.addr);
auto index_ptr = reinterpret_cast<cuvs::neighbors::brute_force::index<T, DistT>*>(index.addr);
cuvs::neighbors::brute_force::serialize(*res_ptr, std::string(filename), *index_ptr);
}

template <typename T>
template <typename T, typename DistT = float>
void* _deserialize(cuvsResources_t res, const char* filename)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index = new cuvs::neighbors::brute_force::index<T>(*res_ptr);
auto index = new cuvs::neighbors::brute_force::index<T, DistT>(*res_ptr);
cuvs::neighbors::brute_force::deserialize(*res_ptr, std::string(filename), index);
return index;
}
Expand All @@ -126,14 +126,13 @@ extern "C" cuvsError_t cuvsBruteForceIndexDestroy(cuvsBruteForceIndex_t index_c_
return cuvs::core::translate_exceptions([=] {
auto index = *index_c_ptr;

if (index.dtype.code == kDLFloat) {
auto index_ptr = reinterpret_cast<cuvs::neighbors::brute_force::index<float>*>(index.addr);
if ((index.dtype.code == kDLFloat) && index.dtype.bits == 32) {
auto index_ptr =
reinterpret_cast<cuvs::neighbors::brute_force::index<float, float>*>(index.addr);
delete index_ptr;
} else if (index.dtype.code == kDLInt) {
auto index_ptr = reinterpret_cast<cuvs::neighbors::brute_force::index<int8_t>*>(index.addr);
delete index_ptr;
} else if (index.dtype.code == kDLUInt) {
auto index_ptr = reinterpret_cast<cuvs::neighbors::brute_force::index<uint8_t>*>(index.addr);
} else if ((index.dtype.code == kDLFloat) && index.dtype.bits == 16) {
auto index_ptr =
reinterpret_cast<cuvs::neighbors::brute_force::index<half, float>*>(index.addr);
delete index_ptr;
}
delete index_c_ptr;
Expand All @@ -148,6 +147,7 @@ extern "C" cuvsError_t cuvsBruteForceBuild(cuvsResources_t res,
{
return cuvs::core::translate_exceptions([=] {
auto dataset = dataset_tensor->dl_tensor;
index->dtype = dataset.dtype;

if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
if (cuvs::core::is_c_contiguous(dataset_tensor)) {
Expand All @@ -159,7 +159,16 @@ extern "C" cuvsError_t cuvsBruteForceBuild(cuvsResources_t res,
} else {
RAFT_FAIL("dataset input to cuvsBruteForceBuild must be contiguous (non-strided)");
}
index->dtype = dataset.dtype;
} else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 16) {
if (cuvs::core::is_c_contiguous(dataset_tensor)) {
index->addr =
reinterpret_cast<uintptr_t>(_build<half>(res, dataset_tensor, metric, metric_arg));
} else if (cuvs::core::is_f_contiguous(dataset_tensor)) {
index->addr = reinterpret_cast<uintptr_t>(
_build<half, raft::col_major>(res, dataset_tensor, metric, metric_arg));
} else {
RAFT_FAIL("dataset input to cuvsBruteForceBuild must be contiguous (non-strided)");
}
} else {
RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d",
dataset.dtype.code,
Expand Down Expand Up @@ -204,6 +213,15 @@ extern "C" cuvsError_t cuvsBruteForceSearch(cuvsResources_t res,
} else {
RAFT_FAIL("queries input to cuvsBruteForceSearch must be contiguous (non-strided)");
}
} else if (queries.dtype.code == kDLFloat && queries.dtype.bits == 16) {
if (cuvs::core::is_c_contiguous(queries_tensor)) {
_search<half>(res, index, queries_tensor, neighbors_tensor, distances_tensor, prefilter);
} else if (cuvs::core::is_f_contiguous(queries_tensor)) {
_search<half, raft::col_major>(
res, index, queries_tensor, neighbors_tensor, distances_tensor, prefilter);
} else {
RAFT_FAIL("queries input to cuvsBruteForceSearch must be contiguous (non-strided)");
}
} else {
RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d",
queries.dtype.code,
Expand All @@ -228,6 +246,9 @@ extern "C" cuvsError_t cuvsBruteForceDeserialize(cuvsResources_t res,
if (dtype.kind == 'f' && dtype.itemsize == 4) {
index->dtype.code = kDLFloat;
index->addr = reinterpret_cast<uintptr_t>(_deserialize<float>(res, filename));
} else if (dtype.kind == 'f' && dtype.itemsize == 2) {
index->dtype.code = kDLFloat;
index->addr = reinterpret_cast<uintptr_t>(_deserialize<half>(res, filename));
} else {
RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits);
}
Expand All @@ -241,6 +262,8 @@ extern "C" cuvsError_t cuvsBruteForceSerialize(cuvsResources_t res,
return cuvs::core::translate_exceptions([=] {
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
_serialize<float>(res, filename, *index);
} else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) {
_serialize<half>(res, filename, *index);
} else {
RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits);
}
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/neighbors/cagra_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res,
index->dtype = dataset.dtype;
if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
index->addr = reinterpret_cast<uintptr_t>(_build<float>(res, *params, dataset_tensor));
} else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 16) {
index->addr = reinterpret_cast<uintptr_t>(_build<half>(res, *params, dataset_tensor));
} else if (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) {
index->addr = reinterpret_cast<uintptr_t>(_build<int8_t>(res, *params, dataset_tensor));
} else if (dataset.dtype.code == kDLUInt && dataset.dtype.bits == 8) {
Expand Down Expand Up @@ -321,6 +323,9 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res,
if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) {
_search<float>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else if (queries.dtype.code == kDLFloat && queries.dtype.bits == 16) {
_search<half>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else if (queries.dtype.code == kDLInt && queries.dtype.bits == 8) {
_search<int8_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
Expand Down Expand Up @@ -433,6 +438,8 @@ extern "C" cuvsError_t cuvsCagraSerialize(cuvsResources_t res,
return cuvs::core::translate_exceptions([=] {
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
_serialize<float>(res, filename, index, include_dataset);
} else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) {
_serialize<half>(res, filename, index, include_dataset);
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
_serialize<int8_t>(res, filename, index, include_dataset);
} else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) {
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/neighbors/ivf_pq_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ extern "C" cuvsError_t cuvsIvfPqBuild(cuvsResources_t res,
if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
index->addr =
reinterpret_cast<uintptr_t>(_build<float, int64_t>(res, *params, dataset_tensor));
} else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 16) {
index->addr =
reinterpret_cast<uintptr_t>(_build<half, int64_t>(res, *params, dataset_tensor));
} else if (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) {
index->addr =
reinterpret_cast<uintptr_t>(_build<int8_t, int64_t>(res, *params, dataset_tensor));
Expand Down Expand Up @@ -197,6 +200,9 @@ extern "C" cuvsError_t cuvsIvfPqSearch(cuvsResources_t res,
if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) {
_search<float, int64_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
} else if (queries.dtype.code == kDLFloat && queries.dtype.bits == 16) {
_search<half, int64_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
} else if (queries.dtype.code == kDLInt && queries.dtype.bits == 8) {
_search<int8_t, int64_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
Expand Down Expand Up @@ -274,6 +280,8 @@ extern "C" cuvsError_t cuvsIvfPqExtend(cuvsResources_t res,

if (vectors.dtype.code == kDLFloat && vectors.dtype.bits == 32) {
_extend<float, int64_t>(res, new_vectors, new_indices, *index);
} else if (vectors.dtype.code == kDLFloat && vectors.dtype.bits == 16) {
_extend<half, int64_t>(res, new_vectors, new_indices, *index);
} else if (vectors.dtype.code == kDLInt && vectors.dtype.bits == 8) {
_extend<int8_t, int64_t>(res, new_vectors, new_indices, *index);
} else if (vectors.dtype.code == kDLUInt && vectors.dtype.bits == 8) {
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/preprocessing/quantize/scalar_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ extern "C" cuvsError_t cuvsScalarQuantizerTrain(cuvsResources_t res,
auto dataset = dataset_tensor->dl_tensor;
if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
_train<float>(res, *params, dataset_tensor, quantizer);
} else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 16) {
_train<half>(res, *params, dataset_tensor, quantizer);
} else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) {
_train<double>(res, *params, dataset_tensor, quantizer);
} else {
Expand All @@ -180,6 +182,8 @@ extern "C" cuvsError_t cuvsScalarQuantizerTransform(cuvsResources_t res,
auto dataset = dataset_tensor->dl_tensor;
if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
_transform<float>(res, quantizer, dataset_tensor, out_tensor);
} else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 16) {
_transform<half>(res, quantizer, dataset_tensor, out_tensor);
} else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 64) {
_transform<double>(res, quantizer, dataset_tensor, out_tensor);
} else {
Expand All @@ -199,6 +203,8 @@ cuvsError_t cuvsScalarQuantizerInverseTransform(cuvsResources_t res,
auto dtype = out->dl_tensor.dtype;
if (dtype.code == kDLFloat && dtype.bits == 32) {
_inverse_transform<float>(res, quantizer, dataset, out);
} else if (dtype.code == kDLFloat && dtype.bits == 16) {
_inverse_transform<half>(res, quantizer, dataset, out);
} else if (dtype.code == kDLFloat && dtype.bits == 64) {
_inverse_transform<double>(res, quantizer, dataset, out);
} else {
Expand Down
12 changes: 8 additions & 4 deletions python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def build(dataset, metric="sqeuclidean", metric_arg=2.0, resources=None):
Parameters
----------
dataset : CUDA array interface compliant matrix shape (n_samples, dim)
Supported dtype [float, int8, uint8]
Supported dtype [float32, float16]
metric : Distance metric to use. Default is sqeuclidean
metric_arg : value of 'p' for Minkowski distances
{resources_docstring}
Expand All @@ -102,7 +102,9 @@ def build(dataset, metric="sqeuclidean", metric_arg=2.0, resources=None):
"""

dataset_ai = wrap_array(dataset)
_check_input_array(dataset_ai, [np.dtype('float32')], exp_row_major=False)
_check_input_array(dataset_ai,
[np.dtype('float32'), np.dtype('float16')],
exp_row_major=False)

cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

Expand Down Expand Up @@ -141,7 +143,7 @@ def search(Index index,
index : Index
Trained Brute Force index.
queries : CUDA array interface compliant matrix shape (n_samples, dim)
Supported dtype [float, int8, uint8]
Supported dtype [float32, float16]
k : int
The number of neighbors.
neighbors : Optional CUDA array interface compliant matrix shape
Expand Down Expand Up @@ -218,7 +220,9 @@ def search(Index index,
cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

queries_cai = wrap_array(queries)
_check_input_array(queries_cai, [np.dtype('float32')], exp_row_major=False)
_check_input_array(queries_cai,
[np.dtype('float32'), np.dtype('float16')],
exp_row_major=False)

cdef uint32_t n_queries = queries_cai.shape[0]

Expand Down
8 changes: 6 additions & 2 deletions python/cuvs/cuvs/neighbors/cagra/cagra.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,9 @@ def build(IndexParams index_params, dataset, resources=None):
# todo(dgd): we can make the check of dtype a parameter of wrap_array
# in RAFT to make this a single call
dataset_ai = wrap_array(dataset)
_check_input_array(dataset_ai, [np.dtype('float32'), np.dtype('byte'),
_check_input_array(dataset_ai, [np.dtype('float32'),
np.dtype('float16'),
np.dtype('byte'),
np.dtype('ubyte')])

cdef Index idx = Index()
Expand Down Expand Up @@ -543,7 +545,9 @@ def search(SearchParams search_params,
# todo(dgd): we can make the check of dtype a parameter of wrap_array
# in RAFT to make this a single call
queries_cai = wrap_array(queries)
_check_input_array(queries_cai, [np.dtype('float32'), np.dtype('byte'),
_check_input_array(queries_cai, [np.dtype('float32'),
np.dtype('float16'),
np.dtype('byte'),
np.dtype('ubyte')])

cdef uint32_t n_queries = queries_cai.shape[0]
Expand Down
Loading

0 comments on commit 8c683b0

Please sign in to comment.