Skip to content

Commit

Permalink
feat: add CLIP feature extractor class for behavior classification model
Browse files Browse the repository at this point in the history
- Implemented  class to extract features using CLIP's vision encoder from the  library.
- Integrated the feature extractor with the  for improved input representation.
- Ensured compatibility with the existing model architecture by matching the output feature dimension ().
  • Loading branch information
healthonrails committed Nov 12, 2024
1 parent 83f04a8 commit 56a12c3
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 39 deletions.
4 changes: 2 additions & 2 deletions annolid/behavior/data_loading/multimodal_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def main(input_path: str, output_path: str) -> None:

if __name__ == '__main__':
# Paths to input and output files
input_path = '/data/test_video_annotations.jsonl'
output_path = '/data/transformed_video_annotations.json'
input_path = 'train_video_annotations.jsonl'
output_path = 'transformed_train_video_annotations.json'

# Run the main process
main(input_path, output_path)
2 changes: 1 addition & 1 deletion annolid/behavior/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .classifier import BehaviorClassifier
from .feature_extractors import ResNetFeatureExtractor
from .feature_extractors import ResNetFeatureExtractor, CLIPFeatureExtractor
31 changes: 31 additions & 0 deletions annolid/behavior/models/feature_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,41 @@
import torch.nn as nn
import torchvision.models as models
import logging
from transformers import CLIPModel, CLIPProcessor

logger = logging.getLogger(__name__)


class CLIPFeatureExtractor(nn.Module):
"""
A class to extract features from images using the CLIP vision encoder.
Args:
model_name (str): The name of the pre-trained CLIP model (e.g., 'openai/clip-vit-base-patch32').
"""

def __init__(self, model_name: str = 'openai/clip-vit-base-patch32'):
super().__init__()
self.clip_model = CLIPModel.from_pretrained(model_name)
self.vision_encoder = self.clip_model.vision_model
self.processor = CLIPProcessor.from_pretrained(model_name)

def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Extract features from input images.
Args:
images (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
Returns:
torch.Tensor: Extracted features of shape (batch_size, feature_dim).
"""
# Process the images using CLIP's vision encoder
with torch.no_grad(): # Optional: Avoid backpropagation during feature extraction
features = self.vision_encoder(pixel_values=images).pooler_output

return features # Output shape: (batch_size, feature_dim)

class ResNetFeatureExtractor(nn.Module):
"""
Extracts features from images using a ResNet backbone.
Expand Down
67 changes: 31 additions & 36 deletions annolid/behavior/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,22 @@
from annolid.behavior.data_loading.datasets import BehaviorDataset
from annolid.behavior.data_loading.transforms import ResizeCenterCropNormalize
from annolid.behavior.models.classifier import BehaviorClassifier
from annolid.behavior.models.feature_extractors import ResNetFeatureExtractor
from annolid.behavior.models.feature_extractors import ResNetFeatureExtractor,CLIPFeatureExtractor

# Configuration (Best practice: Move these to a separate configuration file or use command-line arguments)
BATCH_SIZE = 1
NUM_EPOCHS = 10
LEARNING_RATE = 0.001
# Replace with your video folder
VIDEO_FOLDER = "behaivor_videos/"
VIDEO_FOLDER = "behaivor_videos" # Replace with your actual path
CHECKPOINT_DIR = "checkpoints" # Directory to save checkpoints
VALIDATION_SPLIT = 0.2 # Proportion of the dataset to use for validation

logger = logging.getLogger(__name__)


def train_model(model, train_loader, val_loader, num_epochs, device, optimizer, criterion, checkpoint_dir):
"""
Trains the behavior classification model and evaluates it on a validation set.
Args:
model: The model to train.
train_loader: DataLoader for the training data.
val_loader: DataLoader for the validation data.
num_epochs: The number of training epochs.
device: The device to use for training (e.g., "cuda" or "cpu").
optimizer: The optimizer.
criterion: The loss function.
checkpoint_dir: The directory to save model checkpoints.
"""
"""Trains the model and evaluates it on a validation set."""

os.makedirs(checkpoint_dir, exist_ok=True)
best_val_loss = float("inf")
Expand All @@ -44,7 +32,6 @@ def train_model(model, train_loader, val_loader, num_epochs, device, optimizer,
for i, batch in enumerate(train_loader):
try:
inputs, labels, _ = batch
print(inputs, labels)
inputs, labels = inputs.to(device), labels.to(device)
except Exception as e:
logger.error(f"Error processing batch: {e}. Skipping batch.")
Expand All @@ -56,15 +43,17 @@ def train_model(model, train_loader, val_loader, num_epochs, device, optimizer,
loss.backward()
optimizer.step()

if (i + 1) % 10 == 0:
logger.info(
f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}"
)
progress_info = f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}"
print(progress_info)

if (i + 1) % 10 == 0: # Log every 10 steps
logger.info(progress_info)

# Validation after each epoch
val_loss = validate_model(model, val_loader, criterion, device)
val_loss, val_accuracy = validate_model(
model, val_loader, criterion, device)
logger.info(
f"Epoch [{epoch + 1}/{num_epochs}], Validation Loss: {val_loss:.4f}")
f"Epoch [{epoch + 1}/{num_epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

# Save checkpoint if validation loss improves
if val_loss < best_val_loss:
Expand All @@ -77,28 +66,25 @@ def train_model(model, train_loader, val_loader, num_epochs, device, optimizer,


def validate_model(model, val_loader, criterion, device):
"""
Evaluates the model on the validation set.
Args:
model: The model to evaluate.
val_loader: DataLoader for the validation data.
criterion: The loss function.
device: The device to use for evaluation.
Returns:
The average validation loss.
"""
"""Evaluates the model on the validation set and calculates accuracy."""
model.eval()
val_loss = 0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels, _ in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()

return val_loss / len(val_loader)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
avg_val_loss = val_loss / len(val_loader)
return avg_val_loss, accuracy


def main():
Expand Down Expand Up @@ -149,16 +135,25 @@ def main():
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)

feature_extractor = ResNetFeatureExtractor().to(device)
feature_extractor = CLIPFeatureExtractor().to(device)
model = BehaviorClassifier(
feature_extractor, num_classes=num_of_classes).to(device)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
criterion = nn.CrossEntropyLoss()

train_model(model, train_loader, val_loader, args.num_epochs,
device, optimizer, criterion, args.checkpoint_dir)

# Load best model for final evaluation
best_model_path = os.path.join(args.checkpoint_dir, "best_model.pth")
model.load_state_dict(torch.load(best_model_path))
final_val_loss, final_val_accuracy = validate_model(
model, val_loader, criterion, device)
logger.info(
f"Final Validation Loss: {final_val_loss:.4f}, Final Validation Accuracy: {final_val_accuracy:.2f}%")

logger.info("Training and validation completed.")


Expand Down

0 comments on commit 56a12c3

Please sign in to comment.