Skip to content

Commit

Permalink
support for SB3 callbacks in adversarial training
Browse files Browse the repository at this point in the history
  • Loading branch information
smanolloff committed Sep 14, 2023
1 parent cb93fb0 commit eb8f67e
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 16 deletions.
33 changes: 21 additions & 12 deletions src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import abc
import dataclasses
import logging
from typing import Callable, Iterable, Iterator, Mapping, Optional, Type, overload
from typing import Iterable, Iterator, List, Mapping, Optional, Type, overload

import numpy as np
import torch as th
import torch.utils.tensorboard as thboard
import tqdm
from stable_baselines3.common import base_class, on_policy_algorithm, policies, vec_env
from stable_baselines3.common.type_aliases import MaybeCallback
from stable_baselines3.common.callbacks import BaseCallback, ConvertCallback
from stable_baselines3.sac import policies as sac_policies
from torch.nn import functional as F

Expand Down Expand Up @@ -386,6 +388,7 @@ def train_gen(
self,
total_timesteps: Optional[int] = None,
learn_kwargs: Optional[Mapping] = None,
callback: MaybeCallback = None,
) -> None:
"""Trains the generator to maximize the discriminator loss.
Expand All @@ -398,17 +401,27 @@ def train_gen(
`self.gen_train_timesteps`.
learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()`
method.
callback: additional callback(s) passed to the generator's `learn` method.
"""
if total_timesteps is None:
total_timesteps = self.gen_train_timesteps
if learn_kwargs is None:
learn_kwargs = {}

callbacks = [self.gen_callback]

if isinstance(callback, list):
callbacks.extend(callback)
elif isinstance(callback, BaseCallback):
callbacks.append(callback)
elif callback is not None:
callbacks.append(ConvertCallback(callback))

with self.logger.accumulate_means("gen"):
self.gen_algo.learn(
total_timesteps=total_timesteps,
reset_num_timesteps=False,
callback=self.gen_callback,
callback=callbacks,
**learn_kwargs,
)
self._global_step += 1
Expand All @@ -421,37 +434,33 @@ def train_gen(
def train(
self,
total_timesteps: int,
callback: Optional[Callable[[int], None]] = None,
callback: MaybeCallback = None,
) -> None:
"""Alternates between training the generator and discriminator.
Every "round" consists of a call to `train_gen(self.gen_train_timesteps)`,
a call to `train_disc`, and finally a call to `callback(round)`.
Every "round" consists of a call to
`train_gen(self.gen_train_timesteps, callback)`, then a call to `train_disc`.
Training ends once an additional "round" would cause the number of transitions
sampled from the environment to exceed `total_timesteps`.
Args:
total_timesteps: An upper bound on the number of transitions to sample
from the environment during training.
callback: A function called at the end of every round which takes in a
single argument, the round number. Round numbers are in
`range(total_timesteps // self.gen_train_timesteps)`.
callback: callback(s) passed to the generator's `learn` method.
"""
n_rounds = total_timesteps // self.gen_train_timesteps
assert n_rounds >= 1, (
"No updates (need at least "
f"{self.gen_train_timesteps} timesteps, have only "
f"total_timesteps={total_timesteps})!"
)
for r in tqdm.tqdm(range(0, n_rounds), desc="round"):
self.train_gen(self.gen_train_timesteps)
for _r in tqdm.tqdm(range(0, n_rounds), desc="round"):
self.train_gen(self.gen_train_timesteps, callback=callback)
for _ in range(self.n_disc_updates_per_round):
with networks.training(self.reward_train):
# switch to training mode (affects dropout, normalization)
self.train_disc()
if callback:
callback(r)
self.logger.dump(self._global_step)

@overload
Expand Down
28 changes: 24 additions & 4 deletions src/imitation/scripts/train_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sacred.commands
import torch as th
from sacred.observers import FileStorageObserver
from stable_baselines3.common.callbacks import BaseCallback

from imitation.algorithms.adversarial import airl as airl_algo
from imitation.algorithms.adversarial import common
Expand All @@ -22,6 +23,28 @@
logger = logging.getLogger("imitation.scripts.train_adversarial")


class CheckpointCallback(BaseCallback):
def __init__(
self,
trainer: common.AdversarialTrainer,
log_dir: pathlib.Path,
interval: int
):
super().__init__(self)
self.trainer = trainer
self.log_dir = log_dir
self.interval = interval
self.round_num = 0

def _on_step(self) -> bool:
return True

def _on_training_end(self) -> None:
self.round_num += 1
if self.interval > 0 and self.round_num % self.interval == 0:
save(self.trainer, self.log_dir / "checkpoints" / f"{self.round_num:05d}")


def save(trainer: common.AdversarialTrainer, save_path: pathlib.Path):
"""Save discriminator and generator."""
# We implement this here and not in Trainer since we do not want to actually
Expand Down Expand Up @@ -153,10 +176,7 @@ def train_adversarial(
**algorithm_kwargs,
)

def callback(round_num: int, /) -> None:
if checkpoint_interval > 0 and round_num % checkpoint_interval == 0:
save(trainer, log_dir / "checkpoints" / f"{round_num:05d}")

callback = CheckpointCallback(trainer, log_dir, checkpoint_interval)
trainer.train(total_timesteps, callback)
imit_stats = policy_evaluation.eval_policy(trainer.policy, trainer.venv_train)

Expand Down
37 changes: 37 additions & 0 deletions tests/algorithms/test_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import stable_baselines3
import torch as th
from stable_baselines3.common import policies
from stable_baselines3.common.callbacks import BaseCallback
from torch.utils import data as th_data

from imitation.algorithms.adversarial import airl, common, gail
Expand Down Expand Up @@ -464,3 +465,39 @@ def test_regression_gail_with_sac(
reward_net=reward_net,
)
gail_trainer.train(8)


def test_gen_callback(trainer: common.AdversarialTrainer):
learner = stable_baselines3.PPO("MlpPolicy", env=trainer.venv)

def make_fn_callback(calls, key):
def cb(_a, _b):
calls[key] += 1
return cb

class SB3Callback(BaseCallback):
def __init__(self, calls, key):
super().__init__(self)
self.calls = calls
self.key = key

def _on_step(self):
self.calls[self.key] += 1
return True

n_steps = trainer.gen_train_timesteps * 2
calls = {"fn": 0, "sb3": 0, "list.0": 0, "list.1": 0}

trainer.train(n_steps, callback=make_fn_callback(calls, "fn"))
trainer.train(n_steps, callback=SB3Callback(calls, "sb3"))
trainer.train(n_steps, callback=[
SB3Callback(calls, "list.0"),
SB3Callback(calls, "list.1")
])

# Env steps for off-plicy algos (DQN) may exceed `total_timesteps`,
# so we check if the callback was called *at least* that many times.
assert calls["fn"] >= n_steps
assert calls["sb3"] >= n_steps
assert calls["list.0"] >= n_steps
assert calls["list.1"] >= n_steps

0 comments on commit eb8f67e

Please sign in to comment.