diff --git a/direct/nn/transformers/__init__.py b/direct/nn/transformers/__init__.py new file mode 100644 index 00000000..c36ca0ba --- /dev/null +++ b/direct/nn/transformers/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) DIRECT Contributors + +"""DIRECT transformers models.""" diff --git a/direct/nn/transformers/config.py b/direct/nn/transformers/config.py new file mode 100644 index 00000000..61be2426 --- /dev/null +++ b/direct/nn/transformers/config.py @@ -0,0 +1,71 @@ +# Copyright (c) DIRECT Contributors + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple + +from direct.config.defaults import ModelConfig +from direct.nn.transformers.uformer import AttentionTokenProjectionType, LeWinTransformerMLPTokenType + + +@dataclass +class UFormerModelConfig(ModelConfig): + patch_size: int = 256 + embedding_dim: int = 32 + encoder_depths: Tuple[int, ...] = (2, 2, 2, 2) + encoder_num_heads: Tuple[int, ...] = (1, 2, 4, 8) + bottleneck_depth: int = 2 + bottleneck_num_heads: int = 16 + win_size: int = 8 + mlp_ratio: float = 4.0 + qkv_bias: bool = True + qk_scale: Optional[float] = None + drop_rate: float = 0.0 + attn_drop_rate: float = 0.0 + drop_path_rate: float = 0.1 + patch_norm: bool = True + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.LEFF + shift_flag: bool = True + modulator: bool = False + cross_modulator: bool = False + normalized: bool = True + + +@dataclass +class VisionTransformer2DConfig(ModelConfig): + average_img_size: int | tuple[int, int] = MISSING + patch_size: int | tuple[int, int] = 16 + embedding_dim: int = 64 + depth: int = 8 + num_heads: int = (9,) + mlp_ratio: float = 4.0 + qkv_bias: bool = False + qk_scale: float = None + drop_rate: float = 0.0 + attn_drop_rate: float = 0.0 + dropout_path_rate: float = 0.0 + use_gpsa: bool = True + locality_strength: float = 1.0 + use_pos_embedding: bool = True + normalized: bool = True + + +@dataclass +class VisionTransformer3DConfig(ModelConfig): + average_img_size: int | tuple[int, int, int] = MISSING + patch_size: int | tuple[int, int, int] = 16 + embedding_dim: int = 64 + depth: int = 8 + num_heads: int = (9,) + mlp_ratio: float = 4.0 + qkv_bias: bool = False + qk_scale: float = None + drop_rate: float = 0.0 + attn_drop_rate: float = 0.0 + dropout_path_rate: float = 0.0 + use_gpsa: bool = True + locality_strength: float = 1.0 + use_pos_embedding: bool = True + normalized: bool = True diff --git a/direct/nn/transformers/uformer.py b/direct/nn/transformers/uformer.py new file mode 100644 index 00000000..d2268792 --- /dev/null +++ b/direct/nn/transformers/uformer.py @@ -0,0 +1,2006 @@ +# Copyright (c) DIRECT Contributors + +"""U-Former model [1]_ implementation. + +Adapted from [2]_. + +References +---------- +.. [1] Wang, Zhendong, et al. "Uformer: A general u-shaped transformer for image restoration." Proceedings of the + IEEE/CVF conference on computer vision and pattern recognition. 2022. +.. [2] https://github.com/ZhendongWang6/Uformer + +""" + +from __future__ import annotations + +import math +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from torch.nn.init import trunc_normal_ + +from direct.nn.transformers.utils import DropoutPath, init_weights, norm, pad_to_square, unnorm, unpad_to_original +from direct.types import DirectEnum + +__all__ = ["AttentionTokenProjectionType", "LeWinTransformerMLPTokenType", "UFormer", "UFormerModel"] + + +class ECALayer1d(nn.Module): + """Efficient Channel Attention (ECA) module for 1D data. + + Parameters + ---------- + channel : int + Number of channels of the input feature map. + k_size : int + Adaptive selection of kernel size. Default: 3. + """ + + def __init__(self, channel: int, k_size: int = 3) -> None: + """Inits :class:`ECALayer1d`. + + Parameters + ---------- + channel : int + Number of channels of the input feature map. + k_size : int + Adaptive selection of kernel size. Default: 3. + """ + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool1d(1) + self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) + self.sigmoid = nn.Sigmoid() + self.channel = channel + self.k_size = k_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Computes the output of the ECA layer. + + Parameters + ---------- + x : torch.Tensor + Input feature map. + + Returns + ------- + y : torch.Tensor + Output of the ECA layer. + """ + # feature descriptor on the global spatial information + y = self.avg_pool(x.transpose(-1, -2)) + + # Two different branches of ECA module + y = self.conv(y.transpose(-1, -2)) + + # Multi-scale information fusion + y = self.sigmoid(y) + + return x * y.expand_as(x) + + +class SepConv2d(torch.nn.Module): + """A 2D Separable Convolutional layer. + + Applies a depthwise convolution followed by a pointwise convolution. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int or tuple of ints + Size of the convolution kernel. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + padding : int or tuple of ints + Padding added to all four sides of the input. Default: 0. + dilation : int or tuple of ints + Spacing between kernel elements. Default: 1. + act_layer : torch.nn.Module + Activation layer applied after depthwise convolution. Default: nn.ReLU. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int | tuple[int, int] = 1, + padding: int | tuple[int, int] = 0, + dilation: int | tuple[int, int] = 1, + act_layer: nn.Module = nn.ReLU, + ) -> None: + """Inits :class:`SepConv2d`. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int or tuple of ints + Size of the convolution kernel. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + padding : int or tuple of ints + Padding added to all four sides of the input. Default: 0. + dilation : int or tuple of ints + Spacing between kernel elements. Default: 1. + act_layer : torch.nn.Module + Activation layer applied after depthwise convolution. Default: nn.ReLU. + """ + super().__init__() + self.depthwise = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + ) + self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1) + self.act_layer = act_layer() if act_layer is not None else nn.Identity() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`SepConv2d`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor after applying depthwise and pointwise convolutions with activation. + """ + x = self.depthwise(x) + x = self.act_layer(x) + x = self.pointwise(x) + return x + + +class ConvProjectionModule(nn.Module): + """Convolutional projection layer used in the window attention mechanism. + + The projection layer consists of three convolutional layers for queries, keys, and values. + + Parameters + ---------- + dim : int + Number of channels in the input tensor. + heads : int + Number of heads in multi-head attention. Default: 8. + dim_head : int + Dimension of each head. Default: 64. + kernel_size : int + Size of convolutional kernel. Default: 3. + q_stride : int + Stride of the convolutional kernel for queries. Default: 1. + k_stride : int + Stride of the convolutional kernel for keys. Default: 1. + v_stride : int + Stride of the convolutional kernel for values. Default: 1. + bias : bool + Whether to include a bias term. Default: True. + """ + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + kernel_size: int = 3, + q_stride: int = 1, + k_stride: int = 1, + v_stride: int = 1, + bias: bool = True, + ): + """Inits :class:`ConvProjectionModule`. + + Parameters + ---------- + dim : int + Number of channels in the input tensor. + heads : int + Number of heads in multi-head attention. Default: 8. + dim_head : int + Dimension of each head. Default: 64. + kernel_size : int + Size of convolutional kernel. Default: 3. + q_stride : int + Stride of the convolutional kernel for queries. Default: 1. + k_stride : int + Stride of the convolutional kernel for keys. Default: 1. + v_stride : int + Stride of the convolutional kernel for values. Default: 1. + bias : bool + Whether to include a bias term. Default: True. + """ + super().__init__() + + inner_dim = dim_head * heads + self.heads = heads + pad = (kernel_size - q_stride) // 2 + self.to_q = SepConv2d( + in_channels=dim, out_channels=inner_dim, kernel_size=kernel_size, stride=q_stride, padding=pad + ) + self.to_k = SepConv2d( + in_channels=dim, out_channels=inner_dim, kernel_size=kernel_size, stride=k_stride, padding=pad + ) + self.to_v = SepConv2d( + in_channels=dim, out_channels=inner_dim, kernel_size=kernel_size, stride=v_stride, padding=pad + ) + + def forward( + self, x: torch.Tensor, attn_kv: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass of :class:`ConvProjectionModule`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + attn_kv : torch.Tensor, optional + Attention key/value tensor. Default None. + + Returns + ------- + q : torch.Tensor + Query tensor. + k : torch.Tensor + Key tensor. + v : torch.Tensor + Value tensor. + """ + b, n, c, h = *x.shape, self.heads + l = int(math.sqrt(n)) + w = int(math.sqrt(n)) + + attn_kv = x if attn_kv is None else attn_kv + x = rearrange(x, "b (l w) c -> b c l w", l=l, w=w) + attn_kv = rearrange(attn_kv, "b (l w) c -> b c l w", l=l, w=w) + q = self.to_q(x) + q = rearrange(q, "b (h d) l w -> b h (l w) d", h=h) + + k = self.to_k(attn_kv) + v = self.to_v(attn_kv) + k = rearrange(k, "b (h d) l w -> b h (l w) d", h=h) + v = rearrange(v, "b (h d) l w -> b h (l w) d", h=h) + return q, k, v + + +class LinearProjectionModule(nn.Module): + """Linear projection layer used in the window attention mechanism. + + Parameters + ---------- + dim : int + The input feature dimension. + heads : int + The number of heads in the multi-head attention mechanism. Default: 8. + dim_head : int, optional + The feature dimension of each head. Default: 64. + bias : bool, optional + Whether to use bias in the linear projections. Default: True. + """ + + def __init__(self, dim: int, heads: int = 8, dim_head: int = 64, bias: bool = True) -> None: + """Inits :class:LinearProjectionModule`. + + Parameters + ---------- + dim : int + The input feature dimension. + heads : int + The number of heads in the multi-head attention mechanism. Default: 8. + dim_head : int, optional + The feature dimension of each head. Default: 64. + bias : bool, optional + Whether to use bias in the linear projections. Default: True. + """ + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.to_q = nn.Linear(dim, inner_dim, bias=bias) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias) + self.dim = dim + self.inner_dim = inner_dim + + def forward( + self, x: torch.Tensor, attn_kv: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Performs forward pass of :class:`LinearProjectionModule`. + + Parameters + ---------- + x : torch.Tensor of shape (batch_size, seq_length, dim) + The input tensor. + attn_kv : torch.Tensor of shape (batch_size, seq_length, dim), optional + The tensor to be used for computing the attention scores. If None, the input tensor is used. Default: None. + + Returns + ------- + q : torch.Tensor of shape (batch_size, seq_length, heads, dim_head) + The tensor resulting from the linear projection of x used for computing the queries. + k : torch.Tensor of shape (batch_size, seq_length, heads, dim_head) + The tensor resulting from the linear projection of attn_kv used for computing the keys. + v : torch.Tensor of shape (batch_size, seq_length, heads, dim_head) + The tensor resulting from the linear projection of attn_kv used for computing the values. + + """ + B_, N, C = x.shape + if attn_kv is not None: + attn_kv = attn_kv.unsqueeze(0).repeat(B_, 1, 1) + else: + attn_kv = x + N_kv = attn_kv.size(1) + q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) + kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) + q = q[0] + k, v = kv[0], kv[1] + return q, k, v + + +class AttentionTokenProjectionType(DirectEnum): + CONV = "conv" + LINEAR = "linear" + + +class WindowAttentionModule(nn.Module): + """A window-based multi-head self-attention module. + + Parameters + ---------- + dim : int + Input feature dimension. + win_size : tuple[int, int] + The window size (height and width). + num_heads : int + Number of heads for multi-head self-attention. + token_projection : AttentionTokenProjectionType + Type of projection for token-level queries, keys, and values. Either "conv" or "linear". + qkv_bias : bool + Whether to use bias in the linear projection layer for queries, keys, and values. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout rate for attention weights. + proj_drop : float + Dropout rate for the output of the last linear projection layer. + """ + + def __init__( + self, + dim: int, + win_size: tuple[int, int], + num_heads: int, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ) -> None: + """Inits :class:`WindowAttentionModule`. + + Parameters + ---------- + dim : int + Input feature dimension. + win_size : tuple[int, int] + The window size (height and width). + num_heads : int + Number of heads for multi-head self-attention. + token_projection : AttentionTokenProjectionType + Type of projection for token-level queries, keys, and values. Either "conv" or "linear". + qkv_bias : bool + Whether to use bias in the linear projection layer for queries, keys, and values. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout rate for attention weights. + proj_drop : float + Dropout rate for the output of the last linear projection layer. + """ + super().__init__() + self.dim = dim + self.win_size = win_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1] + coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1] + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.win_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + trunc_normal_(self.relative_position_bias_table, std=0.02) + + if token_projection == "conv": + self.qkv = ConvProjectionModule(dim, num_heads, dim // num_heads, bias=qkv_bias) + elif token_projection == "linear": + self.qkv = LinearProjectionModule(dim, num_heads, dim // num_heads, bias=qkv_bias) + else: + raise Exception("Projection error!") + + self.token_projection = token_projection + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward( + self, x: torch.Tensor, attn_kv: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Performs forward pass of :class:`WindowAttentionModule`. + + Parameters + ---------- + x : torch.Tensor + A tensor of shape `(B, N, C)` representing the input features, where `B` is the batch size, `N` is the + sequence length, and `C` is the input feature dimension. + attn_kv : torch.Tensor, optional + An optional tensor of shape `(B, N, C)` representing the key-value pairs used for attention computation. + If `None`, the key-value pairs are computed from `x` itself. Default: None. + mask : torch.Tensor, optional + An optional tensor of shape representing the binary mask for the input sequence. + If `None`, no masking is applied. Default: None. + + Returns + ------- + torch.Tensor + A tensor of shape `(B, N, C)` representing the output features after attention computation. + """ + B_, N, C = x.shape + q, k, v = self.qkv(x, attn_kv) + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + ratio = attn.size(-1) // relative_position_bias.size(-1) + relative_position_bias = repeat(relative_position_bias, "nH l c -> nH l (c d)", d=ratio) + + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + mask = repeat(mask, "nW m n -> nW m (n d)", d=ratio) + attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N * ratio) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}" + + +class AttentionModule(nn.Module): + """Self-attention module. + + Parameters + ---------- + dim : int + The input feature dimension. + num_heads : int + The number of attention heads. + qkv_bias : bool + Whether to include biases in the query, key, and value projections. Default: True. + qk_scale : float, optional + Scaling factor for the query and key projections. Default: None. + attn_drop : float + Dropout probability for the attention weights. Default: 0.0. + proj_drop : float + Dropout probability for the output of the attention module. Default: 0.0. + """ + + def __init__( + self, + dim: int, + num_heads: int, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + """Inits :class:`AttentionModule`. + + Parameters + ---------- + dim : int + The input feature dimension. + num_heads : int + The number of attention heads. + qkv_bias : bool + Whether to include biases in the query, key, and value projections. Default: True. + qk_scale : float, optional + Scaling factor for the query and key projections. Default: None. + attn_drop : float + Dropout probability for the attention weights. Default: 0.0. + proj_drop : float + Dropout probability for the output of the attention module. Default: 0.0. + """ + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = LinearProjectionModule(dim, num_heads, dim // num_heads, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward( + self, x: torch.Tensor, attn_kv: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Performs the forward pass of :class:`AttentionModule`. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + attn_kv : torch.Tensor, optional + The attention key/value tensor. + mask : torch.Tensor, optional + + Returns + ------- + torch.Tensor + """ + B_, N, C = x.shape + q, k, v = self.qkv(x, attn_kv) + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, num_heads={self.num_heads}" + + +class MLP(nn.Module): + """Multi-layer perceptron with optional dropout regularization. + + Parameters + ---------- + in_features : int + Number of input features. + hidden_features : int, optional + Number of output features in the hidden layer. If not specified, `in_features` is used. + out_features : int, optional + Number of output features. If not specified, `in_features` is used. + act_layer : nn.Module + Activation layer. Default: GeLU. + drop : float + Dropout probability. Default: 0.0. + """ + + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, + ) -> None: + """Inits :class:`MLP`. + + Parameters + ---------- + in_features : int + Number of input features. + hidden_features : int, optional + Number of output features in the hidden layer. If not specified, `in_features` is used. + out_features : int, optional + Number of output features. If not specified, `in_features` is used. + act_layer : nn.Module + Activation layer. Default: GeLU. + drop : float + Dropout probability. Default: 0.0. + """ + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.in_features = in_features + self.hidden_features = hidden_features + self.out_features = out_features + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the :class:`MLP`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + output : torch.Tensor + """ + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class LeFF(nn.Module): + """Locally-enhanced Feed-Forward Network module. + + Parameters + ---------- + dim : int + Dimension of the input and output features. Default: 32. + hidden_dim : int + Dimension of the hidden features. Default: 128. + act_layer : nn.Module + Activation layer to apply after the first linear layer and the depthwise convolution. Default: GELU. + use_eca : bool + If True, adds a 1D ECA layer after the second linear layer. Default: False. + """ + + def __init__( + self, dim: int = 32, hidden_dim: int = 128, act_layer: nn.Module = nn.GELU, use_eca: bool = False + ) -> None: + """Inits :class:`LeFF`. + + Parameters + ---------- + dim : int + Dimension of the input and output features. Default: 32. + hidden_dim : int + Dimension of the hidden features. Default: 128. + act_layer : nn.Module + Activation layer to apply after the first linear layer and the depthwise convolution. Default: GELU. + use_eca : bool + If True, adds a 1D ECA layer after the second linear layer. Default: False. + """ + super().__init__() + self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim), act_layer()) + self.dwconv = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1), act_layer() + ) + self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim)) + self.dim = dim + self.hidden_dim = hidden_dim + self.eca = ECALayer1d(dim) if use_eca else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`LeFF`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + """ + # bs x hw x c + _, hw, _ = x.size() + hh = int(math.sqrt(hw)) + + x = self.linear1(x) + + # spatial restore + x = rearrange(x, " b (h w) (c) -> b c h w ", h=hh, w=hh) + # bs,hidden_dim,32x32 + + x = self.dwconv(x) + + # flatten + x = rearrange(x, " b c h w -> b (h w) c", h=hh, w=hh) + + x = self.linear2(x) + x = self.eca(x) + + return x + + +def window_partition(x: torch.Tensor, win_size: int, dilation_rate: int = 1) -> torch.Tensor: + """Partition the input tensor into windows of specified size. + + Parameters + ---------- + x : torch.Tensor + The input tensor to be partitioned into windows. + win_size : int + The size of the square windows to partition the tensor into. + dilation_rate : int + The dilation rate for convolution. Default: 1. + + Returns + ------- + windows : torch.Tensor + The tensor representing windows partitioned from input tensor. + """ + B, H, W, C = x.shape + if dilation_rate != 1: + x = x.permute(0, 3, 1, 2) # B, C, H, W + assert type(dilation_rate) is int, "dilation_rate should be a int" + x = F.unfold( + x, kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1), stride=win_size + ) # B, C*Wh*Ww, H/Wh*W/Ww + windows = x.permute(0, 2, 1).contiguous().view(-1, C, win_size, win_size) # B' ,C ,Wh ,Ww + windows = windows.permute(0, 2, 3, 1).contiguous() # B' ,Wh ,Ww ,C + else: + x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) # B' ,Wh ,Ww ,C + return windows + + +def window_reverse(windows: torch.Tensor, win_size: int, H: int, W: int, dilation_rate: int = 1) -> torch.Tensor: + """Rearrange the partitioned tensor back to the original tensor. + + Parameters + ---------- + windows : torch.Tensor + The tensor representing windows partitioned from input tensor. + win_size : int + The size of the square windows used to partition the tensor. + H : int + The height of the original tensor before partitioning. + W : int + The width of the original tensor before partitioning. + dilation_rate : int + The dilation rate for convolution. Default 1. + + Returns + ------- + x: torch.Tensor + The original tensor rearranged from the partitioned tensor. + + """ + # B' ,Wh ,Ww ,C + B = int(windows.shape[0] / (H * W / win_size / win_size)) + x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) + if dilation_rate != 1: + x = windows.permute(0, 5, 3, 4, 1, 2).contiguous() # B, C*Wh*Ww, H/Wh*W/Ww + x = F.fold( + x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1), stride=win_size + ) + else: + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class DownSampleBlock(nn.Module): + """Convolution based downsample block. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the convolution. + """ + + def __init__(self, in_channels: int, out_channels: int) -> None: + """Inits :class:`DownSampleBlock`. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the convolution. + """ + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1), + ) + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`DownSampleBlock`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Downsampled output. + """ + B, L, C = x.shape + H = int(math.sqrt(L)) + W = int(math.sqrt(L)) + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + out = self.conv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C + return out + + +class UpSampleBlock(nn.Module): + """Convolution based upsample block. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the convolution. + """ + + def __init__(self, in_channels: int, out_channels: int) -> None: + """Inits :class:`UpSampleBlock`. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the convolution. + """ + super().__init__() + self.deconv = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + ) + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`UpSampleBlock`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Upsampled output. + """ + B, L, C = x.shape + H = int(math.sqrt(L)) + W = int(math.sqrt(L)) + x = x.transpose(1, 2).contiguous().view(B, C, H, W) + out = self.deconv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C + return out + + +class InputProjection(nn.Module): + """Input convolutional projection used in the U-Former model. + + Parameters + ---------- + in_channels : int + Number of input channels. Default: 3. + out_channels : int + Number of output channels after the projection. Default: 64. + kernel_size : int or tuple of ints + Convolution kernel size. Default: 3. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + norm_layer : nn.Module, optional + Normalization layer to apply after the projection. Default: None. + act_layer : nn.Module + Activation layer to apply after the projection. Default: nn.LeakyReLU. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 64, + kernel_size: int | tuple[int, int] = 3, + stride: int | tuple[int, int] = 1, + norm_layer: Optional[nn.Module] = None, + act_layer: nn.Module = nn.LeakyReLU, + ) -> None: + """Inits :class:`InputProjection`. + + Parameters + ---------- + in_channels : int + Number of input channels. Default: 3. + out_channels : int + Number of output channels after the projection. Default: 64. + kernel_size : int or tuple of ints + Convolution kernel size. Default: 3. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + norm_layer : nn.Module, optional + Normalization layer to apply after the projection. Default: None. + act_layer : nn.Module + Activation layer to apply after the projection. Default: nn.LeakyReLU. + """ + super().__init__() + self.proj = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=kernel_size // 2), + act_layer(inplace=True), + ) + if norm_layer is not None: + self.norm = norm_layer(out_channels) + else: + self.norm = None + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`InputProjection`. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + torch.Tensor + """ + x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B H*W C + if self.norm is not None: + x = self.norm(x) + return x + + +class OutputProjection(nn.Module): + """Output convolutional projection used in the U-Former model. + + Parameters + ---------- + in_channels : int + Number of input channels. Default: 64. + out_channels : int + Number of output channels after the projection. Default: 3. + kernel_size : int or tuple of ints + Convolution kernel size. Default: 3. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + norm_layer : nn.Module, optional + Normalization layer to apply after the projection. Default: None. + act_layer : nn.Module, optional + Activation layer to apply after the projection. Default: None. + """ + + def __init__( + self, + in_channels: int = 64, + out_channels: int = 3, + kernel_size: int | tuple[int, int] = 3, + stride: int | tuple[int, int] = 1, + norm_layer: Optional[nn.Module] = None, + act_layer: Optional[nn.Module] = None, + ): + """Inits :class:`InputProjection`. + + Parameters + ---------- + in_channels : int + Number of input channels. Default: 64. + out_channels : int + Number of output channels after the projection. Default: 3. + kernel_size : int or tuple of ints + Convolution kernel size. Default: 3. + stride : int or tuple of ints + Stride of the convolution. Default: 1. + norm_layer : nn.Module, optional + Normalization layer to apply after the projection. Default: None. + act_layer : nn.Module, optional + Activation layer to apply after the projection. Default: None. + """ + super().__init__() + self.proj = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=kernel_size // 2), + ) + if act_layer is not None: + self.proj.add_module(act_layer(inplace=True)) + if norm_layer is not None: + self.norm = norm_layer(out_channels) + else: + self.norm = None + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`OutputProjection`. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + torch.Tensor + """ + B, L, C = x.shape + H = int(math.sqrt(L)) + W = int(math.sqrt(L)) + x = x.transpose(1, 2).view(B, C, H, W) + x = self.proj(x) + if self.norm is not None: + x = self.norm(x) + return x + + +class LeWinTransformerMLPTokenType(DirectEnum): + MLP = "mlp" + FFN = "ffn" + LEFF = "leff" + + +class LeWinTransformerBlock(nn.Module): + """Applies a window-based multi-head self-attention and MLP or LeFF on the input tensor. + + Parameters + ---------- + dim : int + Number of input channels. + input_resolution : tuple of ints + Input resolution. + num_heads : int + Number of attention heads. + win_size : int + Window size for the attention mechanism. Default: 8. + shift_size : int + The number of pixels to shift the window. Default: 0. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float, optional + Scale factor for the query and key projection vectors. + If set to None, will use the default value of :math`1 / \sqrt(dim)`. Default: None. + drop : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path : float + Dropout rate for the stochastic depth regularization. Default: 0.0. + act_layer : nn.Module + The activation function to use. Default: nn.GELU. + norm_layer : nn.Module + The normalization layer to use. Default: nn.LayerNorm. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + """ + + def __init__( + self, + dim: int, + input_resolution: tuple[int, int], + num_heads: int, + win_size: int = 8, + shift_size: int = 0, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.LEFF, + modulator: bool = False, + cross_modulator: bool = False, + ) -> None: + r"""Inits :class:`LeWinTransformerBlock`. + + Parameters + ---------- + dim : int + Number of input channels. + input_resolution : tuple of ints + Input resolution. + num_heads : int + Number of attention heads. + win_size : int + Window size for the attention mechanism. Default: 8. + shift_size : int + The number of pixels to shift the window. Default: 0. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float, optional + Scale factor for the query and key projection vectors. + If set to None, will use the default value of :math`1 / \sqrt(dim)`. Default: None. + drop : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path : float + Dropout rate for the stochastic depth regularization. Default: 0.0. + act_layer : nn.Module + The activation function to use. Default: nn.GELU. + norm_layer : nn.Module + The normalization layer to use. Default: nn.LayerNorm. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.win_size = win_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.token_mlp = token_mlp + if min(self.input_resolution) <= self.win_size: + self.shift_size = 0 + self.win_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size" + + if modulator: + self.modulator = nn.Embedding(win_size * win_size, dim) # modulator + else: + self.modulator = None + + if cross_modulator: + self.cross_modulator = nn.Embedding(win_size * win_size, dim) # cross_modulator + self.cross_attn = AttentionModule( + dim, + num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.norm_cross = norm_layer(dim) + else: + self.cross_modulator = None + + self.norm1 = norm_layer(dim) + self.attn = WindowAttentionModule( + dim, + win_size=(self.win_size, self.win_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + token_projection=token_projection, + ) + + self.drop_path = DropoutPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + if token_mlp in ["ffn", "mlp"]: + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + elif token_mlp == "leff": + self.mlp = LeFF(dim, mlp_hidden_dim, act_layer=act_layer) + else: + raise Exception("FFN error!") + + def with_pos_embed(self, tensor: torch.Tensor, pos: Optional[torch.Tensor] = None) -> torch.Tensor: + """Add positional embeddings to the input tensor. + + Parameters + ---------- + tensor : torch.Tensor + The input tensor. + pos : torch.Tensor, optional + The positional embeddings to add to the input tensor. Default: None. + + Returns + ------- + torch.Tensor + """ + return tensor if pos is None else tensor + pos + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio},modulator={self.modulator}" + ) + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Performs the forward pass of :class:`LeWinTransformerBlock`. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + mask : torch.Tensor, optional + The mask tensor indicating which elements should be ignored. Default: None. + + Returns + ------- + torch.Tensor + """ + B, L, C = x.shape + H = int(math.sqrt(L)) + W = int(math.sqrt(L)) + + ## input mask + if mask != None: + input_mask = F.interpolate(mask, size=(H, W)).permute(0, 2, 3, 1) + input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1 + attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size + attn_mask = attn_mask.unsqueeze(2) * attn_mask.unsqueeze(1) # nW, win_size*win_size, win_size*win_size + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + ## shift mask + if self.shift_size > 0: + # calculate attention mask for SW-MSA + shift_mask = torch.zeros((1, H, W, 1)).type_as(x) + h_slices = ( + slice(0, -self.win_size), + slice(-self.win_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.win_size), + slice(-self.win_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + shift_mask[:, h, w, :] = cnt + cnt += 1 + shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1 + shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size + shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze( + 2 + ) # nW, win_size*win_size, win_size*win_size + shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill( + shift_attn_mask == 0, float(0.0) + ) + attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask + if self.cross_modulator is not None: + shortcut = x + x_cross = self.norm_cross(x) + x_cross = self.cross_attn(x, self.cross_modulator.weight) + x = shortcut + x_cross + shortcut = x + + x = self.norm1(x) + x = x.view(B, H, W, C) + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + # partition windows + x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C N*C->C + x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C + # with_modulator + if self.modulator is not None: + wmsa_in = self.with_pos_embed(x_windows, self.modulator.weight) + else: + wmsa_in = x_windows + + # W-MSA/SW-MSA + attn_windows = self.attn(wmsa_in, mask=attn_mask) # nW*B, win_size*win_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C) + shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + del attn_mask + return x + + +class BasicUFormerLayer(nn.Module): + """Basic layer of U-Former. + + Parameters + ---------- + dim : int + Number of input channels. + input_resolution : tuple of ints + Input resolution. + num_heads : int + Number of attention heads. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float, optional + Scale factor for the query and key projection vectors. + If set to None, will use the default value of :math`1 / \sqrt(dim)`. Default: None. + drop : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path : float + Dropout rate for the stochastic depth regularization. Default: 0.0. + norm_layer : nn.Module + The normalization layer to use. Default: nn.LayerNorm. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF. + shift_flag : bool + Whether to use shift in the attention sliding windows or not. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + """ + + def __init__( + self, + dim: int, + input_resolution: tuple[int, int], + depth: int, + num_heads: int, + win_size: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[bool] = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: List[float] | float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.FFN, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + ) -> None: + r"""Inits :class:`BasicUFormerLayer`. + + Parameters + ---------- + dim : int + Number of input channels. + input_resolution : tuple of ints + Input resolution. + num_heads : int + Number of attention heads. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float, optional + Scale factor for the query and key projection vectors. + If set to None, will use the default value of :math`1 / \sqrt(dim)`. Default: None. + drop : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path : float + Dropout rate for the stochastic depth regularization. Default: 0.0. + norm_layer : nn.Module + The normalization layer to use. Default: nn.LayerNorm. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF. + shift_flag : bool + Whether to use shift in the attention sliding windows or not. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + + # build blocks + self.blocks = nn.ModuleList( + [ + LeWinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + win_size=win_size, + shift_size=(0 if (i % 2 == 0) else win_size // 2) if shift_flag else 0, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + token_projection=token_projection, + token_mlp=token_mlp, + modulator=modulator, + cross_modulator=cross_modulator, + ) + for i in range(depth) + ] + ) + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Performs forward pass of :class:`BasicUFormerLayer`. + + Parameters + ---------- + x : torch.Tensor + mask : torch.Tensor, optional + + Returns + ------- + torch.Tensor + """ + for blk in self.blocks: + x = blk(x, mask) + return x + + +class UFormer(nn.Module): + """U-Former model based on [1]_, code originally implemented in [2]_. + + Parameters + ---------- + patch_size : int + Size of the patch. Default: 256. + in_channels : int + Number of input channels. Default: 2. + out_channels : int, optional + Number of output channels. Default: None. + embedding_dim : int + Size of the feature embedding. Default: 32. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4, 8). + bottleneck_depth : int + Default: 16. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + **kwargs: Other keyword arguments to pass to the parent constructor. + + References + ---------- + .. [1] Wang, Zhendong, et al. "Uformer: A general u-shaped transformer for image restoration." Proceedings of the + IEEE/CVF conference on computer vision and pattern recognition. 2022. + .. [2] https://github.com/ZhendongWang6/Uformer + """ + + def __init__( + self, + patch_size: int = 256, + in_channels: int = 2, + out_channels: Optional[int] = None, + embedding_dim: int = 32, + encoder_depths: tuple[int, ...] = (2, 2, 2, 2), + encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8), + bottleneck_depth: int = 2, + bottleneck_num_heads: int = 16, + win_size: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.1, + patch_norm: bool = True, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.LEFF, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + ) -> None: + """Inits :class:`UFormer`. + + Parameters + ---------- + patch_size : int + Size of the patch. Default: 256. + in_channels : int + Number of input channels. Default: 2. + out_channels : int, optional + Number of output channels. Default: None. + embedding_dim : int + Size of the feature embedding. Default: 32. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4, 8). + bottleneck_depth : int + Default: 16. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + **kwargs: Other keyword arguments to pass to the parent constructor. + """ + super().__init__() + if len(encoder_num_heads) != len(encoder_depths): + raise ValueError( + f"The number of heads for each layer should be the same as the number of layers. " + f"Got {len(encoder_num_heads)} for {len(encoder_depths)} layers." + ) + if patch_size < (2 ** len(encoder_depths) * win_size): + raise ValueError( + f"Patch size must be greater or equal than 2 ** number of scales * window size." + f" Received: patch_size={patch_size}, number of scales=={len(encoder_depths)}," + f" and window_size={win_size}." + ) + self.num_enc_layers = len(encoder_num_heads) + self.num_dec_layers = len(encoder_num_heads) + depths = (*encoder_depths, bottleneck_depth, *encoder_depths[::-1]) + num_heads = (*encoder_num_heads, bottleneck_num_heads, bottleneck_num_heads, *encoder_num_heads[::-1][:-1]) + self.embedding_dim = embedding_dim + self.patch_norm = patch_norm + self.mlp_ratio = mlp_ratio + self.token_projection = token_projection + self.mlp = token_mlp + self.win_size = win_size + self.reso = patch_size + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[: self.num_enc_layers]))] + conv_dpr = [drop_path_rate] * depths[self.num_enc_layers + 1] + dec_dpr = enc_dpr[::-1] + + # Build layers + + # Input + self.input_proj = InputProjection( + in_channels=in_channels, out_channels=embedding_dim, kernel_size=3, stride=1, act_layer=nn.LeakyReLU + ) + out_channels = out_channels if out_channels else in_channels + # Output + self.output_proj = OutputProjection( + in_channels=2 * embedding_dim, out_channels=out_channels, kernel_size=3, stride=1 + ) + if in_channels != out_channels: + self.conv_out = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0) + self.in_channels = in_channels + self.out_channels = out_channels + + # Encoder + self.encoder_layers = nn.ModuleList() + self.downsamples = nn.ModuleList() + for i in range(self.num_enc_layers): + layer_name = f"encoderlayer_{i}" + layer_input_resolution = (patch_size // (2**i), patch_size // (2**i)) + layer_dim = embedding_dim * (2**i) + layer_depth = depths[i] + layer_drop_path = enc_dpr[sum(depths[:i]) : sum(depths[: i + 1])] + layer = BasicUFormerLayer( + dim=layer_dim, + input_resolution=layer_input_resolution, + depth=layer_depth, + num_heads=num_heads[i], + win_size=win_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=layer_drop_path, + norm_layer=nn.LayerNorm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + ) + self.encoder_layers.add_module(layer_name, layer) + + downsample_layer_name = f"downsample_{i}" + downsample_layer = DownSampleBlock(layer_dim, embedding_dim * (2 ** (i + 1))) + self.downsamples.add_module(downsample_layer_name, downsample_layer) + # Bottleneck + self.bottleneck = BasicUFormerLayer( + dim=embedding_dim * (2**self.num_enc_layers), + input_resolution=(patch_size // (2**self.num_enc_layers), patch_size // (2**self.num_enc_layers)), + depth=depths[self.num_enc_layers], + num_heads=num_heads[self.num_enc_layers], + win_size=win_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=conv_dpr, + norm_layer=nn.LayerNorm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + ) + # Decoder + self.upsamples = nn.ModuleList() + self.decoder_layers = nn.ModuleList() + for i in range(self.num_dec_layers, 0, -1): + upsample_layer_name = f"upsample_{self.num_dec_layers - i}" + if i == self.num_dec_layers: + upsample_in_channels = embedding_dim * (2**i) + else: + upsample_in_channels = embedding_dim * (2 ** (i + 1)) + upsample_out_channels = embedding_dim * (2 ** (i - 1)) + upsample_layer = UpSampleBlock(upsample_in_channels, upsample_out_channels) + self.upsamples.add_module(upsample_layer_name, upsample_layer) + + layer_name = f"decoderlayer_{self.num_dec_layers - i}" + layer_input_resolution = (patch_size // (2 ** (i - 1)), patch_size // (2 ** (i - 1))) + layer_dim = embedding_dim * (2**i) + layer_num = self.num_enc_layers + self.num_dec_layers - i + 1 + layer_depth = depths[layer_num] + if i == self.num_dec_layers: + layer_drop_path = dec_dpr[: depths[layer_num]] + else: + start = self.num_enc_layers + 1 + layer_drop_path = dec_dpr[sum(depths[start:layer_num]) : sum(depths[start : layer_num + 1])] + layer = BasicUFormerLayer( + dim=layer_dim, + input_resolution=layer_input_resolution, + depth=layer_depth, + num_heads=num_heads[layer_num], + win_size=win_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=layer_drop_path, + norm_layer=nn.LayerNorm, + token_projection=token_projection, + token_mlp=token_mlp, + shift_flag=shift_flag, + modulator=modulator, + cross_modulator=cross_modulator, + ) + self.decoder_layers.add_module(layer_name, layer) + + self.apply(init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {"absolute_pos_embed"} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"relative_position_bias_table"} + + def extra_repr(self) -> str: + return f"embedding_dim={self.embedding_dim}, token_projection={self.token_projection}, token_mlp={self.mlp},win_size={self.win_size}" + + def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Performs forward pass of :class:`UFormer`. + + Parameters + ---------- + input : torch.Tensor + mask : torch.Tensor, optional + + Returns + ------- + torch.Tensor + """ + # Input Projection + output = self.input_proj(input) + output = self.pos_drop(output) + + # Encoder + stack = [] + for encoder_layer, downsample in zip(self.encoder_layers, self.downsamples): + output = encoder_layer(output, mask=mask) + stack.append(output) + output = downsample(output) + # Bottleneck + output = self.bottleneck(output, mask=mask) + + # Decoder + for decoder_layer, upsample in zip(self.decoder_layers, self.upsamples): + downsampled_output = stack.pop() + output = upsample(output) + + output = torch.cat([output, downsampled_output], -1) + output = decoder_layer(output, mask=mask) + + # Output Projection + output = self.output_proj(output) + if self.in_channels != self.out_channels: + input = self.conv_out(input) + return input + output + + +class UFormerModel(nn.Module): + """U-Former model with normalization and padding operations. + + Parameters + ---------- + patch_size : int + Size of the patch. Default: 256. + in_channels : int + Number of input channels. Default: 2. + out_channels : int, optional + Number of output channels. Default: None. + embedding_dim : int + Size of the feature embedding. Default: 32. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4, 8). + bottleneck_depth : int + Default: 16. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + normalized : bool + Whether to apply normalization before and denormalization after the forward pass. Default: True. + **kwargs: Other keyword arguments to pass to the parent constructor. + """ + + def __init__( + self, + patch_size: int = 256, + in_channels: int = 2, + out_channels: Optional[int] = None, + embedding_dim: int = 32, + encoder_depths: tuple[int, ...] = (2, 2, 2, 2), + encoder_num_heads: tuple[int, ...] = (1, 2, 4, 8), + bottleneck_depth: int = 2, + bottleneck_num_heads: int = 16, + win_size: int = 8, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.1, + patch_norm: bool = True, + token_projection: AttentionTokenProjectionType = AttentionTokenProjectionType.LINEAR, + token_mlp: LeWinTransformerMLPTokenType = LeWinTransformerMLPTokenType.LEFF, + shift_flag: bool = True, + modulator: bool = False, + cross_modulator: bool = False, + normalized: bool = True, + ) -> None: + """Inits :class:`UFormer`. + + Parameters + ---------- + patch_size : int + Size of the patch. Default: 256. + in_channels : int + Number of input channels. Default: 2. + out_channels : int, optional + Number of output channels. Default: None. + embedding_dim : int + Size of the feature embedding. Default: 32. + encoder_depths : tuple + Number of layers for each stage of the encoder of the U-former, from top to bottom. Default: (2, 2, 2, 2). + encoder_num_heads : tuple + Number of attention heads for each layer of the encoder of the U-former, from top to bottom. + Default: (1, 2, 4, 8). + bottleneck_depth : int + Default: 16. + bottleneck_num_heads : int + Default: 2. + win_size : int + Window size for the attention mechanism. Default: 8. + mlp_ratio : float + Ratio of the hidden dimension size to the embedding dimension size in the MLP layers. Default: 4.0. + qkv_bias : bool + Whether to use bias in the query, key, and value projections of the attention mechanism. Default: True. + qk_scale : float + Scale factor for the query and key projection vectors. + If set to None, will use the default value of 1 / sqrt(embedding_dim). Default: None. + drop_rate : float + Dropout rate for the token-level dropout layer. Default: 0.0. + attn_drop_rate : float + Dropout rate for the attention score matrix. Default: 0.0. + drop_path_rate : float + Dropout rate for the stochastic depth regularization. Default: 0.1. + patch_norm : bool + Whether to use normalization for the patch embeddings. Default: True. + token_projection : AttentionTokenProjectionType + Type of token projection. Must be one of ["linear", "conv"]. Default: AttentionTokenProjectionType.LINEAR. + token_mlp : LeWinTransformerMLPTokenType + Type of token-level MLP. Must be one of ["leff", "mlp", "ffn"]. Default: LeWinTransformerMLPTokenType.LEFF. + shift_flag : bool + Whether to use shift operation in the local attention mechanism. Default: True. + modulator : bool + Whether to use a modulator in the attention mechanism. Default: False. + cross_modulator : bool + Whether to use cross-modulation in the attention mechanism. Default: False. + normalized : bool + Whether to apply normalization before and denormalization after the forward pass. Default: True. + **kwargs: Other keyword arguments to pass to the parent constructor. + """ + super().__init__() + + self.uformer = UFormer( + patch_size, + in_channels, + out_channels, + embedding_dim, + encoder_depths, + encoder_num_heads, + bottleneck_depth, + bottleneck_num_heads, + win_size, + mlp_ratio, + qkv_bias, + qk_scale, + drop_rate, + attn_drop_rate, + drop_path_rate, + patch_norm, + token_projection, + token_mlp, + shift_flag, + modulator, + cross_modulator, + ) + self.normalized = normalized + self.padding_factor = win_size * (2 ** len(encoder_depths)) + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Performs forward pass of :class:`UFormer`. + + Parameters + ---------- + x : torch.Tensor + mask : torch.Tensor, optional + + Returns + ------- + torch.Tensor + """ + x, _, wpad, hpad = pad_to_square(x, self.padding_factor) + if self.normalized: + x, mean, std = norm(x) + x = self.uformer(x, mask) + if self.normalized: + x = unnorm(x, mean, std) + x = unpad_to_original(x, hpad, wpad) + return x diff --git a/direct/nn/transformers/utils.py b/direct/nn/transformers/utils.py new file mode 100644 index 00000000..c0c8d12d --- /dev/null +++ b/direct/nn/transformers/utils.py @@ -0,0 +1,221 @@ +# Copyright (c) DIRECT Contributors + +"""DIRECT module containing utility functions for the transformers models.""" + +from __future__ import annotations + +from math import ceil, floor + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + +__all__ = ["init_weights", "norm", "pad_to_divisible", "pad_to_square", "unnorm", "unpad_to_original", "DropoutPath"] + + +def pad_to_divisible(x: torch.Tensor, pad_size: tuple[int, ...]) -> tuple[torch.Tensor, tuple[tuple[int, int], ...]]: + """Pad the input tensor with zeros to make its spatial dimensions divisible by the specified pad size. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (*, spatial_1, spatial_2, ..., spatial_N), where spatial dimensions can vary in number. + pad_size : tuple[int, ...] + Patch size to make each spatial dimension divisible by. This is a tuple of integers for each spatial dimension. + + Returns + ------- + tuple + Containing the padded tensor and a tuple of tuples indicating the number of pixels padded in each spatial dimension. + """ + pads = [] + for dim, p_dim in zip(x.shape[-len(pad_size) :], pad_size): + pad_before = (p_dim - dim % p_dim) % p_dim / 2 + pads.append((floor(pad_before), ceil(pad_before))) + + # Reverse and flatten pads to match torch's expected (pad_n_before, pad_n_after, ..., pad_1_before, pad_1_after) format + flat_pads = tuple(val for sublist in pads[::-1] for val in sublist) + x = F.pad(x, flat_pads) + + return x, tuple(pads) + + +def unpad_to_original(x: torch.Tensor, *pads: tuple[int, int]) -> torch.Tensor: + """Remove the padding added to the input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor with padded spatial dimensions. + pads : tuple[int, int] + A tuple of (pad_before, pad_after) for each spatial dimension. + + Returns + ------- + torch.Tensor + Tensor with the padding removed, matching the shape of the original input tensor before padding. + """ + slices = [slice(None)] * (x.ndim - len(pads)) # Keep the batch and channel dimensions + for i, (pad_before, pad_after) in enumerate(pads): + slices.append(slice(pad_before, x.shape[-len(pads) + i] - pad_after)) + + return x[tuple(slices)] + + +def pad_to_square( + inp: torch.Tensor, factor: float +) -> tuple[torch.Tensor, torch.Tensor, tuple[int, int], tuple[int, int]]: + """Pad a tensor to a square shape with a given factor. + + Parameters + ---------- + inp : torch.Tensor + The input tensor to pad to square shape. Expected shape is (\*, height, width). + factor : float + The factor to which the input tensor will be padded. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor, tuple[int, int], tuple[int, int]] + A tuple of two tensors, the first is the input tensor padded to a square shape, and the + second is the corresponding mask for the padded tensor. + + Examples + -------- + 1. + >>> x = torch.rand(1, 3, 224, 192) + >>> padded_x, mask, wpad, hpad = pad_to_square(x, factor=16.0) + >>> padded_x.shape, mask.shape + (torch.Size([1, 3, 224, 224]), torch.Size([1, 1, 224, 224])) + 2. + >>> x = torch.rand(3, 13, 2, 234, 180) + >>> padded_x, mask, wpad, hpad = pad_to_square(x, factor=16.0) + >>> padded_x.shape, wpad, hpad + (torch.Size([3, 13, 2, 240, 240]), (30, 30), (3, 3)) + """ + channels, h, w = inp.shape[-3:] + + # Calculate the maximum size and pad to the next multiple of the factor + x = int(ceil(max(h, w) / float(factor)) * factor) + + # Create a tensor of zeros with the maximum size and copy the input tensor into the center + img = torch.zeros(*inp.shape[:-3], channels, x, x, device=inp.device).type_as(inp) + mask = torch.zeros(*((1,) * (img.ndim - 3)), 1, x, x, device=inp.device).type_as(inp) + + # Compute the offset and copy the input tensor into the center of the zero tensor + offset_h = (x - h) // 2 + offset_w = (x - w) // 2 + hpad = (offset_h, offset_h + h) + wpad = (offset_w, offset_w + w) + img[..., hpad[0] : hpad[1], wpad[0] : wpad[1]] = inp.clone() + mask[..., hpad[0] : hpad[1], wpad[0] : wpad[1]].fill_(1.0) + # Return the padded tensor and the corresponding mask, and padding in spatial dimensions + return ( + img, + 1 - mask, + (wpad[0], wpad[1] - w + (1 if w % 2 != 0 else 0)), + (hpad[0], hpad[1] - h + (1 if h % 2 != 0 else 0)), + ) + + +def norm(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Normalize the input tensor by subtracting the mean and dividing by the standard deviation + across each channel and pixel for arbitrary spatial dimensions. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, *spatial_dims), where spatial_dims can vary in number (e.g., 2D, 3D, etc.). + + Returns + ------- + tuple + Containing the normalized tensor, mean tensor, and standard deviation tensor. + """ + # Flatten spatial dimensions and compute mean and std across them + spatial_dims = x.shape[2:] # Get all spatial dimensions + flattened = x.view(x.shape[0], x.shape[1], -1) # Flatten the spatial dimensions for mean/std calculation + + mean = flattened.mean(-1, keepdim=True).view(x.shape[0], x.shape[1], *([1] * len(spatial_dims))) + std = flattened.std(-1, keepdim=True).view(x.shape[0], x.shape[1], *([1] * len(spatial_dims))) + + # Normalize + x = (x - mean) / std + + return x, mean, std + + +def unnorm(x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: + """Denormalize the input tensor by multiplying by the standard deviation and adding the mean + for arbitrary spatial dimensions. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, *spatial_dims), where spatial_dims can vary in number. + mean : torch.Tensor + Mean tensor obtained during normalization. + std : torch.Tensor + Standard deviation tensor obtained during normalization. + + Returns + ------- + torch.Tensor + Tensor with the same shape as the original input tensor, but denormalized. + """ + return x * std + mean + + +def init_weights(m: nn.Module) -> None: + """Initializes the weights of the network using a truncated normal distribution. + + Parameters + ---------- + m : nn.Module + A module of the network whose weights need to be initialized. + """ + + if isinstance(m, nn.Linear): + init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + init.constant_(m.bias, 0) + init.constant_(m.weight, 1.0) + + +class DropoutPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + """Inits :class:`DropoutPath`. + + Parameters + ---------- + drop_prob : float + Probability of dropping a residual connection. Default: 0.0. + scale_by_keep : bool + Whether to scale the remaining activations by 1 / (1 - drop_prob) to maintain the expected value of + the activations. Default: True. + """ + super(DropoutPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + @staticmethod + def _dropout_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def forward(self, x): + return self._dropout_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"dropout_prob={round(self.drop_prob, 3):0.3f}" diff --git a/direct/nn/transformers/vit.py b/direct/nn/transformers/vit.py new file mode 100644 index 00000000..a9163517 --- /dev/null +++ b/direct/nn/transformers/vit.py @@ -0,0 +1,1333 @@ +# Copyright (c) DIRECT Contributors + +"""DIRECT Vision Transformer module. + +Implementation of Vision Transformer model [1, 2]_ in PyTorch. + +Code borrowed from [3]_ which uses code from timm [4]_. + +References +---------- +.. [1] Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, + M., Heigold, G., Gelly, S., Uszkoreit, J., Houlsby, N.: An Image is Worth 16x16 Words: + Transformers for Image Recognition at Scale, http://arxiv.org/abs/2010.11929, (2021). +.. [2] Steiner, A., Kolesnikov, A., Zhai, X., Wightman, R., Uszkoreit, J., Beyer, L.: How to train your ViT? Data, + Augmentation, and Regularization in Vision Transformers, http://arxiv.org/abs/2106.10270, (2022). +.. [3] https://github.com/facebookresearch/convit +.. [4] https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + +from direct.constants import COMPLEX_SIZE +from direct.nn.transformers.utils import DropoutPath, init_weights, norm, pad_to_divisible, unnorm, unpad_to_original +from direct.types import DirectEnum + +__all__ = ["VisionTransformer2D", "VisionTransformer3D"] + + +class VisionTransformerDimensionality(DirectEnum): + + TWO_DIMENSIONAL = "2D" + THREE_DIMENSIONAL = "3D" + + +class MLP(nn.Module): + """MLP layer with dropout and activation for Vision Transformer. + + Parameters + ---------- + in_features : int + Size of the input feature. + hidden_features : int, optional + Size of the hidden layer feature. If None, then hidden_features = in_features. Default: None. + out_features : int, optional + Size of the output feature. If None, then out_features = in_features. Default: None. + act_layer : nn.Module, optional + Activation layer to be used. Default: nn.GELU. + drop : float, optional + Dropout probability. Default: 0. + """ + + def __init__( + self, + in_features: int, + hidden_features: int = None, + out_features: int = None, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, + ) -> None: + """Inits :class:`MLP`. + + Parameters + ---------- + in_features : int + Size of the input feature. + hidden_features : int, optional + Size of the hidden layer feature. If None, then hidden_features = in_features. Default: None. + out_features : int, optional + Size of the output feature. If None, then out_features = in_features. Default: None. + act_layer : nn.Module, optional + Activation layer to be used. Default: nn.GELU. + drop : float, optional + Dropout probability. Default: 0. + """ + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.apply(init_weights) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`MLP`. + + Parameters + ---------- + x : torch.Tensor + Input tensor to the network. + + Returns + ------- + torch.Tensor + Output tensor of the network. + + """ + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class GPSA(nn.Module): + """Gated Positional Self-Attention module for Vision Transformer. + + Parameters + ---------- + dimensionality : VisionTransformerDimensionality + The dimensionality of the input data. + dim : int + Dimensionality of the input embeddings. + num_heads : int + Number of attention heads. + qkv_bias : bool + If True, include bias terms in the query, key, and value projections. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout probability for attention weights. + proj_drop : float + Dropout probability for output tensor. + locality_strength : float + Strength of locality assumption in initialization. + use_local_init : bool + If True, use the locality-based initialization. + grid_size : tuple[int,int], optional + The size of the grid (height, width) for relative position encoding. + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + locality_strength: float = 1.0, + use_local_init: bool = True, + grid_size=None, + ) -> None: + """Inits :class:`GPSA`. + + Parameters + ---------- + dim : int + Dimensionality of the input embeddings. + num_heads : int + Number of attention heads. + qkv_bias : bool + If True, include bias terms in the query, key, and value projections. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout probability for attention weights. + proj_drop : float + Dropout probability for output tensor. + locality_strength : float + Strength of locality assumption in initialization. + use_local_init : bool + If True, use the locality-based initialization. + grid_size : tuple[int,int], optional + The size of the grid (height, width) for relative position encoding. + """ + super().__init__() + self.num_heads = num_heads + self.dim = dim + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.pos_proj = nn.Linear(3, num_heads) + self.proj_drop = nn.Dropout(proj_drop) + self.locality_strength = locality_strength + self.gating_param = nn.Parameter(torch.ones(self.num_heads)) + self.apply(init_weights) + if use_local_init: + self.local_init(locality_strength=locality_strength) + self.current_grid_size = grid_size + + def get_attention(self, x: torch.Tensor) -> torch.Tensor: + """Compute the attention scores for each patch in x. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, N, C). + + Returns + ------- + torch.Tensor + Attention scores for each patch in x. + """ + B, N, C = x.shape + + k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + pos_score = self.pos_proj(self.get_rel_indices()).expand(B, -1, -1, -1).permute(0, 3, 1, 2) + patch_score = (q @ k.transpose(-2, -1)) * self.scale + patch_score = patch_score.softmax(dim=-1) + pos_score = pos_score.softmax(dim=-1) + + gating = self.gating_param.view(1, -1, 1, 1) + attn = (1.0 - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score + attn = attn / attn.sum(dim=-1).unsqueeze(-1) + attn = self.attn_drop(attn) + return attn + + @abstractmethod + def local_init(self, locality_strength: Optional[float] = 1.0) -> None: + pass + + @abstractmethod + def get_rel_indices(self) -> None: + pass + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`GPSA`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor: + """ + B, N, C = x.shape + + attn = self.get_attention(x) + v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GPSA2D(GPSA): + """Gated Positional Self-Attention module for Vision Transformer. + + Parameters + ---------- + dim : int + Dimensionality of the input embeddings. + num_heads : int + Number of attention heads. + qkv_bias : bool + If True, include bias terms in the query, key, and value projections. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout probability for attention weights. + proj_drop : float + Dropout probability for output tensor. + locality_strength : float + Strength of locality assumption in initialization. + use_local_init : bool + If True, use the locality-based initialization. + grid_size : tuple[int,int], optional + The size of the grid (height, width) for relative position encoding. + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + locality_strength: float = 1.0, + use_local_init: bool = True, + grid_size=None, + ) -> None: + """Inits :class:`GPSA`. + + Parameters + ---------- + dim : int + Dimensionality of the input embeddings. + num_heads : int + Number of attention heads. + qkv_bias : bool + If True, include bias terms in the query, key, and value projections. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout probability for attention weights. + proj_drop : float + Dropout probability for output tensor. + locality_strength : float + Strength of locality assumption in initialization. + use_local_init : bool + If True, use the locality-based initialization. + grid_size : tuple[int,int], optional + The size of the grid (height, width) for relative position encoding. + """ + super().__init__( + dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + locality_strength=locality_strength, + use_local_init=use_local_init, + grid_size=grid_size, + ) + + def local_init(self, locality_strength: Optional[float] = 1.0) -> None: + """Initializes the parameters for a locally connected attention mechanism. + + Parameters + ---------- + locality_strength : float, optional + A scalar multiplier for the locality distance. Default: 1.0. + + Returns + ------- + None + """ + self.v.weight.data.copy_(torch.eye(self.dim)) + locality_distance = 1 # max(1,1/locality_strength**.5) + + kernel_size = int(self.num_heads**0.5) + center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2 + + # compute the positional projection weights with locality distance + for h1 in range(kernel_size): + for h2 in range(kernel_size): + position = h1 + kernel_size * h2 + self.pos_proj.weight.data[position, 2] = -1 + self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance + self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance + self.pos_proj.weight.data *= locality_strength + + def get_rel_indices(self) -> None: + """Generates relative positional indices for each patch in the input. + + Returns + ------- + None + """ + H, W = self.current_grid_size + N = H * W + rel_indices = torch.zeros(1, N, N, 3) + indx = torch.arange(W).view(1, -1) - torch.arange(W).view(-1, 1) + indx = indx.repeat(H, H) + indy = torch.arange(H).view(1, -1) - torch.arange(H).view(-1, 1) + indy = indy.repeat_interleave(W, dim=0).repeat_interleave(W, dim=1) + indd = indx**2 + indy**2 + rel_indices[:, :, :, 2] = indd.unsqueeze(0) + rel_indices[:, :, :, 1] = indy.unsqueeze(0) + rel_indices[:, :, :, 0] = indx.unsqueeze(0) + + return rel_indices.to(self.v.weight.device) + + +class GPSA3D(GPSA): + """Gated Positional Self-Attention module for Vision Transformer (3D variant). + + Parameters + ---------- + dim : int + Dimensionality of the input embeddings. + num_heads : int + Number of attention heads. + qkv_bias : bool + If True, include bias terms in the query, key, and value projections. + qk_scale : float + Scale factor for query and key. + attn_drop : float + Dropout probability for attention weights. + proj_drop : float + Dropout probability for output tensor. + locality_strength : float + Strength of locality assumption in initialization. + use_local_init : bool + If True, use the locality-based initialization. + grid_size : tuple[int, int, int], optional + The size of the grid (depth, height, width) for relative position encoding. + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + locality_strength: float = 1.0, + use_local_init: bool = True, + grid_size=None, + ) -> None: + super().__init__( + dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + locality_strength=locality_strength, + use_local_init=use_local_init, + grid_size=grid_size, + ) + + def local_init(self, locality_strength: Optional[float] = 1.0) -> None: + self.v.weight.data.copy_(torch.eye(self.dim)) + locality_distance = 1 + + kernel_size = int(self.num_heads ** (1 / 3)) + center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2 + + for h1 in range(kernel_size): + for h2 in range(kernel_size): + for h3 in range(kernel_size): + position = h1 + kernel_size * (h2 + kernel_size * h3) + self.pos_proj.weight.data[position, 2] = -1 + self.pos_proj.weight.data[position, 1] = 2 * (h2 - center) * locality_distance + self.pos_proj.weight.data[position, 0] = 2 * (h3 - center) * locality_distance + self.pos_proj.weight.data *= locality_strength + + def get_rel_indices(self) -> torch.Tensor: + D, H, W = self.current_grid_size + N = D * H * W + rel_indices = torch.zeros(1, N, N, 3) + + indz = torch.arange(D).view(1, -1) - torch.arange(D).view(-1, 1) + indz = indz.repeat(H * W, H * W) + + indx = torch.arange(W).view(1, -1) - torch.arange(W).view(-1, 1) + indx = indx.repeat(D * H, D * H) + + indy = torch.arange(H).view(1, -1) - torch.arange(H).view(-1, 1) + indy = indy.repeat(D * W, D * W) + + indd = indz**2 + indx**2 + indy**2 + rel_indices[:, :, :, 2] = indd.unsqueeze(0) + rel_indices[:, :, :, 1] = indy.unsqueeze(0) + rel_indices[:, :, :, 0] = indx.unsqueeze(0) + + return rel_indices.to(self.v.weight.device) + + +class MHSA(nn.Module): + """Multi-Head Self-Attention (MHSA) module. + + Parameters + ---------- + dim : int + Number of input features. + num_heads : int + Number of heads in the attention mechanism. Default is 8. + qkv_bias : bool + If True, bias is added to the query, key and value projections. Default is False. + qk_scale : float or None + Scaling factor for the query-key dot product. If None, it is set to + head_dim ** -0.5 where head_dim = dim // num_heads. Default is None. + attn_drop : float + Dropout rate for the attention weights. Default is 0. + proj_drop : float + Dropout rate for the output of the module. Default is 0. + grid_size : tuple[int, int] or None + If not None, the module is designed to work with a grid of + patches. grid_size is a tuple of the form (H, W) where H and W are the number of patches in + the vertical and horizontal directions respectively. Default is None. + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + grid_size: tuple[int, int] = None, + ) -> None: + """Inits :class:`MHSA`. + + Parameters + ---------- + dim : int + Number of input features. + num_heads : int + Number of heads in the attention mechanism. Default is 8. + qkv_bias : bool + If True, bias is added to the query, key and value projections. Default is False. + qk_scale : float or None + Scaling factor for the query-key dot product. If None, it is set to + head_dim ** -0.5 where head_dim = dim // num_heads. Default is None. + attn_drop : float + Dropout rate for the attention weights. Default is 0. + proj_drop : float + Dropout rate for the output of the module. Default is 0. + grid_size : tuple[int, int] or None + If not None, the module is designed to work with a grid of + patches. grid_size is a tuple of the form (H, W) where H and W are the number of patches in + the vertical and horizontal directions respectively. Default is None. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.apply(init_weights) + self.current_grid_size = grid_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`MHSA`. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, N, C). + + Returns + ------- + torch.Tensor + Output tensor of shape (B, N, C). + """ + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class VisionTransformerBlock(nn.Module): + """A single transformer block used in the VisionTransformer model. + + Parameters + ---------- + dimensionality : VisionTransformerDimensionality + The dimensionality of the input data. + dim : int + The feature dimension. + num_heads : int + The number of attention heads. + mlp_ratio : float, optional + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool, optional + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float, optional + The scale factor for the query-key dot product. Default: None. + drop : float, optional + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop : float, optional + The dropout probability for the attention layer. Default: 0.0. + dropout_path : float, optional + The dropout probability for the dropout path. Default: 0.0. + act_layer : nn.Module, optional + The activation layer used in the MLP. Default: nn.GELU. + norm_layer : nn.Module, optional + The normalization layer used in the block. Default: nn.LayerNorm. + use_gpsa : bool, optional + Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True. + **kwargs: Additional arguments for the attention layer. + """ + + def __init__( + self, + dimensionality: VisionTransformerDimensionality, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop: float = 0.0, + attn_drop: float = 0.0, + dropout_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + use_gpsa: bool = True, + **kwargs, + ) -> None: + """Inits :class:`VisionTransformerBlock`. + + Parameters + ---------- + dimensionality : VisionTransformerDimensionality + The dimensionality of the input data. + dim : int + The feature dimension. + num_heads : int + The number of attention heads. + mlp_ratio : float, optional + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool, optional + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float, optional + The scale factor for the query-key dot product. Default: None. + drop : float, optional + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop : float, optional + The dropout probability for the attention layer. Default: 0.0. + dropout_path : float, optional + The dropout probability for the dropout path. Default: 0.0. + act_layer : nn.Module, optional + The activation layer used in the MLP. Default: nn.GELU. + norm_layer : nn.Module, optional + The normalization layer used in the block. Default: nn.LayerNorm. + use_gpsa : bool, optional + Whether to use the GPSA attention layer. If set to False, the MHSA layer will be used. Default: True. + **kwargs: Additional arguments for the attention layer. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.use_gpsa = use_gpsa + if self.use_gpsa: + self.attn = (GPSA2D if dimensionality == VisionTransformerDimensionality.TWO_DIMENSIONAL else GPSA3D)( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + **kwargs, + ) + else: + self.attn = MHSA( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + **kwargs, + ) + self.dropout_path = DropoutPath(dropout_path) if dropout_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x: torch.Tensor, grid_size: tuple[int, int]) -> torch.Tensor: + """Forward pass for the :class:`VisionTransformerBlock`. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + grid_size : tuple[int, int] + The size of the grid used by the attention layer. + + Returns + ------- + torch.Tensor: The output tensor. + """ + self.attn.current_grid_size = grid_size + x = x + self.dropout_path(self.attn(self.norm1(x))) + x = x + self.dropout_path(self.mlp(self.norm2(x))) + + return x + + +class PatchEmbedding(nn.Module): + """Image to Patch Embedding.""" + + def __init__( + self, patch_size, in_channels, embedding_dim, dimensionality: VisionTransformerDimensionality + ) -> None: + """Inits :class:`PatchEmbedding` module for Vision Transformer. + + Parameters + ---------- + patch_size : int or tuple[int, int] + The patch size. If an int is provided, the patch will be a square. + in_channels : int + Number of input channels. + embedding_dim : int + Dimension of the output embedding. + dimensionality : VisionTransformerDimensionality + The dimensionality of the input data. + """ + super().__init__() + self.proj = (nn.Conv2d if dimensionality == VisionTransformerDimensionality.TWO_DIMENSIONAL else nn.Conv3d)( + in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size + ) + self.apply(init_weights) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`PatchEmbedding`. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + torch.Tensor + Patch embedding. + """ + x = self.proj(x) + return x + + +class VisionTransformer(nn.Module): + """Vision Transformer model. + + Parameters + ---------- + dimensionality : VisionTransformerDimensionality + The dimensionality of the input data. + average_img_size : int or tuple[int, int] or tuple[int, int, int] + The average size of the input image. If an int is provided, this will be determined by the + `dimensionality`, i.e., (average_img_size, average_img_size) for 2D and + (average_img_size, average_img_size, average_img_size) for 3D. Default: 320. + patch_size : int or tuple[int, int] or tuple[int, int, int] + The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e., + (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16. + in_channels : int + Number of input channels. Default: COMPLEX_SIZE. + out_channels : int or None + Number of output channels. If None, this will be set to `in_channels`. Default: None. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa: bool + Whether to use GPSA layer. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + + def __init__( + self, + dimensionality: VisionTransformerDimensionality, + average_img_size: int | tuple[int, int] | tuple[int, int, int] = 320, + patch_size: int | tuple[int, int] | tuple[int, int, int] = 16, + in_channels: int = COMPLEX_SIZE, + out_channels: int = None, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + use_gpsa: bool = True, + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + normalized: bool = True, + ) -> None: + """Inits :class:`VisionTransformer`. + + Parameters + ---------- + dimensionality : VisionTransformerDimensionality + The dimensionality of the input data. + average_img_size : int or tuple[int, int] or tuple[int, int, int] + The average size of the input image. If an int is provided, this will be determined by the + `dimensionality`, i.e., (average_img_size, average_img_size) for 2D and + (average_img_size, average_img_size, average_img_size) for 3D. Default: 320. + patch_size : int or tuple[int, int] or tuple[int, int, int] + The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e., + (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16. + in_channels : int + Number of input channels. Default: COMPLEX_SIZE. + out_channels : int or None + Number of output channels. If None, this will be set to `in_channels`. Default: None. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa: bool + Whether to use GPSA layer. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + super().__init__() + + self.dimensionality = dimensionality + + self.depth = depth + embedding_dim *= num_heads + self.num_features = embedding_dim # num_features for consistency with other models + self.locality_strength = locality_strength + self.use_pos_embedding = use_pos_embedding + + if isinstance(average_img_size, int): + if self.dimensionality == VisionTransformerDimensionality.TWO_DIMENSIONAL: + img_size = (average_img_size, average_img_size) + else: + img_size = (average_img_size, average_img_size, average_img_size) + else: + if len(average_img_size) != ( + 2 if self.dimensionality == VisionTransformerDimensionality.TWO_DIMENSIONAL else 3 + ): + raise ValueError( + f"average_img_size should have length 2 for 2D and 3 for 3D, got {len(average_img_size)}." + ) + img_size = average_img_size + + if isinstance(patch_size, int): + if self.dimensionality == VisionTransformerDimensionality.TWO_DIMENSIONAL: + self.patch_size = (patch_size, patch_size) + else: + self.patch_size = (patch_size, patch_size, patch_size) + else: + if len(patch_size) != (2 if self.dimensionality == VisionTransformerDimensionality.TWO_DIMENSIONAL else 3): + raise ValueError(f"patch_size should have length 2 for 2D and 3 for 3D, got {len(patch_size)}.") + self.patch_size = patch_size + + self.in_channels = in_channels + self.out_channels = out_channels if out_channels else in_channels + + self.patch_embed = PatchEmbedding( + patch_size=self.patch_size, + in_channels=in_channels, + embedding_dim=embedding_dim, + dimensionality=dimensionality, + ) + + self.pos_drop = nn.Dropout(p=drop_rate) + + if self.use_pos_embedding: + self.pos_embed = nn.Parameter( + torch.zeros(1, embedding_dim, *[img_size[i] // self.patch_size[i] for i in range(len(img_size))]) + ) + + init.trunc_normal_(self.pos_embed, std=0.02) + + dpr = [x.item() for x in torch.linspace(0, dropout_path_rate, depth)] # stochastic depth decay rule + + self.blocks = nn.ModuleList( + [ + VisionTransformerBlock( + dimensionality=dimensionality, + dim=embedding_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + dropout_path=dpr[i], + norm_layer=nn.LayerNorm, + use_gpsa=use_gpsa, + **({"locality_strength": locality_strength} if use_gpsa else {}), + ) + for i in range(depth) + ] + ) + + self.normalized = normalized + + self.norm = nn.LayerNorm(embedding_dim) + # head + self.feature_info = [dict(num_chs=embedding_dim, reduction=0, module="head")] + self.head = nn.Linear(self.num_features, self.out_channels * np.prod(self.patch_size)) + + self.head.apply(init_weights) + + def get_head(self) -> nn.Module: + """Returns the head of the model. + + Returns + ------- + nn.Module + """ + return self.head + + def reset_head(self) -> None: + """Resets the head of the model.""" + self.head = nn.Linear(self.num_features, self.out_channels * np.prod(self.patch_size)) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the feature extraction part of the model. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + + Returns + ------- + torch.Tensor + """ + x = self.patch_embed(x) + size = x.shape[2:] + + if self.use_pos_embedding: + pos_embed = F.interpolate( + self.pos_embed, + size=size, + mode=( + "bilinear" + if self.dimensionality == VisionTransformerDimensionality.TWO_DIMENSIONAL + else "trilinear" + ), + align_corners=False, + ) + x = x + pos_embed + + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + for _, block in enumerate(self.blocks): + x = block(x, size) + + x = self.norm(x) + + return x + + @abstractmethod + def seq2img(self, x: torch.Tensor, img_size: tuple[int, ...]) -> torch.Tensor: + """Converts the sequence patches tensor to an image tensor. + + Parameters + ---------- + x : torch.Tensor + The sequence tensor. + img_size : tuple[int, ...] + The size of the image tensor. + + Returns + ------- + torch.Tensor + The image tensor. + """ + pass + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`VisionTransformer`. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + torch.Tensor + """ + x, pads = pad_to_divisible(x, self.patch_size) + + size = x.shape[2:] + + if self.normalized: + x, mean, std = norm(x) + + x = self.forward_features(x) + x = self.head(x) + x = self.seq2img(x, size) + + if self.normalized: + x = unnorm(x, mean, std) + + x = unpad_to_original(x, *pads) + + return x + + +class VisionTransformer2D(VisionTransformer): + """Vision Transformer model for 2D data. + + Parameters + ---------- + average_img_size : int or tuple[int, int] + The average size of the input image. If an int is provided, this will be determined by the + `dimensionality`, i.e., (average_img_size, average_img_size) for 2D and + (average_img_size, average_img_size, average_img_size) for 3D. Default: 320. + patch_size : int or tuple[int, int] + The size of the patch. If an int is provided, this will be determined by the `dimensionality`, i.e., + (patch_size, patch_size) for 2D and (patch_size, patch_size, patch_size) for 3D. Default: 16. + in_channels : int + Number of input channels. Default: COMPLEX_SIZE. + out_channels : int or None + Number of output channels. If None, this will be set to `in_channels`. Default: None. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa: bool + Whether to use GPSA layer. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + + def __init__( + self, + average_img_size: int | tuple[int, int] = 320, + patch_size: int | tuple[int, int] = 16, + in_channels: int = COMPLEX_SIZE, + out_channels: int = None, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + use_gpsa: bool = True, + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + normalized: bool = True, + ) -> None: + """Inits :class:`VisionTransformer2D`. + + Parameters + ---------- + average_img_size : int or tuple[int, int] + The average size of the input image. If an int is provided, this will be defined as + (average_img_size, average_img_size). Default: 320. + patch_size : int or tuple[int, int] + The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size). Default: 16. + in_channels : int + Number of input channels. Default: COMPLEX_SIZE. + out_channels : int or None + Number of output channels. If None, this will be set to `in_channels`. Default: None. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa: bool + Whether to use GPSA layer. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + super().__init__( + dimensionality=VisionTransformerDimensionality.TWO_DIMENSIONAL, + average_img_size=average_img_size, + patch_size=patch_size, + in_channels=in_channels, + out_channels=out_channels, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + dropout_path_rate=dropout_path_rate, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + ) + + def seq2img(self, x: torch.Tensor, img_size: tuple[int, ...]) -> torch.Tensor: + """Converts the sequence patches tensor to an image tensor. + + Parameters + ---------- + x : torch.Tensor + The sequence tensor. + img_size : tuple[int, ...] + The size of the image tensor. + + Returns + ------- + torch.Tensor + The image tensor. + """ + x = x.view(x.shape[0], x.shape[1], self.out_channels, self.patch_size[0], self.patch_size[1]) + x = x.chunk(x.shape[1], dim=1) + x = torch.cat(x, dim=4).permute(0, 1, 2, 4, 3) + x = x.chunk(img_size[0] // self.patch_size[0], dim=3) + x = torch.cat(x, dim=4).permute(0, 1, 2, 4, 3).squeeze(1) + + return x + + +class VisionTransformer3D(VisionTransformer): + """Vision Transformer model for 3D data. + + Parameters + ---------- + average_img_size : int or tuple[int, int, int] + The average size of the input image. If an int is provided, this will be defined as + (average_img_size, average_img_size, average_img_size). Default: 320. + patch_size : int or tuple[int, int, int] + The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size, patch_size). + Default: 16. + in_channels : int + Number of input channels. Default: COMPLEX_SIZE. + out_channels : int or None + Number of output channels. If None, this will be set to `in_channels`. Default: None. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa: bool + Whether to use GPSA layer. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + + def __init__( + self, + average_img_size: int | tuple[int, int, int] = 320, + patch_size: int | tuple[int, int, int] = 16, + in_channels: int = COMPLEX_SIZE, + out_channels: int = None, + embedding_dim: int = 64, + depth: int = 8, + num_heads: int = 9, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_scale: float = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + use_gpsa: bool = True, + locality_strength: float = 1.0, + use_pos_embedding: bool = True, + normalized: bool = True, + ) -> None: + """Inits :class:`VisionTransformer3D`. + + Parameters + ---------- + average_img_size : int or tuple[int, int, int] + The average size of the input image. If an int is provided, this will be defined as + (average_img_size, average_img_size, average_img_size). Default: 320. + patch_size : int or tuple[int, int, int] + The size of the patch. If an int is provided, this will be defined as (patch_size, patch_size, patch_size). + Default: 16. + in_channels : int + Number of input channels. Default: COMPLEX_SIZE. + out_channels : int or None + Number of output channels. If None, this will be set to `in_channels`. Default: None. + embedding_dim : int + Dimension of the output embedding. + depth : int + Number of transformer blocks. + num_heads : int + Number of attention heads. + mlp_ratio : float + The ratio of hidden dimension size to input dimension size in the MLP layer. Default: 4.0. + qkv_bias : bool + Whether to add bias to the query, key, and value projections. Default: False. + qk_scale : float + The scale factor for the query-key dot product. Default: None. + drop_rate : float + The dropout probability for all dropout layers except dropout_path. Default: 0.0. + attn_drop_rate : float + The dropout probability for the attention layer. Default: 0.0. + dropout_path_rate : float + The dropout probability for the dropout path. Default: 0.0. + use_gpsa: bool + Whether to use GPSA layer. Default: True. + locality_strength : float + The strength of the locality assumption in initialization. Default: 1.0. + use_pos_embedding : bool + Whether to use positional embeddings. Default: True. + normalized : bool + Whether to normalize the input tensor. Default: True. + """ + + super().__init__( + dimensionality=VisionTransformerDimensionality.THREE_DIMENSIONAL, + average_img_size=average_img_size, + patch_size=patch_size, + in_channels=in_channels, + out_channels=out_channels, + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + dropout_path_rate=dropout_path_rate, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + ) + + def seq2img(self, x: torch.Tensor, img_size: tuple[int, ...]) -> torch.Tensor: + """Converts the sequence of 3D patches to a 3D image tensor. + + Parameters + ---------- + x : torch.Tensor + The sequence tensor, where each entry corresponds to a flattened 3D patch. + img_size : tuple of ints + The size of the 3D image tensor (depth, height, width). + + Returns + ------- + torch.Tensor + The reconstructed 3D image tensor. + """ + # Reshape the sequence into patches of shape (batch, num_patches, out_channels, D, H, W) + x = x.view( + x.shape[0], x.shape[1], self.out_channels, self.patch_size[0], self.patch_size[1], self.patch_size[2] + ) + + # Chunk along the sequence dimension (depth, height, width) + depth_chunks = img_size[0] // self.patch_size[0] # Number of chunks along depth + height_chunks = img_size[1] // self.patch_size[1] # Number of chunks along height + width_chunks = img_size[2] // self.patch_size[2] # Number of chunks along width + + # First, chunk along the sequence dimension (width axis) + x = torch.cat(x.chunk(width_chunks, dim=1), dim=5).permute(0, 1, 2, 3, 4, 5) + + # Now, chunk along the height axis + x = torch.cat(x.chunk(height_chunks, dim=1), dim=4).permute(0, 1, 2, 3, 4, 5) + + # Finally, chunk along the depth axis + x = torch.cat(x.chunk(depth_chunks, dim=1), dim=3).permute(0, 1, 2, 3, 4, 5).squeeze(1) + + return x diff --git a/tests/tests_nn/test_transformers.py b/tests/tests_nn/test_transformers.py new file mode 100644 index 00000000..5f3a9bfd --- /dev/null +++ b/tests/tests_nn/test_transformers.py @@ -0,0 +1,274 @@ +# Copyright (c) DIRECT Contributors + +"""Tests for transformers models.""" + +import pytest +import torch + +from direct.nn.transformers.uformer import UFormerModel, AttentionTokenProjectionType, LeWinTransformerMLPTokenType +from direct.nn.transformers.vit import VisionTransformer2D, VisionTransformer3D + + +def create_input(shape): + data = torch.rand(shape).float() + + return data + + +# @pytest.mark.parametrize( +# "shape", +# [ +# [3, 2, 32, 32], +# [3, 2, 16, 16], +# ], +# ) +# @pytest.mark.parametrize( +# "embedding_dim", +# [20], +# ) +# @pytest.mark.parametrize( +# "patch_size", +# [140], +# ) +# @pytest.mark.parametrize( +# "encoder_depths, encoder_num_heads, bottleneck_depth, bottleneck_num_heads", +# [ +# [(2, 2, 2), (1, 2, 4), 1, 8], +# [(2, 2, 2, 2), (1, 2, 4, 8), 2, 8], +# ], +# ) +# @pytest.mark.parametrize( +# "patch_norm", +# [True, False], +# ) +# @pytest.mark.parametrize( +# "win_size", +# [8], +# ) +# @pytest.mark.parametrize( +# "mlp_ratio", +# [2], +# ) +# @pytest.mark.parametrize( +# "qkv_bias", +# [True, False], +# ) +# @pytest.mark.parametrize( +# "qk_scale", +# [None, 0.5], +# ) +# @pytest.mark.parametrize( +# "token_projection", +# [AttentionTokenProjectionType.LINEAR, AttentionTokenProjectionType.CONV], +# ) +# @pytest.mark.parametrize( +# "token_mlp", +# [LeWinTransformerMLPTokenType.FFN, LeWinTransformerMLPTokenType.MLP, LeWinTransformerMLPTokenType.LEFF], +# ) +# def test_uformer( +# shape, +# patch_size, +# embedding_dim, +# encoder_depths, +# encoder_num_heads, +# bottleneck_depth, +# bottleneck_num_heads, +# win_size, +# mlp_ratio, +# patch_norm, +# qkv_bias, +# qk_scale, +# token_projection, +# token_mlp, +# ): +# model = UFormerModel( +# patch_size=patch_size, +# in_channels=2, +# embedding_dim=embedding_dim, +# encoder_depths=encoder_depths, +# encoder_num_heads=encoder_num_heads, +# bottleneck_depth=bottleneck_depth, +# bottleneck_num_heads=bottleneck_num_heads, +# win_size=win_size, +# mlp_ratio=mlp_ratio, +# qkv_bias=qkv_bias, +# qk_scale=qk_scale, +# patch_norm=patch_norm, +# token_projection=token_projection, +# token_mlp=token_mlp, +# ) +# data = create_input(shape).cpu() +# out = model(data) +# assert list(out.shape) == shape + + +@pytest.mark.parametrize( + "shape, average_img_size", + [ + [[1, 3, 128, 128], 128], + [[3, 2, 64, 50], (64, 50)], + ], +) +@pytest.mark.parametrize( + "patch_size", + [16, 8, (16, 10)], +) +@pytest.mark.parametrize( + "embedding_dim", + [6, 12], +) +@pytest.mark.parametrize( + "depth", + [2, 4], +) +@pytest.mark.parametrize( + "num_heads", + [3, 4], +) +@pytest.mark.parametrize( + "mlp_ratio", + [4.0, 2.0], +) +@pytest.mark.parametrize( + "qkv_bias", + [True, False], +) +@pytest.mark.parametrize( + "qk_scale", + [None, 0.5], +) +@pytest.mark.parametrize( + "use_gpsa", + [True, False], +) +@pytest.mark.parametrize( + "locality_strength", + [0.5], +) +@pytest.mark.parametrize( + "use_pos_embedding", + [True, False], +) +@pytest.mark.parametrize( + "normalized", + [True, False], +) +def test_vision_transformer_2d( + shape, + average_img_size, + patch_size, + embedding_dim, + depth, + num_heads, + mlp_ratio, + qkv_bias, + qk_scale, + use_gpsa, + locality_strength, + use_pos_embedding, + normalized, +): + model = VisionTransformer2D( + average_img_size=average_img_size, + patch_size=patch_size, + in_channels=shape[1], + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + ) + data = create_input(shape).cpu() + out = model(data) + assert list(out.shape) == [shape[0], shape[1], shape[2], shape[3]] + + +@pytest.mark.parametrize( + "shape, average_img_size", + [ + [[1, 3, 64, 64, 64], 64], + [[2, 2, 32, 32, 32], (32, 32, 32)], + ], +) +@pytest.mark.parametrize( + "patch_size", + [8, (8, 6, 8)], +) +@pytest.mark.parametrize( + "embedding_dim", + [8, 16], +) +@pytest.mark.parametrize( + "depth", + [4, 8], +) +@pytest.mark.parametrize( + "num_heads", + [6], +) +@pytest.mark.parametrize( + "mlp_ratio", + [4.0], +) +@pytest.mark.parametrize( + "qkv_bias", + [True, False], +) +@pytest.mark.parametrize( + "qk_scale", + [None, 0.5], +) +@pytest.mark.parametrize( + "use_gpsa", + [True, False], +) +@pytest.mark.parametrize( + "locality_strength", + [1.0], +) +@pytest.mark.parametrize( + "use_pos_embedding", + [True, False], +) +@pytest.mark.parametrize( + "normalized", + [True, False], +) +def test_vision_transformer_3d( + shape, + average_img_size, + patch_size, + embedding_dim, + depth, + num_heads, + mlp_ratio, + qkv_bias, + qk_scale, + use_gpsa, + locality_strength, + use_pos_embedding, + normalized, +): + model = VisionTransformer3D( + average_img_size=average_img_size, + patch_size=patch_size, + in_channels=shape[1], + embedding_dim=embedding_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + use_gpsa=use_gpsa, + locality_strength=locality_strength, + use_pos_embedding=use_pos_embedding, + normalized=normalized, + ) + data = create_input(shape).cpu() + out = model(data) + assert list(out.shape) == [shape[0], shape[1], shape[2], shape[3], shape[4]]