diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index cae8c8519e..4c6e3d0ed4 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -22,7 +22,7 @@ on: default: nightly concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name }} cancel-in-progress: true jobs: diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py index 123aeba87b..2af8b1b8ff 100644 --- a/ci/checks/copyright.py +++ b/ci/checks/copyright.py @@ -20,6 +20,8 @@ import os import sys +import git + SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) # Add the scripts dir for gitutils @@ -37,7 +39,12 @@ re.compile(r"setup[.]cfg$"), re.compile(r"meta[.]yaml$") ] -ExemptFiles = ["cpp/include/raft/neighbors/detail/faiss_select/"] +ExemptFiles = [ + re.compile("cpp/include/raft/neighbors/detail/faiss_select/"), + re.compile("cpp/include/raft/thirdparty/"), + re.compile("docs/source/sphinxext/github_link.py"), + re.compile("cpp/cmake/modules/FindAVX.cmake") +] # this will break starting at year 10000, which is probably OK :) CheckSimple = re.compile( @@ -48,10 +55,12 @@ def checkThisFile(f): - # This check covers things like symlinks which point to files that DNE - if not (os.path.exists(f)): - return False - if gitutils and gitutils.isFileEmpty(f): + if isinstance(f, git.Diff): + if f.deleted_file or f.b_blob.size == 0: + return False + f = f.b_path + elif not os.path.exists(f) or os.stat(f).st_size == 0: + # This check covers things like symlinks which point to files that DNE return False for exempt in ExemptFiles: if exempt.search(f): @@ -62,36 +71,90 @@ def checkThisFile(f): return False +def modifiedFiles(): + """Get a set of all modified files, as Diff objects. + + The files returned have been modified in git since the merge base of HEAD + and the upstream of the target branch. We return the Diff objects so that + we can read only the staged changes. + """ + repo = git.Repo() + # Use the environment variable TARGET_BRANCH or RAPIDS_BASE_BRANCH (defined in CI) if possible + target_branch = os.environ.get("TARGET_BRANCH", os.environ.get("RAPIDS_BASE_BRANCH")) + if target_branch is None: + # Fall back to the closest branch if not on CI + target_branch = repo.git.describe( + all=True, tags=True, match="branch-*", abbrev=0 + ).lstrip("heads/") + + upstream_target_branch = None + if target_branch in repo.heads: + # Use the tracking branch of the local reference if it exists. This + # returns None if no tracking branch is set. + upstream_target_branch = repo.heads[target_branch].tracking_branch() + if upstream_target_branch is None: + # Fall back to the remote with the newest target_branch. This code + # path is used on CI because the only local branch reference is + # current-pr-branch, and thus target_branch is not in repo.heads. + # This also happens if no tracking branch is defined for the local + # target_branch. We use the remote with the latest commit if + # multiple remotes are defined. + candidate_branches = [ + remote.refs[target_branch] for remote in repo.remotes + if target_branch in remote.refs + ] + if len(candidate_branches) > 0: + upstream_target_branch = sorted( + candidate_branches, + key=lambda branch: branch.commit.committed_datetime, + )[-1] + else: + # If no remotes are defined, try to use the local version of the + # target_branch. If this fails, the repo configuration must be very + # strange and we can fix this script on a case-by-case basis. + upstream_target_branch = repo.heads[target_branch] + merge_base = repo.merge_base("HEAD", upstream_target_branch.commit)[0] + diff = merge_base.diff() + changed_files = {f for f in diff if f.b_path is not None} + return changed_files + + def getCopyrightYears(line): res = CheckSimple.search(line) if res: - return (int(res.group(1)), int(res.group(1))) + return int(res.group(1)), int(res.group(1)) res = CheckDouble.search(line) if res: - return (int(res.group(1)), int(res.group(2))) - return (None, None) + return int(res.group(1)), int(res.group(2)) + return None, None def replaceCurrentYear(line, start, end): # first turn a simple regex into double (if applicable). then update years res = CheckSimple.sub(r"Copyright (c) \1-\1, NVIDIA CORPORATION", line) res = CheckDouble.sub( - r"Copyright (c) {:04d}-{:04d}, NVIDIA CORPORATION".format(start, end), - res) + rf"Copyright (c) {start:04d}-{end:04d}, NVIDIA CORPORATION", + res, + ) return res def checkCopyright(f, update_current_year): - """ - Checks for copyright headers and their years - """ + """Checks for copyright headers and their years.""" errs = [] thisYear = datetime.datetime.now().year lineNum = 0 crFound = False yearMatched = False - with io.open(f, "r", encoding="utf-8") as fp: - lines = fp.readlines() + + if isinstance(f, git.Diff): + path = f.b_path + lines = f.b_blob.data_stream.read().decode().splitlines(keepends=True) + else: + path = f + with open(f, encoding="utf-8") as fp: + lines = fp.readlines() + for line in lines: lineNum += 1 start, end = getCopyrightYears(line) @@ -100,20 +163,19 @@ def checkCopyright(f, update_current_year): crFound = True if start > end: e = [ - f, + path, lineNum, "First year after second year in the copyright " "header (manual fix required)", - None + None, ] errs.append(e) - if thisYear < start or thisYear > end: + elif thisYear < start or thisYear > end: e = [ - f, + path, lineNum, - "Current year not included in the " - "copyright header", - None + "Current year not included in the copyright header", + None, ] if thisYear < start: e[-1] = replaceCurrentYear(line, thisYear, end) @@ -122,15 +184,14 @@ def checkCopyright(f, update_current_year): errs.append(e) else: yearMatched = True - fp.close() # copyright header itself not found if not crFound: e = [ - f, + path, 0, "Copyright header missing or formatted incorrectly " "(manual fix required)", - None + None, ] errs.append(e) # even if the year matches a copyright header, make the check pass @@ -140,21 +201,19 @@ def checkCopyright(f, update_current_year): if update_current_year: errs_update = [x for x in errs if x[-1] is not None] if len(errs_update) > 0: - print("File: {}. Changing line(s) {}".format( - f, ', '.join(str(x[1]) for x in errs if x[-1] is not None))) + lines_changed = ", ".join(str(x[1]) for x in errs_update) + print(f"File: {path}. Changing line(s) {lines_changed}") for _, lineNum, __, replacement in errs_update: lines[lineNum - 1] = replacement - with io.open(f, "w", encoding="utf-8") as out_file: - for new_line in lines: - out_file.write(new_line) - errs = [x for x in errs if x[-1] is None] + with open(path, "w", encoding="utf-8") as out_file: + out_file.writelines(lines) return errs def getAllFilesUnderDir(root, pathFilter=None): retList = [] - for (dirpath, dirnames, filenames) in os.walk(root): + for dirpath, dirnames, filenames in os.walk(root): for fn in filenames: filePath = os.path.join(dirpath, fn) if pathFilter(filePath): @@ -169,49 +228,37 @@ def checkCopyright_main(): it compares between branches "$PR_TARGET_BRANCH" and "current-pr-branch" """ retVal = 0 - global ExemptFiles argparser = argparse.ArgumentParser( - "Checks for a consistent copyright header in git's modified files") - argparser.add_argument("--update-current-year", - dest='update_current_year', - action="store_true", - required=False, - help="If set, " - "update the current year if a header " - "is already present and well formatted.") - argparser.add_argument("--git-modified-only", - dest='git_modified_only', - action="store_true", - required=False, - help="If set, " - "only files seen as modified by git will be " - "processed.") - argparser.add_argument("--exclude", - dest='exclude', - action="append", - required=False, - default=["python/cuml/_thirdparty/", - "cpp/include/raft/thirdparty/", - "cpp/cmake/modules/FindAVX.cmake"], - help=("Exclude the paths specified (regexp). " - "Can be specified multiple times.")) - - (args, dirs) = argparser.parse_known_args() - try: - ExemptFiles = ExemptFiles + [pathName for pathName in args.exclude] - ExemptFiles = [re.compile(file) for file in ExemptFiles] - except re.error as reException: - print("Regular expression error:") - print(reException) - return 1 + "Checks for a consistent copyright header in git's modified files" + ) + argparser.add_argument( + "--update-current-year", + dest="update_current_year", + action="store_true", + required=False, + help="If set, " + "update the current year if a header is already " + "present and well formatted.", + ) + argparser.add_argument( + "--git-modified-only", + dest="git_modified_only", + action="store_true", + required=False, + help="If set, " + "only files seen as modified by git will be " + "processed.", + ) + + args, dirs = argparser.parse_known_args() if args.git_modified_only: - files = gitutils.modifiedFiles(pathFilter=checkThisFile) + files = [f for f in modifiedFiles() if checkThisFile(f)] else: files = [] for d in [os.path.abspath(d) for d in dirs]: - if not (os.path.isdir(d)): + if not os.path.isdir(d): raise ValueError(f"{d} is not a directory.") files += getAllFilesUnderDir(d, pathFilter=checkThisFile) @@ -220,24 +267,24 @@ def checkCopyright_main(): errors += checkCopyright(f, args.update_current_year) if len(errors) > 0: - print("Copyright headers incomplete in some of the files!") + if any(e[-1] is None for e in errors): + print("Copyright headers incomplete in some of the files!") for e in errors: print(" %s:%d Issue: %s" % (e[0], e[1], e[2])) print("") n_fixable = sum(1 for e in errors if e[-1] is not None) path_parts = os.path.abspath(__file__).split(os.sep) - file_from_repo = os.sep.join(path_parts[path_parts.index("ci"):]) - if n_fixable > 0: - print(("You can run `python {} --git-modified-only " - "--update-current-year` to fix {} of these " - "errors.\n").format(file_from_repo, n_fixable)) + file_from_repo = os.sep.join(path_parts[path_parts.index("ci") :]) + if n_fixable > 0 and not args.update_current_year: + print( + f"You can run `python {file_from_repo} --git-modified-only " + "--update-current-year` and stage the results in git to " + f"fix {n_fixable} of these errors.\n" + ) retVal = 1 - else: - print("Copyright check passed") return retVal if __name__ == "__main__": - import sys sys.exit(checkCopyright_main()) diff --git a/ci/wheel_smoke_test_pylibraft.py b/ci/wheel_smoke_test_pylibraft.py index 7fee674691..c0df2fe45c 100644 --- a/ci/wheel_smoke_test_pylibraft.py +++ b/ci/wheel_smoke_test_pylibraft.py @@ -1,3 +1,18 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import numpy as np from scipy.spatial.distance import cdist diff --git a/conda/recipes/pylibraft/meta.yaml b/conda/recipes/pylibraft/meta.yaml index 454cac0d77..b8a088d0f3 100644 --- a/conda/recipes/pylibraft/meta.yaml +++ b/conda/recipes/pylibraft/meta.yaml @@ -48,7 +48,6 @@ requirements: - cython >=3.0.0 - libraft {{ version }} - libraft-headers {{ version }} - - numpy >=1.21 - python x.x - rmm ={{ minor_version }} - scikit-build >=0.13.1 @@ -60,6 +59,7 @@ requirements: {% endif %} - libraft {{ version }} - libraft-headers {{ version }} + - numpy >=1.21 - python x.x - rmm ={{ minor_version }} diff --git a/cpp/include/raft/neighbors/detail/div_utils.hpp b/cpp/include/raft/neighbors/detail/div_utils.hpp new file mode 100644 index 0000000000..0455d0ec9b --- /dev/null +++ b/cpp/include/raft/neighbors/detail/div_utils.hpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef _RAFT_HAS_CUDA +#include +#else +#include +#endif + +/** + * @brief A simple wrapper for raft::Pow2 which uses Pow2 utils only when available and regular + * integer division otherwise. This is done to allow a common interface for division arithmetic for + * non CUDA headers. + * + * @tparam Value_ a compile-time value representable as a power-of-two. + */ +namespace raft::neighbors::detail { +template +struct div_utils { + typedef decltype(Value_) Type; + static constexpr Type Value = Value_; + + template + static constexpr _RAFT_HOST_DEVICE inline auto roundDown(T x) + { +#if defined(_RAFT_HAS_CUDA) + return Pow2::roundDown(x); +#else + return raft::round_down_safe(x, Value_); +#endif + } + + template + static constexpr _RAFT_HOST_DEVICE inline auto mod(T x) + { +#if defined(_RAFT_HAS_CUDA) + return Pow2::mod(x); +#else + return x % Value_; +#endif + } + + template + static constexpr _RAFT_HOST_DEVICE inline auto div(T x) + { +#if defined(_RAFT_HAS_CUDA) + return Pow2::div(x); +#else + return x / Value_; +#endif + } +}; +} // namespace raft::neighbors::detail \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index 33ed51ad05..e57133fc23 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -313,6 +313,59 @@ auto calculate_offsets_and_indices(IdxT n_rows, return max_cluster_size; } +template +void set_centers(raft::resources const& handle, index* index, const float* cluster_centers) +{ + auto stream = resource::get_cuda_stream(handle); + auto* device_memory = resource::get_workspace_resource(handle); + + // combine cluster_centers and their norms + RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle(), + sizeof(float) * index->dim_ext(), + cluster_centers, + sizeof(float) * index->dim(), + sizeof(float) * index->dim(), + index->n_lists(), + cudaMemcpyDefault, + stream)); + + rmm::device_uvector center_norms(index->n_lists(), stream, device_memory); + raft::linalg::rowNorm(center_norms.data(), + cluster_centers, + index->dim(), + index->n_lists(), + raft::linalg::L2Norm, + true, + stream); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle() + index->dim(), + sizeof(float) * index->dim_ext(), + center_norms.data(), + sizeof(float), + sizeof(float), + index->n_lists(), + cudaMemcpyDefault, + stream)); + + // Rotate cluster_centers + float alpha = 1.0; + float beta = 0.0; + linalg::gemm(handle, + true, + false, + index->rot_dim(), + index->n_lists(), + index->dim(), + &alpha, + index->rotation_matrix().data_handle(), + index->dim(), + cluster_centers, + index->dim(), + &beta, + index->centers_rot().data_handle(), + index->rot_dim(), + resource::get_cuda_stream(handle)); +} + template void transpose_pq_centers(const resources& handle, index& index, @@ -613,6 +666,100 @@ void unpack_list_data(raft::resources const& res, resource::get_cuda_stream(res)); } +/** + * A consumer for the `run_on_vector` that just flattens PQ codes + * into a tightly packed matrix. That is, the codes are not expanded to one code-per-byte. + */ +template +struct unpack_contiguous { + uint8_t* codes; + uint32_t code_size; + + /** + * Create a callable to be passed to `run_on_vector`. + * + * @param[in] codes flat compressed PQ codes + */ + __host__ __device__ inline unpack_contiguous(uint8_t* codes, uint32_t pq_dim) + : codes{codes}, code_size{raft::ceildiv(pq_dim * PqBits, 8)} + { + } + + /** Write j-th component (code) of the i-th vector into the output array. */ + __host__ __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) + { + bitfield_view_t code_view{codes + i * code_size}; + code_view[j] = code; + } +}; + +template +__launch_bounds__(BlockSize) RAFT_KERNEL unpack_contiguous_list_data_kernel( + uint8_t* out_codes, + device_mdspan::list_extents, row_major> in_list_data, + uint32_t n_rows, + uint32_t pq_dim, + std::variant offset_or_indices) +{ + run_on_list( + in_list_data, offset_or_indices, n_rows, pq_dim, unpack_contiguous(out_codes, pq_dim)); +} + +/** + * Unpack flat PQ codes from an existing list by the given offset. + * + * @param[out] codes flat compressed PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)] + * @param[in] list_data the packed ivf::list data. + * @param[in] offset_or_indices how many records in the list to skip or the exact indices. + * @param[in] pq_bits codebook size (1 << pq_bits) + * @param[in] stream + */ +inline void unpack_contiguous_list_data( + uint8_t* codes, + device_mdspan::list_extents, row_major> list_data, + uint32_t n_rows, + uint32_t pq_dim, + std::variant offset_or_indices, + uint32_t pq_bits, + rmm::cuda_stream_view stream) +{ + if (n_rows == 0) { return; } + + constexpr uint32_t kBlockSize = 256; + dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); + dim3 threads(kBlockSize, 1, 1); + auto kernel = [pq_bits]() { + switch (pq_bits) { + case 4: return unpack_contiguous_list_data_kernel; + case 5: return unpack_contiguous_list_data_kernel; + case 6: return unpack_contiguous_list_data_kernel; + case 7: return unpack_contiguous_list_data_kernel; + case 8: return unpack_contiguous_list_data_kernel; + default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); + } + }(); + kernel<<>>(codes, list_data, n_rows, pq_dim, offset_or_indices); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +/** Unpack the list data; see the public interface for the api and usage. */ +template +void unpack_contiguous_list_data(raft::resources const& res, + const index& index, + uint8_t* out_codes, + uint32_t n_rows, + uint32_t label, + std::variant offset_or_indices) +{ + unpack_contiguous_list_data(out_codes, + index.lists()[label]->data.view(), + n_rows, + index.pq_dim(), + offset_or_indices, + index.pq_bits(), + resource::get_cuda_stream(res)); +} + /** A consumer for the `run_on_list` and `run_on_vector` that approximates the original input data. */ struct reconstruct_vectors { @@ -850,6 +997,101 @@ void pack_list_data(raft::resources const& res, resource::get_cuda_stream(res)); } +/** + * A producer for the `write_vector` reads tightly packed flat codes. That is, + * the codes are not expanded to one code-per-byte. + */ +template +struct pack_contiguous { + const uint8_t* codes; + uint32_t code_size; + + /** + * Create a callable to be passed to `write_vector`. + * + * @param[in] codes flat compressed PQ codes + */ + __host__ __device__ inline pack_contiguous(const uint8_t* codes, uint32_t pq_dim) + : codes{codes}, code_size{raft::ceildiv(pq_dim * PqBits, 8)} + { + } + + /** Read j-th component (code) of the i-th vector from the source. */ + __host__ __device__ inline auto operator()(uint32_t i, uint32_t j) -> uint8_t + { + bitfield_view_t code_view{const_cast(codes + i * code_size)}; + return uint8_t(code_view[j]); + } +}; + +template +__launch_bounds__(BlockSize) RAFT_KERNEL pack_contiguous_list_data_kernel( + device_mdspan::list_extents, row_major> list_data, + const uint8_t* codes, + uint32_t n_rows, + uint32_t pq_dim, + std::variant offset_or_indices) +{ + write_list( + list_data, offset_or_indices, n_rows, pq_dim, pack_contiguous(codes, pq_dim)); +} + +/** + * Write flat PQ codes into an existing list by the given offset. + * + * NB: no memory allocation happens here; the list must fit the data (offset + n_rows). + * + * @param[out] list_data the packed ivf::list data. + * @param[in] codes flat compressed PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)] + * @param[in] offset_or_indices how many records in the list to skip or the exact indices. + * @param[in] pq_bits codebook size (1 << pq_bits) + * @param[in] stream + */ +inline void pack_contiguous_list_data( + device_mdspan::list_extents, row_major> list_data, + const uint8_t* codes, + uint32_t n_rows, + uint32_t pq_dim, + std::variant offset_or_indices, + uint32_t pq_bits, + rmm::cuda_stream_view stream) +{ + if (n_rows == 0) { return; } + + constexpr uint32_t kBlockSize = 256; + dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); + dim3 threads(kBlockSize, 1, 1); + auto kernel = [pq_bits]() { + switch (pq_bits) { + case 4: return pack_contiguous_list_data_kernel; + case 5: return pack_contiguous_list_data_kernel; + case 6: return pack_contiguous_list_data_kernel; + case 7: return pack_contiguous_list_data_kernel; + case 8: return pack_contiguous_list_data_kernel; + default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); + } + }(); + kernel<<>>(list_data, codes, n_rows, pq_dim, offset_or_indices); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +void pack_contiguous_list_data(raft::resources const& res, + index* index, + const uint8_t* new_codes, + uint32_t n_rows, + uint32_t label, + std::variant offset_or_indices) +{ + pack_contiguous_list_data(index->lists()[label]->data.view(), + new_codes, + n_rows, + index->pq_dim(), + offset_or_indices, + index->pq_bits(), + resource::get_cuda_stream(res)); +} + /** * * A producer for the `write_list` and `write_vector` that encodes level-1 input vector residuals @@ -1634,35 +1876,6 @@ auto build(raft::resources const& handle, labels_view, utils::mapping()); - { - // combine cluster_centers and their norms - RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle(), - sizeof(float) * index.dim_ext(), - cluster_centers, - sizeof(float) * index.dim(), - sizeof(float) * index.dim(), - index.n_lists(), - cudaMemcpyDefault, - stream)); - - rmm::device_uvector center_norms(index.n_lists(), stream, device_memory); - raft::linalg::rowNorm(center_norms.data(), - cluster_centers, - index.dim(), - index.n_lists(), - raft::linalg::L2Norm, - true, - stream); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle() + index.dim(), - sizeof(float) * index.dim_ext(), - center_norms.data(), - sizeof(float), - sizeof(float), - index.n_lists(), - cudaMemcpyDefault, - stream)); - } - // Make rotation matrix make_rotation_matrix(handle, params.force_random_rotation, @@ -1670,24 +1883,7 @@ auto build(raft::resources const& handle, index.dim(), index.rotation_matrix().data_handle()); - // Rotate cluster_centers - float alpha = 1.0; - float beta = 0.0; - linalg::gemm(handle, - true, - false, - index.rot_dim(), - index.n_lists(), - index.dim(), - &alpha, - index.rotation_matrix().data_handle(), - index.dim(), - cluster_centers, - index.dim(), - &beta, - index.centers_rot().data_handle(), - index.rot_dim(), - stream); + set_centers(handle, &index, cluster_centers); // Train PQ codebooks switch (index.codebook_kind()) { diff --git a/cpp/include/raft/neighbors/ivf_flat_codepacker.hpp b/cpp/include/raft/neighbors/ivf_flat_codepacker.hpp index 4594332fdf..5379788ab4 100644 --- a/cpp/include/raft/neighbors/ivf_flat_codepacker.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_codepacker.hpp @@ -19,36 +19,11 @@ #include #include #include +#include #include -#ifdef _RAFT_HAS_CUDA -#include -#else -#include -#endif - namespace raft::neighbors::ivf_flat::codepacker { -template -_RAFT_HOST_DEVICE inline auto roundDown(T x) -{ -#if defined(_RAFT_HAS_CUDA) - return Pow2::roundDown(x); -#else - return raft::round_down_safe(x, kIndexGroupSize); -#endif -} - -template -_RAFT_HOST_DEVICE inline auto mod(T x) -{ -#if defined(_RAFT_HAS_CUDA) - return Pow2::mod(x); -#else - return x % kIndexGroupSize; -#endif -} - /** * Write one flat code into a block by the given offset. The offset indicates the id of the record * in the list. This function interleaves the code and is intended to later copy the interleaved @@ -68,12 +43,12 @@ _RAFT_HOST_DEVICE void pack_1( const T* flat_code, T* block, uint32_t dim, uint32_t veclen, uint32_t offset) { // The data is written in interleaved groups of `index::kGroupSize` vectors - // using interleaved_group = Pow2; + using interleaved_group = neighbors::detail::div_utils; // Interleave dimensions of the source vector while recording it. // NB: such `veclen` is selected, that `dim % veclen == 0` - auto group_offset = roundDown(offset); - auto ingroup_id = mod(offset) * veclen; + auto group_offset = interleaved_group::roundDown(offset); + auto ingroup_id = interleaved_group::mod(offset) * veclen; for (uint32_t l = 0; l < dim; l += veclen) { for (uint32_t j = 0; j < veclen; j++) { @@ -100,11 +75,11 @@ _RAFT_HOST_DEVICE void unpack_1( const T* block, T* flat_code, uint32_t dim, uint32_t veclen, uint32_t offset) { // The data is written in interleaved groups of `index::kGroupSize` vectors - // using interleaved_group = Pow2; + using interleaved_group = neighbors::detail::div_utils; // NB: such `veclen` is selected, that `dim % veclen == 0` - auto group_offset = roundDown(offset); - auto ingroup_id = mod(offset) * veclen; + auto group_offset = interleaved_group::roundDown(offset); + auto ingroup_id = interleaved_group::mod(offset) * veclen; for (uint32_t l = 0; l < dim; l += veclen) { for (uint32_t j = 0; j < veclen; j++) { diff --git a/cpp/include/raft/neighbors/ivf_flat_helpers.cuh b/cpp/include/raft/neighbors/ivf_flat_helpers.cuh index 096e8051c3..7a05c9991c 100644 --- a/cpp/include/raft/neighbors/ivf_flat_helpers.cuh +++ b/cpp/include/raft/neighbors/ivf_flat_helpers.cuh @@ -22,7 +22,10 @@ #include #include +#include + namespace raft::neighbors::ivf_flat::helpers { +using namespace raft::spatial::knn::detail; // NOLINT /** * @defgroup ivf_flat_helpers Helper functions for manipulationg IVF Flat Index * @{ @@ -106,5 +109,37 @@ void unpack( res, list_data, veclen, offset, codes); } } // namespace codepacker + +/** + * @brief Public helper API to reset the data and indices ptrs, and the list sizes. Useful for + * externally modifying the index without going through the build stage. The data and indices of the + * IVF lists will be lost. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * using namespace raft::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // initialize an empty index + * ivf_flat::index index(res, index_params, D); + * // reset the index's state and list sizes + * ivf_flat::helpers::reset_index(res, &index); + * @endcode + * + * @tparam IdxT + * + * @param[in] res raft resource + * @param[inout] index pointer to IVF-PQ index + */ +template +void reset_index(const raft::resources& res, index* index) +{ + auto stream = resource::get_cuda_stream(res); + + utils::memzero(index->list_sizes().data_handle(), index->list_sizes().size(), stream); + utils::memzero(index->data_ptrs().data_handle(), index->data_ptrs().size(), stream); + utils::memzero(index->inds_ptrs().data_handle(), index->inds_ptrs().size(), stream); +} /** @} */ } // namespace raft::neighbors::ivf_flat::helpers diff --git a/cpp/include/raft/neighbors/ivf_pq_helpers.cuh b/cpp/include/raft/neighbors/ivf_pq_helpers.cuh index f00107f629..fec31f1c61 100644 --- a/cpp/include/raft/neighbors/ivf_pq_helpers.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_helpers.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -23,7 +24,10 @@ #include #include +#include + namespace raft::neighbors::ivf_pq::helpers { +using namespace raft::spatial::knn::detail; // NOLINT /** * @defgroup ivf_pq_helpers Helper functions for manipulationg IVF PQ Index * @{ @@ -71,6 +75,53 @@ inline void unpack( codes, list_data, offset, pq_bits, resource::get_cuda_stream(res)); } +/** + * @brief Unpack `n_rows` consecutive records of a single list (cluster) in the compressed index + * starting at given `offset`. The output codes of a single vector are contiguous, not expanded to + * one code per byte, which means the output has ceildiv(pq_dim * pq_bits, 8) bytes per PQ encoded + * vector. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * auto list_data = index.lists()[label]->data.view(); + * // allocate the buffer for the output + * uint32_t n_rows = 4; + * auto codes = raft::make_device_matrix( + * res, n_rows, raft::ceildiv(index.pq_dim() * index.pq_bits(), 8)); + * uint32_t offset = 0; + * // unpack n_rows elements from the list + * ivf_pq::helpers::codepacker::unpack_contiguous( + * res, list_data, index.pq_bits(), offset, n_rows, index.pq_dim(), codes.data_handle()); + * @endcode + * + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res raft resource + * @param[in] list_data block to read from + * @param[in] pq_bits bit length of encoded vector elements + * @param[in] offset + * How many records in the list to skip. + * @param[in] n_rows How many records to unpack + * @param[in] pq_dim The dimensionality of the PQ compressed records + * @param[out] codes + * the destination buffer [n_rows, ceildiv(pq_dim * pq_bits, 8)]. + * The length `n_rows` defines how many records to unpack, + * it must be smaller than the list size. + */ +inline void unpack_contiguous( + raft::resources const& res, + device_mdspan::list_extents, row_major> list_data, + uint32_t pq_bits, + uint32_t offset, + uint32_t n_rows, + uint32_t pq_dim, + uint8_t* codes) +{ + ivf_pq::detail::unpack_contiguous_list_data( + codes, list_data, n_rows, pq_dim, offset, pq_bits, resource::get_cuda_stream(res)); +} + /** * Write flat PQ codes into an existing list by the given offset. * @@ -87,7 +138,7 @@ inline void unpack( * res, make_const_mdspan(codes.view()), index.pq_bits(), 42, list_data); * @endcode * - * @param[in] res + * @param[in] res raft resource * @param[in] codes flat PQ codes, one code per byte [n_vec, pq_dim] * @param[in] pq_bits bit length of encoded vector elements * @param[in] offset how many records to skip before writing the data into the list @@ -102,6 +153,47 @@ inline void pack( { ivf_pq::detail::pack_list_data(list_data, codes, offset, pq_bits, resource::get_cuda_stream(res)); } + +/** + * Write flat PQ codes into an existing list by the given offset. The input codes of a single vector + * are contiguous (not expanded to one code per byte). + * + * NB: no memory allocation happens here; the list must fit the data (offset + n_rows records). + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * auto list_data = index.lists()[label]->data.view(); + * // allocate the buffer for the input codes + * auto codes = raft::make_device_matrix( + * res, n_rows, raft::ceildiv(index.pq_dim() * index.pq_bits(), 8)); + * ... prepare compressed vectors to pack into the list in codes ... + * // write codes into the list starting from the 42nd position. If the current size of the list + * // is greater than 42, this will overwrite the codes starting at this offset. + * ivf_pq::helpers::codepacker::pack_contiguous( + * res, codes.data_handle(), n_rows, index.pq_dim(), index.pq_bits(), 42, list_data); + * @endcode + * + * @param[in] res raft resource + * @param[in] codes flat PQ codes, [n_vec, ceildiv(pq_dim * pq_bits, 8)] + * @param[in] n_rows number of records + * @param[in] pq_dim + * @param[in] pq_bits bit length of encoded vector elements + * @param[in] offset how many records to skip before writing the data into the list + * @param[in] list_data block to write into + */ +inline void pack_contiguous( + raft::resources const& res, + const uint8_t* codes, + uint32_t n_rows, + uint32_t pq_dim, + uint32_t pq_bits, + uint32_t offset, + device_mdspan::list_extents, row_major> list_data) +{ + ivf_pq::detail::pack_contiguous_list_data( + list_data, codes, n_rows, pq_dim, offset, pq_bits, resource::get_cuda_stream(res)); +} } // namespace codepacker /** @@ -122,7 +214,7 @@ inline void pack( * ivf_pq::helpers::pack_list_data(res, &index, codes_to_pack, label, 42); * @endcode * - * @param[in] res + * @param[in] res raft resource * @param[inout] index IVF-PQ index. * @param[in] codes flat PQ codes, one code per byte [n_rows, pq_dim] * @param[in] label The id of the list (cluster) into which we write. @@ -138,6 +230,56 @@ void pack_list_data(raft::resources const& res, ivf_pq::detail::pack_list_data(res, index, codes, label, offset); } +/** + * Write flat PQ codes into an existing list by the given offset. Use this when the input + * vectors are PQ encoded and not expanded to one code per byte. + * + * The list is identified by its label. + * + * NB: no memory allocation happens here; the list into which the vectors are packed must fit offset + * + n_rows rows. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * raft::resources res; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_pq::build(res, index_params, dataset, N, D); + * // allocate the buffer for n_rows input codes. Each vector occupies + * // raft::ceildiv(index.pq_dim() * index.pq_bits(), 8) bytes because + * // codes are compressed and without gaps. + * auto codes = raft::make_device_matrix( + * res, n_rows, raft::ceildiv(index.pq_dim() * index.pq_bits(), 8)); + * ... prepare the compressed vectors to pack into the list in codes ... + * // the first n_rows codes in the fourth IVF list are to be overwritten. + * uint32_t label = 3; + * // write codes into the list starting from the 0th position + * ivf_pq::helpers::pack_contiguous_list_data( + * res, &index, codes.data_handle(), n_rows, label, 0); + * @endcode + * + * @tparam IdxT + * + * @param[in] res raft resource + * @param[inout] index pointer to IVF-PQ index + * @param[in] codes flat contiguous PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)] + * @param[in] n_rows how many records to pack + * @param[in] label The id of the list (cluster) into which we write. + * @param[in] offset how many records to skip before writing the data into the list + */ +template +void pack_contiguous_list_data(raft::resources const& res, + index* index, + uint8_t* codes, + uint32_t n_rows, + uint32_t label, + uint32_t offset) +{ + ivf_pq::detail::pack_contiguous_list_data(res, index, codes, n_rows, label, offset); +} + /** * @brief Unpack `n_take` consecutive records of a single list (cluster) in the compressed index * starting at given `offset`, one code per byte (independently of pq_bits). @@ -200,8 +342,8 @@ void unpack_list_data(raft::resources const& res, * * @tparam IdxT type of the indices in the source dataset * - * @param[in] res - * @param[in] index + * @param[in] res raft resource + * @param[in] index IVF-PQ index (passed by reference) * @param[in] in_cluster_indices * The offsets of the selected indices within the cluster. * @param[out] out_codes @@ -221,6 +363,53 @@ void unpack_list_data(raft::resources const& res, return ivf_pq::detail::unpack_list_data(res, index, out_codes, label, in_cluster_indices); } +/** + * @brief Unpack `n_rows` consecutive PQ encoded vectors of a single list (cluster) in the + * compressed index starting at given `offset`, not expanded to one code per byte. Each code in the + * output buffer occupies ceildiv(index.pq_dim() * index.pq_bits(), 8) bytes. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * // We will unpack the whole fourth cluster + * uint32_t label = 3; + * // Get the list size + * uint32_t list_size = 0; + * raft::update_host(&list_size, index.list_sizes().data_handle() + label, 1, + * raft::resource::get_cuda_stream(res)); raft::resource::sync_stream(res); + * // allocate the buffer for the output + * auto codes = raft::make_device_matrix(res, list_size, raft::ceildiv(index.pq_dim() * + * index.pq_bits(), 8)); + * // unpack the whole list + * ivf_pq::helpers::unpack_list_data(res, index, codes.data_handle(), list_size, label, 0); + * @endcode + * + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res raft resource + * @param[in] index IVF-PQ index (passed by reference) + * @param[out] out_codes + * the destination buffer [n_rows, ceildiv(index.pq_dim() * index.pq_bits(), 8)]. + * The length `n_rows` defines how many records to unpack, + * offset + n_rows must be smaller than or equal to the list size. + * @param[in] n_rows how many codes to unpack + * @param[in] label + * The id of the list (cluster) to decode. + * @param[in] offset + * How many records in the list to skip. + */ +template +void unpack_contiguous_list_data(raft::resources const& res, + const index& index, + uint8_t* out_codes, + uint32_t n_rows, + uint32_t label, + uint32_t offset) +{ + return ivf_pq::detail::unpack_contiguous_list_data( + res, index, out_codes, n_rows, label, offset); +} + /** * @brief Decode `n_take` consecutive records of a single list (cluster) in the compressed index * starting at given `offset`. @@ -232,7 +421,7 @@ void unpack_list_data(raft::resources const& res, * // Get the list size * uint32_t list_size = 0; * raft::copy(&list_size, index.list_sizes().data_handle() + label, 1, - * resource::get_cuda_stream(res)); resource::sync_stream(res); + * resource::get_cuda_stream(res)); resource::sync_stream(res); * // allocate the buffer for the output * auto decoded_vectors = raft::make_device_matrix(res, list_size, index.dim()); * // decode the whole list @@ -397,6 +586,7 @@ void extend_list(raft::resources const& res, * @endcode * * @tparam IdxT + * * @param[in] res * @param[inout] index * @param[in] label the id of the target list (cluster). @@ -407,5 +597,197 @@ void erase_list(raft::resources const& res, index* index, uint32_t label) ivf_pq::detail::erase_list(res, index, label); } +/** + * @brief Public helper API to reset the data and indices ptrs, and the list sizes. Useful for + * externally modifying the index without going through the build stage. The data and indices of the + * IVF lists will be lost. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * using namespace raft::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // initialize an empty index + * ivf_pq::index index(res, index_params, D); + * // reset the index's state and list sizes + * ivf_pq::helpers::reset_index(res, &index); + * @endcode + * + * @tparam IdxT + * + * @param[in] res raft resource + * @param[inout] index pointer to IVF-PQ index + */ +template +void reset_index(const raft::resources& res, index* index) +{ + auto stream = resource::get_cuda_stream(res); + + utils::memzero( + index->accum_sorted_sizes().data_handle(), index->accum_sorted_sizes().size(), stream); + utils::memzero(index->list_sizes().data_handle(), index->list_sizes().size(), stream); + utils::memzero(index->data_ptrs().data_handle(), index->data_ptrs().size(), stream); + utils::memzero(index->inds_ptrs().data_handle(), index->inds_ptrs().size(), stream); +} + +/** + * @brief Public helper API exposing the computation of the index's rotation matrix. + * NB: This is to be used only when the rotation matrix is not already computed through + * raft::neighbors::ivf_pq::build. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * // use default index parameters + * ivf_pq::index_params index_params; + * // force random rotation + * index_params.force_random_rotation = true; + * // initialize an empty index + * raft::neighbors::ivf_pq::index index(res, index_params, D); + * // reset the index + * reset_index(res, &index); + * // compute the rotation matrix with random_rotation + * raft::neighbors::ivf_pq::helpers::make_rotation_matrix( + * res, &index, index_params.force_random_rotation); + * @endcode + * + * @tparam IdxT + * + * @param[in] res raft resource + * @param[inout] index pointer to IVF-PQ index + * @param[in] force_random_rotation whether to apply a random rotation matrix on the input data. See + * raft::neighbors::ivf_pq::index_params for more details. + */ +template +void make_rotation_matrix(raft::resources const& res, + index* index, + bool force_random_rotation) +{ + raft::neighbors::ivf_pq::detail::make_rotation_matrix(res, + force_random_rotation, + index->rot_dim(), + index->dim(), + index->rotation_matrix().data_handle()); +} + +/** + * @brief Public helper API for externally modifying the index's IVF centroids. + * NB: The index must be reset before this. Use raft::neighbors::ivf_pq::extend to construct IVF + lists according to new centroids. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * // allocate the buffer for the input centers + * auto cluster_centers = raft::make_device_matrix(res, index.n_lists(), + index.dim()); + * ... prepare ivf centroids in cluster_centers ... + * // reset the index + * reset_index(res, &index); + * // recompute the state of the index + * raft::neighbors::ivf_pq::helpers::recompute_internal_state(res, index); + * // Write the IVF centroids + * raft::neighbors::ivf_pq::helpers::set_centers( + res, + &index, + cluster_centers); + * @endcode + * + * @tparam IdxT + * + * @param[in] res raft resource + * @param[inout] index pointer to IVF-PQ index + * @param[in] cluster_centers new cluster centers [index.n_lists(), index.dim()] + */ +template +void set_centers(raft::resources const& res, + index* index, + device_matrix_view cluster_centers) +{ + RAFT_EXPECTS(cluster_centers.extent(0) == index->n_lists(), + "Number of rows in the new centers must be equal to the number of IVF lists"); + RAFT_EXPECTS(cluster_centers.extent(1) == index->dim(), + "Number of columns in the new cluster centers and index dim are different"); + RAFT_EXPECTS(index->size() == 0, "Index must be empty"); + ivf_pq::detail::set_centers(res, index, cluster_centers.data_handle()); +} + +/** + * @brief Helper exposing the re-computation of list sizes and related arrays if IVF lists have been + * modified. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * raft::resources res; + * // use default index parameters + * ivf_pq::index_params index_params; + * // initialize an empty index + * ivf_pq::index index(res, index_params, D); + * ivf_pq::helpers::reset_index(res, &index); + * // resize the first IVF list to hold 5 records + * auto spec = list_spec{ + * index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; + * uint32_t new_size = 5; + * ivf::resize_list(res, list, spec, new_size, 0); + * raft::update_device(index.list_sizes(), &new_size, 1, stream); + * // recompute the internal state of the index + * ivf_pq::recompute_internal_state(res, &index); + * @endcode + * + * @tparam IdxT + * + * @param[in] res raft resource + * @param[inout] index pointer to IVF-PQ index + */ +template +void recompute_internal_state(const raft::resources& res, index* index) +{ + auto& list = index->lists()[0]; + ivf_pq::detail::recompute_internal_state(res, *index); +} + +/** + * @brief Public helper API for fetching a trained index's IVF centroids into a buffer that may be + * allocated on either host or device. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * // allocate the buffer for the output centers + * auto cluster_centers = raft::make_device_matrix( + * res, index.n_lists(), index.dim()); + * // Extract the IVF centroids into the buffer + * raft::neighbors::ivf_pq::helpers::extract_centers(res, index, cluster_centers.data_handle()); + * @endcode + * + * @tparam IdxT + * + * @param[in] res raft resource + * @param[in] index IVF-PQ index (passed by reference) + * @param[out] cluster_centers IVF cluster centers [index.n_lists(), index.dim] + */ +template +void extract_centers(raft::resources const& res, + const index& index, + raft::device_matrix_view cluster_centers) +{ + RAFT_EXPECTS(cluster_centers.extent(0) == index.n_lists(), + "Number of rows in the output buffer for cluster centers must be equal to the " + "number of IVF lists"); + RAFT_EXPECTS( + cluster_centers.extent(1) == index.dim(), + "Number of columns in the output buffer for cluster centers and index dim are different"); + auto stream = resource::get_cuda_stream(res); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(cluster_centers.data_handle(), + sizeof(float) * index.dim(), + index.centers().data_handle(), + sizeof(float) * index.dim_ext(), + sizeof(float) * index.dim(), + index.n_lists(), + cudaMemcpyDefault, + stream)); +} /** @} */ } // namespace raft::neighbors::ivf_pq::helpers diff --git a/cpp/include/raft/neighbors/ivf_pq_types.hpp b/cpp/include/raft/neighbors/ivf_pq_types.hpp index 24df77b35a..45ab18c84f 100644 --- a/cpp/include/raft/neighbors/ivf_pq_types.hpp +++ b/cpp/include/raft/neighbors/ivf_pq_types.hpp @@ -487,6 +487,30 @@ struct index : ann::index { return centers_rot_.view(); } + /** fetch size of a particular IVF list in bytes using the list extents. + * Usage example: + * @code{.cpp} + * raft::resources res; + * // use default index params + * ivf_pq::index_params index_params; + * // extend the IVF lists while building the index + * index_params.add_data_on_build = true; + * // create and fill the index from a [N, D] dataset + * auto index = raft::neighbors::ivf_pq::build(res, index_params, dataset, N, D); + * // Fetch the size of the fourth list + * uint32_t size = index.get_list_size_in_bytes(3); + * @endcode + * + * @param[in] label list ID + */ + inline auto get_list_size_in_bytes(uint32_t label) -> uint32_t + { + RAFT_EXPECTS(label < this->n_lists(), + "Expected label to be less than number of lists in the index"); + auto list_data = this->lists()[label]->data; + return list_data.size(); + } + private: raft::distance::DistanceType metric_; codebook_gen codebook_kind_; diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index bdb83ecfdc..eb30b60eca 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -380,6 +380,83 @@ class ivf_pq_test : public ::testing::TestWithParam { list_data_size, Compare{})); } + void check_packing_contiguous(index* index, uint32_t label) + { + auto old_list = index->lists()[label]; + auto n_rows = old_list->size.load(); + + if (n_rows == 0) { return; } + + auto codes = make_device_matrix(handle_, n_rows, index->pq_dim()); + auto indices = make_device_vector(handle_, n_rows); + copy(indices.data_handle(), old_list->indices.data_handle(), n_rows, stream_); + + uint32_t code_size = ceildiv(index->pq_dim() * index->pq_bits(), 8); + + auto codes_compressed = make_device_matrix(handle_, n_rows, code_size); + + ivf_pq::helpers::unpack_contiguous_list_data( + handle_, *index, codes_compressed.data_handle(), n_rows, label, 0); + ivf_pq::helpers::erase_list(handle_, index, label); + ivf_pq::detail::extend_list_prepare(handle_, index, make_const_mdspan(indices.view()), label); + ivf_pq::helpers::pack_contiguous_list_data( + handle_, index, codes_compressed.data_handle(), n_rows, label, 0); + ivf_pq::helpers::recompute_internal_state(handle_, index); + + auto& new_list = index->lists()[label]; + ASSERT_NE(old_list.get(), new_list.get()) + << "The old list should have been shared and retained after ivf_pq index has erased the " + "corresponding cluster."; + auto list_data_size = (n_rows / ivf_pq::kIndexGroupSize) * new_list->data.extent(1) * + new_list->data.extent(2) * new_list->data.extent(3); + + ASSERT_TRUE(old_list->data.size() >= list_data_size); + ASSERT_TRUE(new_list->data.size() >= list_data_size); + ASSERT_TRUE(devArrMatch(old_list->data.data_handle(), + new_list->data.data_handle(), + list_data_size, + Compare{})); + + // Pack a few vectors back to the list. + uint32_t row_offset = 9; + uint32_t n_vec = 3; + ASSERT_TRUE(row_offset + n_vec < n_rows); + size_t offset = row_offset * code_size; + auto codes_to_pack = make_device_matrix_view( + codes_compressed.data_handle() + offset, n_vec, index->pq_dim()); + ivf_pq::helpers::pack_contiguous_list_data( + handle_, index, codes_to_pack.data_handle(), n_vec, label, row_offset); + ASSERT_TRUE(devArrMatch(old_list->data.data_handle(), + new_list->data.data_handle(), + list_data_size, + Compare{})); + + // // Another test with the API that take list_data directly + auto list_data = index->lists()[label]->data.view(); + uint32_t n_take = 4; + ASSERT_TRUE(row_offset + n_take < n_rows); + auto codes2 = raft::make_device_matrix(handle_, n_take, code_size); + ivf_pq::helpers::codepacker::unpack_contiguous(handle_, + list_data, + index->pq_bits(), + row_offset, + n_take, + index->pq_dim(), + codes2.data_handle()); + + // Write it back + ivf_pq::helpers::codepacker::pack_contiguous(handle_, + codes2.data_handle(), + n_vec, + index->pq_dim(), + index->pq_bits(), + row_offset, + list_data); + ASSERT_TRUE(devArrMatch(old_list->data.data_handle(), + new_list->data.data_handle(), + list_data_size, + Compare{})); + } template void run(BuildIndex build_index) @@ -398,6 +475,7 @@ class ivf_pq_test : public ::testing::TestWithParam { case 1: { // Dump and re-write codes for one label check_packing(&index, label); + check_packing_contiguous(&index, label); } break; default: { // check a small subset of data in a randomly chosen cluster to see if the data @@ -962,6 +1040,32 @@ inline auto special_cases() -> test_cases_t x.search_params.n_probes = 100; }); + ADD_CASE({ + x.num_db_vecs = 4335; + x.dim = 4; + x.num_queries = 100000; + x.k = 12; + x.index_params.metric = distance::DistanceType::L2Expanded; + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; + x.index_params.pq_dim = 2; + x.index_params.pq_bits = 8; + x.index_params.n_lists = 69; + x.search_params.n_probes = 69; + }); + + ADD_CASE({ + x.num_db_vecs = 4335; + x.dim = 4; + x.num_queries = 100000; + x.k = 12; + x.index_params.metric = distance::DistanceType::L2Expanded; + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; + x.index_params.pq_dim = 2; + x.index_params.pq_bits = 8; + x.index_params.n_lists = 69; + x.search_params.n_probes = 69; + }); + return xs; } diff --git a/docs/source/raft_ann_benchmarks.md b/docs/source/raft_ann_benchmarks.md index e6c4eaedd0..dcdfc2cec9 100644 --- a/docs/source/raft_ann_benchmarks.md +++ b/docs/source/raft_ann_benchmarks.md @@ -198,27 +198,33 @@ options: --dataset-path DATASET_PATH path to dataset folder (default: ${RAPIDS_DATASET_ROOT_DIR}) ``` -Build statistics CSV file is stored in `/result/build/` -and index search statistics CSV file in `/result/search/`. +Build statistics CSV file is stored in `/result/build/` +and index search statistics CSV file in `/result/search/`, where suffix has three values: +1. `raw`: All search results are exported +2. `throughput`: Pareto frontier of throughput results is exported +3. `latency`: Pareto frontier of latency results is exported + ### Step 4: Plot Results The script `raft-ann-bench.plot` will plot results for all algorithms found in index search statistics -CSV file in `/result/search/<-k{k}-batch_size{batch_size}>.csv`. +CSV files `/result/search/*.csv`. The usage of this script is: ```bash -usage: __main__.py [-h] [--dataset DATASET] [--dataset-path DATASET_PATH] [--output-filepath OUTPUT_FILEPATH] [--algorithms ALGORITHMS] [--groups GROUPS] [--algo-groups ALGO_GROUPS] [-k COUNT] - [-bs BATCH_SIZE] [--build] [--search] [--x-scale X_SCALE] [--y-scale {linear,log,symlog,logit}] [--raw] +usage: [-h] [--dataset DATASET] [--dataset-path DATASET_PATH] [--output-filepath OUTPUT_FILEPATH] [--algorithms ALGORITHMS] [--groups GROUPS] [--algo-groups ALGO_GROUPS] + [-k COUNT] [-bs BATCH_SIZE] [--build] [--search] [--x-scale X_SCALE] [--y-scale {linear,log,symlog,logit}] [--mode {throughput,latency}] [--time-unit {s,ms,us}] + [--raw] options: -h, --help show this help message and exit --dataset DATASET dataset to plot (default: glove-100-inner) --dataset-path DATASET_PATH - path to dataset folder (default: os.getcwd()/datasets/) + path to dataset folder (default: /home/coder/raft/datasets/) --output-filepath OUTPUT_FILEPATH - directory for PNG to be saved (default: os.getcwd()) + directory for PNG to be saved (default: /home/coder/raft) --algorithms ALGORITHMS - plot only comma separated list of named algorithms. If parameters `groups` and `algo-groups are both undefined, then group `base` is plot by default (default: None) + plot only comma separated list of named algorithms. If parameters `groups` and `algo-groups are both undefined, then group `base` is plot by default + (default: None) --groups GROUPS plot only comma separated groups of parameters (default: base) --algo-groups ALGO_GROUPS, --algo-groups ALGO_GROUPS add comma separated . to plot. Example usage: "--algo-groups=raft_cagra.large,hnswlib.large" (default: None) @@ -231,8 +237,14 @@ options: --x-scale X_SCALE Scale to use when drawing the X-axis. Typically linear, logit or a2 (default: linear) --y-scale {linear,log,symlog,logit} Scale to use when drawing the Y-axis (default: linear) - --raw Show raw results (not just Pareto frontier) in faded colours (default: False) + --mode {throughput,latency} + search mode whose Pareto frontier is used on the y-axis (default: throughput) + --time-unit {s,ms,us} + time unit to plot when mode is latency (default: ms) + --raw Show raw results (not just Pareto frontier) of mode arg (default: False) ``` +`mode`: plots pareto frontier of `throughput` or `latency` results exported in the previous step + `algorithms`: plots all algorithms that it can find results for the specified `dataset`. By default, only `base` group will be plotted. `groups`: plot only specific groups of parameters configurations for an algorithm. Groups are defined in YAML configs (see `configuration`), and by default run `base` group diff --git a/python/raft-ann-bench/src/raft-ann-bench/data_export/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/data_export/__main__.py index 4978c99d60..572b81bbe2 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/data_export/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/data_export/__main__.py @@ -43,9 +43,26 @@ ) skip_search_cols = ( - set(["recall", "qps", "items_per_second", "Recall"]) | skip_build_cols + set(["recall", "qps", "latency", "items_per_second", "Recall", "Latency"]) + | skip_build_cols ) +metrics = { + "k-nn": { + "description": "Recall", + "worst": float("-inf"), + "lim": [0.0, 1.03], + }, + "throughput": { + "description": "Queries per second (1/s)", + "worst": float("-inf"), + }, + "latency": { + "description": "Search Latency (s)", + "worst": float("inf"), + }, +} + def read_file(dataset, dataset_path, method): dir = os.path.join(dataset_path, dataset, "result", method) @@ -92,6 +109,31 @@ def convert_json_to_csv_build(dataset, dataset_path): traceback.print_exc() +def create_pointset(data, xn, yn): + xm, ym = (metrics[xn], metrics[yn]) + rev_y = -1 if ym["worst"] < 0 else 1 + rev_x = -1 if xm["worst"] < 0 else 1 + + y_idx = 3 if yn == "throughput" else 4 + data.sort(key=lambda t: (rev_y * t[y_idx], rev_x * t[2])) + + lines = [] + last_x = xm["worst"] + comparator = ( + (lambda xv, lx: xv > lx) if last_x < 0 else (lambda xv, lx: xv < lx) + ) + for d in data: + if comparator(d[2], last_x): + last_x = d[2] + lines.append(d) + return lines + + +def get_frontier(df, metric): + lines = create_pointset(df.values.tolist(), "k-nn", metric) + return pd.DataFrame(lines, columns=df.columns) + + def convert_json_to_csv_search(dataset, dataset_path): for file, algo_name, df in read_file(dataset, dataset_path, "search"): try: @@ -100,14 +142,21 @@ def convert_json_to_csv_search(dataset, dataset_path): ) algo_name = algo_name.replace("_base", "") df["name"] = df["name"].str.split("/").str[0] - write = pd.DataFrame( - { - "algo_name": [algo_name] * len(df), - "index_name": df["name"], - "recall": df["Recall"], - "qps": df["items_per_second"], - } - ) + try: + write = pd.DataFrame( + { + "algo_name": [algo_name] * len(df), + "index_name": df["name"], + "recall": df["Recall"], + "throughput": df["items_per_second"], + "latency": df["Latency"], + } + ) + except Exception as e: + print( + "Search file %s (%s) missing a key. Skipping..." + % (file, e) + ) for name in df: if name not in skip_search_cols: write[name] = df[name] @@ -120,20 +169,29 @@ def convert_json_to_csv_search(dataset, dataset_path): write["build cpu_time"] = None write["build GPU"] = None - for col_idx in range(6, len(build_df.columns)): - col_name = build_df.columns[col_idx] - write[col_name] = None - - for s_index, search_row in write.iterrows(): - for b_index, build_row in build_df.iterrows(): - if search_row["index_name"] == build_row["index_name"]: - write.iloc[s_index, write_ncols] = build_df.iloc[ - b_index, 2 - ] - write.iloc[ - s_index, write_ncols + 1 : - ] = build_df.iloc[b_index, 3:] - break + try: + for col_idx in range(6, len(build_df.columns)): + col_name = build_df.columns[col_idx] + write[col_name] = None + + for s_index, search_row in write.iterrows(): + for b_index, build_row in build_df.iterrows(): + if ( + search_row["index_name"] + == build_row["index_name"] + ): + write.iloc[ + s_index, write_ncols + ] = build_df.iloc[b_index, 2] + write.iloc[ + s_index, write_ncols + 1 : + ] = build_df.iloc[b_index, 3:] + break + except Exception as e: + print( + "Build file %s (%s) missing a key. Skipping..." + % (build_file, e) + ) else: warnings.warn( f"Build CSV not found for {algo_name}, " @@ -141,7 +199,13 @@ def convert_json_to_csv_search(dataset, dataset_path): "appended in the Search CSV" ) - write.to_csv(file.replace(".json", ".csv"), index=False) + write.to_csv(file.replace(".json", "_raw.csv"), index=False) + throughput = get_frontier(write, "throughput") + throughput.to_csv( + file.replace(".json", "_throughput.csv"), index=False + ) + latency = get_frontier(write, "latency") + latency.to_csv(file.replace(".json", "_latency.csv"), index=False) except Exception as e: print( "An error occurred processing file %s (%s). Skipping..." diff --git a/python/raft-ann-bench/src/raft-ann-bench/plot/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/plot/__main__.py index c45ff5b14e..8bd54170c9 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/plot/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/plot/__main__.py @@ -38,10 +38,14 @@ "worst": float("-inf"), "lim": [0.0, 1.03], }, - "qps": { + "throughput": { "description": "Queries per second (1/s)", "worst": float("-inf"), }, + "latency": { + "description": "Search Latency (s)", + "worst": float("inf"), + }, } @@ -98,53 +102,20 @@ def create_linestyles(unique_algorithms): ) -def get_up_down(metric): - if metric["worst"] == float("inf"): - return "down" - return "up" - - -def get_left_right(metric): - if metric["worst"] == float("inf"): - return "left" - return "right" - - -def create_pointset(data, xn, yn): - xm, ym = (metrics[xn], metrics[yn]) - rev_y = -1 if ym["worst"] < 0 else 1 - rev_x = -1 if xm["worst"] < 0 else 1 - data.sort(key=lambda t: (rev_y * t[-1], rev_x * t[-2])) - - axs, ays, als, aidxs = [], [], [], [] - # Generate Pareto frontier - xs, ys, ls, idxs = [], [], [], [] - last_x = xm["worst"] - comparator = ( - (lambda xv, lx: xv > lx) if last_x < 0 else (lambda xv, lx: xv < lx) - ) - for algo_name, index_name, xv, yv in data: - if not xv or not yv: - continue - axs.append(xv) - ays.append(yv) - als.append(algo_name) - aidxs.append(algo_name) - if comparator(xv, last_x): - last_x = xv - xs.append(xv) - ys.append(yv) - ls.append(algo_name) - idxs.append(index_name) - return xs, ys, ls, idxs, axs, ays, als, aidxs - - def create_plot_search( - all_data, raw, x_scale, y_scale, fn_out, linestyles, dataset, k, batch_size + all_data, + x_scale, + y_scale, + fn_out, + linestyles, + dataset, + k, + batch_size, + mode, + time_unit, ): xn = "k-nn" - yn = "qps" - xm, ym = (metrics[xn], metrics[yn]) + xm, ym = (metrics[xn], metrics[mode]) # Now generate each plot handles = [] labels = [] @@ -152,17 +123,15 @@ def create_plot_search( # Sorting by mean y-value helps aligning plots with labels def mean_y(algo): - xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset( - all_data[algo], xn, yn - ) - return -np.log(np.array(ys)).mean() + points = np.array(all_data[algo], dtype=object) + return -np.log(np.array(points[:, 3], dtype=np.float32)).mean() # Find range for logit x-scale min_x, max_x = 1, 0 for algo in sorted(all_data.keys(), key=mean_y): - xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset( - all_data[algo], xn, yn - ) + points = np.array(all_data[algo], dtype=object) + xs = points[:, 2] + ys = points[:, 3] min_x = min([min_x] + [x for x in xs if x > 0]) max_x = max([max_x] + [x for x in xs if x < 1]) color, faded, linestyle, marker = linestyles[algo] @@ -178,23 +147,15 @@ def mean_y(algo): marker=marker, ) handles.append(handle) - if raw: - (handle2,) = plt.plot( - axs, - ays, - "-", - label=algo, - color=faded, - ms=5, - mew=2, - lw=2, - marker=marker, - ) + labels.append(algo) ax = plt.gca() - ax.set_ylabel(ym["description"]) - ax.set_xlabel(xm["description"]) + y_description = ym["description"] + if mode == "latency": + y_description = y_description.replace("(s)", f"({time_unit})") + ax.set_ylabel(y_description) + ax.set_xlabel("Recall") # Custom scales of the type --x-scale a3 if x_scale[0] == "a": alpha = float(x_scale[1:]) @@ -250,10 +211,8 @@ def inv_fun(x): def create_plot_build( - build_results, search_results, linestyles, fn_out, dataset, k, batch_size + build_results, search_results, linestyles, fn_out, dataset ): - xn = "k-nn" - yn = "qps" qps_85 = [-1] * len(linestyles) bt_85 = [0] * len(linestyles) @@ -271,16 +230,17 @@ def create_plot_build( colors = OrderedDict() # Sorting by mean y-value helps aligning plots with labels + def mean_y(algo): - xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset( - search_results[algo], xn, yn - ) - return -np.log(np.array(ys)).mean() + points = np.array(search_results[algo], dtype=object) + return -np.log(np.array(points[:, 3], dtype=np.float32)).mean() for pos, algo in enumerate(sorted(search_results.keys(), key=mean_y)): - xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset( - search_results[algo], xn, yn - ) + points = np.array(search_results[algo], dtype=object) + xs = points[:, 2] + ys = points[:, 3] + ls = points[:, 0] + idxs = points[:, 1] # x is recall, y is qps, ls is algo_name, idxs is index_name for i in range(len(xs)): if xs[i] >= 0.85 and xs[i] < 0.9 and ys[i] > qps_85[pos]: @@ -311,11 +271,11 @@ def mean_y(algo): fig.savefig(fn_out) -def load_lines(results_path, result_files, method, index_key): +def load_lines(results_path, result_files, method, index_key, mode, time_unit): results = dict() for result_filename in result_files: - if result_filename.endswith(".csv"): + try: with open(os.path.join(results_path, result_filename), "r") as f: lines = f.readlines() lines = lines[:-1] if lines[-1] == "\n" else lines @@ -323,7 +283,8 @@ def load_lines(results_path, result_files, method, index_key): if method == "build": key_idx = [2] elif method == "search": - key_idx = [2, 3] + y_idx = 3 if mode == "throughput" else 4 + key_idx = [2, y_idx] for line in lines[1:]: split_lines = line.split(",") @@ -340,7 +301,22 @@ def load_lines(results_path, result_files, method, index_key): to_add = [algo_name, index_name] for key_i in key_idx: to_add.append(float(split_lines[key_i])) + if ( + mode == "latency" + and time_unit != "s" + and method == "search" + ): + to_add[-1] = ( + to_add[-1] * (10**3) + if time_unit == "ms" + else to_add[-1] * (10**6) + ) results[dict_key].append(to_add) + except Exception: + print( + f"An error occurred processing file {result_filename}. " + "Skipping..." + ) return results @@ -354,12 +330,31 @@ def load_all_results( batch_size, method, index_key, + raw, + mode, + time_unit, ): results_path = os.path.join(dataset_path, "result", method) result_files = os.listdir(results_path) - result_files = [ - result_file for result_file in result_files if ".csv" in result_file - ] + if method == "build": + result_files = [ + result_file + for result_file in result_files + if ".csv" in result_file + ] + elif method == "search": + if raw: + suffix = "_raw" + else: + suffix = f"_{mode}" + result_files = [ + result_file + for result_file in result_files + if f"{suffix}.csv" in result_file + ] + if len(result_files) == 0: + raise FileNotFoundError(f"No CSV result files found in {results_path}") + if method == "search": result_files = [ result_filename @@ -407,7 +402,9 @@ def load_all_results( final_results = final_results + final_algo_groups final_results = set(final_results) - results = load_lines(results_path, final_results, method, index_key) + results = load_lines( + results_path, final_results, method, index_key, mode, time_unit + ) return results @@ -481,9 +478,21 @@ def main(): choices=["linear", "log", "symlog", "logit"], default="linear", ) + parser.add_argument( + "--mode", + help="search mode whose Pareto frontier is used on the y-axis", + choices=["throughput", "latency"], + default="throughput", + ) + parser.add_argument( + "--time-unit", + help="time unit to plot when mode is latency", + choices=["s", "ms", "us"], + default="ms", + ) parser.add_argument( "--raw", - help="Show raw results (not just Pareto frontier) in faded colours", + help="Show raw results (not just Pareto frontier) of mode arg", action="store_true", ) @@ -528,12 +537,14 @@ def main(): batch_size, "search", "algo", + args.raw, + args.mode, + args.time_unit, ) linestyles = create_linestyles(sorted(search_results.keys())) if search: create_plot_search( search_results, - args.raw, args.x_scale, args.y_scale, search_output_filepath, @@ -541,6 +552,8 @@ def main(): args.dataset, k, batch_size, + args.mode, + args.time_unit, ) if build: build_results = load_all_results( @@ -552,6 +565,9 @@ def main(): batch_size, "build", "index", + args.raw, + args.mode, + args.time_unit, ) create_plot_build( build_results, @@ -559,8 +575,6 @@ def main(): linestyles, build_output_filepath, args.dataset, - k, - batch_size, )