diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 5c5228354a52..a70a9ce976a4 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -6118,6 +6118,39 @@ def CIR_LinkerOptionsOp : CIR_Op<"linker_options", [ }]; } +//===----------------------------------------------------------------------===// +// BlockAddressOp +//===----------------------------------------------------------------------===// + +def CIR_BlockAddressOp : CIR_Op<"blockaddress", [Pure]> { + let summary = "Get the address of a cir.label within a function"; + let description = [{ + The `cir.blockaddress` operation takes a function name and a label and + produces a pointer value that represents the address of that cir.label within + the specified function. + + This operation models GCC's "labels as values" extension (`&&label`), which + allows taking the address of a local label and using it as a computed + jump target (e.g., with `goto *addr;`). + + Example: + ```mlir + %1 = cir.alloca !cir.ptr, !cir.ptr>, ["ptr", init] {alignment = 8 : i64} + %addr = cir.blockaddress("foo", "label") -> !cir.ptr + cir.store align(8) %addr, %1 : !cir.ptr, !cir.ptr> + cir.br ^bb1 + ^bb1: + cir.label "label" + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$func, StrAttr:$label); + let results = (outs CIR_VoidPtrType:$addr); + let assemblyFormat = [{ + `(` $func `,` $label `)` `->` qualified(type($addr)) attr-dict + }]; +} + //===----------------------------------------------------------------------===// // Standard library function calls //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 2a0d77084628..e39622cace79 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -226,8 +226,18 @@ class ScalarExprEmitter : public StmtVisitor { } mlir::Value VisitUnaryExprOrTypeTraitExpr(const UnaryExprOrTypeTraitExpr *E); - mlir::Value VisitAddrLabelExpr(const AddrLabelExpr *E) { - llvm_unreachable("NYI"); + + mlir::Value VisitAddrLabelExpr(const AddrLabelExpr *e) { + auto func = cast(CGF.CurFn); + llvm::StringRef symName = func.getSymName(); + mlir::FlatSymbolRefAttr funName = + mlir::FlatSymbolRefAttr::get(&CGF.getMLIRContext(), symName); + mlir::StringAttr labelName = + mlir::StringAttr::get(&CGF.getMLIRContext(), e->getLabel()->getName()); + return cir::BlockAddressOp::create(Builder, CGF.getLoc(e->getSourceRange()), + CGF.convertType(e->getType()), funName, + labelName); + ; } mlir::Value VisitSizeOfPackExpr(SizeOfPackExpr *E) { llvm_unreachable("NYI"); diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp index 6d353e5ee707..a928cfee2eb5 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp @@ -180,7 +180,6 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *S, case Stmt::CXXForRangeStmtClass: return emitCXXForRangeStmt(cast(*S), Attrs); - case Stmt::IndirectGotoStmtClass: case Stmt::ReturnStmtClass: // When implemented, GCCAsmStmtClass should fall-through to MSAsmStmtClass. case Stmt::GCCAsmStmtClass: @@ -196,6 +195,7 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *S, case Stmt::OMPBarrierDirectiveClass: return emitOMPBarrierDirective(cast(*S)); // Unsupported AST nodes: + case Stmt::IndirectGotoStmtClass: case Stmt::CapturedStmtClass: case Stmt::ObjCAtTryStmtClass: case Stmt::ObjCAtThrowStmtClass: diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 8a6f451d57ad..19b51d2bf819 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -17,13 +17,14 @@ #include "clang/CIR/Dialect/IR/CIRTypes.h" #include "clang/CIR/Interfaces/CIRLoopOpInterface.h" #include "clang/CIR/MissingFeatures.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/LogicalResult.h" #include #include -#include #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" @@ -2928,23 +2929,46 @@ LogicalResult cir::FuncOp::verify() { << "' must have empty body"; } - std::set labels; - std::set gotos; - + llvm::SmallSet labels; + llvm::SmallSet gotos; + llvm::SmallSet blockAddresses; + bool invalidBlockAddress = false; getOperation()->walk([&](mlir::Operation *op) { if (auto lab = dyn_cast(op)) { - labels.emplace(lab.getLabel()); + labels.insert(lab.getLabel()); } else if (auto goTo = dyn_cast(op)) { - gotos.emplace(goTo.getLabel()); + gotos.insert(goTo.getLabel()); + } else if (auto blkAdd = dyn_cast(op)) { + if (blkAdd.getFunc() != getSymName()) { + // Stop the walk early, no need to continue + invalidBlockAddress = true; + return mlir::WalkResult::interrupt(); + } + blockAddresses.insert(blkAdd.getLabel()); } + return mlir::WalkResult::advance(); }); - std::vector mismatched; - std::set_difference(gotos.begin(), gotos.end(), labels.begin(), labels.end(), - std::back_inserter(mismatched)); + if (invalidBlockAddress) + return emitOpError() << "blockaddress references a different function"; + + llvm::SmallSet mismatched; + if (!labels.empty() || !gotos.empty()) { + mismatched = llvm::set_difference(gotos, labels); + + if (!mismatched.empty()) + return emitOpError() << "goto/label mismatch"; + } - if (!mismatched.empty()) - return emitOpError() << "goto/label mismatch"; + mismatched.clear(); + + if (!labels.empty() || !blockAddresses.empty()) { + mismatched = llvm::set_difference(blockAddresses, labels); + + if (!mismatched.empty()) + return emitOpError() + << "expects an existing label target in the referenced function"; + } return success(); } diff --git a/clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp b/clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp index e78049ecefae..e2d07c4ff433 100644 --- a/clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp +++ b/clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp @@ -3,6 +3,7 @@ #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/Dialect/Passes.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" @@ -24,17 +25,29 @@ static void process(cir::FuncOp func) { mlir::OpBuilder rewriter(func.getContext()); llvm::StringMap labels; llvm::SmallVector gotos; + llvm::SmallSet blockAddrLabel; func.getBody().walk([&](mlir::Operation *op) { if (auto lab = dyn_cast(op)) { - // Will construct a string copy inplace. Safely erase the label labels.try_emplace(lab.getLabel(), lab->getBlock()); - lab.erase(); } else if (auto goTo = dyn_cast(op)) { gotos.push_back(goTo); + } else if (auto blockAddr = dyn_cast(op)) { + blockAddrLabel.insert(blockAddr.getLabel()); } }); + for (auto &lab : labels) { + StringRef labelName = lab.getKey(); + Block *block = lab.getValue(); + if (!blockAddrLabel.contains(labelName)) { + // erase the LabelOp inside the block if safe + if (auto lab = dyn_cast(&block->front())) { + lab.erase(); + } + } + } + for (auto goTo : gotos) { mlir::OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(goTo); @@ -54,4 +67,4 @@ void GotoSolverPass::runOnOperation() { std::unique_ptr mlir::createGotoSolverPass() { return std::make_unique(); -} \ No newline at end of file +} diff --git a/clang/test/CIR/CodeGen/label-values.c b/clang/test/CIR/CodeGen/label-values.c new file mode 100644 index 000000000000..503ab07a2811 --- /dev/null +++ b/clang/test/CIR/CodeGen/label-values.c @@ -0,0 +1,74 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir +// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR + +void A(void) { + void *ptr = &&A; +A: + return; +} +// CIR: cir.func dso_local @A +// CIR: [[PTR:%.*]] = cir.alloca !cir.ptr, !cir.ptr>, ["ptr", init] {alignment = 8 : i64} +// CIR: [[BLOCK:%.*]] = cir.blockaddress(@A, "A") -> !cir.ptr +// CIR: cir.store align(8) [[BLOCK]], [[PTR]] : !cir.ptr, !cir.ptr> +// CIR: cir.br ^bb1 +// CIR: ^bb1: // pred: ^bb0 +// CIR: cir.label "A" +// CIR: cir.return + +void B(void) { +B: + void *ptr = &&B; +} + +// CIR: cir.func dso_local @B() +// CIR: cir.label "B" +// CIR: [[PTR:%.*]] = cir.alloca !cir.ptr, !cir.ptr>, ["ptr", init] {alignment = 8 : i64} +// CIR: [[BLOCK:%.*]] = cir.blockaddress(@B, "B") -> !cir.ptr +// CIR: cir.store align(8) [[BLOCK]], [[PTR]] : !cir.ptr, !cir.ptr> +// CIR: cir.return + +void C(int x) { + void *ptr = (x == 0) ? &&A : &&B; +A: + return; +B: + return; +} + +// CIR: cir.func dso_local @C +// CIR: [[BLOCK1:%.*]] = cir.blockaddress(@C, "A") -> !cir.ptr +// CIR: [[BLOCK2:%.*]] = cir.blockaddress(@C, "B") -> !cir.ptr +// CIR: [[COND:%.*]] = cir.select if [[CMP:%.*]] then [[BLOCK1]] else [[BLOCK2]] : (!cir.bool, !cir.ptr, !cir.ptr) -> !cir.ptr +// CIR: cir.store align(8) [[COND]], [[PTR:%.*]] : !cir.ptr, !cir.ptr> +// CIR: cir.br ^bb2 +// CIR: ^bb1: // 2 preds: ^bb2, ^bb3 +// CIR: cir.return +// CIR: ^bb2: // pred: ^bb0 +// CIR: cir.label "A" +// CIR: cir.br ^bb1 +// CIR: ^bb3: // no predecessors +// CIR: cir.label "B" +// CIR: cir.br ^bb1 + +void D(void) { + void *ptr = &&A; + void *ptr2 = &&A; +A: + void *ptr3 = &&A; + return; +} + +// CIR: cir.func dso_local @D +// CIR: %[[PTR:.*]] = cir.alloca !cir.ptr, !cir.ptr>, ["ptr", init] +// CIR: %[[PTR2:.*]] = cir.alloca !cir.ptr, !cir.ptr>, ["ptr2", init] +// CIR: %[[PTR3:.*]] = cir.alloca !cir.ptr, !cir.ptr>, ["ptr3", init] +// CIR: %[[BLK1:.*]] = cir.blockaddress(@D, "A") -> !cir.ptr +// CIR: cir.store align(8) %[[BLK1]], %[[PTR]] : !cir.ptr, !cir.ptr> +// CIR: %[[BLK2:.*]] = cir.blockaddress(@D, "A") -> !cir.ptr +// CIR: cir.store align(8) %[[BLK2]], %[[PTR2]] : !cir.ptr, !cir.ptr> +// CIR: cir.br ^bb1 +// CIR: ^bb1: // pred: ^bb0 +// CIR: cir.label "A" +// CIR: %[[BLK3:.*]] = cir.blockaddress(@D, "A") -> !cir.ptr +// CIR: cir.store align(8) %[[BLK3]], %[[PTR3]] : !cir.ptr, !cir.ptr> +// CIR: cir.return diff --git a/clang/test/CIR/IR/block-adress.cir b/clang/test/CIR/IR/block-adress.cir new file mode 100644 index 000000000000..1b72bce3df41 --- /dev/null +++ b/clang/test/CIR/IR/block-adress.cir @@ -0,0 +1,34 @@ +// RUN: cir-opt %s | cir-opt | FileCheck %s + +!void = !cir.void + +module { + cir.func @block_address(){ + %0 = cir.blockaddress(@block_address, "label") -> !cir.ptr + cir.br ^bb1 + ^bb1: + cir.label "label" + cir.return + } +// CHECK: cir.func @block_address +// CHECK: %0 = cir.blockaddress(@block_address, "label") -> !cir.ptr +// CHECK: cir.br ^bb1 +// CHECK: ^bb1: +// CHECK: cir.label "label" +// CHECK: cir.return + +cir.func @block_address_inside_scope() -> () { + cir.scope{ + %0 = cir.blockaddress(@block_address_inside_scope, "label") -> !cir.ptr + } + cir.br ^bb1 +^bb1: + cir.label "label" + cir.return +} +// CHECK: cir.func @block_address_inside_scope +// CHECK: cir.scope +// CHECK: %0 = cir.blockaddress(@block_address_inside_scope, "label") -> !cir.ptr +// CHECK: cir.label "label" +// CHECK: cir.return +} diff --git a/clang/test/CIR/IR/invalid-block-address.cir b/clang/test/CIR/IR/invalid-block-address.cir new file mode 100644 index 000000000000..26ffa1a14568 --- /dev/null +++ b/clang/test/CIR/IR/invalid-block-address.cir @@ -0,0 +1,21 @@ +// RUN: cir-opt %s -verify-diagnostics -split-input-file + +!void = !cir.void + +// expected-error@+1 {{expects an existing label target in the referenced function}} +cir.func @bad_block_address() -> () { + %0 = cir.blockaddress(@bad_block_address, "label") -> !cir.ptr + cir.br ^bb1 + ^bb1: + cir.label "wrong_label" + cir.return +} + +// expected-error@+1 {{blockaddress references a different function}} +cir.func @bad_block_func() -> () { + %0 = cir.blockaddress(@mismatch_func, "label") -> !cir.ptr + cir.br ^bb1 + ^bb1: + cir.label "label" + cir.return +} diff --git a/clang/test/CIR/Transforms/goto_solver.cir b/clang/test/CIR/Transforms/goto_solver.cir new file mode 100644 index 000000000000..177a8ce652e0 --- /dev/null +++ b/clang/test/CIR/Transforms/goto_solver.cir @@ -0,0 +1,63 @@ +// RUN: cir-opt %s -cir-goto-solver -o - | FileCheck %s + +!void = !cir.void + +cir.func @a(){ + %0 = cir.alloca !cir.ptr, !cir.ptr>, ["ptr", init] {alignment = 8 : i64} + %1 = cir.blockaddress(@a, "label1") -> !cir.ptr + cir.store align(8) %1, %0 : !cir.ptr, !cir.ptr> + cir.br ^bb1 +^bb1: + cir.label "label1" + cir.br ^bb2 +^bb2: + // This label is not referenced by any blockaddressOp, so it should be removed + cir.label "label2" + cir.return +} + +// CHECK: cir.func @a() +// CHECK: %1 = cir.blockaddress(@a, "label1") -> !cir.ptr +// CHECK: ^bb1: +// CHECK: cir.label "label1" +// CHECK: cir.br ^bb2 +// CHECK: ^bb2: +// CHECK-NOT: cir.label "label2" + +cir.func @b(){ + %0 = cir.alloca !cir.ptr, !cir.ptr>, ["ptr", init] {alignment = 8 : i64} + %1 = cir.blockaddress(@b, "label1") -> !cir.ptr + cir.store align(8) %1, %0 : !cir.ptr, !cir.ptr> + cir.goto "label2" +^bb1: + cir.label "label1" + cir.br ^bb2 +^bb2: + // This label is not referenced by any blockaddressOp, so it should be removed + cir.label "label2" + cir.return +} + +// CHECK: cir.func @b() { +// CHECK: %1 = cir.blockaddress(@b, "label1") -> !cir.ptr +// CHECK: cir.store align(8) %1, {{.*}} : !cir.ptr, !cir.ptr> +// CHECK: cir.br ^bb2 +// CHECK: ^bb1: +// CHECK: cir.label "label1" +// CHECK: cir.br ^bb2 +// CHECK: ^bb2: +// CHECK-NOT: cir.label "label2" + +cir.func @c() { + cir.label "label1" + %0 = cir.alloca !cir.ptr, !cir.ptr>, ["ptr", init] {alignment = 8 : i64} + %1 = cir.blockaddress(@c, "label1") -> !cir.ptr + cir.store align(8) %1, %0 : !cir.ptr, !cir.ptr> + cir.return +} + +// CHECK: cir.func @c +// CHECK: cir.label "label1" +// CHECK: %1 = cir.blockaddress(@c, "label1") -> !cir.ptr +// CHECK: cir.store align(8) %1, {{.*}} : !cir.ptr, !cir.ptr> + diff --git a/clang/tools/cir-opt/cir-opt.cpp b/clang/tools/cir-opt/cir-opt.cpp index 2c242f9d2db1..68b074949df5 100644 --- a/clang/tools/cir-opt/cir-opt.cpp +++ b/clang/tools/cir-opt/cir-opt.cpp @@ -55,7 +55,9 @@ int main(int argc, char **argv) { ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { return mlir::createCIRSimplifyPass(); }); - + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return mlir::createGotoSolverPass(); + }); ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { return mlir::createSCFPreparePass(); });