Skip to content

Commit

Permalink
feat: Integrate TensorBoard for training visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
healthonrails committed Jan 16, 2025
1 parent 21bded7 commit af72db3
Showing 1 changed file with 33 additions and 9 deletions.
42 changes: 33 additions & 9 deletions annolid/behavior/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from annolid.behavior.data_loading.datasets import BehaviorDataset
from annolid.behavior.data_loading.transforms import ResizeCenterCropNormalize
from annolid.behavior.models.classifier import BehaviorClassifier
from annolid.behavior.models.feature_extractors import ResNetFeatureExtractor,CLIPFeatureExtractor
from annolid.behavior.models.feature_extractors import ResNetFeatureExtractor, CLIPFeatureExtractor
from torch.utils.tensorboard import SummaryWriter # Import TensorBoard

# Configuration (Best practice: Move these to a separate configuration file or use command-line arguments)
BATCH_SIZE = 1
Expand All @@ -17,15 +18,16 @@
VIDEO_FOLDER = "behaivor_videos" # Replace with your actual path
CHECKPOINT_DIR = "checkpoints" # Directory to save checkpoints
VALIDATION_SPLIT = 0.2 # Proportion of the dataset to use for validation
TENSORBOARD_LOG_DIR = "runs" # Directory for TensorBoard logs

logger = logging.getLogger(__name__)


def train_model(model, train_loader, val_loader, num_epochs, device, optimizer, criterion, checkpoint_dir):
"""Trains the model and evaluates it on a validation set."""
def train_model(model, train_loader, val_loader, num_epochs, device, optimizer, criterion, checkpoint_dir, writer):
"""Trains the model and evaluates it on a validation set, logging to TensorBoard."""

os.makedirs(checkpoint_dir, exist_ok=True)
best_val_loss = float("inf")
global_step = 0 # Track training steps for TensorBoard

for epoch in range(num_epochs):
model.train()
Expand All @@ -46,15 +48,24 @@ def train_model(model, train_loader, val_loader, num_epochs, device, optimizer,
progress_info = f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}"
print(progress_info)

# Log training loss to TensorBoard
writer.add_scalar('Loss/train', loss.item(), global_step)

if (i + 1) % 10 == 0: # Log every 10 steps
logger.info(progress_info)

global_step += 1

# Validation after each epoch
val_loss, val_accuracy = validate_model(
model, val_loader, criterion, device)
logger.info(
f"Epoch [{epoch + 1}/{num_epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

# Log validation loss and accuracy to TensorBoard
writer.add_scalar('Loss/validation', val_loss, epoch)
writer.add_scalar('Accuracy/validation', val_accuracy / 100.0, epoch) # Scale to 0-1

# Save checkpoint if validation loss improves
if val_loss < best_val_loss:
best_val_loss = val_loss
Expand All @@ -64,7 +75,6 @@ def train_model(model, train_loader, val_loader, num_epochs, device, optimizer,

logger.info("Training finished.")


def validate_model(model, val_loader, criterion, device):
"""Evaluates the model on the validation set and calculates accuracy."""
model.eval()
Expand All @@ -86,7 +96,6 @@ def validate_model(model, val_loader, criterion, device):
avg_val_loss = val_loss / len(val_loader)
return avg_val_loss, accuracy


def main():
parser = argparse.ArgumentParser(
description="Train animal behavior classifier.")
Expand All @@ -100,6 +109,8 @@ def main():
default=LEARNING_RATE, help="Learning rate.")
parser.add_argument("--checkpoint_dir", type=str,
default=CHECKPOINT_DIR, help="Checkpoint directory.")
parser.add_argument("--tensorboard_log_dir", type=str,
default=TENSORBOARD_LOG_DIR, help="Directory for TensorBoard logs.")

args = parser.parse_args()

Expand Down Expand Up @@ -143,8 +154,21 @@ def main():
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
criterion = nn.CrossEntropyLoss()

# Initialize TensorBoard SummaryWriter
writer = SummaryWriter(args.tensorboard_log_dir)
# Add graph to TensorBoard
try:
# Get a sample input from the DataLoader to determine the correct shape
sample_batch = next(iter(train_loader))
# Get the first item from the batch and add a batch dimension
dummy_input = sample_batch[0][0].unsqueeze(0).to(device)
writer.add_graph(model, dummy_input)
except Exception as e:
logger.warning(f"Failed to add graph to TensorBoard: {e}")


train_model(model, train_loader, val_loader, args.num_epochs,
device, optimizer, criterion, args.checkpoint_dir)
device, optimizer, criterion, args.checkpoint_dir, writer)

# Load best model for final evaluation
best_model_path = os.path.join(args.checkpoint_dir, "best_model.pth")
Expand All @@ -155,7 +179,7 @@ def main():
f"Final Validation Loss: {final_val_loss:.4f}, Final Validation Accuracy: {final_val_accuracy:.2f}%")

logger.info("Training and validation completed.")

writer.close() # Close the TensorBoard writer

if __name__ == "__main__":
main()
main()

0 comments on commit af72db3

Please sign in to comment.