Skip to content

Commit

Permalink
Remove input_shape property and set default to None. (facebookresearc…
Browse files Browse the repository at this point in the history
…h#90)

Summary:
Pull Request resolved: fairinternal/ClassyVision#90

Pull Request resolved: facebookresearch#687

This does not update video models or models with custom logic.

ci/circleci: cpu_tests are failed due to formatting differences between arc lint and circleci lint. We prefer the arc lint formatting.

Reviewed By: mannatsingh

Differential Revision: D25957137

fbshipit-source-id: 6defbac19e3baddb609ed33b9fd161959eaa7dfe
  • Loading branch information
Vaibhav Aggarwal authored and facebook-github-bot committed Jan 23, 2021
1 parent 56459a5 commit 8592b83
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 24 deletions.
7 changes: 5 additions & 2 deletions classy_vision/models/classy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,11 @@ def execute_heads(self) -> Dict[str, torch.Tensor]:

@property
def input_shape(self):
"""If implemented, returns expected input tensor shape"""
raise NotImplementedError
"""Returns the input shape that the model can accept, excluding the batch dimension.
By default it returns (3, 224, 224).
"""
return (3, 224, 224)


class _ClassyModelAdapter(ClassyModel):
Expand Down
7 changes: 0 additions & 7 deletions classy_vision/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,3 @@ def forward(self, x):
out = self.features(out)

return out

@property
def input_shape(self):
if self.small_input:
return (3, 32, 32)
else:
return (3, 224, 224)
4 changes: 0 additions & 4 deletions classy_vision/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,3 @@ def forward(self, x):
out = x.view(batchsize_per_replica, -1)
out = self.mlp(out)
return out

@property
def input_shape(self):
return (self._num_inputs,)
4 changes: 0 additions & 4 deletions classy_vision/models/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,6 @@ def init_weights(self):
m.weight.data.normal_(mean=0.0, std=0.01)
m.bias.data.zero_()

@property
def input_shape(self):
return (3, 224, 224)


# Register some "classic" RegNets
class _RegNet(RegNet):
Expand Down
7 changes: 0 additions & 7 deletions classy_vision/models/resnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,13 +434,6 @@ def forward(self, x):

return out

@property
def input_shape(self):
if self.small_input:
return (3, 32, 32)
else:
return (3, 224, 224)

def _convert_model_state(self, state):
"""Convert model state from the old implementation to the current format.
Expand Down

0 comments on commit 8592b83

Please sign in to comment.