From 1c48182e4f136c3e3d52ea5a66b16c1bd5fab50b Mon Sep 17 00:00:00 2001 From: Sameera Shah Date: Mon, 1 Jul 2024 18:19:58 +0800 Subject: [PATCH] returns training loss --- src/dynadojo/baselines/ode.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/dynadojo/baselines/ode.py b/src/dynadojo/baselines/ode.py index 8aa76e78..92a85c3a 100644 --- a/src/dynadojo/baselines/ode.py +++ b/src/dynadojo/baselines/ode.py @@ -26,6 +26,7 @@ def fit(self, x: np.ndarray, epochs=100, **kwargs): x = torch.tensor(x, dtype=torch.float32) state = x[:, 0, :] t = torch.linspace(0.0, self._timesteps, self._timesteps) + losses = [] for _ in range(epochs): self.opt.zero_grad() pred = odeint(self.forward, state, t, method='rk4') @@ -33,7 +34,10 @@ def fit(self, x: np.ndarray, epochs=100, **kwargs): loss = self.mse_loss(pred, x).float() loss.backward() self.opt.step() - + losses.append(loss) + return { + "train_losses": losses + } def predict(self, x0: np.ndarray, timesteps: int, **kwargs) -> np.ndarray: x0 = torch.tensor(x0, dtype=torch.float32) t = torch.linspace(0.0, timesteps, timesteps)