-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for arbitrary image resolutions #24
base: main
Are you sure you want to change the base?
Changes from 2 commits
30f9393
d55b45a
af09c34
cd13378
ccdbe31
805c706
ff992fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
from tensorflow.keras import backend as K | ||
from tensorflow.keras import layers | ||
|
||
from ..layers import BlockImages, SwapAxes, UnblockImages | ||
from ..layers import SwapAxes, TFBlockImagesByGrid, TFUnblockImages | ||
|
||
|
||
def GridGatingUnit(use_bias: bool = True, name: str = "grid_gating_unit"): | ||
|
@@ -18,9 +18,7 @@ def GridGatingUnit(use_bias: bool = True, name: str = "grid_gating_unit"): | |
|
||
def apply(x): | ||
u, v = tf.split(x, 2, axis=-1) | ||
v = layers.LayerNormalization( | ||
epsilon=1e-06, name=f"{name}_intermediate_layernorm" | ||
)(v) | ||
v = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_intermediate_layernorm")(v) | ||
n = K.int_shape(x)[-3] # get spatial dim | ||
v = SwapAxes()(v, -1, -3) | ||
v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v) | ||
|
@@ -47,9 +45,8 @@ def apply(x): | |
K.int_shape(x)[3], | ||
) | ||
gh, gw = grid_size | ||
fh, fw = h // gh, w // gw | ||
|
||
x = BlockImages()(x, patch_size=(fh, fw)) | ||
x, ph, pw = TFBlockImagesByGrid()(x, grid_size=(gh, gw)) | ||
Comment on lines
-52
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How come these operations are the same? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the original implementation, the authors implement BlockByGrid by computing the block size of a grid cell, and using BlockImages (which block images into patches of block-size). From the paper, the authors explain the difference between "grid" and "block" like that: They are equivalent because it does the split based on the grid_size as argument instead of the block_size (called as (fh, fw) in the code) as the authors did. A more formal test is performed here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for explaining.
How is the block size of [3, 2] interpreted in that case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the original code, it is done as explained, note here: gh, gw = grid_size
fh, fw = h // gh, w // gw
u = BlockImages()(u, patch_size=(fh, fw)) Note that this code is very similar to the pseudo-code written here. We can use the block [3,2] to compute the green part of the image (grid blocking with grid_size=[3,2]) this way: In the example shown in the image, we have that image size is [6,4]. Thus to split it with a grid_size of [2,2], we can do: gh, gw = (2, 2)
h, w = (6,4) # image dimensions
fh, fw = h // gh, w // gw # Note that fh = 3, and fw = 2
block_image = BlockImages()(image_from_the_piture, patch_size=(fh,fw)) # patch_size=(3,2) The above code snippet implements the green part of the image, and is very similar to what we described first. In case with the gh, gw = (2,2)
block_image_using_tfblockByGrid = TFBlockByGrid()(image_from_the_picture, grid_size=(gh,gw)) and I am not sure if this answer what you asked, though. Let me know. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! So, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exactly! |
||
# gMLP1: Global (grid) mixing part, provides global grid communication. | ||
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x) | ||
y = layers.Dense( | ||
|
@@ -66,7 +63,7 @@ def apply(x): | |
)(y) | ||
y = layers.Dropout(dropout_rate)(y) | ||
x = x + y | ||
x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw)) | ||
x = TFUnblockImages()(x, grid_size=(gh, gw), patch_size=(ph, pw)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same. You're changing the semanticity of the code. Could you please elaborate why? Reading this change and also previous |
||
return x | ||
|
||
return apply |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,18 +8,14 @@ | |
from tensorflow.keras import backend as K | ||
from tensorflow.keras import layers | ||
|
||
from ..layers import BlockImages, SwapAxes, UnblockImages | ||
from ..layers import SwapAxes, TFBlockImages, TFBlockImagesByGrid, TFUnblockImages | ||
from .block_gating import BlockGmlpLayer | ||
from .grid_gating import GridGmlpLayer | ||
|
||
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same") | ||
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same") | ||
ConvT_up = functools.partial( | ||
layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same" | ||
) | ||
Conv_down = functools.partial( | ||
layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same" | ||
) | ||
ConvT_up = functools.partial(layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same") | ||
Conv_down = functools.partial(layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same") | ||
|
||
|
||
def ResidualSplitHeadMultiAxisGmlpLayer( | ||
|
@@ -116,24 +112,22 @@ def apply(x): | |
u, v = tf.split(x, 2, axis=-1) | ||
|
||
# Get grid MLP weights | ||
gh, gw = grid_size | ||
fh, fw = h // gh, w // gw | ||
u = BlockImages()(u, patch_size=(fh, fw)) | ||
dim_u = K.int_shape(u)[-3] | ||
ghu, gwu = grid_size | ||
u, phu, pwu = TFBlockImagesByGrid()(u, grid_size=(ghu, gwu)) | ||
dim_u = ghu * gwu | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explain the rationale in the comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, advisable not to change the original variable names here and elsewhere. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Variable names will be recovered on next push. Did it only for readability (since they get rewritten a few lines below) If i understood correclty, you are asking why we can substitute K.int_shape(u)[-3] for (gh * gw): From BlockImages(), we have that the output's shape is "b (gh gw) (fh fw) c". Thus, since: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same reason why fh and fw are getting replaced by gh and gw here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Essentially, those transformations are the same: def same_operations(random_image, grid_size=(gh,gw)):
b, h, w, c = random_image.shape
image_blocked_by_grid = BlockByGrid(random_image, grid_size=(gh, gw))
image_blocked_by_block = BlockByPatch(random_image, patch_size=(h // gh, w // gw)
image_blocked_by_grid == image_blocked_by_block # this should be True. we have this pseudo-code as a test here. Note that |
||
u = SwapAxes()(u, -1, -3) | ||
u = layers.Dense(dim_u, use_bias=use_bias, name=f"{name}_Dense_0")(u) | ||
u = SwapAxes()(u, -1, -3) | ||
u = UnblockImages()(u, grid_size=(gh, gw), patch_size=(fh, fw)) | ||
u = TFUnblockImages()(u, grid_size=(ghu, gwu), patch_size=(phu, pwu)) | ||
|
||
# Get Block MLP weights | ||
fh, fw = block_size | ||
gh, gw = h // fh, w // fw | ||
v = BlockImages()(v, patch_size=(fh, fw)) | ||
dim_v = K.int_shape(v)[-2] | ||
fhv, fwv = block_size | ||
v, ghv, gwv = TFBlockImages()(v, patch_size=(fhv, fwv)) | ||
dim_v = fhv * fwv | ||
v = SwapAxes()(v, -1, -2) | ||
v = layers.Dense(dim_v, use_bias=use_bias, name=f"{name}_Dense_1")(v) | ||
v = SwapAxes()(v, -1, -2) | ||
v = UnblockImages()(v, grid_size=(gh, gw), patch_size=(fh, fw)) | ||
v = TFUnblockImages()(v, patch_size=(fhv, fwv), grid_size=(ghv, gwv)) | ||
|
||
x = tf.concat([u, v], axis=-1) | ||
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project")(x) | ||
|
@@ -159,9 +153,7 @@ def CrossGatingBlock( | |
def apply(x, y): | ||
# Upscale Y signal, y is the gating signal. | ||
if upsample_y: | ||
y = ConvT_up( | ||
filters=features, use_bias=use_bias, name=f"{name}_ConvTranspose_0" | ||
)(y) | ||
y = ConvT_up(filters=features, use_bias=use_bias, name=f"{name}_ConvTranspose_0")(y) | ||
|
||
x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_Conv_0")(x) | ||
n, h, w, num_channels = ( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,65 +1,89 @@ | ||
""" | ||
Layers based on https://github.com/google-research/maxim/blob/main/maxim/models/maxim.py | ||
and reworked to cope with variable image dimensions | ||
""" | ||
|
||
import einops | ||
import tensorflow as tf | ||
from tensorflow.experimental import numpy as tnp | ||
from tensorflow.keras import backend as K | ||
from tensorflow.keras import layers | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable("maxim") | ||
class BlockImages(layers.Layer): | ||
class TFBlockImages(layers.Layer): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
def call(self, x, patch_size): | ||
bs, h, w, num_channels = ( | ||
K.int_shape(x)[0], | ||
K.int_shape(x)[1], | ||
K.int_shape(x)[2], | ||
K.int_shape(x)[3], | ||
) | ||
def call(self, image, patch_size): | ||
bs, h, w, num_channels = (tf.shape(image)[0], tf.shape(image)[1], tf.shape(image)[2], tf.shape(image)[3]) | ||
ph, pw = patch_size | ||
gh = h // ph | ||
gw = w // pw | ||
pad = [[0, 0], [0, 0]] | ||
patches = tf.space_to_batch_nd(image, [ph, pw], pad) | ||
patches = tf.split(patches, ph * pw, axis=0) | ||
patches = tf.stack(patches, 3) # (bs, h/p, h/p, p*p, 3) | ||
patches_dim = tf.shape(patches) | ||
patches = tf.reshape(patches, [patches_dim[0], patches_dim[1], patches_dim[2], -1]) | ||
patches = tf.reshape(patches, (patches_dim[0], patches_dim[1] * patches_dim[2], ph * pw, num_channels)) | ||
return [patches, gh, gw] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm honestly not sure why we are getting rid of einops. This is significantly more lines of code and also more complex to read. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, using einops would be in hand. But please, consider this code snippet:
The problem with einops is that it expects int literals as argument to the symbols used in the pattern string. I could not make it work using tensors as shown by the example above. At some stages of the model (here, here, here), the split is computed in online fashion, thus relying on tensors (for the case where the img size is None). Thus it was necessary to rewrite using tensorflow. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But it wasn't a problem with the current version of the code. What changed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current version of the code informs the image dimension beforehand, thus when you do:
you have the integer literals we need for the einops operations. However, In case when we feed (None, None, 3) as input, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the line number you're using for Black formatting? The line-numbers seem long and should be formatted accordingly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where I work we use 122, I am reformatting with 88 (black's default, IIRC). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 80 is the default. You can bump it to 90 (which is what I used). |
||
|
||
grid_height, grid_width = h // patch_size[0], w // patch_size[1] | ||
def get_config(self): | ||
return super().get_config() | ||
|
||
x = einops.rearrange( | ||
x, | ||
"n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c", | ||
gh=grid_height, | ||
gw=grid_width, | ||
fh=patch_size[0], | ||
fw=patch_size[1], | ||
) | ||
|
||
return x | ||
@tf.keras.utils.register_keras_serializable("maxim") | ||
class TFBlockImagesByGrid(layers.Layer): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
def call(self, image, grid_size): | ||
bs, h, w, num_channels = (tf.shape(image)[0], tf.shape(image)[1], tf.shape(image)[2], tf.shape(image)[3]) | ||
gh, gw = grid_size | ||
ph = h // gh | ||
pw = w // gw | ||
pad = [[0, 0], [0, 0]] | ||
|
||
def block_single_image(img): | ||
pat = tf.expand_dims(img, 0) # batch = 1 | ||
pat = tf.space_to_batch_nd(pat, [ph, pw], pad) # p*p*bs, g, g, c | ||
pat = tf.expand_dims(pat, 3) # pxpxbs, g, g, 1, c | ||
pat = tf.transpose(pat, perm=[3, 1, 2, 0, 4]) # 1, g, g, pxp, c | ||
pat = tf.reshape(pat, [gh, gw, ph * pw, num_channels]) | ||
return pat | ||
|
||
patches = image | ||
patches = tf.map_fn(fn=lambda x: block_single_image(x), elems=patches) | ||
patches_dim = tf.shape(patches) | ||
patches = tf.reshape(patches, [patches_dim[0], patches_dim[1], patches_dim[2], -1]) | ||
patches = tf.reshape(patches, (patches_dim[0], patches_dim[1] * patches_dim[2], ph * pw, num_channels)) | ||
return [patches, ph, pw] | ||
|
||
def get_config(self): | ||
config = super().get_config().copy() | ||
return config | ||
return super().get_config() | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable("maxim") | ||
class UnblockImages(layers.Layer): | ||
class TFUnblockImages(layers.Layer): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
def call(self, x, grid_size, patch_size): | ||
x = einops.rearrange( | ||
x, | ||
"n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c", | ||
gh=grid_size[0], | ||
gw=grid_size[1], | ||
fh=patch_size[0], | ||
fw=patch_size[1], | ||
) | ||
def call(self, x, patch_size, grid_size): | ||
bs, grid_sqrt, patch_sqrt, num_channels = (tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]) | ||
ph, pw = patch_size | ||
gh, gw = grid_size | ||
|
||
return x | ||
pad = [[0, 0], [0, 0]] | ||
|
||
y = tf.reshape(x, (bs, gh, gw, -1, num_channels)) # (bs, gh, gw, ph*pw, 3) | ||
y = tf.expand_dims(y, 0) | ||
y = tf.transpose(y, perm=[4, 1, 2, 3, 0, 5]) | ||
y = tf.reshape(y, [bs * ph * pw, gh, gw, num_channels]) | ||
y = tf.batch_to_space(y, [ph, pw], pad) | ||
|
||
return y | ||
|
||
def get_config(self): | ||
config = super().get_config().copy() | ||
return config | ||
return super().get_config() | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable("maxim") | ||
|
@@ -76,28 +100,60 @@ def get_config(self): | |
|
||
|
||
@tf.keras.utils.register_keras_serializable("maxim") | ||
class Resizing(layers.Layer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the need to segregate this to Up and Down? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found it easier to read, but indeed it adds a chunk of code. Reformatting to use a single layer only. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's easier to read, I would consider adding an elaborate comment in the script so that readers are aware. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think now is better (with a single resizing layer). |
||
def __init__(self, height, width, antialias=True, method="bilinear", **kwargs): | ||
class ResizingDown(tf.keras.layers.Layer): | ||
def __init__(self, ratio: float, method="bilinear", antialias=True, **kwargs): | ||
super().__init__(**kwargs) | ||
self.height = height | ||
self.width = width | ||
self.antialias = antialias | ||
self.ratio = ratio | ||
self.method = method | ||
self.antialias = antialias | ||
|
||
def call(self, x): | ||
return tf.image.resize( | ||
x, | ||
size=(self.height, self.width), | ||
def __call__(self, img): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Prefer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, fixing... |
||
n, h, w, c = (tf.shape(img)[0], tf.shape(img)[1], tf.shape(img)[2], tf.shape(img)[3]) | ||
x = tf.image.resize( | ||
img, | ||
(h // self.ratio, w // self.ratio), | ||
method=self.method, | ||
antialias=self.antialias, | ||
) | ||
return x | ||
|
||
def get_config(self): | ||
config = super().get_config().copy() | ||
config.update( | ||
{ | ||
"ratio": self.ratio, | ||
"antialias": self.antialias, | ||
"method": self.method, | ||
} | ||
) | ||
return config | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable("maxim") | ||
class ResizingUp(tf.keras.layers.Layer): | ||
def __init__(self, ratio: float, method="bilinear", antialias=True, **kwargs): | ||
super().__init__(**kwargs) | ||
self.ratio = tf.constant(ratio, dtype=tf.float32) | ||
self.method = method | ||
self.antialias = antialias | ||
|
||
def __call__(self, img): | ||
shape = tf.shape(img) | ||
new_sh = self.ratio * tf.cast(shape[1:3], tf.float32) | ||
|
||
x = tf.image.resize( | ||
img, | ||
size=tf.cast(new_sh, tf.int32), | ||
method=self.method, | ||
antialias=self.antialias, | ||
) | ||
return x | ||
|
||
def get_config(self): | ||
config = super().get_config().copy() | ||
config.update( | ||
{ | ||
"height": self.height, | ||
"width": self.width, | ||
"ratio": self.ratio, | ||
"antialias": self.antialias, | ||
"method": self.method, | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why there's a separate layer for handling blocking by grids?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BlockByGrid can be implemented as follows, please see a more detailed explanation here:
But, while implementing TFBlockImages I used tf.split which expects an int literal as argument for num_or_size_splits.
However, in cases where we only have the grid_size and the block_size has to be computed on the fly (as here), it needs to be a tensor, and we can't use tf.split ins this case. That's why I also wrote BlockByGrid.