Skip to content

Getting Started

Greg Pauloski edited this page Mar 6, 2023 · 4 revisions

Installation

  1. Create an environment. I recommend installing PyTorch and K-FAC into a virtual environment of your choice to avoid global package installations. E.g.,
    $ python -m venv venv
    $ . venv/bin/activation
  2. Install PyTorch: https://pytorch.org/get-started/locally/
  3. Clone and install KFAC.
    $ git clone https://github.com/gpauloski/kfac-pytorch.git
    $ cd kfac-pytorch
    $ pip install .

Update your Training Script

K-FAC is designed to be used in place with your existing training scripts. You must 1) import the KFACPreconditioner, 2) initialize the preconditioner, and 3) call the preconditioner before each optimization step.

from kfac.preconditioner import KFACPreconditioner

model = torch.nn.parallel.DistributedDataParallel(...)
optimizer = optim.SGD(model.parameters(), ...)

# Initialize KFAC
preconditioner = KFACPreconditioner(model, ...)

for data, target in train_loader:
    optimizer.zero_grad()
    output = model(data)

    loss = criterion(output, target)
    loss.backward()

    # Perform preconditioning before each optimizer step
    preconditioner.step()
    optimizer.step()

In the initialization process, your model will be scanned for any PyTorch modules which are able to be preconditioned with K-FAC (e.g., Conv2D and Linear modules) and will register though modules with the preconditioner. When registered, hooks are attached to the forward and backward passes on the module which will save the necessary intermediate data for the K-FAC computation.

After all layers are registered, the distributed strategy will be applied to determine where various K-FAC computations should take place. The KFACPreconditioner implements the adaptable distributed strategy referred to as KAISA, and the BaseKFACPreconditioner can be extended to implement custom strategies (see Custom Preconditioners).

Now you should be ready to start training!

Additional Features

There are additional features with K-FAC that you may want to enabled prior to starting training.

Resuming from Checkpoints

The KFACPreconditioner supports saving the preconditioning state just like PyTorch optimizers.

model = ...
optimizer = ...
preconditioner = KFACPreconditioner(model, ...)

# Save state to a file
state = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'preconditioner': preconditioner.state_dict(),
}
torch.save(state, filepath)

# Restore state from a file
checkpoint = torch.load(filepath)
preconditioner.load_state_dict(checkpoint['preconditioner'])

The preconditioner stats includes the state of the hyperparameters and the accumulated factors A and G for all layers. If you do not want to save the state of the factors, use KFACPreconditioner.state_dict(include_factors=False). When a state is loaded, KFACPreconditioner.load_state_dict will recompute the inverses/eigen decompositions of the factors.

Hyperparameter Schedulers

Just as PyTorch optimizers use LR schedulers to implement hyperparameter schedules, K-FAC parameters can be adjusted throughout training via two means.

Option 1: All of the precondition hyperparameters can take constant values of callable functions of the form Callable[[int], Any] which take as input the K-FAC step number and return a new value for the hyperparameter.

KFACPreconditioner(
    model,
    # KFAC will call the external function get_update_steps each iteration
    # to determine the update frequency
    factor_update_steps=get_update_steps,
    # Pass a custom lambda to compute a hyper parameter
    damping=lambda x: 0.01 / x,
    # Use a lambda to get a hyperparameter from elsewhere
    lr=lambda x: optimizer.param_groups[0]['lr'],
)

Option 2: Use the LambdaParamScheduler like a PyTorch LR scheduler.

from kfac.scheduler import LambdaParamScheduler

preconditioner = KFACPreconditioner(model, ...)
scheduler = LambdaParamScheduler(
    preconditioner,
    lr=lambda x: optimizer.param_groups[0]['lr'],
    ...
)

for data, target in train_loader:
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    preconditioner.step()
    optimizer.step()
    scheduler.step()

Mixed Precision

K-FAC supports mixed precision training with torch.cuda.amp and torch.nn.parallel.DistributedDataParallel Factors for a given layer will be stored in the data type used during the forward/backward pass. For most PyTorch layers, this will be float16. However, there are exceptions such as layers that take integer inputs (often embeddings). Inverses of the factors are stored in float32, but this can be overridden to float16 to save memory at the cost of potential numerical instability. The GradScaler object can be passed to K-FAC such that K-FAC can appropriately unscale the backward pass data. If the GradScaler is provided, G factors will be cast to float32 to prevent underflow when unscaling the gradients.

When using torch.cuda.amp for mixed precision training, be sure to call KFACPreconditioner.step() outside of an autocast() region. E.g.

from kfac.preconditioner import KFACPreconditioner

model = ...
optimizer = ...
scaler = GradScaler()
preconditioner = KFACPreconditioner(model, grad_scaler=scaler)

for i, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    with autocast():
        output = model(data)
        loss = criterion(output, target)
    scaler.scale(loss).backward()

    # Unscale gradient before preconditioning
    scaler.unscale_(optimizer)
    preconditioner.step()
    scaler.step(optimizer)
    scaler.update()

See the PyTorch mixed-precision docs for more information.

K-FAC does not support NVIDIA AMP because some operations used in K-FAC (torch.inverse and torch.symeig) do not support half-precision inputs, and NVIDIA AMP does not have functionality for disabling autocast in certain code regions.

Gradient Accumulation

If training with gradient accumulation, step the accumulation_steps parameter of the KFACPreconditioner accordingly. This will ensure that the forward and backward pass hooks on each registered module correctly adjust how frequently to capture intermediate data depending on the factor_update_steps and accumulation_steps.