diff --git a/include/circt/Dialect/Arc/ArcOps.td b/include/circt/Dialect/Arc/ArcOps.td index 4db69689c4db..8a1b6327ec96 100644 --- a/include/circt/Dialect/Arc/ArcOps.td +++ b/include/circt/Dialect/Arc/ArcOps.td @@ -425,15 +425,11 @@ def MemoryWriteOp : ArcOp<"memory_write", [ let arguments = (ins MemoryType:$memory, AnyInteger:$address, - Optional:$enable, AnyInteger:$data ); let assemblyFormat = [{ - $memory `[` $address `]` `,` $data (`if` $enable^)? - attr-dict `:` type($memory) + $memory `[` $address `]` `,` $data attr-dict `:` type($memory) }]; - let hasFolder = 1; - let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// @@ -594,19 +590,17 @@ def StateWriteOp : ArcOp<"state_write", [ let arguments = (ins StateType:$state, AnyType:$value, - Optional:$condition, OptionalAttr:$traceTapModel, OptionalAttr:$traceTapIndex); let assemblyFormat = [{ - $state `=` $value (`if` $condition^)? + $state `=` $value (`tap` $traceTapModel`[`$traceTapIndex^`]` )? attr-dict `:` type($state) }]; let hasVerifier = true; let builders = [ OpBuilder<(ins "mlir::Value":$state, - "mlir::Value":$value, - "mlir::Value":$condition), [{ - build($_builder, $_state, state, value, condition, + "mlir::Value":$value), [{ + build($_builder, $_state, state, value, mlir::FlatSymbolRefAttr{}, mlir::IntegerAttr{}); }]> ]; diff --git a/lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp b/lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp index a5d71ac4a6e1..e6db31c9e4dc 100644 --- a/lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp +++ b/lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp @@ -152,30 +152,17 @@ struct StateWriteOpLowering : public OpConversionPattern { LogicalResult matchAndRewrite(arc::StateWriteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - auto doStore = [&](OpBuilder &builder, Location loc) { - if (!isa(op.getValue().getType())) { - LLVM::StoreOp::create(builder, loc, adaptor.getValue(), - adaptor.getState()); - return; - } - - int numBytes = op.getState().getType().getByteWidth(); - Value size = LLVM::ConstantOp::create(builder, loc, rewriter.getI64Type(), - numBytes); - LLVM::MemcpyOp::create(builder, loc, adaptor.getState(), - adaptor.getValue(), size, /*volatile=*/false); - }; - - if (adaptor.getCondition()) { - rewriter.replaceOpWithNewOp( - op, adaptor.getCondition(), [&](auto &builder, auto loc) { - doStore(builder, loc); - scf::YieldOp::create(builder, loc); - }); - } else { - doStore(rewriter, op.getLoc()); - rewriter.eraseOp(op); + if (!isa(op.getValue().getType())) { + rewriter.replaceOpWithNewOp(op, adaptor.getValue(), + adaptor.getState()); + return success(); } + + int numBytes = op.getState().getType().getByteWidth(); + Value size = LLVM::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getI64Type(), numBytes); + rewriter.replaceOpWithNewOp( + op, adaptor.getState(), adaptor.getValue(), size, /*volatile=*/false); return success(); } }; @@ -355,9 +342,6 @@ struct MemoryWriteOpLowering : public OpConversionPattern { op.getLoc(), adaptor.getMemory(), adaptor.getAddress(), cast(op.getMemory().getType()), rewriter); auto enable = access.withinBounds; - if (adaptor.getEnable()) - enable = LLVM::AndOp::create(rewriter, op.getLoc(), adaptor.getEnable(), - enable); // Only attempt to write the memory if the address is within bounds. rewriter.replaceOpWithNewOp( diff --git a/lib/Dialect/Arc/ArcFolds.cpp b/lib/Dialect/Arc/ArcFolds.cpp index bbfbb857e87c..65d52b05e2ec 100644 --- a/lib/Dialect/Arc/ArcFolds.cpp +++ b/lib/Dialect/Arc/ArcFolds.cpp @@ -92,24 +92,6 @@ LogicalResult StateOp::canonicalize(StateOp op, PatternRewriter &rewriter) { return failure(); } -//===----------------------------------------------------------------------===// -// MemoryWriteOp -//===----------------------------------------------------------------------===// - -LogicalResult MemoryWriteOp::fold(FoldAdaptor adaptor, - SmallVectorImpl &results) { - if (isAlways(adaptor.getEnable(), true)) - return getEnableMutable().clear(), success(); - return failure(); -} - -LogicalResult MemoryWriteOp::canonicalize(MemoryWriteOp op, - PatternRewriter &rewriter) { - if (isAlways(op.getEnable(), false)) - return rewriter.eraseOp(op), success(); - return failure(); -} - //===----------------------------------------------------------------------===// // StorageGetOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Arc/Transforms/LowerState.cpp b/lib/Dialect/Arc/Transforms/LowerState.cpp index d71dd1d68ed1..0108b6bb4535 100644 --- a/lib/Dialect/Arc/Transforms/LowerState.cpp +++ b/lib/Dialect/Arc/Transforms/LowerState.cpp @@ -373,7 +373,7 @@ Value ModuleLowering::detectPosedge(Value clock) { // Read the old clock value from storage and write the new clock value to // storage. auto oldClock = StateReadOp::create(builder, loc, oldStorage); - StateWriteOp::create(builder, loc, oldStorage, clock, Value{}); + StateWriteOp::create(builder, loc, oldStorage, clock); // Detect a rising edge. auto edge = comb::XorOp::create(builder, loc, oldClock, clock); @@ -497,8 +497,7 @@ LogicalResult OpLowering::lower(StateOp op) { auto state = module.getAllocatedState(result); if (!state) return failure(); - StateWriteOp::create(module.initialBuilder, value.getLoc(), state, value, - Value{}); + StateWriteOp::create(module.initialBuilder, value.getLoc(), state, value); } return success(); } @@ -628,8 +627,7 @@ LogicalResult OpLowering::lowerStateful( if (value.getType() != type) value = BitcastOp::create(module.builder, loweredReset.getLoc(), type, value); - StateWriteOp::create(module.builder, loweredReset.getLoc(), state, value, - Value{}); + StateWriteOp::create(module.builder, loweredReset.getLoc(), state, value); } module.builder.setInsertionPoint(ifResetOp.elseYield()); } @@ -664,7 +662,7 @@ LogicalResult OpLowering::lowerStateful( // Compute the transfer function and write its results to the state's storage. auto loweredResults = createMapping(loweredInputs); for (auto [state, value] : llvm::zip(states, loweredResults)) - StateWriteOp::create(module.builder, value.getLoc(), state, value, Value{}); + StateWriteOp::create(module.builder, value.getLoc(), state, value); // Since we just wrote the new state value to storage, insert read ops just // before the if op that keep the old value around for any later ops that @@ -782,8 +780,7 @@ LogicalResult OpLowering::lower(MemoryOp op) { } // Actually write to the memory. - MemoryWriteOp::create(module.builder, write.getLoc(), state, address, - Value{}, data); + MemoryWriteOp::create(module.builder, write.getLoc(), state, address, data); } return success(); @@ -808,7 +805,7 @@ LogicalResult OpLowering::lower(TapOp op) { alloc->setAttr("names", op.getNamesAttr()); state = alloc; } - StateWriteOp::create(module.builder, op.getLoc(), state, value, Value{}); + StateWriteOp::create(module.builder, op.getLoc(), state, value); return success(); } @@ -836,7 +833,7 @@ LogicalResult OpLowering::lower(InstanceOp op) { state->setAttr("name", module.builder.getStringAttr( op.getInstanceName() + "/" + cast(name).getValue())); - StateWriteOp::create(module.builder, value.getLoc(), state, value, Value{}); + StateWriteOp::create(module.builder, value.getLoc(), state, value); } // HACK: Also ensure that storage has been allocated for all outputs. @@ -869,7 +866,7 @@ LogicalResult OpLowering::lower(hw::OutputOp op) { auto state = RootOutputOp::create( module.allocBuilder, value.getLoc(), StateType::get(value.getType()), cast(name), module.storageArg); - StateWriteOp::create(module.builder, value.getLoc(), state, value, Value{}); + StateWriteOp::create(module.builder, value.getLoc(), state, value); } return success(); } @@ -1267,8 +1264,7 @@ Value OpLowering::lowerValue(seq::InitialOp op, OpResult result, Phase phase) { module.storageArg); OpBuilder::InsertionGuard guard(module.initialBuilder); module.initialBuilder.setInsertionPointAfterValue(value); - StateWriteOp::create(module.initialBuilder, value.getLoc(), state, value, - Value{}); + StateWriteOp::create(module.initialBuilder, value.getLoc(), state, value); } // Read back the value computed during the initial phase. diff --git a/test/Conversion/ArcToLLVM/lower-arc-to-llvm.mlir b/test/Conversion/ArcToLLVM/lower-arc-to-llvm.mlir index b18ec9e9e591..8c4a9a3b3251 100644 --- a/test/Conversion/ArcToLLVM/lower-arc-to-llvm.mlir +++ b/test/Conversion/ArcToLLVM/lower-arc-to-llvm.mlir @@ -77,14 +77,6 @@ func.func @StateUpdates(%arg0: !arc.storage<1>) { // CHECK-NEXT: [[LOAD:%.+]] = llvm.load [[PTR]] : !llvm.ptr -> i1 arc.state_write %0 = %1 : // CHECK-NEXT: llvm.store [[LOAD]], [[PTR]] : i1, !llvm.ptr - %false = hw.constant false - arc.state_write %0 = %false if %1 : - // CHECK-NEXT: [[FALSE:%.+]] = llvm.mlir.constant(false) - // CHECK-NEXT: llvm.cond_br [[LOAD]], [[BB1:\^.+]], [[BB2:\^.+]] - // CHECK-NEXT: [[BB1]]: - // CHECK-NEXT: llvm.store [[FALSE]], [[PTR]] - // CHECK-NEXT: llvm.br [[BB2]] - // CHECK-NEXT: [[BB2]]: return // CHECK-NEXT: llvm.return } @@ -117,18 +109,6 @@ func.func @MemoryUpdates(%arg0: !arc.storage<24>, %enable: i1) { // CHECK-NEXT: [[BB_RESUME]]([[LOADED:%.+]]: i42): // CHECK: [[ADDED:%.+]] = llvm.add [[LOADED]], [[LOADED]] - arc.memory_write %0[%c3_i19], %2 if %enable : <4 x i42, i19> - // CHECK-NEXT: [[ADDR:%.+]] = llvm.zext [[THREE]] : i19 to i20 - // CHECK-NEXT: [[FOUR:%.+]] = llvm.mlir.constant(4 - // CHECK-NEXT: [[INBOUNDS:%.+]] = llvm.icmp "ult" [[ADDR]], [[FOUR]] - // CHECK-NEXT: [[GEP:%.+]] = llvm.getelementptr [[PTR]][[[ADDR]]] : (!llvm.ptr, i20) -> !llvm.ptr, i64 - // CHECK-NEXT: [[COND:%.+]] = llvm.and %arg1, [[INBOUNDS]] - // CHECK-NEXT: llvm.cond_br [[COND]], [[BB_STORE:\^.+]], [[BB_RESUME:\^.+]] - // CHECK-NEXT: [[BB_STORE]]: - // CHECK-NEXT: llvm.store [[ADDED]], [[GEP]] : i42, !llvm.ptr - // CHECK-NEXT: llvm.br [[BB_RESUME]] - // CHECK-NEXT: [[BB_RESUME]]: - arc.memory_write %0[%c3_i19], %2 : <4 x i42, i19> // CHECK-NEXT: [[ADDR:%.+]] = llvm.zext [[THREE]] : i19 to i20 // CHECK-NEXT: [[FOUR:%.+]] = llvm.mlir.constant(4 diff --git a/test/Dialect/Arc/basic.mlir b/test/Dialect/Arc/basic.mlir index 0d2ed28a5419..c362fc0a510b 100644 --- a/test/Dialect/Arc/basic.mlir +++ b/test/Dialect/Arc/basic.mlir @@ -140,8 +140,6 @@ hw.module @memoryOps(in %clk: !seq.clock, in %en: i1, in %mask: i32, in %arg: i1 // CHECK: %{{.+}} = arc.memory_read [[MEM]][%c0_i32] : <4 x i32, i32> %2 = arc.memory_read %mem[%c0_i32] : <4 x i32, i32> - // CHECK-NEXT: arc.memory_write [[MEM]][%c0_i32], %c0_i32 if %en : <4 x i32, i32> - arc.memory_write %mem[%c0_i32], %c0_i32 if %en : <4 x i32, i32> // CHECK-NEXT: arc.memory_write [[MEM]][%c0_i32], %c0_i32 : <4 x i32, i32> arc.memory_write %mem[%c0_i32], %c0_i32 : <4 x i32, i32> @@ -381,8 +379,6 @@ func.func @ReadsWrites(%arg0: !arc.state, %arg1: i42, %arg2: i1) { arc.state_read %arg0 : // CHECK: arc.state_write %arg0 = %arg1 : arc.state_write %arg0 = %arg1 : - // CHECK: arc.state_write %arg0 = %arg1 if %arg2 : - arc.state_write %arg0 = %arg1 if %arg2 : return } diff --git a/test/Dialect/Arc/canonicalizers.mlir b/test/Dialect/Arc/canonicalizers.mlir index e358c8499eff..5ebbf3be69b6 100644 --- a/test/Dialect/Arc/canonicalizers.mlir +++ b/test/Dialect/Arc/canonicalizers.mlir @@ -79,16 +79,6 @@ hw.module @clockDomainDCE(in %clk: !seq.clock) { } } -// CHECK-LABEL: hw.module @memoryOps -hw.module @memoryOps(in %clk: i1, in %mem: !arc.memory<4 x i32, i32>, in %addr: i32, in %data: i32) { - %true = hw.constant true - // CHECK-NEXT: arc.memory_write %mem[%addr], %data : <4 x i32, i32> - arc.memory_write %mem[%addr], %data if %true : <4 x i32, i32> - - %false = hw.constant false - arc.memory_write %mem[%addr], %data if %false : <4 x i32, i32> -} - // CHECK-LABEL: hw.module @clockDomainCanonicalizer hw.module @clockDomainCanonicalizer(in %clk: !seq.clock, in %data: i32, out out0: i32, out out1: i1, out out2: i32, out out3: i32, out out4: i32) { %c0_i32 = hw.constant 0 : i32