Skip to content

Commit

Permalink
Merge pull request #791 from ThomasNickerson/master
Browse files Browse the repository at this point in the history
Fix UMAP.update for large data sets
  • Loading branch information
lmcinnes authored Oct 27, 2021
2 parents f6172cf + 9342d97 commit 9f6d1a6
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 3 deletions.
20 changes: 20 additions & 0 deletions umap/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,33 @@ def iris_model(iris):
return UMAP(n_neighbors=10, min_dist=0.01, random_state=42).fit(iris.data)


@pytest.fixture(scope="session")
def iris_model_large(iris):
return UMAP(
n_neighbors=10,
min_dist=0.01,
random_state=42,
force_approximation_algorithm=True,
).fit(iris.data)


@pytest.fixture(scope="session")
def iris_subset_model(iris, iris_selection):
return UMAP(n_neighbors=10, min_dist=0.01, random_state=42).fit(
iris.data[iris_selection]
)


@pytest.fixture(scope="session")
def iris_subset_model_large(iris, iris_selection):
return UMAP(
n_neighbors=10,
min_dist=0.01,
random_state=42,
force_approximation_algorithm=True,
).fit(iris.data[iris_selection])


@pytest.fixture(scope="session")
def supervised_iris_model(iris):
return UMAP(n_neighbors=10, min_dist=0.01, n_epochs=200, random_state=42).fit(
Expand Down
4 changes: 3 additions & 1 deletion umap/tests/test_umap_on_iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def test_umap_trustworthiness_on_sphere_iris(
r * np.cos(embedding[:, 0]),
]
).T
trust = trustworthiness(iris.data, projected_embedding, n_neighbors=10, metric="cosine")
trust = trustworthiness(
iris.data, projected_embedding, n_neighbors=10, metric="cosine"
)
assert (
trust >= 0.80
), "Insufficiently trustworthy spherical embedding for iris dataset: {}".format(
Expand Down
23 changes: 23 additions & 0 deletions umap/tests/test_umap_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,29 @@ def test_umap_update(iris, iris_subset_model, iris_selection, iris_model):
assert error < 1.0


def test_umap_update_large(
iris, iris_subset_model_large, iris_selection, iris_model_large
):

new_data = iris.data[~iris_selection]
new_model = iris_subset_model_large
new_model.update(new_data)

comparison_graph = scipy.sparse.vstack(
[
iris_model_large.graph_[iris_selection],
iris_model_large.graph_[~iris_selection],
]
)
comparison_graph = scipy.sparse.hstack(
[comparison_graph[:, iris_selection], comparison_graph[:, ~iris_selection]]
)

error = np.sum(np.abs((new_model.graph_ - comparison_graph).data))

assert error < 1.5


# -----------------
# UMAP Graph output
# -----------------
Expand Down
6 changes: 4 additions & 2 deletions umap/umap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def nearest_neighbors(
low_memory=low_memory,
n_jobs=n_jobs,
verbose=verbose,
compressed=False,
)
knn_indices, knn_dists = knn_search_index.neighbor_graph

Expand Down Expand Up @@ -1527,8 +1528,8 @@ class UMAP(BaseEstimator):
target_weight: float (optional, default 0.5)
weighting factor between data topology and target topology. A value of
0.0 weights predominantly on data, a value of 1.0 places a strong emphasis on
target. The default of 0.5 balances the weighting equally between data and
0.0 weights predominantly on data, a value of 1.0 places a strong emphasis on
target. The default of 0.5 balances the weighting equally between data and
target.
transform_seed: int (optional, default 42)
Expand Down Expand Up @@ -3310,6 +3311,7 @@ def update(self, X):
)

else:
self._knn_search_index.prepare()
self._knn_search_index.update(X)
self._raw_data = self._knn_search_index._raw_data
(
Expand Down

0 comments on commit 9f6d1a6

Please sign in to comment.