diff --git a/deel/lip/layers/pooling.py b/deel/lip/layers/pooling.py index 92cb883d..d1fc209a 100644 --- a/deel/lip/layers/pooling.py +++ b/deel/lip/layers/pooling.py @@ -475,13 +475,10 @@ def call(self, inputs, **kwargs): # convert to channels_first inputs = tf.transpose(inputs, [0, 2, 3, 1]) # from shape (bs, w, h, c*pw*ph) to (bs, w, h, pw, ph, c) - bs, w, h = inputs.shape[:-1] - ( - pw, - ph, - ) = self.pool_size - c = inputs.shape[-1] // (pw * ph) - print(c) + input_shape = tf.shape(inputs) + w, h, c_in = input_shape[1], input_shape[2], input_shape[3] + pw, ph = self.pool_size + c = c_in // (pw * ph) inputs = tf.reshape(inputs, (-1, w, h, pw, ph, c)) inputs = tf.transpose( tf.reshape(