From 674273191afb2da164468557fae66afd9199a36b Mon Sep 17 00:00:00 2001 From: Said Mazouz <95222894+smazouz42@users.noreply.github.com> Date: Wed, 15 May 2024 12:58:50 +0100 Subject: [PATCH] Fix import handling (#49) This pull request fixes https://github.com/pyccel/pyccel-cuda/issues/48, by implementing a tiny wrapper for CUDA and a wrapper for non-CUDA functionalities only with external 'C'. **Commit Summary** - Implemented new header printer for CUDA. - Added CUDA wrapper assignment - Instead of wrapping all local headers, wrap only C functions with extern 'C' --------- Co-authored-by: EmilyBourne Co-authored-by: bauom <40796259+bauom@users.noreply.github.com> --- AUTHORS | 1 + CHANGELOG.md | 3 +- pyccel/codegen/printing/cucode.py | 45 ++++++++---- pyccel/codegen/python_wrapper.py | 4 ++ pyccel/codegen/wrapper/cuda_to_c_wrapper.py | 78 +++++++++++++++++++++ tests/epyccel/modules/cuda_module.py | 13 ++++ tests/epyccel/test_epyccel_modules.py | 13 ++++ 7 files changed, 143 insertions(+), 14 deletions(-) create mode 100644 pyccel/codegen/wrapper/cuda_to_c_wrapper.py create mode 100644 tests/epyccel/modules/cuda_module.py diff --git a/AUTHORS b/AUTHORS index 6c30ce5830..3dbaa2f249 100644 --- a/AUTHORS +++ b/AUTHORS @@ -31,3 +31,4 @@ Contributors * Farouk Ech-Charef * Mustapha Belbiad * Varadarajan Rengaraj +* Said Mazouz diff --git a/CHANGELOG.md b/CHANGELOG.md index 941ef5f341..3d414f1256 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,8 @@ All notable changes to this project will be documented in this file. ### Added -- #32 : add support for `nvcc` Compiler and `cuda` language as a possible option. +- #32 : Add support for `nvcc` Compiler and `cuda` language as a possible option. +- #48 : Fix incorrect handling of imports in `cuda`. ## \[UNRELEASED\] diff --git a/pyccel/codegen/printing/cucode.py b/pyccel/codegen/printing/cucode.py index 86146b065b..277d2a3a6a 100644 --- a/pyccel/codegen/printing/cucode.py +++ b/pyccel/codegen/printing/cucode.py @@ -52,19 +52,7 @@ def _print_Module(self, expr): # Print imports last to be sure that all additional_imports have been collected imports = [Import(expr.name, Module(expr.name,(),())), *self._additional_imports.values()] - c_headers_imports = '' - local_imports = '' - - for imp in imports: - if imp.source in c_library_headers: - c_headers_imports += self._print(imp) - else: - local_imports += self._print(imp) - - imports = f'{c_headers_imports}\ - extern "C"{{\n\ - {local_imports}\ - }}' + imports = ''.join(self._print(i) for i in imports) code = f'{imports}\n\ {global_variables}\n\ @@ -72,3 +60,34 @@ def _print_Module(self, expr): self.exit_scope() return code + + def _print_ModuleHeader(self, expr): + self.set_scope(expr.module.scope) + self._in_header = True + name = expr.module.name + + funcs = "" + cuda_headers = "" + for f in expr.module.funcs: + if not f.is_inline: + if 'kernel' in f.decorators: # Checking for 'kernel' decorator + cuda_headers += self.function_signature(f) + ';\n' + else: + funcs += self.function_signature(f) + ';\n' + global_variables = ''.join('extern '+self._print(d) for d in expr.module.declarations if not d.variable.is_private) + # Print imports last to be sure that all additional_imports have been collected + imports = [*expr.module.imports, *self._additional_imports.values()] + imports = ''.join(self._print(i) for i in imports) + + self._in_header = False + self.exit_scope() + function_declaration = f'{cuda_headers}\n\ + extern "C"{{\n\ + {funcs}\ + }}\n' + return '\n'.join((f"#ifndef {name.upper()}_H", + f"#define {name.upper()}_H", + global_variables, + function_declaration, + "#endif // {name.upper()}_H\n")) + diff --git a/pyccel/codegen/python_wrapper.py b/pyccel/codegen/python_wrapper.py index 9437727042..62c303fa64 100644 --- a/pyccel/codegen/python_wrapper.py +++ b/pyccel/codegen/python_wrapper.py @@ -13,6 +13,7 @@ from pyccel.codegen.printing.fcode import FCodePrinter from pyccel.codegen.wrapper.fortran_to_c_wrapper import FortranToCWrapper from pyccel.codegen.wrapper.c_to_python_wrapper import CToPythonWrapper +from pyccel.codegen.wrapper.cuda_to_c_wrapper import CudaToCWrapper from pyccel.codegen.utilities import recompile_object from pyccel.codegen.utilities import copy_internal_library from pyccel.codegen.utilities import internal_libs @@ -144,6 +145,9 @@ def create_shared_library(codegen, verbose=verbose) timings['Bind C wrapping'] = time.time() - start_bind_c_compiling c_ast = bind_c_mod + elif language == 'cuda': + wrapper = CudaToCWrapper() + c_ast = wrapper.wrap(codegen.ast) else: c_ast = codegen.ast diff --git a/pyccel/codegen/wrapper/cuda_to_c_wrapper.py b/pyccel/codegen/wrapper/cuda_to_c_wrapper.py new file mode 100644 index 0000000000..c0e24c7c09 --- /dev/null +++ b/pyccel/codegen/wrapper/cuda_to_c_wrapper.py @@ -0,0 +1,78 @@ +# coding: utf-8 +#------------------------------------------------------------------------------------------# +# This file is part of Pyccel which is released under MIT License. See the LICENSE file or # +# go to https://github.com/pyccel/pyccel/blob/master/LICENSE for full license details. # +#------------------------------------------------------------------------------------------# +""" +Module describing the code-wrapping class : CudaToPythonWrapper +which creates an interface exposing Cuda code to C. +""" + +from pyccel.ast.bind_c import BindCModule +from pyccel.errors.errors import Errors +from pyccel.ast.bind_c import BindCVariable +from .wrapper import Wrapper + +errors = Errors() + +class CudaToCWrapper(Wrapper): + """ + Class for creating a wrapper exposing Cuda code to C. + + While CUDA is typically compatible with C by default. + this wrapper becomes necessary in scenarios where specific adaptations + or modifications are required to ensure seamless integration with C. + """ + + def _wrap_Module(self, expr): + """ + Create a Module which is compatible with C. + + Create a Module which provides an interface between C and the + Module described by expr. + + Parameters + ---------- + expr : pyccel.ast.core.Module + The module to be wrapped. + + Returns + ------- + pyccel.ast.core.BindCModule + The C-compatible module. + """ + init_func = expr.init_func + if expr.interfaces: + errors.report("Interface wrapping is not yet supported for Cuda", + severity='warning', symbol=expr) + if expr.classes: + errors.report("Class wrapping is not yet supported for Cuda", + severity='warning', symbol=expr) + + variables = [self._wrap(v) for v in expr.variables] + + return BindCModule(expr.name, variables, expr.funcs, + init_func=init_func, + scope = expr.scope, + original_module=expr) + + def _wrap_Variable(self, expr): + """ + Create all objects necessary to expose a module variable to C. + + Create and return the objects which must be printed in the wrapping + module in order to expose the variable to C + + Parameters + ---------- + expr : pyccel.ast.variables.Variable + The module variable. + + Returns + ------- + pyccel.ast.core.BindCVariable + The C-compatible variable. which must be printed in + the wrapping module to expose the variable. + """ + return expr.clone(expr.name, new_class = BindCVariable) + diff --git a/tests/epyccel/modules/cuda_module.py b/tests/epyccel/modules/cuda_module.py new file mode 100644 index 0000000000..bb7ae6b98a --- /dev/null +++ b/tests/epyccel/modules/cuda_module.py @@ -0,0 +1,13 @@ +# pylint: disable=missing-function-docstring, missing-module-docstring +import numpy as np + +g = np.float64(9.81) +r0 = np.float32(1.0) +rmin = 0.01 +rmax = 1.0 + +skip_centre = True + +method = 3 + +tiny = np.int32(4) diff --git a/tests/epyccel/test_epyccel_modules.py b/tests/epyccel/test_epyccel_modules.py index 445e8dc457..42aebf0134 100644 --- a/tests/epyccel/test_epyccel_modules.py +++ b/tests/epyccel/test_epyccel_modules.py @@ -182,3 +182,16 @@ def test_awkward_names(language): assert mod.function() == modnew.function() assert mod.pure() == modnew.pure() assert mod.allocate(1) == modnew.allocate(1) + +def test_cuda_module(language_with_cuda): + import modules.cuda_module as mod + + modnew = epyccel(mod, language=language_with_cuda) + + atts = ('g', 'r0', 'rmin', 'rmax', 'skip_centre', + 'method', 'tiny') + for att in atts: + mod_att = getattr(mod, att) + modnew_att = getattr(modnew, att) + assert mod_att == modnew_att + assert type(mod_att) is type(modnew_att)