Skip to content

Commit

Permalink
Add support for TensorFlow 2.16
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Mar 28, 2024
1 parent 83b9eec commit 6e92b97
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 160 deletions.
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Release history
0.7.1 (unreleased)
==================

*Compatible with TensorFlow 2.4 - 2.13*
*Compatible with TensorFlow 2.4 - 2.16*

0.7.0 (July 20, 2023)
=====================
Expand Down
3 changes: 2 additions & 1 deletion docs/basic-usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ a new LMU layer:

.. testcode::

import keras
import keras_lmu

lmu_layer = keras_lmu.LMU(
memory_d=1,
order=256,
theta=784,
hidden_cell=tf.keras.layers.SimpleRNNCell(units=10),
hidden_cell=keras.layers.SimpleRNNCell(units=10),
)

Note that the values used above for ``memory_d``, ``order``,
Expand Down
13 changes: 7 additions & 6 deletions docs/examples/psMNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"source": [
"%matplotlib inline\n",
"\n",
"import keras\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import Image, display\n",
Expand Down Expand Up @@ -261,22 +262,22 @@
" memory_d=1,\n",
" order=256,\n",
" theta=n_pixels,\n",
" hidden_cell=tf.keras.layers.SimpleRNNCell(212),\n",
" hidden_cell=keras.layers.SimpleRNNCell(212),\n",
" hidden_to_memory=False,\n",
" memory_to_memory=False,\n",
" input_to_hidden=True,\n",
" kernel_initializer=\"ones\",\n",
")\n",
"\n",
"# TensorFlow layer definition\n",
"inputs = tf.keras.Input((n_pixels, 1))\n",
"inputs = keras.Input((n_pixels, 1))\n",
"lmus = lmu_layer(inputs)\n",
"outputs = tf.keras.layers.Dense(10)(lmus)\n",
"outputs = keras.layers.Dense(10)(lmus)\n",
"\n",
"# TensorFlow model definition\n",
"model = tf.keras.Model(inputs=inputs, outputs=outputs)\n",
"model = keras.Model(inputs=inputs, outputs=outputs)\n",
"model.compile(\n",
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" optimizer=\"adam\",\n",
" metrics=[\"accuracy\"],\n",
")\n",
Expand Down Expand Up @@ -318,7 +319,7 @@
"\n",
"saved_weights_fname = \"./psMNIST-weights.hdf5\"\n",
"callbacks = [\n",
" tf.keras.callbacks.ModelCheckpoint(\n",
" keras.callbacks.ModelCheckpoint(\n",
" filepath=saved_weights_fname, monitor=\"val_loss\", verbose=1, save_best_only=True\n",
" ),\n",
"]\n",
Expand Down
Loading

0 comments on commit 6e92b97

Please sign in to comment.