- PyTorch Distributed Data Parallel (DDP) based training harness for training the network (post-pruning) as fast as possible.
- FFCV integration for super-fast training on ImageNet (1:09 mins/epoch on 4xA100 GPUs with ResNet18).
- Support for most (if not all) torchvision models with limited testing of coverage with timm.
- Multiple pruning techniques, listed below.
- Simple harness, with hydra -- easily extensible.
- Logging to CSV and wandb (nothing fancy, but you can integrate wandb/comet/your own system easily).
An aim was also to make it easy to look through stuff, and I put in decent effort with logging via rich :D
The numbers below were obtained on a cluster with similar computational configuration -- only variation was the dataloading method, AMP (enabled where specified) and the GPU model used was NVIDIA A100 (40GB).
The model used was ResNet50 and the effective batch size in each case was 512.
- CIFAR10
- CIFAR100
- ImageNet
As it stands, ResNets, VGG variants should work out of the box. If you run into issues with any other variant happy to look into. For CIFAR based datasets, there are modification to the basic architecture based on tuning and references such as this repository.
There is additional support for Vision Transformers via timm, however as of this commit -- this is limited and has been tested only for DeIT.
-
- Name: Iterative Magnitude Pruning (IMP)
- Type of Pruning: Iterative
- Paper: The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks
-
- Name: IMP with Weight Rewinding (IMP + WR)
- Type of Pruning: Iterative
- Paper: Stabilizing the lottery ticket hypothesis
-
- Name: IMP with Learning Rate Rewinding (IMP + LRR)
- Type of Pruning: Iterative
- Paper: Comparing Rewinding and Fine-tuning in Neural Network Pruning
-
- Name: SNIP
- Type of Pruning: Pruning at Initialization (PaI), One-shot
- Paper: SNIP: Single-shot Network Pruning based on Connection Sensitivity
-
- Name: SynFlow
- Type of Pruning: Pruning at Initialization (PaI), One-shot
- Paper: Pruning neural networks without any data by iteratively conserving synaptic flow
-
- Name: Random Balanced/ERK Pruning
- Type of Pruning: Pruning at Initialization (PaI) One-shot + Iterative
- Paper: Why Random Pruning Is All We Need to Start Sparse
-
- Name: Random Pruning
- Type of Pruning: Iterative
- Paper: The Unreasonable Effectiveness of Random Pruning: Return of the Most Naive Baseline for Sparse Training
- run_experiment.py - This the main script for running pruning experiments, it uses the PruningHarness which is sub-classes BaseHarness and supports training all configurations currently possible in this repository. If you would like to modify the eventual running, I'd recommend using this.
- harness_definitions/base_harness.py: Base training harness for running experiments, can be re-used for non-pruning experiments as well -- if you think its releveant and want the flexibility of modifying the forward pass and other componenets.
- utils/harness_params.py: I realized hydra based config systems is more flexible, so now all experiment parameters are specified via hydra + easily extensible via dataclasses.
- utils/harness_utils.py: This contains a lot of functions used for making the code run, logging metrics and other misc stuff. Let me know if you know how to cut it down :)
- utils/custom_models.py: Model wrapper with additional functionalities that make your pruning experiments easier.
- utils/dataset.py: definiton for CIFAR10/CIFAR100, ImageNet with FFCV but WebDatasets is a WIP.
- utils/schedulers.py: learning rate schedulers, for when you need to use them.
- utils/pruning_utils.py: Pruning functions + a simple function to apply the function.
Where necessary, pruning will use a single GPU/Dataset in the training precision chosen.
- To run ImageNet experiments, you obviously need ImageNet downloaded -- in addition, since we use FFCV, you would need to generate .beton files as per the instructions here.
- CIFAR10, CIFAR100 and other stuff are handled using cifar10-airbench, but no change is required by the user. You do not need distributed training as its faster on a single GPU (lol) -- so there is no support for dist training with these datasets via airbench. But if you really want to you can modify the harness, train loop and use the Standard PT loaders.
- Have a look at the harness_params and the config structure to understand how to configure experiemnts. Its worth it.
Now to the fun part:
To start an experiment, ensure there is appropriate (sufficient) compute (or it might take a while -- its going to anyways) and in case of ImageNet the appropriate betons available.
pip install -r requirements.txt
python run_experiment.py --config-name=cifar10_er_erk dataset_params.data_root_dir=<PATH_TO_FOLDER>
For DDP (Only ImageNet)
torchrun --nproc_per_node=<num_gpus> run_experiment.py --config-name=imagenet_er_erk dataset_params.data_root_dir=<PATH_TO_FOLDER>
and it should start.
This is a bit detailed, coming soon - if you need any help -- open an issue or reach out.
The configs provided in conf/ are for some tuned baselines, but if you find a better configuration -- please feel free to make a pull request.
All baselines are coming soon!
If you use this code in your research, and find it useful in general -- please consider citing using:
@software{Nelaturu_TurboPrune_High-Speed_Distributed,
author = {Nelaturu, Sree Harsha and Gadhikar, Advait and Burkholz, Rebekka},
license = {Apache-2.0},
title = {{TurboPrune: High-Speed Distributed Lottery Ticket Training}},
url = {https://github.com/nelaturuharsha/TurboPrune}}
- This code is built using references to the substantial hard work put in by Advait Gadhikar.
- Thank you to Dr. Rebekka Burkholz for the opportunity to build this :)
- I was heavily influenced by the code style here. Just a general thanks and shout-out to the FFCV team for all they've done!
- All credit/references for the original methods and reference implementations are due to the original authors of the work :)
- Thank you Andrej, Bhavnick, Akanksha for feedback :)