Skip to content

Commit

Permalink
Pass args to forward as tuple
Browse files Browse the repository at this point in the history
Format

Small typo
  • Loading branch information
destefy committed Mar 13, 2024
1 parent dc900d6 commit 13f83e3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
9 changes: 5 additions & 4 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from hidet.graph.flow_graph import FlowGraph
from hidet.graph.transforms import PassContext, optimize
from hidet.cuda.graph import CudaGraphCreationError
from hidet.ffi import runtime_api
from .dynamo_config import dynamo_config
from .interpreter import Interpreter
from .utils import serialize_output, deserialize_output, resolve_save_dir_multigraph
from .utils import symbol_like_torch
from hidet.ffi import runtime_api

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -116,7 +116,8 @@ def __init__(self, cgraph: CompiledGraph, inputs, output_format):
self.inputs = inputs
self.output_format = output_format

def forward(self, *args):
# Due to the structure of the torch.fx.Graph that wraps this function, arguments are passed as a tuple
def forward(self, args):
if dynamo_config['use_cuda_graph']:
try:
runner = self.cgraph.cuda_graph()
Expand All @@ -138,7 +139,7 @@ def forward(self, *args):
# ignore constant
pass

hidet_inputs = preprocess_inputs(*tensor_args)
hidet_inputs = preprocess_inputs(tensor_args)
hidet_outputs: List[hidet.Tensor] = runner.run_async(hidet_inputs)
outputs: Sequence[torch.Tensor] = [tensor.torch() for tensor in hidet_outputs]
return deserialize_output(self.output_format, outputs)
Expand Down Expand Up @@ -176,7 +177,7 @@ def wrapper(*args):
cff = CompiledForwardFunction(cgraph, inputs, output_format)

# The torch.fx.GraphModule compiled forward function expects arguments to be passed as a tuple
# But the compiled forward function expects arguments to be passed like forward(a, b, c, ...)
# But CompiledForwardFunction.forward expects arguments to be passed like forward(a, b, c, ...)
def args_to_tuple_wrapper(*args):
return cff(args)

Expand Down
1 change: 0 additions & 1 deletion python/hidet/runtime/compiled_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class GraphExecution:
tensor_device: List[str]



class CompiledGraph:
"""
A compiled graph that can be directly called in Python.
Expand Down

0 comments on commit 13f83e3

Please sign in to comment.