Geero is a simple neural network library built using Google Jax and inspired by Stax. This library is mainly for educational purposes and to experiment with deep learning methods.
python -m venv venvGeero
# Linux (WSL2)
. venvGeero/bin/activate
# Alternate Linux Method
source venvGeero/bin/activate
Sometimes VSCode will automatically activate the venv, sometimes it does not. You can use one of the above commands to activate the venv. Or sometimes just closing and reopening your terminal will cause it to properly source your venv.
If you are having trouble getting the venv to activate (at all) when it has worked properly in the past...
CTRL + SHIFT + P -> Search 'Reload' -> Click 'Python: Clear Cache and Reload Window'
pip install --upgrade pip
pip install -r requirements.txt
If using WSL2 on Windows. Please install python3-tk. This will give you a backend for GUIs. Otherwise matplotlib and similar features will not work.
sudo apt-get install python3-tk
By default the requirements.txt file will install jax for cpu. However, to take advantage of the gpu you must install both cuda and jax dependencies. The easiest way to install these dependencies is through pip. Instructions can be found in the README of the jax repo. Once these dependencies are installed, the gpu should automatically be used when using any Jax or Geero functionality.
File | Dataset | Description |
---|---|---|
mnist.py | MNIST | A classic MNIST training example utilizing a neural network with Dense layers |
mnist_conv.py | MNIST | MNIST training example that utilizes a convolutional neural network |
mnist_aug.py | MNIST | MNIST training example that utilizes data augmentation. |
mnist_fcnn.py | MNIST | MNIST training example that utilizes a fully convolutional neural network |
fashion_mnist.py | FASHION MNIST | FASHION MNIST training example that utilizes a convolutional neural network |
cifar10.py | CIFAR-10 | CIFAR-10 training example that utilizes a convolutional neural network |
resnet.py | CIFAR-10 | CIFAR-10 training example that utilizes a resnet of my own creation. Test set accuracy is around 90% |
Please see env_vars, for a list of available environment variables and logging levels.
Please see decorators, for a list of currently implemented decorators.
This project utilizes the Google docstring format. It is recommended that this is configured with the autoDocstring VSCode extension.
Examples of the docstring format can be found below:
Currently all initializers used within Geero are from the initializers implementation in Jax. I may port them over at some point for visibility. But for now you can reference and use the Jax documentation for initializers.
The -1 for the input_shape tuple is a wildcard for the batch size. The wildcard was added because the batch size can vary between forward passes after initialization. And the actual initialization of the layers is in no way dependent on the actual batch size.
input_shape = (input_size,)
# or...
input_shape = (-1, input_size)
# or if you have a fixed batch size (such as 128)...
input_shape = (128, input_size)
These variations only really affect the output of some of the debug shape print statements. The more information you give Geero, the more accurate shapes it will display in the debug statements.
The initialization (init_fun) of the Conv layer is not dependent on the batch size. It can also be independent of the input_height and input_width, but only if you have a purely convolutional network. This is because if you have Dense layers mixed in, the Dense layers need to know the flattened output shape of the Conv layer. This is because unlike Conv, the Dense layer IS dependent on the input size.
input_shape = (-1, input_height, input_width, input_channels)
# ONLY if your network is fully convolutional can you do...
input_shape = (-1, -1, -1, input_channels)
# Although this is not recommended if you DO have a fixed input size.
# Because the debugging information will be less useful.