Skip to content


Improve batched_unfold function and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Apr 18, 2024
1 parent c3be642 commit c1b49bd
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 38 deletions.
112 changes: 83 additions & 29 deletions bnpm/
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional, Union, Iterator
import math

import numpy as np
Expand Down Expand Up @@ -252,7 +253,11 @@ def make_batches(

## Make slices
offset = np.random.randint(0, batch_size) if randomize_batch_indices else 0
idx_slices = [slice(s, min(e, l), None) for s, e in (zip(np.arange(idx_start + offset, l, batch_size), np.arange(idx_start + batch_size + offset, l + batch_size + offset, batch_size)))]
idx_slices = [
slice(s, min(e, l), None) for s, e in (zip(
np.arange(idx_start + offset, l, batch_size),
np.arange(idx_start + batch_size + offset, l + batch_size + offset, batch_size)
if offset > 0:
idx_slices = [slice(0, offset, None)] + idx_slices

Expand Down Expand Up @@ -464,47 +469,96 @@ def off_diagonal(x):
return x.reshape(-1)[:-1].reshape(n - 1, n + 1)[:, 1:].reshape(-1)

def batched_unfold(tensor, dimension, size, step, batch_size):
def batched_unfold(
tensor: torch.Tensor,
dimension: int,
size: int,
step: int,
batch_size: int = 10,
pad_value: Optional[Union[int, float]] = None,
) -> Iterator[torch.Tensor]:
Generates batches of overlapping windows of indices from a tensor, mimicking
the behavior of torch.Tensor.unfold but in a batched manner. Using
``, dimension, size, step, batch_size)),
dim=dimension)`` should be equivalent to ``tensor.unfold(dimension, size, step)``.
Generates batches of overlapping windows from a tensor in a batched manner.
This function mimics the behavior of `torch.Tensor.unfold` but allows for
processing in smaller, more manageable batches. Using
`, dim=dimension)` should reproduce the
result of `tensor.unfold(dimension, size, step)`. Note: The last batch may
be smaller than the specified batch size due to the remainder of the
RH 2023
RH 2024
tensor (torch.Tensor):
The input tensor to be unfolded.
The input tensor from which to generate unfolded windows.
dimension (int):
The dimension along which to unfold the tensor.
size (int):
The size of each slice that is unfolded.
The size of each window to unfold.
step (int):
The step between slices.
The step size between the starts of each window.
batch_size (int):
Number of windows per batch.
The number of windows to include in each batch. (Default is 10)
pad_value (Optional[Union[int, float]]):
The value to use for padding the last batch if it is smaller than
the specified batch size. If None, the last batch will not be
padded. (Default is None)
A batch of unfolded tensors.
total_windows = (tensor.size(dimension) - size) // step + 1

for i in range(0, total_windows, batch_size):
# Calculate the start of the slice to ensure correct overlap
start = i * step
# Adjust end index to get a full batch, while not exceeding tensor size
end = min(start + (batch_size - 1) * step + size, tensor.size(dimension))

# Check if the slice size is smaller than needed for a full unfold
if (end - start) < size:

# Slice the tensor to get the current batch and then unfold
batch_slice = tensor.narrow(dimension, start, end - start)
unfolded_batch = batch_slice.unfold(dimension, size, step)
yield unfolded_batch
A batch of unfolded windows from the tensor.
n_samples = tensor.shape[dimension]
if step >= n_samples:
raise ValueError("Step size must be less than the tensor size along the specified dimension.")
if size > n_samples:
raise ValueError("Window size must be less than or equal to the tensor size along the specified dimension.")

# Calculate end points
idx_start_last = ((n_samples - size) // step) * step
# idx_last_start = max(idx_last_start, 0)
# idx_last_end = idx_last_start + size - 1 ## Inclusive

idx_starts_all = range(0, idx_start_last + 1, step)
# idx_ends_all = list((idx + size for idx in idx_starts_all)) ## Exclusive

## Generate batches
for i_batch, idx_start in enumerate(idx_starts_all[::batch_size]):
idx_end = idx_start + (step * (batch_size - 1)) + size
n_samples_pad = 0
if idx_end > n_samples:
if pad_value is None:
idx_end = n_samples
n_samples_batch = n_samples - idx_start
idx_start_last_batch = int((math.ceil((n_samples_batch - size) / step)) * step)
idx_end = idx_start + idx_start_last_batch + size
idx_end = min(idx_end, n_samples + size)
n_samples_pad = idx_end - n_samples

len_batch = idx_end - idx_start

if n_samples_pad > 0:
shape_pad = list(tensor.shape)
shape_pad[dimension] = n_samples_pad
x_batch =
length=n_samples - idx_start,
x_batch = tensor.narrow(dimension, idx_start, len_batch)

yield x_batch.unfold(dimension, size, step)

Expand Down
19 changes: 15 additions & 4 deletions bnpm/
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,12 @@ def torch_coherence(
detrend: str = 'constant',
axis: int = -1,
batch_size: Optional[int] = None,
pad_last_segment: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
Computes the magnitude-squared coherence between two signals using a PyTorch
implementation. This function gives identical results to the
scipy.signal.coherence. \n
scipy.signal.coherence, but much faster.\n
Speed: The 'linear' detrending method is not fast on GPU, despite the
implementation being similar. 'constant' is roughly 3x as fast as 'linear'
Expand Down Expand Up @@ -571,6 +572,9 @@ def torch_coherence(
If None, then all segments are processed at once. Note that
``num_segments = (x.shape[axis] - nperseg) // (nperseg - noverlap) +
1``. (Default is None)
pad_last_segment (bool):
Whether to pad the last segment with a value defined by this
argument. If None, then no padding is done. (Default is None)
Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -588,8 +592,11 @@ def torch_coherence(
## Convert axis to positive
axis = axis % len(x.shape)

## Check dimensions
### They should either be the same or one of them should be 1
## Checks
### Tensor checks
assert isinstance(x, torch.Tensor), "x should be a torch tensor"
assert isinstance(y, torch.Tensor), "y should be a torch tensor"
### Dims should either be the same or one of them should be 1
if not (x.shape == y.shape):
assert all([(x.shape[ii] == y.shape[ii]) or (1 in [x.shape[ii], y.shape[ii]]) for ii in range(len(x.shape))]), f"x and y should have the same shape or one of them should have shape 1 at each dimension. Found x.shape={x.shape} and y.shape={y.shape}"

Expand All @@ -606,6 +613,9 @@ def torch_coherence(
window = scipy.signal.get_window(window, nperseg)
window = torch.tensor(window, dtype=x.dtype, device=x.device)

## args should be greater than nfft
assert all([arg.shape[axis] >= nfft for arg in (x, y)]), f"Signal length along axis should be greater than nfft. Found x.shape={x.shape} and y.shape={y.shape} and nfft={nfft}"

## Detrend the signals
def detrend_constant(y, axis):
y = y - torch.mean(y, axis=axis, keepdim=True)
Expand Down Expand Up @@ -656,13 +666,14 @@ def detrend_linear(y, axis):

## Prepare batch generator
num_segments = (x.shape[axis] - nperseg) // (nperseg - noverlap) + 1
batch_size = num_segments if batch_size is None else batch_size
batch_size = max(num_segments, 1) if batch_size is None else batch_size
x_batches, y_batches = (indexing.batched_unfold(
step=nperseg - noverlap,
) for var in (x, y))

## Perform Welch's averaging of FFT segments
Expand Down
3 changes: 3 additions & 0 deletions bnpm/tests/
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
__all__ = [
58 changes: 53 additions & 5 deletions bnpm/tests/
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import scipy.signal
from hypothesis import given, strategies as st
from hypothesis import settings

from ..spectral import torch_coherence

Expand Down Expand Up @@ -109,7 +110,7 @@ def test_overlap_sizes(noverlap):
@pytest.mark.parametrize("nfft", [256, 512, 1024])
def test_fft_lengths(nfft):
np.random.seed(0) # For reproducibility
t = np.linspace(0, 1, 1000, endpoint=False)
t = np.linspace(0, 1, 2048, endpoint=False)
x = np.sin(2 * np.pi * 5 * t) # 5 Hz sinusoid
y = np.sin(2 * np.pi * 5 * t + np.pi/4) # 5 Hz sinusoid, phase shifted

Expand Down Expand Up @@ -145,7 +146,7 @@ def test_detrending_methods(detrend):
# Check if the results are close enough
assert np.allclose(coherence_pytorch.numpy(), coherence_scipy, atol=1e-2), f"Coherence values do not match closely enough for detrend method={detrend}."

# Test multi-dimensional input
# Test 2D input
def test_multi_dimensional_input():
np.random.seed(0) # For reproducibility
t = np.linspace(0, 1, 1000, endpoint=False)
Expand All @@ -165,6 +166,53 @@ def test_multi_dimensional_input():
freqs_pytorch, coherence_pytorch = torch_coherence(x_torch, y_torch, fs=fs, nperseg=nperseg)
freqs_scipy, coherence_scipy = scipy.signal.coherence(x, y, fs=fs, nperseg=nperseg, axis=1)

# Check if the results are close enough, comparing each ensemble member's coherence
for i in range(10):
assert np.allclose(coherence_pytorch[i].numpy(), coherence_scipy[i], atol=1e-2), "Coherence values do not match for multi-dimensional input."
# Check if the results are close enough
assert np.allclose(coherence_pytorch.numpy(), coherence_scipy, atol=1e-2), "Coherence values do not match for multi-dimensional input."

# Test ND input
st.integers(min_value=1, max_value=4), # Number of dimensions
st.integers(min_value=16, max_value=40,), # Number of samples
def test_nd_input(ndim, nsamples):
axis = np.random.choice(range(ndim))
## Make random data
### Make random shape with ndim and either 1 or nsamples features
shape = np.random.randint(nsamples, nsamples * 2, size=ndim)
### Select between 0 and ndim-1 integers with values between 0 and ndim
axes = np.random.choice(range(0, ndim), np.random.randint(0, ndim), replace=False)
### Make singleton dimensions at either x or y for the dims in axis
shape_x = list(shape)
shape_y = list(shape)
for i in axes:
if i == axis:
### 50% chance of applying to x or y, 10% chance of applying to both
r = np.random.rand()
if r > 0.5:
shape_x[i] = 1
elif r < 0.1:
shape_x[i] = 1
shape_y[i] = 1
shape_y[i] = 1

x = np.random.randn(*shape_x).astype(np.float32)
y = np.random.randn(*shape_y).astype(np.float32)

x_torch = torch.tensor(x)
y_torch = torch.tensor(y)

fs = 1.0
nperseg = 16

freqs_pytorch, coherence_pytorch = torch_coherence(x_torch, y_torch, fs=fs, nperseg=nperseg, axis=axis)
freqs_scipy, coherence_scipy = scipy.signal.coherence(x, y, fs=fs, nperseg=nperseg, axis=axis)

# Check if the results are close enough
assert np.allclose(coherence_pytorch.numpy(), coherence_scipy, atol=1e-2), "Coherence values do not match for ND input."
3 changes: 3 additions & 0 deletions bnpm/tests/test_indexing/
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__all__ = [

0 comments on commit c1b49bd

Please sign in to comment.