Learning

Paper review | Video Swin Transformer

Video Swin Transformer Extended to Video

2023
.
03
.
20
by
Jisu Yu
Paper review | Video Swin Transformer

Let's take a look ata paper review of the Video Swin Transformer, which is an extension of the SwinTransformer that achieved SOTA (State-of-the-Art) in the image domain to video.

 

 

Abstract

 

Recently, the trend of computer vision models has been moving from CNN to Transformer. While the existing Transformer model is based on a transformer layer that globally connects patches in the spatiotemporal dimension, this paper borrows the structure of Swin Transformer, an image domain model, and applies self-attention locally in a video transformer. As a result, SOTA is achieved in video recognition tasks including action recognition.

 

1. Introduction

 

The success of VisionTransformer (ViT) for images has led to research on applying Transformer structures to video-based recognition tasks, and models such as ViViT, MTN, and TimeSFormer have emerged. The model in this paper is also a transformer structure, and its contribution is as follows.

  • Performs computation locally instead of performing self-attention across the video based on the assumption that neighboring time points (=frames) or locations (=pixels) in the video have similar values (spatiotemporal locality).
  • Experimentally found that setting Backbone's learning rate to 0.1 times that of the head leads to higher performance.

 

As a result, it showed higher performance than the existing SOTA in the Video Recognition Task(Kinetics-400/600) while reducing the amount of computation and model size.

 

2. Video Swin Transformer

 

The model used in this paper is a straightforward adaptation of the Swin Transformer structure, with the only difference being that it is extended by one dimension, the time axis, for video application. Therefore, before looking at the Video Swin Transformer, we will first briefly review the Swin Transformer structure.

 

-----------------------

Swin Transformer(Shifted WINdow Transformer)

Background

 

Problems with the existing Vision Transformer (ViT)

  • Does not reflect image characteristics by using a fixed patch(=token) → Consider image resolution and object size X
  • Computation increases quadratically as the number of patches increases by applying global self-attention


기존 Vision Transformer
Existing Vision Transformer

Model architecture

To compensate for the problems of ViT, a hierarchical structure is applied to be flexible to resolution and object size.

  • Processing a larger area at once through patch merging while going from Stage 1 to 4

 

Swin Transformer Block

  • W-MSA (Window Multi-head Self Attention)
  • Apply local window to perform self-attention between patches within a window
    →Since the number of patches in the window is fixed, the computation amount increases linearly with the image size.

  • SW-MSA (Shifted Window Multi-head Self Attention)
  • Performs self-attention by cyclically shifting the window by M/2 to consider the information of neighboring pixels located on the window boundary (M = number of patches in the window).
    →Areas that were not adjacent to each other in the original image are masked so that they are not reflected.

 

W-MAS & SW-MSA
W-MAS & SW-MSA

 

That was a brief explanation of Swin Transformer, now let's get back to the point and take a look at Video Swin Transformer.

 

-----------------

 

2.1 OverallArchitecture

 

The difference from the original model is the addition of a time dimension (T) and the extension ofW-MSA (Window Multi-head Self-Attention) and SW-MSA (Shifted Window Multi-headSelf-Attention) to 3D.

Swin Transformer
Swin Transformer

vs

Video Swin Transformer
Video Swin Transformer

2.1.1 pyTorch code

 

[Swin Transformer]


# Swin Transformer
class SwinTransformerBlock(nn.Module):
    """ Swin Transformer Block.
    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

 

[Video Swin Transformer]


class SwinTransformerBlock3D(nn.Module):
    """ Swin Transformer Block.
    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (tuple[int]): Window size.
        shift_size (tuple[int]): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0),
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        self.use_checkpoint=use_checkpoint

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention3D(
            dim, window_size=self.window_size, num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)      

Comparing the code of the two models, they have the same structure except for the two differences below.

  • Addition of a time axis window_size (7, 7) →(2, 7, 7)
  • WindowAttention → WindowAttention3D

 

Also, one of the main ideas, cyclic shift, differs only by the addition of one more dimension.


# Swin Transformer
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
                
# Video Swin Transformer
x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))

2.2 3D Shifted Window based MSA Module

 

3D shifted windows
3D shifted windows

 

The overall logic is similar, as we have extended the Swin Transformer structure from 2D to 3D.

Swin Transformer 구조를 2D에서 3D로 확장한 기본 로직

 

2.3 Architecture

 

Composed of 4 different versions, the size and computational complexity of each model is the same as the base model: x0.25, x0.5, x1, x2

  • Swin-T : C=96, layer numbers = {2, 2, 6, 2}
  • Swin-S : C=96, layer numbers = {2, 2, 18, 2}
  • Swin-B : C=128, layer numbers = {2, 2, 18, 2}
  • Swin-L : C=192, layer numbers = {2, 2, 18, 2}

where C is the channel number of the hidden layer in stage 1, and we used the window size P =8 and M = 7 as default.

Furthermore, the query dimension of each head is d = 32 and the expansion of the MLP layer is Ɑ = 4.

 

2.4 Initialization from Pre-trained Model

 

The same as Swin Transformer, we initialized with a pre-trained model trained on a large dataset, but the following two blocks have a different shape from the existing model, so we processed them separately.

 

Linear embedding layer

  • The time dimension was halved in Stage 1 (T/2) and the number of channels was doubled (48 → 96).
    →Duplicate the weights from the pre-trained model twice and multiply the entire matrix by 0.5 (this affects the mean and variance of the output x)

 

Video Swin Transformer block's relative position biases

  • (2M - 1, 2M - 1) → (2P - 1, 2M - 1, 2M - 1)
  • Initialize the matrix of the pre-trained model by cloning it (2P - 1) times to make the relative position biases the same within each frame

 

2.4.1 pyTorch code

As we saw above, the model in this paper uses weights from a model pre-trained with images, so we need to scale the dimensions to fit video. If you look at the official code, this is implemented as a function called inflate_weights. You can see that we have replicated the use of WEIGHTS as described in the paper.


  def inflate_weights(self, logger):
      """Inflate the swin2d parameters to swin3d.
      The differences between swin3d and swin2d mainly lie in an extra
      axis. To utilize the pretrained parameters in 2d model,
      the weight of swin2d models should be inflated to fit in the shapes of
      the 3d counterpart.
      Args:
          logger (logging.Logger): The logger used to print
              debugging infomation.
      """
      checkpoint = torch.load(self.pretrained, map_location='cpu')
      state_dict = checkpoint['model']

      # delete relative_position_index since we always re-init it
      relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
      for k in relative_position_index_keys:
          del state_dict[k]

      # delete attn_mask since we always re-init it
      attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
      for k in attn_mask_keys:
          del state_dict[k]

      state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).repeat(1,1,self.patch_size[0],1,1) / self.patch_size[0]

      # bicubic interpolate relative_position_bias_table if not match
      relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
      for k in relative_position_bias_table_keys:
          relative_position_bias_table_pretrained = state_dict[k]
          relative_position_bias_table_current = self.state_dict()[k]
          L1, nH1 = relative_position_bias_table_pretrained.size()
          L2, nH2 = relative_position_bias_table_current.size()
          L2 = (2*self.window_size[1]-1) * (2*self.window_size[2]-1)
          wd = self.window_size[0]
          if nH1 != nH2:
              logger.warning(f"Error in loading {k}, passing")
          else:
              if L1 != L2:
                  S1 = int(L1 ** 0.5)
                  relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
                      relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(2*self.window_size[1]-1, 2*self.window_size[2]-1),
                      mode='bicubic')
                  relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
          state_dict[k] = relative_position_bias_table_pretrained.repeat(2*wd-1,1)

      msg = self.load_state_dict(state_dict, strict=False)
      logger.info(msg)
      logger.info(f"=> loaded successfully '{self.pretrained}'")
      del checkpoint
      torch.cuda.empty_cache()

 

3. Experiments

3.1 Datasets

 

The datasets for the experiments were organized as follows, and the performance evaluation was done as TOP-1 and TOP-5.

 

Human action recognition

  • Kinetics-400: 400action categories, 240k(train) / 20k(val) video
  • Kinetics-600: 600action categories, 370k(train) / 28.3k(val) video

 

Temporal modeling

  • Something-SomethingV2 (SSv2): 174 classes, 168.9K(train) / 24.7K(val) video

 

3.2 Comparison to state-of-the-art

 

The experimental results show that Kinetics and Something-Something V2 achieved SOTA, respectively.

Comparison to state-of-the-art on Kinetic-600

3.3 Ablation Study

 

Different designs for spatiotemporal attention

 

We also conducted an experiment to compare the performance of three different designs of spatiotemporal attention (Joint / Split / Factorized). Each design is described below.

  • Joint: (default): Computes spatiotemporal attention in each 3D windows-based MSA layer.
  • Split: Adds two temporal transformer layers on top of a spatial-only Swin Transformer.
    (This has proven useful in ViViT and VTN)
  • Factorized: Add atemporal-only MSA Layer after each spatial-only MSA Layer in the Swin Transformer.
    (method used in TimeSformer)
  • The results show that the joint version performed the best.

 

spatiotemporal attention with Swin-T on K400

Temporal dimension of 3D tokens & Temporal window size

  • The following is a performance comparison experiment by adjusting the temporal dimension and window size.

 

Ablation study on temporal dimenstion of 3D tokens and temporal window size with Swin-T on K400

  • In our experiments with temporal dimension, we found that the larger the dimension, the higher the Top 1-Acc. However, the larger dimension also increased the computational cost and slowed down the inference speed.
  • In the window size comparison experiment, we can also see that the performance improves as the window size increases (+0.3%), but the performance gain comes at a significant cost in computation (+17%).

 

 

3D shifted windows

 

We also ran an experiment to show the benefits of applying 3D shifted windows, which also resulted in a performance improvement.

Ablation study on the 3D shifted window approach with Swin-T on K400

Ratio of backbone/head learning rate

The results of this experiment seem to be a result of what we found while tuning the parameters. In the table below, you can see that the performance is better when the learning rate of the backbone is 0.1 times that of the head.

Ablation study of the ratio of backbone 1r and head 1r with Swin-B on K400

4. Conclusion

 

The final conclusion. In this paper, we proposed an extension of the existing Swin Transformer to 3D to utilize pre-trained image models while reducing the amount of computation. As a result, we achieved SOTA on three widely used benchmarks(Kinetics-400, Kinetics-600, and Something-Something v2), showing the possibility of extending/adapting image models to video.

 

This has been a review of papers on Video Swin Transformer.

 

Video swin transformer Reference

GitHub - microsoft/Swin-Transformer: This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".

GitHub - Swin Transformer/Video-Swin-Transformer: This is an official implementation for "Video Swin Transformers". 

Talk to Expert