Skip to content

Commit

Permalink
Support bit_hamming and bit_jaccard distances.
Browse files Browse the repository at this point in the history
  • Loading branch information
lmcinnes committed Mar 28, 2024
1 parent 8bfc563 commit 4cc131e
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions umap/umap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"cosine": 2,
"hellinger": 1,
"jaccard": 1,
"bit_jaccard": 1,
"dice": 1,
}

Expand Down Expand Up @@ -2351,8 +2352,10 @@ def fit(self, X, y=None, force_all_finite=True):
- 'allow-nan': accepts only np.nan and pd.NA values in array.
Values cannot be infinite.
"""

X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite)
if self.metric in ("bit_hamming", "bit_jaccard"):
X = check_array(X, dtype=np.uint8, order="C", force_all_finite=force_all_finite)
else:
X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite)
self._raw_data = X

# Handle all the optional arguments, setting default
Expand Down Expand Up @@ -2926,7 +2929,10 @@ def transform(self, X, force_all_finite=True):
"Transform unavailable when model was fit with only a single data sample."
)
# If we just have the original input then short circuit things
X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite)
if self.metric in ("bit_hamming", "bit_jaccard"):
X = check_array(X, dtype=np.uint8, order="C", force_all_finite=force_all_finite)
else:
X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite)
x_hash = joblib.hash(X)
if x_hash == self._input_hash:
if self.transform_mode == "embedding":
Expand Down Expand Up @@ -3297,7 +3303,10 @@ def _output_dist_only(x, y, *kwds):
return inv_transformed_points

def update(self, X, force_all_finite=True):
X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite)
if self.metric in ("bit_hamming", "bit_jaccard"):
X = check_array(X, dtype=np.uint8, order="C", force_all_finite=force_all_finite)
else:
X = check_array(X, dtype=np.float32, accept_sparse="csr", order="C", force_all_finite=force_all_finite)
random_state = check_random_state(self.transform_seed)
rng_state = random_state.randint(INT32_MIN, INT32_MAX, 3).astype(np.int64)

Expand Down

0 comments on commit 4cc131e

Please sign in to comment.