Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] VideoMAE #663

Open
wants to merge 4 commits into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions mmselfsup/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,20 @@ def transform(self,
if not isinstance(img, List):
img = [img]
for i, img_ in enumerate(img):
if len(img_.shape) < 3:
img_ = np.expand_dims(img_, -1)
img_ = np.ascontiguousarray(img_.transpose(2, 0, 1))
# to handle the single channel image
img_ = np.expand_dims(img_, -1) \
if len(img_.shape) == 2 else img_

if len(img_.shape) == 3:
img_ = np.ascontiguousarray(img_.transpose(2, 0, 1))
elif len(img_.shape) == 5:
# for video data from mmaction with the shape
# (M, C, T, H, W), M = num_crops x num_clips
img_ = img_
else:
raise ValueError(f'img should be 2, 3 or 4 dimensional, \
instead of {len(img_.shape)} dimensional.')

img[i] = to_tensor(img_)
packed_results['inputs'] = img

Expand Down
9 changes: 4 additions & 5 deletions mmselfsup/models/losses/reconstruction_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,13 @@ def forward(self,
"""
loss = self.penalty(pred, target)

# if the dim of the loss is 3, take the average of the loss
# along the last dim
if len(loss.shape) == 3:
loss = loss.mean(dim=-1)

if mask is None:
loss = loss.mean()
else:
# if the dim of the loss is 3, take the average of the loss
# along the last dim
if len(loss.shape) == 3:
loss = loss.mean(dim=-1)
loss = (loss * mask).sum() / mask.sum() / self.channel

return loss
17 changes: 10 additions & 7 deletions mmselfsup/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
RelativeLocDataPreprocessor,
RotationPredDataPreprocessor,
SelfSupDataPreprocessor,
TwoNormDataPreprocessor)
TwoNormDataPreprocessor,
VideoMAEDataPreprocessor)
from .ema import CosineEMA
from .extractor import Extractor
from .gather_layer import GatherLayer
from .multi_pooling import MultiPooling
from .multi_prototypes import MultiPrototypes
from .position_embedding import build_2d_sincos_position_embedding
from .position_embedding import (build_1d_sincos_position_embedding,
build_2d_sincos_position_embedding)
from .sobel import Sobel
from .transformer_blocks import (CAETransformerRegressorLayer,
MultiheadAttention,
Expand All @@ -26,9 +28,10 @@
__all__ = [
'Extractor', 'GatherLayer', 'MultiPooling', 'MultiPrototypes',
'build_2d_sincos_position_embedding', 'Sobel', 'MultiheadAttention',
'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'CosineEMA',
'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor',
'RotationPredDataPreprocessor', 'CAEDataPreprocessor', 'ResLayerExtraNorm',
'NormEMAVectorQuantizer', 'TwoNormDataPreprocessor',
'PromptTransformerEncoderLayer', 'build_clip_model'
'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'Encoder',
'CosineEMA', 'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor',
'RotationPredDataPreprocessor', 'CAEDataPreprocessor',
'VideoMAEDataPreprocessor', 'ResLayerExtraNorm', 'NormEMAVectorQuantizer',
'TwoNormDataPreprocessor', 'PromptTransformerEncoderLayer',
'build_clip_model', 'build_1d_sincos_position_embedding'
]
75 changes: 74 additions & 1 deletion mmselfsup/models/utils/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Optional, Sequence, Tuple, Union

import torch
from mmengine.model import ImgDataPreprocessor
from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor

from mmselfsup.registry import MODELS

Expand Down Expand Up @@ -290,3 +290,76 @@ def forward(
]

return batch_inputs, batch_data_samples


@MODELS.register_module()
class VideoMAEDataPreprocessor(BaseDataPreprocessor):
""""""

def __init__(self,
mean: Optional[Sequence[Union[float, int]]] = None,
std: Optional[Sequence[Union[float, int]]] = None,
pad_size_divisor: int = 1,
pad_value: Union[float, int] = 0,
to_rgb: bool = False,
format_shape: str = 'NCHW') -> None:
super().__init__()
self.pad_size_divisor = pad_size_divisor
self.pad_value = pad_value
self.to_rgb = to_rgb
self.format_shape = format_shape

if mean is not None:
assert std is not None, 'To enable the normalization in ' \
'preprocessing, please specify both ' \
'`mean` and `std`.'
# Enable the normalization in preprocessing.
self._enable_normalize = True
if self.format_shape == 'NCHW':
normalizer_shape = (-1, 1, 1)
elif self.format_shape == 'NCTHW' or self.format_shape == 'NCTVM':
normalizer_shape = (-1, 1, 1, 1)
else:
raise ValueError(f'Invalid format shape: {format_shape}')

self.register_buffer(
'mean',
torch.tensor(mean, dtype=torch.float32).view(normalizer_shape),
False)
self.register_buffer(
'std',
torch.tensor(std, dtype=torch.float32).view(normalizer_shape),
False)
else:
self._enable_normalize = False

def forward(self, data: dict, training: bool = False):

data = [val for _, val in data.items()]
batch_inputs, batch_data_samples = self.cast_data(data)

# ------ To RGB ------
if self.to_rgb:
if self.format_shape == 'NCHW':
batch_inputs = [
batch_input[..., [2, 1, 0], :, :]
for batch_input in batch_inputs
]
elif self.format_shape == 'NCTHW':
batch_inputs = [
batch_input[..., [2, 1, 0], :, :, :]
for batch_input in batch_inputs
]
else:
raise ValueError(f'Invalid format shape: {self.format_shape}')

# -- Normalization ---
if self._enable_normalize:
batch_inputs = [(batch_input - self.mean) / self.std
for batch_input in batch_inputs]
else:
batch_inputs = [
batch_input.to(torch.float32) for batch_input in batch_inputs
]

return batch_inputs, batch_data_samples
29 changes: 29 additions & 0 deletions mmselfsup/models/utils/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,32 @@ def build_2d_sincos_position_embedding(
pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1)

return pos_emb


def build_1d_sincos_position_embedding(
num_patches: int,
embed_dims: int,
temperature: Optional[int] = 10000.) -> torch.Tensor:
"""The function is to build 1d position embedding for model to obtain the
position information of the input patches.

Sinusoid encoding is a kind of relative position encoding method came from
`Attention Is All You Need<https://arxiv.org/abs/1706.03762>`_.

Args:
num_patches (int): The number of the input patches.
embed_dims (int): The dimension of the embedding vector.
temperature (int, optional): The temperature parameter. Defaults to
10000.
"""
vector = torch.arange(embed_dims, dtype=torch.float64)
vector = (vector - vector % 2) / embed_dims
vector = torch.pow(temperature, -vector).view(1, -1)

sinusoid_table = torch.arange(num_patches).view(-1, 1) * vector
sinusoid_table[:, 0::2].sin_() # dim 2i
sinusoid_table[:, 1::2].cos_() # dim 2i+1

sinusoid_table = sinusoid_table.to(torch.float32)

return sinusoid_table
5 changes: 3 additions & 2 deletions mmselfsup/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from .gather import concat_all_gather
from .misc import get_model
from .setup_env import register_all_modules
from .typing import * # noqa: F401, F403

__all__ = [
'AliasMethod', 'batch_shuffle_ddp', 'batch_unshuffle_ddp',
'dist_forward_collect', 'nondist_forward_collect', 'collect_env',
'sync_random_seed', 'distributed_sinkhorn', 'concat_all_gather',
'register_all_modules', 'get_model'
'distributed_sinkhorn', 'concat_all_gather', 'register_all_modules',
'get_model'
]
9 changes: 9 additions & 0 deletions mmselfsup/utils/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Collecting some commonly used type hint in mmselfsup."""
from typing import Optional, Union

from mmengine.config import ConfigDict

# Type hint of config data
ConfigType = Union[ConfigDict, dict]
OptConfigType = Optional[ConfigType]
53 changes: 53 additions & 0 deletions projects/videomae/configs/_base_/datasets/k400_videomae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# dataset settings

dataset_type = 'mmaction.VideoDataset'
data_root = 'data/kinetics400/videos_train'
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'

file_client_args = dict(
io_backend='petrel',
path_mapping=dict(
{'data/kinetics400': 's3://openmmlab/datasets/action/Kinetics400'}))

# file_client_args = dict(io_backend='disk')
train_pipeline = [
dict(type='mmaction.DecordInit', **file_client_args),
dict(
type='mmaction.SampleFrames',
clip_len=16,
frame_interval=4,
num_clips=1),
dict(type='mmaction.DecordDecode'),
dict(
type='mmaction.MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1),
dict(type='mmaction.Resize', scale=(224, 224), keep_ratio=False),
dict(type='mmaction.FormatShape', input_format='NCTHW'),
dict(
type='VideoMAEMaskGenerator',
input_size=(16, 224, 224),
patch_size=16,
tubelet_size=2,
mask_ratio=0.9,
mask_mode='tube'),
dict(
type='PackSelfSupInputs',
key='imgs',
algorithm_keys=['mask'],
meta_keys=['img_shape', 'label'])
]

train_dataloader = dict(
batch_size=32,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=dict(video=data_root),
pipeline=train_pipeline))
40 changes: 40 additions & 0 deletions projects/videomae/configs/_base_/models/videomae_vit-small-p16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# model settings
model = dict(
type='VideoMAE',
data_preprocessor=dict(
type='VideoMAEDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
format_shape='NCTHW'),
backbone=dict(
type='VideoMAEViT',
img_size=224,
embed_dims=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
num_frames=16,
norm_cfg=dict(type='LN', eps=1e-6),
patch_size=16,
mask_ratio=0.9),
neck=dict(
type='VideoMAEPretrainDecoder',
img_size=224,
num_frames=16,
num_classes=1536,
num_heads=3,
input_dims=384,
embed_dims=192,
patch_size=16,
depth=4,
),
head=dict(
type='VideoMAEPretrainHead',
norm_pix=True,
patch_size=16,
loss=dict(type='PixelReconstructionLoss', criterion='L2')),
init_cfg=[
dict(type='Xavier', distribution='uniform', layer='Linear'),
dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
])
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# optimizer_wrapper
optimizer = dict(type='AdamW', lr=1.5e-4, betas=(0.9, 0.95), weight_decay=0.05)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)

# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR', T_max=160, by_epoch=True, begin=40, end=200)
]

# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=200)
5 changes: 5 additions & 0 deletions projects/videomae/configs/videomae/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# VideoMAE

> [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602)

<!-- [ALGORITHM] -->
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# default_scope = 'mmaction'

default_hooks = dict(
runtime_info=dict(type='mmaction.RuntimeInfoHook'),
timer=dict(type='mmaction.IterTimerHook'),
logger=dict(type='mmaction.LoggerHook', interval=20, ignore_last=False),
param_scheduler=dict(type='mmaction.ParamSchedulerHook'),
checkpoint=dict(
type='mmaction.CheckpointHook', interval=1, save_best='auto'),
sampler_seed=dict(type='mmaction.DistSamplerSeedHook'),
sync_buffers=dict(type='mmaction.SyncBuffersHook'))

env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'))

log_processor = dict(
type='mmaction.LogProcessor', window_size=20, by_epoch=True)

vis_backends = [dict(type='mmaction.LocalVisBackend')]
visualizer = dict(type='mmaction.ActionVisualizer', vis_backends=vis_backends)

log_level = 'INFO'
load_from = None
resume = False
Loading