Skip to content

Commit

Permalink
Compile _optimize_layout_euclidean_single_epoch once
Browse files Browse the repository at this point in the history
  • Loading branch information
kmkolasinski committed Oct 11, 2024
1 parent c72ac2f commit f8895b8
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions umap/layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,22 @@ def _optimize_layout_euclidean_densmap_epoch_init(
re_sum[i] = np.log(epsilon + (re_sum[i] / phi_sum[i]))


_nb_optimize_layout_euclidean_single_epoch = numba.njit(
_optimize_layout_euclidean_single_epoch, fastmath=True, parallel=False
)

_nb_optimize_layout_euclidean_single_epoch_parallel = numba.njit(
_optimize_layout_euclidean_single_epoch, fastmath=True, parallel=True
)


def _get_optimize_layout_euclidean_single_epoch_fn(parallel: bool = False):
if parallel:
return _nb_optimize_layout_euclidean_single_epoch_parallel
else:
return _nb_optimize_layout_euclidean_single_epoch


def optimize_layout_euclidean(
head_embedding,
tail_embedding,
Expand Down Expand Up @@ -308,9 +324,10 @@ def optimize_layout_euclidean(
epoch_of_next_negative_sample = epochs_per_negative_sample.copy()
epoch_of_next_sample = epochs_per_sample.copy()

optimize_fn = numba.njit(
_optimize_layout_euclidean_single_epoch, fastmath=True, parallel=parallel
)
# Fix for calling UMAP many times for small datasets, otherwise we spend here
# a lot of time in compilation step (first call to numba function)
optimize_fn = _get_optimize_layout_euclidean_single_epoch_fn(parallel)

if densmap_kwds is None:
densmap_kwds = {}
if tqdm_kwds is None:
Expand Down Expand Up @@ -352,7 +369,6 @@ def optimize_layout_euclidean(
) + head_embedding[:, 0].astype(np.float64).view(np.int64).reshape(-1, 1)

for n in tqdm(range(n_epochs), **tqdm_kwds):

densmap_flag = (
densmap
and (densmap_kwds["lambda"] > 0)
Expand Down

0 comments on commit f8895b8

Please sign in to comment.