diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index e03fa84e50..7a0400ea91 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -10,6 +10,7 @@ import dataclasses import pathlib +import types from typing import Protocol, TypeVar import factory @@ -42,6 +43,9 @@ def __call__( ) -> stages.BuildSystemProject[SrcL, LS, TgtL]: ... +_MODULES: list[types.ModuleType] = [] + + @dataclasses.dataclass(frozen=True) class Compiler( workflow.ChainableWorkflowMixin[ @@ -83,11 +87,12 @@ def __call__( f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) - compiled_prog = getattr( - importer.import_from_path(src_dir / new_data.module), new_data.entry_point_name - ) + m = importer.import_from_path(src_dir / new_data.module) + # Keep a reference to the module so they are not garbage collected. This avoids a SEGFAULT + # in nanobind when calling the compiled program. + _MODULES.append(m) - return compiled_prog + return getattr(m, new_data.entry_point_name) class CompilerFactory(factory.Factory):