Skip to content

Commit

Permalink
better docs
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobeltrami committed Nov 28, 2023
1 parent 3beed61 commit 4393860
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
13 changes: 5 additions & 8 deletions micromind/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def reduce(self, stage, clear=False):
Compute and return the metric for a given prediction and batch data.
Arguments
-------
---------
pred : torch.Tensor
The model's prediction.
batch : torch.Tensor
Expand Down Expand Up @@ -273,10 +273,11 @@ def configure_optimizers(self):
"""Configures and defines the optimizer for the task. Defaults to adam
with lr=0.001; It can be overwritten by either passing arguments from the
command line, or by overwriting this entire method.
Scheduler step is called every optimization step.
Returns
---------
Optimizer and learning rate scheduler.
-------
Optimizer and learning rate scheduler.
: Union[Tuple[torch.optim.Adam, None], torch.optim.Adam]
"""
Expand All @@ -289,11 +290,7 @@ def configure_optimizers(self):
elif self.hparams.opt == "sgd":
opt = torch.optim.SGD(self.modules.parameters(), self.hparams.lr)

sched = torch.optim.lr_scheduler.ReduceLROnPlateau(
opt, "min", factor=0.1, patience=10, threshold=5
)

return opt, sched
return opt

def __call__(self, *x, **xv):
"""Just forwards everything to the forward method."""
Expand Down
20 changes: 19 additions & 1 deletion recipes/objection_detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

class YOLO(mm.MicroMind):
def __init__(self, m_cfg, hparams, *args, **kwargs):
"""Initializes the YOLO model."""
super().__init__(*args, **kwargs)

self.modules["phinet"] = PhiNet(
Expand Down Expand Up @@ -94,7 +95,6 @@ def get_parameters(self):
out_neck[1].shape[1],
out_neck[2].shape[1],
)
# head = DetectionHead(filters=head_filters)

return (c1, c2), neck_filters, up, head_filters

Expand All @@ -111,6 +111,7 @@ def preprocess_batch(self, batch):
return preprocessed_batch

def forward(self, batch):
"""Runs the forward method by calling every module."""
preprocessed_batch = self.preprocess_batch(batch)
backbone = self.modules["phinet"](preprocessed_batch["img"].to(self.device))[1]
backbone[-1] = self.modules["sppf"](backbone[-1])
Expand All @@ -120,6 +121,7 @@ def forward(self, batch):
return head

def compute_loss(self, pred, batch):
"""Computes the loss."""
self.criterion = Loss(self.m_cfg, self.modules["head"], self.device)
preprocessed_batch = self.preprocess_batch(batch)

Expand All @@ -131,6 +133,7 @@ def compute_loss(self, pred, batch):
return lossi_sum

def configure_optimizers(self):
"""Configures the optimizer and the scheduler."""
opt = torch.optim.SGD(self.modules.parameters(), lr=1e-2, weight_decay=0.0005)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=14000, eta_min=1e-3
Expand All @@ -139,6 +142,21 @@ def configure_optimizers(self):

@torch.no_grad()
def mAP(self, pred, batch):
"""Compute the mean average precision (mAP) for a batch of predictions.
Arguments
---------
pred : torch.Tensor
Model predictions for the batch.
batch : dict
A dictionary containing batch information, including bounding boxes,
classes and shapes.
Returns
-------
torch.Tensor
A tensor containing the computed mean average precision (mAP) for the batch.
"""
preprocessed_batch = self.preprocess_batch(batch)
post_predictions = postprocess(
preds=pred[0], img=preprocessed_batch, orig_imgs=batch
Expand Down

0 comments on commit 4393860

Please sign in to comment.