Skip to content

Commit

Permalink
rnndescent support (#120)
Browse files Browse the repository at this point in the history
* initial rnndescent support

* nndescent tests

* Pass args to knn/build/query

* doc clean up: add HNSW reference

* Reference for nn_method = "nndescent"

* document nn_args for nn_method = nndescent

* fix missing function

* clean up hnsw metric check

* mention other nn_methods in metric doc

* regenerate Rd files
  • Loading branch information
jlmelville authored Mar 22, 2024
1 parent 1e48e0f commit 772634a
Show file tree
Hide file tree
Showing 14 changed files with 1,610 additions and 112 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Suggests:
knitr,
RcppHNSW,
rmarkdown,
rnndescent,
RSpectra,
testthat
LinkingTo:
Expand Down
15 changes: 14 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,20 @@ for details on these parameters. Although typically faster than Annoy (for a
given accuracy), be aware that the only supported `metric` values are
`"euclidean"`, `"cosine"` and `"correlation"`. Finally, RcppHNSW is only a
suggested package, not a requirement, so you need to install it yourself (e.g.
via `install.packages("RcppHNSW")`).
via `install.packages("RcppHNSW")`). Also see the
[article on HNSW in uwot](https://jlmelville.github.io/uwot/articles/hnsw-umap.html)
in the documentation.
* The nearest neighbor descent approximate nearest neighbor search algorithm is
now supported via the
[rnndescent](https://cran.r-project.org/package=rnndescent) package. Set
`nn_method = "nndescent"` to use it. The behavior of the method can be
controlled by the new `nn_args` parameter. There are many supported metrics and
possible parameters that can be set in `nn_args`, so please see the
[article on nearest neighbor descent in uwot](https://jlmelville.github.io/uwot/articles/rnndescent-umap.html)
in the documentation, and also the rnndescent package's
[documentation](https://jlmelville.github.io/rnndescent/index.html) for details.
`rnndescent` is only a suggested package, not a requirement, so you need to
install it yourself (e.g. via `install.packages("rnndescent")`).

## Bug fixes and minor improvements

Expand Down
39 changes: 12 additions & 27 deletions R/neighbors.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,21 @@ find_nn <- function(X, k, include_self = TRUE, method = "fnn",
)
},
"hnsw" = {
nn_args_names <- names(nn_args)

if ("M" %in% nn_args_names) {
M <- nn_args$M
}
else {
M <- 16
}

if ("ef_construction" %in% nn_args_names) {
ef_construction <- nn_args$ef_construction
}
else {
ef_construction <- 200
}

if ("ef" %in% nn_args_names) {
ef <- nn_args$ef
}
else {
ef <- 10
}

res <- hnsw_nn(
nn_args$X <- X
nn_args$k <- k
nn_args$metric <- metric
nn_args$n_threads <- n_threads
nn_args$verbose <- verbose
nn_args$ret_index <- ret_index

res <- do.call(hnsw_nn, nn_args)
},
"nndescent" = {
res <- nndescent_nn(
X,
k = k,
metric = metric,
M = M,
ef_construction = ef_construction,
ef = ef,
nn_args = nn_args,
n_threads = n_threads,
ret_index = ret_index,
verbose = verbose
Expand Down
5 changes: 5 additions & 0 deletions R/nn_hnsw.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,8 @@ hnsw_load <- function(name, ndim, filename) {
)
methods::new(class_name, ndim, filename)
}

is_ok_hnsw_metric <- function(metric) {
hnsw_metrics <- c("euclidean", "cosine", "correlation")
metric %in% hnsw_metrics
}
157 changes: 157 additions & 0 deletions R/nn_nndescent.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
nndescent_nn <- function(X,
k = 10,
metric = "euclidean",
nn_args = list(),
n_threads = NULL,
ret_index = FALSE,
verbose = FALSE) {
if (is.null(n_threads)) {
n_threads <- default_num_threads()
}

if (!ret_index) {
nn_knn_args <- get_nndescent_knn_args(nn_args)
nn_knn_args <- lmerge(
nn_knn_args,
list(
data = X,
k = k,
metric = metric,
n_threads = n_threads,
verbose = verbose
)
)
return(do.call(rnndescent::rnnd_knn, nn_knn_args))
}

ann <- nndescent_build(
X,
k,
metric,
nn_args = nn_args,
n_threads = n_threads,
verbose = verbose
)
res <-
list(
idx = ann$ann$graph$idx,
dist = ann$ann$graph$dist,
index = ann
)
res$index$ann$ann$graph <- NULL
res
}

nndescent_build <- function(X,
k,
metric,
nn_args = list(),
n_threads = NULL,
verbose = FALSE) {
nn_build_args <- get_nndescent_build_args(nn_args)
nn_build_args <- lmerge(
nn_build_args,
list(
data = X,
k = k,
metric = metric,
n_threads = n_threads,
verbose = verbose
)
)

index <- do.call(rnndescent::rnnd_build, nn_build_args)
list(
ann = index,
type = "nndescentv1",
metric = metric,
ndim = ncol(X)
)
}


nndescent_search <- function(X,
k,
ann,
nn_args = list(),
n_threads = NULL,
verbose = FALSE) {
nn_query_args <- get_nndescent_query_args(nn_args)
nn_query_args <- lmerge(
nn_query_args,
list(
index = ann$ann,
query = X,
k = k,
n_threads = n_threads,
verbose = verbose
)
)

do.call(rnndescent::rnnd_query, nn_query_args)
}

get_nndescent_knn_args <- function(nn_args) {
nn_knn_args <- list()
nnd_knn_names <- c(
"use_alt_metric",
"init",
"n_trees",
"leaf_size",
"max_tree_depth",
"margin",
"n_iters",
"delta",
"max_candidates",
"weight_by_degree",
"low_memory"
)
for (name in nnd_knn_names) {
if (name %in% names(nn_args)) {
nn_knn_args[[name]] <- nn_args[[name]]
}
}
nn_knn_args
}

get_nndescent_build_args <- function(nn_args) {
# prune_reverse should probably always be TRUE
nn_build_args <- list(prune_reverse = TRUE)
nnd_build_names <- c(
"use_alt_metric",
"init",
"n_trees",
"leaf_size",
"max_tree_depth",
"margin",
"n_iters",
"delta",
"max_candidates",
"weight_by_degree",
"low_memory",
"n_search_trees",
"pruning_degree_multiplier",
"diversify_prob",
"prune_reverse"
)
for (name in nnd_build_names) {
if (name %in% names(nn_args)) {
nn_build_args[[name]] <- nn_args[[name]]
}
}
nn_build_args
}

get_nndescent_query_args <- function(nn_args) {
nn_query_args <- list()
nnd_query_names <- c(
"epsilon",
"max_search_fraction"
)
for (name in nnd_query_names) {
if (name %in% names(nn_args)) {
nn_query_args[[name]] <- nn_args[[name]]
}
}
nn_query_args
}
18 changes: 18 additions & 0 deletions R/transform.R
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,13 @@ umap_transform <- function(X = NULL, model = NULL,
)
}

if (is.character(model$nn_method) &&
model$nn_method == "nndescent" && !is_installed("rnndescent")) {
stop(
"This model requires the rnndescent package to be installed."
)
}

if (is.null(n_epochs)) {
n_epochs <- model$n_epochs
if (is.null(n_epochs)) {
Expand Down Expand Up @@ -562,6 +569,17 @@ umap_transform <- function(X = NULL, model = NULL,
nn$dist <- sqrt(nn$dist)
}
}
else if (startsWith(ann$type, "nndescent")) {
nn <-
nndescent_search(
X,
k = n_neighbors,
ann = ann,
nn_args = model$nn_args,
n_threads = n_threads,
verbose = verbose
)
}
else {
stop("Unknown nn method: ", ann$type)
}
Expand Down
Loading

0 comments on commit 772634a

Please sign in to comment.