Skip to content

Commit

Permalink
Inline generated functions and update validation
Browse files Browse the repository at this point in the history
Signed-off-by: Jake Tronge <[email protected]>
  • Loading branch information
jtronge committed Nov 29, 2023
1 parent 8df3db5 commit 7e4c1ed
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions ompi/mpi/c/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ class ConvertOMPIToStandard:
COMM = 'ompi_convert_comm_ompi_to_standard'


# Inline function attributes
INLINE_ATTRS = '__opal_attribute_always_inline__ static inline'


def mpi_fn_name_from_base_fn_name(name):
"""Convert from a base name to the standard 'MPI_*' name."""
return f'MPI_{name.capitalize()}'
Expand Down Expand Up @@ -318,7 +322,7 @@ def dump_lines(self, lines):
self.dump(line)

def generate_error_convert_fn(self):
self.dump(f'static inline int {ConvertFuncs.ERROR_CLASS}(int error_class)')
self.dump(f'{INLINE_ATTRS} int {ConvertFuncs.ERROR_CLASS}(int error_class)')
self.dump('{')
lines = []
lines.append('switch (error_class) {')
Expand All @@ -333,7 +337,7 @@ def generate_error_convert_fn(self):

def generic_convert(self, fn_name, param_name, type_, value_names):
intern_type = self.mangle_name(type_)
self.dump(f'static inline {type_} {fn_name}({intern_type} {param_name})')
self.dump(f'{INLINE_ATTRS} {type_} {fn_name}({intern_type} {param_name})')
self.dump('{')
lines = []
for i, value_name in enumerate(value_names):
Expand All @@ -350,7 +354,7 @@ def generic_convert(self, fn_name, param_name, type_, value_names):

def generic_convert_reverse(self, fn_name, param_name, type_, value_names):
intern_type = self.mangle_name(type_)
self.dump(f'static inline {intern_type} {fn_name}({type_} {param_name})')
self.dump(f'{INLINE_ATTRS} {intern_type} {fn_name}({type_} {param_name})')
self.dump('{')
lines = []
for i, value_name in enumerate(value_names):
Expand Down Expand Up @@ -388,7 +392,7 @@ def generate_win_convert_fn(self):

def generate_pointer_convert_fn(self, type_, fn_name, constants):
abi_type = self.mangle_name(type_)
self.dump(f'static inline void {fn_name}({abi_type} *ptr)')
self.dump(f'{INLINE_ATTRS} void {fn_name}({abi_type} *ptr)')
self.dump('{')
lines = []
for i, ompi_name in enumerate(constants):
Expand All @@ -411,7 +415,7 @@ def generate_file_convert_fn(self):
def generate_status_convert_fn(self):
type_ = 'MPI_Status'
abi_type = self.mangle_name(type_)
self.dump(f'static inline void {ConvertFuncs.STATUS}({abi_type} *out, {type_} *inp)')
self.dump(f'{INLINE_ATTRS} void {ConvertFuncs.STATUS}({abi_type} *out, {type_} *inp)')
self.dump('{')
self.dump(' out->MPI_SOURCE = inp->MPI_SOURCE;')
self.dump(' out->MPI_TAG = inp->MPI_TAG;')
Expand Down Expand Up @@ -1051,7 +1055,7 @@ class TemplateParseError(Exception):

def validate_body(body):
"""Validate the body of a template."""
# Just do a simple bracket balance test determine the bounds of the
# Just do a simple bracket balance test to determine the bounds of the
# function body. All lines after the function body should be blank. There
# are cases where this will break, such as if someone puts code all on one
# line.
Expand All @@ -1060,13 +1064,16 @@ def validate_body(body):
for line in body:
line = line.strip()
if bracket_balance == 0 and line_count > 0 and line:
raise TemplateParserError('Extra code found in template; only one function body is allowed')
raise TemplateParseError('Extra code found in template; only one function body is allowed')

update = line.count('{') - line.count('}')
bracket_balance += update
if bracket_balance != 0:
line_count += 1

if bracket_balance != 0:
raise TemplateParseError('Mismatched brackets found in template')


class SourceTemplate:
"""Source template for a single API function."""
Expand Down Expand Up @@ -1182,7 +1189,7 @@ def standard_abi(base_name, template):
internal_name = f'ompi_abi_{template.prototype.name}'
internal_sig = template.prototype.signature('ompi', internal_name,
count_type='MPI_Count')
print('static inline', internal_sig)
print(INLINE_ATTRS, internal_sig)
template.print_body(func_name=base_name)

def generate_function(prototype, fn_name, internal_fn, count_type='int'):
Expand Down

0 comments on commit 7e4c1ed

Please sign in to comment.