Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 8592b83

Browse files
Vaibhav Aggarwalfacebook-github-bot
Vaibhav Aggarwal
authored andcommitted
Remove input_shape property and set default to None. (#90)
Summary: Pull Request resolved: fairinternal/ClassyVision#90 Pull Request resolved: #687 This does not update video models or models with custom logic. ci/circleci: cpu_tests are failed due to formatting differences between arc lint and circleci lint. We prefer the arc lint formatting. Reviewed By: mannatsingh Differential Revision: D25957137 fbshipit-source-id: 6defbac19e3baddb609ed33b9fd161959eaa7dfe
1 parent 56459a5 commit 8592b83

File tree

5 files changed

+5
-24
lines changed

5 files changed

+5
-24
lines changed

classy_vision/models/classy_model.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,11 @@ def execute_heads(self) -> Dict[str, torch.Tensor]:
407407

408408
@property
409409
def input_shape(self):
410-
"""If implemented, returns expected input tensor shape"""
411-
raise NotImplementedError
410+
"""Returns the input shape that the model can accept, excluding the batch dimension.
411+
412+
By default it returns (3, 224, 224).
413+
"""
414+
return (3, 224, 224)
412415

413416

414417
class _ClassyModelAdapter(ClassyModel):

classy_vision/models/densenet.py

-7
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,3 @@ def forward(self, x):
274274
out = self.features(out)
275275

276276
return out
277-
278-
@property
279-
def input_shape(self):
280-
if self.small_input:
281-
return (3, 32, 32)
282-
else:
283-
return (3, 224, 224)

classy_vision/models/mlp.py

-4
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,3 @@ def forward(self, x):
8585
out = x.view(batchsize_per_replica, -1)
8686
out = self.mlp(out)
8787
return out
88-
89-
@property
90-
def input_shape(self):
91-
return (self._num_inputs,)

classy_vision/models/regnet.py

-4
Original file line numberDiff line numberDiff line change
@@ -537,10 +537,6 @@ def init_weights(self):
537537
m.weight.data.normal_(mean=0.0, std=0.01)
538538
m.bias.data.zero_()
539539

540-
@property
541-
def input_shape(self):
542-
return (3, 224, 224)
543-
544540

545541
# Register some "classic" RegNets
546542
class _RegNet(RegNet):

classy_vision/models/resnext.py

-7
Original file line numberDiff line numberDiff line change
@@ -434,13 +434,6 @@ def forward(self, x):
434434

435435
return out
436436

437-
@property
438-
def input_shape(self):
439-
if self.small_input:
440-
return (3, 32, 32)
441-
else:
442-
return (3, 224, 224)
443-
444437
def _convert_model_state(self, state):
445438
"""Convert model state from the old implementation to the current format.
446439

0 commit comments

Comments
 (0)