@@ -19,8 +19,6 @@ pub(crate) struct OffloadGlobals<'ll> {
1919 pub launcher_fn : & ' ll llvm:: Value ,
2020 pub launcher_ty : & ' ll llvm:: Type ,
2121
22- pub bin_desc : & ' ll llvm:: Type ,
23-
2422 pub kernel_args_ty : & ' ll llvm:: Type ,
2523
2624 pub offload_entry_ty : & ' ll llvm:: Type ,
@@ -31,8 +29,8 @@ pub(crate) struct OffloadGlobals<'ll> {
3129
3230 pub ident_t_global : & ' ll llvm:: Value ,
3331
34- pub register_lib : & ' ll llvm :: Value ,
35- pub unregister_lib : & ' ll llvm :: Value ,
32+ // FIXME(offload): Drop this, once we fully automated our offload compilation pipeline, since
33+ // LLVM will initialize them for us if it sees gpu kernels being registered.
3634 pub init_rtls : & ' ll llvm:: Value ,
3735}
3836
@@ -44,15 +42,6 @@ impl<'ll> OffloadGlobals<'ll> {
4442 let ( begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers ( cx) ;
4543 let ident_t_global = generate_at_one ( cx) ;
4644
47- let tptr = cx. type_ptr ( ) ;
48- let ti32 = cx. type_i32 ( ) ;
49- let tgt_bin_desc_ty = vec ! [ ti32, tptr, tptr, tptr] ;
50- let bin_desc = cx. type_named_struct ( "struct.__tgt_bin_desc" ) ;
51- cx. set_struct_body ( bin_desc, & tgt_bin_desc_ty, false ) ;
52-
53- let reg_lib_decl = cx. type_func ( & [ cx. type_ptr ( ) ] , cx. type_void ( ) ) ;
54- let register_lib = declare_offload_fn ( & cx, "__tgt_register_lib" , reg_lib_decl) ;
55- let unregister_lib = declare_offload_fn ( & cx, "__tgt_unregister_lib" , reg_lib_decl) ;
5645 let init_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
5746 let init_rtls = declare_offload_fn ( cx, "__tgt_init_all_rtls" , init_ty) ;
5847
@@ -63,20 +52,84 @@ impl<'ll> OffloadGlobals<'ll> {
6352 OffloadGlobals {
6453 launcher_fn,
6554 launcher_ty,
66- bin_desc,
6755 kernel_args_ty,
6856 offload_entry_ty,
6957 begin_mapper,
7058 end_mapper,
7159 mapper_fn_ty,
7260 ident_t_global,
73- register_lib,
74- unregister_lib,
7561 init_rtls,
7662 }
7763 }
7864}
7965
66+ // We need to register offload before using it. We also should unregister it once we are done, for
67+ // good measures. Previously we have done so before and after each individual offload intrinsic
68+ // call, but that comes at a performance cost. The repeated (un)register calls might also confuse
69+ // the LLVM ompOpt pass, which tries to move operations to a better location. The easiest solution,
70+ // which we copy from clang, is to just have those two calls once, in the global ctor/dtor section
71+ // of the final binary.
72+ pub ( crate ) fn register_offload < ' ll > ( cx : & CodegenCx < ' ll , ' _ > ) {
73+ // First we check quickly whether we already have done our setup, in which case we return early.
74+ // Shouldn't be needed for correctness.
75+ let register_lib_name = "__tgt_register_lib" ;
76+ if cx. get_function ( register_lib_name) . is_some ( ) {
77+ return ;
78+ }
79+
80+ let reg_lib_decl = cx. type_func ( & [ cx. type_ptr ( ) ] , cx. type_void ( ) ) ;
81+ let register_lib = declare_offload_fn ( & cx, register_lib_name, reg_lib_decl) ;
82+ let unregister_lib = declare_offload_fn ( & cx, "__tgt_unregister_lib" , reg_lib_decl) ;
83+
84+ let ptr_null = cx. const_null ( cx. type_ptr ( ) ) ;
85+ let const_struct = cx. const_struct ( & [ cx. get_const_i32 ( 0 ) , ptr_null, ptr_null, ptr_null] , false ) ;
86+ let omp_descriptor =
87+ add_global ( cx, ".omp_offloading.descriptor" , const_struct, InternalLinkage ) ;
88+ // @.omp_offloading.descriptor = internal constant %__tgt_bin_desc { i32 1, ptr @.omp_offloading.device_images, ptr @__start_llvm_offload_entries, ptr @__stop_llvm_offload_entries }
89+ // @.omp_offloading.descriptor = internal constant %__tgt_bin_desc { i32 0, ptr null, ptr null, ptr null }
90+
91+ let atexit = cx. type_func ( & [ cx. type_ptr ( ) ] , cx. type_i32 ( ) ) ;
92+ let atexit_fn = declare_offload_fn ( cx, "atexit" , atexit) ;
93+
94+ let desc_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
95+ let reg_name = ".omp_offloading.descriptor_reg" ;
96+ let unreg_name = ".omp_offloading.descriptor_unreg" ;
97+ let desc_reg_fn = declare_offload_fn ( cx, reg_name, desc_ty) ;
98+ let desc_unreg_fn = declare_offload_fn ( cx, unreg_name, desc_ty) ;
99+ llvm:: set_linkage ( desc_reg_fn, InternalLinkage ) ;
100+ llvm:: set_linkage ( desc_unreg_fn, InternalLinkage ) ;
101+ llvm:: set_section ( desc_reg_fn, c".text.startup" ) ;
102+ llvm:: set_section ( desc_unreg_fn, c".text.startup" ) ;
103+
104+ // define internal void @.omp_offloading.descriptor_reg() section ".text.startup" {
105+ // entry:
106+ // call void @__tgt_register_lib(ptr @.omp_offloading.descriptor)
107+ // %0 = call i32 @atexit(ptr @.omp_offloading.descriptor_unreg)
108+ // ret void
109+ // }
110+ let bb = Builder :: append_block ( cx, desc_reg_fn, "entry" ) ;
111+ let mut a = Builder :: build ( cx, bb) ;
112+ a. call ( reg_lib_decl, None , None , register_lib, & [ omp_descriptor] , None , None ) ;
113+ a. call ( atexit, None , None , atexit_fn, & [ desc_unreg_fn] , None , None ) ;
114+ a. ret_void ( ) ;
115+
116+ // define internal void @.omp_offloading.descriptor_unreg() section ".text.startup" {
117+ // entry:
118+ // call void @__tgt_unregister_lib(ptr @.omp_offloading.descriptor)
119+ // ret void
120+ // }
121+ let bb = Builder :: append_block ( cx, desc_unreg_fn, "entry" ) ;
122+ let mut a = Builder :: build ( cx, bb) ;
123+ a. call ( reg_lib_decl, None , None , unregister_lib, & [ omp_descriptor] , None , None ) ;
124+ a. ret_void ( ) ;
125+
126+ // @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 101, ptr @.omp_offloading.descriptor_reg, ptr null }]
127+ let args = vec ! [ cx. get_const_i32( 101 ) , desc_reg_fn, ptr_null] ;
128+ let const_struct = cx. const_struct ( & args, false ) ;
129+ let arr = cx. const_array ( cx. val_ty ( const_struct) , & [ const_struct] ) ;
130+ add_global ( cx, "llvm.global_ctors" , arr, AppendingLinkage ) ;
131+ }
132+
80133pub ( crate ) struct OffloadKernelDims < ' ll > {
81134 num_workgroups : & ' ll Value ,
82135 threads_per_block : & ' ll Value ,
@@ -487,9 +540,6 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
487540 let tgt_decl = offload_globals. launcher_fn ;
488541 let tgt_target_kernel_ty = offload_globals. launcher_ty ;
489542
490- // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
491- let tgt_bin_desc = offload_globals. bin_desc ;
492-
493543 let tgt_kernel_decl = offload_globals. kernel_args_ty ;
494544 let begin_mapper_decl = offload_globals. begin_mapper ;
495545 let end_mapper_decl = offload_globals. end_mapper ;
@@ -513,12 +563,9 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
513563 }
514564
515565 // Step 0)
516- // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
517- // %6 = alloca %struct.__tgt_bin_desc, align 8
518566 unsafe {
519567 llvm:: LLVMRustPositionBuilderPastAllocas ( & builder. llbuilder , builder. llfn ( ) ) ;
520568 }
521- let tgt_bin_desc_alloca = builder. direct_alloca ( tgt_bin_desc, Align :: EIGHT , "EmptyDesc" ) ;
522569
523570 let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
524571 // Baseptr are just the input pointer to the kernel, stored in a local alloca
@@ -536,7 +583,6 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
536583 unsafe {
537584 llvm:: LLVMPositionBuilderAtEnd ( & builder. llbuilder , bb) ;
538585 }
539- builder. memset ( tgt_bin_desc_alloca, cx. get_const_i8 ( 0 ) , cx. get_const_i64 ( 32 ) , Align :: EIGHT ) ;
540586
541587 // Now we allocate once per function param, a copy to be passed to one of our maps.
542588 let mut vals = vec ! [ ] ;
@@ -574,15 +620,9 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
574620 geps. push ( gep) ;
575621 }
576622
577- let mapper_fn_ty = cx. type_func ( & [ cx. type_ptr ( ) ] , cx. type_void ( ) ) ;
578- let register_lib_decl = offload_globals. register_lib ;
579- let unregister_lib_decl = offload_globals. unregister_lib ;
580623 let init_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
581624 let init_rtls_decl = offload_globals. init_rtls ;
582625
583- // FIXME(offload): Later we want to add them to the wrapper code, rather than our main function.
584- // call void @__tgt_register_lib(ptr noundef %6)
585- builder. call ( mapper_fn_ty, None , None , register_lib_decl, & [ tgt_bin_desc_alloca] , None , None ) ;
586626 // call void @__tgt_init_all_rtls()
587627 builder. call ( init_ty, None , None , init_rtls_decl, & [ ] , None , None ) ;
588628
@@ -679,6 +719,4 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
679719 num_args,
680720 s_ident_t,
681721 ) ;
682-
683- builder. call ( mapper_fn_ty, None , None , unregister_lib_decl, & [ tgt_bin_desc_alloca] , None , None ) ;
684722}
0 commit comments