Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
hjjq committed Oct 17, 2023
1 parent 4ec132a commit 1c68107
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
23 changes: 15 additions & 8 deletions python/hidet/graph/ops/reduce/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def cuda_schedule_reduce_by_default(self, max_block_size=256, use_atomic=True) -
from hidet.ir.compute import ReduceOperation
from hidet.ir.type import data_type, Int, tensor_type
from hidet.ir.expr import cast, address, is_constant
from hidet.ir.layout import row_major, local_layout
from hidet.ir.layout import row_major
from hidet.lang import spatial, repeat, attrs, tensor_pointer
from hidet.lang.cuda import dynamic_shared_memory, syncthreads

Expand Down Expand Up @@ -236,7 +236,6 @@ def cuda_schedule_reduce_by_default(self, max_block_size=256, use_atomic=True) -
else:
write_remain_shape = read_remain_shape[:]
y_vectorized_shape = write_remain_shape
write_layout = row_major(*write_remain_shape)
reduce_shape = [v for i, v in enumerate(shape) if i in dims]
reduce_extent = hidet.utils.prod(reduce_shape)
remain_extent = hidet.utils.prod(read_remain_shape)
Expand All @@ -251,11 +250,13 @@ def cuda_schedule_reduce_by_default(self, max_block_size=256, use_atomic=True) -
num_warps = reduce_warps * remain_warps
block_size = num_warps * WARP_SIZE
grid_size = cdiv(remain_extent, remain_warps * WARP_SIZE)
read_task_mapping = spatial(1, grid_size) * spatial(reduce_warps, remain_warps * WARP_SIZE) * repeat(repeats_per_reduce, 1)
read_task_mapping = (
spatial(1, grid_size) * spatial(reduce_warps, remain_warps * WARP_SIZE) * repeat(repeats_per_reduce, 1)
)
write_task_mapping = spatial(1, grid_size) * spatial(reduce_warps, remain_warps * WARP_SIZE)
remain_write_mapping = spatial(*write_remain_shape)

use_smem = False if (is_constant(reduce_warps) and reduce_warps == 1) else True
use_smem = not (is_constant(reduce_warps) and reduce_warps == 1)
smem_length = remain_warps * WARP_SIZE * lanes
smem_flattened_layout = row_major(smem_length)
smem_task_mapping = spatial(reduce_warps, remain_warps * WARP_SIZE) * repeat(1, lanes)
Expand Down Expand Up @@ -333,9 +334,11 @@ def reduce_kernel(x: xdtype[x.shape], y: xdtype[y.shape]):
reduce_round = indices[0]
if reduce_round == k:
remain_idx = indices[1]
smem_staging[remain_idx] = ro.combine(smem_staging[remain_idx], rv[remain_idx % lanes])
smem_staging[remain_idx] = ro.combine(
smem_staging[remain_idx], rv[remain_idx % lanes]
)
syncthreads()

# At this point, the shared memory (or rv, if not using smem) contains the final reduction value.
# Next, need to write back to global memory
if threadIdx.x < remain_warps * WARP_SIZE:
Expand All @@ -344,9 +347,13 @@ def reduce_kernel(x: xdtype[x.shape], y: xdtype[y.shape]):
if lanes > 1:
lane_vec = cast(~write_val, ~vtype.lane_type)
if use_smem:
lane_vec[remain_idx % lanes] = ro.finalize(acc=smem_staging[remain_idx], size=reduce_extent)
lane_vec[remain_idx % lanes] = ro.finalize(
acc=smem_staging[remain_idx], size=reduce_extent
)
else:
lane_vec[remain_idx % lanes] = ro.finalize(acc=rv[remain_idx % lanes], size=reduce_extent)
lane_vec[remain_idx % lanes] = ro.finalize(
acc=rv[remain_idx % lanes], size=reduce_extent
)
else:
if use_smem:
write_val[0] = ro.finalize(acc=smem_staging[remain_idx], size=reduce_extent)
Expand Down
3 changes: 1 addition & 2 deletions python/hidet/graph/ops/reduce/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,14 @@ def resolve_simplify(self, op: Operator) -> Optional[List[Tensor]]:

def resolve_decompose(self, op: Operator) -> Optional[List[Tensor]]:
dims = op.attrs['dims']
keepdims = op.attrs['keepdims']
x: Tensor = op.inputs[0]
shape = x.shape
dims = normalize_dim(dims, len(shape))
if (len(shape) - 1) not in dims and len(dims) > 1:
# start from highest dim to support keepdims=True
dims.sort(reverse=True)
for dim in dims:
x = op.reforward([x], {'dims':[dim]})[0]
x = op.reforward([x], {'dims': [dim]})[0]
return [x]
return None

Expand Down

0 comments on commit 1c68107

Please sign in to comment.