Skip to content

Commit

Permalink
Refactor torch_coherence function to include batch processing option.…
Browse files Browse the repository at this point in the history
… Massive speedup.
  • Loading branch information
RichieHakim committed Apr 17, 2024
1 parent 073d194 commit c3be642
Showing 1 changed file with 36 additions and 30 deletions.
66 changes: 36 additions & 30 deletions bnpm/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from . import misc
from . import torch_helpers
from . import timeSeries
from . import indexing


def design_butter_bandpass(lowcut, highcut, fs, order=5, plot_pref=True):
Expand Down Expand Up @@ -528,18 +529,12 @@ def torch_coherence(
nfft: Optional[int] = None,
detrend: str = 'constant',
axis: int = -1,
batch_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the magnitude-squared coherence between two signals using a PyTorch
implementation. This function gives identical results to the
scipy.signal.coherence. \n
The primary difference in implementation between this and scipy's coherence
is that this uses an accumulation method for Welch's method, while scipy
just makes a large array with all the overlapping windows. Therefore, this
method uses less memory and is faster for large windows but is slower for
small windows and there is a very small amount of numerical error due to the
accumulation. \n
Speed: The 'linear' detrending method is not fast on GPU, despite the
implementation being similar. 'constant' is roughly 3x as fast as 'linear'
Expand Down Expand Up @@ -571,6 +566,11 @@ def torch_coherence(
'constant' or 'linear'. (Default is 'constant')
axis (int):
Axis along which the coherence is calculated. (Default is -1)
batch_size (Optional[int]):
Number of segments to process at once. Used to reduce memory usage.
If None, then all segments are processed at once. Note that
``num_segments = (x.shape[axis] - nperseg) // (nperseg - noverlap) +
1``. (Default is None)
Returns:
Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -591,7 +591,7 @@ def torch_coherence(
## Check dimensions
### They should either be the same or one of them should be 1
if not (x.shape == y.shape):
assert all([x.shape[ii] in [1, y.shape[ii]] for ii in range(len(x.shape))]), f"x and y should have the same shape or one of them should have shape 1 at each dimension. Found x.shape={x.shape} and y.shape={y.shape}"
assert all([(x.shape[ii] == y.shape[ii]) or (1 in [x.shape[ii], y.shape[ii]]) for ii in range(len(x.shape))]), f"x and y should have the same shape or one of them should have shape 1 at each dimension. Found x.shape={x.shape} and y.shape={y.shape}"

if nperseg is None:
nperseg = len(x) // 8
Expand Down Expand Up @@ -619,18 +619,18 @@ def detrend_linear(y, axis):
Uses least squares approach to remove linear trend.
"""
## Move axis to end
y_dims_to = [ii for ii in range(len(y.shape)) if ii != axis] + [axis]
y = y.permute(*y_dims_to)[..., None]
y = y.moveaxis(axis, -1)[..., None]
## Prepare the design matrix
X = X_linearDetrendPrep[*([None] * (len(y.shape) - 2))]
## Compute the coefficients
beta = torch.linalg.lstsq(X, y)[0]
# beta = torch.linalg.lstsq(X, y)[0]
### Use closed form solution for least squares
beta = torch.linalg.inv(X.transpose(-1, -2) @ X) @ X.transpose(-1, -2) @ y
## Remove the trend
y = y - opt_einsum.contract('...ij, ...jk -> ...ik', X, beta)
y = y[..., 0]
## Move axis back to original position (argsort y_dims_to)
y_dims_from = [y_dims_to.index(ii) for ii in range(len(y.shape))]
y = y.permute(*y_dims_from)
y = y.moveaxis(-1, axis)
return y

if detrend == 'constant':
Expand All @@ -645,33 +645,39 @@ def detrend_linear(y, axis):
x_shape = list(x.shape)
y_shape = list(y.shape)
out_shape = [max(x_shape[i], y_shape[i]) for i in range(len(x_shape))]
out_shape[axis] = nfft // 2 + 1 ## rfft returns only non-negative frequencies (0 to fs/2 inclusive )
out_shape[axis] = nfft // 2 + 1 ## rfft returns only non-negative frequencies (0 to fs/2 inclusive)

## Initialize sums for Welch's method
### Prepare complex dtype
dtype_complex = x.dtype.to_complex()
f_cross_sum = torch.zeros(out_shape, dtype=dtype_complex, device=x.device)
psd1_sum = torch.zeros(out_shape, dtype=dtype_complex, device=x.device)
psd2_sum = torch.zeros(out_shape, dtype=dtype_complex, device=x.device)
psd1_sum = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
psd2_sum = torch.zeros(out_shape, dtype=x.dtype, device=x.device)

## Perform Welch's averaging of FFT segments
## Prepare batch generator
num_segments = (x.shape[axis] - nperseg) // (nperseg - noverlap) + 1
### Pad window with [None] dims to match x and y
window = window[(None,) * axis + (slice(None),) + (None,) * (len(x.shape) - axis - 1)]
for ii in range(num_segments):
start = ii * (nperseg - noverlap)
end = start + nperseg
fn_get_segment = lambda x, axis, start, end: torch.fft.rfft(fn_detrend(torch_helpers.slice_along_dim(x, axis, slice(start, end)), axis=axis) * window, n=nfft, dim=axis)
segment1 = fn_get_segment(x, axis, start, end)
segment2 = fn_get_segment(y, axis, start, end)
f_cross_sum += torch.conj(segment1) * segment2
psd1_sum += torch.conj(segment1) * segment1
psd2_sum += torch.conj(segment2) * segment2
batch_size = num_segments if batch_size is None else batch_size
x_batches, y_batches = (indexing.batched_unfold(
var,
dimension=axis,
size=nperseg,
step=nperseg - noverlap,
batch_size=batch_size,
) for var in (x, y))

## Perform Welch's averaging of FFT segments
for segs_x, segs_y in zip(x_batches, y_batches):
process_segment = lambda x: torch.fft.rfft(fn_detrend(x, axis=-1) * window, n=nfft, dim=-1) ## Note the broadcasting of 1-D window with last dimension of x
segs_x = process_segment(segs_x)
segs_y = process_segment(segs_y)
f_cross_sum += torch.sum(torch.conj(segs_x) * segs_y, dim=axis).moveaxis(-1, axis)
psd1_sum += torch.sum((torch.conj(segs_x) * segs_x).real, dim=axis).moveaxis(-1, axis)
psd2_sum += torch.sum((torch.conj(segs_y) * segs_y).real, dim=axis).moveaxis(-1, axis)

## Averaging the sums
f_cross = f_cross_sum / num_segments
psd1 = psd1_sum.real / num_segments
psd2 = psd2_sum.real / num_segments
psd1 = psd1_sum / num_segments
psd2 = psd2_sum / num_segments

## Compute coherence
coherence = torch.abs(f_cross) ** 2 / (psd1 * psd2)
Expand Down

0 comments on commit c3be642

Please sign in to comment.