diff --git a/bnpm/spectral.py b/bnpm/spectral.py index e267759..3a9d7bb 100644 --- a/bnpm/spectral.py +++ b/bnpm/spectral.py @@ -13,6 +13,7 @@ from . import misc from . import torch_helpers from . import timeSeries +from . import indexing def design_butter_bandpass(lowcut, highcut, fs, order=5, plot_pref=True): @@ -528,18 +529,12 @@ def torch_coherence( nfft: Optional[int] = None, detrend: str = 'constant', axis: int = -1, + batch_size: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the magnitude-squared coherence between two signals using a PyTorch implementation. This function gives identical results to the scipy.signal.coherence. \n - - The primary difference in implementation between this and scipy's coherence - is that this uses an accumulation method for Welch's method, while scipy - just makes a large array with all the overlapping windows. Therefore, this - method uses less memory and is faster for large windows but is slower for - small windows and there is a very small amount of numerical error due to the - accumulation. \n Speed: The 'linear' detrending method is not fast on GPU, despite the implementation being similar. 'constant' is roughly 3x as fast as 'linear' @@ -571,6 +566,11 @@ def torch_coherence( 'constant' or 'linear'. (Default is 'constant') axis (int): Axis along which the coherence is calculated. (Default is -1) + batch_size (Optional[int]): + Number of segments to process at once. Used to reduce memory usage. + If None, then all segments are processed at once. Note that + ``num_segments = (x.shape[axis] - nperseg) // (nperseg - noverlap) + + 1``. (Default is None) Returns: Tuple[torch.Tensor, torch.Tensor]: @@ -591,7 +591,7 @@ def torch_coherence( ## Check dimensions ### They should either be the same or one of them should be 1 if not (x.shape == y.shape): - assert all([x.shape[ii] in [1, y.shape[ii]] for ii in range(len(x.shape))]), f"x and y should have the same shape or one of them should have shape 1 at each dimension. Found x.shape={x.shape} and y.shape={y.shape}" + assert all([(x.shape[ii] == y.shape[ii]) or (1 in [x.shape[ii], y.shape[ii]]) for ii in range(len(x.shape))]), f"x and y should have the same shape or one of them should have shape 1 at each dimension. Found x.shape={x.shape} and y.shape={y.shape}" if nperseg is None: nperseg = len(x) // 8 @@ -619,18 +619,18 @@ def detrend_linear(y, axis): Uses least squares approach to remove linear trend. """ ## Move axis to end - y_dims_to = [ii for ii in range(len(y.shape)) if ii != axis] + [axis] - y = y.permute(*y_dims_to)[..., None] + y = y.moveaxis(axis, -1)[..., None] ## Prepare the design matrix X = X_linearDetrendPrep[*([None] * (len(y.shape) - 2))] ## Compute the coefficients - beta = torch.linalg.lstsq(X, y)[0] + # beta = torch.linalg.lstsq(X, y)[0] + ### Use closed form solution for least squares + beta = torch.linalg.inv(X.transpose(-1, -2) @ X) @ X.transpose(-1, -2) @ y ## Remove the trend y = y - opt_einsum.contract('...ij, ...jk -> ...ik', X, beta) y = y[..., 0] ## Move axis back to original position (argsort y_dims_to) - y_dims_from = [y_dims_to.index(ii) for ii in range(len(y.shape))] - y = y.permute(*y_dims_from) + y = y.moveaxis(-1, axis) return y if detrend == 'constant': @@ -645,33 +645,39 @@ def detrend_linear(y, axis): x_shape = list(x.shape) y_shape = list(y.shape) out_shape = [max(x_shape[i], y_shape[i]) for i in range(len(x_shape))] - out_shape[axis] = nfft // 2 + 1 ## rfft returns only non-negative frequencies (0 to fs/2 inclusive ) + out_shape[axis] = nfft // 2 + 1 ## rfft returns only non-negative frequencies (0 to fs/2 inclusive) ## Initialize sums for Welch's method ### Prepare complex dtype dtype_complex = x.dtype.to_complex() f_cross_sum = torch.zeros(out_shape, dtype=dtype_complex, device=x.device) - psd1_sum = torch.zeros(out_shape, dtype=dtype_complex, device=x.device) - psd2_sum = torch.zeros(out_shape, dtype=dtype_complex, device=x.device) + psd1_sum = torch.zeros(out_shape, dtype=x.dtype, device=x.device) + psd2_sum = torch.zeros(out_shape, dtype=x.dtype, device=x.device) - ## Perform Welch's averaging of FFT segments + ## Prepare batch generator num_segments = (x.shape[axis] - nperseg) // (nperseg - noverlap) + 1 - ### Pad window with [None] dims to match x and y - window = window[(None,) * axis + (slice(None),) + (None,) * (len(x.shape) - axis - 1)] - for ii in range(num_segments): - start = ii * (nperseg - noverlap) - end = start + nperseg - fn_get_segment = lambda x, axis, start, end: torch.fft.rfft(fn_detrend(torch_helpers.slice_along_dim(x, axis, slice(start, end)), axis=axis) * window, n=nfft, dim=axis) - segment1 = fn_get_segment(x, axis, start, end) - segment2 = fn_get_segment(y, axis, start, end) - f_cross_sum += torch.conj(segment1) * segment2 - psd1_sum += torch.conj(segment1) * segment1 - psd2_sum += torch.conj(segment2) * segment2 + batch_size = num_segments if batch_size is None else batch_size + x_batches, y_batches = (indexing.batched_unfold( + var, + dimension=axis, + size=nperseg, + step=nperseg - noverlap, + batch_size=batch_size, + ) for var in (x, y)) + + ## Perform Welch's averaging of FFT segments + for segs_x, segs_y in zip(x_batches, y_batches): + process_segment = lambda x: torch.fft.rfft(fn_detrend(x, axis=-1) * window, n=nfft, dim=-1) ## Note the broadcasting of 1-D window with last dimension of x + segs_x = process_segment(segs_x) + segs_y = process_segment(segs_y) + f_cross_sum += torch.sum(torch.conj(segs_x) * segs_y, dim=axis).moveaxis(-1, axis) + psd1_sum += torch.sum((torch.conj(segs_x) * segs_x).real, dim=axis).moveaxis(-1, axis) + psd2_sum += torch.sum((torch.conj(segs_y) * segs_y).real, dim=axis).moveaxis(-1, axis) ## Averaging the sums f_cross = f_cross_sum / num_segments - psd1 = psd1_sum.real / num_segments - psd2 = psd2_sum.real / num_segments + psd1 = psd1_sum / num_segments + psd2 = psd2_sum / num_segments ## Compute coherence coherence = torch.abs(f_cross) ** 2 / (psd1 * psd2)