최근 computer vision 모델의 트렌드가 CNN에서 Transformer로 많이 넘어오고 있습니다. 기존 Transformer 모델은 시공간적 차원에서 patch들을 전역적(globally)으로 연결하는 transformer layer에 기반을 두고 있는 반면 본 논문에서는 이미지 도메인 모델인 Swin Transformer의 구조를 차용해 video transformer에서 self-attention을 지역적(locally)으로 적용했습니다. 그 결과 Action Recognition을 포함해 Video Recognition Task에서 SOTA를 달성하였다고 합니다.
1. Introduction
이미지에 대한 Vision Transformer (ViT)의 성공으로 video-based recognition task에 Transformer 구조를 적용하기 위한 연구가 진행되었고, ViViT, MTN, TimeSFormer 등의 모델이 등장했습다. 본 논문의 모델 역시 Transformer 구조로 contribution은 다음과 같습니다.
비디오에서 인접한 시점(=프레임)이나 위치(=픽셀)는 비슷한 값을 가진다는 가정(spatiotemporal locality)을 전제로 비디오 전체에 걸쳐 self-attention을 수행하는 대신 지역적으로 연산을 수행
Backbone의 learning rate를 head의 0.1 배로 설정하면 성능이 더 높다는 사실을 실험적으로 발견
그 결과 연산량과 모델 사이즈를 줄이면서도 Video Recognition Task(Kinetics-400/600)에서 기존 SOTA 보다 더 높은 성능을 보여주었습니다.
2. Video Swin Transformer
본 논문에서 사용한 모델은 Swin Transformer 구조를 그대로 차용한 모델로, 비디오 적용을 위해 시간 축으로 한 차원 확장했다는 점에서만 차이를 갖습니다. 따라서 Video Swin Transformer를 살펴보기 전에 먼저 Swin Transformer 구조에 대해 간략하게 리뷰하도록 하겠습니다.
-----------------------
Swin Transformer (Shifted WINdow Transformer)
Background
기존 Vision Transformer (ViT)의 문제점
고정된 patch(=token)를 사용해 이미지 특성이 반영되지 않음 → 이미지 해상도와 객체의 크기 고려 X
Global self-attention을 적용해 patch 수가 증가함에 따라 연산량이 quadratic하게 증가
기존 Vision Transformer
Model architecture
ViT의 문제점을 보완하기 위해 해상도와 객체 크기에 유연하도록 hierarchical 구조를 적용
Stage 1~4로 가면서 Patch Merging을 통해 더 넓은 부분을 한 번에 처리
Swin Transformer Block
W-MSA (Window Multi-head Self Attention)
Local window를 적용하여 window 내 patch 간 self-attention을 수행 → window 내 patch 수가 고정되어 있기 때문에 연산량은 이미지 크기에 선형적으로 증가
SW-MSA (Shifted Window Multi-head Self Attention)
Window 경계에 위치한 인접 pixel의 정보를 고려하기 위해 Window를 M/2만큼 cyclic하게 shift하여 self-attention을 수행 (M = window 내 patch 개수) → 원본 이미지에서 서로 인접하지 않았던 영역은 mask 처리해 반영하지 않도록 만듦
W-MAS & SW-MSA
지금까지 Swin Transformer에 대한 간략한 설명이었습니다. 이제 다시 본론으로 돌아와서 Video Swin Transformer에 대해 살펴보도록 하겠습니다.
-----------------
2.1 Overall Architecture
기존 모델과 다른 점은 시간 차원(T)이 추가된 것과 W-MSA (Window Multi-head Self Attention)와 SW-MSA (Shifted Window Multi-head Self Attention)가 3D로 확장된 것입니다.
Swin Transformer
vs
Video Swin Transformer
pyTorch 코드로 살펴보기
[Swin Transformer]
# Swin Transformer
class SwinTransformerBlock(nn.Module):
""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
[Video Swin Transformer]
class SwinTransformerBlock3D(nn.Module):
""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Window size.
shift_size (tuple[int]): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0),
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.use_checkpoint=use_checkpoint
self.norm1 = norm_layer(dim)
self.attn = WindowAttention3D(
dim, window_size=self.window_size, num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
두 모델 코드를 비교해보면, 아래 두 가지 차이를 제외하고는 모두 동일한 구조를 갖고 있습니다.
시간축의 추가로 window_size (7, 7) → (2, 7, 7)
WindowAttention → WindowAttention3D
또한 메인 아이디어 중 하나인 cyclic shift도 차원이 하나 추가되었다는 점에서만 차이를 갖습니다.
Swin Transformer 구조를 2D에서 3D로 확장했기 때문에 전체적인 로직도 유사합니다.
2.3 Architecture
4개의 다른 버전으로 구성, 각 모델의 사이즈와 계산 복잡도는 베이스 모델의 x0.25, x0.5, x1, x2
Swin-T : C=96, layer numbers = {2, 2, 6, 2}
Swin-S : C=96, layer numbers = {2, 2, 18, 2}
Swin-B : C=128, layer numbers = {2, 2, 18, 2}
Swin-L : C=192, layer numbers = {2, 2, 18, 2}
여기서 C는 1 stage의 hidden layer의 channel number 이고, Window size P = 8, M = 7 를 default로 사용했습니다.
추가적으로 각 Head의 query 차원은 d = 32, MLP Layer의 expansion은 Ɑ = 4 입니다.
2.4 Initialization from Pre-trained Model
Swin Transformer와 동일하게 대규모 데이터셋으로 학습된 사전 학습 모델로 초기화를 진행하였, 아래 두 블럭은 기존 모델과 다른 형태을 가지므로 별도의 처리를 해주었습니다.
Linear embedding layer
Stage 1에서 시간 차원이 반으로 줄어 (T/2) 기존보다 채널이 2배 증가 (48 → 96) → 사전 훈련된 모델에서 가중치를 두 번 복제 후 전체 행렬에 0.5를 multiply (출력의 평균과 분산에는 영향 x)
Video Swin Transformer block의 relative position biases
(2M - 1, 2M - 1) → (2P - 1, 2M - 1, 2M - 1)
각 프레임 내에서 relative position biases을 동일하게 만들기 위해 사전 훈련된 모델의 매트릭스를 (2P - 1)회 복제하여 초기화
PyTorch 코드로 살펴보기
위에서 보았듯이 본 논문의 모델은 이미지로 사전 훈련된 모델의 weight를 사용하기 때문에 비디오에 맞게 차원을 확장해 줄 필요가 있습니다. 공식 코드를 보면 inflate_weights라는 함수로 구현되어 있습니다. 논문에서 한 설명처럼 weight를 복제해서 사용한 걸 확인할 수 있습니다.
def inflate_weights(self, logger):
"""Inflate the swin2d parameters to swin3d.
The differences between swin3d and swin2d mainly lie in an extra
axis. To utilize the pretrained parameters in 2d model,
the weight of swin2d models should be inflated to fit in the shapes of
the 3d counterpart.
Args:
logger (logging.Logger): The logger used to print
debugging infomation.
"""
checkpoint = torch.load(self.pretrained, map_location='cpu')
state_dict = checkpoint['model']
# delete relative_position_index since we always re-init it
relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
for k in relative_position_index_keys:
del state_dict[k]
# delete attn_mask since we always re-init it
attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
for k in attn_mask_keys:
del state_dict[k]
state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).repeat(1,1,self.patch_size[0],1,1) / self.patch_size[0]
# bicubic interpolate relative_position_bias_table if not match
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
for k in relative_position_bias_table_keys:
relative_position_bias_table_pretrained = state_dict[k]
relative_position_bias_table_current = self.state_dict()[k]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
L2 = (2*self.window_size[1]-1) * (2*self.window_size[2]-1)
wd = self.window_size[0]
if nH1 != nH2:
logger.warning(f"Error in loading {k}, passing")
else:
if L1 != L2:
S1 = int(L1 ** 0.5)
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(2*self.window_size[1]-1, 2*self.window_size[2]-1),
mode='bicubic')
relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
state_dict[k] = relative_position_bias_table_pretrained.repeat(2*wd-1,1)
msg = self.load_state_dict(state_dict, strict=False)
logger.info(msg)
logger.info(f"=> loaded successfully '{self.pretrained}'")
del checkpoint
torch.cuda.empty_cache()
3. Experiments
3.1 Datasets
실험을 위한 데이터셋의 구성은 다음과 같고 성능 평가는 top-1와 top-5로 이루어졌습니다.
Human action recognition
Kinetics-400: 400개의 action category, 240k(train) / 20k(val) video
Kinetics-600: 600개의 action category, 370k(train) / 28.3k(val) video
Temporal modeling
Something-Something V2 (SSv2): 174 클래스, 168.9K(train) / 24.7K(val) video
3.2 Comparison to state-of-the-art
실험 결과 Kinetics과 Something-Something V2에서 각각 SOTA를 달성했습니다.
3.3 Ablation Study
Different designs for spatiotemporal attention
세 가지 방식(Joint / Split / Factorized )의 spatiotemporal attention을 적용하여 성능을 비교하는 실험도 진행하였습니다. 각 방식에 대한 설명은 다음과 같습니다.
Joint: (default): 각 3D windows-based MSA layer에서 spatiotemporal attention을 계산
Split: spatial-only Swin Transformer위에 2개의 temporal transformer layer를 추가 (ViViT와 VTN에서 유용하다고 입증된 방법)
Factorized: Swin Transformer의 각 spatial-only MSA Layer 뒤에 temporal-only MSA Layer 추가 (TimeSformer에서 사용된 방법)
결과를 확인해보니 joint 버전이 가장 좋은 성능을 보였다고 합니다.
Temporal dimension of 3D tokens & Temporal window size
다음으로 Temporal dimension와 window size를 조절하면서 진행한 성능 비교 실험입니다.
Temporal dimension의 비교 실험 결과를 보면, 차원이 클수록 상위 Top 1-Acc가 높다는 것을 알 수 있습니다. 하지만 차원이 커진 만큼 계산 비용이 커지고 추론 속도도 느려졌습니다.
Window size 비교 실험에서도 window size의 크기가 커질수록 성능(+0.3%)은 향상되지만 성능 향상도 대비 계산량(+17%) 상당히 많아지는 것을 확인할 수 있습니다.
3D shifted windows
3D shifted window 적용의 이점을 보여주기 위한 실험도 진행했습니다. 그 결과 성능 향상을 확인할 수 있었습니다.
Ratio of backbone/head learning rate
이 실험의 결과는 파라미터 튜닝을 진행하면서 발견한 결과인 것 같습니다. 아래 표를 보면, Backbone의 learning rate를 head의 0.1배로 하면 성능이 더 잘 나오는 것을 확인할 수 있습니다.
4. Conclusion
마지막 결론입니다. 본 논문은 기존 Swin Transformer를 3D로 확장하여, 계산량을 줄이면서 사전 훈련된 이미지 모델을 활용할 수 있는 방안을 제시하였습니다. 그 결과 널리 사용되는 세 가지 벤치마크(Kinetics-400, Kinetics-600, Something-Something v2)에서 SOTA를 달성하였고, 이를 통해 이미지 모델을 비디오로 확장/적용할 수 있는 가능성을 보여주었습니다.