Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/onnx support #2620

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Improved**

- New model: `StatsForecastAutoTBATS`. This model offers the [AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) model from Nixtla's `statsforecasts` library. [#2611](https://github.com/unit8co/darts/pull/2611) by [He Weilin](https://github.com/cnhwl).
- Added ONNX support for torch-based models, and an example of export and loading for inference in the User Guide. [#2620](https://github.com/unit8co/darts/pull/2620) by [Antoine Madrona](https://github.com/madtoinou)

**Fixed**
- Fixed a bug when performing optimized historical forecasts with `stride=1` using a `RegressionModel` with `output_chunk_shift>=1` and `output_chunk_length=1`, where the forecast time index was not properly shifted. [#2634](https://github.com/unit8co/darts/pull/2634) by [Mattias De Charleroy](https://github.com/MattiasDC).
Expand Down
55 changes: 45 additions & 10 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
When subclassing this class, please make sure to add the following methods with the given signatures:
- :func:`PLForecastingModule.__init__()`
- :func:`PLForecastingModule.forward()`
- :func:`PLForecastingModule._process_input_batch()`
- :func:`PLForecastingModule._produce_train_output()`
- :func:`PLForecastingModule._get_batch_prediction()`

Expand Down Expand Up @@ -583,6 +584,13 @@ def to_dtype(self, dtype):
logger,
)

def to_onnx(self, file_path, input_sample=None, **kwargs):
if not input_sample:
logger.warning(
"It is recommended to use `TorchForecastingModel.to_onnx` method instead."
)
super().to_onnx(file_path=file_path, input_sample=input_sample, **kwargs)

@property
def epochs_trained(self):
current_epoch = self.current_epoch
Expand Down Expand Up @@ -632,17 +640,48 @@ def _produce_train_output(self, input_batch: tuple):
input_batch
``(past_target, past_covariates, static_covariates)``
"""
past_target, past_covariates, static_covariates = input_batch
return self(self._process_input_batch(input_batch))

def _process_input_batch(
self, input_batch: tuple
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Converts output of PastCovariatesDataset (training dataset) into an input/past- and
output/future chunk.

Parameters
----------
input_batch
``(past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates)``.

Returns
-------
tuple
``(x_past, x_static)`` the input/past and output/future chunks.
"""
# because of future past covariates, the batch shape is different during training and prediction
if len(input_batch) == 3:
(
past_target,
past_covariates,
static_covariates,
) = input_batch
else:
(
past_target,
past_covariates,
future_past_covariates,
static_covariates,
) = input_batch
# Currently all our PastCovariates models require past target and covariates concatenated
inpt = (
return (
(
torch.cat([past_target, past_covariates], dim=2)
if past_covariates is not None
else past_target
),
static_covariates,
)
return self(inpt)

def _get_batch_prediction(
self, n: int, input_batch: tuple, roll_size: int
Expand Down Expand Up @@ -674,12 +713,9 @@ def _get_batch_prediction(
past_covariates.shape[dim_component] if past_covariates is not None else 0
)

input_past = torch.cat(
[ds for ds in [past_target, past_covariates] if ds is not None],
dim=dim_component,
)
input_past, input_static = self._process_input_batch(input_batch)

out = self._produce_predict_output(x=(input_past, static_covariates))[
out = self._produce_predict_output(x=(input_past, input_static))[
:, self.first_prediction_index :, :
]

Expand Down Expand Up @@ -796,7 +832,6 @@ def _process_input_batch(
future_covariates,
static_covariates,
) = input_batch
dim_variable = 2

x_past = torch.cat(
[
Expand All @@ -808,7 +843,7 @@ def _process_input_batch(
]
if tensor is not None
],
dim=dim_variable,
dim=2,
)
return x_past, future_covariates, static_covariates

Expand Down
9 changes: 7 additions & 2 deletions darts/models/forecasting/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def forward(
pass

def _produce_train_output(self, input_batch: tuple) -> torch.Tensor:
# only return the forecast, not the hidden state
return self(self._process_input_batch(input_batch))[0]

def _process_input_batch(
self, input_batch: tuple
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
(
past_target,
historic_future_covariates,
Expand All @@ -112,15 +118,14 @@ def _produce_train_output(self, input_batch: tuple) -> torch.Tensor:
) = input_batch
# For the RNN we concatenate the past_target with the future_covariates
# (they have the same length because we enforce a Shift dataset for RNNs)
model_input = (
return (
(
torch.cat([past_target, future_covariates], dim=2)
if future_covariates is not None
else past_target
),
static_covariates,
)
return self(model_input)[0]

def _produce_predict_output(
self, x: tuple, last_hidden_state: Optional[torch.Tensor] = None
Expand Down
88 changes: 88 additions & 0 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,94 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates):
logger=logger,
)

def to_onnx(
self,
path: Optional[str] = None,
input_sample: Optional[tuple] = None,
randomize_input_sample: bool = False,
**kwargs,
):
"""Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's
:func:`torch.onnx.export` method ()`official documentation <https://lightning.ai/docs/pytorch/
stable/common/lightning_module.html#to-onnx>`_).

Note: onnx library (optionnal dependency) must be installed in order to call this method

Example for exporting a :class:`DLinearModel`:

.. highlight:: python
.. code-block:: python

from darts.models import DLinearModel
from darts import TimeSeries
import numpy as np

train_ts = TimeSeries.from_values(np.arange(0,100))
model = DLinearModel(input_chunk_length=4, output_chunk_length=1)
model.fit(train_ts, epochs=1)
model.to_onnx("my_model.onnx")
..

Parameters
----------
path
Path under which to save the model at its current state. If no path is specified, the model
is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.onnx"``.
input_sample
Tuple of Tensor corresponding to the inputs of the model forward pass.
randomize_input_sample
Wether to randomize the values in the `input_sample` to avoid leaking data.
**kwargs
Additional kwargs for PyTorch's :func:`torch.onnx.export` method, such as ``verbose`` prints a
description of the model being exported to stdout.
For more information, read the `official documentation <https://pytorch.org/docs/master/
onnx.html#torch.onnx.export>`_.
"""
if not self._fit_called:
raise_log(
ValueError("`fit()` needs to be called before `to_onnx()`."), logger
)

if not self.train_sample and not input_sample:
raise_log(
ValueError(
"Either the `input_sample` argument or the `train_sample` attribute must be provided."
),
logger,
)

if path is None:
path = self._default_save_path() + ".onnx"

if not input_sample:
# last dimension in train_sample_shape is the expected target
mock_batch = tuple(
torch.rand((1,) + shape, dtype=self.model.dtype) if shape else None
for shape in self.model.train_sample_shape[:-1]
)
input_sample = self.model._process_input_batch(mock_batch)
elif randomize_input_sample:
input_sample = tuple(
torch.rand(tensor.shape, dtype=self.model.dtype)
if tensor is not None
else None
for tensor in input_sample
)

# torch models necessarily use historic target values as features in current implementation
input_names = ["x_past"]
if self._uses_future_covariates:
input_names.append("x_future")
if self._uses_static_covariates:
input_names.append("x_static")

self.model.to_onnx(
file_path=path,
input_sample=(input_sample,),
input_names=input_names,
**kwargs,
)

@random_method
def fit(
self,
Expand Down
82 changes: 82 additions & 0 deletions docs/userguide/torch_forecasting_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ We assume that you already know about covariates in Darts. If you're new to the
- [Manual saving / loading](#manual-saving--loading)
- [Train & save on GPU, load on CPU](#trainingsaving-on-gpu-and-loading-on-cpu)
- [Load pre-trained model for fine-tuning](#re-training-or-fine-tuning-a-pre-trained-model)
- [Exporting model to ONNX format for inference](#exporting-model-to-ONNX-format-for-inference)
- [Callbacks](#callbacks)
- [Early Stopping](#example-with-early-stopping)
- [Custom Callback](#example-of-custom-callback-to-store-losses)
Expand Down Expand Up @@ -350,6 +351,87 @@ model_finetune = SomeTorchForecastingModel(..., # use identical parameters & va
model_finetune.load_weights("/your/path/to/save/model.pt")
```

#### Exporting model to ONNX format for inference

It is also possible to export the model weights to the ONNX format to run inference in a lightweight environment. This example assumes that the model is trained using a future covariates that extends far enough into the future. Note that the user must align and slice the series manually and it will not be possible to forecast `n > output_chunk_length` without implementing the auto-regression logic

```python
model = SomeTorchForecastingModel(...)
model.fit(...)

# make sure to have onnx installed
onnx_filename = "example_onnx.onnx"
model.to_onnx(onnx_filename, export_params=True)
```

Now, to load the model and predict steps after the end of the series:

```python
import onnx
import onnxruntime as ort

def prepare_onnx_inputs(
model,
series: TimeSeries,
past_covariates : Optional[TimeSeries] = None,
future_covariates : Optional[TimeSeries] = None,
) -> tuple[Optional[np.ndarray]]:
"""Helper function to slice and concatenate the input features"""
past_feats, future_feats, static_feats = None, None, None
# convert and concatenate the historic features (target, past and future covariates)
past_feats = series.values()[-model.input_chunk_length:]
if past_covariates:
past_feats = np.concatenate(
[
past_feats,
past_covariates.values()[-model.input_chunk_length:]
],
axis=1
)
if future_covariates:
past_feats = np.concatenate(
[
past_feats,
future_covariates.values()[-model.input_chunk_length:]
],
axis=1
)
past_feats = np.expand_dims(past_feats, axis=0)

# convert the future covariates
if model._uses_future_covariates:
if future_covariates:
future_feats = np.expand_dims(future_covariates.values()[
len(series):len(series)+model.output_chunk_length
], axis=0)

# convert static covariates
if series.has_static_covariates:
static_feats = np.expand_dims(series.static_covariates_values(), axis=0)

return past_feats, future_feats, static_feats

onnx_model = onnx.load(onnx_filename)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(onnx_filename)

# use helper function to extract the features from the series
past_feats, future_feats, static_feats = prepare_input_feats(
model=model,
series=series,
past_covariates = None,
future_covariates = ts_future,
)

# extract only the features expected by the model
ort_inputs = {
k:v for k, v in zip(['x_past', 'x_future', 'x_static'], [past_feats, future_feats, static_feats]) if k in [inp.name for inp in list(ort_session.get_inputs())]
}
ort_outs = ort_session.run(None, ort_inputs)
```

Note that the forecasts might be slightly different due to rounding errors. Also, due to its specificities, `RNNModel` requires different pre-processing of the series to obtain the input arrays (notably because of `training_length`).

### Callbacks

Callbacks are a powerful way to monitor or control the behavior of the model during the training process. Some examples:
Expand Down
Loading