Skip to content

Commit

Permalink
Improve batched_unfold function and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Apr 18, 2024
1 parent c3be642 commit c1b49bd
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 38 deletions.
112 changes: 83 additions & 29 deletions bnpm/indexing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional, Union, Iterator
import math

import numpy as np
Expand Down Expand Up @@ -252,7 +253,11 @@ def make_batches(

## Make slices
offset = np.random.randint(0, batch_size) if randomize_batch_indices else 0
idx_slices = [slice(s, min(e, l), None) for s, e in (zip(np.arange(idx_start + offset, l, batch_size), np.arange(idx_start + batch_size + offset, l + batch_size + offset, batch_size)))]
idx_slices = [
slice(s, min(e, l), None) for s, e in (zip(
np.arange(idx_start + offset, l, batch_size),
np.arange(idx_start + batch_size + offset, l + batch_size + offset, batch_size)
))]
if offset > 0:
idx_slices = [slice(0, offset, None)] + idx_slices

Expand Down Expand Up @@ -464,47 +469,96 @@ def off_diagonal(x):
return x.reshape(-1)[:-1].reshape(n - 1, n + 1)[:, 1:].reshape(-1)


def batched_unfold(tensor, dimension, size, step, batch_size):
def batched_unfold(
tensor: torch.Tensor,
dimension: int,
size: int,
step: int,
batch_size: int = 10,
pad_value: Optional[Union[int, float]] = None,
) -> Iterator[torch.Tensor]:
"""
Generates batches of overlapping windows of indices from a tensor, mimicking
the behavior of torch.Tensor.unfold but in a batched manner. Using
``torch.cat(list(batched_unfold(tensor, dimension, size, step, batch_size)),
dim=dimension)`` should be equivalent to ``tensor.unfold(dimension, size, step)``.
Generates batches of overlapping windows from a tensor in a batched manner.
This function mimics the behavior of `torch.Tensor.unfold` but allows for
processing in smaller, more manageable batches. Using
`torch.cat(list(batched_unfold(...)), dim=dimension)` should reproduce the
result of `tensor.unfold(dimension, size, step)`. Note: The last batch may
be smaller than the specified batch size due to the remainder of the
division.
RH 2023
RH 2024
Args:
tensor (torch.Tensor):
The input tensor to be unfolded.
The input tensor from which to generate unfolded windows.
dimension (int):
The dimension along which to unfold the tensor.
size (int):
The size of each slice that is unfolded.
The size of each window to unfold.
step (int):
The step between slices.
The step size between the starts of each window.
batch_size (int):
Number of windows per batch.
The number of windows to include in each batch. (Default is 10)
pad_value (Optional[Union[int, float]]):
The value to use for padding the last batch if it is smaller than
the specified batch size. If None, the last batch will not be
padded. (Default is None)
Yields:
torch.Tensor:
A batch of unfolded tensors.
"""
total_windows = (tensor.size(dimension) - size) // step + 1

for i in range(0, total_windows, batch_size):
# Calculate the start of the slice to ensure correct overlap
start = i * step
# Adjust end index to get a full batch, while not exceeding tensor size
end = min(start + (batch_size - 1) * step + size, tensor.size(dimension))

# Check if the slice size is smaller than needed for a full unfold
if (end - start) < size:
continue

# Slice the tensor to get the current batch and then unfold
batch_slice = tensor.narrow(dimension, start, end - start)
unfolded_batch = batch_slice.unfold(dimension, size, step)
yield unfolded_batch
A batch of unfolded windows from the tensor.
"""
n_samples = tensor.shape[dimension]
if step >= n_samples:
raise ValueError("Step size must be less than the tensor size along the specified dimension.")
if size > n_samples:
raise ValueError("Window size must be less than or equal to the tensor size along the specified dimension.")

# Calculate end points
idx_start_last = ((n_samples - size) // step) * step
# idx_last_start = max(idx_last_start, 0)
# idx_last_end = idx_last_start + size - 1 ## Inclusive

idx_starts_all = range(0, idx_start_last + 1, step)
# idx_ends_all = list((idx + size for idx in idx_starts_all)) ## Exclusive

## Generate batches
for i_batch, idx_start in enumerate(idx_starts_all[::batch_size]):
idx_end = idx_start + (step * (batch_size - 1)) + size
n_samples_pad = 0
if idx_end > n_samples:
if pad_value is None:
idx_end = n_samples
else:
n_samples_batch = n_samples - idx_start
idx_start_last_batch = int((math.ceil((n_samples_batch - size) / step)) * step)
idx_end = idx_start + idx_start_last_batch + size
idx_end = min(idx_end, n_samples + size)
n_samples_pad = idx_end - n_samples

len_batch = idx_end - idx_start

if n_samples_pad > 0:
shape_pad = list(tensor.shape)
shape_pad[dimension] = n_samples_pad
x_batch = torch.cat((
tensor.narrow(
dim=dimension,
start=idx_start,
length=n_samples - idx_start,
),
torch.full(
size=shape_pad,
fill_value=pad_value,
dtype=tensor.dtype,
device=tensor.device,
)
))
else:
x_batch = tensor.narrow(dimension, idx_start, len_batch)

yield x_batch.unfold(dimension, size, step)


#######################################
Expand Down
19 changes: 15 additions & 4 deletions bnpm/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,12 @@ def torch_coherence(
detrend: str = 'constant',
axis: int = -1,
batch_size: Optional[int] = None,
pad_last_segment: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the magnitude-squared coherence between two signals using a PyTorch
implementation. This function gives identical results to the
scipy.signal.coherence. \n
scipy.signal.coherence, but much faster.\n
Speed: The 'linear' detrending method is not fast on GPU, despite the
implementation being similar. 'constant' is roughly 3x as fast as 'linear'
Expand Down Expand Up @@ -571,6 +572,9 @@ def torch_coherence(
If None, then all segments are processed at once. Note that
``num_segments = (x.shape[axis] - nperseg) // (nperseg - noverlap) +
1``. (Default is None)
pad_last_segment (bool):
Whether to pad the last segment with a value defined by this
argument. If None, then no padding is done. (Default is None)
Returns:
Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -588,8 +592,11 @@ def torch_coherence(
## Convert axis to positive
axis = axis % len(x.shape)

## Check dimensions
### They should either be the same or one of them should be 1
## Checks
### Tensor checks
assert isinstance(x, torch.Tensor), "x should be a torch tensor"
assert isinstance(y, torch.Tensor), "y should be a torch tensor"
### Dims should either be the same or one of them should be 1
if not (x.shape == y.shape):
assert all([(x.shape[ii] == y.shape[ii]) or (1 in [x.shape[ii], y.shape[ii]]) for ii in range(len(x.shape))]), f"x and y should have the same shape or one of them should have shape 1 at each dimension. Found x.shape={x.shape} and y.shape={y.shape}"

Expand All @@ -606,6 +613,9 @@ def torch_coherence(
window = scipy.signal.get_window(window, nperseg)
window = torch.tensor(window, dtype=x.dtype, device=x.device)

## args should be greater than nfft
assert all([arg.shape[axis] >= nfft for arg in (x, y)]), f"Signal length along axis should be greater than nfft. Found x.shape={x.shape} and y.shape={y.shape} and nfft={nfft}"

## Detrend the signals
def detrend_constant(y, axis):
y = y - torch.mean(y, axis=axis, keepdim=True)
Expand Down Expand Up @@ -656,13 +666,14 @@ def detrend_linear(y, axis):

## Prepare batch generator
num_segments = (x.shape[axis] - nperseg) // (nperseg - noverlap) + 1
batch_size = num_segments if batch_size is None else batch_size
batch_size = max(num_segments, 1) if batch_size is None else batch_size
x_batches, y_batches = (indexing.batched_unfold(
var,
dimension=axis,
size=nperseg,
step=nperseg - noverlap,
batch_size=batch_size,
pad_value=pad_last_segment,
) for var in (x, y))

## Perform Welch's averaging of FFT segments
Expand Down
3 changes: 3 additions & 0 deletions bnpm/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
__all__ = [
'test_all',
'benchmark_all',
'test_coherence',
'test_PCA',
'test_regression',
]
58 changes: 53 additions & 5 deletions bnpm/tests/test_coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import scipy.signal
from hypothesis import given, strategies as st
from hypothesis import settings

from ..spectral import torch_coherence

Expand Down Expand Up @@ -109,7 +110,7 @@ def test_overlap_sizes(noverlap):
@pytest.mark.parametrize("nfft", [256, 512, 1024])
def test_fft_lengths(nfft):
np.random.seed(0) # For reproducibility
t = np.linspace(0, 1, 1000, endpoint=False)
t = np.linspace(0, 1, 2048, endpoint=False)
x = np.sin(2 * np.pi * 5 * t) # 5 Hz sinusoid
y = np.sin(2 * np.pi * 5 * t + np.pi/4) # 5 Hz sinusoid, phase shifted

Expand Down Expand Up @@ -145,7 +146,7 @@ def test_detrending_methods(detrend):
# Check if the results are close enough
assert np.allclose(coherence_pytorch.numpy(), coherence_scipy, atol=1e-2), f"Coherence values do not match closely enough for detrend method={detrend}."

# Test multi-dimensional input
# Test 2D input
def test_multi_dimensional_input():
np.random.seed(0) # For reproducibility
t = np.linspace(0, 1, 1000, endpoint=False)
Expand All @@ -165,6 +166,53 @@ def test_multi_dimensional_input():
freqs_pytorch, coherence_pytorch = torch_coherence(x_torch, y_torch, fs=fs, nperseg=nperseg)
freqs_scipy, coherence_scipy = scipy.signal.coherence(x, y, fs=fs, nperseg=nperseg, axis=1)

# Check if the results are close enough, comparing each ensemble member's coherence
for i in range(10):
assert np.allclose(coherence_pytorch[i].numpy(), coherence_scipy[i], atol=1e-2), "Coherence values do not match for multi-dimensional input."
# Check if the results are close enough
assert np.allclose(coherence_pytorch.numpy(), coherence_scipy, atol=1e-2), "Coherence values do not match for multi-dimensional input."

# Test ND input
@given(
st.integers(min_value=1, max_value=4), # Number of dimensions
st.integers(min_value=16, max_value=40,), # Number of samples
)
@settings(
max_examples=10,
deadline=2000,
)
def test_nd_input(ndim, nsamples):
np.random.seed(0)
axis = np.random.choice(range(ndim))
## Make random data
### Make random shape with ndim and either 1 or nsamples features
shape = np.random.randint(nsamples, nsamples * 2, size=ndim)
### Select between 0 and ndim-1 integers with values between 0 and ndim
axes = np.random.choice(range(0, ndim), np.random.randint(0, ndim), replace=False)
### Make singleton dimensions at either x or y for the dims in axis
shape_x = list(shape)
shape_y = list(shape)
for i in axes:
if i == axis:
continue
### 50% chance of applying to x or y, 10% chance of applying to both
r = np.random.rand()
if r > 0.5:
shape_x[i] = 1
elif r < 0.1:
shape_x[i] = 1
shape_y[i] = 1
else:
shape_y[i] = 1

x = np.random.randn(*shape_x).astype(np.float32)
y = np.random.randn(*shape_y).astype(np.float32)

x_torch = torch.tensor(x)
y_torch = torch.tensor(y)

fs = 1.0
nperseg = 16

freqs_pytorch, coherence_pytorch = torch_coherence(x_torch, y_torch, fs=fs, nperseg=nperseg, axis=axis)
freqs_scipy, coherence_scipy = scipy.signal.coherence(x, y, fs=fs, nperseg=nperseg, axis=axis)

# Check if the results are close enough
assert np.allclose(coherence_pytorch.numpy(), coherence_scipy, atol=1e-2), "Coherence values do not match for ND input."
3 changes: 3 additions & 0 deletions bnpm/tests/test_indexing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__all__ = [
'test_indexing',
]
Loading

0 comments on commit c1b49bd

Please sign in to comment.