diff --git a/annolid/behavior/training/train.py b/annolid/behavior/training/train.py index 65d3e0b..c34d797 100644 --- a/annolid/behavior/training/train.py +++ b/annolid/behavior/training/train.py @@ -7,7 +7,8 @@ from annolid.behavior.data_loading.datasets import BehaviorDataset from annolid.behavior.data_loading.transforms import ResizeCenterCropNormalize from annolid.behavior.models.classifier import BehaviorClassifier -from annolid.behavior.models.feature_extractors import ResNetFeatureExtractor,CLIPFeatureExtractor +from annolid.behavior.models.feature_extractors import ResNetFeatureExtractor, CLIPFeatureExtractor +from torch.utils.tensorboard import SummaryWriter # Import TensorBoard # Configuration (Best practice: Move these to a separate configuration file or use command-line arguments) BATCH_SIZE = 1 @@ -17,15 +18,16 @@ VIDEO_FOLDER = "behaivor_videos" # Replace with your actual path CHECKPOINT_DIR = "checkpoints" # Directory to save checkpoints VALIDATION_SPLIT = 0.2 # Proportion of the dataset to use for validation +TENSORBOARD_LOG_DIR = "runs" # Directory for TensorBoard logs logger = logging.getLogger(__name__) - -def train_model(model, train_loader, val_loader, num_epochs, device, optimizer, criterion, checkpoint_dir): - """Trains the model and evaluates it on a validation set.""" +def train_model(model, train_loader, val_loader, num_epochs, device, optimizer, criterion, checkpoint_dir, writer): + """Trains the model and evaluates it on a validation set, logging to TensorBoard.""" os.makedirs(checkpoint_dir, exist_ok=True) best_val_loss = float("inf") + global_step = 0 # Track training steps for TensorBoard for epoch in range(num_epochs): model.train() @@ -46,15 +48,24 @@ def train_model(model, train_loader, val_loader, num_epochs, device, optimizer, progress_info = f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}" print(progress_info) + # Log training loss to TensorBoard + writer.add_scalar('Loss/train', loss.item(), global_step) + if (i + 1) % 10 == 0: # Log every 10 steps logger.info(progress_info) + global_step += 1 + # Validation after each epoch val_loss, val_accuracy = validate_model( model, val_loader, criterion, device) logger.info( f"Epoch [{epoch + 1}/{num_epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%") + # Log validation loss and accuracy to TensorBoard + writer.add_scalar('Loss/validation', val_loss, epoch) + writer.add_scalar('Accuracy/validation', val_accuracy / 100.0, epoch) # Scale to 0-1 + # Save checkpoint if validation loss improves if val_loss < best_val_loss: best_val_loss = val_loss @@ -64,7 +75,6 @@ def train_model(model, train_loader, val_loader, num_epochs, device, optimizer, logger.info("Training finished.") - def validate_model(model, val_loader, criterion, device): """Evaluates the model on the validation set and calculates accuracy.""" model.eval() @@ -86,7 +96,6 @@ def validate_model(model, val_loader, criterion, device): avg_val_loss = val_loss / len(val_loader) return avg_val_loss, accuracy - def main(): parser = argparse.ArgumentParser( description="Train animal behavior classifier.") @@ -100,6 +109,8 @@ def main(): default=LEARNING_RATE, help="Learning rate.") parser.add_argument("--checkpoint_dir", type=str, default=CHECKPOINT_DIR, help="Checkpoint directory.") + parser.add_argument("--tensorboard_log_dir", type=str, + default=TENSORBOARD_LOG_DIR, help="Directory for TensorBoard logs.") args = parser.parse_args() @@ -143,8 +154,21 @@ def main(): optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) criterion = nn.CrossEntropyLoss() + # Initialize TensorBoard SummaryWriter + writer = SummaryWriter(args.tensorboard_log_dir) + # Add graph to TensorBoard + try: + # Get a sample input from the DataLoader to determine the correct shape + sample_batch = next(iter(train_loader)) + # Get the first item from the batch and add a batch dimension + dummy_input = sample_batch[0][0].unsqueeze(0).to(device) + writer.add_graph(model, dummy_input) + except Exception as e: + logger.warning(f"Failed to add graph to TensorBoard: {e}") + + train_model(model, train_loader, val_loader, args.num_epochs, - device, optimizer, criterion, args.checkpoint_dir) + device, optimizer, criterion, args.checkpoint_dir, writer) # Load best model for final evaluation best_model_path = os.path.join(args.checkpoint_dir, "best_model.pth") @@ -155,7 +179,7 @@ def main(): f"Final Validation Loss: {final_val_loss:.4f}, Final Validation Accuracy: {final_val_accuracy:.2f}%") logger.info("Training and validation completed.") - + writer.close() # Close the TensorBoard writer if __name__ == "__main__": - main() + main() \ No newline at end of file