From f3ccb8763d18d23fe2c18e2e0310234c346af0eb Mon Sep 17 00:00:00 2001 From: Kevin Tong Date: Wed, 21 Feb 2024 13:39:14 -0500 Subject: [PATCH] [Fixbug] Fix dynamic memcpy bug (#427) Minimal failure case: ``` resize_inputs: Tensor = symbol([1, 3, "h", "w"], dtype="int32", device="cpu") resize_outputs = self.resize(resize_inputs.to(self.dtype, self.device)) # (float32, cuda) resize_graph: FlowGraph = trace_from(resize_outputs, resize_inputs) resize_graph.build() ``` compiles this launch where symbols `h` and `w` are undefined. ``` DLL void hidet_launch_0(float * __restrict__ x, float * __restrict__ y) { cudaMemcpyAsync(y, x, (4 * ((3 * h) * w)), cudaMemcpyHostToDevice, (cudaStream_t)get_cuda_stream()); } ``` Fix is to add exprs to BlackBoxStmt so that symbols defined in exprs can be visited during codegen. --- python/hidet/ir/primitives/cuda/memcpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/ir/primitives/cuda/memcpy.py b/python/hidet/ir/primitives/cuda/memcpy.py index bf62c0d53..224fa08a8 100644 --- a/python/hidet/ir/primitives/cuda/memcpy.py +++ b/python/hidet/ir/primitives/cuda/memcpy.py @@ -27,5 +27,5 @@ def memcpy_async(dst: Expr, src: Expr, count: Expr, kind: str): raise RuntimeError(f'Unsupported transfer from {src} to {dst}, candidate kinds are {list(kind_map.keys())}') return BlackBoxStmt( - 'cudaMemcpyAsync({}, {}, {}, {}, (cudaStream_t){});'.format(dst, src, count, kind_map[kind], get_cuda_stream()) + f'cudaMemcpyAsync({{}}, {{}}, {{}}, {kind_map[kind]}, (cudaStream_t){{}});', dst, src, count, get_cuda_stream() )