Skip to content

Commit

Permalink
Merge pull request #6 from benoitmartin88/rc
Browse files Browse the repository at this point in the history
release 0.2.0
  • Loading branch information
benoitmartin88 authored Sep 20, 2019
2 parents 9277077 + aa7a15b commit 2a4d0a3
Show file tree
Hide file tree
Showing 14 changed files with 236 additions and 94 deletions.
35 changes: 35 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
Changelog
=========
All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

# Unreleased


# # [0.2.0] - 2019-09-20
## New
- Add `ModuleTrainer.evaluate` method
- Add CsvWriter to evaluate method
- Add `filename_transform_function` argument to `SaveBestCheckpointCallback`
- Metric step method now return the intermediate computed values
- Rename `CsvWriter`'s `extra` argument to `extra_data_function`
- `CsvWriter` can now be called with the `extra_data` extra argument to define the extra data that will be logged
- Add `ModuleTrainer.load` method

## Change
- Rename `dateset_loader` to `dataloader`


# [0.1.0] - 2019-09-16
## New
- `ModuleTrainer` object
- `EarlyStopping`: stop training after a configurable period of stagnation
- Checkpointing: save model and estimator at regular intervals
- CSV file writer to output logs
- Several metrics are available: all default PyTorch loss functions, Accuracy, MAE
- Progress bar from console
- SIGINT handling: handle CTRL-C
- Model's data type (float32, float64)
- Full use of Pytorch's Cuda support
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ trainer.train(train_loader, max_epochs=100)
## Dependencies

- python > 3.5
- pytorch 1.0.1 (install instructions from the official [PyTorch website](https://pytorch.org/get-started/locally))
- pytorch > 1.0.0 (install instructions from the official [PyTorch website](https://pytorch.org/get-started/locally))


## Contributing
Expand Down
2 changes: 1 addition & 1 deletion pytorchtrainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

__version__ = '0.1.0'
__version__ = '0.2.0'


from .trainer import create_default_trainer, ModuleTrainer, State
Expand Down
18 changes: 12 additions & 6 deletions pytorchtrainer/callback/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,31 +68,37 @@ def _load_checkpoint(self, model: nn.Module, optimizer: optim.Optimizer, state):

class SaveBestCheckpointCallback(SaveCheckpointCallback):
def __init__(self, state_metric_name: str, saves_to_keep=5, comparison_function=lambda metric, best: metric < best,
save_directory=default_save_diretory, filename=default_best_filename):
save_directory=default_save_diretory, filename=default_best_filename,
filename_transform_function=None):
super().__init__(save_directory, filename)
self.state_metric_name = state_metric_name
# self.saves_to_keep = saves_to_keep # TODO
self.comparison_function = comparison_function
self.current_best = None

if filename_transform_function is None:
self.filename_transform_function = self._default_filename_transform_function

os.makedirs(self.save_directory, exist_ok=True)

def __call__(self, trainer):
"""
best.pt.tar -> best_METRIC_EPOCH_1.pt.tar
best.pt.tar -> best_EPOCH_METRIC_1.pt.tar
:param trainer:
:return:
"""
if self.current_best is None or self.comparison_function(trainer.state.last_train_loss, self.current_best):
self.current_best = trainer.state.get(self.state_metric_name)

old_filename = self.filename
self.filename = self._get_filename(trainer.state)
self.filename = self.filename_transform_function(self.filename, trainer.state)

self._save_checkpoint(trainer.model, trainer.optimizer, trainer.state)
self.filename = old_filename

def _get_filename(self, state):
c = self.filename.count('.')
base, *ext = self.filename.rsplit('.', c)
@staticmethod
def _default_filename_transform_function(filename, state):
c = filename.count('.')
base, *ext = filename.rsplit('.', c)
return base + "_%d_%.2f_%d." % (state.current_epoch, state.last_train_loss, 1) + '.'.join(ext)

33 changes: 21 additions & 12 deletions pytorchtrainer/callback/file_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

class CsvWriter(Callback):
def __init__(self, save_directory=default_save_directory, filename=default_filename, delimiter=';',
extra_header=None, callback=None):
extra_header=None, extra_data_function=None):
super().__init__()
file, ext = filename.rsplit('.', 1)
self.log_file_path = os.path.join(save_directory, file + '_' + time.strftime("%Y%m%d_%H%M%S") + '.' + ext)
self.delimiter = delimiter
self.callback = callback
self.extra_data_function = extra_data_function

os.makedirs(save_directory, exist_ok=True)

Expand All @@ -29,17 +29,26 @@ def __init__(self, save_directory=default_save_directory, filename=default_filen
writer = csv.writer(writer, delimiter=delimiter, quotechar='"', quoting=csv.QUOTE_MINIMAL)
writer.writerow(header)

def __call__(self, trainer):
self.__save(trainer.state)
def __call__(self, trainer, extra_data: list = None):
self.__save(trainer.state, extra_data)

def __save(self, trainer_state):
def __save(self, trainer_state, extra_data: list = None):
with open(self.log_file_path, mode='a') as writer:
writer = csv.writer(writer, delimiter=self.delimiter, quotechar='"', quoting=csv.QUOTE_MINIMAL)

extra = []
if self.callback is not None:
extra = self.callback(trainer_state)
if not isinstance(extra, list):
raise TypeError("callback should return a list.")

writer.writerow([time.time(), trainer_state.current_epoch+1, trainer_state.current_iteration+1, trainer_state.last_train_loss] + extra)
if extra_data is None:
extra_data = []
if self.extra_data_function is not None:
extra_data = self.extra_data_function(trainer_state)
if not isinstance(extra_data, list):
raise TypeError("callback should return a list.")

if len(extra_data) >= 1 and not isinstance(extra_data[0], list):
# extra_data -> [0, 42, 51]
writer.writerow([time.time(), trainer_state.current_epoch+1, trainer_state.current_iteration+1, trainer_state.last_train_loss] + extra_data)
else:
# extra_data -> [[], [], []]
for data in extra_data:
if not isinstance(data, list):
raise TypeError("callback should return a list.")
writer.writerow([time.time(), trainer_state.current_epoch + 1, trainer_state.current_iteration + 1, trainer_state.last_train_loss] + data)
31 changes: 3 additions & 28 deletions pytorchtrainer/callback/validation.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,14 @@
import torch
from .callback import Callback


class ValidationCallback(Callback):
def __init__(self, dataset_loader, metric, device=None, dtype=None, non_blocking=False):
def __init__(self, dataloader, metric, device=None, dtype=None, non_blocking=False):
super().__init__(state_attribute_name="last_validation_%s" % metric.name, state_attribute_default_value=metric.default_value)
self.dataset_loader = dataset_loader
self.dataloader = dataloader
self.metric = metric
self.device = device
self.dtype = dtype
self.non_blocking = non_blocking

def __call__(self, trainer):
setattr(trainer.state, self.state_attribute_name, self._validation_function(trainer.model, trainer.prepare_batch_function))

def _validation_function(self, model, prepare_batch_function):
model.eval()

device_to_use = self.device
models_device = next(model.parameters()).device

if self.device is None:
# use the model's device
device_to_use = models_device

model.to(device_to_use)

self.metric.reset()

with torch.no_grad():
for batch in self.dataset_loader:
x, y, model_args = prepare_batch_function(batch, device=device_to_use, dtype=self.dtype, non_blocking=self.non_blocking)
y_pred = model(x, **model_args)
self.metric.step(y, y_pred)

model.to(models_device) # this will be a no-op if the device has not changed
return self.metric.compute()

setattr(trainer.state, self.state_attribute_name, trainer.evaluate(self.dataloader, self.metric))
2 changes: 1 addition & 1 deletion pytorchtrainer/metric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def __init__(self, name: str, default_value=None):
self.name = name.replace(' ', '_')
self.default_value = default_value

def step(self, y: torch.Tensor, y_pred: torch.Tensor):
def step(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()

def compute(self):
Expand Down
1 change: 1 addition & 0 deletions pytorchtrainer/metric/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def step(self, y: torch.Tensor, y_pred: torch.Tensor):

self._total_correct += torch.sum(correct).item()
self._total += correct.size(dim=0) # dim 0 should be batch size
return torch.sum(correct).float() / correct.size(dim=0)

def compute(self):
if self._total == 0:
Expand Down
1 change: 1 addition & 0 deletions pytorchtrainer/metric/mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def step(self, y: torch.Tensor, y_pred: torch.Tensor):
absolute_errors = torch.abs(y - y_pred)
self._absolute_error_sum += torch.sum(absolute_errors).item()
self._total += y.size(dim=0) # dim 0 should be batch size
return torch.sum(absolute_errors)

def compute(self):
if self._total == 0:
Expand Down
4 changes: 3 additions & 1 deletion pytorchtrainer/metric/torch_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ def __init__(self, loss_function: torch.nn.modules.loss):
self._total = 0

def step(self, y: torch.Tensor, y_pred: torch.Tensor):
self._loss_sum += self.loss_function(y_pred, y).item()
loss = self.loss_function(y_pred, y)
self._loss_sum += loss.item()
self._total += 1
return loss

def compute(self):
if self._total == 0:
Expand Down
Loading

0 comments on commit 2a4d0a3

Please sign in to comment.