Skip to content

Commit

Permalink
[Bug] Fix a bug where the shared memory becomes zero in LaunchKernelS…
Browse files Browse the repository at this point in the history
…tmt (#58)

* .

* .
  • Loading branch information
yaoyaoding committed Jan 5, 2023
1 parent 4cb20c1 commit 9c26d88
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 7 deletions.
2 changes: 1 addition & 1 deletion python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def visit_Function(self, func: Function) -> Doc:
doc += ') {'

# comments
label = func.get_attr('label')
label = func.get_attr('label', default=None, allow_missing=True)
if label:
doc += (NewLine() + '// label: {}'.format(label)).indent()

Expand Down
31 changes: 28 additions & 3 deletions python/hidet/ir/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,40 @@ def __init__(self, name: str, params, body, ret_type, kind: str, extern_vars=Non
def __call__(self, *args, **kwargs) -> Call:
raise ValueError('Can only call script function in another script function, or lower it to execute.')

def get_attr(self, attr_name, default=None):
def get_attr(self, attr_name, default=None, allow_missing=False):
"""
Get attribute of this function.
When default is not None or allow_missing is True, this function will return the default value (in case
default is not None) or None (in case default is None) when the attribute is not found. Otherwise,
this function will raise a KeyError.
Parameters
----------
attr_name: str
The name of attribute
default: Any, optional
The default value of attribute
allow_missing: bool, default False
Returns
-------
attr_value: Any
The value of attribute
"""
if attr_name in self.attrs:
return self.attrs[attr_name]
return default
if default is not None or allow_missing:
return default
else:
raise KeyError('Attribute {} is not found in function {}'.format(attr_name, self.name))


class IRModule(Node):
"""
The intermidiate representation of tensor programs.
The intermediate representation of tensor programs.
An IRModule contains one or more functions. It is the basic compilation unit of hidet.
"""
Expand Down
5 changes: 3 additions & 2 deletions python/hidet/transforms/generate_packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ def process_module(self, ir_module: IRModule) -> IRModule:
if func.kind not in ['cuda_kernel', 'host_kernel']:
# only generate packed func for entry function
continue
if func.get_attr('packed_func', None) is not None:
if func.get_attr('packed_func', allow_missing=True) is not None:
# this function itself is a packed function
continue
if any(
f.get_attr('packed_func', None) is ir_module.lookup_var(func.name) for f in ir_module.functions.values()
f.get_attr('packed_func', allow_missing=True) is ir_module.lookup_var(func.name)
for f in ir_module.functions.values()
):
# the packed function for current function has existed, skip
continue
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/transforms/tools/generate_packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def extract_params_and_call(num_args: Expr, arg_types: Expr, args: Expr) -> Stmt
[param2arg[param] for param in func.params],
grid_dim=_rewrite_dim3(_normalize_dim3(func.get_attr('cuda_grid_dim')), param2arg),
block_dim=_rewrite_dim3(_normalize_dim3(func.get_attr('cuda_block_dim')), param2arg),
shared_mem=rewrite(int32(func.get_attr('shared_mem', 0)), param2arg),
shared_mem=rewrite(int32(func.get_attr('cuda_dynamic_smem_bytes', 0)), param2arg),
)
elif func.kind == 'host_kernel':
sb += Call(func_var, [param2arg[param] for param in func.params])
Expand Down

0 comments on commit 9c26d88

Please sign in to comment.