Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/ccf #2122

Merged
merged 5 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

### For users of the library:
**Improved**
- Added `darts.utils.statistics.plot_ccf` that can be used to plot the cross correlation between a time series (e.g. target series) and the lagged values of another time series (e.g. covariates series). [#2122](https://github.com/unit8co/darts/pull/2122) by [Dennis Bader](https://github.com/dennisbader).
- Improvements to `TimeSeries`:
- Improved the time series frequency inference when using slices or pandas DatetimeIndex as keys for `__getitem__`. [#2152](https://github.com/unit8co/darts/pull/2152) by [DavidKleindienst](https://github.com/DavidKleindienst).

Expand Down
4 changes: 4 additions & 0 deletions darts/tests/utils/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
check_seasonality,
extract_trend_and_seasonality,
granger_causality_tests,
plot_acf,
plot_ccf,
plot_pacf,
plot_residuals_analysis,
remove_seasonality,
Expand Down Expand Up @@ -235,5 +237,7 @@ def test_statistics_plot(self):
plt.close()
plot_residuals_analysis(self.series[:10])
plt.close()
plot_acf(self.series)
plot_pacf(self.series)
plot_ccf(self.series, self.series)
plt.close()
136 changes: 131 additions & 5 deletions darts/utils/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,16 @@
import numpy as np
from scipy.signal import argrelmax
from scipy.stats import norm
from statsmodels.compat.python import lzip
from statsmodels.tsa.seasonal import MSTL, STL, seasonal_decompose
from statsmodels.tsa.stattools import acf, adfuller, grangercausalitytests, kpss, pacf
from statsmodels.tsa.stattools import (
acf,
adfuller,
ccovf,
grangercausalitytests,
kpss,
pacf,
)

from darts import TimeSeries
from darts.logging import get_logger, raise_if, raise_if_not, raise_log
Expand Down Expand Up @@ -599,8 +607,8 @@ def plot_acf(
default_formatting: bool = True,
) -> None:
"""
Plots the ACF of `ts`, highlighting it at lag `m`, with corresponding significance interval.
Uses :func:`statsmodels.tsa.stattools.acf` [1]_
Plots the Autocorrelation Function (ACF) of `ts`, highlighting it at lag `m`, with corresponding significance
interval. Uses :func:`statsmodels.tsa.stattools.acf` [1]_

Parameters
----------
Expand Down Expand Up @@ -695,8 +703,8 @@ def plot_pacf(
default_formatting: bool = True,
) -> None:
"""
Plots the Partial ACF of `ts`, highlighting it at lag `m`, with corresponding significance interval.
Uses :func:`statsmodels.tsa.stattools.pacf` [1]_
Plots the Partial Autocorrelation Function (PACF) of `ts`, highlighting it at lag `m`, with corresponding
significance interval. Uses :func:`statsmodels.tsa.stattools.pacf` [1]_

Parameters
----------
Expand Down Expand Up @@ -785,6 +793,124 @@ def plot_pacf(
axis.plot((0, max_lag + 1), (0, 0), color="black" if default_formatting else None)


def plot_ccf(
ts: TimeSeries,
ts_other: TimeSeries,
m: Optional[int] = None,
max_lag: int = 24,
alpha: float = 0.05,
bartlett_confint: bool = True,
fig_size: Tuple[int, int] = (10, 5),
axis: Optional[plt.axis] = None,
default_formatting: bool = True,
) -> None:
"""
Plots the Cross Correlation Function (CCF) between `ts` and `ts_other`, highlighting it at lag `m`, with
corresponding significance interval. Uses :func:`statsmodels.tsa.stattools.ccf` [1]_

This can be used to find the cross correlation between the target and different covariates lags.
If `ts_other` is identical `ts`, it corresponds to `plot_acf()`.

Parameters
----------
ts
The TimeSeries whose CCF with `ts_other` should be plotted.
ts_other
The TimeSeries which to compare against `ts` in the CCF. E.g. check the cross correlation of different
covariate lags with the target.
m
Optionally, a time lag to highlight on the plot.
max_lag
The maximal lag order to consider.
alpha
The confidence interval to display.
bartlett_confint
The boolean value indicating whether the confidence interval should be
calculated using Bartlett's formula.
fig_size
The size of the figure to be displayed.
axis
Optionally, an axis object to plot the CCF on.
default_formatting
Whether to use the darts default scheme.

References
----------
.. [1] https://www.statsmodels.org/dev/generated/statsmodels.tsa.stattools.ccf.html
"""

ts._assert_univariate()
ts_other._assert_univariate()
raise_if(
max_lag is None or not (1 <= max_lag < len(ts)),
"max_lag must be greater than or equal to 1 and less than len(ts).",
)
raise_if(
m is not None and not (0 <= m <= max_lag),
"m must be greater than or equal to 0 and less than or equal to max_lag.",
)
raise_if(
alpha is None or not (0 < alpha < 1),
"alpha must be greater than 0 and less than 1.",
)
ts_other = ts_other.slice_intersect(ts)
if len(ts_other) != len(ts):
raise_log(
ValueError("`ts_other` must contain at least the full time index of `ts`."),
logger=logger,
)

x = ts.values()
y = ts_other.values()
cvf = ccovf(x=x, y=y, adjusted=True, demean=True, fft=False)

ccf = cvf / (np.std(x) * np.std(y))
ccf = ccf[: max_lag + 1]

n_obs = len(x)
if bartlett_confint:
varccf = np.ones_like(ccf) / n_obs
varccf[0] = 0
varccf[1] = 1.0 / n_obs
varccf[2:] *= 1 + 2 * np.cumsum(ccf[1:-1] ** 2)
else:
varccf = 1.0 / n_obs

interval = norm.ppf(1.0 - alpha / 2.0) * np.sqrt(varccf)
confint = np.array(lzip(ccf - interval, ccf + interval))

if axis is None:
plt.figure(figsize=fig_size)
axis = plt

for i in range(len(ccf)):
axis.plot(
(i, i),
(0, ccf[i]),
color=("#b512b8" if m is not None and i == m else "black")
if default_formatting
else None,
lw=(1 if m is not None and i == m else 0.5),
)

# Adjusts the upper band of the confidence interval to center it on the x axis.
upp_band = [confint[lag][1] - ccf[lag] for lag in range(1, max_lag + 1)]

# Setting color t0 None overrides custom settings
extra_arguments = {}
if default_formatting:
extra_arguments["color"] = "#003DFD"

axis.fill_between(
np.arange(1, max_lag + 1),
upp_band,
[-x for x in upp_band],
alpha=0.25 if default_formatting else None,
**extra_arguments,
)
axis.plot((0, max_lag + 1), (0, 0), color="black" if default_formatting else None)


def plot_hist(
data: Union[TimeSeries, List[float], np.ndarray],
bins: Optional[Union[int, np.ndarray, List[float]]] = None,
Expand Down
Loading