Skip to content

Commit

Permalink
CHANGES rewrote find_nearest, removed find_nearest_idx and _array
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Aug 15, 2024
1 parent 2362ad3 commit 8ac0b5c
Showing 1 changed file with 59 additions and 59 deletions.
118 changes: 59 additions & 59 deletions bnpm/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,71 +287,71 @@ def make_batches(
yield iterable[order[idx]]


@njit
def find_nearest_idx(array, value):
'''
Finds the value and index of the nearest
value in an array.
RH 2021, 2024
def find_nearest(
array: Union[np.ndarray, torch.Tensor],
values: Union[float, int, np.ndarray, torch.Tensor],
presorted: bool = False,
return_idx: bool = True,
return_values: bool = False,
return_diff: bool = False,
) -> Union[np.ndarray, tuple]:
"""
Find the nearest value in a 1D array (or along a dimension) to a given
value.
RH 2024
Args:
array (np.ndarray):
Array of values to search through.
value (scalar):
Value to search for.
array (Union[np.ndarray, torch.Tensor]):
Array to search for nearest values. Must be 1D.
values (Union[float, int, np.ndarray, torch.Tensor]):
Value or values to search for. If an array, then must be 1D.
presorted (bool):
Whether `array` is already sorted. If False, the function will sort
the array before searching for the nearest value.
return_idx (bool):
Whether to return the index of the nearest value.
return_values (bool):
Whether to return the nearest value.
Returns:
array_idx (int):
Index of the nearest value in array.
'''
idx = np.searchsorted(array, value, side="left")
if idx > 0 and (idx == len(array) or np.abs(value - array[idx-1]) < np.abs(value - array[idx])):
return idx-1
outputs (Union[np.ndarray, tuple]):
idx (np.ndarray or torch.Tensor):
Index of the nearest value. Shape is the same as `values`.
nearest_values (np.ndarray or torch.Tensor):
Nearest values. Shape is the same as `values`.
diff (np.ndarray or torch.Tensor):
Difference between the value and the nearest value. Shape is the same as `values`.
"""
if isinstance(array, np.ndarray):
kind = 'numpy'
elif isinstance(array, torch.Tensor):
kind = 'torch'
else:
return idx
@njit(parallel=True)
def find_nearest_array(array, values, max_diff=None):
'''
Finds the values and indices of the nearest
values in an array.
RH 2021, 2024
raise ValueError('array must be a numpy array or torch tensor')
array, values = (torch.as_tensor(v) for v in (array, values))

assert array.ndim == 1, 'array must be 1D'
if values.ndim == 0:
values = values.unsqueeze(0)
assert values.ndim == 1, 'values must be 1D'

array_sorted = array if presorted else torch.sort(array)[0]

idx_nearest = torch.searchsorted(array_sorted, values, side="left")
idx_nearest = torch.where(
(idx_nearest > 0) & \
((idx_nearest == len(array_sorted)) | (torch.abs(values - array_sorted[idx_nearest-1]) < torch.abs(values - array_sorted[idx_nearest]))),
idx_nearest-1,
idx_nearest,
)
vals_nearest = array_sorted[idx_nearest] if return_values else None
diff_nearest = torch.abs(vals_nearest - values) if return_diff else None

Args:
array (np.ndarray):
Array of values to search through.
values (np.ndarray):
Values to search for.
out = tuple((v for t, v in zip([return_idx, return_values, return_diff], [idx_nearest, vals_nearest, diff_nearest]) if t))

Returns:
array_idx (np.ndarray):
Indices of the nearest values in array.
array_val (np.ndarray):
Values of the nearest values in array.
diff (np.ndarray):
Differences between the values and the
nearest values in array.
'''
assert array.ndim == 1, 'array must be 1-D'
assert values.ndim == 1, 'values must be 1-D'

vals_nearest = np.zeros(values.shape if array.size > 0 else (0,), dtype=array.dtype)
idx_nearest = np.zeros(values.shape if array.size > 0 else (0,), dtype=np.int64)
diff_nearest = np.zeros(values.shape if array.size > 0 else (0,), dtype=array.dtype)

if array.size > 0:
for ii in prange(len(values)):
idx_nearest[ii] = find_nearest_idx(array , values[ii])

vals_nearest = array[idx_nearest]
diff_nearest = np.abs(vals_nearest - values)

if max_diff is not None:
bool_keep = diff_nearest <= max_diff
vals_nearest = vals_nearest[bool_keep]
idx_nearest = idx_nearest[bool_keep]
diff_nearest = diff_nearest[bool_keep]

return vals_nearest, idx_nearest, diff_nearest
if kind == 'numpy':
out = tuple((v.numpy() for v in out))
return out


def pad_with_singleton_dims(array, n_dims_pre=0, n_dims_post=0):
Expand Down

0 comments on commit 8ac0b5c

Please sign in to comment.