Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 80 additions & 23 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Ptr/IR/PtrAttrs.h"
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
#include "mlir/Dialect/Ptr/IR/PtrOps.h"
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
Expand Down Expand Up @@ -90,7 +94,8 @@ struct ConvertCIRToMLIRPass
mlir::affine::AffineDialect, mlir::memref::MemRefDialect,
mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect,
mlir::scf::SCFDialect, mlir::math::MathDialect,
mlir::vector::VectorDialect, mlir::LLVM::LLVMDialect>();
mlir::ptr::PtrDialect, mlir::vector::VectorDialect,
mlir::LLVM::LLVMDialect>();
}
void runOnOperation() final;

Expand Down Expand Up @@ -1612,32 +1617,83 @@ class CIRPtrStrideOpLowering
// only been used to propogate %base and %stride to memref.load/store and
// should be erased after the conversion.
mlir::LogicalResult
matchAndRewrite(cir::PtrStrideOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
if (!isCastArrayToPtrConsumer(op))
return mlir::failure();
if (!isLoadStoreOrCastArrayToPtrProduer(op))
return mlir::failure();
auto baseOp =
rewriteArrayDecay(cir::PtrStrideOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto baseDefiningOp =
adaptor.getBase().getDefiningOp<mlir::memref::ReinterpretCastOp>();
if (!baseOp)
if (!baseDefiningOp)
return mlir::failure();
auto base = baseOp->getOperand(0);
auto dstType = op.getType();
auto newDstType = llvm::cast<mlir::MemRefType>(convertTy(dstType));

auto base = baseDefiningOp->getOperand(0);
auto ptrType = op.getType();
auto memrefType = llvm::cast<mlir::MemRefType>(convertTy(ptrType));
auto stride = adaptor.getStride();
auto indexType = rewriter.getIndexType();

// Generate casting if the stride is not index type.
if (stride.getType() != indexType)
stride = mlir::arith::IndexCastOp::create(rewriter, op.getLoc(),
indexType, stride);
llvm::SmallVector<mlir::OpFoldResult> sizes, strides;
if (mlir::failed(prepareReinterpretMetadata(newDstType, rewriter, sizes,
strides, op.getOperation())))
return mlir::failure();

rewriter.replaceOpWithNewOp<mlir::memref::ReinterpretCastOp>(
op, newDstType, base, stride, sizes, strides);
rewriter.eraseOp(baseOp);
op, memrefType, base, stride, mlir::ValueRange{}, mlir::ValueRange{},
llvm::ArrayRef<mlir::NamedAttribute>{});

rewriter.eraseOp(baseDefiningOp);
return mlir::success();
}

mlir::LogicalResult
matchAndRewrite(cir::PtrStrideOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
if (isCastArrayToPtrConsumer(op) && isLoadStoreOrCastArrayToPtrProduer(op))
return rewriteArrayDecay(op, adaptor, rewriter);

auto base = adaptor.getBase();
auto stride = adaptor.getStride();

auto ptrType = op.getType();
auto elementType = convertTy(ptrType.getPointee());

auto ptrPtrType = mlir::ptr::PtrType::get(
rewriter.getContext(),
mlir::ptr::GenericSpaceAttr::get(op->getContext()));

mlir::Value elemSizeVal = mlir::ptr::TypeOffsetOp::create(
rewriter, op.getLoc(), rewriter.getIndexType(), elementType);

mlir::Value strideIndex = mlir::arith::IndexCastOp::create(
rewriter, op.getLoc(), rewriter.getIndexType(), stride);

mlir::Value offset = mlir::arith::MulIOp::create(rewriter, op.getLoc(),
strideIndex, elemSizeVal);

auto t1 = mlir::cast<mlir::MemRefType>(base.getType());
auto t2 =
mlir::MemRefType::get(t1.getShape(), t1.getElementType(),
t1.getLayout(), ptrPtrType.getMemorySpace());

auto ptrMetaType = mlir::ptr::PtrMetadataType::get(t2);

auto fixedBase = mlir::memref::MemorySpaceCastOp::create(
rewriter, op->getLoc(), t2, base);

auto getMetadataOp = mlir::ptr::GetMetadataOp::create(
rewriter, op->getLoc(), ptrMetaType, fixedBase);

auto toPtrOp = mlir::ptr::ToPtrOp::create(rewriter, op->getLoc(),
ptrPtrType, fixedBase);

auto ptrAddOp = mlir::ptr::PtrAddOp::create(rewriter, op.getLoc(),
ptrPtrType, toPtrOp, offset);

auto fromPtrOp = mlir::ptr::FromPtrOp::create(rewriter, op.getLoc(), t2,
ptrAddOp, getMetadataOp);

auto memrefCastOp = mlir::memref::MemorySpaceCastOp::create(
rewriter, op.getLoc(), t1, fromPtrOp);

rewriter.replaceOp(op, memrefCastOp);
return mlir::success();
}
};
Expand Down Expand Up @@ -1782,11 +1838,12 @@ void ConvertCIRToMLIRPass::runOnOperation() {

mlir::ConversionTarget target(getContext());
target.addLegalOp<mlir::ModuleOp>();
target.addLegalDialect<mlir::affine::AffineDialect, mlir::arith::ArithDialect,
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
mlir::math::MathDialect, mlir::vector::VectorDialect,
mlir::LLVM::LLVMDialect>();
target
.addLegalDialect<mlir::affine::AffineDialect, mlir::arith::ArithDialect,
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
mlir::ptr::PtrDialect, mlir::math::MathDialect,
mlir::vector::VectorDialect, mlir::LLVM::LLVMDialect>();
auto *context = patterns.getContext();

// We cannot mark cir dialect as illegal before conversion.
Expand Down
24 changes: 24 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/ptrstride-ptr.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: cir-opt %s --cir-to-mlir | FileCheck %s -check-prefix=MLIR

!s32i = !cir.int<s, 32>
module {
cir.func @raw_pointer(%p : !cir.ptr<!s32i>) -> !s32i {
// MLIR: %[[TWO:.*]] = arith.constant 2 : i32
// MLIR-NEXT: %[[TYPEOFFSET:.*]] = ptr.type_offset i32 : index
// MLIR-NEXT: %[[I:.*]] = arith.index_cast %[[TWO]] : i32 to index
// MLIR-NEXT: %[[OFFSET:.*]] = arith.muli %[[I]], %[[TYPEOFFSET]] : index
// MLIR-NEXT: %[[CAST1:.*]] = memref.memory_space_cast %arg0 : memref<i32> to memref<i32, #ptr.generic_space>
// MLIR-NEXT: %[[META:.*]] = ptr.get_metadata %[[CAST1]] : memref<i32, #ptr.generic_space>
// MLIR-NEXT: %[[P1:.*]] = ptr.to_ptr %[[CAST1]] : memref<i32, #ptr.generic_space> -> <#ptr.generic_space>
// MLIR-NEXT: %[[P2:.*]] = ptr.ptr_add %[[P1]], %[[OFFSET]] : !ptr.ptr<#ptr.generic_space>, index
// MLIR-NEXT: %[[PP:.*]] = ptr.from_ptr %[[P2]] metadata %[[META]] : <#ptr.generic_space> -> memref<i32, #ptr.generic_space>
// MLIR-NEXT: %[[CAST2:.*]] = memref.memory_space_cast %[[PP]] : memref<i32, #ptr.generic_space> to memref<i32>
// MLIR-NEXT: %[[R:.*]] = memref.load %[[CAST2]][] : memref<i32>
// MLIR-NEXT: return %[[R]] : i32

%0 = cir.const #cir.int<2> : !s32i
%1 = cir.ptr_stride %p, %0 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
%2 = cir.load %1 : !cir.ptr<!s32i>, !s32i
cir.return %2 : !s32i
}
}