diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index 5b4ec9c5cf16..297031f2c42b 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -26,6 +26,10 @@ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Ptr/IR/PtrAttrs.h" +#include "mlir/Dialect/Ptr/IR/PtrDialect.h" +#include "mlir/Dialect/Ptr/IR/PtrOps.h" +#include "mlir/Dialect/Ptr/IR/PtrTypes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" @@ -90,7 +94,8 @@ struct ConvertCIRToMLIRPass mlir::affine::AffineDialect, mlir::memref::MemRefDialect, mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect, mlir::scf::SCFDialect, mlir::math::MathDialect, - mlir::vector::VectorDialect, mlir::LLVM::LLVMDialect>(); + mlir::ptr::PtrDialect, mlir::vector::VectorDialect, + mlir::LLVM::LLVMDialect>(); } void runOnOperation() final; @@ -1612,32 +1617,83 @@ class CIRPtrStrideOpLowering // only been used to propogate %base and %stride to memref.load/store and // should be erased after the conversion. mlir::LogicalResult - matchAndRewrite(cir::PtrStrideOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - if (!isCastArrayToPtrConsumer(op)) - return mlir::failure(); - if (!isLoadStoreOrCastArrayToPtrProduer(op)) - return mlir::failure(); - auto baseOp = + rewriteArrayDecay(cir::PtrStrideOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto baseDefiningOp = adaptor.getBase().getDefiningOp(); - if (!baseOp) + if (!baseDefiningOp) return mlir::failure(); - auto base = baseOp->getOperand(0); - auto dstType = op.getType(); - auto newDstType = llvm::cast(convertTy(dstType)); + + auto base = baseDefiningOp->getOperand(0); + auto ptrType = op.getType(); + auto memrefType = llvm::cast(convertTy(ptrType)); auto stride = adaptor.getStride(); auto indexType = rewriter.getIndexType(); + // Generate casting if the stride is not index type. if (stride.getType() != indexType) stride = mlir::arith::IndexCastOp::create(rewriter, op.getLoc(), indexType, stride); - llvm::SmallVector sizes, strides; - if (mlir::failed(prepareReinterpretMetadata(newDstType, rewriter, sizes, - strides, op.getOperation()))) - return mlir::failure(); + rewriter.replaceOpWithNewOp( - op, newDstType, base, stride, sizes, strides); - rewriter.eraseOp(baseOp); + op, memrefType, base, stride, mlir::ValueRange{}, mlir::ValueRange{}, + llvm::ArrayRef{}); + + rewriter.eraseOp(baseDefiningOp); + return mlir::success(); + } + + mlir::LogicalResult + matchAndRewrite(cir::PtrStrideOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + if (isCastArrayToPtrConsumer(op) && isLoadStoreOrCastArrayToPtrProduer(op)) + return rewriteArrayDecay(op, adaptor, rewriter); + + auto base = adaptor.getBase(); + auto stride = adaptor.getStride(); + + auto ptrType = op.getType(); + auto elementType = convertTy(ptrType.getPointee()); + + auto ptrPtrType = mlir::ptr::PtrType::get( + rewriter.getContext(), + mlir::ptr::GenericSpaceAttr::get(op->getContext())); + + mlir::Value elemSizeVal = mlir::ptr::TypeOffsetOp::create( + rewriter, op.getLoc(), rewriter.getIndexType(), elementType); + + mlir::Value strideIndex = mlir::arith::IndexCastOp::create( + rewriter, op.getLoc(), rewriter.getIndexType(), stride); + + mlir::Value offset = mlir::arith::MulIOp::create(rewriter, op.getLoc(), + strideIndex, elemSizeVal); + + auto t1 = mlir::cast(base.getType()); + auto t2 = + mlir::MemRefType::get(t1.getShape(), t1.getElementType(), + t1.getLayout(), ptrPtrType.getMemorySpace()); + + auto ptrMetaType = mlir::ptr::PtrMetadataType::get(t2); + + auto fixedBase = mlir::memref::MemorySpaceCastOp::create( + rewriter, op->getLoc(), t2, base); + + auto getMetadataOp = mlir::ptr::GetMetadataOp::create( + rewriter, op->getLoc(), ptrMetaType, fixedBase); + + auto toPtrOp = mlir::ptr::ToPtrOp::create(rewriter, op->getLoc(), + ptrPtrType, fixedBase); + + auto ptrAddOp = mlir::ptr::PtrAddOp::create(rewriter, op.getLoc(), + ptrPtrType, toPtrOp, offset); + + auto fromPtrOp = mlir::ptr::FromPtrOp::create(rewriter, op.getLoc(), t2, + ptrAddOp, getMetadataOp); + + auto memrefCastOp = mlir::memref::MemorySpaceCastOp::create( + rewriter, op.getLoc(), t1, fromPtrOp); + + rewriter.replaceOp(op, memrefCastOp); return mlir::success(); } }; @@ -1782,11 +1838,12 @@ void ConvertCIRToMLIRPass::runOnOperation() { mlir::ConversionTarget target(getContext()); target.addLegalOp(); - target.addLegalDialect(); + target + .addLegalDialect(); auto *context = patterns.getContext(); // We cannot mark cir dialect as illegal before conversion. diff --git a/clang/test/CIR/Lowering/ThroughMLIR/ptrstride-ptr.cir b/clang/test/CIR/Lowering/ThroughMLIR/ptrstride-ptr.cir new file mode 100644 index 000000000000..ffd9e9646e9d --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/ptrstride-ptr.cir @@ -0,0 +1,24 @@ +// RUN: cir-opt %s --cir-to-mlir | FileCheck %s -check-prefix=MLIR + +!s32i = !cir.int +module { + cir.func @raw_pointer(%p : !cir.ptr) -> !s32i { + // MLIR: %[[TWO:.*]] = arith.constant 2 : i32 + // MLIR-NEXT: %[[TYPEOFFSET:.*]] = ptr.type_offset i32 : index + // MLIR-NEXT: %[[I:.*]] = arith.index_cast %[[TWO]] : i32 to index + // MLIR-NEXT: %[[OFFSET:.*]] = arith.muli %[[I]], %[[TYPEOFFSET]] : index + // MLIR-NEXT: %[[CAST1:.*]] = memref.memory_space_cast %arg0 : memref to memref + // MLIR-NEXT: %[[META:.*]] = ptr.get_metadata %[[CAST1]] : memref + // MLIR-NEXT: %[[P1:.*]] = ptr.to_ptr %[[CAST1]] : memref -> <#ptr.generic_space> + // MLIR-NEXT: %[[P2:.*]] = ptr.ptr_add %[[P1]], %[[OFFSET]] : !ptr.ptr<#ptr.generic_space>, index + // MLIR-NEXT: %[[PP:.*]] = ptr.from_ptr %[[P2]] metadata %[[META]] : <#ptr.generic_space> -> memref + // MLIR-NEXT: %[[CAST2:.*]] = memref.memory_space_cast %[[PP]] : memref to memref + // MLIR-NEXT: %[[R:.*]] = memref.load %[[CAST2]][] : memref + // MLIR-NEXT: return %[[R]] : i32 + + %0 = cir.const #cir.int<2> : !s32i + %1 = cir.ptr_stride %p, %0 : (!cir.ptr, !s32i) -> !cir.ptr + %2 = cir.load %1 : !cir.ptr, !s32i + cir.return %2 : !s32i + } +} \ No newline at end of file