Skip to content

Commit

Permalink
modified fit function to return losses
Browse files Browse the repository at this point in the history
  • Loading branch information
sameerashahh committed Jul 3, 2024
1 parent 55cfcf0 commit 51301a4
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/dynadojo/baselines/dnn_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ def fit(self, x: np.ndarray, epochs=2000, verbose=0, **kwargs):
head = x[:, :-1, :]
tail = x[:, 1:, :]
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
self.model.fit(head, tail, validation_split=0.2, epochs=epochs, callbacks=[callback], verbose=verbose)
history = self.model.fit(head, tail, validation_split=0.2, epochs=epochs, callbacks=[callback], verbose=verbose)
# print(history.history.keys())
train_losses = history.history['loss']
val_losses = history.history['val_loss']
return {
"train_loss": train_losses,
"val_loss": val_losses
}

def predict(self, x0: np.ndarray, timesteps: int, **kwargs) -> np.ndarray:
preds = [x0]
Expand Down

0 comments on commit 51301a4

Please sign in to comment.