Skip to content

Commit

Permalink
Revive dynamic shape support with torch.compile (#162)
Browse files Browse the repository at this point in the history
Revive dynamic shape support with `torch.compile`.
It was broken due to changes in pytorch interface.
  • Loading branch information
vadiklyutiy committed Jul 22, 2024
1 parent 982b552 commit b75e5d8
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 101 deletions.
38 changes: 21 additions & 17 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch
import hidet.option
from hidet import Tensor
from hidet.ir import dtypes
from hidet.ir.type import DataType
from hidet.ir.expr import SymbolVar
from hidet.runtime import CompiledGraph
Expand All @@ -30,26 +29,31 @@

logger = logging.getLogger(__name__)


def get_flow_graph(interpreter, example_inputs):
# prepare dummy and symbolic inputs for correctness and flow graph construction
inputs: List[Union[Tensor, SymbolVar, int, bool, float]] = [] # for flow graph construction
for example_input in example_inputs:
# NOTES ABOUT DYNAMIC SHAPE.
# From pytorch we got two argument:
# - fxgraph
# - example_inputs
# In case when we are requested to create dynamic shape, `example_inputs` contain info
# about used symbols only (all symbols are presented in `example_input` as element of list).
# But in `example_inputs` there is no information about what dimentions of input tensors
# should be symbolic and correspondence between symbol and dimention.
# These info is presented in fxgraph. Every input corresponds fxgraph node.
# in `fx_node.meta['example_value']` stored `FakeTensor` that contain all symbols in its shape.
# We use this data to determinate shapes of the inputs.
def get_flow_graph(interpreter: Interpreter, example_inputs):
inputs: List[Union[Tensor, SymbolVar]] = [] # for flow graph construction
for fxgraph_node, example_input in zip(interpreter.graph.nodes, example_inputs):
if isinstance(example_input, torch.Tensor):
symbolic_input = symbol_like_torch(example_input)
fake_input = fxgraph_node.meta['example_value']
symbolic_input = symbol_like_torch(fake_input)
inputs.append(symbolic_input)
elif isinstance(example_input, (int, bool, float)):
elif isinstance(example_input, int):
inputs.append(example_input)
elif isinstance(example_input, torch.SymInt):
from torch.fx.experimental.symbolic_shapes import SymNode

node: SymNode = example_input.node
try:
inputs.append(node.pytype(example_input))
except RuntimeError:
# is a symbolic scalar input
pytype2dtype = {int: dtypes.int32, float: dtypes.float32, bool: dtypes.boolean}
inputs.append(hidet.symbol_var(name=str(example_input), dtype=pytype2dtype[node.pytype]))
assert fxgraph_node.op == 'placeholder' and fxgraph_node.type is torch.SymInt
name = fxgraph_node.name
var = hidet.symbol_var(name)
inputs.append(var)
else:
raise ValueError(f"hidet_backend: unexpected example input {example_input}, type {type(example_input)}")

Expand Down
18 changes: 12 additions & 6 deletions python/hidet/graph/frontend/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,20 @@ def symbol_like_torch(tensor) -> Tensor:
from torch._subclasses.fake_tensor import FakeTensor

if isinstance(tensor, FakeTensor):
# this should be fine for now; torch wraps around the sympy library
symbolic_shape = []
for s in tensor.shape:
try:
i = int(s)
except Exception: # pylint: disable=broad-except
i = str(s)
symbolic_shape.append(i)
if isinstance(s, int):
symbolic_shape.append(s)
else:
assert isinstance(s, torch.SymInt)
expr = s.node.expr
if expr.is_Integer:
i = int(s)
symbolic_shape.append(i)
else:
assert expr.is_Symbol
name = s.node.expr.name
symbolic_shape.append(name)
return hidet.symbol(shape=symbolic_shape, dtype=dtype_from_torch(tensor.dtype).name, device=tensor.device.type)
elif isinstance(tensor, torch.Tensor):
return hidet.symbol(
Expand Down
73 changes: 73 additions & 0 deletions python/hidet/testing/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,76 @@ def check_module(model: torch.nn.Module, args: Sequence[torch.Tensor], atol=1e-4
torch_output = torch_output.detach().cpu().numpy()
hidet_output = hidet_output.detach().cpu().numpy()
numpy.testing.assert_allclose(torch_output, hidet_output, atol=atol, rtol=rtol)


# Class to initialise backend, run compilation
class Backend:
def __init__(self, backend, dtype, search_space=2) -> None:
assert backend in [
'hidet',
'max-autotune',
'max-autotune-no-cudagraphs',
'eager',
], 'backend is hidet or max-autotune or max-autotune-no-cudagraphs or eager supported only'
self.backend = backend
self.dtype = dtype
self.search_space = search_space
if self.backend == 'hidet':
self.init_hidet()

def init_hidet(self):
import hidet
import os

use_fp16 = self.dtype == 'float16'
hidet.torch.dynamo_config.search_space(self.search_space)
hidet.torch.dynamo_config.use_fp16(use_fp16)
hidet.torch.dynamo_config.use_fp16_reduction(use_fp16)
hidet.torch.dynamo_config.use_attention(True)
hidet.torch.dynamo_config.use_tensor_core(True)
hidet.torch.dynamo_config.use_cuda_graph(True)
hidet.option.search_space(self.search_space)

# hidet.option.cache_dir(hidet.option.get_cache_dir() + '/regression')
# hidet.option.parallel_tune(max_parallel_jobs=1)
# hidet.option.debug_cache_tuning(True)
# hidet.option.save_lower_ir(True)
# hidet.option.debug_show_verbose_flow_graph(True)

# Initialise compiler server
if os.environ.get('CI_CS_HOSTNAME'):
hidet.option.compile_server.addr(os.environ.get('CI_CS_HOSTNAME'))
hidet.option.compile_server.port(int(os.environ.get('CI_CS_PORT')))
hidet.option.compile_server.username(os.environ.get('CI_CS_USERNAME'))
hidet.option.compile_server.password(os.environ.get('CI_CS_PASSWORD'))
hidet.option.compile_server.repo(os.environ.get('REPO_NAME').strip(), os.environ.get('REPO_BRANCH').strip())
hidet.option.compile_server.enable(flag=True)

def compile(self, model):
if self.backend == 'hidet':
model = torch.compile(model, backend=self.backend)
elif self.backend == 'eager':
pass
else:
model = torch.compile(model, mode=self.backend)
return model


# Make benchmarking of given torch model
def bench_torch_model(model, torch_inputs, bench_iters=100, warmup_iters=10):
for _ in range(warmup_iters):
out = model(*torch_inputs) # pylint:disable=unused-variable
torch.cuda.empty_cache()

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
for _ in range(bench_iters):
out = model(*torch_inputs) # pylint:disable=unused-variable
end.record()
end.synchronize()
torch.cuda.empty_cache()

latency = start.elapsed_time(end) / bench_iters
return latency
3 changes: 2 additions & 1 deletion tests/benchmarks/bench_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import argparse
import numpy as np
import hidet
from bench_utils import bench_torch_model, Backend

from hidet.testing.torch_utils import bench_torch_model, Backend


def bench_matmul_f16(params: str, *args, **kwargs) -> float:
Expand Down
6 changes: 4 additions & 2 deletions tests/benchmarks/bench_op_torch_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import argparse
import torch
from bench_utils import bench_torch_model, Backend
from hidet.testing.torch_utils import bench_torch_model, Backend


# MATMUL BENCHMARKS #
Expand Down Expand Up @@ -139,10 +139,12 @@ def create_model_transpose(params: str, dtype):
# Main benchmark function for ops.
# Calls bench_torch_model
def bench_op(operator, params, dtype, backend):
dtype = getattr(torch, dtype)
comp_backend = Backend(backend, dtype)
dtype = getattr(torch, dtype)

model_creator = getattr(sys.modules[__name__], "create_model_" + operator)
model, model_inputs = model_creator(params, dtype)
model = model.eval().to(dtype).cuda()
with torch.no_grad(), torch.autocast("cuda"):
opt_model = comp_backend.compile(model)
latency = bench_torch_model(opt_model, model_inputs)
Expand Down
2 changes: 1 addition & 1 deletion tests/benchmarks/bench_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForCausalLM, logging
from bench_utils import bench_torch_model, Backend
from hidet.testing.torch_utils import bench_torch_model, Backend

os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.set_verbosity_error()
Expand Down
73 changes: 0 additions & 73 deletions tests/benchmarks/bench_utils.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/benchmarks/bench_vision.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import torch
import torchvision
from bench_utils import bench_torch_model, Backend
from hidet.testing.torch_utils import bench_torch_model, Backend


def bench_torchvision(model_name, shape, dtype, backend):
Expand Down
97 changes: 97 additions & 0 deletions tests/frontends/torch/test_torch_dyn_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import pytest, sys
import torch
import hidet
from hidet.testing.torch_utils import Backend


def no_compilaion(*args, **kwargs):
assert False, 'At this point must not be compilation, everything should be covered by dynamic shapes'


# REDUCE #
class torch_sum(torch.nn.Module):
def __init__(self, axis):
super(torch_sum, self).__init__()
self.axis = axis

def forward(self, x):
return torch.sum(x, dim=self.axis)


def create_model_reduce(axis):
model = torch_sum(axis=axis)
return model


@pytest.mark.parametrize('operator', ['reduce'])
@pytest.mark.parametrize('dtype', ['float32'])
@pytest.mark.parametrize('axis', [[1, 2]])
def test_dynamic_shape_w_mark_dynamic(operator, dtype, axis):
hidet_backend = Backend('hidet', dtype, search_space=0)
torch_backend = Backend('eager', dtype)
dtype = getattr(torch, dtype)

model_creator = getattr(sys.modules[__name__], "create_model_" + operator)
model = model_creator(axis)
model = model.eval().to(dtype).cuda()
with torch.no_grad(), torch.autocast("cuda"):
hidet_model = hidet_backend.compile(model)
torch_model = torch_backend.compile(model)

model_inputs1x = torch.randn(*[2, 16, 16, 3], dtype=dtype, device='cuda')
# Mark dimension as dynamic
torch._dynamo.mark_dynamic(model_inputs1x, 0)
hidet_out = hidet_model(model_inputs1x)
torch_out = torch_model(model_inputs1x)
assert torch.allclose(hidet_out, torch_out, rtol=1e-04, atol=1e-04)

tmp = hidet.drivers.build_task
hidet.drivers.build_task = no_compilaion

model_inputs2x = torch.randn(*[3, 16, 16, 3], dtype=dtype, device='cuda')
hidet_out = hidet_model(model_inputs2x)
torch_out = torch_model(model_inputs2x)
assert torch.allclose(hidet_out, torch_out, rtol=1e-04, atol=1e-04)

model_inputs3x = torch.randn(*[5, 16, 16, 3], dtype=dtype, device='cuda')
hidet_out = hidet_model(model_inputs3x)
torch_out = torch_model(model_inputs3x)
assert torch.allclose(hidet_out, torch_out, rtol=1e-04, atol=1e-04)

hidet.drivers.build_task = tmp


@pytest.mark.parametrize('operator', ['reduce'])
@pytest.mark.parametrize('dtype', ['float32'])
@pytest.mark.parametrize('axis', [[1, 2]])
def test_dynamic_shape_w_heuristic_mark(operator, dtype, axis):
hidet_backend = Backend('hidet', dtype, search_space=0)
torch_backend = Backend('eager', dtype)
dtype = getattr(torch, dtype)

model_creator = getattr(sys.modules[__name__], "create_model_" + operator)
model = model_creator(axis)
model = model.eval().to(dtype).cuda()
with torch.no_grad(), torch.autocast("cuda"):
hidet_model = hidet_backend.compile(model)
torch_model = torch_backend.compile(model)

model_inputs1x = torch.randn(*[2, 16, 16, 3], dtype=dtype, device='cuda')
hidet_out = hidet_model(model_inputs1x)
torch_out = torch_model(model_inputs1x)
assert torch.allclose(hidet_out, torch_out, rtol=1e-04, atol=1e-04)

model_inputs2x = torch.randn(*[3, 16, 16, 3], dtype=dtype, device='cuda')
hidet_out = hidet_model(model_inputs2x)
torch_out = torch_model(model_inputs2x)
assert torch.allclose(hidet_out, torch_out, rtol=1e-04, atol=1e-04)

tmp = hidet.drivers.build_task
hidet.drivers.build_task = no_compilaion

model_inputs3x = torch.randn(*[5, 16, 16, 3], dtype=dtype, device='cuda')
hidet_out = hidet_model(model_inputs3x)
torch_out = torch_model(model_inputs3x)
assert torch.allclose(hidet_out, torch_out, rtol=1e-04, atol=1e-04)

hidet.drivers.build_task = tmp

0 comments on commit b75e5d8

Please sign in to comment.