diff --git a/kauldron/modules/__init__.py b/kauldron/modules/__init__.py index 805f1b49..4068fd9b 100644 --- a/kauldron/modules/__init__.py +++ b/kauldron/modules/__init__.py @@ -42,3 +42,4 @@ from kauldron.modules.pos_embeddings import FourierEmbedding from kauldron.modules.pos_embeddings import LearnedEmbedding from kauldron.modules.pos_embeddings import ZeroEmbedding +from kauldron.modules.pos_embeddings import Add2DRoPE diff --git a/kauldron/modules/pos_embeddings.py b/kauldron/modules/pos_embeddings.py index 7a64bf47..dc9cf8f0 100644 --- a/kauldron/modules/pos_embeddings.py +++ b/kauldron/modules/pos_embeddings.py @@ -24,6 +24,42 @@ from kauldron.typing import Axes, Dtype, Float, Initializer, Shape, typechecked # pylint: disable=g-multiple-import,g-importing-member +class Add2DRoPE(nn.Module): + """Helper Module for adding a 2D version of RoPE embedding to inputs. + + Attributes: + num_wavelengths: Number of wavelengths. + """ + + num_wavelengths: int + + def __call__(self, inputs: Float['*B X Y C']) -> Float['*B X Y C']: + num_wavelengths = inputs.shape[-1] + assert num_wavelengths % 4 == 0 + num_channels = num_wavelengths // 2 + freq_exponents = (2. / num_channels) * jnp.arange(num_channels // 2) + timescale = (num_wavelengths ** freq_exponents)[-3] + posx, posy = jnp.arange(inputs.shape[-3]), jnp.arange(inputs.shape[-2]) + radx = posx[None, None, ..., None] / timescale[None, None, None, :] + # radians.shape = [...,L,1,d=D/2] + sinx, cosx = jnp.sin(radx), jnp.cos(radx) + rady = posy[None, None, ..., None] / timescale[None, None, None, :] + # radians.shape = [...,L,1,d=D/2] + siny, cosy = jnp.sin(rady), jnp.cos(rady) + + x_inputs, y_inputs = jnp.split(inputs, 2, axis=-1) + x1, x2 = jnp.split(x_inputs, 2, axis=-1) + cosx, sinx = jnp.expand_dims(cosx, axis=-2), jnp.expand_dims(sinx, axis=-2) + cosy, siny = jnp.expand_dims(cosy, axis=-3), jnp.expand_dims(siny, axis=-3) + x_res = jnp.concatenate( + [x1 * cosx - x2 * sinx, x2 * cosx + x1 * sinx], axis=-1) + y1, y2 = jnp.split(y_inputs, 2, axis=-1) + y_res = jnp.concatenate( + [y1 * cosy - y2 * siny, y2 * cosy + y1 * siny], axis=-1) + res = jnp.concatenate([x_res, y_res], axis=-1) + return res + + class AddEmbedding(nn.Module): """Helper Module for adding a PositionEmbedding e.g. in a `knn.Sequential`.