Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HLSL] Add Increment/DecrementCounter methods to structured buffers #114148

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

hekota
Copy link
Member

@hekota hekota commented Oct 29, 2024

Introduces __builtin_hlsl_buffer_update_counter clang buildin that is used to implement the IncrementCounter and DecrementCounter methods on RWStructuredBuffer and RasterizerOrderedStructuredBuffer (see Note).

The builtin is translated to LLVM intrisic llvm.dx.bufferUpdateCounter or llvm.spv.bufferUpdateCounter.

Introduces BuiltinTypeMethodBuilder helper in HLSLExternalSemaSource that enables adding methods to builtin types using builder pattern like this:

   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
       .addParam("param_name", Type, InOutModifier)
       .callBuiltin("buildin_name", { BuiltinParams })
       .finalizeMethod();

Note: RasterizerOrderedStructuredBuffer does not exist yet, it is being added in PR #113648. After #113648 is merged this PR will be updated to add Increment/DecrementCounter on this buffer type as well.

Fixes #113513

Introduces `__builtin_hlsl_buffer_update_counter` clang buildin that is used to implement IncrementCounter and DecrementCounter methods on RWStructuredBuffer and RasterizerOrderedStructuredBuffer. The builtin is translated to LLVM intrisics llvm.dx.bufferUpdateCounter/llvm.spv.bufferUpdateCounter.

Introduces `BuiltinTypeMethodBuilder` helper in `HLSLExternalSemaSource` that allows adding methods to builtin types
using the builder pattern like this:

   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
       .addParam("param_name", Type, InOutModifier)
       .callBuiltin("buildin_name", { BuiltinParams })
       .finalizeMethod();

Fixes llvm#113513
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen backend:DirectX HLSL HLSL Language Support llvm:ir labels Oct 29, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-backend-directx

Author: Helena Kotas (hekota)

Changes

Introduces __builtin_hlsl_buffer_update_counter clang buildin that is used to implement the IncrementCounter and DecrementCounter methods on RWStructuredBuffer and RasterizerOrderedStructuredBuffer (see Note).

The builtin is translated to LLVM intrisic llvm.dx.bufferUpdateCounter or llvm.spv.bufferUpdateCounter.

Introduces BuiltinTypeMethodBuilder helper in HLSLExternalSemaSource that allows adding methods to builtin types using builder pattern like this:

   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
       .addParam("param_name", Type, InOutModifier)
       .callBuiltin("buildin_name", { BuiltinParams })
       .finalizeMethod();

Note: RasterizerOrderedStructuredBuffer does not exist yet, it is being added in PR llvm/llvm-project#113648. After llvm/llvm-project#113648 is merged this PR will be updated to add Increment/DecrementCounter on this buffer type as well.

Fixes #113513


Patch is 25.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114148.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6-1)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+4)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+8)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Sema/HLSLExternalSemaSource.cpp (+247-31)
  • (modified) clang/lib/Sema/SemaExpr.cpp (+4)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+41)
  • (added) clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl (+25)
  • (added) clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-ps.hlsl (+29)
  • (added) clang/test/SemaHLSL/BuiltIns/buffer_update_counter-errors.hlsl (+22)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+3)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+3)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..72bc2d5e7df23e 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4846,7 +4846,6 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
-
 def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_select"];
   let Attributes = [NoThrow, Const];
@@ -4871,6 +4870,12 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLBufferUpdateCounter : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_buffer_update_counter"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "uint32_t(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 8e4718008ece72..2aea6bb657578a 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -7256,6 +7256,8 @@ def err_typecheck_illegal_increment_decrement : Error<
   "cannot %select{decrement|increment}1 value of type %0">;
 def err_typecheck_expect_int : Error<
   "used type %0 where integer is required">;
+def err_typecheck_expect_hlsl_resource : Error<
+  "used type %0 where __hlsl_resource_t is required">;
 def err_typecheck_arithmetic_incomplete_or_sizeless_type : Error<
   "arithmetic on a pointer to %select{an incomplete|sizeless}0 type %1">;
 def err_typecheck_pointer_arith_function_type : Error<
@@ -12485,6 +12487,8 @@ def warn_attr_min_eq_max:  Warning<
 
 def err_hlsl_attribute_number_arguments_insufficient_shader_model: Error<
   "attribute %0 with %1 arguments requires shader model %2 or greater">;
+def err_hlsl_expect_arg_const_int_one_or_neg_one: Error<
+  "argument %0 must be constant integer 1 or -1">;
 
 // Layout randomization diagnostics.
 def err_non_designated_init_used : Error<
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e2d03eff8ab4a0..71273de3400b17 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18959,6 +18959,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
         nullptr, "hlsl.radians");
   }
+  case Builtin::BI__builtin_hlsl_buffer_update_counter: {
+    Value *ResHandle = EmitScalarExpr(E->getArg(0));
+    Value *Offset = EmitScalarExpr(E->getArg(1));
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/Offset->getType(),
+        CGM.getHLSLRuntime().getBufferUpdateCounterIntrinsic(),
+        ArrayRef<Value *>{ResHandle, Offset}, nullptr);
+  }
   }
   return nullptr;
 }
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ff7df41b5c62e7..aac93dfc373ed4 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -93,6 +93,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
 
   GENERATE_HLSL_INTRINSIC_FUNCTION(CreateHandleFromBinding, handle_fromBinding)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(BufferUpdateCounter, bufferUpdateCounter)
 
   //===----------------------------------------------------------------------===//
   // End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index ce8564429b3802..24c3954b134c5f 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -22,12 +22,15 @@
 #include "clang/Sema/SemaHLSL.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Frontend/HLSL/HLSLResource.h"
+#include "llvm/Support/ErrorHandling.h"
 
 #include <functional>
 
 using namespace clang;
 using namespace llvm::hlsl;
 
+static DeclRefExpr *lookupBuiltinFunction(Sema &S, StringRef Name);
+
 namespace {
 
 struct TemplateParameterListBuilder;
@@ -121,12 +124,8 @@ struct BuiltinTypeDeclBuilder {
     TypeSourceInfo *ElementTypeInfo = nullptr;
 
     QualType ElemTy = Ctx.Char8Ty;
-    if (Template) {
-      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
-              Template->getTemplateParameters()->getParam(0))) {
-        ElemTy = QualType(TTD->getTypeForDecl(), 0);
-      }
-    }
+    if (Template)
+      ElemTy = getFirstTemplateTypeParam();
     ElementTypeInfo = Ctx.getTrivialTypeSourceInfo(ElemTy, SourceLocation());
 
     // add handle member with resource type attributes
@@ -145,25 +144,6 @@ struct BuiltinTypeDeclBuilder {
     return *this;
   }
 
-  static DeclRefExpr *lookupBuiltinFunction(ASTContext &AST, Sema &S,
-                                            StringRef Name) {
-    IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
-    DeclarationNameInfo NameInfo =
-        DeclarationNameInfo(DeclarationName(&II), SourceLocation());
-    LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
-    // AllowBuiltinCreation is false but LookupDirect will create
-    // the builtin when searching the global scope anyways...
-    S.LookupName(R, S.getCurScope());
-    // FIXME: If the builtin function was user-declared in global scope,
-    // this assert *will* fail. Should this call LookupBuiltin instead?
-    assert(R.isSingleResult() &&
-           "Since this is a builtin it should always resolve!");
-    auto *VD = cast<ValueDecl>(R.getFoundDecl());
-    QualType Ty = VD->getType();
-    return DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), SourceLocation(),
-                               VD, false, NameInfo, Ty, VK_PRValue);
-  }
-
   static Expr *emitResourceClassExpr(ASTContext &AST, ResourceClass RC) {
     return IntegerLiteral::Create(
         AST,
@@ -211,12 +191,8 @@ struct BuiltinTypeDeclBuilder {
 
     ASTContext &AST = Record->getASTContext();
     QualType ElemTy = AST.Char8Ty;
-    if (Template) {
-      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
-              Template->getTemplateParameters()->getParam(0))) {
-        ElemTy = QualType(TTD->getTypeForDecl(), 0);
-      }
-    }
+    if (Template)
+      ElemTy = getFirstTemplateTypeParam();
     QualType ReturnTy = ElemTy;
 
     FunctionProtoType::ExtProtoInfo ExtInfo;
@@ -282,6 +258,23 @@ struct BuiltinTypeDeclBuilder {
     return *this;
   }
 
+  FieldDecl *getResourceHandleField() {
+    FieldDecl *FD = Fields["h"];
+    if (FD && FD->getType()->isHLSLAttributedResourceType())
+      return FD;
+    return nullptr;
+  }
+
+  QualType getFirstTemplateTypeParam() {
+    if (Template) {
+      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
+              Template->getTemplateParameters()->getParam(0))) {
+        return QualType(TTD->getTypeForDecl(), 0);
+      }
+    }
+    return QualType();
+  }
+
   BuiltinTypeDeclBuilder &startDefinition() {
     if (Record->isCompleteDefinition())
       return *this;
@@ -302,6 +295,10 @@ struct BuiltinTypeDeclBuilder {
   TemplateParameterListBuilder addTemplateArgumentList(Sema &S);
   BuiltinTypeDeclBuilder &addSimpleTemplateParams(Sema &S,
                                                   ArrayRef<StringRef> Names);
+
+  // Builtin types methods
+  BuiltinTypeDeclBuilder &addIncrementCounterMethod(Sema &S);
+  BuiltinTypeDeclBuilder &addDecrementCounterMethod(Sema &S);
 };
 
 struct TemplateParameterListBuilder {
@@ -359,6 +356,176 @@ struct TemplateParameterListBuilder {
     return Builder;
   }
 };
+
+// Builder for methods of builtin types. Allows adding methods to builtin types
+// using the builder pattern like this:
+//
+//   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
+//       .addParam("param_name", Type, InOutModifier)
+//       .callBuiltin("buildin_name", { BuiltinParams })
+//       .finalizeMethod();
+//
+// The builder needs to have all of the method parameters before it can create
+// a CXXMethodDecl. It collects them in addParam calls and when a first
+// method that builds the body is called it creates the CXXMethodDecl and
+// ParmVarDecls instances. These can then be referenced from the body building
+// methods. Destructor or an explicit call to finalizeMethod() will complete
+// the method definition.
+struct BuiltinTypeMethodBuilder {
+  struct MethodParam {
+    const IdentifierInfo &NameII;
+    QualType Ty;
+    HLSLParamModifierAttr::Spelling Modifier;
+    MethodParam(const IdentifierInfo &NameII, QualType Ty,
+                HLSLParamModifierAttr::Spelling Modifier)
+        : NameII(NameII), Ty(Ty), Modifier(Modifier) {}
+  };
+
+  BuiltinTypeDeclBuilder &DeclBuilder;
+  Sema &S;
+  DeclarationNameInfo NameInfo;
+  QualType ReturnTy;
+  CXXMethodDecl *Method;
+  llvm::SmallVector<MethodParam> Params;
+  llvm::SmallVector<Stmt *> StmtsList;
+
+public:
+  BuiltinTypeMethodBuilder(Sema &S, BuiltinTypeDeclBuilder &DB, StringRef Name,
+                           QualType ReturnTy)
+      : DeclBuilder(DB), S(S), ReturnTy(ReturnTy), Method(nullptr) {
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    NameInfo = DeclarationNameInfo(DeclarationName(&II), SourceLocation());
+  }
+
+  BuiltinTypeMethodBuilder &addParam(StringRef Name, QualType Ty,
+                                     HLSLParamModifierAttr::Spelling Modifier =
+                                         HLSLParamModifierAttr::Keyword_in) {
+    assert(Method == nullptr && "Cannot add param, method already created");
+
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    Params.emplace_back(II, Ty, Modifier);
+    return *this;
+  }
+
+private:
+  void createMethodDecl() {
+    assert(Method == nullptr && "Method already created");
+
+    // create method type
+    ASTContext &AST = S.getASTContext();
+    SmallVector<QualType> ParamTypes;
+    for (auto &MP : Params)
+      ParamTypes.emplace_back(MP.Ty);
+    QualType MethodTy = AST.getFunctionType(ReturnTy, ParamTypes,
+                                            FunctionProtoType::ExtProtoInfo());
+
+    // create method decl
+    auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
+    Method =
+        CXXMethodDecl::Create(AST, DeclBuilder.Record, SourceLocation(),
+                              NameInfo, MethodTy, TSInfo, SC_None, false, false,
+                              ConstexprSpecKind::Unspecified, SourceLocation());
+
+    // create params & set them to the function prototype
+    SmallVector<ParmVarDecl *> ParmDecls;
+    auto FnProtoLoc =
+        Method->getTypeSourceInfo()->getTypeLoc().getAs<FunctionProtoTypeLoc>();
+    unsigned i = 0;
+    for (auto &MP : Params) {
+      ParmVarDecl *Parm = ParmVarDecl::Create(
+          AST, Method->getDeclContext(), SourceLocation(), SourceLocation(),
+          &MP.NameII, MP.Ty,
+          AST.getTrivialTypeSourceInfo(MP.Ty, SourceLocation()), SC_None,
+          nullptr);
+      if (MP.Modifier != HLSLParamModifierAttr::Keyword_in) {
+        auto *Mod =
+            HLSLParamModifierAttr::Create(AST, SourceRange(), MP.Modifier);
+        Parm->addAttr(Mod);
+      }
+      ParmDecls.push_back(Parm);
+      FnProtoLoc.setParam(i++, Parm);
+    }
+    Method->setParams({ParmDecls});
+  }
+
+  void addResourceHandleToParms(SmallVector<Expr *> &Parms) {
+    ASTContext &AST = S.getASTContext();
+    FieldDecl *HandleField = DeclBuilder.getResourceHandleField();
+    auto *This = CXXThisExpr::Create(
+        AST, SourceLocation(), Method->getFunctionObjectParameterType(), true);
+    Parms.push_back(MemberExpr::CreateImplicit(AST, This, false, HandleField,
+                                               HandleField->getType(),
+                                               VK_LValue, OK_Ordinary));
+  }
+
+public:
+  ~BuiltinTypeMethodBuilder() { finalizeMethod(); }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltin(StringRef BuiltinName, ArrayRef<Expr *> CallParms,
+              bool AddResourceHandleAsFirstArg = true) {
+    if (!Method)
+      createMethodDecl();
+
+    ASTContext &AST = S.getASTContext();
+    DeclRefExpr *Fn = lookupBuiltinFunction(S, BuiltinName);
+    Expr *Call = nullptr;
+
+    if (AddResourceHandleAsFirstArg) {
+      SmallVector<Expr *> NewCallParms;
+      addResourceHandleToParms(NewCallParms);
+      for (auto *P : CallParms)
+        NewCallParms.push_back(P);
+
+      Call = CallExpr::Create(AST, Fn, NewCallParms, AST.VoidPtrTy, VK_PRValue,
+                              SourceLocation(), FPOptionsOverride());
+    } else {
+      Call = CallExpr::Create(AST, Fn, CallParms, AST.VoidPtrTy, VK_PRValue,
+                              SourceLocation(), FPOptionsOverride());
+    }
+    StmtsList.push_back(Call);
+    return *this;
+  }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltinForwardArgs(StringRef BuiltinName,
+                         bool AddResourceHandleAsFirstArg = true) {
+    // FIXME: Call the buildin with all of the method parameters
+    // plus optional resource handle as the first arg.
+    llvm_unreachable("not yet implemented");
+  }
+
+  BuiltinTypeDeclBuilder &finalizeMethod() {
+    if (DeclBuilder.Record->isCompleteDefinition())
+      return DeclBuilder;
+
+    if (!Method)
+      createMethodDecl();
+
+    if (!Method->hasBody()) {
+      ASTContext &AST = S.getASTContext();
+      if (ReturnTy != AST.VoidTy && !StmtsList.empty()) {
+        if (Expr *LastExpr = dyn_cast<Expr>(StmtsList.back())) {
+          StmtsList.pop_back();
+          StmtsList.push_back(
+              ReturnStmt::Create(AST, SourceLocation(), LastExpr, nullptr));
+        }
+      }
+
+      Method->setBody(CompoundStmt::Create(AST, StmtsList, FPOptionsOverride(),
+                                           SourceLocation(), SourceLocation()));
+      Method->setLexicalDeclContext(DeclBuilder.Record);
+      Method->setAccess(AccessSpecifier::AS_public);
+      Method->addAttr(AlwaysInlineAttr::CreateImplicit(
+          AST, SourceRange(), AlwaysInlineAttr::CXX11_clang_always_inline));
+      DeclBuilder.Record->addDecl(Method);
+    }
+    return DeclBuilder;
+  }
+};
+
 } // namespace
 
 TemplateParameterListBuilder
@@ -375,6 +542,30 @@ BuiltinTypeDeclBuilder::addSimpleTemplateParams(Sema &S,
   return Builder.finalizeTemplateArgs();
 }
 
+BuiltinTypeDeclBuilder &
+BuiltinTypeDeclBuilder::addIncrementCounterMethod(Sema &S) {
+  ASTContext &AST = S.getASTContext();
+  Expr *One =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getTypeSize(AST.IntTy), 1),
+                             AST.IntTy, SourceLocation());
+  return BuiltinTypeMethodBuilder(S, *this, "IncrementCounter",
+                                  AST.UnsignedIntTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", {One})
+      .finalizeMethod();
+}
+
+BuiltinTypeDeclBuilder &
+BuiltinTypeDeclBuilder::addDecrementCounterMethod(Sema &S) {
+  ASTContext &AST = S.getASTContext();
+  Expr *NegOne =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getTypeSize(AST.IntTy), -1),
+                             AST.IntTy, SourceLocation());
+  return BuiltinTypeMethodBuilder(S, *this, "DecrementCounter",
+                                  AST.UnsignedIntTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", {NegOne})
+      .finalizeMethod();
+}
+
 HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
 
 void HLSLExternalSemaSource::InitializeSema(Sema &S) {
@@ -528,8 +719,13 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
                     ResourceKind::TypedBuffer, /*IsROV=*/false,
                     /*RawBuffer=*/true)
         .addArraySubscriptOperators()
+        .addIncrementCounterMethod(*SemaPtr)
+        .addDecrementCounterMethod(*SemaPtr)
         .completeDefinition();
   });
+
+  // FIXME: Also add Increment/DecrementCounter to
+  // RasterizerOrderedStructuredBuffer when llvm/llvm-project/#113648 is merged.
 }
 
 void HLSLExternalSemaSource::onCompletion(CXXRecordDecl *Record,
@@ -552,3 +748,23 @@ void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
     return;
   It->second(Record);
 }
+
+static DeclRefExpr *lookupBuiltinFunction(Sema &S, StringRef Name) {
+  IdentifierInfo &II =
+      S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+  DeclarationNameInfo NameInfo =
+      DeclarationNameInfo(DeclarationName(&II), SourceLocation());
+  LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
+  // AllowBuiltinCreation is false but LookupDirect will create
+  // the builtin when searching the global scope anyways...
+  S.LookupName(R, S.getCurScope());
+  // FIXME: If the builtin function was user-declared in global scope,
+  // this assert *will* fail. Should this call LookupBuiltin instead?
+  assert(R.isSingleResult() &&
+         "Since this is a builtin it should always resolve!");
+  auto *VD = cast<ValueDecl>(R.getFoundDecl());
+  QualType Ty = VD->getType();
+  return DeclRefExpr::Create(S.getASTContext(), NestedNameSpecifierLoc(),
+                             SourceLocation(), VD, false, NameInfo, Ty,
+                             VK_PRValue);
+}
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ff6616901016ab..770bd4a81633e1 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -986,6 +986,10 @@ Sema::VarArgKind Sema::isValidVarArgType(const QualType &Ty) {
   if (getLangOpts().MSVCCompat)
     return VAK_MSVCUndefined;
 
+  if (getLangOpts().HLSL &&
+      Ty->getUnqualifiedDesugaredType()->isHLSLAttributedResourceType())
+    return VAK_Valid;
+
   // FIXME: In C++11, these cases are conditionally-supported, meaning we're
   // permitted to reject them. We should consider doing so.
   return VAK_Undefined;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 1f6c5b8d4561bc..1b7f0456a3e82a 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1860,6 +1860,31 @@ static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
   return false;
 }
 
+static bool CheckResourceHandle(Sema *S, CallExpr *TheCall, unsigned ArgIndex) {
+  assert(TheCall->getNumArgs() >= ArgIndex);
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
+  if (!ArgType.getTypePtr()
+           ->getUnqualifiedDesugaredType()
+           ->isHLSLAttributedResourceType()) {
+    S->Diag(TheCall->getArg(0)->getBeginLoc(),
+            diag::err_typecheck_expect_hlsl_resource)
+        << ArgType;
+    return true;
+  }
+  return false;
+}
+
+static bool CheckInt(Sema *S, CallExpr *TheCall, unsigned ArgIndex) {
+  assert(TheCall->getNumArgs() >= ArgIndex);
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
+  if (!ArgType->isIntegerType()) {
+    S->Diag(TheCall->getArg(0)->getBeginLoc(), diag::err_typecheck_expect_int)
+        << ArgType;
+    return true;
+  }
+  return false;
+}
+
 // Note: returning true in this case results in CheckBuiltinFunctionCall
 // returning an ExprError
 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -2100,6 +2125,22 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_buffer_update_counter: {
+    if (SemaRef.checkArgCount(TheCall, 2) ||
+        CheckResourceHandle(&SemaRef, TheCall, 0) ||
+        CheckInt(&SemaRef, TheCall, 1))
+      return true;
+    Expr *OffsetExpr = TheCall->getArg(1);
+    std::optional<llvm::APSInt> Offset =
+        OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext());
+    if (!Offset.has_value() || abs(Offset->getExtValue()) != 1) {
+      SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
+                   diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
+          << 1;
+      return true;
+    }
+    break;
+  }
   }
   return false;
 }
diff --git a/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
new file mode 100644
index 00000000000000..c8ff5d3cd905fb
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
@@ -0,0 +1,25 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN-DISABLED...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-clang-codegen

Author: Helena Kotas (hekota)

Changes

Introduces __builtin_hlsl_buffer_update_counter clang buildin that is used to implement the IncrementCounter and DecrementCounter methods on RWStructuredBuffer and RasterizerOrderedStructuredBuffer (see Note).

The builtin is translated to LLVM intrisic llvm.dx.bufferUpdateCounter or llvm.spv.bufferUpdateCounter.

Introduces BuiltinTypeMethodBuilder helper in HLSLExternalSemaSource that allows adding methods to builtin types using builder pattern like this:

   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
       .addParam("param_name", Type, InOutModifier)
       .callBuiltin("buildin_name", { BuiltinParams })
       .finalizeMethod();

Note: RasterizerOrderedStructuredBuffer does not exist yet, it is being added in PR llvm/llvm-project#113648. After llvm/llvm-project#113648 is merged this PR will be updated to add Increment/DecrementCounter on this buffer type as well.

Fixes #113513


Patch is 25.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114148.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6-1)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+4)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+8)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Sema/HLSLExternalSemaSource.cpp (+247-31)
  • (modified) clang/lib/Sema/SemaExpr.cpp (+4)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+41)
  • (added) clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl (+25)
  • (added) clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-ps.hlsl (+29)
  • (added) clang/test/SemaHLSL/BuiltIns/buffer_update_counter-errors.hlsl (+22)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+3)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+3)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..72bc2d5e7df23e 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4846,7 +4846,6 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
-
 def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_select"];
   let Attributes = [NoThrow, Const];
@@ -4871,6 +4870,12 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLBufferUpdateCounter : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_buffer_update_counter"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "uint32_t(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 8e4718008ece72..2aea6bb657578a 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -7256,6 +7256,8 @@ def err_typecheck_illegal_increment_decrement : Error<
   "cannot %select{decrement|increment}1 value of type %0">;
 def err_typecheck_expect_int : Error<
   "used type %0 where integer is required">;
+def err_typecheck_expect_hlsl_resource : Error<
+  "used type %0 where __hlsl_resource_t is required">;
 def err_typecheck_arithmetic_incomplete_or_sizeless_type : Error<
   "arithmetic on a pointer to %select{an incomplete|sizeless}0 type %1">;
 def err_typecheck_pointer_arith_function_type : Error<
@@ -12485,6 +12487,8 @@ def warn_attr_min_eq_max:  Warning<
 
 def err_hlsl_attribute_number_arguments_insufficient_shader_model: Error<
   "attribute %0 with %1 arguments requires shader model %2 or greater">;
+def err_hlsl_expect_arg_const_int_one_or_neg_one: Error<
+  "argument %0 must be constant integer 1 or -1">;
 
 // Layout randomization diagnostics.
 def err_non_designated_init_used : Error<
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e2d03eff8ab4a0..71273de3400b17 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18959,6 +18959,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
         nullptr, "hlsl.radians");
   }
+  case Builtin::BI__builtin_hlsl_buffer_update_counter: {
+    Value *ResHandle = EmitScalarExpr(E->getArg(0));
+    Value *Offset = EmitScalarExpr(E->getArg(1));
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/Offset->getType(),
+        CGM.getHLSLRuntime().getBufferUpdateCounterIntrinsic(),
+        ArrayRef<Value *>{ResHandle, Offset}, nullptr);
+  }
   }
   return nullptr;
 }
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ff7df41b5c62e7..aac93dfc373ed4 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -93,6 +93,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
 
   GENERATE_HLSL_INTRINSIC_FUNCTION(CreateHandleFromBinding, handle_fromBinding)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(BufferUpdateCounter, bufferUpdateCounter)
 
   //===----------------------------------------------------------------------===//
   // End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index ce8564429b3802..24c3954b134c5f 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -22,12 +22,15 @@
 #include "clang/Sema/SemaHLSL.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Frontend/HLSL/HLSLResource.h"
+#include "llvm/Support/ErrorHandling.h"
 
 #include <functional>
 
 using namespace clang;
 using namespace llvm::hlsl;
 
+static DeclRefExpr *lookupBuiltinFunction(Sema &S, StringRef Name);
+
 namespace {
 
 struct TemplateParameterListBuilder;
@@ -121,12 +124,8 @@ struct BuiltinTypeDeclBuilder {
     TypeSourceInfo *ElementTypeInfo = nullptr;
 
     QualType ElemTy = Ctx.Char8Ty;
-    if (Template) {
-      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
-              Template->getTemplateParameters()->getParam(0))) {
-        ElemTy = QualType(TTD->getTypeForDecl(), 0);
-      }
-    }
+    if (Template)
+      ElemTy = getFirstTemplateTypeParam();
     ElementTypeInfo = Ctx.getTrivialTypeSourceInfo(ElemTy, SourceLocation());
 
     // add handle member with resource type attributes
@@ -145,25 +144,6 @@ struct BuiltinTypeDeclBuilder {
     return *this;
   }
 
-  static DeclRefExpr *lookupBuiltinFunction(ASTContext &AST, Sema &S,
-                                            StringRef Name) {
-    IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
-    DeclarationNameInfo NameInfo =
-        DeclarationNameInfo(DeclarationName(&II), SourceLocation());
-    LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
-    // AllowBuiltinCreation is false but LookupDirect will create
-    // the builtin when searching the global scope anyways...
-    S.LookupName(R, S.getCurScope());
-    // FIXME: If the builtin function was user-declared in global scope,
-    // this assert *will* fail. Should this call LookupBuiltin instead?
-    assert(R.isSingleResult() &&
-           "Since this is a builtin it should always resolve!");
-    auto *VD = cast<ValueDecl>(R.getFoundDecl());
-    QualType Ty = VD->getType();
-    return DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), SourceLocation(),
-                               VD, false, NameInfo, Ty, VK_PRValue);
-  }
-
   static Expr *emitResourceClassExpr(ASTContext &AST, ResourceClass RC) {
     return IntegerLiteral::Create(
         AST,
@@ -211,12 +191,8 @@ struct BuiltinTypeDeclBuilder {
 
     ASTContext &AST = Record->getASTContext();
     QualType ElemTy = AST.Char8Ty;
-    if (Template) {
-      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
-              Template->getTemplateParameters()->getParam(0))) {
-        ElemTy = QualType(TTD->getTypeForDecl(), 0);
-      }
-    }
+    if (Template)
+      ElemTy = getFirstTemplateTypeParam();
     QualType ReturnTy = ElemTy;
 
     FunctionProtoType::ExtProtoInfo ExtInfo;
@@ -282,6 +258,23 @@ struct BuiltinTypeDeclBuilder {
     return *this;
   }
 
+  FieldDecl *getResourceHandleField() {
+    FieldDecl *FD = Fields["h"];
+    if (FD && FD->getType()->isHLSLAttributedResourceType())
+      return FD;
+    return nullptr;
+  }
+
+  QualType getFirstTemplateTypeParam() {
+    if (Template) {
+      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
+              Template->getTemplateParameters()->getParam(0))) {
+        return QualType(TTD->getTypeForDecl(), 0);
+      }
+    }
+    return QualType();
+  }
+
   BuiltinTypeDeclBuilder &startDefinition() {
     if (Record->isCompleteDefinition())
       return *this;
@@ -302,6 +295,10 @@ struct BuiltinTypeDeclBuilder {
   TemplateParameterListBuilder addTemplateArgumentList(Sema &S);
   BuiltinTypeDeclBuilder &addSimpleTemplateParams(Sema &S,
                                                   ArrayRef<StringRef> Names);
+
+  // Builtin types methods
+  BuiltinTypeDeclBuilder &addIncrementCounterMethod(Sema &S);
+  BuiltinTypeDeclBuilder &addDecrementCounterMethod(Sema &S);
 };
 
 struct TemplateParameterListBuilder {
@@ -359,6 +356,176 @@ struct TemplateParameterListBuilder {
     return Builder;
   }
 };
+
+// Builder for methods of builtin types. Allows adding methods to builtin types
+// using the builder pattern like this:
+//
+//   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
+//       .addParam("param_name", Type, InOutModifier)
+//       .callBuiltin("buildin_name", { BuiltinParams })
+//       .finalizeMethod();
+//
+// The builder needs to have all of the method parameters before it can create
+// a CXXMethodDecl. It collects them in addParam calls and when a first
+// method that builds the body is called it creates the CXXMethodDecl and
+// ParmVarDecls instances. These can then be referenced from the body building
+// methods. Destructor or an explicit call to finalizeMethod() will complete
+// the method definition.
+struct BuiltinTypeMethodBuilder {
+  struct MethodParam {
+    const IdentifierInfo &NameII;
+    QualType Ty;
+    HLSLParamModifierAttr::Spelling Modifier;
+    MethodParam(const IdentifierInfo &NameII, QualType Ty,
+                HLSLParamModifierAttr::Spelling Modifier)
+        : NameII(NameII), Ty(Ty), Modifier(Modifier) {}
+  };
+
+  BuiltinTypeDeclBuilder &DeclBuilder;
+  Sema &S;
+  DeclarationNameInfo NameInfo;
+  QualType ReturnTy;
+  CXXMethodDecl *Method;
+  llvm::SmallVector<MethodParam> Params;
+  llvm::SmallVector<Stmt *> StmtsList;
+
+public:
+  BuiltinTypeMethodBuilder(Sema &S, BuiltinTypeDeclBuilder &DB, StringRef Name,
+                           QualType ReturnTy)
+      : DeclBuilder(DB), S(S), ReturnTy(ReturnTy), Method(nullptr) {
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    NameInfo = DeclarationNameInfo(DeclarationName(&II), SourceLocation());
+  }
+
+  BuiltinTypeMethodBuilder &addParam(StringRef Name, QualType Ty,
+                                     HLSLParamModifierAttr::Spelling Modifier =
+                                         HLSLParamModifierAttr::Keyword_in) {
+    assert(Method == nullptr && "Cannot add param, method already created");
+
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    Params.emplace_back(II, Ty, Modifier);
+    return *this;
+  }
+
+private:
+  void createMethodDecl() {
+    assert(Method == nullptr && "Method already created");
+
+    // create method type
+    ASTContext &AST = S.getASTContext();
+    SmallVector<QualType> ParamTypes;
+    for (auto &MP : Params)
+      ParamTypes.emplace_back(MP.Ty);
+    QualType MethodTy = AST.getFunctionType(ReturnTy, ParamTypes,
+                                            FunctionProtoType::ExtProtoInfo());
+
+    // create method decl
+    auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
+    Method =
+        CXXMethodDecl::Create(AST, DeclBuilder.Record, SourceLocation(),
+                              NameInfo, MethodTy, TSInfo, SC_None, false, false,
+                              ConstexprSpecKind::Unspecified, SourceLocation());
+
+    // create params & set them to the function prototype
+    SmallVector<ParmVarDecl *> ParmDecls;
+    auto FnProtoLoc =
+        Method->getTypeSourceInfo()->getTypeLoc().getAs<FunctionProtoTypeLoc>();
+    unsigned i = 0;
+    for (auto &MP : Params) {
+      ParmVarDecl *Parm = ParmVarDecl::Create(
+          AST, Method->getDeclContext(), SourceLocation(), SourceLocation(),
+          &MP.NameII, MP.Ty,
+          AST.getTrivialTypeSourceInfo(MP.Ty, SourceLocation()), SC_None,
+          nullptr);
+      if (MP.Modifier != HLSLParamModifierAttr::Keyword_in) {
+        auto *Mod =
+            HLSLParamModifierAttr::Create(AST, SourceRange(), MP.Modifier);
+        Parm->addAttr(Mod);
+      }
+      ParmDecls.push_back(Parm);
+      FnProtoLoc.setParam(i++, Parm);
+    }
+    Method->setParams({ParmDecls});
+  }
+
+  void addResourceHandleToParms(SmallVector<Expr *> &Parms) {
+    ASTContext &AST = S.getASTContext();
+    FieldDecl *HandleField = DeclBuilder.getResourceHandleField();
+    auto *This = CXXThisExpr::Create(
+        AST, SourceLocation(), Method->getFunctionObjectParameterType(), true);
+    Parms.push_back(MemberExpr::CreateImplicit(AST, This, false, HandleField,
+                                               HandleField->getType(),
+                                               VK_LValue, OK_Ordinary));
+  }
+
+public:
+  ~BuiltinTypeMethodBuilder() { finalizeMethod(); }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltin(StringRef BuiltinName, ArrayRef<Expr *> CallParms,
+              bool AddResourceHandleAsFirstArg = true) {
+    if (!Method)
+      createMethodDecl();
+
+    ASTContext &AST = S.getASTContext();
+    DeclRefExpr *Fn = lookupBuiltinFunction(S, BuiltinName);
+    Expr *Call = nullptr;
+
+    if (AddResourceHandleAsFirstArg) {
+      SmallVector<Expr *> NewCallParms;
+      addResourceHandleToParms(NewCallParms);
+      for (auto *P : CallParms)
+        NewCallParms.push_back(P);
+
+      Call = CallExpr::Create(AST, Fn, NewCallParms, AST.VoidPtrTy, VK_PRValue,
+                              SourceLocation(), FPOptionsOverride());
+    } else {
+      Call = CallExpr::Create(AST, Fn, CallParms, AST.VoidPtrTy, VK_PRValue,
+                              SourceLocation(), FPOptionsOverride());
+    }
+    StmtsList.push_back(Call);
+    return *this;
+  }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltinForwardArgs(StringRef BuiltinName,
+                         bool AddResourceHandleAsFirstArg = true) {
+    // FIXME: Call the buildin with all of the method parameters
+    // plus optional resource handle as the first arg.
+    llvm_unreachable("not yet implemented");
+  }
+
+  BuiltinTypeDeclBuilder &finalizeMethod() {
+    if (DeclBuilder.Record->isCompleteDefinition())
+      return DeclBuilder;
+
+    if (!Method)
+      createMethodDecl();
+
+    if (!Method->hasBody()) {
+      ASTContext &AST = S.getASTContext();
+      if (ReturnTy != AST.VoidTy && !StmtsList.empty()) {
+        if (Expr *LastExpr = dyn_cast<Expr>(StmtsList.back())) {
+          StmtsList.pop_back();
+          StmtsList.push_back(
+              ReturnStmt::Create(AST, SourceLocation(), LastExpr, nullptr));
+        }
+      }
+
+      Method->setBody(CompoundStmt::Create(AST, StmtsList, FPOptionsOverride(),
+                                           SourceLocation(), SourceLocation()));
+      Method->setLexicalDeclContext(DeclBuilder.Record);
+      Method->setAccess(AccessSpecifier::AS_public);
+      Method->addAttr(AlwaysInlineAttr::CreateImplicit(
+          AST, SourceRange(), AlwaysInlineAttr::CXX11_clang_always_inline));
+      DeclBuilder.Record->addDecl(Method);
+    }
+    return DeclBuilder;
+  }
+};
+
 } // namespace
 
 TemplateParameterListBuilder
@@ -375,6 +542,30 @@ BuiltinTypeDeclBuilder::addSimpleTemplateParams(Sema &S,
   return Builder.finalizeTemplateArgs();
 }
 
+BuiltinTypeDeclBuilder &
+BuiltinTypeDeclBuilder::addIncrementCounterMethod(Sema &S) {
+  ASTContext &AST = S.getASTContext();
+  Expr *One =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getTypeSize(AST.IntTy), 1),
+                             AST.IntTy, SourceLocation());
+  return BuiltinTypeMethodBuilder(S, *this, "IncrementCounter",
+                                  AST.UnsignedIntTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", {One})
+      .finalizeMethod();
+}
+
+BuiltinTypeDeclBuilder &
+BuiltinTypeDeclBuilder::addDecrementCounterMethod(Sema &S) {
+  ASTContext &AST = S.getASTContext();
+  Expr *NegOne =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getTypeSize(AST.IntTy), -1),
+                             AST.IntTy, SourceLocation());
+  return BuiltinTypeMethodBuilder(S, *this, "DecrementCounter",
+                                  AST.UnsignedIntTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", {NegOne})
+      .finalizeMethod();
+}
+
 HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
 
 void HLSLExternalSemaSource::InitializeSema(Sema &S) {
@@ -528,8 +719,13 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
                     ResourceKind::TypedBuffer, /*IsROV=*/false,
                     /*RawBuffer=*/true)
         .addArraySubscriptOperators()
+        .addIncrementCounterMethod(*SemaPtr)
+        .addDecrementCounterMethod(*SemaPtr)
         .completeDefinition();
   });
+
+  // FIXME: Also add Increment/DecrementCounter to
+  // RasterizerOrderedStructuredBuffer when llvm/llvm-project/#113648 is merged.
 }
 
 void HLSLExternalSemaSource::onCompletion(CXXRecordDecl *Record,
@@ -552,3 +748,23 @@ void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
     return;
   It->second(Record);
 }
+
+static DeclRefExpr *lookupBuiltinFunction(Sema &S, StringRef Name) {
+  IdentifierInfo &II =
+      S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+  DeclarationNameInfo NameInfo =
+      DeclarationNameInfo(DeclarationName(&II), SourceLocation());
+  LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
+  // AllowBuiltinCreation is false but LookupDirect will create
+  // the builtin when searching the global scope anyways...
+  S.LookupName(R, S.getCurScope());
+  // FIXME: If the builtin function was user-declared in global scope,
+  // this assert *will* fail. Should this call LookupBuiltin instead?
+  assert(R.isSingleResult() &&
+         "Since this is a builtin it should always resolve!");
+  auto *VD = cast<ValueDecl>(R.getFoundDecl());
+  QualType Ty = VD->getType();
+  return DeclRefExpr::Create(S.getASTContext(), NestedNameSpecifierLoc(),
+                             SourceLocation(), VD, false, NameInfo, Ty,
+                             VK_PRValue);
+}
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ff6616901016ab..770bd4a81633e1 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -986,6 +986,10 @@ Sema::VarArgKind Sema::isValidVarArgType(const QualType &Ty) {
   if (getLangOpts().MSVCCompat)
     return VAK_MSVCUndefined;
 
+  if (getLangOpts().HLSL &&
+      Ty->getUnqualifiedDesugaredType()->isHLSLAttributedResourceType())
+    return VAK_Valid;
+
   // FIXME: In C++11, these cases are conditionally-supported, meaning we're
   // permitted to reject them. We should consider doing so.
   return VAK_Undefined;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 1f6c5b8d4561bc..1b7f0456a3e82a 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1860,6 +1860,31 @@ static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
   return false;
 }
 
+static bool CheckResourceHandle(Sema *S, CallExpr *TheCall, unsigned ArgIndex) {
+  assert(TheCall->getNumArgs() >= ArgIndex);
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
+  if (!ArgType.getTypePtr()
+           ->getUnqualifiedDesugaredType()
+           ->isHLSLAttributedResourceType()) {
+    S->Diag(TheCall->getArg(0)->getBeginLoc(),
+            diag::err_typecheck_expect_hlsl_resource)
+        << ArgType;
+    return true;
+  }
+  return false;
+}
+
+static bool CheckInt(Sema *S, CallExpr *TheCall, unsigned ArgIndex) {
+  assert(TheCall->getNumArgs() >= ArgIndex);
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
+  if (!ArgType->isIntegerType()) {
+    S->Diag(TheCall->getArg(0)->getBeginLoc(), diag::err_typecheck_expect_int)
+        << ArgType;
+    return true;
+  }
+  return false;
+}
+
 // Note: returning true in this case results in CheckBuiltinFunctionCall
 // returning an ExprError
 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -2100,6 +2125,22 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_buffer_update_counter: {
+    if (SemaRef.checkArgCount(TheCall, 2) ||
+        CheckResourceHandle(&SemaRef, TheCall, 0) ||
+        CheckInt(&SemaRef, TheCall, 1))
+      return true;
+    Expr *OffsetExpr = TheCall->getArg(1);
+    std::optional<llvm::APSInt> Offset =
+        OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext());
+    if (!Offset.has_value() || abs(Offset->getExtValue()) != 1) {
+      SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
+                   diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
+          << 1;
+      return true;
+    }
+    break;
+  }
   }
   return false;
 }
diff --git a/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
new file mode 100644
index 00000000000000..c8ff5d3cd905fb
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
@@ -0,0 +1,25 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN-DISABLED...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-llvm-ir

Author: Helena Kotas (hekota)

Changes

Introduces __builtin_hlsl_buffer_update_counter clang buildin that is used to implement the IncrementCounter and DecrementCounter methods on RWStructuredBuffer and RasterizerOrderedStructuredBuffer (see Note).

The builtin is translated to LLVM intrisic llvm.dx.bufferUpdateCounter or llvm.spv.bufferUpdateCounter.

Introduces BuiltinTypeMethodBuilder helper in HLSLExternalSemaSource that allows adding methods to builtin types using builder pattern like this:

   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
       .addParam("param_name", Type, InOutModifier)
       .callBuiltin("buildin_name", { BuiltinParams })
       .finalizeMethod();

Note: RasterizerOrderedStructuredBuffer does not exist yet, it is being added in PR llvm/llvm-project#113648. After llvm/llvm-project#113648 is merged this PR will be updated to add Increment/DecrementCounter on this buffer type as well.

Fixes #113513


Patch is 25.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114148.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6-1)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+4)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+8)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Sema/HLSLExternalSemaSource.cpp (+247-31)
  • (modified) clang/lib/Sema/SemaExpr.cpp (+4)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+41)
  • (added) clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl (+25)
  • (added) clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-ps.hlsl (+29)
  • (added) clang/test/SemaHLSL/BuiltIns/buffer_update_counter-errors.hlsl (+22)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+3)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+3)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..72bc2d5e7df23e 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4846,7 +4846,6 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
-
 def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_select"];
   let Attributes = [NoThrow, Const];
@@ -4871,6 +4870,12 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLBufferUpdateCounter : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_buffer_update_counter"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "uint32_t(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 8e4718008ece72..2aea6bb657578a 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -7256,6 +7256,8 @@ def err_typecheck_illegal_increment_decrement : Error<
   "cannot %select{decrement|increment}1 value of type %0">;
 def err_typecheck_expect_int : Error<
   "used type %0 where integer is required">;
+def err_typecheck_expect_hlsl_resource : Error<
+  "used type %0 where __hlsl_resource_t is required">;
 def err_typecheck_arithmetic_incomplete_or_sizeless_type : Error<
   "arithmetic on a pointer to %select{an incomplete|sizeless}0 type %1">;
 def err_typecheck_pointer_arith_function_type : Error<
@@ -12485,6 +12487,8 @@ def warn_attr_min_eq_max:  Warning<
 
 def err_hlsl_attribute_number_arguments_insufficient_shader_model: Error<
   "attribute %0 with %1 arguments requires shader model %2 or greater">;
+def err_hlsl_expect_arg_const_int_one_or_neg_one: Error<
+  "argument %0 must be constant integer 1 or -1">;
 
 // Layout randomization diagnostics.
 def err_non_designated_init_used : Error<
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e2d03eff8ab4a0..71273de3400b17 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18959,6 +18959,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
         nullptr, "hlsl.radians");
   }
+  case Builtin::BI__builtin_hlsl_buffer_update_counter: {
+    Value *ResHandle = EmitScalarExpr(E->getArg(0));
+    Value *Offset = EmitScalarExpr(E->getArg(1));
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/Offset->getType(),
+        CGM.getHLSLRuntime().getBufferUpdateCounterIntrinsic(),
+        ArrayRef<Value *>{ResHandle, Offset}, nullptr);
+  }
   }
   return nullptr;
 }
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ff7df41b5c62e7..aac93dfc373ed4 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -93,6 +93,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
 
   GENERATE_HLSL_INTRINSIC_FUNCTION(CreateHandleFromBinding, handle_fromBinding)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(BufferUpdateCounter, bufferUpdateCounter)
 
   //===----------------------------------------------------------------------===//
   // End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index ce8564429b3802..24c3954b134c5f 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -22,12 +22,15 @@
 #include "clang/Sema/SemaHLSL.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Frontend/HLSL/HLSLResource.h"
+#include "llvm/Support/ErrorHandling.h"
 
 #include <functional>
 
 using namespace clang;
 using namespace llvm::hlsl;
 
+static DeclRefExpr *lookupBuiltinFunction(Sema &S, StringRef Name);
+
 namespace {
 
 struct TemplateParameterListBuilder;
@@ -121,12 +124,8 @@ struct BuiltinTypeDeclBuilder {
     TypeSourceInfo *ElementTypeInfo = nullptr;
 
     QualType ElemTy = Ctx.Char8Ty;
-    if (Template) {
-      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
-              Template->getTemplateParameters()->getParam(0))) {
-        ElemTy = QualType(TTD->getTypeForDecl(), 0);
-      }
-    }
+    if (Template)
+      ElemTy = getFirstTemplateTypeParam();
     ElementTypeInfo = Ctx.getTrivialTypeSourceInfo(ElemTy, SourceLocation());
 
     // add handle member with resource type attributes
@@ -145,25 +144,6 @@ struct BuiltinTypeDeclBuilder {
     return *this;
   }
 
-  static DeclRefExpr *lookupBuiltinFunction(ASTContext &AST, Sema &S,
-                                            StringRef Name) {
-    IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
-    DeclarationNameInfo NameInfo =
-        DeclarationNameInfo(DeclarationName(&II), SourceLocation());
-    LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
-    // AllowBuiltinCreation is false but LookupDirect will create
-    // the builtin when searching the global scope anyways...
-    S.LookupName(R, S.getCurScope());
-    // FIXME: If the builtin function was user-declared in global scope,
-    // this assert *will* fail. Should this call LookupBuiltin instead?
-    assert(R.isSingleResult() &&
-           "Since this is a builtin it should always resolve!");
-    auto *VD = cast<ValueDecl>(R.getFoundDecl());
-    QualType Ty = VD->getType();
-    return DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), SourceLocation(),
-                               VD, false, NameInfo, Ty, VK_PRValue);
-  }
-
   static Expr *emitResourceClassExpr(ASTContext &AST, ResourceClass RC) {
     return IntegerLiteral::Create(
         AST,
@@ -211,12 +191,8 @@ struct BuiltinTypeDeclBuilder {
 
     ASTContext &AST = Record->getASTContext();
     QualType ElemTy = AST.Char8Ty;
-    if (Template) {
-      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
-              Template->getTemplateParameters()->getParam(0))) {
-        ElemTy = QualType(TTD->getTypeForDecl(), 0);
-      }
-    }
+    if (Template)
+      ElemTy = getFirstTemplateTypeParam();
     QualType ReturnTy = ElemTy;
 
     FunctionProtoType::ExtProtoInfo ExtInfo;
@@ -282,6 +258,23 @@ struct BuiltinTypeDeclBuilder {
     return *this;
   }
 
+  FieldDecl *getResourceHandleField() {
+    FieldDecl *FD = Fields["h"];
+    if (FD && FD->getType()->isHLSLAttributedResourceType())
+      return FD;
+    return nullptr;
+  }
+
+  QualType getFirstTemplateTypeParam() {
+    if (Template) {
+      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
+              Template->getTemplateParameters()->getParam(0))) {
+        return QualType(TTD->getTypeForDecl(), 0);
+      }
+    }
+    return QualType();
+  }
+
   BuiltinTypeDeclBuilder &startDefinition() {
     if (Record->isCompleteDefinition())
       return *this;
@@ -302,6 +295,10 @@ struct BuiltinTypeDeclBuilder {
   TemplateParameterListBuilder addTemplateArgumentList(Sema &S);
   BuiltinTypeDeclBuilder &addSimpleTemplateParams(Sema &S,
                                                   ArrayRef<StringRef> Names);
+
+  // Builtin types methods
+  BuiltinTypeDeclBuilder &addIncrementCounterMethod(Sema &S);
+  BuiltinTypeDeclBuilder &addDecrementCounterMethod(Sema &S);
 };
 
 struct TemplateParameterListBuilder {
@@ -359,6 +356,176 @@ struct TemplateParameterListBuilder {
     return Builder;
   }
 };
+
+// Builder for methods of builtin types. Allows adding methods to builtin types
+// using the builder pattern like this:
+//
+//   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
+//       .addParam("param_name", Type, InOutModifier)
+//       .callBuiltin("buildin_name", { BuiltinParams })
+//       .finalizeMethod();
+//
+// The builder needs to have all of the method parameters before it can create
+// a CXXMethodDecl. It collects them in addParam calls and when a first
+// method that builds the body is called it creates the CXXMethodDecl and
+// ParmVarDecls instances. These can then be referenced from the body building
+// methods. Destructor or an explicit call to finalizeMethod() will complete
+// the method definition.
+struct BuiltinTypeMethodBuilder {
+  struct MethodParam {
+    const IdentifierInfo &NameII;
+    QualType Ty;
+    HLSLParamModifierAttr::Spelling Modifier;
+    MethodParam(const IdentifierInfo &NameII, QualType Ty,
+                HLSLParamModifierAttr::Spelling Modifier)
+        : NameII(NameII), Ty(Ty), Modifier(Modifier) {}
+  };
+
+  BuiltinTypeDeclBuilder &DeclBuilder;
+  Sema &S;
+  DeclarationNameInfo NameInfo;
+  QualType ReturnTy;
+  CXXMethodDecl *Method;
+  llvm::SmallVector<MethodParam> Params;
+  llvm::SmallVector<Stmt *> StmtsList;
+
+public:
+  BuiltinTypeMethodBuilder(Sema &S, BuiltinTypeDeclBuilder &DB, StringRef Name,
+                           QualType ReturnTy)
+      : DeclBuilder(DB), S(S), ReturnTy(ReturnTy), Method(nullptr) {
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    NameInfo = DeclarationNameInfo(DeclarationName(&II), SourceLocation());
+  }
+
+  BuiltinTypeMethodBuilder &addParam(StringRef Name, QualType Ty,
+                                     HLSLParamModifierAttr::Spelling Modifier =
+                                         HLSLParamModifierAttr::Keyword_in) {
+    assert(Method == nullptr && "Cannot add param, method already created");
+
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    Params.emplace_back(II, Ty, Modifier);
+    return *this;
+  }
+
+private:
+  void createMethodDecl() {
+    assert(Method == nullptr && "Method already created");
+
+    // create method type
+    ASTContext &AST = S.getASTContext();
+    SmallVector<QualType> ParamTypes;
+    for (auto &MP : Params)
+      ParamTypes.emplace_back(MP.Ty);
+    QualType MethodTy = AST.getFunctionType(ReturnTy, ParamTypes,
+                                            FunctionProtoType::ExtProtoInfo());
+
+    // create method decl
+    auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
+    Method =
+        CXXMethodDecl::Create(AST, DeclBuilder.Record, SourceLocation(),
+                              NameInfo, MethodTy, TSInfo, SC_None, false, false,
+                              ConstexprSpecKind::Unspecified, SourceLocation());
+
+    // create params & set them to the function prototype
+    SmallVector<ParmVarDecl *> ParmDecls;
+    auto FnProtoLoc =
+        Method->getTypeSourceInfo()->getTypeLoc().getAs<FunctionProtoTypeLoc>();
+    unsigned i = 0;
+    for (auto &MP : Params) {
+      ParmVarDecl *Parm = ParmVarDecl::Create(
+          AST, Method->getDeclContext(), SourceLocation(), SourceLocation(),
+          &MP.NameII, MP.Ty,
+          AST.getTrivialTypeSourceInfo(MP.Ty, SourceLocation()), SC_None,
+          nullptr);
+      if (MP.Modifier != HLSLParamModifierAttr::Keyword_in) {
+        auto *Mod =
+            HLSLParamModifierAttr::Create(AST, SourceRange(), MP.Modifier);
+        Parm->addAttr(Mod);
+      }
+      ParmDecls.push_back(Parm);
+      FnProtoLoc.setParam(i++, Parm);
+    }
+    Method->setParams({ParmDecls});
+  }
+
+  void addResourceHandleToParms(SmallVector<Expr *> &Parms) {
+    ASTContext &AST = S.getASTContext();
+    FieldDecl *HandleField = DeclBuilder.getResourceHandleField();
+    auto *This = CXXThisExpr::Create(
+        AST, SourceLocation(), Method->getFunctionObjectParameterType(), true);
+    Parms.push_back(MemberExpr::CreateImplicit(AST, This, false, HandleField,
+                                               HandleField->getType(),
+                                               VK_LValue, OK_Ordinary));
+  }
+
+public:
+  ~BuiltinTypeMethodBuilder() { finalizeMethod(); }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltin(StringRef BuiltinName, ArrayRef<Expr *> CallParms,
+              bool AddResourceHandleAsFirstArg = true) {
+    if (!Method)
+      createMethodDecl();
+
+    ASTContext &AST = S.getASTContext();
+    DeclRefExpr *Fn = lookupBuiltinFunction(S, BuiltinName);
+    Expr *Call = nullptr;
+
+    if (AddResourceHandleAsFirstArg) {
+      SmallVector<Expr *> NewCallParms;
+      addResourceHandleToParms(NewCallParms);
+      for (auto *P : CallParms)
+        NewCallParms.push_back(P);
+
+      Call = CallExpr::Create(AST, Fn, NewCallParms, AST.VoidPtrTy, VK_PRValue,
+                              SourceLocation(), FPOptionsOverride());
+    } else {
+      Call = CallExpr::Create(AST, Fn, CallParms, AST.VoidPtrTy, VK_PRValue,
+                              SourceLocation(), FPOptionsOverride());
+    }
+    StmtsList.push_back(Call);
+    return *this;
+  }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltinForwardArgs(StringRef BuiltinName,
+                         bool AddResourceHandleAsFirstArg = true) {
+    // FIXME: Call the buildin with all of the method parameters
+    // plus optional resource handle as the first arg.
+    llvm_unreachable("not yet implemented");
+  }
+
+  BuiltinTypeDeclBuilder &finalizeMethod() {
+    if (DeclBuilder.Record->isCompleteDefinition())
+      return DeclBuilder;
+
+    if (!Method)
+      createMethodDecl();
+
+    if (!Method->hasBody()) {
+      ASTContext &AST = S.getASTContext();
+      if (ReturnTy != AST.VoidTy && !StmtsList.empty()) {
+        if (Expr *LastExpr = dyn_cast<Expr>(StmtsList.back())) {
+          StmtsList.pop_back();
+          StmtsList.push_back(
+              ReturnStmt::Create(AST, SourceLocation(), LastExpr, nullptr));
+        }
+      }
+
+      Method->setBody(CompoundStmt::Create(AST, StmtsList, FPOptionsOverride(),
+                                           SourceLocation(), SourceLocation()));
+      Method->setLexicalDeclContext(DeclBuilder.Record);
+      Method->setAccess(AccessSpecifier::AS_public);
+      Method->addAttr(AlwaysInlineAttr::CreateImplicit(
+          AST, SourceRange(), AlwaysInlineAttr::CXX11_clang_always_inline));
+      DeclBuilder.Record->addDecl(Method);
+    }
+    return DeclBuilder;
+  }
+};
+
 } // namespace
 
 TemplateParameterListBuilder
@@ -375,6 +542,30 @@ BuiltinTypeDeclBuilder::addSimpleTemplateParams(Sema &S,
   return Builder.finalizeTemplateArgs();
 }
 
+BuiltinTypeDeclBuilder &
+BuiltinTypeDeclBuilder::addIncrementCounterMethod(Sema &S) {
+  ASTContext &AST = S.getASTContext();
+  Expr *One =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getTypeSize(AST.IntTy), 1),
+                             AST.IntTy, SourceLocation());
+  return BuiltinTypeMethodBuilder(S, *this, "IncrementCounter",
+                                  AST.UnsignedIntTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", {One})
+      .finalizeMethod();
+}
+
+BuiltinTypeDeclBuilder &
+BuiltinTypeDeclBuilder::addDecrementCounterMethod(Sema &S) {
+  ASTContext &AST = S.getASTContext();
+  Expr *NegOne =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getTypeSize(AST.IntTy), -1),
+                             AST.IntTy, SourceLocation());
+  return BuiltinTypeMethodBuilder(S, *this, "DecrementCounter",
+                                  AST.UnsignedIntTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", {NegOne})
+      .finalizeMethod();
+}
+
 HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
 
 void HLSLExternalSemaSource::InitializeSema(Sema &S) {
@@ -528,8 +719,13 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
                     ResourceKind::TypedBuffer, /*IsROV=*/false,
                     /*RawBuffer=*/true)
         .addArraySubscriptOperators()
+        .addIncrementCounterMethod(*SemaPtr)
+        .addDecrementCounterMethod(*SemaPtr)
         .completeDefinition();
   });
+
+  // FIXME: Also add Increment/DecrementCounter to
+  // RasterizerOrderedStructuredBuffer when llvm/llvm-project/#113648 is merged.
 }
 
 void HLSLExternalSemaSource::onCompletion(CXXRecordDecl *Record,
@@ -552,3 +748,23 @@ void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
     return;
   It->second(Record);
 }
+
+static DeclRefExpr *lookupBuiltinFunction(Sema &S, StringRef Name) {
+  IdentifierInfo &II =
+      S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+  DeclarationNameInfo NameInfo =
+      DeclarationNameInfo(DeclarationName(&II), SourceLocation());
+  LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
+  // AllowBuiltinCreation is false but LookupDirect will create
+  // the builtin when searching the global scope anyways...
+  S.LookupName(R, S.getCurScope());
+  // FIXME: If the builtin function was user-declared in global scope,
+  // this assert *will* fail. Should this call LookupBuiltin instead?
+  assert(R.isSingleResult() &&
+         "Since this is a builtin it should always resolve!");
+  auto *VD = cast<ValueDecl>(R.getFoundDecl());
+  QualType Ty = VD->getType();
+  return DeclRefExpr::Create(S.getASTContext(), NestedNameSpecifierLoc(),
+                             SourceLocation(), VD, false, NameInfo, Ty,
+                             VK_PRValue);
+}
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ff6616901016ab..770bd4a81633e1 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -986,6 +986,10 @@ Sema::VarArgKind Sema::isValidVarArgType(const QualType &Ty) {
   if (getLangOpts().MSVCCompat)
     return VAK_MSVCUndefined;
 
+  if (getLangOpts().HLSL &&
+      Ty->getUnqualifiedDesugaredType()->isHLSLAttributedResourceType())
+    return VAK_Valid;
+
   // FIXME: In C++11, these cases are conditionally-supported, meaning we're
   // permitted to reject them. We should consider doing so.
   return VAK_Undefined;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 1f6c5b8d4561bc..1b7f0456a3e82a 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1860,6 +1860,31 @@ static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
   return false;
 }
 
+static bool CheckResourceHandle(Sema *S, CallExpr *TheCall, unsigned ArgIndex) {
+  assert(TheCall->getNumArgs() >= ArgIndex);
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
+  if (!ArgType.getTypePtr()
+           ->getUnqualifiedDesugaredType()
+           ->isHLSLAttributedResourceType()) {
+    S->Diag(TheCall->getArg(0)->getBeginLoc(),
+            diag::err_typecheck_expect_hlsl_resource)
+        << ArgType;
+    return true;
+  }
+  return false;
+}
+
+static bool CheckInt(Sema *S, CallExpr *TheCall, unsigned ArgIndex) {
+  assert(TheCall->getNumArgs() >= ArgIndex);
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
+  if (!ArgType->isIntegerType()) {
+    S->Diag(TheCall->getArg(0)->getBeginLoc(), diag::err_typecheck_expect_int)
+        << ArgType;
+    return true;
+  }
+  return false;
+}
+
 // Note: returning true in this case results in CheckBuiltinFunctionCall
 // returning an ExprError
 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -2100,6 +2125,22 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_buffer_update_counter: {
+    if (SemaRef.checkArgCount(TheCall, 2) ||
+        CheckResourceHandle(&SemaRef, TheCall, 0) ||
+        CheckInt(&SemaRef, TheCall, 1))
+      return true;
+    Expr *OffsetExpr = TheCall->getArg(1);
+    std::optional<llvm::APSInt> Offset =
+        OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext());
+    if (!Offset.has_value() || abs(Offset->getExtValue()) != 1) {
+      SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
+                   diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
+          << 1;
+      return true;
+    }
+    break;
+  }
   }
   return false;
 }
diff --git a/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
new file mode 100644
index 00000000000000..c8ff5d3cd905fb
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
@@ -0,0 +1,25 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN-DISABLED...
[truncated]

Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looking good. I really like the additions to the builder API. Can you please add an AST test to verify the shape of the new AST nodes and their instantiations?

clang/lib/Sema/HLSLExternalSemaSource.cpp Outdated Show resolved Hide resolved
if (DeclBuilder.Record->isCompleteDefinition())
return DeclBuilder;

if (!Method)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the check above both seem like they would be errors. Maybe they should be asserts?

Copy link
Member Author

@hekota hekota Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isCompleteDefinition: I copied this approach from the way BuiltinTypeDeclBuilder currently does it. I will change these to asserts everywhere.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other check (!Method) is not quite an error, it just means no body building method have been called (the created method has no body).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the check if (DeclBuilder.Record->isCompleteDefinition()) needs to stay in. We might already have a complete definition from a precompiled header.

If I replace these with asserts the precompiled header test C:\llvm-project3\clang\test\AST\HLSL\pch_with_buf.hlsl fails.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the PCH is providing a completed definition, we probably shouldn't be calling into the complete callback since all of that stuff should fail. Can we detect that earlier and never get here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other check (!Method) is not quite an error, it just means no body building method have been called (the created method has no body).

A method with no body seems like a logic error on the person writing code (hence the request for an assert). Why would we want to enable generating empty methods?

clang/lib/Sema/HLSLExternalSemaSource.cpp Show resolved Hide resolved
clang/lib/Sema/HLSLExternalSemaSource.cpp Outdated Show resolved Hide resolved
clang/lib/Sema/HLSLExternalSemaSource.cpp Outdated Show resolved Hide resolved
Copy link

github-actions bot commented Nov 5, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple small comment take or leave, but looks good.

clang/include/clang/Basic/Builtins.td Show resolved Hide resolved
BuiltinTypeMethodBuilder &
callBuiltin(StringRef BuiltinName, ArrayRef<Expr *> CallParms,
bool AddResourceHandleAsFirstArg = true) {
if (!Method)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Might be nice to have a comment here explaining the need for the create call.

Suggested change
if (!Method)
// The first statement added to a method creates the declaration.
if (!Method)

// FIXME: Call the buildin with all of the method parameters
// plus optional resource handle as the first arg.
llvm_unreachable("not yet implemented");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is unused, maybe we should remove it from this PR and add it in the subsequent PR when it is used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX clang:codegen clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

[HLSL] Implement IncrementCounter/DecrementCounter on structured buffers
3 participants