Skip to content

Harness for training/finding lottery tickets in PyTorch. With support for multiple pruning techniques and augmented by distributed training, FFCV and AMP.

License

Notifications You must be signed in to change notification settings

nelaturuharsha/TurboPrune

Repository files navigation

TurboPrune: High-Speed Distributed Lottery Ticket Training

circular_image_centered

  • PyTorch Distributed Data Parallel (DDP) based training harness for training the network (post-pruning) as fast as possible.
  • FFCV integration for super-fast training on ImageNet (1:09 mins/epoch on 4xA100 GPUs with ResNet18).
  • Support for most (if not all) torchvision models with limited testing of coverage with timm.
  • Multiple pruning techniques, listed below.
  • Simple harness, with hydra -- easily extensible.
  • Logging to CSV and wandb (nothing fancy, but you can integrate wandb/comet/your own system easily).

An aim was also to make it easy to look through stuff, and I put in decent effort with logging via rich :D

Timing Comparison

The numbers below were obtained on a cluster with similar computational configuration -- only variation was the dataloading method, AMP (enabled where specified) and the GPU model used was NVIDIA A100 (40GB).

The model used was ResNet50 and the effective batch size in each case was 512.

circular_image_centered

Datasets Supported

  1. CIFAR10
  2. CIFAR100
  3. ImageNet

Networks supported

As it stands, ResNets, VGG variants should work out of the box. If you run into issues with any other variant happy to look into. For CIFAR based datasets, there are modification to the basic architecture based on tuning and references such as this repository.

There is additional support for Vision Transformers via timm, however as of this commit -- this is limited and has been tested only for DeIT.

Pruning Algorithms included:

Repository structure:

  1. run_experiment.py - This the main script for running pruning experiments, it uses the PruningHarness which is sub-classes BaseHarness and supports training all configurations currently possible in this repository. If you would like to modify the eventual running, I'd recommend using this.
  2. harness_definitions/base_harness.py: Base training harness for running experiments, can be re-used for non-pruning experiments as well -- if you think its releveant and want the flexibility of modifying the forward pass and other componenets.
  3. utils/harness_params.py: I realized hydra based config systems is more flexible, so now all experiment parameters are specified via hydra + easily extensible via dataclasses.
  4. utils/harness_utils.py: This contains a lot of functions used for making the code run, logging metrics and other misc stuff. Let me know if you know how to cut it down :)
  5. utils/custom_models.py: Model wrapper with additional functionalities that make your pruning experiments easier.
  6. utils/dataset.py: definiton for CIFAR10/CIFAR100, ImageNet with FFCV but WebDatasets is a WIP.
  7. utils/schedulers.py: learning rate schedulers, for when you need to use them.
  8. utils/pruning_utils.py: Pruning functions + a simple function to apply the function.

Where necessary, pruning will use a single GPU/Dataset in the training precision chosen.

Important Pre-requisites

  • To run ImageNet experiments, you obviously need ImageNet downloaded -- in addition, since we use FFCV, you would need to generate .beton files as per the instructions here.
  • CIFAR10, CIFAR100 and other stuff are handled using cifar10-airbench, but no change is required by the user. You do not need distributed training as its faster on a single GPU (lol) -- so there is no support for dist training with these datasets via airbench. But if you really want to you can modify the harness, train loop and use the Standard PT loaders.
  • Have a look at the harness_params and the config structure to understand how to configure experiemnts. Its worth it.

Usage

Now to the fun part:

Running an Experiment

To start an experiment, ensure there is appropriate (sufficient) compute (or it might take a while -- its going to anyways) and in case of ImageNet the appropriate betons available.

pip install -r requirements.txt
python run_experiment.py --config-name=cifar10_er_erk dataset_params.data_root_dir=<PATH_TO_FOLDER>

For DDP (Only ImageNet)

torchrun --nproc_per_node=<num_gpus> run_experiment.py --config-name=imagenet_er_erk dataset_params.data_root_dir=<PATH_TO_FOLDER>

and it should start.

Hydra Configuration

This is a bit detailed, coming soon - if you need any help -- open an issue or reach out.

Baselines

The configs provided in conf/ are for some tuned baselines, but if you find a better configuration -- please feel free to make a pull request.

ImageNet Baseline

CIFAR10 Baseline

CIFAR100 Baseline

All baselines are coming soon!

If you use this code in your research, and find it useful in general -- please consider citing using:

@software{Nelaturu_TurboPrune_High-Speed_Distributed,
author = {Nelaturu, Sree Harsha and Gadhikar, Advait and Burkholz, Rebekka},
license = {Apache-2.0},
title = {{TurboPrune: High-Speed  Distributed Lottery Ticket Training}},
url = {https://github.com/nelaturuharsha/TurboPrune}}

Footnotes and Acknowledgments:

  • This code is built using references to the substantial hard work put in by Advait Gadhikar.
  • Thank you to Dr. Rebekka Burkholz for the opportunity to build this :)
  • I was heavily influenced by the code style here. Just a general thanks and shout-out to the FFCV team for all they've done!
  • All credit/references for the original methods and reference implementations are due to the original authors of the work :)
  • Thank you Andrej, Bhavnick, Akanksha for feedback :)

About

Harness for training/finding lottery tickets in PyTorch. With support for multiple pruning techniques and augmented by distributed training, FFCV and AMP.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages