diff --git a/clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp b/clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp index dba29d8cab83..112d268a64d5 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp @@ -75,6 +75,12 @@ class CIRGenNVCUDARuntime : public CIRGenCUDARuntime { mlir::Operation *getKernelHandle(cir::FuncOp fn, GlobalDecl GD) override; + mlir::Operation *getKernelStub(mlir::Operation *handle) override { + auto loc = KernelStubs.find(handle); + assert(loc != KernelStubs.end()); + return loc->second; + } + void internalizeDeviceSideVar(const VarDecl *d, cir::GlobalLinkageKind &linkage) override; /// Returns function or variable name on device side even if the current diff --git a/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h b/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h index 0694a9a95d6f..a7c99b75cb36 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h +++ b/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h @@ -45,6 +45,9 @@ class CIRGenCUDARuntime { const CUDAKernelCallExpr *expr, ReturnValueSlot retValue); virtual mlir::Operation *getKernelHandle(cir::FuncOp fn, GlobalDecl GD) = 0; + /// Get kernel stub by kernel handle. + virtual mlir::Operation *getKernelStub(mlir::Operation *handle) = 0; + virtual void internalizeDeviceSideVar(const VarDecl *d, cir::GlobalLinkageKind &linkage) = 0; /// Returns function or variable name on device side even if the current diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp index 449c7625b267..2241fd9bb573 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp @@ -533,9 +533,11 @@ static CIRGenCallee emitDirectCallee(CIRGenModule &CGM, GlobalDecl GD) { return CIRGenCallee::forBuiltin(builtinID, FD); } - auto CalleePtr = emitFunctionDeclPointer(CGM, GD); + mlir::Operation *CalleePtr = emitFunctionDeclPointer(CGM, GD); - assert(!CGM.getLangOpts().CUDA && "NYI"); + if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice && + FD->hasAttr()) + CalleePtr = CGM.getCUDARuntime().getKernelStub(CalleePtr); return CIRGenCallee::forDirect(CalleePtr, GD); } diff --git a/clang/test/CIR/CodeGen/CUDA/cuda-builtin-vars.cu b/clang/test/CIR/CodeGen/CUDA/cuda-builtin-vars.cu index fcdb6c0cb30b..1f4042f01bda 100644 --- a/clang/test/CIR/CodeGen/CUDA/cuda-builtin-vars.cu +++ b/clang/test/CIR/CodeGen/CUDA/cuda-builtin-vars.cu @@ -6,84 +6,102 @@ // RUN: -fcuda-is-device -emit-cir -o - %s \ // RUN: | FileCheck --check-prefix=CIR %s +// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda \ +// RUN: -fcuda-is-device -emit-llvm -o - %s \ +// RUN: | FileCheck --check-prefix=OGCG %s + #include "__clang_cuda_builtin_vars.h" // LLVM: define{{.*}} void @_Z6kernelPi(ptr %0) -// CIR-LABEL: @_Z6kernelPi +// OGCG: define{{.*}} void @_Z6kernelPi(ptr noundef %out) __attribute__((global)) void kernel(int *out) { int i = 0; - // out[i++] = threadIdx.x; - // CIR-DISABLED: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_xEv() - // CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.x" - // LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.x() - - // out[i++] = threadIdx.y; - // CIR-DISABLED: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_yEv() - // CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.y" - // LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.y() - - // out[i++] = threadIdx.z; - // CIR-DISABLED: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_zEv() - // CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.z" - // LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.z() - - - // out[i++] = blockIdx.x; - // CIR-DISABLED: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_xEv() - // CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.x" - // LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() - - // out[i++] = blockIdx.y; - // CIR-DISABLED: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_yEv() - // CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.y" - // LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.y() - - // out[i++] = blockIdx.z; - // CIR-DISABLED: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_zEv() - // CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.z" - // LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.z() - - - // out[i++] = blockDim.x; - // CIR-DISABLED: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_xEv() - // CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.x" - // LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.x() - - // out[i++] = blockDim.y; - // CIR-DISABLED: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_yEv() - // CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.y" - // LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.y() - - // out[i++] = blockDim.z; - // CIR-DISABLED: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_zEv() - // CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.z" - // LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.z() - - - // out[i++] = gridDim.x; - // CIR-DISABLED: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_xEv() - // CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.x" - // LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.x() - - // out[i++] = gridDim.y; - // CIR-DISABLED: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_yEv() - // CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.y" - // LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.y() - - // out[i++] = gridDim.z; - // CIR-DISABLED: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_zEv() - // CIR-DISABLED: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.z" - // LLVM-DISABLED: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.z() - - - // out[i++] = warpSize; - // CIR-DISABLED: [[REGISTER:%.*]] = cir.const #cir.int<32> - // CIR-DISABLED: cir.store{{.*}} [[REGISTER]] - // LLVM-DISABLED: store i32 32, - - - // CIR-DISABLED: cir.return loc - // LLVM-DISABLED: ret void + out[i++] = threadIdx.x; + // CIR: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_xEv() + // CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.x" + // LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.x() + // OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.x() + + out[i++] = threadIdx.y; + // CIR: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_yEv() + // CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.y" + // LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.y() + // OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.y() + + out[i++] = threadIdx.z; + // CIR: cir.func linkonce_odr @_ZN26__cuda_builtin_threadIdx_t17__fetch_builtin_zEv() + // CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.tid.z" + // LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.z() + // OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.z() + + + out[i++] = blockIdx.x; + // CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_xEv() + // CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.x" + // LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() + // OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() + + out[i++] = blockIdx.y; + // CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_yEv() + // CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.y" + // LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.y() + // OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.y() + + out[i++] = blockIdx.z; + // CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockIdx_t17__fetch_builtin_zEv() + // CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ctaid.z" + // LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.z() + // OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.z() + + + out[i++] = blockDim.x; + // CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_xEv() + // CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.x" + // LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + // OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + + out[i++] = blockDim.y; + // CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_yEv() + // CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.y" + // LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.y() + // OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.y() + + out[i++] = blockDim.z; + // CIR: cir.func linkonce_odr @_ZN25__cuda_builtin_blockDim_t17__fetch_builtin_zEv() + // CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.ntid.z" + // LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.z() + // OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.z() + + + out[i++] = gridDim.x; + // CIR: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_xEv() + // CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.x" + // LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.x() + // OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.x() + + out[i++] = gridDim.y; + // CIR: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_yEv() + // CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.y" + // LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.y() + // OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.y() + + out[i++] = gridDim.z; + // CIR: cir.func linkonce_odr @_ZN24__cuda_builtin_gridDim_t17__fetch_builtin_zEv() + // CIR: cir.llvm.intrinsic "nvvm.read.ptx.sreg.nctaid.z" + // LLVM: call{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.z() + // OGCG: call noundef{{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.z() + + + out[i++] = warpSize; + // CIR: [[REGISTER:%.*]] = cir.const #cir.int<32> + // CIR: cir.store{{.*}} [[REGISTER]] + // LLVM: store i32 32, + // OGCG: store i32 32, + + + // CIR: cir.return loc + // LLVM: ret void + // OGCG: ret void }