Skip to content

Commit

Permalink
Merge branch 'branch-23.12' into huge_page
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet authored Oct 18, 2023
2 parents 3e9079d + 889e9f5 commit 15539ea
Show file tree
Hide file tree
Showing 176 changed files with 2,639 additions and 2,331 deletions.
5 changes: 4 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ repos:
hooks:
- id: rapids-dependency-file-generator
args: ["--clean"]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-json

default_language_version:
python: python3
112 changes: 30 additions & 82 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,106 +277,54 @@ pairwise_distance(in1, in2, out=output, metric="euclidean")

## Installing

RAFT itself can be installed through conda, [CMake Package Manager (CPM)](https://github.com/cpm-cmake/CPM.cmake), pip, or by building the repository from source. Please refer to the [build instructions](docs/source/build.md) for more a comprehensive guide on installing and building RAFT and using it in downstream projects.
RAFT's C++ and Python libraries can both be installed through Conda and the Python libraries through Pip.

### Conda

### Installing C++ and Python through Conda

The easiest way to install RAFT is through conda and several packages are provided.
- `libraft-headers` RAFT headers
- `libraft` (optional) shared library of pre-compiled template instantiations and runtime APIs.
- `pylibraft` (optional) Python wrappers around RAFT algorithms and primitives.
- `raft-dask` (optional) enables deployment of multi-node multi-GPU algorithms that use RAFT `raft::comms` in Dask clusters.
- `libraft-headers` C++ headers
- `libraft` (optional) C++ shared library containing pre-compiled template instantiations and runtime API.
- `pylibraft` (optional) Python library
- `raft-dask` (optional) Python library for deployment of multi-node multi-GPU algorithms that use the RAFT `raft::comms` abstraction layer in Dask clusters.
- `raft-ann-bench` (optional) Benchmarking tool for easily producing benchmarks that compare RAFT's vector search algorithms against other state-of-the-art implementations.
- `raft-ann-bench-cpu` (optional) Reproducible benchmarking tool similar to above, but doesn't require CUDA to be installed on the machine. Can be used to test in environments with competitive CPUs.

Use the following command, depending on your CUDA version, to install all of the RAFT packages with conda (replace `rapidsai` with `rapidsai-nightly` to install more up-to-date but less stable nightly packages). `mamba` is preferred over the `conda` command.
```bash
# for CUDA 11.8
mamba install -c rapidsai -c conda-forge -c nvidia raft-dask pylibraft cuda-version=11.8
```

Use the following command to install all of the RAFT packages with conda (replace `rapidsai` with `rapidsai-nightly` to install more up-to-date but less stable nightly packages). `mamba` is preferred over the `conda` command.
```bash
mamba install -c rapidsai -c conda-forge -c nvidia raft-dask pylibraft
# for CUDA 12.0
mamba install -c rapidsai -c conda-forge -c nvidia raft-dask pylibraft cuda-version=12.0
```

You can also install the conda packages individually using the `mamba` command above.
Note that the above commands will also install `libraft-headers` and `libraft`.

You can also install the conda packages individually using the `mamba` command above. For example, if you'd like to install RAFT's headers and pre-compiled shared library to use in your project:
```bash
# for CUDA 12.0
mamba install -c rapidsai -c conda-forge -c nvidia libraft libraft-headers cuda-version=12.0
```

After installing RAFT, `find_package(raft COMPONENTS compiled distributed)` can be used in your CUDA/C++ cmake build to compile and/or link against needed dependencies in your raft target. `COMPONENTS` are optional and will depend on the packages installed.
If installing the C++ APIs please see [using libraft](https://docs.rapids.ai/api/raft/nightly/using_libraft/) for more information on using the pre-compiled shared library. You can also refer to the [example C++ template project](https://github.com/rapidsai/raft/tree/branch-23.12/cpp/template) for a ready-to-go CMake configuration that you can drop into your project and build against installed RAFT development artifacts above.

### Pip
### Installing Python through Pip

pylibraft and raft-dask both have experimental packages that can be [installed through pip](https://rapids.ai/pip.html#install):
`pylibraft` and `raft-dask` both have experimental packages that can be [installed through pip](https://rapids.ai/pip.html#install):
```bash
pip install pylibraft-cu11 --extra-index-url=https://pypi.nvidia.com
pip install raft-dask-cu11 --extra-index-url=https://pypi.nvidia.com
```

### CMake & CPM

RAFT uses the [RAPIDS-CMake](https://github.com/rapidsai/rapids-cmake) library, which makes it easy to include in downstream cmake projects. RAPIDS-CMake provides a convenience layer around CPM. Please refer to [these instructions](https://github.com/rapidsai/rapids-cmake#installation) to install and use rapids-cmake in your project.

#### Example Template Project
These packages statically build RAFT's pre-compiled instantiations and so the C++ headers and pre-compiled shared library won't be readily available to use in your code.

You can find an [example RAFT](cpp/template/README.md) project template in the `cpp/template` directory, which demonstrates how to build a new application with RAFT or incorporate RAFT into an existing cmake project.
The [build instructions](https://docs.rapids.ai/api/raft/nightly/build/) contain more details on building RAFT from source and including it in downstream projects. You can also find a more comprehensive version of the above CPM code snippet the [Building RAFT C++ and Python from source](https://docs.rapids.ai/api/raft/nightly/build/#building-c-and-python-from-source) section of the build instructions.

#### CMake Targets

Additional CMake targets can be made available by adding components in the table below to the `RAFT_COMPONENTS` list above, separated by spaces. The `raft::raft` target will always be available. RAFT headers require, at a minimum, the CUDA toolkit libraries and RMM dependencies.

| Component | Target | Description | Base Dependencies |
|-------------|---------------------|----------------------------------------------------------|----------------------------------------|
| n/a | `raft::raft` | Full RAFT header library | CUDA toolkit, RMM, NVTX, CCCL, CUTLASS |
| compiled | `raft::compiled` | Pre-compiled template instantiations and runtime library | raft::raft |
| distributed | `raft::distributed` | Dependencies for `raft::comms` APIs | raft::raft, UCX, NCCL |

### Source

The easiest way to build RAFT from source is to use the `build.sh` script at the root of the repository:
1. Create an environment with the needed dependencies:
```
mamba env create --name raft_dev_env -f conda/environments/all_cuda-118_arch-x86_64.yaml
mamba activate raft_dev_env
```
```
./build.sh raft-dask pylibraft libraft tests bench --compile-lib
```
You can find an example [RAFT project template](cpp/template/README.md) in the `cpp/template` directory, which demonstrates how to build a new application with RAFT or incorporate RAFT into an existing CMake project.

The [build](docs/source/build.md) instructions contain more details on building RAFT from source and including it in downstream projects. You can also find a more comprehensive version of the above CPM code snippet the [Building RAFT C++ from source](docs/source/build.md#building-raft-c-from-source-in-cmake) section of the build instructions.

## Folder Structure and Contents

The folder structure mirrors other RAPIDS repos, with the following folders:

- `bench/ann`: Python scripts for running ANN benchmarks
- `ci`: Scripts for running CI in PRs
- `conda`: Conda recipes and development conda environments
- `cpp`: Source code for C++ libraries.
- `bench`: Benchmarks source code
- `cmake`: CMake modules and templates
- `doxygen`: Doxygen configuration
- `include`: The C++ API headers are fully-contained here (deprecated directories are excluded from the listing below)
- `cluster`: Basic clustering primitives and algorithms.
- `comms`: A multi-node multi-GPU communications abstraction layer for NCCL+UCX and MPI+NCCL, which can be deployed in Dask clusters using the `raft-dask` Python package.
- `core`: Core API headers which require minimal dependencies aside from RMM and Cudatoolkit. These are safe to expose on public APIs and do not require `nvcc` to build. This is the same for any headers in RAFT which have the suffix `*_types.hpp`.
- `distance`: Distance primitives
- `linalg`: Dense linear algebra
- `matrix`: Dense matrix operations
- `neighbors`: Nearest neighbors and knn graph construction
- `random`: Random number generation, sampling, and data generation primitives
- `solver`: Iterative and combinatorial solvers for optimization and approximation
- `sparse`: Sparse matrix operations
- `convert`: Sparse conversion functions
- `distance`: Sparse distance computations
- `linalg`: Sparse linear algebra
- `neighbors`: Sparse nearest neighbors and knn graph construction
- `op`: Various sparse operations such as slicing and filtering (Note: this will soon be renamed to `sparse/matrix`)
- `solver`: Sparse solvers for optimization and approximation
- `stats`: Moments, summary statistics, model performance measures
- `util`: Various reusable tools and utilities for accelerated algorithm development
- `internal`: A private header-only component that hosts the code shared between benchmarks and tests.
- `scripts`: Helpful scripts for development
- `src`: Compiled APIs and template instantiations for the shared libraries
- `template`: A skeleton template containing the bare-bones file structure and cmake configuration for writing applications with RAFT.
- `test`: Googletests source code
- `docs`: Source code and scripts for building library documentation (Uses breath, doxygen, & pydocs)
- `notebooks`: IPython notebooks with usage examples and tutorials
- `python`: Source code for Python libraries.
- `pylibraft`: Python build and source code for pylibraft library
- `raft-dask`: Python build and source code for raft-dask library
- `thirdparty`: Third-party licenses

## Contributing

Expand Down
8 changes: 4 additions & 4 deletions cpp/bench/prims/distance/masked_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ struct Params {
AdjacencyPattern pattern;
}; // struct Params

__global__ void init_adj(AdjacencyPattern pattern,
int n,
raft::device_matrix_view<bool, int, raft::layout_c_contiguous> adj,
raft::device_vector_view<int, int, raft::layout_c_contiguous> group_idxs)
RAFT_KERNEL init_adj(AdjacencyPattern pattern,
int n,
raft::device_matrix_view<bool, int, raft::layout_c_contiguous> adj,
raft::device_vector_view<int, int, raft::layout_c_contiguous> group_idxs)
{
int m = adj.extent(0);
int num_groups = adj.extent(1);
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/sparse/convert_csr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ struct bench_param {
};

template <typename index_t>
__global__ void init_adj_kernel(bool* adj, index_t num_rows, index_t num_cols, index_t divisor)
RAFT_KERNEL init_adj_kernel(bool* adj, index_t num_rows, index_t num_cols, index_t divisor)
{
index_t r = blockDim.y * blockIdx.y + threadIdx.y;
index_t c = blockDim.x * blockIdx.x + threadIdx.x;
Expand Down
16 changes: 7 additions & 9 deletions cpp/include/raft/cluster/detail/agglomerative.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,7 @@ void build_dendrogram_host(raft::resources const& handle,
}

template <typename value_idx>
__global__ void write_levels_kernel(const value_idx* children,
value_idx* parents,
value_idx n_vertices)
RAFT_KERNEL write_levels_kernel(const value_idx* children, value_idx* parents, value_idx n_vertices)
{
value_idx tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < n_vertices) {
Expand All @@ -179,12 +177,12 @@ __global__ void write_levels_kernel(const value_idx* children,
* @param labels
*/
template <typename value_idx>
__global__ void inherit_labels(const value_idx* children,
const value_idx* levels,
std::size_t n_leaves,
value_idx* labels,
int cut_level,
value_idx n_vertices)
RAFT_KERNEL inherit_labels(const value_idx* children,
const value_idx* levels,
std::size_t n_leaves,
value_idx* labels,
int cut_level,
value_idx n_vertices)
{
value_idx tid = blockDim.x * blockIdx.x + threadIdx.x;

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/cluster/detail/connectivities.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ struct distance_graph_impl<raft::cluster::LinkageDistance::KNN_GRAPH, value_idx,
};

template <typename value_idx>
__global__ void fill_indices2(value_idx* indices, size_t m, size_t nnz)
RAFT_KERNEL fill_indices2(value_idx* indices, size_t m, size_t nnz)
{
value_idx tid = (blockIdx.x * blockDim.x) + threadIdx.x;
if (tid >= nnz) return;
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ template <uint32_t BlockDimY,
typename LabelT,
typename CounterT,
typename MappingOpT>
__global__ void __launch_bounds__((WarpSize * BlockDimY))
__launch_bounds__((WarpSize * BlockDimY)) RAFT_KERNEL
adjust_centers_kernel(MathT* centers, // [n_clusters, dim]
IdxT n_clusters,
IdxT dim,
Expand Down
46 changes: 23 additions & 23 deletions cpp/include/raft/cluster/detail/kmeans_deprecated.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ constexpr unsigned int BSIZE_DIV_WSIZE = (BLOCK_SIZE / WARP_SIZE);
* initialized to zero.
*/
template <typename index_type_t, typename value_type_t>
static __global__ void computeDistances(index_type_t n,
index_type_t d,
index_type_t k,
const value_type_t* __restrict__ obs,
const value_type_t* __restrict__ centroids,
value_type_t* __restrict__ dists)
RAFT_KERNEL computeDistances(index_type_t n,
index_type_t d,
index_type_t k,
const value_type_t* __restrict__ obs,
const value_type_t* __restrict__ centroids,
value_type_t* __restrict__ dists)
{
// Loop index
index_type_t i;
Expand Down Expand Up @@ -173,11 +173,11 @@ static __global__ void computeDistances(index_type_t n,
* cluster. Entries must be initialized to zero.
*/
template <typename index_type_t, typename value_type_t>
static __global__ void minDistances(index_type_t n,
index_type_t k,
value_type_t* __restrict__ dists,
index_type_t* __restrict__ codes,
index_type_t* __restrict__ clusterSizes)
RAFT_KERNEL minDistances(index_type_t n,
index_type_t k,
value_type_t* __restrict__ dists,
index_type_t* __restrict__ codes,
index_type_t* __restrict__ clusterSizes)
{
// Loop index
index_type_t i, j;
Expand Down Expand Up @@ -233,11 +233,11 @@ static __global__ void minDistances(index_type_t n,
* @param code_new Index associated with new centroid.
*/
template <typename index_type_t, typename value_type_t>
static __global__ void minDistances2(index_type_t n,
value_type_t* __restrict__ dists_old,
const value_type_t* __restrict__ dists_new,
index_type_t* __restrict__ codes_old,
index_type_t code_new)
RAFT_KERNEL minDistances2(index_type_t n,
value_type_t* __restrict__ dists_old,
const value_type_t* __restrict__ dists_new,
index_type_t* __restrict__ codes_old,
index_type_t code_new)
{
// Loop index
index_type_t i = threadIdx.x + blockIdx.x * blockDim.x;
Expand Down Expand Up @@ -275,9 +275,9 @@ static __global__ void minDistances2(index_type_t n,
* cluster. Entries must be initialized to zero.
*/
template <typename index_type_t>
static __global__ void computeClusterSizes(index_type_t n,
const index_type_t* __restrict__ codes,
index_type_t* __restrict__ clusterSizes)
RAFT_KERNEL computeClusterSizes(index_type_t n,
const index_type_t* __restrict__ codes,
index_type_t* __restrict__ clusterSizes)
{
index_type_t i = threadIdx.x + blockIdx.x * blockDim.x;
while (i < n) {
Expand Down Expand Up @@ -308,10 +308,10 @@ static __global__ void computeClusterSizes(index_type_t n,
* column is the mean position of a cluster).
*/
template <typename index_type_t, typename value_type_t>
static __global__ void divideCentroids(index_type_t d,
index_type_t k,
const index_type_t* __restrict__ clusterSizes,
value_type_t* __restrict__ centroids)
RAFT_KERNEL divideCentroids(index_type_t d,
index_type_t k,
const index_type_t* __restrict__ clusterSizes,
value_type_t* __restrict__ centroids)
{
// Global indices
index_type_t gidx, gidy;
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/common/detail/scatter.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,7 +22,7 @@
namespace raft::detail {

template <typename DataT, int VecLen, typename Lambda, typename IdxT>
__global__ void scatterKernel(DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op)
RAFT_KERNEL scatterKernel(DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op)
{
typedef TxN_t<DataT, VecLen> DataVec;
typedef TxN_t<IdxT, VecLen> IdxVec;
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/core/detail/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ __device__ auto increment_indices(IdxType* indices,
* parameters.
*/
template <typename DstType, typename SrcType>
__global__ mdspan_copyable_with_kernel_t<DstType, SrcType> mdspan_copy_kernel(DstType dst,
SrcType src)

RAFT_KERNEL mdspan_copy_kernel(DstType dst, SrcType src)
{
using config = mdspan_copyable<true, DstType, SrcType>;

Expand Down
32 changes: 32 additions & 0 deletions cpp/include/raft/core/detail/macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,38 @@
// as a weak symbol rather than a global."
#define RAFT_WEAK_FUNCTION __attribute__((weak))

// The RAFT_HIDDEN_FUNCTION specificies that the function will be hidden
// and therefore not callable by consumers of raft when compiled as
// a shared library.
//
// Hidden visibility also ensures that the linker doesn't de-duplicate the
// symbol across multiple `.so`. This allows multiple libraries to embed raft
// without issue
#define RAFT_HIDDEN_FUNCTION __attribute__((visibility("hidden")))

// The RAFT_KERNEL specificies that a kernel has hidden visibility
//
// Raft needs to ensure that the visibility of its __global__ function
// templates have hidden visibility ( default is weak visibility).
//
// When kernls have weak visibility it means that if two dynamic libraries
// both contain identical instantiations of a RAFT template, then the linker
// will discard one of the two instantiations and use only one of them.
//
// Do to unique requirements of how the CUDA works this de-deduplication
// can lead to the wrong kernels being called ( SM version being wrong ),
// silently no kernel being called at all, or cuda runtime errors being
// thrown.
//
// https://github.com/rapidsai/raft/issues/1722
#if defined(__CUDACC_RDC__)
#define RAFT_KERNEL RAFT_HIDDEN_FUNCTION __global__ void
#elif defined(_RAFT_HAS_CUDA)
#define RAFT_KERNEL static __global__ void
#else
#define RAFT_KERNEL static void
#endif

/**
* Some macro magic to remove optional parentheses of a macro argument.
* See https://stackoverflow.com/a/62984543
Expand Down
Loading

0 comments on commit 15539ea

Please sign in to comment.