Skip to content

Commit 542e8e2

Browse files
committed
move initialization of omp/ol runtimes into global_ctor/dtor
1 parent 1b39278 commit 542e8e2

File tree

3 files changed

+64
-32
lines changed

3 files changed

+64
-32
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ pub(crate) struct OffloadGlobals<'ll> {
1818
pub launcher_fn: &'ll llvm::Value,
1919
pub launcher_ty: &'ll llvm::Type,
2020

21-
pub bin_desc: &'ll llvm::Type,
22-
2321
pub kernel_args_ty: &'ll llvm::Type,
2422

2523
pub offload_entry_ty: &'ll llvm::Type,
@@ -30,8 +28,6 @@ pub(crate) struct OffloadGlobals<'ll> {
3028

3129
pub ident_t_global: &'ll llvm::Value,
3230

33-
pub register_lib: &'ll llvm::Value,
34-
pub unregister_lib: &'ll llvm::Value,
3531
pub init_rtls: &'ll llvm::Value,
3632
}
3733

@@ -43,35 +39,79 @@ impl<'ll> OffloadGlobals<'ll> {
4339
let (begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers(cx);
4440
let ident_t_global = generate_at_one(cx);
4541

46-
let tptr = cx.type_ptr();
47-
let ti32 = cx.type_i32();
48-
let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
49-
let bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
50-
cx.set_struct_body(bin_desc, &tgt_bin_desc_ty, false);
51-
52-
let reg_lib_decl = cx.type_func(&[cx.type_ptr()], cx.type_void());
53-
let register_lib = declare_offload_fn(&cx, "__tgt_register_lib", reg_lib_decl);
54-
let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", reg_lib_decl);
5542
let init_ty = cx.type_func(&[], cx.type_void());
5643
let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
5744

5845
OffloadGlobals {
5946
launcher_fn,
6047
launcher_ty,
61-
bin_desc,
6248
kernel_args_ty,
6349
offload_entry_ty,
6450
begin_mapper,
6551
end_mapper,
6652
mapper_fn_ty,
6753
ident_t_global,
68-
register_lib,
69-
unregister_lib,
7054
init_rtls,
7155
}
7256
}
7357
}
7458

59+
pub(crate) fn setup<'ll>(cx: &CodegenCx<'ll, '_>) {
60+
let reg_lib_decl = cx.type_func(&[cx.type_ptr()], cx.type_void());
61+
let register_lib = declare_offload_fn(&cx, "__tgt_register_lib", reg_lib_decl);
62+
let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", reg_lib_decl);
63+
64+
let i32_0 = cx.get_const_i32(0);
65+
let ptr_null = cx.const_null(cx.type_ptr());
66+
let const_struct = cx.const_struct(&[i32_0, ptr_null, ptr_null, ptr_null], false);
67+
let omp_descriptor =
68+
add_global(cx, ".omp_offloading.descriptor", const_struct, InternalLinkage);
69+
// @.omp_offloading.descriptor = internal constant %__tgt_bin_desc { i32 1, ptr @.omp_offloading.device_images, ptr @__start_llvm_offload_entries, ptr @__stop_llvm_offload_entries }
70+
// @.omp_offloading.descriptor = internal constant %__tgt_bin_desc { i32 0, ptr null, ptr null, ptr null }
71+
unsafe { llvm::LLVMDumpModule(cx.llmod()) };
72+
73+
let atexit = cx.type_func(&[cx.type_ptr()], cx.type_i32());
74+
let atexit_fn = declare_offload_fn(cx, "atexit", atexit);
75+
76+
let reg_name = ".omp_offloading.descriptor_reg";
77+
let unreg_name = ".omp_offloading.descriptor_unreg";
78+
let desc_ty = cx.type_func(&[], cx.type_void());
79+
let desc_reg_fn = declare_offload_fn(cx, reg_name, desc_ty);
80+
let desc_unreg_fn = declare_offload_fn(cx, unreg_name, desc_ty);
81+
llvm::set_linkage(desc_reg_fn, InternalLinkage);
82+
llvm::set_linkage(desc_unreg_fn, InternalLinkage);
83+
llvm::set_section(desc_reg_fn, c".text.startup");
84+
llvm::set_section(desc_unreg_fn, c".text.startup");
85+
86+
// define internal void @.omp_offloading.descriptor_reg() section ".text.startup" {
87+
// entry:
88+
// call void @__tgt_register_lib(ptr @.omp_offloading.descriptor)
89+
// %0 = call i32 @atexit(ptr @.omp_offloading.descriptor_unreg)
90+
// ret void
91+
// }
92+
let bb = Builder::append_block(cx, desc_reg_fn, "entry");
93+
let mut a = Builder::build(cx, bb);
94+
a.call(reg_lib_decl, None, None, register_lib, &[omp_descriptor], None, None);
95+
a.call(atexit, None, None, atexit_fn, &[desc_unreg_fn], None, None);
96+
a.ret_void();
97+
98+
// define internal void @.omp_offloading.descriptor_unreg() section ".text.startup" {
99+
// entry:
100+
// call void @__tgt_unregister_lib(ptr @.omp_offloading.descriptor)
101+
// ret void
102+
// }
103+
let bb = Builder::append_block(cx, desc_unreg_fn, "entry");
104+
let mut a = Builder::build(cx, bb);
105+
a.call(reg_lib_decl, None, None, unregister_lib, &[omp_descriptor], None, None);
106+
a.ret_void();
107+
108+
// @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 101, ptr @.omp_offloading.descriptor_reg, ptr null }]
109+
let args = vec![cx.get_const_i32(101), desc_reg_fn, ptr_null];
110+
let const_struct = cx.const_struct(&args, false);
111+
let arr = cx.const_array(cx.val_ty(const_struct), &[const_struct]);
112+
let _global_ctor = add_global(cx, "llvm.global_ctors", arr, AppendingLinkage);
113+
}
114+
75115
pub(crate) struct OffloadKernelDims<'ll> {
76116
num_workgroups: &'ll Value,
77117
threads_per_block: &'ll Value,
@@ -478,9 +518,6 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
478518
let tgt_decl = offload_globals.launcher_fn;
479519
let tgt_target_kernel_ty = offload_globals.launcher_ty;
480520

481-
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
482-
let tgt_bin_desc = offload_globals.bin_desc;
483-
484521
let tgt_kernel_decl = offload_globals.kernel_args_ty;
485522
let begin_mapper_decl = offload_globals.begin_mapper;
486523
let end_mapper_decl = offload_globals.end_mapper;
@@ -504,12 +541,9 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
504541
}
505542

506543
// Step 0)
507-
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
508-
// %6 = alloca %struct.__tgt_bin_desc, align 8
509544
unsafe {
510545
llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, builder.llfn());
511546
}
512-
let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
513547

514548
let ty = cx.type_array(cx.type_ptr(), num_args);
515549
// Baseptr are just the input pointer to the kernel, stored in a local alloca
@@ -527,7 +561,6 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
527561
unsafe {
528562
llvm::LLVMPositionBuilderAtEnd(&builder.llbuilder, bb);
529563
}
530-
builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
531564

532565
// Now we allocate once per function param, a copy to be passed to one of our maps.
533566
let mut vals = vec![];
@@ -539,15 +572,9 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
539572
geps.push(gep);
540573
}
541574

542-
let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
543-
let register_lib_decl = offload_globals.register_lib;
544-
let unregister_lib_decl = offload_globals.unregister_lib;
545575
let init_ty = cx.type_func(&[], cx.type_void());
546576
let init_rtls_decl = offload_globals.init_rtls;
547577

548-
// FIXME(offload): Later we want to add them to the wrapper code, rather than our main function.
549-
// call void @__tgt_register_lib(ptr noundef %6)
550-
builder.call(mapper_fn_ty, None, None, register_lib_decl, &[tgt_bin_desc_alloca], None, None);
551578
// call void @__tgt_init_all_rtls()
552579
builder.call(init_ty, None, None, init_rtls_decl, &[], None, None);
553580

@@ -644,6 +671,4 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
644671
num_args,
645672
s_ident_t,
646673
);
647-
648-
builder.call(mapper_fn_ty, None, None, unregister_lib_decl, &[tgt_bin_desc_alloca], None, None);
649674
}

compiler/rustc_codegen_llvm/src/common.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
124124
pub(crate) fn const_null(&self, t: &'ll Type) -> &'ll Value {
125125
unsafe { llvm::LLVMConstNull(t) }
126126
}
127+
128+
pub(crate) fn const_struct(&self, elts: &[&'ll Value], packed: bool) -> &'ll Value {
129+
struct_in_context(self.llcx(), elts, packed)
130+
}
127131
}
128132

129133
impl<'ll, 'tcx> ConstCodegenMethods for CodegenCx<'ll, 'tcx> {

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ use tracing::debug;
3030
use crate::abi::FnAbiLlvmExt;
3131
use crate::builder::Builder;
3232
use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call};
33-
use crate::builder::gpu_offload::{OffloadKernelDims, gen_call_handling, gen_define_handling};
33+
use crate::builder::gpu_offload::{
34+
OffloadKernelDims, gen_call_handling, gen_define_handling, setup,
35+
};
3436
use crate::context::CodegenCx;
3537
use crate::declare::declare_raw_fn;
3638
use crate::errors::{
@@ -1403,6 +1405,7 @@ fn codegen_offload<'ll, 'tcx>(
14031405
return;
14041406
}
14051407
};
1408+
setup(cx);
14061409
let offload_data = gen_define_handling(&cx, &metadata, &types, target_symbol, offload_globals);
14071410
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals, &offload_dims);
14081411
}

0 commit comments

Comments
 (0)