From 8ac0b5cc9d435d62a077f0ed77df675044103afc Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Wed, 14 Aug 2024 22:00:36 -0400 Subject: [PATCH] CHANGES rewrote find_nearest, removed find_nearest_idx and _array --- bnpm/indexing.py | 118 +++++++++++++++++++++++------------------------ 1 file changed, 59 insertions(+), 59 deletions(-) diff --git a/bnpm/indexing.py b/bnpm/indexing.py index 1f525d0..d602e52 100644 --- a/bnpm/indexing.py +++ b/bnpm/indexing.py @@ -287,71 +287,71 @@ def make_batches( yield iterable[order[idx]] -@njit -def find_nearest_idx(array, value): - ''' - Finds the value and index of the nearest - value in an array. - RH 2021, 2024 - +def find_nearest( + array: Union[np.ndarray, torch.Tensor], + values: Union[float, int, np.ndarray, torch.Tensor], + presorted: bool = False, + return_idx: bool = True, + return_values: bool = False, + return_diff: bool = False, +) -> Union[np.ndarray, tuple]: + """ + Find the nearest value in a 1D array (or along a dimension) to a given + value. + RH 2024 + Args: - array (np.ndarray): - Array of values to search through. - value (scalar): - Value to search for. + array (Union[np.ndarray, torch.Tensor]): + Array to search for nearest values. Must be 1D. + values (Union[float, int, np.ndarray, torch.Tensor]): + Value or values to search for. If an array, then must be 1D. + presorted (bool): + Whether `array` is already sorted. If False, the function will sort + the array before searching for the nearest value. + return_idx (bool): + Whether to return the index of the nearest value. + return_values (bool): + Whether to return the nearest value. Returns: - array_idx (int): - Index of the nearest value in array. - ''' - idx = np.searchsorted(array, value, side="left") - if idx > 0 and (idx == len(array) or np.abs(value - array[idx-1]) < np.abs(value - array[idx])): - return idx-1 + outputs (Union[np.ndarray, tuple]): + idx (np.ndarray or torch.Tensor): + Index of the nearest value. Shape is the same as `values`. + nearest_values (np.ndarray or torch.Tensor): + Nearest values. Shape is the same as `values`. + diff (np.ndarray or torch.Tensor): + Difference between the value and the nearest value. Shape is the same as `values`. + """ + if isinstance(array, np.ndarray): + kind = 'numpy' + elif isinstance(array, torch.Tensor): + kind = 'torch' else: - return idx -@njit(parallel=True) -def find_nearest_array(array, values, max_diff=None): - ''' - Finds the values and indices of the nearest - values in an array. - RH 2021, 2024 + raise ValueError('array must be a numpy array or torch tensor') + array, values = (torch.as_tensor(v) for v in (array, values)) + + assert array.ndim == 1, 'array must be 1D' + if values.ndim == 0: + values = values.unsqueeze(0) + assert values.ndim == 1, 'values must be 1D' + + array_sorted = array if presorted else torch.sort(array)[0] + + idx_nearest = torch.searchsorted(array_sorted, values, side="left") + idx_nearest = torch.where( + (idx_nearest > 0) & \ + ((idx_nearest == len(array_sorted)) | (torch.abs(values - array_sorted[idx_nearest-1]) < torch.abs(values - array_sorted[idx_nearest]))), + idx_nearest-1, + idx_nearest, + ) + vals_nearest = array_sorted[idx_nearest] if return_values else None + diff_nearest = torch.abs(vals_nearest - values) if return_diff else None - Args: - array (np.ndarray): - Array of values to search through. - values (np.ndarray): - Values to search for. + out = tuple((v for t, v in zip([return_idx, return_values, return_diff], [idx_nearest, vals_nearest, diff_nearest]) if t)) - Returns: - array_idx (np.ndarray): - Indices of the nearest values in array. - array_val (np.ndarray): - Values of the nearest values in array. - diff (np.ndarray): - Differences between the values and the - nearest values in array. - ''' - assert array.ndim == 1, 'array must be 1-D' - assert values.ndim == 1, 'values must be 1-D' - - vals_nearest = np.zeros(values.shape if array.size > 0 else (0,), dtype=array.dtype) - idx_nearest = np.zeros(values.shape if array.size > 0 else (0,), dtype=np.int64) - diff_nearest = np.zeros(values.shape if array.size > 0 else (0,), dtype=array.dtype) - - if array.size > 0: - for ii in prange(len(values)): - idx_nearest[ii] = find_nearest_idx(array , values[ii]) - - vals_nearest = array[idx_nearest] - diff_nearest = np.abs(vals_nearest - values) - - if max_diff is not None: - bool_keep = diff_nearest <= max_diff - vals_nearest = vals_nearest[bool_keep] - idx_nearest = idx_nearest[bool_keep] - diff_nearest = diff_nearest[bool_keep] - - return vals_nearest, idx_nearest, diff_nearest + if kind == 'numpy': + out = tuple((v.numpy() for v in out)) + return out def pad_with_singleton_dims(array, n_dims_pre=0, n_dims_post=0):