Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
flatten operation (resnet50) (pytorch#61265)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#61265

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D29626383

Pulled By: migeed-z

fbshipit-source-id: 107769fc14f1fad295a93a10e84235f25ae17357
  • Loading branch information
migeed-z authored and facebook-github-bot committed Jul 16, 2021
1 parent 4479aa8 commit 4e2fe97
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 14 deletions.
123 changes: 121 additions & 2 deletions test/fx/test_gradual_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@
from torch.fx.experimental.graph_gradual_typechecker import GraphTypeChecker, broadcast_types
from torch.fx.experimental.rewriter import RewritingTracer
from torch.fx import GraphModule
from torch.fx.passes.shape_prop import ShapeProp

try:
from torchvision.models import resnet50

HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
skipIfNoMkldnn = unittest.skipIf(
not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()),
"no MKLDNN",
)


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
Expand Down Expand Up @@ -551,7 +565,7 @@ def forward(self, x: TensorType((2, 2, 4, 5))):
assert isinstance(n.type, TensorType)
assert torch.Size(n.type.__args__) == B.forward(torch.rand(2, 2, 4, 5)).size()

def test_type_check_conv2D_and_maxpool2d(self):
def test_type_check_conv2D_maxpool2d_flatten(self):

class BasicBlock(torch.nn.Module):
def __init__(self):
Expand All @@ -570,6 +584,7 @@ def forward(self, x : TensorType((4, 3, 32, 32))):
out = self.pool(out)
out = self.fc1(out)
out = self.pool2(out)
out = torch.flatten(out, 1)
return out

B = BasicBlock()
Expand All @@ -582,13 +597,40 @@ def forward(self, x : TensorType((4, 3, 32, 32))):
expected_ph_types = [TensorType((4, 3, 32, 32)), TensorType((4, 6, 28, 28)),
TensorType((4, 6, 14, 14)), TensorType((4, 16, 10, 10)),
TensorType((4, 16, 5, 5)), TensorType((4, 16, 5, 120)),
TensorType((4, 16, 6, 7)), TensorType((4, 16, 6, 7))]
TensorType((4, 16, 6, 7)), TensorType((4, 672)), TensorType((4, 672))]

expected_iter = iter(expected_ph_types)

for n in traced.graph.nodes:
assert n.type == next(expected_iter)

def test_type_check_flatten(self):
class M(torch.nn.Module):
def forward(self, x: TensorType((1, 2, 3, 5, Dyn))):
return torch.flatten(x, 1, 2)

module = M()
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
tc = GraphTypeChecker({}, symbolic_traced)
tc.type_check()
for n in symbolic_traced.graph.nodes:
if n.op == 'output':
assert n.type == TensorType((1, 6, 5, Dyn))


def test_type_check_flatten_2(self):
class M(torch.nn.Module):
def forward(self, x: TensorType((1, Dyn, 3, 5, Dyn))):
return torch.flatten(x, 1, 2)

module = M()
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
tc = GraphTypeChecker({}, symbolic_traced)
tc.type_check()
for n in symbolic_traced.graph.nodes:
if n.op == 'output':
assert n.type == TensorType((1, Dyn, 5, Dyn))



def test_type_typechecl_maxpool2d_3dinput(self):
Expand Down Expand Up @@ -700,6 +742,83 @@ def forward(self, x):
assert is_consistent(n.type, TensorType(b.size()))


def test_flatten_fully_static(self):
annotation_list = [Dyn, TensorType((2, 5, 6, 9)), TensorType((10, 15, 13, 14)),
TensorType((10, Dyn, 13, 14)), TensorType((Dyn, Dyn, Dyn, 10))]
input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
(10, 15, 13, 14), (2, 2, 10, 10)]

intermediate_list = [Dyn, (2, 5, 6, 9), (10, 15, 13, 14),
(10, 15, 13, 14), (2, 2, 10, 10)]

start_dim = [1, 2, 1, 2, 0]
end_dim = [1, 3, 3, 3, -2]

for i in range(5):
annotation = annotation_list[i]
input = input_list[i]
# intermediate_type = intermediate_list[i]

class BasicBlock(torch.nn.Module):
def __init__(self, start, end):
super(BasicBlock, self).__init__()
self.start = start
self.end = end

def forward(self, x):
out = torch.flatten(x, self.start, self.end)
return out

B = BasicBlock(start_dim[i], end_dim[i])
ast_rewriter = RewritingTracer()
graph = ast_rewriter.trace(B)
traced = GraphModule(ast_rewriter.root, graph, "gm")

# annotate our argument
for n in graph.nodes:
if n.op == 'placeholder':
n.type = annotation

b = B.forward(torch.rand(input))
tc = GraphTypeChecker({}, traced)
tc.type_check()

for n in graph.nodes:
if n.op == 'output':
assert is_consistent(n.type, TensorType(b.size()))

@skipIfNoTorchVision
def test_resnet50(self):
gm_run = symbolic_trace(resnet50())
sample_input = torch.randn(1, 3, 224, 224)

# run our nodes
ShapeProp(gm_run).propagate(sample_input)

gm_static = symbolic_trace(resnet50())

for n in gm_static.graph.nodes:
n.type = None

g = GraphTypeChecker({}, gm_static)
g.type_check()
# here we are checking for consistency with fully dynamic nodes
for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes):
assert is_consistent(n1.type, TensorType(n2.meta['tensor_meta'].shape))

# here we give the same input as to runtume
gm_static_with_types = symbolic_trace(resnet50())

# we initialize our placeholder
for n in gm_static_with_types.graph.nodes:
if n.op == 'placeholder':
n.type = TensorType((1, 3, 224, 224))

g = GraphTypeChecker({}, gm_static_with_types)
g.type_check()
for n1, n2 in zip(gm_static_with_types.graph.nodes, gm_run.graph.nodes):
assert n1.type == TensorType(n2.meta['tensor_meta'].shape)


if __name__ == '__main__':
unittest.main()
71 changes: 60 additions & 11 deletions torch/fx/experimental/graph_gradual_typechecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.conv import Conv2d


_INFERENCE_RULES: Dict[Target, Callable] = {}


Expand All @@ -23,7 +22,7 @@ def expand_to_tensor_dim(t, n):
return TensorType(tuple(dims))
elif isinstance(t, TensorType):
if len(t.__args__) != n:
raise TypeError(f'Cannot extend tensor dimension. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}')
raise TypeError(f'Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}')
return t
else:
raise TypeError(f'Cannot match the type {t}')
Expand Down Expand Up @@ -208,11 +207,10 @@ def bn2d_inference_rule(n: Node, module_instance):
raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}')


def calculate(d_in, module_instance, index):
def calculate_out_dimension(d_in, module_instance, index):
"""
For calculating h_in and w_out.
"""

padding = (module_instance.padding, module_instance.padding) \
if isinstance(module_instance.padding, int) else module_instance.padding
kernel_size = (module_instance.kernel_size, module_instance.kernel_size)\
Expand Down Expand Up @@ -269,12 +267,11 @@ def conv2d_inference_rule(n: Node, module_instance):
if is_consistent(arg_type.__args__[1], module_instance.in_channels):
w_in = arg_type.__args__[3]
h_in = arg_type.__args__[2]
h_out = calculate(h_in, module_instance, 0)
w_out = calculate(w_in, module_instance, 1)
h_out = calculate_out_dimension(h_in, module_instance, 0)
w_out = calculate_out_dimension(w_in, module_instance, 1)
new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out))
gub = get_greatest_upper_bound(new_type, curr_node_type)
n.type = gub

return n.type
else:
raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}')
Expand All @@ -300,8 +297,10 @@ def maxpool2d_check(typ, module_instance):
if len(new_type_list) == 4 or len(new_type_list) == 3:
w_in = new_type_list[-1]
h_in = new_type_list[-2]
h_out = calculate(h_in, module_instance, 0)
w_out = calculate(w_in, module_instance, 1)

h_out = calculate_out_dimension(h_in, module_instance, 0)
w_out = calculate_out_dimension(w_in, module_instance, 1)

new_type_list[-1] = w_out
new_type_list[-2] = h_out
return TensorType(tuple(new_type_list))
Expand Down Expand Up @@ -360,7 +359,6 @@ def linear_inference_rule(n: Node, module_instance):
return n.type



def adaptiveavgpool2d_check(tensor_type, module_instance):
output_size = module_instance.output_size
if isinstance(output_size, int):
Expand Down Expand Up @@ -397,6 +395,50 @@ def adaptiveavgpool2d_inference_rule(n: Node, module_instance):
n.type = get_greatest_upper_bound(n.type, output_type)
return n.type

def flatten_check(tensor_type, start_dim, end_dim):
l = len(tensor_type.__args__)

start_dim = l if start_dim == -1 else start_dim
end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1

if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim:
my_args = list(tensor_type.__args__)
lhs = my_args[0:start_dim]
rhs = my_args[end_dim:]
mid = my_args[start_dim:end_dim]
if Dyn in mid:
mid = [Dyn]
else:
mid = [reduce(lambda x, y: x * y, my_args[start_dim:end_dim])]
new_type_list = lhs + mid + rhs
return TensorType(tuple(new_type_list))
else:
raise TypeError(f'Incompatable dimentions {start_dim}, {end_dim - 1} in type {tensor_type}')

@register_inference_rule(torch.flatten)
def flatten_inference_rule(n: Node):
assert isinstance(n.args[0], Node)

# set the default start and end dims
start_dim = 1
end_dim = -1

if len(n.args) > 1:
assert isinstance(n.args[1], int)
start_dim = n.args[1]

if len(n.args) > 2:
assert isinstance(n.args[2], int)
end_dim = n.args[2]

if n.args[0].type == Dyn and isinstance(n.type, TensorType):
n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))

if isinstance(n.args[0].type, TensorType):
output_type = flatten_check(n.args[0].type, start_dim, end_dim)
n.type = get_greatest_upper_bound(output_type , n.type)

return n.type

class GraphTypeChecker:
def __init__(self, env, traced):
Expand Down Expand Up @@ -424,6 +466,13 @@ def type_check_node(self, n: Node):
- Reshape
- Transpose
- Add
- Relu
- conv2d
- batchnorm2d
- flatten
- maxpool2d
- adaptiveavgpool2d
- linear
"""
if n.type is None:
n.type = Dyn
Expand All @@ -438,7 +487,7 @@ def type_check_node(self, n: Node):
raise RuntimeError(f'No inference rule registered for target {n.target}!')

if n.op == 'call_module':
module_instance = getattr(self.traced, str(n.target))
module_instance = self.traced.get_submodule(n.target)
if type(module_instance) in _INFERENCE_RULES:
return _INFERENCE_RULES[type(module_instance)](n, module_instance)
else:
Expand Down
3 changes: 2 additions & 1 deletion torch/fx/passes/shape_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def forward(self, x):
ShapeProp(gm).propagate(sample_input)
for node in gm.graph.nodes:
print(node.name, node.dtype, node.shape)
print(node.name, node.meta['tensor_meta'].dtype,
node.meta['tensor_meta'].shape)
The output of this code is:
Expand Down

0 comments on commit 4e2fe97

Please sign in to comment.