Skip to content

Commit

Permalink
transpose 2d v1
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiwei-fang committed Mar 1, 2024
1 parent 8befb62 commit 1408300
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 1 deletion.
11 changes: 11 additions & 0 deletions .github/scripts/bench/bench_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ def bench_conv2d(params: str, *args, **kwargs) -> float:
g = g.cuda_graph()
return bench_torch_model(lambda: g.run_async(), [])

def bench_transpose2d(params: str, *args, **kwargs) -> float:
x_shape = params
x_shape = [int(s) for s in x_shape.split('x')]
x = hidet.symbol(x_shape, dtype='float32', device='cuda')
o = hidet.ops.transpose(x)
g = hidet.trace_from(o, inputs=[x])
g = hidet.graph.optimize(g)
g = g.cuda_graph()
return bench_torch_model(lambda: g.run_async(), [])

def bench_conv2d_gemm_f16(params: str, *args, **kwargs) -> float:
x_shape, w_shape = params.split(',')
x_shape = [int(s) for s in x_shape.split('x')]
Expand Down Expand Up @@ -101,6 +111,7 @@ def bench_reduce(params: str, *args, **kwargs) -> float:
'matmul_f16': bench_matmul_f16,
'batch_matmul': bench_batch_matmul,
'conv2d': bench_conv2d,
'transpose2d' : bench_transpose2d,
'conv2d_gemm_f16': bench_conv2d_gemm_f16,
'attn': bench_attn,
'attn_mask_add': bench_attn_mask_add,
Expand Down
90 changes: 89 additions & 1 deletion python/hidet/graph/ops/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
from typing import List, Optional, Union, Sequence
from hidet.ir.type import DataType, data_type
from hidet.ir.expr import Expr, Constant, if_then_else, convert, cast as ir_cast, is_constant
from hidet.ir.expr import Int
from hidet.ir.expr import Int, logical_and
from hidet.ir.layout import RowMajorLayout
from hidet.ir.utils import index_deserialize, index_serialize
from hidet.utils import prod
from .utils import Task, InverseMap, Operator, Tensor, TensorNode, compute, input_like, normalize_dim, can_broadcast
from .utils import TensorInput, normalize_slice
from hidet.ir.module import IRModule
from hidet.ir.library import tune
from hidet.utils.py import cdiv


def is_true(x: Union[Expr, bool]) -> bool:
Expand Down Expand Up @@ -320,6 +323,84 @@ def fmap(*indices):
super().__init__(name='tile', inputs=[data], outputs=[out])


class TransposeTask2D(Task):
def __init__(self, input: TensorNode):
self.input_shape = input.shape
self.input_dtype = input.type.dtype
self.output_shape = [self.input_shape[1], self.input_shape[0]]

output = compute(name='output', shape=self.output_shape, fcompute=lambda i, j: input[j, i])

super().__init__(name='transpose2d', inputs=[input], outputs=[output], attributes={})

def allow_prologue(self) -> bool:
return False

def allow_epilogue(self) -> bool:
return True

def implement_cuda(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
return tune.extract_ir_modules(self.cuda_schedule_threads_coarsening_transpose)

@tune.space(1, coarsen_factor_row=[1, 2, 3, 4], coarsen_factor_col=[1, 2, 3, 4])
def cuda_schedule_threads_coarsening_transpose(self, coarsen_factor_row=1, coarsen_factor_col=1) -> IRModule:
# pylint: disable=unused-variable
import hidet
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
from hidet.lang.cuda import shared_tensor
from hidet.lang import attrs

input, output = self.inputs[0], self.outputs[0]
tile_size_baseline = 32
numElementsPerThread_row, numElementsPerThread_col = coarsen_factor_row, coarsen_factor_col
blockSize_row = min(cdiv(self.input_shape[0], numElementsPerThread_row), tile_size_baseline)
blockSize_col = min(cdiv(self.input_shape[1], numElementsPerThread_col), tile_size_baseline)
numElementsPerBlock_row = numElementsPerThread_row * blockSize_row
numElementsPerBlock_col = numElementsPerThread_col * blockSize_col

sharedMemSize_row, sharedMemSize_col = numElementsPerBlock_row, numElementsPerBlock_col
if blockSize_row % tile_size_baseline == 0 and blockSize_col % tile_size_baseline == 0:
sharedMemSize_col += 1
block_size = (blockSize_row, blockSize_col)
grid_size = (
cdiv(self.input_shape[0], numElementsPerBlock_row),
cdiv(self.input_shape[1], numElementsPerBlock_col),
)
with hidet.script_module() as module:

@hidet.script
def transpose_kernel(
input: self.input_dtype[self.input_shape], output: self.input_dtype[self.output_shape]
):
attrs.cuda.grid_dim = grid_size
attrs.cuda.block_dim = block_size
tile = shared_tensor(self.input_dtype, shape=[sharedMemSize_row, sharedMemSize_col])

for kx in range(coarsen_factor_row):
for ky in range(coarsen_factor_col):
tx = threadIdx.x + blockSize_row * kx
ty = threadIdx.y + blockSize_col * ky
xIndex = blockIdx.x * numElementsPerBlock_row + tx
yIndex = blockIdx.y * numElementsPerBlock_col + ty

if xIndex < self.input_shape[0] and yIndex < self.input_shape[1]:
tile[tx, ty] = input[xIndex, yIndex]

syncthreads()
for kx in range(coarsen_factor_row):
for ky in range(coarsen_factor_col):
tx = threadIdx.x + blockSize_row * kx
ty = threadIdx.y + blockSize_col * ky
xIndex = blockIdx.y * numElementsPerBlock_col + tx
yIndex = blockIdx.x * numElementsPerBlock_row + ty

if xIndex < self.output_shape[0] and yIndex < self.output_shape[1]:
output[xIndex, yIndex] = tile[ty, tx]

ir_module = module.ir_module()
return ir_module


class ReshapeOp(Operator):
def __init__(self, x: Tensor, shape):
task = ReshapeTask(input_like(x, 'x'), shape)
Expand Down Expand Up @@ -536,8 +617,15 @@ def flatten(x: Tensor, start_dim=0, end_dim=-1) -> Tensor:
return FlattenOp(x, start_dim, end_dim).outputs[0]


class TransposeOp2D(Operator):
def __init__(self, input: Tensor):
super().__init__(inputs=[input], attributes={}, task=TransposeTask2D(input_like(input, 'input')))


def transpose(x: Tensor, axes: Optional[Sequence[int]] = None) -> Tensor:
rank = len(x.shape)
if rank == 2:
TransposeOp2D(x).outputs[0]
if axes is None:
axes = list(reversed(range(rank)))
axes = [normalize_dim(dim, rank) for dim in axes]
Expand Down
5 changes: 5 additions & 0 deletions tests/operators/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def test_transpose(shape, axes):
check_transform(shape, lambda x: np.transpose(x, axes), lambda x: ops.transpose(x, axes))


@pytest.mark.parametrize("shape", [[33, 44], [1, 100], [100, 1], [10, 20], [20, 10], [100, 200], [2000, 3000]])
def test_transpose_2d(shape):
check_transform(shape, lambda x: np.transpose(x), lambda x: ops.transpose(x))


@pytest.mark.parametrize(
"shapes, dtype, axis",
[
Expand Down

0 comments on commit 1408300

Please sign in to comment.