Skip to content

Commit

Permalink
Format inputs to compiled function
Browse files Browse the repository at this point in the history
  • Loading branch information
destefy committed Mar 13, 2024
1 parent f46bc71 commit dc900d6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
19 changes: 14 additions & 5 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def get_compiled_graph(flow_graph: FlowGraph):
logger.info('finish building computation graph')
return cgraph


def preprocess_inputs(inputs: Sequence[torch.Tensor]) -> List[hidet.Tensor]:
torch_inputs: List[torch.Tensor] = []
for x in inputs:
Expand All @@ -107,7 +108,8 @@ def preprocess_inputs(inputs: Sequence[torch.Tensor]) -> List[hidet.Tensor]:
hidet_inputs: List[hidet.Tensor] = [hidet.from_torch(tensor) for tensor in torch_inputs]
return hidet_inputs

class CompiledForwardFunction(torch.nn.Module):

class CompiledForwardFunction(torch.nn.Module):
def __init__(self, cgraph: CompiledGraph, inputs, output_format):
super().__init__()
self.cgraph = cgraph
Expand All @@ -120,7 +122,7 @@ def forward(self, *args):
runner = self.cgraph.cuda_graph()
except CudaGraphCreationError:
runner = self.cgraph

tensor_args = []
for param, arg in zip(self.inputs, args):
if isinstance(param, Tensor):
Expand All @@ -135,8 +137,8 @@ def forward(self, *args):
else:
# ignore constant
pass
hidet_inputs = preprocess_inputs(tensor_args)

hidet_inputs = preprocess_inputs(*tensor_args)
hidet_outputs: List[hidet.Tensor] = runner.run_async(hidet_inputs)
outputs: Sequence[torch.Tensor] = [tensor.torch() for tensor in hidet_outputs]
return deserialize_output(self.output_format, outputs)
Expand Down Expand Up @@ -171,4 +173,11 @@ def wrapper(*args):

cgraph = get_compiled_graph(flow_graph)

return CompiledForwardFunction(cgraph, example_inputs, output_format)
cff = CompiledForwardFunction(cgraph, inputs, output_format)

# The torch.fx.GraphModule compiled forward function expects arguments to be passed as a tuple
# But the compiled forward function expects arguments to be passed like forward(a, b, c, ...)
def args_to_tuple_wrapper(*args):
return cff(args)

return args_to_tuple_wrapper
5 changes: 3 additions & 2 deletions python/hidet/runtime/compiled_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ class GraphExecution:
outputs_index: List[int]
tensor_device: List[str]

import sys


class CompiledGraph:
"""
A compiled graph that can be directly called in Python.
Expand Down Expand Up @@ -522,7 +523,7 @@ def load(self, path: str):
See Also
--------
CompiledGraph.save or save_compiled_graph
Parameters
----------
path: str
Expand Down

0 comments on commit dc900d6

Please sign in to comment.