diff --git a/python/hidet/backend/codegen.py b/python/hidet/backend/codegen.py index 567009068..7791944a0 100644 --- a/python/hidet/backend/codegen.py +++ b/python/hidet/backend/codegen.py @@ -237,7 +237,7 @@ def visit_Function(self, func: Function) -> Doc: doc += ') {' # comments - label = func.get_attr('label') + label = func.get_attr('label', default=None, allow_missing=True) if label: doc += (NewLine() + '// label: {}'.format(label)).indent() diff --git a/python/hidet/ir/func.py b/python/hidet/ir/func.py index 7547691eb..cc4ea2ed8 100644 --- a/python/hidet/ir/func.py +++ b/python/hidet/ir/func.py @@ -72,15 +72,40 @@ def __init__(self, name: str, params, body, ret_type, kind: str, extern_vars=Non def __call__(self, *args, **kwargs) -> Call: raise ValueError('Can only call script function in another script function, or lower it to execute.') - def get_attr(self, attr_name, default=None): + def get_attr(self, attr_name, default=None, allow_missing=False): + """ + Get attribute of this function. + + When default is not None or allow_missing is True, this function will return the default value (in case + default is not None) or None (in case default is None) when the attribute is not found. Otherwise, + this function will raise a KeyError. + + Parameters + ---------- + attr_name: str + The name of attribute + + default: Any, optional + The default value of attribute + + allow_missing: bool, default False + + Returns + ------- + attr_value: Any + The value of attribute + """ if attr_name in self.attrs: return self.attrs[attr_name] - return default + if default is not None or allow_missing: + return default + else: + raise KeyError('Attribute {} is not found in function {}'.format(attr_name, self.name)) class IRModule(Node): """ - The intermidiate representation of tensor programs. + The intermediate representation of tensor programs. An IRModule contains one or more functions. It is the basic compilation unit of hidet. """ diff --git a/python/hidet/transforms/generate_packed_func.py b/python/hidet/transforms/generate_packed_func.py index 18b14edc6..66c3f7e66 100644 --- a/python/hidet/transforms/generate_packed_func.py +++ b/python/hidet/transforms/generate_packed_func.py @@ -28,11 +28,12 @@ def process_module(self, ir_module: IRModule) -> IRModule: if func.kind not in ['cuda_kernel', 'host_kernel']: # only generate packed func for entry function continue - if func.get_attr('packed_func', None) is not None: + if func.get_attr('packed_func', allow_missing=True) is not None: # this function itself is a packed function continue if any( - f.get_attr('packed_func', None) is ir_module.lookup_var(func.name) for f in ir_module.functions.values() + f.get_attr('packed_func', allow_missing=True) is ir_module.lookup_var(func.name) + for f in ir_module.functions.values() ): # the packed function for current function has existed, skip continue diff --git a/python/hidet/transforms/tools/generate_packed_func.py b/python/hidet/transforms/tools/generate_packed_func.py index f8773d301..ce250c546 100644 --- a/python/hidet/transforms/tools/generate_packed_func.py +++ b/python/hidet/transforms/tools/generate_packed_func.py @@ -86,7 +86,7 @@ def extract_params_and_call(num_args: Expr, arg_types: Expr, args: Expr) -> Stmt [param2arg[param] for param in func.params], grid_dim=_rewrite_dim3(_normalize_dim3(func.get_attr('cuda_grid_dim')), param2arg), block_dim=_rewrite_dim3(_normalize_dim3(func.get_attr('cuda_block_dim')), param2arg), - shared_mem=rewrite(int32(func.get_attr('shared_mem', 0)), param2arg), + shared_mem=rewrite(int32(func.get_attr('cuda_dynamic_smem_bytes', 0)), param2arg), ) elif func.kind == 'host_kernel': sb += Call(func_var, [param2arg[param] for param in func.params])