From f8895b8ee0d10c96a1c4ecada1416de246f4867f Mon Sep 17 00:00:00 2001 From: Krzysztof Date: Fri, 11 Oct 2024 14:20:34 +0200 Subject: [PATCH] Compile _optimize_layout_euclidean_single_epoch once --- umap/layouts.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/umap/layouts.py b/umap/layouts.py index 43b633c4..1bf0feee 100644 --- a/umap/layouts.py +++ b/umap/layouts.py @@ -219,6 +219,22 @@ def _optimize_layout_euclidean_densmap_epoch_init( re_sum[i] = np.log(epsilon + (re_sum[i] / phi_sum[i])) +_nb_optimize_layout_euclidean_single_epoch = numba.njit( + _optimize_layout_euclidean_single_epoch, fastmath=True, parallel=False +) + +_nb_optimize_layout_euclidean_single_epoch_parallel = numba.njit( + _optimize_layout_euclidean_single_epoch, fastmath=True, parallel=True +) + + +def _get_optimize_layout_euclidean_single_epoch_fn(parallel: bool = False): + if parallel: + return _nb_optimize_layout_euclidean_single_epoch_parallel + else: + return _nb_optimize_layout_euclidean_single_epoch + + def optimize_layout_euclidean( head_embedding, tail_embedding, @@ -308,9 +324,10 @@ def optimize_layout_euclidean( epoch_of_next_negative_sample = epochs_per_negative_sample.copy() epoch_of_next_sample = epochs_per_sample.copy() - optimize_fn = numba.njit( - _optimize_layout_euclidean_single_epoch, fastmath=True, parallel=parallel - ) + # Fix for calling UMAP many times for small datasets, otherwise we spend here + # a lot of time in compilation step (first call to numba function) + optimize_fn = _get_optimize_layout_euclidean_single_epoch_fn(parallel) + if densmap_kwds is None: densmap_kwds = {} if tqdm_kwds is None: @@ -352,7 +369,6 @@ def optimize_layout_euclidean( ) + head_embedding[:, 0].astype(np.float64).view(np.int64).reshape(-1, 1) for n in tqdm(range(n_epochs), **tqdm_kwds): - densmap_flag = ( densmap and (densmap_kwds["lambda"] > 0)