Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ViViT(Video Vision Transformer) to KerasCV #2335

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
15 changes: 15 additions & 0 deletions keras_cv/models/video_classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.models.video_classification.vivit import ViViT
201 changes: 201 additions & 0 deletions keras_cv/models/video_classification/vivit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright 2024 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.models.task import Task


@keras_cv_export(
[
"keras_cv.models.ViViT",
"keras_cv.models.video_classification.ViViT",
]
)
class ViViT(Task):
"""A Keras model implementing a Video Vision Transformer
for video classification.

References:
- [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691)
(ICCV 2021)

Args:
tubelet_embedder: 'keras.layers.Layer'. A layer for spatio-temporal tube
embedding applied to input sequences retrieved from video frames.
positional_encoder: 'keras.layers.Layer'. A layer for adding positional
information to the encoded video tokens.
inp_shape: tuple, the shape of the input video frames.
num_classes: int, the number of classes for video classification.
transformer_layers: int, the number of transformer layers in the model.
Defaults to 8.
num_heads: int, the number of heads for multi-head
self-attention mechanism. Defaults to 8.
embed_dim: int, number of dimensions in the embedding space.
Defaults to 128.
layer_norm_eps: float, epsilon value for layer normalization.
Defaults to 1e-6.


Examples:
```python
import keras_cv

INPUT_SHAPE = (32, 32, 32, 1)
NUM_CLASSES = 11
PATCH_SIZE = (8, 8, 8)
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 8
NUM_LAYERS = 8

frames = np.random.uniform(size=(5, 32, 32, 32, 1))
labels = np.ones(shape=(5))
model = ViViT(
tubelet_embedder=TubeletEmbedding(
embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE
),
positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM),
inp_shape=INPUT_SHAPE,
transformer_layers=NUM_LAYERS,
num_heads=NUM_HEADS,
embed_dim=PROJECTION_DIM,
layer_norm_eps=LAYER_NORM_EPS,
num_classes=NUM_CLASSES,
)

# Evaluate model
model(frames)

# Train model
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
],
)

model.fit(frames, labels, epochs=3)

```
"""

def __init__(
self,
tubelet_embedder,
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
positional_encoder,
inp_shape,
num_classes,
transformer_layers=8,
num_heads=8,
embed_dim=128,
layer_norm_eps=1e-6,
**kwargs,
):
if not isinstance(tubelet_embedder, keras.layers.Layer):
raise ValueError(
"Argument `tubelet_embedder` must be a "
" `keras.layers.Layer` instance "
f" . Received instead "
f"tubelet_embedder={tubelet_embedder} "
f"(of type {type(tubelet_embedder)})."
)

if not isinstance(positional_encoder, keras.layers.Layer):
raise ValueError(
"Argument `positional_encoder` must be a "
"`keras.layers.Layer` instance "
f" . Received instead "
f"positional_encoder={positional_encoder} "
f"(of type {type(positional_encoder)})."
)

inputs = keras.layers.Input(shape=inp_shape)
patches = tubelet_embedder(inputs)
encoded_patches = positional_encoder(patches)

for _ in range(transformer_layers):
x1 = keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
attention_output = keras.layers.MultiHeadAttention(
num_heads=num_heads,
key_dim=embed_dim // num_heads,
dropout=0.1,
)(x1, x1)

x2 = keras.layers.Add()([attention_output, encoded_patches])

x3 = keras.layers.LayerNormalization(epsilon=1e-6)(x2)
x3 = keras.Sequential(
[
keras.layers.Dense(
units=embed_dim * 4, activation=keras.ops.gelu
),
keras.layers.Dense(
units=embed_dim, activation=keras.ops.gelu
),
]
)(x3)

encoded_patches = keras.layers.Add()([x3, x2])

representation = keras.layers.LayerNormalization(
epsilon=layer_norm_eps
)(encoded_patches)
representation = keras.layers.GlobalAvgPool1D()(representation)

outputs = keras.layers.Dense(units=num_classes, activation="softmax")(
representation
)

super().__init__(inputs=inputs, outputs=outputs, **kwargs)

self.inp_shape = inp_shape
self.num_heads = num_heads
self.num_classes = num_classes
self.tubelet_embedder = tubelet_embedder
self.positional_encoder = positional_encoder

divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
def get_config(self):
config = super().get_config()
config.update(
{
"num_heads": self.num_heads,
"inp_shape": self.inp_shape,
"num_classes": self.num_classes,
"tubelet_embedder": keras.saving.serialize_keras_object(
self.tubelet_embedder
),
"positional_encoder": keras.saving.serialize_keras_object(
self.positional_encoder
),
}
)
return config

@classmethod
def from_config(cls, config):
if "tubelet_embedder" in config and isinstance(
config["tubelet_embedder"], dict
):
config["tubelet_embedder"] = keras.layers.deserialize(
config["tubelet_embedder"]
)
if "positional_encoder" in config and isinstance(
config["positional_encoder"], dict
):
config["positional_encoder"] = keras.layers.deserialize(
config["positional_encoder"]
)
return super().from_config(config)
102 changes: 102 additions & 0 deletions keras_cv/models/video_classification/vivit_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2024 The KerasCV Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops


@keras_cv_export(
"keras_cv.layers.TubeletEmebedding",
package="keras_cv.layers",
)
class TubeletEmbedding(keras.layers.Layer):
"""
A Keras layer for spatio-temporal tube embedding applied to input sequences
retrieved from video frames.

References:
- [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691)
(ICCV 2021)

Args:
embed_dim: int, number of dimensions in the embedding space.
Defaults to 128.
patch_size: tuple or int, size of the spatio-temporal patch.
If int, the same size is used for all dimensions.
If tuple, specifies the size for each dimension.
Defaults to 8.

"""

def __init__(self, embed_dim=128, patch_size=8, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.patch_size = patch_size

def build(self, input_shape):
self.projection = keras.layers.Conv3D(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define all layers in init and build them here like self.layer_name.build(expected_input_shape)

filters=self.embed_dim,
kernel_size=self.patch_size,
strides=self.patch_size,
padding="VALID",
)
self.flatten = keras.layers.Reshape(target_shape=(-1, self.embed_dim))

def call(self, videos):
projected_patches = self.projection(videos)
flattened_patches = self.flatten(projected_patches)
return flattened_patches

def get_config(self):
config = super().get_config()
config.update(
{"embed_dim": self.embed_dim, "patch_size": self.patch_size}
)
return config


@keras_cv_export(
"keras_cv.layers.PositionalEncoder",
package="keras_cv.layers",
)
class PositionalEncoder(keras.layers.Layer):
"""
A Keras layer for adding positional information to the encoded video tokens.

References:
- [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691)
(ICCV 2021)

Args:
embed_dim: int, number of dimensions in the embedding space.
Defaults to 128.

"""

def __init__(self, embed_dim=128, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim

def build(self, input_shape):
_, num_tokens, _ = input_shape
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
self.position_embedding = keras.layers.Embedding(
input_dim=num_tokens, output_dim=self.embed_dim
)
self.positions = ops.arange(start=0, stop=num_tokens, step=1)

def call(self, encoded_tokens):
encoded_positions = self.position_embedding(self.positions)
encoded_tokens = encoded_tokens + encoded_positions
return encoded_tokens
Loading
Loading