-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for arbitrary image resolutions #24
base: main
Are you sure you want to change the base?
Changes from all commits
30f9393
d55b45a
af09c34
cd13378
ccdbe31
805c706
ff992fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
from tensorflow.keras import backend as K | ||
from tensorflow.keras import layers | ||
|
||
from ..layers import BlockImages, SwapAxes, UnblockImages | ||
from ..layers import SwapAxes, TFBlockImagesByGrid, TFUnblockImages | ||
|
||
|
||
def GridGatingUnit(use_bias: bool = True, name: str = "grid_gating_unit"): | ||
|
@@ -47,9 +47,8 @@ def apply(x): | |
K.int_shape(x)[3], | ||
) | ||
gh, gw = grid_size | ||
fh, fw = h // gh, w // gw | ||
|
||
x = BlockImages()(x, patch_size=(fh, fw)) | ||
x, ph, pw = TFBlockImagesByGrid()(x, grid_size=(gh, gw)) | ||
Comment on lines
-52
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How come these operations are the same? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the original implementation, the authors implement BlockByGrid by computing the block size of a grid cell, and using BlockImages (which block images into patches of block-size). From the paper, the authors explain the difference between "grid" and "block" like that: They are equivalent because it does the split based on the grid_size as argument instead of the block_size (called as (fh, fw) in the code) as the authors did. A more formal test is performed here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for explaining.
How is the block size of [3, 2] interpreted in that case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the original code, it is done as explained, note here: gh, gw = grid_size
fh, fw = h // gh, w // gw
u = BlockImages()(u, patch_size=(fh, fw)) Note that this code is very similar to the pseudo-code written here. We can use the block [3,2] to compute the green part of the image (grid blocking with grid_size=[3,2]) this way: In the example shown in the image, we have that image size is [6,4]. Thus to split it with a grid_size of [2,2], we can do: gh, gw = (2, 2)
h, w = (6,4) # image dimensions
fh, fw = h // gh, w // gw # Note that fh = 3, and fw = 2
block_image = BlockImages()(image_from_the_piture, patch_size=(fh,fw)) # patch_size=(3,2) The above code snippet implements the green part of the image, and is very similar to what we described first. In case with the gh, gw = (2,2)
block_image_using_tfblockByGrid = TFBlockByGrid()(image_from_the_picture, grid_size=(gh,gw)) and I am not sure if this answer what you asked, though. Let me know. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! So, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exactly! |
||
# gMLP1: Global (grid) mixing part, provides global grid communication. | ||
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x) | ||
y = layers.Dense( | ||
|
@@ -66,7 +65,7 @@ def apply(x): | |
)(y) | ||
y = layers.Dropout(dropout_rate)(y) | ||
x = x + y | ||
x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw)) | ||
x = TFUnblockImages()(x, grid_size=(gh, gw), patch_size=(ph, pw)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same. You're changing the semanticity of the code. Could you please elaborate why? Reading this change and also previous |
||
return x | ||
|
||
return apply |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,65 +1,114 @@ | ||
""" | ||
Layers based on https://github.com/google-research/maxim/blob/main/maxim/models/maxim.py | ||
and reworked to cope with variable image dimensions | ||
""" | ||
|
||
import einops | ||
import tensorflow as tf | ||
from tensorflow.experimental import numpy as tnp | ||
from tensorflow.keras import backend as K | ||
from tensorflow.keras import layers | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable("maxim") | ||
class BlockImages(layers.Layer): | ||
class TFBlockImages(layers.Layer): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
def call(self, x, patch_size): | ||
def call(self, image, patch_size): | ||
bs, h, w, num_channels = ( | ||
K.int_shape(x)[0], | ||
K.int_shape(x)[1], | ||
K.int_shape(x)[2], | ||
K.int_shape(x)[3], | ||
tf.shape(image)[0], | ||
tf.shape(image)[1], | ||
tf.shape(image)[2], | ||
tf.shape(image)[3], | ||
) | ||
ph, pw = patch_size | ||
gh = h // ph | ||
gw = w // pw | ||
pad = [[0, 0], [0, 0]] | ||
patches = tf.space_to_batch_nd(image, [ph, pw], pad) | ||
patches = tf.split(patches, ph * pw, axis=0) | ||
patches = tf.stack(patches, 3) # (bs, h/p, h/p, p*p, 3) | ||
patches_dim = tf.shape(patches) | ||
patches = tf.reshape( | ||
patches, [patches_dim[0], patches_dim[1], patches_dim[2], -1] | ||
) | ||
patches = tf.reshape( | ||
patches, | ||
(patches_dim[0], patches_dim[1] * patches_dim[2], ph * pw, num_channels), | ||
) | ||
return [patches, gh, gw] | ||
|
||
grid_height, grid_width = h // patch_size[0], w // patch_size[1] | ||
def get_config(self): | ||
return super().get_config() | ||
|
||
x = einops.rearrange( | ||
x, | ||
"n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c", | ||
gh=grid_height, | ||
gw=grid_width, | ||
fh=patch_size[0], | ||
fw=patch_size[1], | ||
) | ||
|
||
return x | ||
@tf.keras.utils.register_keras_serializable("maxim") | ||
class TFBlockImagesByGrid(layers.Layer): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
def call(self, image, grid_size): | ||
bs, h, w, num_channels = ( | ||
tf.shape(image)[0], | ||
tf.shape(image)[1], | ||
tf.shape(image)[2], | ||
tf.shape(image)[3], | ||
) | ||
gh, gw = grid_size | ||
ph = h // gh | ||
pw = w // gw | ||
pad = [[0, 0], [0, 0]] | ||
|
||
def block_single_image(img): | ||
pat = tf.expand_dims(img, 0) # batch = 1 | ||
pat = tf.space_to_batch_nd(pat, [ph, pw], pad) # p*p*bs, g, g, c | ||
pat = tf.expand_dims(pat, 3) # pxpxbs, g, g, 1, c | ||
pat = tf.transpose(pat, perm=[3, 1, 2, 0, 4]) # 1, g, g, pxp, c | ||
pat = tf.reshape(pat, [gh, gw, ph * pw, num_channels]) | ||
return pat | ||
|
||
patches = image | ||
patches = tf.map_fn(fn=lambda x: block_single_image(x), elems=patches) | ||
patches_dim = tf.shape(patches) | ||
patches = tf.reshape( | ||
patches, [patches_dim[0], patches_dim[1], patches_dim[2], -1] | ||
) | ||
patches = tf.reshape( | ||
patches, | ||
(patches_dim[0], patches_dim[1] * patches_dim[2], ph * pw, num_channels), | ||
) | ||
return [patches, ph, pw] | ||
|
||
def get_config(self): | ||
config = super().get_config().copy() | ||
return config | ||
return super().get_config() | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable("maxim") | ||
class UnblockImages(layers.Layer): | ||
class TFUnblockImages(layers.Layer): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
def call(self, x, grid_size, patch_size): | ||
x = einops.rearrange( | ||
x, | ||
"n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c", | ||
gh=grid_size[0], | ||
gw=grid_size[1], | ||
fh=patch_size[0], | ||
fw=patch_size[1], | ||
def call(self, x, patch_size, grid_size): | ||
bs, grid_sqrt, patch_sqrt, num_channels = ( | ||
tf.shape(x)[0], | ||
tf.shape(x)[1], | ||
tf.shape(x)[2], | ||
tf.shape(x)[3], | ||
) | ||
ph, pw = patch_size | ||
gh, gw = grid_size | ||
|
||
return x | ||
pad = [[0, 0], [0, 0]] | ||
|
||
y = tf.reshape(x, (bs, gh, gw, -1, num_channels)) # (bs, gh, gw, ph*pw, 3) | ||
y = tf.expand_dims(y, 0) | ||
y = tf.transpose(y, perm=[4, 1, 2, 3, 0, 5]) | ||
y = tf.reshape(y, [bs * ph * pw, gh, gw, num_channels]) | ||
y = tf.batch_to_space(y, [ph, pw], pad) | ||
|
||
return y | ||
|
||
def get_config(self): | ||
config = super().get_config().copy() | ||
return config | ||
return super().get_config() | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable("maxim") | ||
|
@@ -76,28 +125,31 @@ def get_config(self): | |
|
||
|
||
@tf.keras.utils.register_keras_serializable("maxim") | ||
class Resizing(layers.Layer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the need to segregate this to Up and Down? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found it easier to read, but indeed it adds a chunk of code. Reformatting to use a single layer only. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's easier to read, I would consider adding an elaborate comment in the script so that readers are aware. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think now is better (with a single resizing layer). |
||
def __init__(self, height, width, antialias=True, method="bilinear", **kwargs): | ||
class Resizing(tf.keras.layers.Layer): | ||
def __init__(self, ratio: float, method="bilinear", antialias=True, **kwargs): | ||
super().__init__(**kwargs) | ||
self.height = height | ||
self.width = width | ||
self.antialias = antialias | ||
self.ratio = ratio | ||
self.method = method | ||
self.antialias = antialias | ||
|
||
def call(self, x): | ||
return tf.image.resize( | ||
x, | ||
size=(self.height, self.width), | ||
antialias=self.antialias, | ||
def call(self, img): | ||
shape = tf.shape(img) | ||
|
||
new_sh = tf.cast(shape[1:3], tf.float32) // self.ratio | ||
|
||
x = tf.image.resize( | ||
img, | ||
size=tf.cast(new_sh, tf.int32), | ||
method=self.method, | ||
antialias=self.antialias, | ||
) | ||
return x | ||
|
||
def get_config(self): | ||
config = super().get_config().copy() | ||
config.update( | ||
{ | ||
"height": self.height, | ||
"width": self.width, | ||
"ratio": self.ratio, | ||
"antialias": self.antialias, | ||
"method": self.method, | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why there's a separate layer for handling blocking by grids?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BlockByGrid can be implemented as follows, please see a more detailed explanation here:
But, while implementing TFBlockImages I used tf.split which expects an int literal as argument for num_or_size_splits.
However, in cases where we only have the grid_size and the block_size has to be computed on the fly (as here), it needs to be a tensor, and we can't use tf.split ins this case. That's why I also wrote BlockByGrid.