diff --git a/umap/layouts.py b/umap/layouts.py index 43b633c4..1bf0feee 100644 --- a/umap/layouts.py +++ b/umap/layouts.py @@ -219,6 +219,22 @@ def _optimize_layout_euclidean_densmap_epoch_init( re_sum[i] = np.log(epsilon + (re_sum[i] / phi_sum[i])) +_nb_optimize_layout_euclidean_single_epoch = numba.njit( + _optimize_layout_euclidean_single_epoch, fastmath=True, parallel=False +) + +_nb_optimize_layout_euclidean_single_epoch_parallel = numba.njit( + _optimize_layout_euclidean_single_epoch, fastmath=True, parallel=True +) + + +def _get_optimize_layout_euclidean_single_epoch_fn(parallel: bool = False): + if parallel: + return _nb_optimize_layout_euclidean_single_epoch_parallel + else: + return _nb_optimize_layout_euclidean_single_epoch + + def optimize_layout_euclidean( head_embedding, tail_embedding, @@ -308,9 +324,10 @@ def optimize_layout_euclidean( epoch_of_next_negative_sample = epochs_per_negative_sample.copy() epoch_of_next_sample = epochs_per_sample.copy() - optimize_fn = numba.njit( - _optimize_layout_euclidean_single_epoch, fastmath=True, parallel=parallel - ) + # Fix for calling UMAP many times for small datasets, otherwise we spend here + # a lot of time in compilation step (first call to numba function) + optimize_fn = _get_optimize_layout_euclidean_single_epoch_fn(parallel) + if densmap_kwds is None: densmap_kwds = {} if tqdm_kwds is None: @@ -352,7 +369,6 @@ def optimize_layout_euclidean( ) + head_embedding[:, 0].astype(np.float64).view(np.int64).reshape(-1, 1) for n in tqdm(range(n_epochs), **tqdm_kwds): - densmap_flag = ( densmap and (densmap_kwds["lambda"] > 0)