diff --git a/micromind/core.py b/micromind/core.py
index 835b1ec..35b7f5a 100644
--- a/micromind/core.py
+++ b/micromind/core.py
@@ -5,6 +5,7 @@
Authors:
- Francesco Paissan, 2023
"""
+
from abc import ABC, abstractmethod
from argparse import Namespace
from dataclasses import dataclass
@@ -331,7 +332,7 @@ def add_forward_to_modules(self):
self.modules.device = self.device
@torch.no_grad()
- def compute_params(self):
+ def compute_params(self, str="total"):
"""Computes the number of parameters for the modules inside `self.modules`.
Returns a dictionary with the parameter count for each module.
@@ -341,8 +342,12 @@ def compute_params(self):
"""
self.eval()
params = {}
- for k, m in self.modules.items():
- params[k] = summary(m, verbose=0).total_params
+ if str == "total":
+ for k, m in self.modules.items():
+ params[k] = summary(m, verbose=0).total_params
+ if str == "trainable":
+ for k, m in self.modules.items():
+ params[k] = summary(m, verbose=0).trainable_params
return params
@@ -451,6 +456,10 @@ def on_train_end(self):
"""Runs at the end of each training. Cleans up before exiting."""
pass
+ def on_train_epoch_end(self):
+ """Runs at the end of each training epoch. Cleans up before exiting."""
+ pass
+
def eval(self):
self.modules.eval()
@@ -460,6 +469,7 @@ def train(
datasets: Dict = {},
metrics: List[Metric] = [],
checkpointer: Optional[Checkpointer] = None,
+ max_norm=10.0,
debug: Optional[bool] = False,
) -> None:
"""
@@ -525,12 +535,12 @@ def train(
loss_epoch += loss.item()
self.accelerator.backward(loss)
+ self.accelerator.clip_grad_norm_(
+ self.modules.parameters(), max_norm=max_norm
+ )
self.opt.step()
loss_epoch += loss.item()
- if hasattr(self, "lr_sched"):
- # ok for cos_lr
- self.lr_sched.step()
for m in self.metrics:
if (
@@ -563,21 +573,29 @@ def train(
if "val" in datasets:
val_metrics = self.validate()
- if (
- self.accelerator.is_local_main_process
- and self.checkpointer is not None
- ):
- self.checkpointer(
- self,
- train_metrics,
- val_metrics,
- )
else:
- val_metrics = train_metrics.update({"val_loss": loss_epoch / (idx + 1)})
+ train_metrics.update({"val_loss": loss_epoch / (idx + 1)})
+ val_metrics = train_metrics
+
+ self.on_train_epoch_end()
+
+ if self.accelerator.is_local_main_process and self.checkpointer is not None:
+ self.checkpointer(
+ self,
+ train_metrics,
+ val_metrics,
+ )
if e >= 1 and self.debug:
break
+ if hasattr(self, "lr_sched"):
+ # ok for cos_lr
+ # self.lr_sched.step(val_metrics["val_loss"])
+
+ self.lr_sched.step()
+ print(f"sched step - new LR={self.lr_sched.get_lr()}")
+
self.on_train_end()
return None
diff --git a/micromind/networks/yolo.py b/micromind/networks/yolo.py
index f259f65..fbb8c51 100644
--- a/micromind/networks/yolo.py
+++ b/micromind/networks/yolo.py
@@ -464,6 +464,10 @@ def __init__(
self.heads = heads
self.up1 = Upsample(up[0], mode="nearest")
self.up2 = Upsample(up[1], mode="nearest")
+
+ # print(filters, heads)
+ # breakpoint()
+
self.n1 = XiConv(
c_in=int(filters[1] + filters[2]),
c_out=int(filters[1]),
@@ -471,6 +475,7 @@ def __init__(
gamma=3,
skip_tensor_in=False,
)
+
self.n2 = XiConv(
int(filters[0] + filters[1]),
int(filters[0]),
@@ -483,6 +488,10 @@ def __init__(
the needed blocks. Otherwise the not needed blocks would be initialized
(and thus would occupy space) but will never be used.
"""
+ self.n3 = None
+ self.n4 = None
+ self.n5 = None
+ self.n6 = None
if self.heads[1] or self.heads[2]:
self.n3 = XiConv(
int(filters[0]),
@@ -519,6 +528,75 @@ def __init__(
)
+class Yolov8NeckOpt_gamma2(Yolov8Neck):
+ def __init__(
+ self, filters=[256, 512, 768], up=[2, 2], heads=[True, True, True], d=1
+ ):
+ super().__init__()
+ self.heads = heads
+ self.up1 = Upsample(up[0], mode="nearest")
+ self.up2 = Upsample(up[1], mode="nearest")
+
+ self.n1 = XiConv(
+ c_in=int(filters[1] + filters[2]),
+ c_out=int(filters[1]),
+ kernel_size=3,
+ gamma=2,
+ skip_tensor_in=False,
+ )
+
+ self.n2 = XiConv(
+ int(filters[0] + filters[1]),
+ int(filters[0]),
+ kernel_size=3,
+ gamma=2,
+ skip_tensor_in=False,
+ )
+ """
+ Only if we decide to use the 2nd and 3rd detection head we define
+ the needed blocks. Otherwise the not needed blocks would be initialized
+ (and thus would occupy space) but will never be used.
+ """
+ self.n3 = None
+ self.n4 = None
+ self.n5 = None
+ self.n6 = None
+ if self.heads[1] or self.heads[2]:
+ self.n3 = XiConv(
+ int(filters[0]),
+ int(filters[0]),
+ kernel_size=3,
+ gamma=2,
+ stride=2,
+ padding=1,
+ skip_tensor_in=False,
+ )
+ self.n4 = XiConv(
+ int(filters[0] + filters[1]),
+ int(filters[1]),
+ kernel_size=3,
+ gamma=2,
+ skip_tensor_in=False,
+ )
+ if self.heads[2]:
+ self.n5 = XiConv(
+ int(filters[1]),
+ int(filters[1]),
+ gamma=2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ skip_tensor_in=False,
+ )
+ self.n6 = XiConv(
+ int(filters[1] + filters[2]),
+ int(filters[2]),
+ gamma=2,
+ kernel_size=3,
+ skip_tensor_in=False,
+ )
+
+
class DetectionHead(nn.Module):
"""Implements YOLOv8's detection head.
@@ -537,6 +615,7 @@ def __init__(self, nc=80, filters=(), heads=[True, True, True]):
super().__init__()
self.reg_max = 16
self.nc = nc
+ # filters = [f for f, h in zip(filters, heads) if h]
self.nl = len(filters)
self.no = nc + self.reg_max * 4
self.stride = torch.tensor([8.0, 16.0, 32.0], dtype=torch.float16)
@@ -615,14 +694,16 @@ class YOLOv8(nn.Module):
Number of classes to predict.
"""
- def __init__(self, w, r, d, num_classes=80):
+ def __init__(self, w, r, d, num_classes=80, heads=[True, True, True]):
super().__init__()
self.net = Darknet(w, r, d)
self.fpn = Yolov8Neck(
- filters=[int(256 * w), int(512 * w), int(512 * w * r)], d=d
+ filters=[int(256 * w), int(512 * w), int(512 * w * r)], heads=heads, d=d
)
self.head = DetectionHead(
- num_classes, filters=(int(256 * w), int(512 * w), int(512 * w * r))
+ num_classes,
+ filters=(int(256 * w), int(512 * w), int(512 * w * r)),
+ heads=heads,
)
def forward(self, x):
diff --git a/recipes/object_detection/README.md b/recipes/object_detection/README.md
index 8e20e9c..2a4ebd9 100644
--- a/recipes/object_detection/README.md
+++ b/recipes/object_detection/README.md
@@ -1,5 +1,6 @@
## Object Detection using YOLO
+**[16 Jan 2024]** Updated training code for better performance. Added ultralytics metrics calculation .
**[16 Jan 2024]** Added optimized YOLO neck, using XiConv. Fixed compatibility with ultralytics weights.
**[17 Dec 2023]** Add VOC dataset, selective head option, and instructions for dataset download.
**[1 Dec 2023]** Fix DDP handling and computational graph.
diff --git a/recipes/object_detection/cfg/data/VOC.yaml b/recipes/object_detection/cfg/data/VOC.yaml
index 3b6e115..5249510 100644
--- a/recipes/object_detection/cfg/data/VOC.yaml
+++ b/recipes/object_detection/cfg/data/VOC.yaml
@@ -43,7 +43,7 @@ mixup: 0.0 # (float) image mixup (probability)
copy_paste: 0.0 # (float) segment copy-paste (probability)
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
-path: ../datasets/VOC
+path: datasets/VOC
train: # train images (relative to 'path') 16551 images
- images/train2012
- images/train2007
diff --git a/recipes/object_detection/cfg/data/coco.yaml b/recipes/object_detection/cfg/data/coco.yaml
index c19f3f9..aff1da1 100644
--- a/recipes/object_detection/cfg/data/coco.yaml
+++ b/recipes/object_detection/cfg/data/coco.yaml
@@ -44,7 +44,7 @@ copy_paste: 0.0 # (float) segment copy-paste (probability)
# Dataset location
-path: /mnt/data/coco # dataset root dir
+path: datasets/coco # dataset root dir
train: train2017.txt # train images (relative to 'path') 118287 images
val: val2017.txt # val images (relative to 'path') 5000 images
test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
diff --git a/recipes/object_detection/cfg/data/coco8.yaml b/recipes/object_detection/cfg/data/coco8.yaml
index 3e01c8c..596f837 100644
--- a/recipes/object_detection/cfg/data/coco8.yaml
+++ b/recipes/object_detection/cfg/data/coco8.yaml
@@ -44,7 +44,7 @@ copy_paste: 0.0 # (float) segment copy-paste (probability)
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
-path: /mnt/data/coco8 # dataset root dir
+path: datasets/coco8 # dataset root dir
train: images/train # train images (relative to 'path') 4 images
val: images/val # val images (relative to 'path') 4 images
test: # test images (optional)
diff --git a/recipes/object_detection/cfg/yolo_phinet.py b/recipes/object_detection/cfg/yolo_phinet.py
index 11d25d6..7dbf0a5 100644
--- a/recipes/object_detection/cfg/yolo_phinet.py
+++ b/recipes/object_detection/cfg/yolo_phinet.py
@@ -5,16 +5,18 @@
- Matteo Beltrami, 2023
- Francesco Paissan, 2023
"""
+
# Data configuration
batch_size = 8
-data_cfg = "cfg/data/coco.yaml"
-data_dir = "data/coco"
-epochs = 200
+data_cfg = "cfg/data/VOC.yaml"
+data_dir = "datasets/coco"
+epochs = 350
+num_classes = 80
# Model configuration
input_shape = [3, 640, 640]
-alpha = 2.3
-num_layers = 7
+alpha = 1.1
+num_layers = 8
beta = 0.75
t_zero = 5
divisor = 8
diff --git a/recipes/object_detection/inference.py b/recipes/object_detection/inference.py
index 9399c1f..5ebff90 100644
--- a/recipes/object_detection/inference.py
+++ b/recipes/object_detection/inference.py
@@ -27,25 +27,16 @@
preprocess,
)
from train import YOLO
+from micromind.utils.yolo import load_config
class Inference(YOLO):
- def __init__(self, hparams):
- super().__init__(hparams=hparams, m_cfg={})
-
- def forward(self, img):
- """Executes the detection network.
-
- Arguments
- ---------
- bacth : List[torch.Tensor]
- Input to the detection network.
-
- Returns
- -------
- Output of the detection network : torch.Tensor
- """
- backbone = self.modules["backbone"](img)
+ def __init__(self, m_cfg, hparams):
+ super().__init__(m_cfg, hparams=hparams)
+
+ def forward(self, batch):
+ """Runs the forward method by calling every module."""
+ backbone = self.modules["backbone"](batch)
neck_input = backbone[1]
neck_input.append(self.modules["sppf"](backbone[0]))
neck = self.modules["neck"](*neck_input)
@@ -73,6 +64,8 @@ def forward(self, img):
img_paths = [sys.argv[2]]
for img_path in img_paths:
image = torchvision.io.read_image(img_path)
+ if image.shape[0] == 4:
+ image = image[:3, :, :] # Mantieni solo i primi 3 canali (RGB)
out_paths = [
(
output_folder_path
@@ -85,7 +78,8 @@ def forward(self, img):
pre_processed_image = preprocess(image)
- model = Inference(hparams)
+ m_cfg, data_cfg = load_config(hparams.data_cfg)
+ model = Inference(m_cfg, hparams=hparams)
# Load pretrained if passed.
if hparams.ckpt_pretrained != "":
model.load_modules(hparams.ckpt_pretrained)
@@ -97,11 +91,13 @@ def forward(self, img):
with torch.no_grad():
st = time.time()
- predictions = model(pre_processed_image)
+ predictions = model.forward(pre_processed_image)
print(f"Inference took {int(round(((time.time() - st) * 1000)))}ms")
+ breakpoint()
post_predictions = postprocess(
preds=predictions[0], img=pre_processed_image, orig_imgs=image
)
+ breakpoint()
class_labels = [s.strip() for s in open(hparams.coco_names, "r").readlines()]
draw_bounding_boxes_and_save(
@@ -112,4 +108,3 @@ def forward(self, img):
)
# Exporting onnx model.
- # model.export("model.onnx", "onnx", hparams.input_shape)
diff --git a/recipes/object_detection/prepare_data.py b/recipes/object_detection/prepare_data.py
index bb3b570..b4e4e6a 100644
--- a/recipes/object_detection/prepare_data.py
+++ b/recipes/object_detection/prepare_data.py
@@ -6,6 +6,7 @@
- Matteo Beltrami, 2023
- Francesco Paissan, 2023
"""
+
from typing import Dict
import os
diff --git a/recipes/object_detection/train.py b/recipes/object_detection/train.py
index dfc6f91..7f36eb7 100644
--- a/recipes/object_detection/train.py
+++ b/recipes/object_detection/train.py
@@ -12,38 +12,43 @@
"""
import torch
+import torch.nn as nn
+import torch.optim as optim
from prepare_data import create_loaders
-from ultralytics.utils.ops import scale_boxes, xywh2xyxy
from yolo_loss import Loss
+import math
import micromind as mm
from micromind.networks import PhiNet
-from micromind.networks.yolo import SPPF, DetectionHead, Yolov8Neck, Yolov8NeckOpt
+from micromind.networks.yolo import SPPF, Yolov8Neck, DetectionHead
from micromind.utils import parse_configuration
-from micromind.utils.yolo import (
- load_config,
- mean_average_precision,
- postprocess,
-)
+from micromind.utils.yolo import load_config
import sys
import os
+from micromind.utils.yolo import get_variant_multiples
+from validation.validator import DetectionValidator
class YOLO(mm.MicroMind):
def __init__(self, m_cfg, hparams, *args, **kwargs):
"""Initializes the YOLO model."""
super().__init__(*args, **kwargs)
+
self.m_cfg = m_cfg
+ w, r, d = get_variant_multiples("n")
self.modules["backbone"] = PhiNet(
input_shape=hparams.input_shape,
alpha=hparams.alpha,
- num_layers=hparams.num_layers,
beta=hparams.beta,
t_zero=hparams.t_zero,
+ num_layers=hparams.num_layers,
+ h_swish=False,
+ squeeze_excite=True,
include_top=False,
- compatibility=False,
+ num_classes=hparams.num_classes,
divisor=hparams.divisor,
+ compatibility=False,
downsampling_layers=hparams.downsampling_layers,
return_layers=hparams.return_layers,
)
@@ -53,11 +58,13 @@ def __init__(self, m_cfg, hparams, *args, **kwargs):
)
self.modules["sppf"] = SPPF(*sppf_ch)
- self.modules["neck"] = Yolov8NeckOpt(
+ self.modules["neck"] = Yolov8Neck(
filters=neck_filters, up=up, heads=hparams.heads
)
- self.modules["head"] = DetectionHead(filters=head_filters, heads=hparams.heads)
+ self.modules["head"] = DetectionHead(
+ hparams.num_classes, filters=head_filters, heads=hparams.heads
+ )
self.criterion = Loss(self.m_cfg, self.modules["head"], self.device)
print("Number of parameters for each module:")
@@ -131,8 +138,23 @@ def preprocess_batch(self, batch):
def forward(self, batch):
"""Runs the forward method by calling every module."""
- preprocessed_batch = self.preprocess_batch(batch)
- backbone = self.modules["backbone"](preprocessed_batch["img"].to(self.device))
+ if self.modules.training:
+ preprocessed_batch = self.preprocess_batch(batch)
+ backbone = self.modules["backbone"](
+ preprocessed_batch["img"].to(self.device)
+ )
+ else:
+
+ if torch.is_tensor(batch):
+ backbone = self.modules["backbone"](batch)
+ neck_input = backbone[1]
+ neck_input.append(self.modules["sppf"](backbone[0]))
+ neck = self.modules["neck"](*neck_input)
+ head = self.modules["head"](neck)
+ return head
+
+ backbone = self.modules["backbone"](batch["img"] / 255)
+
neck_input = backbone[1]
neck_input.append(self.modules["sppf"](backbone[0]))
neck = self.modules["neck"](*neck_input)
@@ -151,61 +173,159 @@ def compute_loss(self, pred, batch):
return lossi_sum
+ def build_optimizer(
+ self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e6
+ ):
+ """
+ Constructs an optimizer for the given model, based on the specified optimizer
+ name, learning rate, momentum, weight decay, and number of iterations.
+
+ Args:
+ model (torch.nn.Module): The model for which to build an optimizer.
+ name (str, optional): The name of the optimizer to use. If 'auto', the
+ optimizer is selected based on the number of iterations.
+ Default: 'auto'.
+ lr (float, optional): The learning rate for the optimizer. Default: 0.001.
+ momentum (float, optional): The momentum factor for the optimizer.
+ Default: 0.9.
+ decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
+ iterations (float, optional): The number of iterations, which determines
+ the optimizer if name is 'auto'. Default: 1e5.
+
+ Returns:
+ (torch.optim.Optimizer): The constructed optimizer.
+ """
+
+ g = [], [], [] # optimizer parameter groups
+ bn = tuple(
+ v for k, v in nn.__dict__.items() if "Norm" in k
+ ) # normalization layers, i.e. BatchNorm2d()
+ if name == "auto":
+ print(
+ f"optimizer: 'optimizer=auto' found, "
+ f"ignoring 'lr0={lr}' and 'momentum={momentum}' and "
+ f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
+ )
+ nc = getattr(model, "nc", 80) # number of classes
+ lr_fit = round(
+ 0.002 * 5 / (4 + nc), 6
+ ) # lr0 fit equation to 6 decimal places
+ name, lr, momentum = ("AdamW", lr_fit, 0.9)
+ lr *= 10
+ # self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
+
+ for module_name, module in model.named_modules():
+ for param_name, param in module.named_parameters(recurse=False):
+ fullname = f"{module_name}.{param_name}" if module_name else param_name
+ if "bias" in fullname: # bias (no decay)
+ g[2].append(param)
+ elif isinstance(module, bn): # weight (no decay)
+ g[1].append(param)
+ else: # weight (with decay)
+ g[0].append(param)
+
+ if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
+ optimizer = getattr(optim, name, optim.Adam)(
+ g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0
+ )
+ elif name == "RMSProp":
+ optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
+ elif name == "SGD":
+ optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
+ else:
+ raise NotImplementedError(
+ f"Optimizer '{name}' not found in list of available optimizers "
+ f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
+ "To request support for addition optimizers please visit"
+ "https://github.com/ultralytics/ultralytics."
+ )
+
+ optimizer.add_param_group(
+ {"params": g[0], "weight_decay": decay}
+ ) # add g0 with weight_decay
+ optimizer.add_param_group(
+ {"params": g[1], "weight_decay": 0.0}
+ ) # add g1 (BatchNorm2d weights)
+ print(
+ f"{optimizer:} {type(optimizer).__name__}(lr={lr}, "
+ f"momentum={momentum}) with parameter groups"
+ f"{len(g[1])} weight(decay=0.0), {len(g[0])} "
+ f"weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
+ )
+ return optimizer, lr
+
+ def _setup_scheduler(self, opt, lrf=0.01, lr0=0.01, cos_lr=True):
+ """Initialize training learning rate scheduler."""
+
+ def one_cycle(y1=0.0, y2=1.0, steps=100):
+ """Returns a lambda function for sinusoidal ramp from y1 to y2
+ https://arxiv.org/pdf/1812.01187.pdf."""
+ return (
+ lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1)
+ + y1
+ )
+
+ lrf *= lr0
+
+ if cos_lr:
+ self.lf = one_cycle(1, lrf, 350) # cosine 1->hyp['lrf']
+ else:
+ self.lf = (
+ lambda x: max(1 - x / self.epochs, 0) * (1.0 - lrf) + lrf
+ ) # linear
+ return optim.lr_scheduler.LambdaLR(opt, lr_lambda=self.lf)
+
def configure_optimizers(self):
"""Configures the optimizer and the scheduler."""
- opt = torch.optim.SGD(self.modules.parameters(), lr=1e-2, weight_decay=0.0005)
- sched = torch.optim.lr_scheduler.CosineAnnealingLR(
- opt, T_max=14000, eta_min=1e-3
- )
+ # opt = torch.optim.SGD(self.modules.parameters(), lr=1e-2, weight_decay=0.0005)
+ # opt = torch.optim.AdamW(
+ # self.modules.parameters(), lr=0.000119, weight_decay=0.0
+ # )
+ opt, lr = self.build_optimizer(self.modules, name="auto", lr=0.01, momentum=0.9)
+ sched = self._setup_scheduler(opt, 0.01, lr)
+
return opt, sched
@torch.no_grad()
- def mAP(self, pred, batch):
- """Compute the mean average precision (mAP) for a batch of predictions.
-
- Arguments
- ---------
- pred : torch.Tensor
- Model predictions for the batch.
- batch : dict
- A dictionary containing batch information, including bounding boxes,
- classes and shapes.
-
- Returns
- -------
- torch.Tensor
- A tensor containing the computed mean average precision (mAP) for the batch.
+ def on_train_epoch_end(self):
"""
- preprocessed_batch = self.preprocess_batch(batch)
- post_predictions = postprocess(
- preds=pred[0], img=preprocessed_batch, orig_imgs=batch
+ Computes the mean average precision (mAP) at the end of the training epoch
+ and logs the metrics in `metrics.txt` inside the experiment folder.
+ The `verbose` argument if set to `True` prints details regarding the
+ number of images, instances and metrics for each class of the dataset.
+ The `plots` argument, if set to `True`, saves in the `runs/detect/train`
+ folder the plots of the confusion matrix, the F1-Confidence,
+ Precision-Confidence, Precision-Recall, Recall-Confidence curves and the
+ predictions and labels of the first three batches of images.
+ """
+ args = dict(
+ model="yolov8n.pt", data=hparams.data_cfg, verbose=False, plots=False
)
-
- batch_bboxes_xyxy = xywh2xyxy(batch["bboxes"])
- dim = batch["resized_shape"][0][0]
- batch_bboxes_xyxy[:, :4] *= dim
-
- batch_bboxes = []
- for i in range(len(batch["batch_idx"])):
- for b in range(len(batch_bboxes_xyxy[batch["batch_idx"] == i, :])):
- batch_bboxes.append(
- scale_boxes(
- batch["resized_shape"][i],
- batch_bboxes_xyxy[batch["batch_idx"] == i, :][b],
- batch["ori_shape"][i],
- )
- )
-
- batch_bboxes = torch.stack(batch_bboxes).to(self.device)
- mmAP = mean_average_precision(
- post_predictions, batch, batch_bboxes, data_cfg["nc"]
+ validator = DetectionValidator(args=args)
+
+ validator(model=self)
+
+ val_metrics = [
+ validator.metrics.box.map * 100,
+ validator.metrics.box.map50 * 100,
+ validator.metrics.box.map75 * 100,
+ ]
+ metrics_file = os.path.join(exp_folder, "val_log.txt")
+ metrics_info = (
+ f"Epoch {self.current_epoch}: "
+ f"mAP50-95(B): {round(val_metrics[0], 3)}%; "
+ f"mAP50(B): {round(val_metrics[1], 3)}%; "
+ f"mAP75(B): {round(val_metrics[2], 3)}%\n"
)
- return torch.Tensor([mmAP])
+ with open(metrics_file, "a") as file:
+ file.write(metrics_info)
+ return
def replace_datafolder(hparams, data_cfg):
"""Replaces the data root folder, if told to do so from the configuration."""
+ print(data_cfg["train"])
data_cfg["path"] = str(data_cfg["path"])
data_cfg["path"] = (
data_cfg["path"][:-1] if data_cfg["path"][-1] == "/" else data_cfg["path"]
@@ -240,7 +360,7 @@ def replace_datafolder(hparams, data_cfg):
m_cfg, data_cfg = load_config(hparams.data_cfg)
# check if specified path for images is different, correct it in case
- data_cfg = replace_datafolder(hparams, data_cfg)
+ # data_cfg = replace_datafolder(hparams, data_cfg)
m_cfg.imgsz = hparams.input_shape[-1] # temp solution
train_loader, val_loader = create_loaders(m_cfg, data_cfg, hparams.batch_size)
@@ -255,12 +375,10 @@ def replace_datafolder(hparams, data_cfg):
yolo_mind = YOLO(m_cfg, hparams=hparams)
- mAP = mm.Metric("mAP", yolo_mind.mAP, eval_only=True, eval_period=1)
-
yolo_mind.train(
epochs=hparams.epochs,
datasets={"train": train_loader, "val": val_loader},
- metrics=[mAP],
+ metrics=[],
checkpointer=checkpointer,
debug=hparams.debug,
)
diff --git a/recipes/object_detection/train_yolov8.py b/recipes/object_detection/train_yolov8.py
new file mode 100644
index 0000000..58905f7
--- /dev/null
+++ b/recipes/object_detection/train_yolov8.py
@@ -0,0 +1,317 @@
+"""
+YOLO training.
+
+This code allows you to train an object detection model with the YOLOv8 neck and loss.
+
+To run this script, you can start it with:
+ python train_yolov8.py cfg/.py
+
+Authors:
+ - Matteo Beltrami, 2024
+ - Francesco Paissan, 2024
+"""
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from prepare_data import create_loaders
+from yolo_loss import Loss
+import math
+
+import micromind as mm
+from micromind.networks.yolo import Darknet, Yolov8Neck, DetectionHead
+from micromind.utils import parse_configuration
+from micromind.utils.yolo import get_variant_multiples, load_config
+import sys
+import os
+from validation.validator import DetectionValidator
+
+
+class YOLO(mm.MicroMind):
+ def __init__(self, m_cfg, hparams, *args, **kwargs):
+ """Initializes the YOLO model."""
+ super().__init__(*args, **kwargs)
+
+ self.m_cfg = m_cfg
+ w, r, d = get_variant_multiples("n")
+
+ self.modules["backbone"] = Darknet(w, r, d)
+ self.modules["neck"] = Yolov8Neck(
+ filters=[int(256 * w), int(512 * w), int(512 * w * r)],
+ heads=hparams.heads,
+ d=d,
+ )
+ self.modules["head"] = DetectionHead(
+ hparams.num_classes,
+ filters=(int(256 * w), int(512 * w), int(512 * w * r)),
+ heads=hparams.heads,
+ )
+ self.criterion = Loss(self.m_cfg, self.modules["head"], self.device)
+
+ print("Number of parameters for each module:")
+ print(self.compute_params())
+
+ def preprocess_batch(self, batch):
+ """Preprocesses a batch of images by scaling and converting to float."""
+ preprocessed_batch = {}
+ preprocessed_batch["img"] = (
+ batch["img"].to(self.device, non_blocking=True).float() / 255
+ )
+ for k in batch:
+ if isinstance(batch[k], torch.Tensor) and k != "img":
+ preprocessed_batch[k] = batch[k].to(self.device)
+
+ return preprocessed_batch
+
+ def forward(self, batch):
+ """Runs the forward method by calling every module."""
+ if self.modules.training:
+ preprocessed_batch = self.preprocess_batch(batch)
+ backbone = self.modules["backbone"](
+ preprocessed_batch["img"].to(self.device)
+ )
+ else:
+
+ if torch.is_tensor(batch):
+ backbone = self.modules["backbone"](batch)
+ if "sppf" in self.modules.keys():
+ neck_input = backbone[1]
+ neck_input.append(self.modules["sppf"](backbone[0]))
+ else:
+ neck_input = backbone
+ neck = self.modules["neck"](*neck_input)
+ head = self.modules["head"](neck)
+ return head
+
+ backbone = self.modules["backbone"](batch["img"] / 255)
+
+ if "sppf" in self.modules.keys():
+ neck_input = backbone[1]
+ neck_input.append(self.modules["sppf"](backbone[0]))
+ else:
+ neck_input = backbone
+ neck = self.modules["neck"](*neck_input)
+ head = self.modules["head"](neck)
+
+ return head
+
+ def compute_loss(self, pred, batch):
+ """Computes the loss."""
+ preprocessed_batch = self.preprocess_batch(batch)
+
+ lossi_sum, lossi = self.criterion(
+ pred,
+ preprocessed_batch,
+ )
+
+ return lossi_sum
+
+ def build_optimizer(
+ self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e6
+ ):
+ """
+ Constructs an optimizer for the given model, based on the specified optimizer
+ name, learning rate, momentum, weight decay, and number of iterations.
+
+ Args:
+ model (torch.nn.Module): The model for which to build an optimizer.
+ name (str, optional): The name of the optimizer to use. If 'auto', the
+ optimizer is selected based on the number of iterations.
+ Default: 'auto'.
+ lr (float, optional): The learning rate for the optimizer. Default: 0.001.
+ momentum (float, optional): The momentum factor for the optimizer.
+ Default: 0.9.
+ decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
+ iterations (float, optional): The number of iterations, which determines
+ the optimizer if name is 'auto'. Default: 1e5.
+
+ Returns:
+ (torch.optim.Optimizer): The constructed optimizer.
+ """
+
+ g = [], [], [] # optimizer parameter groups
+ bn = tuple(
+ v for k, v in nn.__dict__.items() if "Norm" in k
+ ) # normalization layers, i.e. BatchNorm2d()
+ if name == "auto":
+ print(
+ f"optimizer: 'optimizer=auto' found, "
+ f"ignoring 'lr0={lr}' and 'momentum={momentum}' and "
+ f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
+ )
+ nc = getattr(model, "nc", 80) # number of classes
+ lr_fit = round(
+ 0.002 * 5 / (4 + nc), 6
+ ) # lr0 fit equation to 6 decimal places
+ name, lr, momentum = ("AdamW", lr_fit, 0.9)
+ lr *= 10
+ # self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
+
+ for module_name, module in model.named_modules():
+ for param_name, param in module.named_parameters(recurse=False):
+ fullname = f"{module_name}.{param_name}" if module_name else param_name
+ if "bias" in fullname: # bias (no decay)
+ g[2].append(param)
+ elif isinstance(module, bn): # weight (no decay)
+ g[1].append(param)
+ else: # weight (with decay)
+ g[0].append(param)
+
+ if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
+ optimizer = getattr(optim, name, optim.Adam)(
+ g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0
+ )
+ elif name == "RMSProp":
+ optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
+ elif name == "SGD":
+ optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
+ else:
+ raise NotImplementedError(
+ f"Optimizer '{name}' not found in list of available optimizers "
+ f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
+ "To request support for addition optimizers please visit"
+ "https://github.com/ultralytics/ultralytics."
+ )
+
+ optimizer.add_param_group(
+ {"params": g[0], "weight_decay": decay}
+ ) # add g0 with weight_decay
+ optimizer.add_param_group(
+ {"params": g[1], "weight_decay": 0.0}
+ ) # add g1 (BatchNorm2d weights)
+ print(
+ f"{optimizer:} {type(optimizer).__name__}(lr={lr}, "
+ f"momentum={momentum}) with parameter groups"
+ f"{len(g[1])} weight(decay=0.0), {len(g[0])} "
+ f"weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
+ )
+ return optimizer, lr
+
+ def _setup_scheduler(self, opt, lrf=0.01, lr0=0.01, cos_lr=True):
+ """Initialize training learning rate scheduler."""
+
+ def one_cycle(y1=0.0, y2=1.0, steps=100):
+ """Returns a lambda function for sinusoidal ramp from y1 to y2
+ https://arxiv.org/pdf/1812.01187.pdf."""
+ return (
+ lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1)
+ + y1
+ )
+
+ lrf *= lr0
+
+ if cos_lr:
+ self.lf = one_cycle(1, lrf, 350) # cosine 1->hyp['lrf']
+ else:
+ self.lf = (
+ lambda x: max(1 - x / self.epochs, 0) * (1.0 - lrf) + lrf
+ ) # linear
+ return optim.lr_scheduler.LambdaLR(opt, lr_lambda=self.lf)
+
+ def configure_optimizers(self):
+ """Configures the optimizer and the scheduler."""
+ # opt = torch.optim.SGD(self.modules.parameters(), lr=1e-2, weight_decay=0.0005)
+ # opt = torch.optim.AdamW(
+ # self.modules.parameters(), lr=0.000119, weight_decay=0.0
+ # )
+ opt, lr = self.build_optimizer(self.modules, name="auto", lr=0.01, momentum=0.9)
+ sched = self._setup_scheduler(opt, 0.01, lr)
+
+ return opt, sched
+
+ @torch.no_grad()
+ def on_train_epoch_end(self):
+ """
+ Computes the mean average precision (mAP) at the end of the training epoch
+ and logs the metrics in `metrics.txt` inside the experiment folder.
+ The `verbose` argument if set to `True` prints details regarding the
+ number of images, instances and metrics for each class of the dataset.
+ The `plots` argument, if set to `True`, saves in the `runs/detect/train`
+ folder the plots of the confusion matrix, the F1-Confidence,
+ Precision-Confidence, Precision-Recall, Recall-Confidence curves and the
+ predictions and labels of the first three batches of images.
+ """
+ args = dict(
+ model="yolov8n.pt", data=hparams.data_cfg, verbose=False, plots=False
+ )
+ validator = DetectionValidator(args=args)
+
+ validator(model=self)
+
+ val_metrics = [
+ validator.metrics.box.map * 100,
+ validator.metrics.box.map50 * 100,
+ validator.metrics.box.map75 * 100,
+ ]
+ metrics_file = os.path.join(exp_folder, "val_log.txt")
+ metrics_info = (
+ f"Epoch {self.current_epoch}: "
+ f"mAP50-95(B): {round(val_metrics[0], 3)}%; "
+ f"mAP50(B): {round(val_metrics[1], 3)}%; "
+ f"mAP75(B): {round(val_metrics[2], 3)}%\n"
+ )
+
+ with open(metrics_file, "a") as file:
+ file.write(metrics_info)
+ return
+
+
+def replace_datafolder(hparams, data_cfg):
+ """Replaces the data root folder, if told to do so from the configuration."""
+ print(data_cfg["train"])
+ data_cfg["path"] = str(data_cfg["path"])
+ data_cfg["path"] = (
+ data_cfg["path"][:-1] if data_cfg["path"][-1] == "/" else data_cfg["path"]
+ )
+ for key in ["train", "val"]:
+ if not isinstance(data_cfg[key], list):
+ data_cfg[key] = [data_cfg[key]]
+ new_list = []
+ for tmp in data_cfg[key]:
+ if hasattr(hparams, "data_dir"):
+ if hparams.data_dir != data_cfg["path"]:
+ tmp = str(tmp).replace(data_cfg["path"], "")
+ tmp = tmp[1:] if tmp[0] == "/" else tmp
+ tmp = os.path.join(hparams.data_dir, tmp)
+ new_list.append(tmp)
+ data_cfg[key] = new_list
+
+ data_cfg["path"] = hparams.data_dir
+
+ return data_cfg
+
+
+if __name__ == "__main__":
+ assert len(sys.argv) > 1, "Please pass the configuration file to the script."
+ hparams = parse_configuration(sys.argv[1])
+ if len(hparams.input_shape) != 3:
+ hparams.input_shape = [
+ int(x) for x in "".join(hparams.input_shape).split(",")
+ ] # temp solution
+ print(f"Setting input shape to {hparams.input_shape}.")
+
+ m_cfg, data_cfg = load_config(hparams.data_cfg)
+
+ # check if specified path for images is different, correct it in case
+ # data_cfg = replace_datafolder(hparams, data_cfg)
+ m_cfg.imgsz = hparams.input_shape[-1] # temp solution
+
+ train_loader, val_loader = create_loaders(m_cfg, data_cfg, hparams.batch_size)
+
+ exp_folder = mm.utils.checkpointer.create_experiment_folder(
+ hparams.output_folder, hparams.experiment_name
+ )
+
+ checkpointer = mm.utils.checkpointer.Checkpointer(
+ exp_folder, hparams=hparams, key="loss"
+ )
+
+ yolo_mind = YOLO(m_cfg, hparams=hparams)
+
+ yolo_mind.train(
+ epochs=hparams.epochs,
+ datasets={"train": train_loader, "val": val_loader},
+ metrics=[],
+ checkpointer=checkpointer,
+ debug=hparams.debug,
+ )
diff --git a/recipes/object_detection/validate.py b/recipes/object_detection/validate.py
new file mode 100644
index 0000000..8cf1a56
--- /dev/null
+++ b/recipes/object_detection/validate.py
@@ -0,0 +1,67 @@
+"""
+YOLO training.
+
+This code allows you to validate an object detection model.
+
+To run this script, you can start it with:
+ python validate.py cfg/.py
+
+Authors:
+ - Matteo Beltrami, 2024
+ - Francesco Paissan, 2024
+"""
+
+from micromind.utils import parse_configuration
+from micromind.utils.yolo import (
+ load_config,
+)
+import sys
+from validation.validator import DetectionValidator
+
+from train import YOLO, replace_datafolder
+
+
+class YOLO(YOLO):
+ def __init__(self, m_cfg, hparams, *args, **kwargs):
+ """Initializes the YOLO model."""
+ super().__init__(m_cfg, hparams, *args, **kwargs)
+ self.m_cfg = m_cfg
+ self.device = "cuda"
+
+ def forward(self, img):
+ """Runs the forward method by calling every module."""
+ backbone = self.modules["backbone"](img)
+ if "sppf" in self.modules.keys():
+ neck_input = backbone[1]
+ neck_input.append(self.modules["sppf"](backbone[0]))
+ else:
+ neck_input = backbone
+ neck = self.modules["neck"](*neck_input)
+ head = self.modules["head"](neck)
+
+ return head
+
+
+if __name__ == "__main__":
+ assertion_msg = "Usage: python validate.py "
+ assert len(sys.argv) >= 3, assertion_msg
+
+ hparams = parse_configuration(sys.argv[1])
+ m_cfg, data_cfg = load_config(hparams.data_cfg)
+ data_cfg = replace_datafolder(hparams, data_cfg)
+
+ m_cfg.imgsz = hparams.input_shape[-1] # temp solution
+
+ model_weights_path = sys.argv[2]
+ args = dict(model="yolov8n.pt", data=hparams.data_cfg, verbose=False, plots=False)
+ validator = DetectionValidator(args=args)
+
+ model = YOLO(m_cfg, hparams)
+ model.load_modules(model_weights_path)
+
+ val = validator(model=model)
+
+ print("METRICS:")
+ print("Box map50-95:", round(validator.metrics.box.map * 100, 3), "%")
+ print("Box map50:", round(validator.metrics.box.map50 * 100, 3), "%")
+ print("Box map75:", round(validator.metrics.box.map75 * 100, 3), "%")
diff --git a/recipes/object_detection/validation/autobackend.py b/recipes/object_detection/validation/autobackend.py
new file mode 100644
index 0000000..64ad83c
--- /dev/null
+++ b/recipes/object_detection/validation/autobackend.py
@@ -0,0 +1,661 @@
+import ast
+import contextlib
+import json
+import platform
+import zipfile
+from collections import OrderedDict, namedtuple
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+from PIL import Image
+
+from ultralytics.utils import ARM64, LINUX, LOGGER, ROOT, yaml_load
+from ultralytics.utils.checks import (
+ check_requirements,
+ check_suffix,
+ check_version,
+ check_yaml,
+)
+from ultralytics.utils.downloads import attempt_download_asset, is_url
+
+
+def check_class_names(names):
+ """
+ Check class names.
+
+ Map imagenet class codes to human-readable names if required.
+ Convert lists to dicts.
+ """
+ if isinstance(names, list): # names is a list
+ names = dict(enumerate(names)) # convert to dict
+ if isinstance(names, dict):
+ # Convert 1) string keys to int, i.e. '0' to 0,
+ # and non-string values to strings, i.e. True to 'True'
+ names = {int(k): str(v) for k, v in names.items()}
+ n = len(names)
+ if max(names.keys()) >= n:
+ raise KeyError(
+ f"{n}-class dataset requires class indices 0-{n - 1}, \
+ but you have invalid class indices"
+ f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML."
+ )
+ if isinstance(names[0], str) and names[0].startswith(
+ "n0"
+ ): # imagenet class codes, i.e. 'n01440764'
+ names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")[
+ "map"
+ ] # human-readable names
+ names = {k: names_map[v] for k, v in names.items()}
+ return names
+
+
+def default_class_names(data=None):
+ """Applies default class names to an input YAML file or
+ returns numerical class names."""
+ if data:
+ with contextlib.suppress(Exception):
+ return yaml_load(check_yaml(data))["names"]
+ return {i: f"class{i}" for i in range(999)} # return default if above errors
+
+
+class AutoBackend(nn.Module):
+ """
+ Handles dynamic backend selection for running inference using
+ Ultralytics YOLO models.
+
+ The AutoBackend class is designed to provide an abstraction layer for various
+ inference engines. It supports a wide
+ range of formats, each with specific naming conventions as outlined below:
+
+ Supported Formats and Naming Conventions:
+ | Format | File Suffix |
+ |-----------------------|------------------|
+ | PyTorch | *.pt |
+ | TorchScript | *.torchscript |
+ | ONNX Runtime | *.onnx |
+ | ONNX OpenCV DNN | *.onnx (dnn=True)|
+ | OpenVINO | *openvino_model/ |
+ | CoreML | *.mlpackage |
+ | TensorRT | *.engine |
+ | TensorFlow SavedModel | *_saved_model |
+ | TensorFlow GraphDef | *.pb |
+ | TensorFlow Lite | *.tflite |
+ | TensorFlow Edge TPU | *_edgetpu.tflite |
+ | PaddlePaddle | *_paddle_model |
+ | ncnn | *_ncnn_model |
+
+ This class offers dynamic backend switching capabilities based on the input
+ model format, making it easier to deploy models across various platforms.
+ """
+
+ @torch.no_grad()
+ def __init__(
+ self,
+ weights="yolov8n.pt",
+ device=torch.device("cpu"),
+ dnn=False,
+ data=None,
+ fp16=False,
+ fuse=True,
+ verbose=True,
+ ):
+ """
+ Initialize the AutoBackend for inference.
+
+ Args:
+ weights (str): Path to the model weights file. Defaults to 'yolov8n.pt'.
+ device (torch.device): Device to run the model on. Defaults to CPU.
+ dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False.
+ data (str | Path | optional): Path to the additional data.yaml file
+ containing class names. Optional.
+ fp16 (bool): Enable half-precision inference. Supported only on specific
+ backends. Defaults to False.
+ fuse (bool): Fuse Conv2D + BatchNorm layers for optimization.
+ Defaults to True.
+ verbose (bool): Enable verbose logging. Defaults to True.
+ """
+ super().__init__()
+ self.modell = weights
+ weights = weights.modules
+ w = str(weights[0] if isinstance(weights, list) else weights)
+ nn_module = isinstance(weights, torch.nn.Module)
+ (
+ pt,
+ jit,
+ onnx,
+ xml,
+ engine,
+ coreml,
+ saved_model,
+ pb,
+ tflite,
+ edgetpu,
+ tfjs,
+ paddle,
+ ncnn,
+ triton,
+ ) = self._model_type(w)
+ fp16 &= pt or jit or onnx or xml or engine or nn_module # or triton # FP16
+ nhwc = (
+ coreml or saved_model or pb or tflite or edgetpu
+ ) # BHWC formats (vs torch BCWH)
+ stride = 32 # default stride
+ model, metadata = None, None
+
+ # Set device
+ cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA
+ if cuda and not any(
+ [nn_module, pt, jit, engine, onnx]
+ ): # GPU dataloader formats
+ device = torch.device("cpu")
+ cuda = False
+
+ # Download if not local
+ if not (pt or triton or nn_module):
+ w = attempt_download_asset(w)
+
+ # Load model
+ if nn_module: # in-memory PyTorch model
+ model = weights.to(device)
+ model = model.fuse(verbose=verbose) if fuse else model
+ if hasattr(model, "kpt_shape"):
+ kpt_shape = model.kpt_shape # pose-only
+ # stride = max(int(model.stride.max()), 32)
+ # names = model.module.names if hasattr(model, "module") else model.names
+ model.half() if fp16 else model.float()
+ self.model = model # explicitly assign for to(), cpu(), cuda(), half()
+ pt = True
+ elif pt: # PyTorch
+ from ultralytics.nn.tasks import attempt_load_weights
+
+ model = attempt_load_weights(
+ weights if isinstance(weights, list) else w,
+ device=device,
+ inplace=True,
+ fuse=fuse,
+ )
+ if hasattr(model, "kpt_shape"):
+ kpt_shape = model.kpt_shape # pose-only
+ stride = max(int(model.stride.max()), 32) # model stride
+ names = (
+ model.module.names if hasattr(model, "module") else model.names
+ ) # get class names
+ model.half() if fp16 else model.float()
+ self.model = model # explicitly assign for to(), cpu(), cuda(), half()
+ elif jit: # TorchScript
+ LOGGER.info(f"Loading {w} for TorchScript inference...")
+ extra_files = {"config.txt": ""} # model metadata
+ model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
+ model.half() if fp16 else model.float()
+ if extra_files["config.txt"]: # load metadata dict
+ metadata = json.loads(
+ extra_files["config.txt"], object_hook=lambda x: dict(x.items())
+ )
+ elif dnn: # ONNX OpenCV DNN
+ LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...")
+ check_requirements("opencv-python>=4.5.4")
+ net = cv2.dnn.readNetFromONNX(w)
+ elif onnx: # ONNX Runtime
+ LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
+ check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
+ import onnxruntime
+
+ providers = (
+ ["CUDAExecutionProvider", "CPUExecutionProvider"]
+ if cuda
+ else ["CPUExecutionProvider"]
+ )
+ session = onnxruntime.InferenceSession(w, providers=providers)
+ output_names = [x.name for x in session.get_outputs()]
+ metadata = session.get_modelmeta().custom_metadata_map # metadata
+ elif xml: # OpenVINO
+ LOGGER.info(f"Loading {w} for OpenVINO inference...")
+ check_requirements(
+ "openvino>=2023.0"
+ ) # requires openvino-dev: https://pypi.org/project/openvino-dev/
+ from openvino.runtime import Core, Layout, get_batch # noqa
+
+ core = Core()
+ w = Path(w)
+ if not w.is_file(): # if not *.xml
+ w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir
+ ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin"))
+ if ov_model.get_parameters()[0].get_layout().empty:
+ ov_model.get_parameters()[0].set_layout(Layout("NCHW"))
+ batch_dim = get_batch(ov_model)
+ if batch_dim.is_static:
+ batch_size = batch_dim.get_length()
+ ov_compiled_model = core.compile_model(
+ ov_model, device_name="AUTO"
+ ) # AUTO selects best available device
+ metadata = w.parent / "metadata.yaml"
+ elif engine: # TensorRT
+ LOGGER.info(f"Loading {w} for TensorRT inference...")
+ try:
+ import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
+ except ImportError:
+ if LINUX:
+ check_requirements(
+ "nvidia-tensorrt",
+ cmds="-U --index-url https://pypi.ngc.nvidia.com",
+ )
+ import tensorrt as trt # noqa
+ check_version(
+ trt.__version__, "7.0.0", hard=True
+ ) # require tensorrt>=7.0.0
+ if device.type == "cpu":
+ device = torch.device("cuda:0")
+ Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr"))
+ logger = trt.Logger(trt.Logger.INFO)
+ # Read file
+ with open(w, "rb") as f, trt.Runtime(logger) as runtime:
+ meta_len = int.from_bytes(
+ f.read(4), byteorder="little"
+ ) # read metadata length
+ metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata
+ model = runtime.deserialize_cuda_engine(f.read()) # read engine
+ context = model.create_execution_context()
+ bindings = OrderedDict()
+ output_names = []
+ fp16 = False # default updated below
+ dynamic = False
+ for i in range(model.num_bindings):
+ name = model.get_binding_name(i)
+ dtype = trt.nptype(model.get_binding_dtype(i))
+ if model.binding_is_input(i):
+ if -1 in tuple(model.get_binding_shape(i)): # dynamic
+ dynamic = True
+ context.set_binding_shape(
+ i, tuple(model.get_profile_shape(0, i)[2])
+ )
+ if dtype == np.float16:
+ fp16 = True
+ else: # output
+ output_names.append(name)
+ shape = tuple(context.get_binding_shape(i))
+ im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
+ bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
+ binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
+ batch_size = bindings["images"].shape[
+ 0
+ ] # if dynamic, this is instead max batch size
+ elif coreml: # CoreML
+ LOGGER.info(f"Loading {w} for CoreML inference...")
+ import coremltools as ct
+
+ model = ct.models.MLModel(w)
+ metadata = dict(model.user_defined_metadata)
+ elif saved_model: # TF SavedModel
+ LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")
+ import tensorflow as tf
+
+ keras = False # assume TF1 saved_model
+ model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
+ metadata = Path(w) / "metadata.yaml"
+ elif (
+ pb
+ ): # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
+ LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
+ import tensorflow as tf
+
+ from ultralytics.engine.exporter import gd_outputs
+
+ def wrap_frozen_graph(gd, inputs, outputs):
+ """Wrap frozen graphs for deployment."""
+ x = tf.compat.v1.wrap_function(
+ lambda: tf.compat.v1.import_graph_def(gd, name=""), []
+ ) # wrapped
+ ge = x.graph.as_graph_element
+ return x.prune(
+ tf.nest.map_structure(ge, inputs),
+ tf.nest.map_structure(ge, outputs),
+ )
+
+ gd = tf.Graph().as_graph_def() # TF GraphDef
+ with open(w, "rb") as f:
+ gd.ParseFromString(f.read())
+ frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
+ elif tflite or edgetpu:
+ try:
+ from tflite_runtime.interpreter import Interpreter, load_delegate
+ except ImportError:
+ import tensorflow as tf
+
+ Interpreter, load_delegate = (
+ tf.lite.Interpreter,
+ tf.lite.experimental.load_delegate,
+ )
+ if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
+ LOGGER.info(f"Loading {w} for TensorFlow Lite Edge TPU inference...")
+ delegate = {
+ "Linux": "libedgetpu.so.1",
+ "Darwin": "libedgetpu.1.dylib",
+ "Windows": "edgetpu.dll",
+ }[platform.system()]
+ interpreter = Interpreter(
+ model_path=w, experimental_delegates=[load_delegate(delegate)]
+ )
+ else: # TFLite
+ LOGGER.info(f"Loading {w} for TensorFlow Lite inference...")
+ interpreter = Interpreter(model_path=w) # load TFLite model
+ interpreter.allocate_tensors() # allocate
+ input_details = interpreter.get_input_details() # inputs
+ output_details = interpreter.get_output_details() # outputs
+ # Load metadata
+ with contextlib.suppress(zipfile.BadZipFile):
+ with zipfile.ZipFile(w, "r") as model:
+ meta_file = model.namelist()[0]
+ metadata = ast.literal_eval(model.read(meta_file).decode("utf-8"))
+ elif tfjs: # TF.js
+ raise NotImplementedError(
+ "YOLOv8 TF.js inference is not currently supported."
+ )
+ elif paddle: # PaddlePaddle
+ LOGGER.info(f"Loading {w} for PaddlePaddle inference...")
+ check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle")
+ import paddle.inference as pdi # noqa
+
+ w = Path(w)
+ if not w.is_file(): # if not *.pdmodel
+ w = next(
+ w.rglob("*.pdmodel")
+ ) # get *.pdmodel file from *_paddle_model dir
+ config = pdi.Config(str(w), str(w.with_suffix(".pdiparams")))
+ if cuda:
+ config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
+ predictor = pdi.create_predictor(config)
+ input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
+ output_names = predictor.get_output_names()
+ metadata = w.parents[1] / "metadata.yaml"
+ elif ncnn: # ncnn
+ LOGGER.info(f"Loading {w} for ncnn inference...")
+ check_requirements(
+ "git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn"
+ ) # requires ncnn
+ import ncnn as pyncnn
+
+ net = pyncnn.Net()
+ net.opt.use_vulkan_compute = cuda
+ w = Path(w)
+ if not w.is_file(): # if not *.param
+ w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir
+ net.load_param(str(w))
+ net.load_model(str(w.with_suffix(".bin")))
+ metadata = w.parent / "metadata.yaml"
+ elif triton: # NVIDIA Triton Inference Server
+ check_requirements("tritonclient[all]")
+ from ultralytics.utils.triton import TritonRemoteModel
+
+ model = TritonRemoteModel(w)
+ else:
+ from ultralytics.engine.exporter import export_formats
+
+ raise TypeError(
+ f"model='{w}' is not a supported model format. "
+ "See https://docs.ultralytics.com/modes/predict for help."
+ f"\n\n{export_formats()}"
+ )
+
+ # Load external metadata YAML
+ if isinstance(metadata, (str, Path)) and Path(metadata).exists():
+ metadata = yaml_load(metadata)
+ if metadata:
+ for k, v in metadata.items():
+ if k in ("stride", "batch"):
+ metadata[k] = int(v)
+ elif k in ("imgsz", "names", "kpt_shape") and isinstance(v, str):
+ metadata[k] = eval(v)
+ stride = metadata["stride"]
+ task = metadata["task"]
+ batch = metadata["batch"]
+ imgsz = metadata["imgsz"]
+ names = metadata["names"]
+ kpt_shape = metadata.get("kpt_shape")
+ elif not (pt or triton or nn_module):
+ LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
+
+ # Check names
+ if "names" not in locals(): # names missing
+ names = default_class_names(data)
+ names = check_class_names(names)
+
+ """Disable gradients since this is only used in validation."""
+ # if pt:
+ # for p in model.parameters():
+ # p.requires_grad = False
+
+ self.__dict__.update(locals()) # assign all variables to self
+
+ def forward(self, im, augment=False, visualize=False, embed=None):
+ """
+ Runs inference on the YOLOv8 MultiBackend model.
+
+ Args:
+ im (torch.Tensor): The image tensor to perform inference on.
+ augment (bool): whether to perform data augmentation during inference,
+ defaults to False
+ visualize (bool): whether to visualize the output predictions,
+ defaults to False
+ embed (list, optional): A list of feature vectors/embeddings to return.
+
+ Returns:
+ (tuple): Tuple containing the raw output tensor,
+ and processed output for visualization (if visualize=True)
+ """
+ b, ch, h, w = im.shape # batch, channel, height, width
+ if self.fp16 and im.dtype != torch.float16:
+ im = im.half() # to FP16
+ if self.nhwc:
+ im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
+
+ if self.pt or self.nn_module: # PyTorch
+ y = self.modell(im)
+ elif self.jit: # TorchScript
+ y = self.model(im)
+ elif self.dnn: # ONNX OpenCV DNN
+ im = im.cpu().numpy() # torch to numpy
+ self.net.setInput(im)
+ y = self.net.forward()
+ elif self.onnx: # ONNX Runtime
+ im = im.cpu().numpy() # torch to numpy
+ y = self.session.run(
+ self.output_names, {self.session.get_inputs()[0].name: im}
+ )
+ elif self.xml: # OpenVINO
+ im = im.cpu().numpy() # FP32
+ y = list(self.ov_compiled_model(im).values())
+ elif self.engine: # TensorRT
+ if self.dynamic and im.shape != self.bindings["images"].shape:
+ i = self.model.get_binding_index("images")
+ self.context.set_binding_shape(i, im.shape) # reshape if dynamic
+ self.bindings["images"] = self.bindings["images"]._replace(
+ shape=im.shape
+ )
+ for name in self.output_names:
+ i = self.model.get_binding_index(name)
+ self.bindings[name].data.resize_(
+ tuple(self.context.get_binding_shape(i))
+ )
+ s = self.bindings["images"].shape
+ assert (
+ im.shape == s
+ ), f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} \
+ max model size {s}"
+ self.binding_addrs["images"] = int(im.data_ptr())
+ self.context.execute_v2(list(self.binding_addrs.values()))
+ y = [self.bindings[x].data for x in sorted(self.output_names)]
+ elif self.coreml: # CoreML
+ im = im[0].cpu().numpy()
+ im_pil = Image.fromarray((im * 255).astype("uint8"))
+ # im = im.resize((192, 320), Image.BILINEAR)
+ y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized
+ if "confidence" in y:
+ raise TypeError(
+ "Ultralytics only supports inference of non-pipelined CoreML \
+ models exported with "
+ f"'nms=False', but 'model={w}' has an NMS pipeline created by \
+ an 'nms=True' export."
+ )
+ elif len(y) == 1: # classification model
+ y = list(y.values())
+ elif len(y) == 2: # segmentation model
+ y = list(
+ reversed(y.values())
+ ) # reversed for segmentation models (pred, proto)
+ elif self.paddle: # PaddlePaddle
+ im = im.cpu().numpy().astype(np.float32)
+ self.input_handle.copy_from_cpu(im)
+ self.predictor.run()
+ y = [
+ self.predictor.get_output_handle(x).copy_to_cpu()
+ for x in self.output_names
+ ]
+ elif self.ncnn: # ncnn
+ mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
+ ex = self.net.create_extractor()
+ input_names, output_names = self.net.input_names(), self.net.output_names()
+ ex.input(input_names[0], mat_in)
+ y = []
+ for output_name in output_names:
+ mat_out = self.pyncnn.Mat()
+ ex.extract(output_name, mat_out)
+ y.append(np.array(mat_out)[None])
+ elif self.triton: # NVIDIA Triton Inference Server
+ im = im.cpu().numpy() # torch to numpy
+ y = self.model(im)
+ else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
+ im = im.cpu().numpy()
+ if self.saved_model: # SavedModel
+ y = self.model(im, training=False) if self.keras else self.model(im)
+ if not isinstance(y, list):
+ y = [y]
+ elif self.pb: # GraphDef
+ y = self.frozen_func(x=self.tf.constant(im))
+ if (
+ len(y) == 2 and len(self.names) == 999
+ ): # segments and names not defined
+ ip, ib = (
+ (0, 1) if len(y[0].shape) == 4 else (1, 0)
+ ) # index of protos, boxes
+ nc = (
+ y[ib].shape[1] - y[ip].shape[3] - 4
+ ) # y = (1, 160, 160, 32), (1, 116, 8400)
+ self.names = {i: f"class{i}" for i in range(nc)}
+ else: # Lite or Edge TPU
+ details = self.input_details[0]
+ integer = details["dtype"] in (
+ np.int8,
+ np.int16,
+ ) # is TFLite quantized int8 or int16 model
+ if integer:
+ scale, zero_point = details["quantization"]
+ im = (im / scale + zero_point).astype(details["dtype"]) # de-scale
+ self.interpreter.set_tensor(details["index"], im)
+ self.interpreter.invoke()
+ y = []
+ for output in self.output_details:
+ x = self.interpreter.get_tensor(output["index"])
+ if integer:
+ scale, zero_point = output["quantization"]
+ x = (x.astype(np.float32) - zero_point) * scale # re-scale
+ if x.ndim > 2: # if task is not classification
+ x[:, [0, 2]] *= w
+ x[:, [1, 3]] *= h
+ y.append(x)
+ # TF segment fixes: export is reversed vs ONNX export
+ # and protos are transposed
+ if len(y) == 2: # segment with (det, proto) output order reversed
+ if len(y[1].shape) != 4:
+ y = list(
+ reversed(y)
+ ) # should be y = (1, 116, 8400), (1, 160, 160, 32)
+ y[1] = np.transpose(
+ y[1], (0, 3, 1, 2)
+ ) # should be y = (1, 116, 8400), (1, 32, 160, 160)
+ y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
+
+ if isinstance(y, (list, tuple)):
+ return (
+ self.from_numpy(y[0])
+ if len(y) == 1
+ else [self.from_numpy(x) for x in y]
+ )
+ else:
+ return self.from_numpy(y)
+
+ def from_numpy(self, x):
+ """
+ Convert a numpy array to a tensor.
+
+ Args:
+ x (np.ndarray): The array to be converted.
+
+ Returns:
+ (torch.Tensor): The converted tensor
+ """
+ return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
+
+ def warmup(self, imgsz=(1, 3, 640, 640)):
+ """
+ Warm up the model by running one forward pass with a dummy input.
+
+ Args:
+ imgsz (tuple): The shape of the dummy input tensor in the format
+ (batch_size, channels, height, width)
+ """
+ warmup_types = (
+ self.pt,
+ self.jit,
+ self.onnx,
+ self.engine,
+ self.saved_model,
+ self.pb,
+ self.triton,
+ self.nn_module,
+ )
+ if any(warmup_types) and (self.device.type != "cpu" or self.triton):
+ im = torch.empty(
+ *imgsz,
+ dtype=torch.half if self.fp16 else torch.float,
+ device=self.device,
+ ) # input
+ for _ in range(2 if self.jit else 1):
+ self.forward(im) # warmup
+
+ @staticmethod
+ def _model_type(p="path/to/model.pt"):
+ """
+ This function takes a path to a model file and returns the model type.
+ Possibles types are pt, jit, onnx, xml, engine, coreml, saved_model, pb,
+ tflite, edgetpu, tfjs, ncnn or paddle.
+
+ Args:
+ p: path to the model file. Defaults to path/to/model.pt
+
+ Examples:
+ >>> model = AutoBackend(weights="path/to/model.onnx")
+ >>> model_type = model._model_type() # returns "onnx"
+ """
+ from ultralytics.engine.exporter import export_formats
+
+ sf = list(export_formats().Suffix) # export suffixes
+ if not is_url(p, check=False) and not isinstance(p, str):
+ check_suffix(p, sf) # checks
+ name = Path(p).name
+ types = [s in name for s in sf]
+ types[5] |= name.endswith(
+ ".mlmodel"
+ ) # retain support for older Apple CoreML *.mlmodel formats
+ types[8] &= not types[9] # tflite &= not edgetpu
+ if any(types):
+ triton = False
+ else:
+ from urllib.parse import urlsplit
+
+ url = urlsplit(p)
+ triton = url.netloc and url.path and url.scheme in {"http", "grpc"}
+
+ return types + [triton]
diff --git a/recipes/object_detection/validation/validator.py b/recipes/object_detection/validation/validator.py
new file mode 100644
index 0000000..32a0ea7
--- /dev/null
+++ b/recipes/object_detection/validation/validator.py
@@ -0,0 +1,756 @@
+import json
+import time
+import os
+from pathlib import Path
+
+import numpy as np
+import torch
+
+from ultralytics.cfg import get_cfg, get_save_dir
+from ultralytics.data.utils import check_cls_dataset, check_det_dataset
+from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
+from ultralytics.utils.checks import check_imgsz
+from ultralytics.utils.ops import Profile
+from ultralytics.utils.torch_utils import (
+ de_parallel,
+ select_device,
+ smart_inference_mode,
+)
+
+from ultralytics.data import build_dataloader, build_yolo_dataset, converter
+from ultralytics.utils import ops
+from ultralytics.utils.checks import check_requirements
+from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
+from ultralytics.utils.plotting import output_to_target, plot_images
+
+from .autobackend import AutoBackend
+
+
+class BaseValidator:
+ """
+ BaseValidator.
+
+ A base class for creating validators.
+
+ Attributes:
+ args (SimpleNamespace): Configuration for the validator.
+ dataloader (DataLoader): Dataloader to use for validation.
+ pbar (tqdm): Progress bar to update during validation.
+ model (nn.Module): Model to validate.
+ data (dict): Data dictionary.
+ device (torch.device): Device to use for validation.
+ batch_i (int): Current batch index.
+ training (bool): Whether the model is in training mode.
+ names (dict): Class names.
+ seen: Records the number of images seen so far during validation.
+ stats: Placeholder for statistics during validation.
+ confusion_matrix: Placeholder for a confusion matrix.
+ nc: Number of classes.
+ iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
+ jdict (dict): Dictionary to store JSON validation results.
+ speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss',
+ 'postprocess' and their respective batch processing times in milliseconds.
+ save_dir (Path): Directory to save results.
+ plots (dict): Dictionary to store plots for visualization.
+ callbacks (dict): Dictionary to store various callback functions.
+ """
+
+ def __init__(
+ self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None
+ ):
+ """
+ Initializes a BaseValidator instance.
+
+ Args:
+ dataloader (torch.utils.data.DataLoader): Dataloader to be used for
+ validation.
+ save_dir (Path, optional): Directory to save results.
+ pbar (tqdm.tqdm): Progress bar for displaying progress.
+ args (SimpleNamespace): Configuration for the validator.
+ _callbacks (dict): Dictionary to store various callback functions.
+ """
+ self.args = get_cfg(overrides=args)
+ self.dataloader = dataloader
+ self.args.plot = False
+ self.pbar = pbar
+ self.stride = None
+ self.data = None
+ self.device = None
+ self.batch_i = None
+ self.training = True
+ self.names = None
+ self.seen = None
+ self.stats = None
+ self.confusion_matrix = None
+ self.nc = None
+ self.iouv = None
+ self.jdict = None
+ self.device = "cuda"
+ self.speed = {
+ "preprocess": 0.0,
+ "inference": 0.0,
+ "loss": 0.0,
+ "postprocess": 0.0,
+ }
+
+ self.save_dir = save_dir or get_save_dir(self.args)
+
+ """This creates a folder `runs/detect/train` in which it's saved
+ the result of every validation."""
+ if self.args.plots:
+ (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(
+ parents=True, exist_ok=True
+ )
+ if self.args.conf is None:
+ self.args.conf = 0.001 # default conf=0.001
+ self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
+
+ self.plots = {}
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
+
+ @smart_inference_mode()
+ def __call__(self, trainer=None, model=None):
+ """Supports validation of a pre-trained model if passed or a model being
+ trained if trainer is passed (trainer gets priority).
+ """
+ self.training = trainer is not None
+ augment = self.args.augment and (not self.training)
+ if self.training:
+ self.device = trainer.device
+ self.data = trainer.data
+ self.args.half = self.device.type != "cpu" # force FP16 val during training
+ model = trainer.ema.ema or trainer.model
+ model = model.half() if self.args.half else model.float()
+ # self.model = model
+ self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
+ self.args.plots &= trainer.stopper.possible_stop or (
+ trainer.epoch == trainer.epochs - 1
+ )
+ model.eval()
+ else:
+ callbacks.add_integration_callbacks(self)
+ model = AutoBackend(
+ model or self.args.model,
+ device=select_device(self.args.device, self.args.batch),
+ dnn=self.args.dnn,
+ data=self.args.data,
+ fp16=self.args.half,
+ fuse=False,
+ )
+ # self.model = model
+ self.device = model.device # update device
+ self.args.half = model.fp16 # update half
+ stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
+ imgsz = check_imgsz(self.args.imgsz, stride=stride)
+ if engine:
+ self.args.batch = model.batch_size
+ elif not pt and not jit:
+ self.args.batch = 1 # export.py models default to batch-size 1
+ LOGGER.info(
+ f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) \
+ for non-PyTorch models"
+ )
+
+ if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
+ self.data = check_det_dataset(self.args.data)
+ elif self.args.task == "classify":
+ self.data = check_cls_dataset(self.args.data, split=self.args.split)
+ else:
+ raise FileNotFoundError(
+ emojis(
+ f"Dataset '{self.args.data}' for task={self.args.task} not \
+ found ❌"
+ )
+ )
+
+ if self.device.type in ("cpu", "mps"):
+ self.args.workers = (
+ 0 # faster CPU val as time dominated by inference, not dataloading
+ )
+ if not pt:
+ self.args.rect = False
+ self.stride = model.stride # used in get_dataloader() for padding
+ self.dataloader = self.dataloader or self.get_dataloader(
+ self.data.get(self.args.split), self.args.batch
+ )
+
+ model.eval()
+ model.warmup(
+ imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)
+ ) # warmup
+
+ self.run_callbacks("on_val_start")
+ dt = (
+ Profile(),
+ Profile(),
+ Profile(),
+ Profile(),
+ )
+ bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
+ self.init_metrics(de_parallel(model))
+ self.jdict = [] # empty before each val
+ for batch_i, batch in enumerate(bar):
+ self.run_callbacks("on_val_batch_start")
+ self.batch_i = batch_i
+ # Preprocess
+ with dt[0]:
+ batch = self.preprocess(batch)
+
+ # Inference
+ with dt[1]:
+ preds = model(batch["img"], augment=augment)
+
+ # Loss
+ with dt[2]:
+ if self.training:
+ self.loss += model.loss(batch, preds)[1]
+
+ # Postprocess
+ with dt[3]:
+ preds = self.postprocess(preds)
+
+ self.update_metrics(preds, batch)
+ if self.args.plots and batch_i < 3:
+ self.plot_val_samples(batch, batch_i)
+ self.plot_predictions(batch, preds, batch_i)
+
+ self.run_callbacks("on_val_batch_end")
+ stats = self.get_stats()
+ self.check_stats(stats)
+ self.speed = dict(
+ zip(
+ self.speed.keys(),
+ (x.t / len(self.dataloader.dataset) * 1e3 for x in dt),
+ )
+ )
+ self.finalize_metrics()
+ self.print_results()
+ self.run_callbacks("on_val_end")
+ if self.training:
+ model.float()
+ results = {
+ **stats,
+ **trainer.label_loss_items(
+ self.loss.cpu() / len(self.dataloader), prefix="val"
+ ),
+ }
+ return {
+ k: round(float(v), 5) for k, v in results.items()
+ } # return results as 5 decimal place floats
+ else:
+ LOGGER.info(
+ "Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms \
+ postprocess per image"
+ % tuple(self.speed.values())
+ )
+ if self.args.save_json and self.jdict:
+ with open(str(self.save_dir / "predictions.json"), "w") as f:
+ LOGGER.info(f"Saving {f.name}...")
+ json.dump(self.jdict, f) # flatten and save
+ stats = self.eval_json(stats) # update stats
+ if self.args.plots or self.args.save_json:
+ LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
+ return stats
+
+ def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
+ """
+ Matches predictions to ground truth objects (pred_classes, true_classes)
+ using IoU.
+
+ Args:
+ pred_classes (torch.Tensor): Predicted class indices of shape(N,).
+ true_classes (torch.Tensor): Target class indices of shape(M,).
+ iou (torch.Tensor): An NxM tensor containing the pairwise IoU values
+ for predictions and ground of truth
+ use_scipy (bool): Whether to use scipy for matching (more precise).
+
+ Returns:
+ (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
+ """
+ # Dx10 matrix, where D - detections, 10 - IoU thresholds
+ correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
+ # LxD matrix where L - labels (rows), D - detections (columns)
+ correct_class = true_classes[:, None] == pred_classes
+ iou = iou * correct_class # zero out the wrong classes
+ iou = iou.cpu().numpy()
+ for i, threshold in enumerate(self.iouv.cpu().tolist()):
+ if use_scipy:
+ import scipy # scope import to avoid importing for all commands
+
+ cost_matrix = iou * (iou >= threshold)
+ if cost_matrix.any():
+ labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(
+ cost_matrix, maximize=True
+ )
+ valid = cost_matrix[labels_idx, detections_idx] > 0
+ if valid.any():
+ correct[detections_idx[valid], i] = True
+ else:
+ matches = np.nonzero(
+ iou >= threshold
+ ) # IoU > threshold and classes match
+ matches = np.array(matches).T
+ if matches.shape[0]:
+ if matches.shape[0] > 1:
+ matches = matches[
+ iou[matches[:, 0], matches[:, 1]].argsort()[::-1]
+ ]
+ matches = matches[
+ np.unique(matches[:, 1], return_index=True)[1]
+ ]
+ # matches = matches[matches[:, 2].argsort()[::-1]]
+ matches = matches[
+ np.unique(matches[:, 0], return_index=True)[1]
+ ]
+ correct[matches[:, 1].astype(int), i] = True
+ return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
+
+ def add_callback(self, event: str, callback):
+ """Appends the given callback."""
+ self.callbacks[event].append(callback)
+
+ def run_callbacks(self, event: str):
+ """Runs all callbacks associated with a specified event."""
+ for callback in self.callbacks.get(event, []):
+ callback(self)
+
+ def get_dataloader(self, dataset_path, batch_size):
+ """Get data loader from dataset path and batch size."""
+ raise NotImplementedError(
+ "get_dataloader function not implemented for this validator"
+ )
+
+ def build_dataset(self, img_path):
+ """Build dataset."""
+ raise NotImplementedError("build_dataset function not implemented in validator")
+
+ def preprocess(self, batch):
+ """Preprocesses an input batch."""
+ return batch
+
+ def postprocess(self, preds):
+ """Describes and summarizes the purpose of 'postprocess()' but no
+ details mentioned."""
+ return preds
+
+ def init_metrics(self, model):
+ """Initialize performance metrics for the YOLO model."""
+ pass
+
+ def update_metrics(self, preds, batch):
+ """Updates metrics based on predictions and batch."""
+ pass
+
+ def finalize_metrics(self, *args, **kwargs):
+ """Finalizes and returns all metrics."""
+ pass
+
+ def get_stats(self):
+ """Returns statistics about the model's performance."""
+ return {}
+
+ def check_stats(self, stats):
+ """Checks statistics."""
+ pass
+
+ def print_results(self):
+ """Prints the results of the model's predictions."""
+ pass
+
+ def get_desc(self):
+ """Get description of the YOLO model."""
+ pass
+
+ @property
+ def metric_keys(self):
+ """Returns the metric keys used in YOLO training/validation."""
+ return []
+
+ def on_plot(self, name, data=None):
+ """Registers plots (e.g. to be consumed in callbacks)"""
+ self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
+
+ # TODO: may need to put these following functions into callback
+ def plot_val_samples(self, batch, ni):
+ """Plots validation samples during training."""
+ pass
+
+ def plot_predictions(self, batch, preds, ni):
+ """Plots YOLO model predictions on batch images."""
+ pass
+
+ def pred_to_json(self, preds, batch):
+ """Convert predictions to JSON format."""
+ pass
+
+ def eval_json(self, stats):
+ """Evaluate and return JSON format of prediction statistics."""
+ pass
+
+
+class DetectionValidator(BaseValidator):
+ """
+ A class extending the BaseValidator class for validation based on a detection model.
+
+ Example:
+ ```python
+ from ultralytics.models.yolo.detect import DetectionValidator
+
+ args = dict(model='yolov8n.pt', data='coco8.yaml')
+ validator = DetectionValidator(args=args)
+ validator()
+ ```
+ """
+
+ def __init__(
+ self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None
+ ):
+ """Initialize detection model with necessary variables and settings."""
+ super().__init__(dataloader, save_dir, pbar, args, _callbacks)
+ self.nt_per_class = None
+ self.is_coco = False
+ self.class_map = None
+ self.args.task = "detect"
+ self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
+ self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
+ self.niou = self.iouv.numel()
+ self.lb = [] # for autolabelling
+
+ def preprocess(self, batch):
+ """Preprocesses batch of images for YOLO training."""
+ batch["img"] = batch["img"].to(self.device, non_blocking=True)
+ batch["img"] = (
+ batch["img"].half() if self.args.half else batch["img"].float()
+ ) / 255
+ for k in ["batch_idx", "cls", "bboxes"]:
+ batch[k] = batch[k].to(self.device)
+
+ if self.args.save_hybrid:
+ height, width = batch["img"].shape[2:]
+ nb = len(batch["img"])
+ bboxes = batch["bboxes"] * torch.tensor(
+ (width, height, width, height), device=self.device
+ )
+ self.lb = (
+ [
+ torch.cat(
+ [
+ batch["cls"][batch["batch_idx"] == i],
+ bboxes[batch["batch_idx"] == i],
+ ],
+ dim=-1,
+ )
+ for i in range(nb)
+ ]
+ if self.args.save_hybrid
+ else []
+ ) # for autolabelling
+
+ return batch
+
+ def init_metrics(self, model):
+ """Initialize evaluation metrics for YOLO."""
+ val = self.data.get(self.args.split, "") # validation path
+ self.is_coco = (
+ isinstance(val, str)
+ and "coco" in val
+ and val.endswith(f"{os.sep}val2017.txt")
+ ) # is COCO
+ self.class_map = (
+ converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
+ )
+ self.args.save_json |= (
+ self.is_coco and not self.training
+ ) # run on final val if training COCO
+ self.names = model.names
+ self.nc = len(model.names)
+ self.metrics.names = self.names
+ self.metrics.plot = self.args.plots
+ self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
+ self.seen = 0
+ self.jdict = []
+ self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[])
+
+ def get_desc(self):
+ """Return a formatted string summarizing class metrics of YOLO model."""
+ return ("%22s" + "%11s" * 6) % (
+ "Class",
+ "Images",
+ "Instances",
+ "Box(P",
+ "R",
+ "mAP50",
+ "mAP50-95)",
+ )
+
+ def postprocess(self, preds):
+ """Apply Non-maximum suppression to prediction outputs."""
+ return ops.non_max_suppression(
+ preds,
+ self.args.conf,
+ self.args.iou,
+ labels=self.lb,
+ multi_label=True,
+ agnostic=self.args.single_cls,
+ max_det=self.args.max_det,
+ )
+
+ def _prepare_batch(self, si, batch):
+ """Prepares a batch of images and annotations for validation."""
+ idx = batch["batch_idx"] == si
+ cls = batch["cls"][idx].squeeze(-1)
+ bbox = batch["bboxes"][idx]
+ ori_shape = batch["ori_shape"][si]
+ imgsz = batch["img"].shape[2:]
+ ratio_pad = batch["ratio_pad"][si]
+ if len(cls):
+ bbox = (
+ ops.xywh2xyxy(bbox)
+ * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]
+ ) # target boxes
+ ops.scale_boxes(
+ imgsz, bbox, ori_shape, ratio_pad=ratio_pad
+ ) # native-space labels
+ return dict(
+ cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad
+ )
+
+ def _prepare_pred(self, pred, pbatch):
+ """Prepares a batch of images and annotations for validation."""
+ predn = pred.clone()
+ ops.scale_boxes(
+ pbatch["imgsz"],
+ predn[:, :4],
+ pbatch["ori_shape"],
+ ratio_pad=pbatch["ratio_pad"],
+ ) # native-space pred
+ return predn
+
+ def update_metrics(self, preds, batch):
+ """Metrics."""
+ for si, pred in enumerate(preds):
+ self.seen += 1
+ npr = len(pred)
+ stat = dict(
+ conf=torch.zeros(0, device=self.device),
+ pred_cls=torch.zeros(0, device=self.device),
+ tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+ )
+ pbatch = self._prepare_batch(si, batch)
+ cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
+ nl = len(cls)
+ stat["target_cls"] = cls
+ if npr == 0:
+ if nl:
+ for k in self.stats.keys():
+ self.stats[k].append(stat[k])
+ # TODO: obb has not supported confusion_matrix yet.
+ # if self.args.plots and self.args.task != "obb":
+ # self.confusion_matrix.process_batch(
+ # detections=None, gt_bboxes=bbox, gt_cls=cls
+ # )
+ continue
+
+ # Predictions
+ if self.args.single_cls:
+ pred[:, 5] = 0
+ predn = self._prepare_pred(pred, pbatch)
+ stat["conf"] = predn[:, 4]
+ stat["pred_cls"] = predn[:, 5]
+
+ # Evaluate
+ if nl:
+ stat["tp"] = self._process_batch(predn, bbox, cls)
+ # TODO: obb has not supported confusion_matrix yet.
+ # if self.args.plots and self.args.task != "obb":
+ # self.confusion_matrix.process_batch(predn, bbox, cls)
+ for k in self.stats.keys():
+ self.stats[k].append(stat[k])
+
+ # Save
+ if self.args.save_json:
+ self.pred_to_json(predn, batch["im_file"][si])
+ if self.args.save_txt:
+ file = (
+ self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt'
+ )
+ self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file)
+
+ def finalize_metrics(self, *args, **kwargs):
+ """Set final values for metrics speed and confusion matrix."""
+ self.metrics.speed = self.speed
+ self.metrics.confusion_matrix = self.confusion_matrix
+
+ def get_stats(self):
+ """Returns metrics statistics and results dictionary."""
+ stats = {
+ k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()
+ } # to numpy
+ if len(stats) and stats["tp"].any():
+ self.metrics.process(**stats)
+ self.nt_per_class = np.bincount(
+ stats["target_cls"].astype(int), minlength=self.nc
+ ) # number of targets per class
+ return self.metrics.results_dict
+
+ def print_results(self):
+ """Prints training/validation set metrics per class."""
+ pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
+ LOGGER.info(
+ pf
+ % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())
+ )
+ if self.nt_per_class.sum() == 0:
+ LOGGER.warning(
+ f"WARNING ⚠️ no labels found in {self.args.task} set, can not \
+ compute metrics without labels"
+ )
+
+ # Print results per class
+ if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
+ for i, c in enumerate(self.metrics.ap_class_index):
+ LOGGER.info(
+ pf
+ % (
+ self.names[c],
+ self.seen,
+ self.nt_per_class[c],
+ *self.metrics.class_result(i),
+ )
+ )
+
+ if self.args.plots:
+ for normalize in True, False:
+ self.confusion_matrix.plot(
+ save_dir=self.save_dir,
+ names=self.names.values(),
+ normalize=normalize,
+ on_plot=self.on_plot,
+ )
+
+ def _process_batch(self, detections, gt_bboxes, gt_cls):
+ """
+ Return correct prediction matrix.
+
+ Args:
+ detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
+ Each detection is of the format: x1, y1, x2, y2, conf, class.
+ labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
+ Each label is of the format: class, x1, y1, x2, y2.
+
+ Returns:
+ (torch.Tensor): Correct prediction matrix of shape [N, 10] for 10
+ IoU levels.
+ """
+ iou = box_iou(gt_bboxes, detections[:, :4])
+ return self.match_predictions(detections[:, 5], gt_cls, iou)
+
+ def build_dataset(self, img_path, mode="val", batch=None):
+ """
+ Build YOLO Dataset.
+
+ Args:
+ img_path (str): Path to the folder containing images.
+ mode (str): `train` mode or `val` mode, users are able to customize
+ different augmentations for each mode.
+ batch (int, optional): Size of batches, this is for `rect`.
+ Defaults to None.
+ """
+ return build_yolo_dataset(
+ self.args, img_path, batch, self.data, mode=mode, stride=self.stride
+ )
+
+ def get_dataloader(self, dataset_path, batch_size):
+ """Construct and return dataloader."""
+ dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
+ return build_dataloader(
+ dataset, batch_size, self.args.workers, shuffle=False, rank=-1
+ ) # return dataloader
+
+ def plot_val_samples(self, batch, ni):
+ """Plot validation image samples."""
+ plot_images(
+ batch["img"],
+ batch["batch_idx"],
+ batch["cls"].squeeze(-1),
+ batch["bboxes"],
+ paths=batch["im_file"],
+ fname=self.save_dir / f"val_batch{ni}_labels.jpg",
+ names=self.names,
+ on_plot=self.on_plot,
+ )
+
+ def plot_predictions(self, batch, preds, ni):
+ """Plots predicted bounding boxes on input images and saves the result."""
+ plot_images(
+ batch["img"],
+ *output_to_target(preds, max_det=self.args.max_det),
+ paths=batch["im_file"],
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
+ names=self.names,
+ on_plot=self.on_plot,
+ ) # pred
+
+ def save_one_txt(self, predn, save_conf, shape, file):
+ """Save YOLO detections to a txt file in normalized coordinates in a
+ specific format."""
+ gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
+ for *xyxy, conf, cls in predn.tolist():
+ xywh = (
+ (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
+ ) # normalized xywh
+ line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
+ with open(file, "a") as f:
+ f.write(("%g " * len(line)).rstrip() % line + "\n")
+
+ def pred_to_json(self, predn, filename):
+ """Serialize YOLO predictions to COCO json format."""
+ stem = Path(filename).stem
+ image_id = int(stem) if stem.isnumeric() else stem
+ box = ops.xyxy2xywh(predn[:, :4]) # xywh
+ box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
+ for p, b in zip(predn.tolist(), box.tolist()):
+ self.jdict.append(
+ {
+ "image_id": image_id,
+ "category_id": self.class_map[int(p[5])],
+ "bbox": [round(x, 3) for x in b],
+ "score": round(p[4], 5),
+ }
+ )
+
+ def eval_json(self, stats):
+ """Evaluates YOLO output in JSON format and returns performance statistics."""
+ if self.args.save_json and self.is_coco and len(self.jdict):
+ anno_json = (
+ self.data["path"] / "annotations/instances_val2017.json"
+ ) # annotations
+ pred_json = self.save_dir / "predictions.json" # predictions
+ LOGGER.info(
+ f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}..."
+ )
+ try:
+ check_requirements("pycocotools>=2.0.6")
+ from pycocotools.coco import COCO # noqa
+ from pycocotools.cocoeval import COCOeval # noqa
+
+ for x in anno_json, pred_json:
+ assert x.is_file(), f"{x} file not found"
+ anno = COCO(str(anno_json)) # init annotations api
+ pred = anno.loadRes(
+ str(pred_json)
+ ) # init predictions api (must pass string, not Path)
+ eval = COCOeval(anno, pred, "bbox")
+ if self.is_coco:
+ eval.params.imgIds = [
+ int(Path(x).stem) for x in self.dataloader.dataset.im_files
+ ] # images to eval
+ eval.evaluate()
+ eval.accumulate()
+ eval.summarize()
+ stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[
+ :2
+ ] # update mAP50-95 and mAP50
+ except Exception as e:
+ LOGGER.warning(f"pycocotools unable to run: {e}")
+ return stats
diff --git a/recipes/object_detection/yolo_loss.py b/recipes/object_detection/yolo_loss.py
index 29650a6..4810468 100644
--- a/recipes/object_detection/yolo_loss.py
+++ b/recipes/object_detection/yolo_loss.py
@@ -7,6 +7,7 @@
- Matteo Beltrami, 2023
- Francesco Paissan, 2023
"""
+
import torch
import torch.nn as nn
from ultralytics.utils.loss import BboxLoss, v8DetectionLoss
@@ -76,8 +77,11 @@ def __call__(self, preds, batch):
[xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2
).split((self.reg_max * 4, self.nc), 1)
+ # breakpoint()
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
+ # print("pred scores shape ", pred_scores.shape) # x, 8400, 80
+ # print("pred distri shape ", pred_distri.shape) # x, 8400, 64 (reg_max * 4)
dtype = pred_scores.dtype
batch_size = pred_scores.shape[0]
@@ -85,6 +89,7 @@ def __call__(self, preds, batch):
torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype)
* self.stride[0]
) # image size (h,w)
+ # print(imgsz)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# Targets
@@ -95,21 +100,29 @@ def __call__(self, preds, batch):
targets = self.preprocess(
targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]
)
+
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
+ # print("pred bboxes")
+ # print(pred_bboxes.shape)
+ # print(pred_bboxes[0, 0])
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
- pred_scores.detach().sigmoid(),
- (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
+ pred_scores.clone().detach().sigmoid(),
+ (pred_bboxes.clone().detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
+ # print("target bboxes")
+ # print(target_bboxes.shape)
+ # print(target_bboxes[0, 0])
+
target_scores_sum = max(target_scores.sum(), 1)
# Cls loss
@@ -117,9 +130,14 @@ def __call__(self, preds, batch):
self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum
) # BCE
+ # print("classification loss", loss[1])
# Bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
+ # print("target bboxes w norm")
+ # print(target_bboxes.shape)
+ # print(target_bboxes[0, 0])
+
loss[0], loss[2] = self.bbox_loss(
pred_distri,
pred_bboxes,
@@ -134,4 +152,21 @@ def __call__(self, preds, batch):
loss[1] *= self.hyp.cls # cls gain
loss[2] *= self.hyp.dfl # dfl gain
+ # with open("dump.txt", "a") as f:
+ # f.write("bbox loss {}".format(loss[0].item()))
+ # f.write("\n")
+ # f.write("cls loss {}".format(loss[1].item()))
+ # f.write("\n")
+ # f.write("dfl loss {}".format(loss[2].item()))
+ # f.write("\n")
+ # f.write("total {}".format(loss.sum().item()))
+ # f.write("\n")
+ # f.write("total * batch_size {}".format(loss.sum().item() * batch_size))
+ # f.write("\n")
+ #
+ # breakpoint()
+
+ # print(torch.std_mean(batch["img"]))
+ # breakpoint()
+
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)