Skip to content

Commit

Permalink
Merge pull request #43 from narumiruna/ruff
Browse files Browse the repository at this point in the history
Ruff
  • Loading branch information
narumiruna authored Dec 21, 2023
2 parents bd0ad5d + 21aa5df commit f64af5c
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 1,196 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
cache: poetry
- uses: chartboost/ruff-action@v1
- name: Install dependencies
run: |
poetry install
poetry run pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
- name: Lint
run: poetry run flake8 -v .
- name: Test
run: poetry run pytest -v -s --cov=. --cov-report=xml tests
- name: Upload coverage reports to Codecov
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ install:
poetry install

lint:
poetry run flake8 -v
poetry run ruff check .

test:
poetry run pytest -v -s --cov=. tests
Expand Down
1 change: 1 addition & 0 deletions configs/mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ job:
trainer:
name: MNISTTrainer
num_epochs: 20
num_classes: 10

dataset:
name: MNISTDataLoader
Expand Down
1,175 changes: 34 additions & 1,141 deletions poetry.lock

Large diffs are not rendered by default.

31 changes: 27 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@ torchmetrics = "^1.2.0"
tqdm = "^4.66.1"
loguru = "^0.7.2"
mlconfig = "^0.2.0"
mlflow = "^2.9.2"
mlflow-skinny = "^2.9.2"

[tool.poetry.group.dev.dependencies]
black = "^23.11.0"
flake8 = "^6.0.0"
isort = "^5.12.0"
pytest = "^7.3.1"
pytest-cov = "^4.1.0"
ruff = "^0.1.8"
toml = "^0.10.2"

[build-system]
Expand All @@ -28,3 +26,28 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry.scripts]
template = "template.cli:main"

[tool.ruff]
exclude = ["build"]
line-length = 120

[tool.ruff.lint]
select = [
"B", # flake8-bugbear
"C", # flake8-comprehensions
"E", # pycodestyle errors
"F", # pyflakes
"I", # isort
# "UP", # pyupgrade
"W", # pycodestyle warnings

]

[tool.ruff.per-file-ignores]
"__init__.py" = ["F401", "F403"]

[tool.ruff.isort]
force-single-line = true

[tool.pytest.ini_options]
filterwarnings = ["ignore::DeprecationWarning"]
17 changes: 0 additions & 17 deletions setup.cfg

This file was deleted.

67 changes: 39 additions & 28 deletions template/trainers/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import torch
import torch.nn.functional as F
from mlconfig import register
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchmetrics import MeanMetric
from tqdm import tqdm
Expand All @@ -13,48 +17,55 @@
@register
class MNISTTrainer(Trainer):
def __init__(
self, device, model, optimizer, scheduler, train_loader, test_loader, num_epochs
):
self,
device: torch.device,
model: Module,
optimizer: Optimizer,
scheduler: LRScheduler,
train_loader: DataLoader,
test_loader: DataLoader,
num_epochs: int,
num_classes: int,
) -> None:
self.device = device
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.train_loader = train_loader
self.test_loader = test_loader
self.num_epochs = num_epochs
self.num_classes = num_classes

self.epoch = 1
self.best_acc = 0
self.state = {"epoch": 1}

def fit(self):
for self.epoch in trange(self.epoch, self.num_epochs + 1):
def fit(self) -> None:
for epoch in trange(self.state["epoch"], self.num_epochs + 1):
train_loss, train_acc = self.train()
test_loss, test_acc = self.evaluate()
self.scheduler.step()

metrics = dict(
train_loss=train_loss,
train_acc=train_acc,
test_loss=test_loss,
test_acc=test_acc,
)
mlflow.log_metrics(metrics, step=self.epoch)

format_string = "Epoch: {}/{}, ".format(self.epoch, self.num_epochs)
format_string += "train loss: {:.4f}, train acc: {:.4f}, ".format(
train_loss, train_acc
)
format_string += "test loss: {:.4f}, test acc: {:.4f}, ".format(
test_loss, test_acc
)
metrics = {
"train_loss": train_loss,
"train_acc": train_acc,
"test_loss": test_loss,
"test_acc": test_acc,
}
mlflow.log_metrics(metrics, step=epoch)

format_string = "Epoch: {}/{}, ".format(epoch, self.num_epochs)
format_string += "train loss: {:.4f}, train acc: {:.4f}, ".format(train_loss, train_acc)
format_string += "test loss: {:.4f}, test acc: {:.4f}, ".format(test_loss, test_acc)
format_string += "best test acc: {:.4f}.".format(self.best_acc)
tqdm.write(format_string)

def train(self):
self.state["epoch"] = epoch

def train(self) -> None:
self.model.train()

loss_metric = MeanMetric()
acc_metric = Accuracy()
acc_metric = Accuracy(task="multiclass", num_classes=self.num_classes)

for x, y in tqdm(self.train_loader):
x = x.to(self.device)
Expand All @@ -73,11 +84,11 @@ def train(self):
return loss_metric.compute().item(), acc_metric.compute().item()

@torch.no_grad()
def evaluate(self):
def evaluate(self) -> None:
self.model.eval()

loss_metric = MeanMetric()
acc_metric = Accuracy()
acc_metric = Accuracy(task="multiclass", num_classes=self.num_classes)

for x, y in tqdm(self.test_loader):
x = x.to(self.device)
Expand All @@ -96,25 +107,25 @@ def evaluate(self):

return loss_metric.compute().item(), test_acc

def save_checkpoint(self, f):
def save_checkpoint(self, f) -> None:
self.model.eval()

checkpoint = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict(),
"epoch": self.epoch,
"state": self.state,
"best_acc": self.best_acc,
}

torch.save(checkpoint, f)
mlflow.log_artifact(f)

def resume(self, f):
def resume(self, f) -> None:
checkpoint = torch.load(f, map_location=self.device)

self.model.load_state_dict(checkpoint["model"])
self.optimizer.load_state_dict(checkpoint["optimizer"])
self.scheduler.load_state_dict(checkpoint["scheduler"])
self.epoch = checkpoint["epoch"] + 1
self.state = checkpoint["state"]
self.best_acc = checkpoint["best_acc"]
2 changes: 1 addition & 1 deletion template/trainers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
class Trainer:
def train(self):
def train(self) -> None:
raise NotImplementedError
3 changes: 1 addition & 2 deletions template/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
from pathlib import Path

import numpy as np
import torch
import yaml

from pathlib import Path


def manual_seed(seed=0):
"""https://pytorch.org/docs/stable/notes/randomness.html"""
Expand Down
1 change: 1 addition & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch

from template.models import LeNet


Expand Down

0 comments on commit f64af5c

Please sign in to comment.