diff --git a/micromind/core.py b/micromind/core.py index 835b1ec..35b7f5a 100644 --- a/micromind/core.py +++ b/micromind/core.py @@ -5,6 +5,7 @@ Authors: - Francesco Paissan, 2023 """ + from abc import ABC, abstractmethod from argparse import Namespace from dataclasses import dataclass @@ -331,7 +332,7 @@ def add_forward_to_modules(self): self.modules.device = self.device @torch.no_grad() - def compute_params(self): + def compute_params(self, str="total"): """Computes the number of parameters for the modules inside `self.modules`. Returns a dictionary with the parameter count for each module. @@ -341,8 +342,12 @@ def compute_params(self): """ self.eval() params = {} - for k, m in self.modules.items(): - params[k] = summary(m, verbose=0).total_params + if str == "total": + for k, m in self.modules.items(): + params[k] = summary(m, verbose=0).total_params + if str == "trainable": + for k, m in self.modules.items(): + params[k] = summary(m, verbose=0).trainable_params return params @@ -451,6 +456,10 @@ def on_train_end(self): """Runs at the end of each training. Cleans up before exiting.""" pass + def on_train_epoch_end(self): + """Runs at the end of each training epoch. Cleans up before exiting.""" + pass + def eval(self): self.modules.eval() @@ -460,6 +469,7 @@ def train( datasets: Dict = {}, metrics: List[Metric] = [], checkpointer: Optional[Checkpointer] = None, + max_norm=10.0, debug: Optional[bool] = False, ) -> None: """ @@ -525,12 +535,12 @@ def train( loss_epoch += loss.item() self.accelerator.backward(loss) + self.accelerator.clip_grad_norm_( + self.modules.parameters(), max_norm=max_norm + ) self.opt.step() loss_epoch += loss.item() - if hasattr(self, "lr_sched"): - # ok for cos_lr - self.lr_sched.step() for m in self.metrics: if ( @@ -563,21 +573,29 @@ def train( if "val" in datasets: val_metrics = self.validate() - if ( - self.accelerator.is_local_main_process - and self.checkpointer is not None - ): - self.checkpointer( - self, - train_metrics, - val_metrics, - ) else: - val_metrics = train_metrics.update({"val_loss": loss_epoch / (idx + 1)}) + train_metrics.update({"val_loss": loss_epoch / (idx + 1)}) + val_metrics = train_metrics + + self.on_train_epoch_end() + + if self.accelerator.is_local_main_process and self.checkpointer is not None: + self.checkpointer( + self, + train_metrics, + val_metrics, + ) if e >= 1 and self.debug: break + if hasattr(self, "lr_sched"): + # ok for cos_lr + # self.lr_sched.step(val_metrics["val_loss"]) + + self.lr_sched.step() + print(f"sched step - new LR={self.lr_sched.get_lr()}") + self.on_train_end() return None diff --git a/micromind/networks/yolo.py b/micromind/networks/yolo.py index f259f65..fbb8c51 100644 --- a/micromind/networks/yolo.py +++ b/micromind/networks/yolo.py @@ -464,6 +464,10 @@ def __init__( self.heads = heads self.up1 = Upsample(up[0], mode="nearest") self.up2 = Upsample(up[1], mode="nearest") + + # print(filters, heads) + # breakpoint() + self.n1 = XiConv( c_in=int(filters[1] + filters[2]), c_out=int(filters[1]), @@ -471,6 +475,7 @@ def __init__( gamma=3, skip_tensor_in=False, ) + self.n2 = XiConv( int(filters[0] + filters[1]), int(filters[0]), @@ -483,6 +488,10 @@ def __init__( the needed blocks. Otherwise the not needed blocks would be initialized (and thus would occupy space) but will never be used. """ + self.n3 = None + self.n4 = None + self.n5 = None + self.n6 = None if self.heads[1] or self.heads[2]: self.n3 = XiConv( int(filters[0]), @@ -519,6 +528,75 @@ def __init__( ) +class Yolov8NeckOpt_gamma2(Yolov8Neck): + def __init__( + self, filters=[256, 512, 768], up=[2, 2], heads=[True, True, True], d=1 + ): + super().__init__() + self.heads = heads + self.up1 = Upsample(up[0], mode="nearest") + self.up2 = Upsample(up[1], mode="nearest") + + self.n1 = XiConv( + c_in=int(filters[1] + filters[2]), + c_out=int(filters[1]), + kernel_size=3, + gamma=2, + skip_tensor_in=False, + ) + + self.n2 = XiConv( + int(filters[0] + filters[1]), + int(filters[0]), + kernel_size=3, + gamma=2, + skip_tensor_in=False, + ) + """ + Only if we decide to use the 2nd and 3rd detection head we define + the needed blocks. Otherwise the not needed blocks would be initialized + (and thus would occupy space) but will never be used. + """ + self.n3 = None + self.n4 = None + self.n5 = None + self.n6 = None + if self.heads[1] or self.heads[2]: + self.n3 = XiConv( + int(filters[0]), + int(filters[0]), + kernel_size=3, + gamma=2, + stride=2, + padding=1, + skip_tensor_in=False, + ) + self.n4 = XiConv( + int(filters[0] + filters[1]), + int(filters[1]), + kernel_size=3, + gamma=2, + skip_tensor_in=False, + ) + if self.heads[2]: + self.n5 = XiConv( + int(filters[1]), + int(filters[1]), + gamma=2, + kernel_size=3, + stride=2, + padding=1, + skip_tensor_in=False, + ) + self.n6 = XiConv( + int(filters[1] + filters[2]), + int(filters[2]), + gamma=2, + kernel_size=3, + skip_tensor_in=False, + ) + + class DetectionHead(nn.Module): """Implements YOLOv8's detection head. @@ -537,6 +615,7 @@ def __init__(self, nc=80, filters=(), heads=[True, True, True]): super().__init__() self.reg_max = 16 self.nc = nc + # filters = [f for f, h in zip(filters, heads) if h] self.nl = len(filters) self.no = nc + self.reg_max * 4 self.stride = torch.tensor([8.0, 16.0, 32.0], dtype=torch.float16) @@ -615,14 +694,16 @@ class YOLOv8(nn.Module): Number of classes to predict. """ - def __init__(self, w, r, d, num_classes=80): + def __init__(self, w, r, d, num_classes=80, heads=[True, True, True]): super().__init__() self.net = Darknet(w, r, d) self.fpn = Yolov8Neck( - filters=[int(256 * w), int(512 * w), int(512 * w * r)], d=d + filters=[int(256 * w), int(512 * w), int(512 * w * r)], heads=heads, d=d ) self.head = DetectionHead( - num_classes, filters=(int(256 * w), int(512 * w), int(512 * w * r)) + num_classes, + filters=(int(256 * w), int(512 * w), int(512 * w * r)), + heads=heads, ) def forward(self, x): diff --git a/recipes/object_detection/README.md b/recipes/object_detection/README.md index 8e20e9c..2a4ebd9 100644 --- a/recipes/object_detection/README.md +++ b/recipes/object_detection/README.md @@ -1,5 +1,6 @@ ## Object Detection using YOLO +**[16 Jan 2024]** Updated training code for better performance. Added ultralytics metrics calculation .
**[16 Jan 2024]** Added optimized YOLO neck, using XiConv. Fixed compatibility with ultralytics weights.
**[17 Dec 2023]** Add VOC dataset, selective head option, and instructions for dataset download.
**[1 Dec 2023]** Fix DDP handling and computational graph. diff --git a/recipes/object_detection/cfg/data/VOC.yaml b/recipes/object_detection/cfg/data/VOC.yaml index 3b6e115..5249510 100644 --- a/recipes/object_detection/cfg/data/VOC.yaml +++ b/recipes/object_detection/cfg/data/VOC.yaml @@ -43,7 +43,7 @@ mixup: 0.0 # (float) image mixup (probability) copy_paste: 0.0 # (float) segment copy-paste (probability) # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: ../datasets/VOC +path: datasets/VOC train: # train images (relative to 'path') 16551 images - images/train2012 - images/train2007 diff --git a/recipes/object_detection/cfg/data/coco.yaml b/recipes/object_detection/cfg/data/coco.yaml index c19f3f9..aff1da1 100644 --- a/recipes/object_detection/cfg/data/coco.yaml +++ b/recipes/object_detection/cfg/data/coco.yaml @@ -44,7 +44,7 @@ copy_paste: 0.0 # (float) segment copy-paste (probability) # Dataset location -path: /mnt/data/coco # dataset root dir +path: datasets/coco # dataset root dir train: train2017.txt # train images (relative to 'path') 118287 images val: val2017.txt # val images (relative to 'path') 5000 images test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794 diff --git a/recipes/object_detection/cfg/data/coco8.yaml b/recipes/object_detection/cfg/data/coco8.yaml index 3e01c8c..596f837 100644 --- a/recipes/object_detection/cfg/data/coco8.yaml +++ b/recipes/object_detection/cfg/data/coco8.yaml @@ -44,7 +44,7 @@ copy_paste: 0.0 # (float) segment copy-paste (probability) # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] -path: /mnt/data/coco8 # dataset root dir +path: datasets/coco8 # dataset root dir train: images/train # train images (relative to 'path') 4 images val: images/val # val images (relative to 'path') 4 images test: # test images (optional) diff --git a/recipes/object_detection/cfg/yolo_phinet.py b/recipes/object_detection/cfg/yolo_phinet.py index 11d25d6..7dbf0a5 100644 --- a/recipes/object_detection/cfg/yolo_phinet.py +++ b/recipes/object_detection/cfg/yolo_phinet.py @@ -5,16 +5,18 @@ - Matteo Beltrami, 2023 - Francesco Paissan, 2023 """ + # Data configuration batch_size = 8 -data_cfg = "cfg/data/coco.yaml" -data_dir = "data/coco" -epochs = 200 +data_cfg = "cfg/data/VOC.yaml" +data_dir = "datasets/coco" +epochs = 350 +num_classes = 80 # Model configuration input_shape = [3, 640, 640] -alpha = 2.3 -num_layers = 7 +alpha = 1.1 +num_layers = 8 beta = 0.75 t_zero = 5 divisor = 8 diff --git a/recipes/object_detection/inference.py b/recipes/object_detection/inference.py index 9399c1f..5ebff90 100644 --- a/recipes/object_detection/inference.py +++ b/recipes/object_detection/inference.py @@ -27,25 +27,16 @@ preprocess, ) from train import YOLO +from micromind.utils.yolo import load_config class Inference(YOLO): - def __init__(self, hparams): - super().__init__(hparams=hparams, m_cfg={}) - - def forward(self, img): - """Executes the detection network. - - Arguments - --------- - bacth : List[torch.Tensor] - Input to the detection network. - - Returns - ------- - Output of the detection network : torch.Tensor - """ - backbone = self.modules["backbone"](img) + def __init__(self, m_cfg, hparams): + super().__init__(m_cfg, hparams=hparams) + + def forward(self, batch): + """Runs the forward method by calling every module.""" + backbone = self.modules["backbone"](batch) neck_input = backbone[1] neck_input.append(self.modules["sppf"](backbone[0])) neck = self.modules["neck"](*neck_input) @@ -73,6 +64,8 @@ def forward(self, img): img_paths = [sys.argv[2]] for img_path in img_paths: image = torchvision.io.read_image(img_path) + if image.shape[0] == 4: + image = image[:3, :, :] # Mantieni solo i primi 3 canali (RGB) out_paths = [ ( output_folder_path @@ -85,7 +78,8 @@ def forward(self, img): pre_processed_image = preprocess(image) - model = Inference(hparams) + m_cfg, data_cfg = load_config(hparams.data_cfg) + model = Inference(m_cfg, hparams=hparams) # Load pretrained if passed. if hparams.ckpt_pretrained != "": model.load_modules(hparams.ckpt_pretrained) @@ -97,11 +91,13 @@ def forward(self, img): with torch.no_grad(): st = time.time() - predictions = model(pre_processed_image) + predictions = model.forward(pre_processed_image) print(f"Inference took {int(round(((time.time() - st) * 1000)))}ms") + breakpoint() post_predictions = postprocess( preds=predictions[0], img=pre_processed_image, orig_imgs=image ) + breakpoint() class_labels = [s.strip() for s in open(hparams.coco_names, "r").readlines()] draw_bounding_boxes_and_save( @@ -112,4 +108,3 @@ def forward(self, img): ) # Exporting onnx model. - # model.export("model.onnx", "onnx", hparams.input_shape) diff --git a/recipes/object_detection/prepare_data.py b/recipes/object_detection/prepare_data.py index bb3b570..b4e4e6a 100644 --- a/recipes/object_detection/prepare_data.py +++ b/recipes/object_detection/prepare_data.py @@ -6,6 +6,7 @@ - Matteo Beltrami, 2023 - Francesco Paissan, 2023 """ + from typing import Dict import os diff --git a/recipes/object_detection/train.py b/recipes/object_detection/train.py index dfc6f91..7f36eb7 100644 --- a/recipes/object_detection/train.py +++ b/recipes/object_detection/train.py @@ -12,38 +12,43 @@ """ import torch +import torch.nn as nn +import torch.optim as optim from prepare_data import create_loaders -from ultralytics.utils.ops import scale_boxes, xywh2xyxy from yolo_loss import Loss +import math import micromind as mm from micromind.networks import PhiNet -from micromind.networks.yolo import SPPF, DetectionHead, Yolov8Neck, Yolov8NeckOpt +from micromind.networks.yolo import SPPF, Yolov8Neck, DetectionHead from micromind.utils import parse_configuration -from micromind.utils.yolo import ( - load_config, - mean_average_precision, - postprocess, -) +from micromind.utils.yolo import load_config import sys import os +from micromind.utils.yolo import get_variant_multiples +from validation.validator import DetectionValidator class YOLO(mm.MicroMind): def __init__(self, m_cfg, hparams, *args, **kwargs): """Initializes the YOLO model.""" super().__init__(*args, **kwargs) + self.m_cfg = m_cfg + w, r, d = get_variant_multiples("n") self.modules["backbone"] = PhiNet( input_shape=hparams.input_shape, alpha=hparams.alpha, - num_layers=hparams.num_layers, beta=hparams.beta, t_zero=hparams.t_zero, + num_layers=hparams.num_layers, + h_swish=False, + squeeze_excite=True, include_top=False, - compatibility=False, + num_classes=hparams.num_classes, divisor=hparams.divisor, + compatibility=False, downsampling_layers=hparams.downsampling_layers, return_layers=hparams.return_layers, ) @@ -53,11 +58,13 @@ def __init__(self, m_cfg, hparams, *args, **kwargs): ) self.modules["sppf"] = SPPF(*sppf_ch) - self.modules["neck"] = Yolov8NeckOpt( + self.modules["neck"] = Yolov8Neck( filters=neck_filters, up=up, heads=hparams.heads ) - self.modules["head"] = DetectionHead(filters=head_filters, heads=hparams.heads) + self.modules["head"] = DetectionHead( + hparams.num_classes, filters=head_filters, heads=hparams.heads + ) self.criterion = Loss(self.m_cfg, self.modules["head"], self.device) print("Number of parameters for each module:") @@ -131,8 +138,23 @@ def preprocess_batch(self, batch): def forward(self, batch): """Runs the forward method by calling every module.""" - preprocessed_batch = self.preprocess_batch(batch) - backbone = self.modules["backbone"](preprocessed_batch["img"].to(self.device)) + if self.modules.training: + preprocessed_batch = self.preprocess_batch(batch) + backbone = self.modules["backbone"]( + preprocessed_batch["img"].to(self.device) + ) + else: + + if torch.is_tensor(batch): + backbone = self.modules["backbone"](batch) + neck_input = backbone[1] + neck_input.append(self.modules["sppf"](backbone[0])) + neck = self.modules["neck"](*neck_input) + head = self.modules["head"](neck) + return head + + backbone = self.modules["backbone"](batch["img"] / 255) + neck_input = backbone[1] neck_input.append(self.modules["sppf"](backbone[0])) neck = self.modules["neck"](*neck_input) @@ -151,61 +173,159 @@ def compute_loss(self, pred, batch): return lossi_sum + def build_optimizer( + self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e6 + ): + """ + Constructs an optimizer for the given model, based on the specified optimizer + name, learning rate, momentum, weight decay, and number of iterations. + + Args: + model (torch.nn.Module): The model for which to build an optimizer. + name (str, optional): The name of the optimizer to use. If 'auto', the + optimizer is selected based on the number of iterations. + Default: 'auto'. + lr (float, optional): The learning rate for the optimizer. Default: 0.001. + momentum (float, optional): The momentum factor for the optimizer. + Default: 0.9. + decay (float, optional): The weight decay for the optimizer. Default: 1e-5. + iterations (float, optional): The number of iterations, which determines + the optimizer if name is 'auto'. Default: 1e5. + + Returns: + (torch.optim.Optimizer): The constructed optimizer. + """ + + g = [], [], [] # optimizer parameter groups + bn = tuple( + v for k, v in nn.__dict__.items() if "Norm" in k + ) # normalization layers, i.e. BatchNorm2d() + if name == "auto": + print( + f"optimizer: 'optimizer=auto' found, " + f"ignoring 'lr0={lr}' and 'momentum={momentum}' and " + f"determining best 'optimizer', 'lr0' and 'momentum' automatically... " + ) + nc = getattr(model, "nc", 80) # number of classes + lr_fit = round( + 0.002 * 5 / (4 + nc), 6 + ) # lr0 fit equation to 6 decimal places + name, lr, momentum = ("AdamW", lr_fit, 0.9) + lr *= 10 + # self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam + + for module_name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + fullname = f"{module_name}.{param_name}" if module_name else param_name + if "bias" in fullname: # bias (no decay) + g[2].append(param) + elif isinstance(module, bn): # weight (no decay) + g[1].append(param) + else: # weight (with decay) + g[0].append(param) + + if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"): + optimizer = getattr(optim, name, optim.Adam)( + g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0 + ) + elif name == "RMSProp": + optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum) + elif name == "SGD": + optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) + else: + raise NotImplementedError( + f"Optimizer '{name}' not found in list of available optimizers " + f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]." + "To request support for addition optimizers please visit" + "https://github.com/ultralytics/ultralytics." + ) + + optimizer.add_param_group( + {"params": g[0], "weight_decay": decay} + ) # add g0 with weight_decay + optimizer.add_param_group( + {"params": g[1], "weight_decay": 0.0} + ) # add g1 (BatchNorm2d weights) + print( + f"{optimizer:} {type(optimizer).__name__}(lr={lr}, " + f"momentum={momentum}) with parameter groups" + f"{len(g[1])} weight(decay=0.0), {len(g[0])} " + f"weight(decay={decay}), {len(g[2])} bias(decay=0.0)" + ) + return optimizer, lr + + def _setup_scheduler(self, opt, lrf=0.01, lr0=0.01, cos_lr=True): + """Initialize training learning rate scheduler.""" + + def one_cycle(y1=0.0, y2=1.0, steps=100): + """Returns a lambda function for sinusoidal ramp from y1 to y2 + https://arxiv.org/pdf/1812.01187.pdf.""" + return ( + lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + + y1 + ) + + lrf *= lr0 + + if cos_lr: + self.lf = one_cycle(1, lrf, 350) # cosine 1->hyp['lrf'] + else: + self.lf = ( + lambda x: max(1 - x / self.epochs, 0) * (1.0 - lrf) + lrf + ) # linear + return optim.lr_scheduler.LambdaLR(opt, lr_lambda=self.lf) + def configure_optimizers(self): """Configures the optimizer and the scheduler.""" - opt = torch.optim.SGD(self.modules.parameters(), lr=1e-2, weight_decay=0.0005) - sched = torch.optim.lr_scheduler.CosineAnnealingLR( - opt, T_max=14000, eta_min=1e-3 - ) + # opt = torch.optim.SGD(self.modules.parameters(), lr=1e-2, weight_decay=0.0005) + # opt = torch.optim.AdamW( + # self.modules.parameters(), lr=0.000119, weight_decay=0.0 + # ) + opt, lr = self.build_optimizer(self.modules, name="auto", lr=0.01, momentum=0.9) + sched = self._setup_scheduler(opt, 0.01, lr) + return opt, sched @torch.no_grad() - def mAP(self, pred, batch): - """Compute the mean average precision (mAP) for a batch of predictions. - - Arguments - --------- - pred : torch.Tensor - Model predictions for the batch. - batch : dict - A dictionary containing batch information, including bounding boxes, - classes and shapes. - - Returns - ------- - torch.Tensor - A tensor containing the computed mean average precision (mAP) for the batch. + def on_train_epoch_end(self): """ - preprocessed_batch = self.preprocess_batch(batch) - post_predictions = postprocess( - preds=pred[0], img=preprocessed_batch, orig_imgs=batch + Computes the mean average precision (mAP) at the end of the training epoch + and logs the metrics in `metrics.txt` inside the experiment folder. + The `verbose` argument if set to `True` prints details regarding the + number of images, instances and metrics for each class of the dataset. + The `plots` argument, if set to `True`, saves in the `runs/detect/train` + folder the plots of the confusion matrix, the F1-Confidence, + Precision-Confidence, Precision-Recall, Recall-Confidence curves and the + predictions and labels of the first three batches of images. + """ + args = dict( + model="yolov8n.pt", data=hparams.data_cfg, verbose=False, plots=False ) - - batch_bboxes_xyxy = xywh2xyxy(batch["bboxes"]) - dim = batch["resized_shape"][0][0] - batch_bboxes_xyxy[:, :4] *= dim - - batch_bboxes = [] - for i in range(len(batch["batch_idx"])): - for b in range(len(batch_bboxes_xyxy[batch["batch_idx"] == i, :])): - batch_bboxes.append( - scale_boxes( - batch["resized_shape"][i], - batch_bboxes_xyxy[batch["batch_idx"] == i, :][b], - batch["ori_shape"][i], - ) - ) - - batch_bboxes = torch.stack(batch_bboxes).to(self.device) - mmAP = mean_average_precision( - post_predictions, batch, batch_bboxes, data_cfg["nc"] + validator = DetectionValidator(args=args) + + validator(model=self) + + val_metrics = [ + validator.metrics.box.map * 100, + validator.metrics.box.map50 * 100, + validator.metrics.box.map75 * 100, + ] + metrics_file = os.path.join(exp_folder, "val_log.txt") + metrics_info = ( + f"Epoch {self.current_epoch}: " + f"mAP50-95(B): {round(val_metrics[0], 3)}%; " + f"mAP50(B): {round(val_metrics[1], 3)}%; " + f"mAP75(B): {round(val_metrics[2], 3)}%\n" ) - return torch.Tensor([mmAP]) + with open(metrics_file, "a") as file: + file.write(metrics_info) + return def replace_datafolder(hparams, data_cfg): """Replaces the data root folder, if told to do so from the configuration.""" + print(data_cfg["train"]) data_cfg["path"] = str(data_cfg["path"]) data_cfg["path"] = ( data_cfg["path"][:-1] if data_cfg["path"][-1] == "/" else data_cfg["path"] @@ -240,7 +360,7 @@ def replace_datafolder(hparams, data_cfg): m_cfg, data_cfg = load_config(hparams.data_cfg) # check if specified path for images is different, correct it in case - data_cfg = replace_datafolder(hparams, data_cfg) + # data_cfg = replace_datafolder(hparams, data_cfg) m_cfg.imgsz = hparams.input_shape[-1] # temp solution train_loader, val_loader = create_loaders(m_cfg, data_cfg, hparams.batch_size) @@ -255,12 +375,10 @@ def replace_datafolder(hparams, data_cfg): yolo_mind = YOLO(m_cfg, hparams=hparams) - mAP = mm.Metric("mAP", yolo_mind.mAP, eval_only=True, eval_period=1) - yolo_mind.train( epochs=hparams.epochs, datasets={"train": train_loader, "val": val_loader}, - metrics=[mAP], + metrics=[], checkpointer=checkpointer, debug=hparams.debug, ) diff --git a/recipes/object_detection/train_yolov8.py b/recipes/object_detection/train_yolov8.py new file mode 100644 index 0000000..58905f7 --- /dev/null +++ b/recipes/object_detection/train_yolov8.py @@ -0,0 +1,317 @@ +""" +YOLO training. + +This code allows you to train an object detection model with the YOLOv8 neck and loss. + +To run this script, you can start it with: + python train_yolov8.py cfg/.py + +Authors: + - Matteo Beltrami, 2024 + - Francesco Paissan, 2024 +""" + +import torch +import torch.nn as nn +import torch.optim as optim +from prepare_data import create_loaders +from yolo_loss import Loss +import math + +import micromind as mm +from micromind.networks.yolo import Darknet, Yolov8Neck, DetectionHead +from micromind.utils import parse_configuration +from micromind.utils.yolo import get_variant_multiples, load_config +import sys +import os +from validation.validator import DetectionValidator + + +class YOLO(mm.MicroMind): + def __init__(self, m_cfg, hparams, *args, **kwargs): + """Initializes the YOLO model.""" + super().__init__(*args, **kwargs) + + self.m_cfg = m_cfg + w, r, d = get_variant_multiples("n") + + self.modules["backbone"] = Darknet(w, r, d) + self.modules["neck"] = Yolov8Neck( + filters=[int(256 * w), int(512 * w), int(512 * w * r)], + heads=hparams.heads, + d=d, + ) + self.modules["head"] = DetectionHead( + hparams.num_classes, + filters=(int(256 * w), int(512 * w), int(512 * w * r)), + heads=hparams.heads, + ) + self.criterion = Loss(self.m_cfg, self.modules["head"], self.device) + + print("Number of parameters for each module:") + print(self.compute_params()) + + def preprocess_batch(self, batch): + """Preprocesses a batch of images by scaling and converting to float.""" + preprocessed_batch = {} + preprocessed_batch["img"] = ( + batch["img"].to(self.device, non_blocking=True).float() / 255 + ) + for k in batch: + if isinstance(batch[k], torch.Tensor) and k != "img": + preprocessed_batch[k] = batch[k].to(self.device) + + return preprocessed_batch + + def forward(self, batch): + """Runs the forward method by calling every module.""" + if self.modules.training: + preprocessed_batch = self.preprocess_batch(batch) + backbone = self.modules["backbone"]( + preprocessed_batch["img"].to(self.device) + ) + else: + + if torch.is_tensor(batch): + backbone = self.modules["backbone"](batch) + if "sppf" in self.modules.keys(): + neck_input = backbone[1] + neck_input.append(self.modules["sppf"](backbone[0])) + else: + neck_input = backbone + neck = self.modules["neck"](*neck_input) + head = self.modules["head"](neck) + return head + + backbone = self.modules["backbone"](batch["img"] / 255) + + if "sppf" in self.modules.keys(): + neck_input = backbone[1] + neck_input.append(self.modules["sppf"](backbone[0])) + else: + neck_input = backbone + neck = self.modules["neck"](*neck_input) + head = self.modules["head"](neck) + + return head + + def compute_loss(self, pred, batch): + """Computes the loss.""" + preprocessed_batch = self.preprocess_batch(batch) + + lossi_sum, lossi = self.criterion( + pred, + preprocessed_batch, + ) + + return lossi_sum + + def build_optimizer( + self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e6 + ): + """ + Constructs an optimizer for the given model, based on the specified optimizer + name, learning rate, momentum, weight decay, and number of iterations. + + Args: + model (torch.nn.Module): The model for which to build an optimizer. + name (str, optional): The name of the optimizer to use. If 'auto', the + optimizer is selected based on the number of iterations. + Default: 'auto'. + lr (float, optional): The learning rate for the optimizer. Default: 0.001. + momentum (float, optional): The momentum factor for the optimizer. + Default: 0.9. + decay (float, optional): The weight decay for the optimizer. Default: 1e-5. + iterations (float, optional): The number of iterations, which determines + the optimizer if name is 'auto'. Default: 1e5. + + Returns: + (torch.optim.Optimizer): The constructed optimizer. + """ + + g = [], [], [] # optimizer parameter groups + bn = tuple( + v for k, v in nn.__dict__.items() if "Norm" in k + ) # normalization layers, i.e. BatchNorm2d() + if name == "auto": + print( + f"optimizer: 'optimizer=auto' found, " + f"ignoring 'lr0={lr}' and 'momentum={momentum}' and " + f"determining best 'optimizer', 'lr0' and 'momentum' automatically... " + ) + nc = getattr(model, "nc", 80) # number of classes + lr_fit = round( + 0.002 * 5 / (4 + nc), 6 + ) # lr0 fit equation to 6 decimal places + name, lr, momentum = ("AdamW", lr_fit, 0.9) + lr *= 10 + # self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam + + for module_name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + fullname = f"{module_name}.{param_name}" if module_name else param_name + if "bias" in fullname: # bias (no decay) + g[2].append(param) + elif isinstance(module, bn): # weight (no decay) + g[1].append(param) + else: # weight (with decay) + g[0].append(param) + + if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"): + optimizer = getattr(optim, name, optim.Adam)( + g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0 + ) + elif name == "RMSProp": + optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum) + elif name == "SGD": + optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) + else: + raise NotImplementedError( + f"Optimizer '{name}' not found in list of available optimizers " + f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]." + "To request support for addition optimizers please visit" + "https://github.com/ultralytics/ultralytics." + ) + + optimizer.add_param_group( + {"params": g[0], "weight_decay": decay} + ) # add g0 with weight_decay + optimizer.add_param_group( + {"params": g[1], "weight_decay": 0.0} + ) # add g1 (BatchNorm2d weights) + print( + f"{optimizer:} {type(optimizer).__name__}(lr={lr}, " + f"momentum={momentum}) with parameter groups" + f"{len(g[1])} weight(decay=0.0), {len(g[0])} " + f"weight(decay={decay}), {len(g[2])} bias(decay=0.0)" + ) + return optimizer, lr + + def _setup_scheduler(self, opt, lrf=0.01, lr0=0.01, cos_lr=True): + """Initialize training learning rate scheduler.""" + + def one_cycle(y1=0.0, y2=1.0, steps=100): + """Returns a lambda function for sinusoidal ramp from y1 to y2 + https://arxiv.org/pdf/1812.01187.pdf.""" + return ( + lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + + y1 + ) + + lrf *= lr0 + + if cos_lr: + self.lf = one_cycle(1, lrf, 350) # cosine 1->hyp['lrf'] + else: + self.lf = ( + lambda x: max(1 - x / self.epochs, 0) * (1.0 - lrf) + lrf + ) # linear + return optim.lr_scheduler.LambdaLR(opt, lr_lambda=self.lf) + + def configure_optimizers(self): + """Configures the optimizer and the scheduler.""" + # opt = torch.optim.SGD(self.modules.parameters(), lr=1e-2, weight_decay=0.0005) + # opt = torch.optim.AdamW( + # self.modules.parameters(), lr=0.000119, weight_decay=0.0 + # ) + opt, lr = self.build_optimizer(self.modules, name="auto", lr=0.01, momentum=0.9) + sched = self._setup_scheduler(opt, 0.01, lr) + + return opt, sched + + @torch.no_grad() + def on_train_epoch_end(self): + """ + Computes the mean average precision (mAP) at the end of the training epoch + and logs the metrics in `metrics.txt` inside the experiment folder. + The `verbose` argument if set to `True` prints details regarding the + number of images, instances and metrics for each class of the dataset. + The `plots` argument, if set to `True`, saves in the `runs/detect/train` + folder the plots of the confusion matrix, the F1-Confidence, + Precision-Confidence, Precision-Recall, Recall-Confidence curves and the + predictions and labels of the first three batches of images. + """ + args = dict( + model="yolov8n.pt", data=hparams.data_cfg, verbose=False, plots=False + ) + validator = DetectionValidator(args=args) + + validator(model=self) + + val_metrics = [ + validator.metrics.box.map * 100, + validator.metrics.box.map50 * 100, + validator.metrics.box.map75 * 100, + ] + metrics_file = os.path.join(exp_folder, "val_log.txt") + metrics_info = ( + f"Epoch {self.current_epoch}: " + f"mAP50-95(B): {round(val_metrics[0], 3)}%; " + f"mAP50(B): {round(val_metrics[1], 3)}%; " + f"mAP75(B): {round(val_metrics[2], 3)}%\n" + ) + + with open(metrics_file, "a") as file: + file.write(metrics_info) + return + + +def replace_datafolder(hparams, data_cfg): + """Replaces the data root folder, if told to do so from the configuration.""" + print(data_cfg["train"]) + data_cfg["path"] = str(data_cfg["path"]) + data_cfg["path"] = ( + data_cfg["path"][:-1] if data_cfg["path"][-1] == "/" else data_cfg["path"] + ) + for key in ["train", "val"]: + if not isinstance(data_cfg[key], list): + data_cfg[key] = [data_cfg[key]] + new_list = [] + for tmp in data_cfg[key]: + if hasattr(hparams, "data_dir"): + if hparams.data_dir != data_cfg["path"]: + tmp = str(tmp).replace(data_cfg["path"], "") + tmp = tmp[1:] if tmp[0] == "/" else tmp + tmp = os.path.join(hparams.data_dir, tmp) + new_list.append(tmp) + data_cfg[key] = new_list + + data_cfg["path"] = hparams.data_dir + + return data_cfg + + +if __name__ == "__main__": + assert len(sys.argv) > 1, "Please pass the configuration file to the script." + hparams = parse_configuration(sys.argv[1]) + if len(hparams.input_shape) != 3: + hparams.input_shape = [ + int(x) for x in "".join(hparams.input_shape).split(",") + ] # temp solution + print(f"Setting input shape to {hparams.input_shape}.") + + m_cfg, data_cfg = load_config(hparams.data_cfg) + + # check if specified path for images is different, correct it in case + # data_cfg = replace_datafolder(hparams, data_cfg) + m_cfg.imgsz = hparams.input_shape[-1] # temp solution + + train_loader, val_loader = create_loaders(m_cfg, data_cfg, hparams.batch_size) + + exp_folder = mm.utils.checkpointer.create_experiment_folder( + hparams.output_folder, hparams.experiment_name + ) + + checkpointer = mm.utils.checkpointer.Checkpointer( + exp_folder, hparams=hparams, key="loss" + ) + + yolo_mind = YOLO(m_cfg, hparams=hparams) + + yolo_mind.train( + epochs=hparams.epochs, + datasets={"train": train_loader, "val": val_loader}, + metrics=[], + checkpointer=checkpointer, + debug=hparams.debug, + ) diff --git a/recipes/object_detection/validate.py b/recipes/object_detection/validate.py new file mode 100644 index 0000000..8cf1a56 --- /dev/null +++ b/recipes/object_detection/validate.py @@ -0,0 +1,67 @@ +""" +YOLO training. + +This code allows you to validate an object detection model. + +To run this script, you can start it with: + python validate.py cfg/.py + +Authors: + - Matteo Beltrami, 2024 + - Francesco Paissan, 2024 +""" + +from micromind.utils import parse_configuration +from micromind.utils.yolo import ( + load_config, +) +import sys +from validation.validator import DetectionValidator + +from train import YOLO, replace_datafolder + + +class YOLO(YOLO): + def __init__(self, m_cfg, hparams, *args, **kwargs): + """Initializes the YOLO model.""" + super().__init__(m_cfg, hparams, *args, **kwargs) + self.m_cfg = m_cfg + self.device = "cuda" + + def forward(self, img): + """Runs the forward method by calling every module.""" + backbone = self.modules["backbone"](img) + if "sppf" in self.modules.keys(): + neck_input = backbone[1] + neck_input.append(self.modules["sppf"](backbone[0])) + else: + neck_input = backbone + neck = self.modules["neck"](*neck_input) + head = self.modules["head"](neck) + + return head + + +if __name__ == "__main__": + assertion_msg = "Usage: python validate.py " + assert len(sys.argv) >= 3, assertion_msg + + hparams = parse_configuration(sys.argv[1]) + m_cfg, data_cfg = load_config(hparams.data_cfg) + data_cfg = replace_datafolder(hparams, data_cfg) + + m_cfg.imgsz = hparams.input_shape[-1] # temp solution + + model_weights_path = sys.argv[2] + args = dict(model="yolov8n.pt", data=hparams.data_cfg, verbose=False, plots=False) + validator = DetectionValidator(args=args) + + model = YOLO(m_cfg, hparams) + model.load_modules(model_weights_path) + + val = validator(model=model) + + print("METRICS:") + print("Box map50-95:", round(validator.metrics.box.map * 100, 3), "%") + print("Box map50:", round(validator.metrics.box.map50 * 100, 3), "%") + print("Box map75:", round(validator.metrics.box.map75 * 100, 3), "%") diff --git a/recipes/object_detection/validation/autobackend.py b/recipes/object_detection/validation/autobackend.py new file mode 100644 index 0000000..64ad83c --- /dev/null +++ b/recipes/object_detection/validation/autobackend.py @@ -0,0 +1,661 @@ +import ast +import contextlib +import json +import platform +import zipfile +from collections import OrderedDict, namedtuple +from pathlib import Path + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from PIL import Image + +from ultralytics.utils import ARM64, LINUX, LOGGER, ROOT, yaml_load +from ultralytics.utils.checks import ( + check_requirements, + check_suffix, + check_version, + check_yaml, +) +from ultralytics.utils.downloads import attempt_download_asset, is_url + + +def check_class_names(names): + """ + Check class names. + + Map imagenet class codes to human-readable names if required. + Convert lists to dicts. + """ + if isinstance(names, list): # names is a list + names = dict(enumerate(names)) # convert to dict + if isinstance(names, dict): + # Convert 1) string keys to int, i.e. '0' to 0, + # and non-string values to strings, i.e. True to 'True' + names = {int(k): str(v) for k, v in names.items()} + n = len(names) + if max(names.keys()) >= n: + raise KeyError( + f"{n}-class dataset requires class indices 0-{n - 1}, \ + but you have invalid class indices" + f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML." + ) + if isinstance(names[0], str) and names[0].startswith( + "n0" + ): # imagenet class codes, i.e. 'n01440764' + names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")[ + "map" + ] # human-readable names + names = {k: names_map[v] for k, v in names.items()} + return names + + +def default_class_names(data=None): + """Applies default class names to an input YAML file or + returns numerical class names.""" + if data: + with contextlib.suppress(Exception): + return yaml_load(check_yaml(data))["names"] + return {i: f"class{i}" for i in range(999)} # return default if above errors + + +class AutoBackend(nn.Module): + """ + Handles dynamic backend selection for running inference using + Ultralytics YOLO models. + + The AutoBackend class is designed to provide an abstraction layer for various + inference engines. It supports a wide + range of formats, each with specific naming conventions as outlined below: + + Supported Formats and Naming Conventions: + | Format | File Suffix | + |-----------------------|------------------| + | PyTorch | *.pt | + | TorchScript | *.torchscript | + | ONNX Runtime | *.onnx | + | ONNX OpenCV DNN | *.onnx (dnn=True)| + | OpenVINO | *openvino_model/ | + | CoreML | *.mlpackage | + | TensorRT | *.engine | + | TensorFlow SavedModel | *_saved_model | + | TensorFlow GraphDef | *.pb | + | TensorFlow Lite | *.tflite | + | TensorFlow Edge TPU | *_edgetpu.tflite | + | PaddlePaddle | *_paddle_model | + | ncnn | *_ncnn_model | + + This class offers dynamic backend switching capabilities based on the input + model format, making it easier to deploy models across various platforms. + """ + + @torch.no_grad() + def __init__( + self, + weights="yolov8n.pt", + device=torch.device("cpu"), + dnn=False, + data=None, + fp16=False, + fuse=True, + verbose=True, + ): + """ + Initialize the AutoBackend for inference. + + Args: + weights (str): Path to the model weights file. Defaults to 'yolov8n.pt'. + device (torch.device): Device to run the model on. Defaults to CPU. + dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False. + data (str | Path | optional): Path to the additional data.yaml file + containing class names. Optional. + fp16 (bool): Enable half-precision inference. Supported only on specific + backends. Defaults to False. + fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. + Defaults to True. + verbose (bool): Enable verbose logging. Defaults to True. + """ + super().__init__() + self.modell = weights + weights = weights.modules + w = str(weights[0] if isinstance(weights, list) else weights) + nn_module = isinstance(weights, torch.nn.Module) + ( + pt, + jit, + onnx, + xml, + engine, + coreml, + saved_model, + pb, + tflite, + edgetpu, + tfjs, + paddle, + ncnn, + triton, + ) = self._model_type(w) + fp16 &= pt or jit or onnx or xml or engine or nn_module # or triton # FP16 + nhwc = ( + coreml or saved_model or pb or tflite or edgetpu + ) # BHWC formats (vs torch BCWH) + stride = 32 # default stride + model, metadata = None, None + + # Set device + cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA + if cuda and not any( + [nn_module, pt, jit, engine, onnx] + ): # GPU dataloader formats + device = torch.device("cpu") + cuda = False + + # Download if not local + if not (pt or triton or nn_module): + w = attempt_download_asset(w) + + # Load model + if nn_module: # in-memory PyTorch model + model = weights.to(device) + model = model.fuse(verbose=verbose) if fuse else model + if hasattr(model, "kpt_shape"): + kpt_shape = model.kpt_shape # pose-only + # stride = max(int(model.stride.max()), 32) + # names = model.module.names if hasattr(model, "module") else model.names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + pt = True + elif pt: # PyTorch + from ultralytics.nn.tasks import attempt_load_weights + + model = attempt_load_weights( + weights if isinstance(weights, list) else w, + device=device, + inplace=True, + fuse=fuse, + ) + if hasattr(model, "kpt_shape"): + kpt_shape = model.kpt_shape # pose-only + stride = max(int(model.stride.max()), 32) # model stride + names = ( + model.module.names if hasattr(model, "module") else model.names + ) # get class names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + elif jit: # TorchScript + LOGGER.info(f"Loading {w} for TorchScript inference...") + extra_files = {"config.txt": ""} # model metadata + model = torch.jit.load(w, _extra_files=extra_files, map_location=device) + model.half() if fp16 else model.float() + if extra_files["config.txt"]: # load metadata dict + metadata = json.loads( + extra_files["config.txt"], object_hook=lambda x: dict(x.items()) + ) + elif dnn: # ONNX OpenCV DNN + LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...") + check_requirements("opencv-python>=4.5.4") + net = cv2.dnn.readNetFromONNX(w) + elif onnx: # ONNX Runtime + LOGGER.info(f"Loading {w} for ONNX Runtime inference...") + check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) + import onnxruntime + + providers = ( + ["CUDAExecutionProvider", "CPUExecutionProvider"] + if cuda + else ["CPUExecutionProvider"] + ) + session = onnxruntime.InferenceSession(w, providers=providers) + output_names = [x.name for x in session.get_outputs()] + metadata = session.get_modelmeta().custom_metadata_map # metadata + elif xml: # OpenVINO + LOGGER.info(f"Loading {w} for OpenVINO inference...") + check_requirements( + "openvino>=2023.0" + ) # requires openvino-dev: https://pypi.org/project/openvino-dev/ + from openvino.runtime import Core, Layout, get_batch # noqa + + core = Core() + w = Path(w) + if not w.is_file(): # if not *.xml + w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir + ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin")) + if ov_model.get_parameters()[0].get_layout().empty: + ov_model.get_parameters()[0].set_layout(Layout("NCHW")) + batch_dim = get_batch(ov_model) + if batch_dim.is_static: + batch_size = batch_dim.get_length() + ov_compiled_model = core.compile_model( + ov_model, device_name="AUTO" + ) # AUTO selects best available device + metadata = w.parent / "metadata.yaml" + elif engine: # TensorRT + LOGGER.info(f"Loading {w} for TensorRT inference...") + try: + import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download + except ImportError: + if LINUX: + check_requirements( + "nvidia-tensorrt", + cmds="-U --index-url https://pypi.ngc.nvidia.com", + ) + import tensorrt as trt # noqa + check_version( + trt.__version__, "7.0.0", hard=True + ) # require tensorrt>=7.0.0 + if device.type == "cpu": + device = torch.device("cuda:0") + Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) + logger = trt.Logger(trt.Logger.INFO) + # Read file + with open(w, "rb") as f, trt.Runtime(logger) as runtime: + meta_len = int.from_bytes( + f.read(4), byteorder="little" + ) # read metadata length + metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata + model = runtime.deserialize_cuda_engine(f.read()) # read engine + context = model.create_execution_context() + bindings = OrderedDict() + output_names = [] + fp16 = False # default updated below + dynamic = False + for i in range(model.num_bindings): + name = model.get_binding_name(i) + dtype = trt.nptype(model.get_binding_dtype(i)) + if model.binding_is_input(i): + if -1 in tuple(model.get_binding_shape(i)): # dynamic + dynamic = True + context.set_binding_shape( + i, tuple(model.get_profile_shape(0, i)[2]) + ) + if dtype == np.float16: + fp16 = True + else: # output + output_names.append(name) + shape = tuple(context.get_binding_shape(i)) + im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) + bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) + binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) + batch_size = bindings["images"].shape[ + 0 + ] # if dynamic, this is instead max batch size + elif coreml: # CoreML + LOGGER.info(f"Loading {w} for CoreML inference...") + import coremltools as ct + + model = ct.models.MLModel(w) + metadata = dict(model.user_defined_metadata) + elif saved_model: # TF SavedModel + LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...") + import tensorflow as tf + + keras = False # assume TF1 saved_model + model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) + metadata = Path(w) / "metadata.yaml" + elif ( + pb + ): # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt + LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...") + import tensorflow as tf + + from ultralytics.engine.exporter import gd_outputs + + def wrap_frozen_graph(gd, inputs, outputs): + """Wrap frozen graphs for deployment.""" + x = tf.compat.v1.wrap_function( + lambda: tf.compat.v1.import_graph_def(gd, name=""), [] + ) # wrapped + ge = x.graph.as_graph_element + return x.prune( + tf.nest.map_structure(ge, inputs), + tf.nest.map_structure(ge, outputs), + ) + + gd = tf.Graph().as_graph_def() # TF GraphDef + with open(w, "rb") as f: + gd.ParseFromString(f.read()) + frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) + elif tflite or edgetpu: + try: + from tflite_runtime.interpreter import Interpreter, load_delegate + except ImportError: + import tensorflow as tf + + Interpreter, load_delegate = ( + tf.lite.Interpreter, + tf.lite.experimental.load_delegate, + ) + if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime + LOGGER.info(f"Loading {w} for TensorFlow Lite Edge TPU inference...") + delegate = { + "Linux": "libedgetpu.so.1", + "Darwin": "libedgetpu.1.dylib", + "Windows": "edgetpu.dll", + }[platform.system()] + interpreter = Interpreter( + model_path=w, experimental_delegates=[load_delegate(delegate)] + ) + else: # TFLite + LOGGER.info(f"Loading {w} for TensorFlow Lite inference...") + interpreter = Interpreter(model_path=w) # load TFLite model + interpreter.allocate_tensors() # allocate + input_details = interpreter.get_input_details() # inputs + output_details = interpreter.get_output_details() # outputs + # Load metadata + with contextlib.suppress(zipfile.BadZipFile): + with zipfile.ZipFile(w, "r") as model: + meta_file = model.namelist()[0] + metadata = ast.literal_eval(model.read(meta_file).decode("utf-8")) + elif tfjs: # TF.js + raise NotImplementedError( + "YOLOv8 TF.js inference is not currently supported." + ) + elif paddle: # PaddlePaddle + LOGGER.info(f"Loading {w} for PaddlePaddle inference...") + check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle") + import paddle.inference as pdi # noqa + + w = Path(w) + if not w.is_file(): # if not *.pdmodel + w = next( + w.rglob("*.pdmodel") + ) # get *.pdmodel file from *_paddle_model dir + config = pdi.Config(str(w), str(w.with_suffix(".pdiparams"))) + if cuda: + config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0) + predictor = pdi.create_predictor(config) + input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) + output_names = predictor.get_output_names() + metadata = w.parents[1] / "metadata.yaml" + elif ncnn: # ncnn + LOGGER.info(f"Loading {w} for ncnn inference...") + check_requirements( + "git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn" + ) # requires ncnn + import ncnn as pyncnn + + net = pyncnn.Net() + net.opt.use_vulkan_compute = cuda + w = Path(w) + if not w.is_file(): # if not *.param + w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir + net.load_param(str(w)) + net.load_model(str(w.with_suffix(".bin"))) + metadata = w.parent / "metadata.yaml" + elif triton: # NVIDIA Triton Inference Server + check_requirements("tritonclient[all]") + from ultralytics.utils.triton import TritonRemoteModel + + model = TritonRemoteModel(w) + else: + from ultralytics.engine.exporter import export_formats + + raise TypeError( + f"model='{w}' is not a supported model format. " + "See https://docs.ultralytics.com/modes/predict for help." + f"\n\n{export_formats()}" + ) + + # Load external metadata YAML + if isinstance(metadata, (str, Path)) and Path(metadata).exists(): + metadata = yaml_load(metadata) + if metadata: + for k, v in metadata.items(): + if k in ("stride", "batch"): + metadata[k] = int(v) + elif k in ("imgsz", "names", "kpt_shape") and isinstance(v, str): + metadata[k] = eval(v) + stride = metadata["stride"] + task = metadata["task"] + batch = metadata["batch"] + imgsz = metadata["imgsz"] + names = metadata["names"] + kpt_shape = metadata.get("kpt_shape") + elif not (pt or triton or nn_module): + LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'") + + # Check names + if "names" not in locals(): # names missing + names = default_class_names(data) + names = check_class_names(names) + + """Disable gradients since this is only used in validation.""" + # if pt: + # for p in model.parameters(): + # p.requires_grad = False + + self.__dict__.update(locals()) # assign all variables to self + + def forward(self, im, augment=False, visualize=False, embed=None): + """ + Runs inference on the YOLOv8 MultiBackend model. + + Args: + im (torch.Tensor): The image tensor to perform inference on. + augment (bool): whether to perform data augmentation during inference, + defaults to False + visualize (bool): whether to visualize the output predictions, + defaults to False + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (tuple): Tuple containing the raw output tensor, + and processed output for visualization (if visualize=True) + """ + b, ch, h, w = im.shape # batch, channel, height, width + if self.fp16 and im.dtype != torch.float16: + im = im.half() # to FP16 + if self.nhwc: + im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3) + + if self.pt or self.nn_module: # PyTorch + y = self.modell(im) + elif self.jit: # TorchScript + y = self.model(im) + elif self.dnn: # ONNX OpenCV DNN + im = im.cpu().numpy() # torch to numpy + self.net.setInput(im) + y = self.net.forward() + elif self.onnx: # ONNX Runtime + im = im.cpu().numpy() # torch to numpy + y = self.session.run( + self.output_names, {self.session.get_inputs()[0].name: im} + ) + elif self.xml: # OpenVINO + im = im.cpu().numpy() # FP32 + y = list(self.ov_compiled_model(im).values()) + elif self.engine: # TensorRT + if self.dynamic and im.shape != self.bindings["images"].shape: + i = self.model.get_binding_index("images") + self.context.set_binding_shape(i, im.shape) # reshape if dynamic + self.bindings["images"] = self.bindings["images"]._replace( + shape=im.shape + ) + for name in self.output_names: + i = self.model.get_binding_index(name) + self.bindings[name].data.resize_( + tuple(self.context.get_binding_shape(i)) + ) + s = self.bindings["images"].shape + assert ( + im.shape == s + ), f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} \ + max model size {s}" + self.binding_addrs["images"] = int(im.data_ptr()) + self.context.execute_v2(list(self.binding_addrs.values())) + y = [self.bindings[x].data for x in sorted(self.output_names)] + elif self.coreml: # CoreML + im = im[0].cpu().numpy() + im_pil = Image.fromarray((im * 255).astype("uint8")) + # im = im.resize((192, 320), Image.BILINEAR) + y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized + if "confidence" in y: + raise TypeError( + "Ultralytics only supports inference of non-pipelined CoreML \ + models exported with " + f"'nms=False', but 'model={w}' has an NMS pipeline created by \ + an 'nms=True' export." + ) + elif len(y) == 1: # classification model + y = list(y.values()) + elif len(y) == 2: # segmentation model + y = list( + reversed(y.values()) + ) # reversed for segmentation models (pred, proto) + elif self.paddle: # PaddlePaddle + im = im.cpu().numpy().astype(np.float32) + self.input_handle.copy_from_cpu(im) + self.predictor.run() + y = [ + self.predictor.get_output_handle(x).copy_to_cpu() + for x in self.output_names + ] + elif self.ncnn: # ncnn + mat_in = self.pyncnn.Mat(im[0].cpu().numpy()) + ex = self.net.create_extractor() + input_names, output_names = self.net.input_names(), self.net.output_names() + ex.input(input_names[0], mat_in) + y = [] + for output_name in output_names: + mat_out = self.pyncnn.Mat() + ex.extract(output_name, mat_out) + y.append(np.array(mat_out)[None]) + elif self.triton: # NVIDIA Triton Inference Server + im = im.cpu().numpy() # torch to numpy + y = self.model(im) + else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) + im = im.cpu().numpy() + if self.saved_model: # SavedModel + y = self.model(im, training=False) if self.keras else self.model(im) + if not isinstance(y, list): + y = [y] + elif self.pb: # GraphDef + y = self.frozen_func(x=self.tf.constant(im)) + if ( + len(y) == 2 and len(self.names) == 999 + ): # segments and names not defined + ip, ib = ( + (0, 1) if len(y[0].shape) == 4 else (1, 0) + ) # index of protos, boxes + nc = ( + y[ib].shape[1] - y[ip].shape[3] - 4 + ) # y = (1, 160, 160, 32), (1, 116, 8400) + self.names = {i: f"class{i}" for i in range(nc)} + else: # Lite or Edge TPU + details = self.input_details[0] + integer = details["dtype"] in ( + np.int8, + np.int16, + ) # is TFLite quantized int8 or int16 model + if integer: + scale, zero_point = details["quantization"] + im = (im / scale + zero_point).astype(details["dtype"]) # de-scale + self.interpreter.set_tensor(details["index"], im) + self.interpreter.invoke() + y = [] + for output in self.output_details: + x = self.interpreter.get_tensor(output["index"]) + if integer: + scale, zero_point = output["quantization"] + x = (x.astype(np.float32) - zero_point) * scale # re-scale + if x.ndim > 2: # if task is not classification + x[:, [0, 2]] *= w + x[:, [1, 3]] *= h + y.append(x) + # TF segment fixes: export is reversed vs ONNX export + # and protos are transposed + if len(y) == 2: # segment with (det, proto) output order reversed + if len(y[1].shape) != 4: + y = list( + reversed(y) + ) # should be y = (1, 116, 8400), (1, 160, 160, 32) + y[1] = np.transpose( + y[1], (0, 3, 1, 2) + ) # should be y = (1, 116, 8400), (1, 32, 160, 160) + y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y] + + if isinstance(y, (list, tuple)): + return ( + self.from_numpy(y[0]) + if len(y) == 1 + else [self.from_numpy(x) for x in y] + ) + else: + return self.from_numpy(y) + + def from_numpy(self, x): + """ + Convert a numpy array to a tensor. + + Args: + x (np.ndarray): The array to be converted. + + Returns: + (torch.Tensor): The converted tensor + """ + return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x + + def warmup(self, imgsz=(1, 3, 640, 640)): + """ + Warm up the model by running one forward pass with a dummy input. + + Args: + imgsz (tuple): The shape of the dummy input tensor in the format + (batch_size, channels, height, width) + """ + warmup_types = ( + self.pt, + self.jit, + self.onnx, + self.engine, + self.saved_model, + self.pb, + self.triton, + self.nn_module, + ) + if any(warmup_types) and (self.device.type != "cpu" or self.triton): + im = torch.empty( + *imgsz, + dtype=torch.half if self.fp16 else torch.float, + device=self.device, + ) # input + for _ in range(2 if self.jit else 1): + self.forward(im) # warmup + + @staticmethod + def _model_type(p="path/to/model.pt"): + """ + This function takes a path to a model file and returns the model type. + Possibles types are pt, jit, onnx, xml, engine, coreml, saved_model, pb, + tflite, edgetpu, tfjs, ncnn or paddle. + + Args: + p: path to the model file. Defaults to path/to/model.pt + + Examples: + >>> model = AutoBackend(weights="path/to/model.onnx") + >>> model_type = model._model_type() # returns "onnx" + """ + from ultralytics.engine.exporter import export_formats + + sf = list(export_formats().Suffix) # export suffixes + if not is_url(p, check=False) and not isinstance(p, str): + check_suffix(p, sf) # checks + name = Path(p).name + types = [s in name for s in sf] + types[5] |= name.endswith( + ".mlmodel" + ) # retain support for older Apple CoreML *.mlmodel formats + types[8] &= not types[9] # tflite &= not edgetpu + if any(types): + triton = False + else: + from urllib.parse import urlsplit + + url = urlsplit(p) + triton = url.netloc and url.path and url.scheme in {"http", "grpc"} + + return types + [triton] diff --git a/recipes/object_detection/validation/validator.py b/recipes/object_detection/validation/validator.py new file mode 100644 index 0000000..32a0ea7 --- /dev/null +++ b/recipes/object_detection/validation/validator.py @@ -0,0 +1,756 @@ +import json +import time +import os +from pathlib import Path + +import numpy as np +import torch + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.data.utils import check_cls_dataset, check_det_dataset +from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis +from ultralytics.utils.checks import check_imgsz +from ultralytics.utils.ops import Profile +from ultralytics.utils.torch_utils import ( + de_parallel, + select_device, + smart_inference_mode, +) + +from ultralytics.data import build_dataloader, build_yolo_dataset, converter +from ultralytics.utils import ops +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou +from ultralytics.utils.plotting import output_to_target, plot_images + +from .autobackend import AutoBackend + + +class BaseValidator: + """ + BaseValidator. + + A base class for creating validators. + + Attributes: + args (SimpleNamespace): Configuration for the validator. + dataloader (DataLoader): Dataloader to use for validation. + pbar (tqdm): Progress bar to update during validation. + model (nn.Module): Model to validate. + data (dict): Data dictionary. + device (torch.device): Device to use for validation. + batch_i (int): Current batch index. + training (bool): Whether the model is in training mode. + names (dict): Class names. + seen: Records the number of images seen so far during validation. + stats: Placeholder for statistics during validation. + confusion_matrix: Placeholder for a confusion matrix. + nc: Number of classes. + iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05. + jdict (dict): Dictionary to store JSON validation results. + speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', + 'postprocess' and their respective batch processing times in milliseconds. + save_dir (Path): Directory to save results. + plots (dict): Dictionary to store plots for visualization. + callbacks (dict): Dictionary to store various callback functions. + """ + + def __init__( + self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None + ): + """ + Initializes a BaseValidator instance. + + Args: + dataloader (torch.utils.data.DataLoader): Dataloader to be used for + validation. + save_dir (Path, optional): Directory to save results. + pbar (tqdm.tqdm): Progress bar for displaying progress. + args (SimpleNamespace): Configuration for the validator. + _callbacks (dict): Dictionary to store various callback functions. + """ + self.args = get_cfg(overrides=args) + self.dataloader = dataloader + self.args.plot = False + self.pbar = pbar + self.stride = None + self.data = None + self.device = None + self.batch_i = None + self.training = True + self.names = None + self.seen = None + self.stats = None + self.confusion_matrix = None + self.nc = None + self.iouv = None + self.jdict = None + self.device = "cuda" + self.speed = { + "preprocess": 0.0, + "inference": 0.0, + "loss": 0.0, + "postprocess": 0.0, + } + + self.save_dir = save_dir or get_save_dir(self.args) + + """This creates a folder `runs/detect/train` in which it's saved + the result of every validation.""" + if self.args.plots: + (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir( + parents=True, exist_ok=True + ) + if self.args.conf is None: + self.args.conf = 0.001 # default conf=0.001 + self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1) + + self.plots = {} + self.callbacks = _callbacks or callbacks.get_default_callbacks() + + @smart_inference_mode() + def __call__(self, trainer=None, model=None): + """Supports validation of a pre-trained model if passed or a model being + trained if trainer is passed (trainer gets priority). + """ + self.training = trainer is not None + augment = self.args.augment and (not self.training) + if self.training: + self.device = trainer.device + self.data = trainer.data + self.args.half = self.device.type != "cpu" # force FP16 val during training + model = trainer.ema.ema or trainer.model + model = model.half() if self.args.half else model.float() + # self.model = model + self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device) + self.args.plots &= trainer.stopper.possible_stop or ( + trainer.epoch == trainer.epochs - 1 + ) + model.eval() + else: + callbacks.add_integration_callbacks(self) + model = AutoBackend( + model or self.args.model, + device=select_device(self.args.device, self.args.batch), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + fuse=False, + ) + # self.model = model + self.device = model.device # update device + self.args.half = model.fp16 # update half + stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine + imgsz = check_imgsz(self.args.imgsz, stride=stride) + if engine: + self.args.batch = model.batch_size + elif not pt and not jit: + self.args.batch = 1 # export.py models default to batch-size 1 + LOGGER.info( + f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) \ + for non-PyTorch models" + ) + + if str(self.args.data).split(".")[-1] in ("yaml", "yml"): + self.data = check_det_dataset(self.args.data) + elif self.args.task == "classify": + self.data = check_cls_dataset(self.args.data, split=self.args.split) + else: + raise FileNotFoundError( + emojis( + f"Dataset '{self.args.data}' for task={self.args.task} not \ + found ❌" + ) + ) + + if self.device.type in ("cpu", "mps"): + self.args.workers = ( + 0 # faster CPU val as time dominated by inference, not dataloading + ) + if not pt: + self.args.rect = False + self.stride = model.stride # used in get_dataloader() for padding + self.dataloader = self.dataloader or self.get_dataloader( + self.data.get(self.args.split), self.args.batch + ) + + model.eval() + model.warmup( + imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz) + ) # warmup + + self.run_callbacks("on_val_start") + dt = ( + Profile(), + Profile(), + Profile(), + Profile(), + ) + bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader)) + self.init_metrics(de_parallel(model)) + self.jdict = [] # empty before each val + for batch_i, batch in enumerate(bar): + self.run_callbacks("on_val_batch_start") + self.batch_i = batch_i + # Preprocess + with dt[0]: + batch = self.preprocess(batch) + + # Inference + with dt[1]: + preds = model(batch["img"], augment=augment) + + # Loss + with dt[2]: + if self.training: + self.loss += model.loss(batch, preds)[1] + + # Postprocess + with dt[3]: + preds = self.postprocess(preds) + + self.update_metrics(preds, batch) + if self.args.plots and batch_i < 3: + self.plot_val_samples(batch, batch_i) + self.plot_predictions(batch, preds, batch_i) + + self.run_callbacks("on_val_batch_end") + stats = self.get_stats() + self.check_stats(stats) + self.speed = dict( + zip( + self.speed.keys(), + (x.t / len(self.dataloader.dataset) * 1e3 for x in dt), + ) + ) + self.finalize_metrics() + self.print_results() + self.run_callbacks("on_val_end") + if self.training: + model.float() + results = { + **stats, + **trainer.label_loss_items( + self.loss.cpu() / len(self.dataloader), prefix="val" + ), + } + return { + k: round(float(v), 5) for k, v in results.items() + } # return results as 5 decimal place floats + else: + LOGGER.info( + "Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms \ + postprocess per image" + % tuple(self.speed.values()) + ) + if self.args.save_json and self.jdict: + with open(str(self.save_dir / "predictions.json"), "w") as f: + LOGGER.info(f"Saving {f.name}...") + json.dump(self.jdict, f) # flatten and save + stats = self.eval_json(stats) # update stats + if self.args.plots or self.args.save_json: + LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") + return stats + + def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False): + """ + Matches predictions to ground truth objects (pred_classes, true_classes) + using IoU. + + Args: + pred_classes (torch.Tensor): Predicted class indices of shape(N,). + true_classes (torch.Tensor): Target class indices of shape(M,). + iou (torch.Tensor): An NxM tensor containing the pairwise IoU values + for predictions and ground of truth + use_scipy (bool): Whether to use scipy for matching (more precise). + + Returns: + (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds. + """ + # Dx10 matrix, where D - detections, 10 - IoU thresholds + correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool) + # LxD matrix where L - labels (rows), D - detections (columns) + correct_class = true_classes[:, None] == pred_classes + iou = iou * correct_class # zero out the wrong classes + iou = iou.cpu().numpy() + for i, threshold in enumerate(self.iouv.cpu().tolist()): + if use_scipy: + import scipy # scope import to avoid importing for all commands + + cost_matrix = iou * (iou >= threshold) + if cost_matrix.any(): + labels_idx, detections_idx = scipy.optimize.linear_sum_assignment( + cost_matrix, maximize=True + ) + valid = cost_matrix[labels_idx, detections_idx] > 0 + if valid.any(): + correct[detections_idx[valid], i] = True + else: + matches = np.nonzero( + iou >= threshold + ) # IoU > threshold and classes match + matches = np.array(matches).T + if matches.shape[0]: + if matches.shape[0] > 1: + matches = matches[ + iou[matches[:, 0], matches[:, 1]].argsort()[::-1] + ] + matches = matches[ + np.unique(matches[:, 1], return_index=True)[1] + ] + # matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[ + np.unique(matches[:, 0], return_index=True)[1] + ] + correct[matches[:, 1].astype(int), i] = True + return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device) + + def add_callback(self, event: str, callback): + """Appends the given callback.""" + self.callbacks[event].append(callback) + + def run_callbacks(self, event: str): + """Runs all callbacks associated with a specified event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + def get_dataloader(self, dataset_path, batch_size): + """Get data loader from dataset path and batch size.""" + raise NotImplementedError( + "get_dataloader function not implemented for this validator" + ) + + def build_dataset(self, img_path): + """Build dataset.""" + raise NotImplementedError("build_dataset function not implemented in validator") + + def preprocess(self, batch): + """Preprocesses an input batch.""" + return batch + + def postprocess(self, preds): + """Describes and summarizes the purpose of 'postprocess()' but no + details mentioned.""" + return preds + + def init_metrics(self, model): + """Initialize performance metrics for the YOLO model.""" + pass + + def update_metrics(self, preds, batch): + """Updates metrics based on predictions and batch.""" + pass + + def finalize_metrics(self, *args, **kwargs): + """Finalizes and returns all metrics.""" + pass + + def get_stats(self): + """Returns statistics about the model's performance.""" + return {} + + def check_stats(self, stats): + """Checks statistics.""" + pass + + def print_results(self): + """Prints the results of the model's predictions.""" + pass + + def get_desc(self): + """Get description of the YOLO model.""" + pass + + @property + def metric_keys(self): + """Returns the metric keys used in YOLO training/validation.""" + return [] + + def on_plot(self, name, data=None): + """Registers plots (e.g. to be consumed in callbacks)""" + self.plots[Path(name)] = {"data": data, "timestamp": time.time()} + + # TODO: may need to put these following functions into callback + def plot_val_samples(self, batch, ni): + """Plots validation samples during training.""" + pass + + def plot_predictions(self, batch, preds, ni): + """Plots YOLO model predictions on batch images.""" + pass + + def pred_to_json(self, preds, batch): + """Convert predictions to JSON format.""" + pass + + def eval_json(self, stats): + """Evaluate and return JSON format of prediction statistics.""" + pass + + +class DetectionValidator(BaseValidator): + """ + A class extending the BaseValidator class for validation based on a detection model. + + Example: + ```python + from ultralytics.models.yolo.detect import DetectionValidator + + args = dict(model='yolov8n.pt', data='coco8.yaml') + validator = DetectionValidator(args=args) + validator() + ``` + """ + + def __init__( + self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None + ): + """Initialize detection model with necessary variables and settings.""" + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.nt_per_class = None + self.is_coco = False + self.class_map = None + self.args.task = "detect" + self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot) + self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95 + self.niou = self.iouv.numel() + self.lb = [] # for autolabelling + + def preprocess(self, batch): + """Preprocesses batch of images for YOLO training.""" + batch["img"] = batch["img"].to(self.device, non_blocking=True) + batch["img"] = ( + batch["img"].half() if self.args.half else batch["img"].float() + ) / 255 + for k in ["batch_idx", "cls", "bboxes"]: + batch[k] = batch[k].to(self.device) + + if self.args.save_hybrid: + height, width = batch["img"].shape[2:] + nb = len(batch["img"]) + bboxes = batch["bboxes"] * torch.tensor( + (width, height, width, height), device=self.device + ) + self.lb = ( + [ + torch.cat( + [ + batch["cls"][batch["batch_idx"] == i], + bboxes[batch["batch_idx"] == i], + ], + dim=-1, + ) + for i in range(nb) + ] + if self.args.save_hybrid + else [] + ) # for autolabelling + + return batch + + def init_metrics(self, model): + """Initialize evaluation metrics for YOLO.""" + val = self.data.get(self.args.split, "") # validation path + self.is_coco = ( + isinstance(val, str) + and "coco" in val + and val.endswith(f"{os.sep}val2017.txt") + ) # is COCO + self.class_map = ( + converter.coco80_to_coco91_class() if self.is_coco else list(range(1000)) + ) + self.args.save_json |= ( + self.is_coco and not self.training + ) # run on final val if training COCO + self.names = model.names + self.nc = len(model.names) + self.metrics.names = self.names + self.metrics.plot = self.args.plots + self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf) + self.seen = 0 + self.jdict = [] + self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[]) + + def get_desc(self): + """Return a formatted string summarizing class metrics of YOLO model.""" + return ("%22s" + "%11s" * 6) % ( + "Class", + "Images", + "Instances", + "Box(P", + "R", + "mAP50", + "mAP50-95)", + ) + + def postprocess(self, preds): + """Apply Non-maximum suppression to prediction outputs.""" + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=True, + agnostic=self.args.single_cls, + max_det=self.args.max_det, + ) + + def _prepare_batch(self, si, batch): + """Prepares a batch of images and annotations for validation.""" + idx = batch["batch_idx"] == si + cls = batch["cls"][idx].squeeze(-1) + bbox = batch["bboxes"][idx] + ori_shape = batch["ori_shape"][si] + imgsz = batch["img"].shape[2:] + ratio_pad = batch["ratio_pad"][si] + if len(cls): + bbox = ( + ops.xywh2xyxy(bbox) + * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] + ) # target boxes + ops.scale_boxes( + imgsz, bbox, ori_shape, ratio_pad=ratio_pad + ) # native-space labels + return dict( + cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad + ) + + def _prepare_pred(self, pred, pbatch): + """Prepares a batch of images and annotations for validation.""" + predn = pred.clone() + ops.scale_boxes( + pbatch["imgsz"], + predn[:, :4], + pbatch["ori_shape"], + ratio_pad=pbatch["ratio_pad"], + ) # native-space pred + return predn + + def update_metrics(self, preds, batch): + """Metrics.""" + for si, pred in enumerate(preds): + self.seen += 1 + npr = len(pred) + stat = dict( + conf=torch.zeros(0, device=self.device), + pred_cls=torch.zeros(0, device=self.device), + tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + ) + pbatch = self._prepare_batch(si, batch) + cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") + nl = len(cls) + stat["target_cls"] = cls + if npr == 0: + if nl: + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + # TODO: obb has not supported confusion_matrix yet. + # if self.args.plots and self.args.task != "obb": + # self.confusion_matrix.process_batch( + # detections=None, gt_bboxes=bbox, gt_cls=cls + # ) + continue + + # Predictions + if self.args.single_cls: + pred[:, 5] = 0 + predn = self._prepare_pred(pred, pbatch) + stat["conf"] = predn[:, 4] + stat["pred_cls"] = predn[:, 5] + + # Evaluate + if nl: + stat["tp"] = self._process_batch(predn, bbox, cls) + # TODO: obb has not supported confusion_matrix yet. + # if self.args.plots and self.args.task != "obb": + # self.confusion_matrix.process_batch(predn, bbox, cls) + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + + # Save + if self.args.save_json: + self.pred_to_json(predn, batch["im_file"][si]) + if self.args.save_txt: + file = ( + self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt' + ) + self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file) + + def finalize_metrics(self, *args, **kwargs): + """Set final values for metrics speed and confusion matrix.""" + self.metrics.speed = self.speed + self.metrics.confusion_matrix = self.confusion_matrix + + def get_stats(self): + """Returns metrics statistics and results dictionary.""" + stats = { + k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items() + } # to numpy + if len(stats) and stats["tp"].any(): + self.metrics.process(**stats) + self.nt_per_class = np.bincount( + stats["target_cls"].astype(int), minlength=self.nc + ) # number of targets per class + return self.metrics.results_dict + + def print_results(self): + """Prints training/validation set metrics per class.""" + pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format + LOGGER.info( + pf + % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()) + ) + if self.nt_per_class.sum() == 0: + LOGGER.warning( + f"WARNING ⚠️ no labels found in {self.args.task} set, can not \ + compute metrics without labels" + ) + + # Print results per class + if self.args.verbose and not self.training and self.nc > 1 and len(self.stats): + for i, c in enumerate(self.metrics.ap_class_index): + LOGGER.info( + pf + % ( + self.names[c], + self.seen, + self.nt_per_class[c], + *self.metrics.class_result(i), + ) + ) + + if self.args.plots: + for normalize in True, False: + self.confusion_matrix.plot( + save_dir=self.save_dir, + names=self.names.values(), + normalize=normalize, + on_plot=self.on_plot, + ) + + def _process_batch(self, detections, gt_bboxes, gt_cls): + """ + Return correct prediction matrix. + + Args: + detections (torch.Tensor): Tensor of shape [N, 6] representing detections. + Each detection is of the format: x1, y1, x2, y2, conf, class. + labels (torch.Tensor): Tensor of shape [M, 5] representing labels. + Each label is of the format: class, x1, y1, x2, y2. + + Returns: + (torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 + IoU levels. + """ + iou = box_iou(gt_bboxes, detections[:, :4]) + return self.match_predictions(detections[:, 5], gt_cls, iou) + + def build_dataset(self, img_path, mode="val", batch=None): + """ + Build YOLO Dataset. + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize + different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. + Defaults to None. + """ + return build_yolo_dataset( + self.args, img_path, batch, self.data, mode=mode, stride=self.stride + ) + + def get_dataloader(self, dataset_path, batch_size): + """Construct and return dataloader.""" + dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val") + return build_dataloader( + dataset, batch_size, self.args.workers, shuffle=False, rank=-1 + ) # return dataloader + + def plot_val_samples(self, batch, ni): + """Plot validation image samples.""" + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) + + def plot_predictions(self, batch, preds, ni): + """Plots predicted bounding boxes on input images and saves the result.""" + plot_images( + batch["img"], + *output_to_target(preds, max_det=self.args.max_det), + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred + + def save_one_txt(self, predn, save_conf, shape, file): + """Save YOLO detections to a txt file in normalized coordinates in a + specific format.""" + gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh + for *xyxy, conf, cls in predn.tolist(): + xywh = ( + (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() + ) # normalized xywh + line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format + with open(file, "a") as f: + f.write(("%g " * len(line)).rstrip() % line + "\n") + + def pred_to_json(self, predn, filename): + """Serialize YOLO predictions to COCO json format.""" + stem = Path(filename).stem + image_id = int(stem) if stem.isnumeric() else stem + box = ops.xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + for p, b in zip(predn.tolist(), box.tolist()): + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(p[5])], + "bbox": [round(x, 3) for x in b], + "score": round(p[4], 5), + } + ) + + def eval_json(self, stats): + """Evaluates YOLO output in JSON format and returns performance statistics.""" + if self.args.save_json and self.is_coco and len(self.jdict): + anno_json = ( + self.data["path"] / "annotations/instances_val2017.json" + ) # annotations + pred_json = self.save_dir / "predictions.json" # predictions + LOGGER.info( + f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}..." + ) + try: + check_requirements("pycocotools>=2.0.6") + from pycocotools.coco import COCO # noqa + from pycocotools.cocoeval import COCOeval # noqa + + for x in anno_json, pred_json: + assert x.is_file(), f"{x} file not found" + anno = COCO(str(anno_json)) # init annotations api + pred = anno.loadRes( + str(pred_json) + ) # init predictions api (must pass string, not Path) + eval = COCOeval(anno, pred, "bbox") + if self.is_coco: + eval.params.imgIds = [ + int(Path(x).stem) for x in self.dataloader.dataset.im_files + ] # images to eval + eval.evaluate() + eval.accumulate() + eval.summarize() + stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[ + :2 + ] # update mAP50-95 and mAP50 + except Exception as e: + LOGGER.warning(f"pycocotools unable to run: {e}") + return stats diff --git a/recipes/object_detection/yolo_loss.py b/recipes/object_detection/yolo_loss.py index 29650a6..4810468 100644 --- a/recipes/object_detection/yolo_loss.py +++ b/recipes/object_detection/yolo_loss.py @@ -7,6 +7,7 @@ - Matteo Beltrami, 2023 - Francesco Paissan, 2023 """ + import torch import torch.nn as nn from ultralytics.utils.loss import BboxLoss, v8DetectionLoss @@ -76,8 +77,11 @@ def __call__(self, preds, batch): [xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2 ).split((self.reg_max * 4, self.nc), 1) + # breakpoint() pred_scores = pred_scores.permute(0, 2, 1).contiguous() pred_distri = pred_distri.permute(0, 2, 1).contiguous() + # print("pred scores shape ", pred_scores.shape) # x, 8400, 80 + # print("pred distri shape ", pred_distri.shape) # x, 8400, 64 (reg_max * 4) dtype = pred_scores.dtype batch_size = pred_scores.shape[0] @@ -85,6 +89,7 @@ def __call__(self, preds, batch): torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] ) # image size (h,w) + # print(imgsz) anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) # Targets @@ -95,21 +100,29 @@ def __call__(self, preds, batch): targets = self.preprocess( targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]] ) + gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) # Pboxes pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) + # print("pred bboxes") + # print(pred_bboxes.shape) + # print(pred_bboxes[0, 0]) _, target_bboxes, target_scores, fg_mask, _ = self.assigner( - pred_scores.detach().sigmoid(), - (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + pred_scores.clone().detach().sigmoid(), + (pred_bboxes.clone().detach() * stride_tensor).type(gt_bboxes.dtype), anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt, ) + # print("target bboxes") + # print(target_bboxes.shape) + # print(target_bboxes[0, 0]) + target_scores_sum = max(target_scores.sum(), 1) # Cls loss @@ -117,9 +130,14 @@ def __call__(self, preds, batch): self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum ) # BCE + # print("classification loss", loss[1]) # Bbox loss if fg_mask.sum(): target_bboxes /= stride_tensor + # print("target bboxes w norm") + # print(target_bboxes.shape) + # print(target_bboxes[0, 0]) + loss[0], loss[2] = self.bbox_loss( pred_distri, pred_bboxes, @@ -134,4 +152,21 @@ def __call__(self, preds, batch): loss[1] *= self.hyp.cls # cls gain loss[2] *= self.hyp.dfl # dfl gain + # with open("dump.txt", "a") as f: + # f.write("bbox loss {}".format(loss[0].item())) + # f.write("\n") + # f.write("cls loss {}".format(loss[1].item())) + # f.write("\n") + # f.write("dfl loss {}".format(loss[2].item())) + # f.write("\n") + # f.write("total {}".format(loss.sum().item())) + # f.write("\n") + # f.write("total * batch_size {}".format(loss.sum().item() * batch_size)) + # f.write("\n") + # + # breakpoint() + + # print(torch.std_mean(batch["img"])) + # breakpoint() + return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)