Skip to content

innat/VideoSwin

Repository files navigation

Video Swin Transformer

Palestine

arXiv keras-3 Static Badge Static Badge Static Badge Open In Colab HugginFace badge HugginFace badge

VideoSwin is a pure transformer based video modeling algorithm, attained top accuracy on the major video recognition benchmarks. In this model, the author advocates an inductive bias of locality in video transformers, which leads to a better speed-accuracy trade-off compared to previous approaches which compute self-attention globally even with spatial-temporal factorization. The locality of the proposed video architecture is realized by adapting the Swin Transformer designed for the image domain, while continuing to leverage the power of pre-trained image models.

This is a unofficial Keras 3 implementation of Video Swin transformers. The official PyTorch implementation is here based on mmaction2. The official PyTorch weight has been converted to Keras 3 compatible. This implementaiton supports to run the model on multiple backend, i.e. TensorFlow, PyTorch, and Jax. However, to work with tensorflow.keras, check the tfkeras branch.

Install

!git clone https://github.com/innat/VideoSwin.git
%cd VideoSwin
!pip install -e . 

Checkpoints

The VideoSwin checkpoints are available in .weights.h5 for Kinetrics 400/600 and Something Something V2 datasets. The variants of this models are tiny, small, and base. Check model zoo page to know details of it.

Inference

A sample usage is shown below with a pretrained weight. We can pick any backend, i.e. tensorflow, torch or jax.

import  os
import torch
os.environ["KERAS_BACKEND"] = "torch" # or any backend.
from videoswin import VideoSwinT

def vswin_tiny():
    !wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_tiny_kinetics400_classifier.weights.h5 -q

    model = VideoSwinT(
        num_classes=400,
        include_rescaling=False,
        activation=None
    )
    model.load_weights(
        'videoswin_tiny_kinetics400_classifier.weights.h5'
    )
    return model

model = vswin_tiny()
container = read_video('sample.mp4')
frames = frame_sampling(container, num_frames=32)
y_pred = model(frames)
y_pred.shape # [1, 400]

probabilities = torch.nn.functional.softmax(y_pred).detach().numpy()
probabilities = probabilities.squeeze(0)
confidences = {
    label_map_inv[i]: float(probabilities[i]) \
    for i in np.argsort(probabilities)[::-1]
}
confidences

A classification results on a sample from Kinetics-400.

Video Top-5
{
'playing_cello': 0.9941741824150085,
'playing_violin': 0.0016851733671501279,
'playing_recorder': 0.0011555481469258666,
'playing_clarinet': 0.0009695519111119211,
'playing_harp': 0.0007713600643910468
}

To get the backbone of video swin, we can pass include_top=False params to exclude the classification layer. For example:

from videoswin.backbone import VideoSwinBackbone

backbone = VideoSwinT(
    include_top=False, input_shape=(32, 224, 224, 3)
)

Or, we use use the VideoSwinBackbone API directly from from videoswin.backbone.

Arbitrary Input Shape

By default, the video swin officially is trained with input shape of 32, 224, 224, 3. But, We can load the model with different shape. And also load the pretrained weight partially.

model = VideoSwinT(
    input_shape=(8, 224, 256, 3),
    include_rescaling=False,
    num_classes=10,
)
model.load_weights('...weights.h5', skip_mismatch=True)

Guides

  1. Comparison of Keras 3 implementaiton VS Official PyTorch implementaiton.
  2. Full Evaluation on Kinetics 400 Test Set using PyTorch backend
  3. Fine tune with TensorFlow backend.
  4. Fine tune with Jax backend
  5. Fine tune with native PyTorch backend
  6. Fine tune with PyTorch Lightening
  7. Convert to ONNX Format

Citation

If you use this videoswin implementation in your research, please cite it using the metadata from our CITATION.cff file, along with the literature.

@article{liu2021video,
  title={Video Swin Transformer},
  author={Liu, Ze and Ning, Jia and Cao, Yue and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Hu, Han},
  journal={arXiv preprint arXiv:2106.13230},
  year={2021}
}