From 7e4c1edc4cbf60ef2cda32c36e62c8125b7eedbe Mon Sep 17 00:00:00 2001 From: Jake Tronge Date: Tue, 28 Nov 2023 13:59:55 -0700 Subject: [PATCH] Inline generated functions and update validation Signed-off-by: Jake Tronge --- ompi/mpi/c/abi.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/ompi/mpi/c/abi.py b/ompi/mpi/c/abi.py index efa6adc4a84..f40986ff981 100755 --- a/ompi/mpi/c/abi.py +++ b/ompi/mpi/c/abi.py @@ -259,6 +259,10 @@ class ConvertOMPIToStandard: COMM = 'ompi_convert_comm_ompi_to_standard' +# Inline function attributes +INLINE_ATTRS = '__opal_attribute_always_inline__ static inline' + + def mpi_fn_name_from_base_fn_name(name): """Convert from a base name to the standard 'MPI_*' name.""" return f'MPI_{name.capitalize()}' @@ -318,7 +322,7 @@ def dump_lines(self, lines): self.dump(line) def generate_error_convert_fn(self): - self.dump(f'static inline int {ConvertFuncs.ERROR_CLASS}(int error_class)') + self.dump(f'{INLINE_ATTRS} int {ConvertFuncs.ERROR_CLASS}(int error_class)') self.dump('{') lines = [] lines.append('switch (error_class) {') @@ -333,7 +337,7 @@ def generate_error_convert_fn(self): def generic_convert(self, fn_name, param_name, type_, value_names): intern_type = self.mangle_name(type_) - self.dump(f'static inline {type_} {fn_name}({intern_type} {param_name})') + self.dump(f'{INLINE_ATTRS} {type_} {fn_name}({intern_type} {param_name})') self.dump('{') lines = [] for i, value_name in enumerate(value_names): @@ -350,7 +354,7 @@ def generic_convert(self, fn_name, param_name, type_, value_names): def generic_convert_reverse(self, fn_name, param_name, type_, value_names): intern_type = self.mangle_name(type_) - self.dump(f'static inline {intern_type} {fn_name}({type_} {param_name})') + self.dump(f'{INLINE_ATTRS} {intern_type} {fn_name}({type_} {param_name})') self.dump('{') lines = [] for i, value_name in enumerate(value_names): @@ -388,7 +392,7 @@ def generate_win_convert_fn(self): def generate_pointer_convert_fn(self, type_, fn_name, constants): abi_type = self.mangle_name(type_) - self.dump(f'static inline void {fn_name}({abi_type} *ptr)') + self.dump(f'{INLINE_ATTRS} void {fn_name}({abi_type} *ptr)') self.dump('{') lines = [] for i, ompi_name in enumerate(constants): @@ -411,7 +415,7 @@ def generate_file_convert_fn(self): def generate_status_convert_fn(self): type_ = 'MPI_Status' abi_type = self.mangle_name(type_) - self.dump(f'static inline void {ConvertFuncs.STATUS}({abi_type} *out, {type_} *inp)') + self.dump(f'{INLINE_ATTRS} void {ConvertFuncs.STATUS}({abi_type} *out, {type_} *inp)') self.dump('{') self.dump(' out->MPI_SOURCE = inp->MPI_SOURCE;') self.dump(' out->MPI_TAG = inp->MPI_TAG;') @@ -1051,7 +1055,7 @@ class TemplateParseError(Exception): def validate_body(body): """Validate the body of a template.""" - # Just do a simple bracket balance test determine the bounds of the + # Just do a simple bracket balance test to determine the bounds of the # function body. All lines after the function body should be blank. There # are cases where this will break, such as if someone puts code all on one # line. @@ -1060,13 +1064,16 @@ def validate_body(body): for line in body: line = line.strip() if bracket_balance == 0 and line_count > 0 and line: - raise TemplateParserError('Extra code found in template; only one function body is allowed') + raise TemplateParseError('Extra code found in template; only one function body is allowed') update = line.count('{') - line.count('}') bracket_balance += update if bracket_balance != 0: line_count += 1 + if bracket_balance != 0: + raise TemplateParseError('Mismatched brackets found in template') + class SourceTemplate: """Source template for a single API function.""" @@ -1182,7 +1189,7 @@ def standard_abi(base_name, template): internal_name = f'ompi_abi_{template.prototype.name}' internal_sig = template.prototype.signature('ompi', internal_name, count_type='MPI_Count') - print('static inline', internal_sig) + print(INLINE_ATTRS, internal_sig) template.print_body(func_name=base_name) def generate_function(prototype, fn_name, internal_fn, count_type='int'):