diff --git a/python/hidet/ir/schedulers/cuda/scheduler.py b/python/hidet/ir/schedulers/cuda/scheduler.py index cae20ab74..558b56aa0 100644 --- a/python/hidet/ir/schedulers/cuda/scheduler.py +++ b/python/hidet/ir/schedulers/cuda/scheduler.py @@ -35,7 +35,13 @@ def schedule_grid_compute(self, node: GridCompute, tensor_map: Dict[TensorNode, grid_dim: Expr = (prod(node.shape) + block_dim - 1) // block_dim if self.task is not None: - name = f'{self.task.name}_compute_{node.name}' + from hidet.graph.ops.fusion.fused_operator import FusedTask + + if isinstance(self.task, FusedTask): + fused_name = self.task.attrs['fused_ops'].replace(' ', '_') + name = f'fused_{fused_name}_{node.name}' + else: + name = f'{self.task.name}_{node.name}' else: name = f'compute_{node.name}'