Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][CUDA] Generate device stubs #1332

Merged
merged 1 commit into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
//===--- CIRGenCUDARuntime.cpp - Interface to CUDA Runtimes ----*- C++ -*--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This provides an abstract class for CUDA CIR generation. Concrete
// subclasses of this implement code generation for specific OpenCL
// runtime libraries.
//
//===----------------------------------------------------------------------===//

#include "CIRGenCUDARuntime.h"
#include "CIRGenFunction.h"
#include "clang/Basic/Cuda.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"

using namespace clang;
using namespace clang::CIRGen;

CIRGenCUDARuntime::~CIRGenCUDARuntime() {}

void CIRGenCUDARuntime::emitDeviceStubBodyLegacy(CIRGenFunction &cgf,
cir::FuncOp fn,
FunctionArgList &args) {
llvm_unreachable("NYI");
}

void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
cir::FuncOp fn,
FunctionArgList &args) {
if (cgm.getLangOpts().HIP)
llvm_unreachable("NYI");

// This requires arguments to be sent to kernels in a different way.
if (cgm.getLangOpts().OffloadViaLLVM)
llvm_unreachable("NYI");

auto &builder = cgm.getBuilder();

// For cudaLaunchKernel, we must add another layer of indirection
// to arguments. For example, for function `add(int a, float b)`,
// we need to pass it as `void *args[2] = { &a, &b }`.

auto loc = fn.getLoc();
auto voidPtrArrayTy =
cir::ArrayType::get(&cgm.getMLIRContext(), cgm.VoidPtrTy, args.size());
mlir::Value kernelArgs = builder.createAlloca(
loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy, "kernel_args",
CharUnits::fromQuantity(16));

// Store arguments into kernelArgs
for (auto [i, arg] : llvm::enumerate(args)) {
mlir::Value index =
builder.getConstInt(loc, llvm::APInt(/*numBits=*/32, i));
mlir::Value storePos = builder.createPtrStride(loc, kernelArgs, index);
builder.CIRBaseBuilderTy::createStore(
loc, cgf.GetAddrOfLocalVar(arg).getPointer(), storePos);
}

// We retrieve dim3 type by looking into the second argument of
// cudaLaunchKernel, as is done in OG.
TranslationUnitDecl *tuDecl = cgm.getASTContext().getTranslationUnitDecl();
DeclContext *dc = TranslationUnitDecl::castToDeclContext(tuDecl);

// The default stream is usually stream 0 (the legacy default stream).
// For per-thread default stream, we need a different LaunchKernel function.
if (cgm.getLangOpts().GPUDefaultStream ==
LangOptions::GPUDefaultStreamKind::PerThread)
llvm_unreachable("NYI");

std::string launchAPI = "cudaLaunchKernel";
const IdentifierInfo &launchII = cgm.getASTContext().Idents.get(launchAPI);
FunctionDecl *launchFD = nullptr;
for (auto *result : dc->lookup(&launchII)) {
if (FunctionDecl *fd = dyn_cast<FunctionDecl>(result))
launchFD = fd;
}

if (launchFD == nullptr) {
cgm.Error(cgf.CurFuncDecl->getLocation(),
"Can't find declaration for " + launchAPI);
return;
}

// Use this function to retrieve arguments for cudaLaunchKernel:
// int __cudaPopCallConfiguration(dim3 *gridDim, dim3 *blockDim, size_t
// *sharedMem, cudaStream_t *stream)
//
// Here cudaStream_t, while also being the 6th argument of cudaLaunchKernel,
// is a pointer to some opaque struct.

mlir::Type dim3Ty =
cgf.getTypes().convertType(launchFD->getParamDecl(1)->getType());
mlir::Type streamTy =
cgf.getTypes().convertType(launchFD->getParamDecl(5)->getType());

mlir::Value gridDim =
builder.createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
"grid_dim", CharUnits::fromQuantity(8));
mlir::Value blockDim =
builder.createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
"block_dim", CharUnits::fromQuantity(8));
mlir::Value sharedMem =
builder.createAlloca(loc, cir::PointerType::get(cgm.SizeTy), cgm.SizeTy,
"shared_mem", cgm.getSizeAlign());
mlir::Value stream =
builder.createAlloca(loc, cir::PointerType::get(streamTy), streamTy,
"stream", cgm.getPointerAlign());

cir::FuncOp popConfig = cgm.createRuntimeFunction(
cir::FuncType::get({gridDim.getType(), blockDim.getType(),
sharedMem.getType(), stream.getType()},
cgm.SInt32Ty),
"__cudaPopCallConfiguration");
cgf.emitRuntimeCall(loc, popConfig, {gridDim, blockDim, sharedMem, stream});

// Now emit the call to cudaLaunchKernel
// cudaError_t cudaLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim,
// void **args, size_t sharedMem,
// cudaStream_t stream);
auto kernelTy =
cir::PointerType::get(&cgm.getMLIRContext(), fn.getFunctionType());

mlir::Value kernel =
builder.create<cir::GetGlobalOp>(loc, kernelTy, fn.getSymName());
mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy);
CallArgList launchArgs;

mlir::Value kernelArgsDecayed =
builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
cir::PointerType::get(cgm.VoidPtrTy));

launchArgs.add(RValue::get(func), launchFD->getParamDecl(0)->getType());
launchArgs.add(
RValue::getAggregate(Address(gridDim, CharUnits::fromQuantity(8))),
launchFD->getParamDecl(1)->getType());
launchArgs.add(
RValue::getAggregate(Address(blockDim, CharUnits::fromQuantity(8))),
launchFD->getParamDecl(2)->getType());
launchArgs.add(RValue::get(kernelArgsDecayed),
launchFD->getParamDecl(3)->getType());
launchArgs.add(
RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, sharedMem)),
launchFD->getParamDecl(4)->getType());
launchArgs.add(RValue::get(stream), launchFD->getParamDecl(5)->getType());

mlir::Type launchTy = cgm.getTypes().convertType(launchFD->getType());
mlir::Operation *launchFn =
cgm.createRuntimeFunction(cast<cir::FuncType>(launchTy), launchAPI);
const auto &callInfo = cgm.getTypes().arrangeFunctionDeclaration(launchFD);
cgf.emitCall(callInfo, CIRGenCallee::forDirect(launchFn), ReturnValueSlot(),
launchArgs);
}

void CIRGenCUDARuntime::emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
FunctionArgList &args) {
// Device stub and its handle might be different.
if (cgm.getLangOpts().HIP)
llvm_unreachable("NYI");

// CUDA 9.0 changed the way to launch kernels.
if (CudaFeatureEnabled(cgm.getTarget().getSDKVersion(),
CudaFeature::CUDA_USES_NEW_LAUNCH) ||
cgm.getLangOpts().OffloadViaLLVM)
emitDeviceStubBodyNew(cgf, fn, args);
else
emitDeviceStubBodyLegacy(cgf, fn, args);
}
47 changes: 47 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//===------ CIRGenCUDARuntime.h - Interface to CUDA Runtimes -----*- C++ -*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This provides an abstract class for CUDA CIR generation. Concrete
// subclasses of this implement code generation for specific OpenCL
// runtime libraries.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H
#define LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H

#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"

namespace clang::CIRGen {

class CIRGenFunction;
class CIRGenModule;
class FunctionArgList;

class CIRGenCUDARuntime {
protected:
CIRGenModule &cgm;

private:
void emitDeviceStubBodyLegacy(CIRGenFunction &cgf, cir::FuncOp fn,
FunctionArgList &args);
void emitDeviceStubBodyNew(CIRGenFunction &cgf, cir::FuncOp fn,
FunctionArgList &args);

public:
CIRGenCUDARuntime(CIRGenModule &cgm) : cgm(cgm) {}
virtual ~CIRGenCUDARuntime();

virtual void emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
FunctionArgList &args);
};

} // namespace clang::CIRGen

#endif // LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ cir::FuncOp CIRGenFunction::generateCode(clang::GlobalDecl GD, cir::FuncOp Fn,
emitConstructorBody(Args);
else if (getLangOpts().CUDA && !getLangOpts().CUDAIsDevice &&
FD->hasAttr<CUDAGlobalAttr>())
llvm_unreachable("NYI");
CGM.getCUDARuntime().emitDeviceStub(*this, Fn, Args);
else if (isa<CXXMethodDecl>(FD) &&
cast<CXXMethodDecl>(FD)->isLambdaStaticInvoker()) {
// The lambda static invoker function is special, because it forwards or
Expand Down
4 changes: 3 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// This is the internal per-translation-unit state used for CIR translation.
//
//===----------------------------------------------------------------------===//
#include "CIRGenCUDARuntime.h"
#include "CIRGenCXXABI.h"
#include "CIRGenCstEmitter.h"
#include "CIRGenFunction.h"
Expand Down Expand Up @@ -108,7 +109,8 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &mlirContext,
theModule{mlir::ModuleOp::create(builder.getUnknownLoc())}, Diags(Diags),
target(astContext.getTargetInfo()), ABI(createCXXABI(*this)),
genTypes{*this}, VTables{*this},
openMPRuntime(new CIRGenOpenMPRuntime(*this)) {
openMPRuntime(new CIRGenOpenMPRuntime(*this)),
cudaRuntime(new CIRGenCUDARuntime(*this)) {

// Initialize CIR signed integer types cache.
SInt8Ty = cir::IntType::get(&getMLIRContext(), 8, /*isSigned=*/true);
Expand Down
12 changes: 11 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "Address.h"
#include "CIRGenBuilder.h"
#include "CIRGenCUDARuntime.h"
#include "CIRGenCall.h"
#include "CIRGenOpenCLRuntime.h"
#include "CIRGenTBAA.h"
Expand Down Expand Up @@ -113,6 +114,9 @@ class CIRGenModule : public CIRGenTypeCache {
/// Holds the OpenMP runtime
std::unique_ptr<CIRGenOpenMPRuntime> openMPRuntime;

/// Holds the CUDA runtime
std::unique_ptr<CIRGenCUDARuntime> cudaRuntime;

/// Per-function codegen information. Updated everytime emitCIR is called
/// for FunctionDecls's.
CIRGenFunction *CurCGF = nullptr;
Expand Down Expand Up @@ -862,12 +866,18 @@ class CIRGenModule : public CIRGenTypeCache {
/// Print out an error that codegen doesn't support the specified decl yet.
void ErrorUnsupported(const Decl *D, const char *Type);

/// Return a reference to the configured OpenMP runtime.
/// Return a reference to the configured OpenCL runtime.
CIRGenOpenCLRuntime &getOpenCLRuntime() {
assert(openCLRuntime != nullptr);
return *openCLRuntime;
}

/// Return a reference to the configured CUDA runtime.
CIRGenCUDARuntime &getCUDARuntime() {
assert(cudaRuntime != nullptr);
return *cudaRuntime;
}

void createOpenCLRuntime() {
openCLRuntime.reset(new CIRGenOpenCLRuntime(*this));
}
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/CodeGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_clang_library(clangCIR
CIRGenClass.cpp
CIRGenCleanup.cpp
CIRGenCoroutine.cpp
CIRGenCUDARuntime.cpp
CIRGenDecl.cpp
CIRGenDeclCXX.cpp
CIRGenException.cpp
Expand Down
26 changes: 14 additions & 12 deletions clang/test/CIR/CodeGen/CUDA/simple.cu
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
#include "../Inputs/cuda.h"

// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
// RUN: -x cuda -emit-cir %s -o %t.cir
// RUN: -x cuda -emit-cir -target-sdk-version=12.3 \
// RUN: %s -o %t.cir
// RUN: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s

// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
// RUN: -fcuda-is-device -emit-cir %s -o %t.cir
// RUN: -fcuda-is-device -emit-cir -target-sdk-version=12.3 \
// RUN: %s -o %t.cir
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s

// Attribute for global_fn
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fnv>{{.*}}
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fni>{{.*}}

__host__ void host_fn(int *a, int *b, int *c) {}
// CIR-HOST: cir.func @_Z7host_fnPiS_S_
Expand All @@ -19,13 +21,13 @@ __device__ void device_fn(int* a, double b, float c) {}
// CIR-HOST-NOT: cir.func @_Z9device_fnPidf
// CIR-DEVICE: cir.func @_Z9device_fnPidf

#ifdef __CUDA_ARCH__
__global__ void global_fn() {}
#else
__global__ void global_fn();
#endif
// CIR-HOST: @_Z24__device_stub__global_fnv(){{.*}}extra([[Kernel]])
// CIR-DEVICE: @_Z9global_fnv
__global__ void global_fn(int a) {}
// CIR-DEVICE: @_Z9global_fni

// Make sure `global_fn` indeed gets emitted
__host__ void x() { auto v = global_fn; }
// Check for device stub emission.

// CIR-HOST: @_Z24__device_stub__global_fni{{.*}}extra([[Kernel]])
// CIR-HOST: cir.alloca {{.*}}"kernel_args"
// CIR-HOST: cir.call @__cudaPopCallConfiguration
bcardosolopes marked this conversation as resolved.
Show resolved Hide resolved
// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni
// CIR-HOST: cir.call @cudaLaunchKernel