Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 2D RoPE embedding to Kauldron. #683

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions kauldron/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@
from kauldron.modules.pos_embeddings import FourierEmbedding
from kauldron.modules.pos_embeddings import LearnedEmbedding
from kauldron.modules.pos_embeddings import ZeroEmbedding
from kauldron.modules.pos_embeddings import Add2DRoPE
36 changes: 36 additions & 0 deletions kauldron/modules/pos_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,42 @@
from kauldron.typing import Axes, Dtype, Float, Initializer, Shape, typechecked # pylint: disable=g-multiple-import,g-importing-member


class Add2DRoPE(nn.Module):
"""Helper Module for adding a 2D version of RoPE embedding to inputs.

Attributes:
num_wavelengths: Number of wavelengths.
"""

num_wavelengths: int

def __call__(self, inputs: Float['*B X Y C']) -> Float['*B X Y C']:
num_wavelengths = inputs.shape[-1]
assert num_wavelengths % 4 == 0
num_channels = num_wavelengths // 2
freq_exponents = (2. / num_channels) * jnp.arange(num_channels // 2)
timescale = (num_wavelengths ** freq_exponents)[-3]
posx, posy = jnp.arange(inputs.shape[-3]), jnp.arange(inputs.shape[-2])
radx = posx[None, None, ..., None] / timescale[None, None, None, :]
# radians.shape = [...,L,1,d=D/2]
sinx, cosx = jnp.sin(radx), jnp.cos(radx)
rady = posy[None, None, ..., None] / timescale[None, None, None, :]
# radians.shape = [...,L,1,d=D/2]
siny, cosy = jnp.sin(rady), jnp.cos(rady)

x_inputs, y_inputs = jnp.split(inputs, 2, axis=-1)
x1, x2 = jnp.split(x_inputs, 2, axis=-1)
cosx, sinx = jnp.expand_dims(cosx, axis=-2), jnp.expand_dims(sinx, axis=-2)
cosy, siny = jnp.expand_dims(cosy, axis=-3), jnp.expand_dims(siny, axis=-3)
x_res = jnp.concatenate(
[x1 * cosx - x2 * sinx, x2 * cosx + x1 * sinx], axis=-1)
y1, y2 = jnp.split(y_inputs, 2, axis=-1)
y_res = jnp.concatenate(
[y1 * cosy - y2 * siny, y2 * cosy + y1 * siny], axis=-1)
res = jnp.concatenate([x_res, y_res], axis=-1)
return res


class AddEmbedding(nn.Module):
"""Helper Module for adding a PositionEmbedding e.g. in a `knn.Sequential`.

Expand Down