Skip to content

Commit d5dcd5b

Browse files
committed
feat : Fold Majority Graph
1 parent acbf18e commit d5dcd5b

File tree

1 file changed

+99
-35
lines changed

1 file changed

+99
-35
lines changed

lib/Dialect/Synth/SynthOps.cpp

Lines changed: 99 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ llvm::APInt MajorityInverterOp::evaluate(ArrayRef<APInt> inputs) {
4949
auto width = inputs[0].getBitWidth();
5050
APInt result(width, 0);
5151

52+
// is every input of same width ?
5253
for (size_t bit = 0; bit < width; ++bit) {
5354
size_t count = 0;
5455
for (size_t i = 0; i < inputs.size(); ++i) {
@@ -67,12 +68,78 @@ llvm::APInt MajorityInverterOp::evaluate(ArrayRef<APInt> inputs) {
6768
OpFoldResult MajorityInverterOp::fold(FoldAdaptor adaptor) {
6869
// TODO: Implement maj(x, 1, 1) = 1, maj(x, 0, 0) = 0
6970

71+
// x x 1 -> x
72+
// x 1 1 -> 1
73+
// x ~x 1 -> 1
74+
// x 0 0 -> 0
75+
// x 1 1 -> 1
76+
77+
if (getNumOperands() != 3)
78+
return {};
79+
bool isOpConstant = true;
80+
// for all constant inputs
7081
SmallVector<APInt, 3> inputValues;
82+
size_t i = 0;
7183
for (auto input : adaptor.getInputs()) {
7284
auto attr = llvm::dyn_cast_or_null<IntegerAttr>(input);
7385
if (!attr)
74-
return {};
75-
inputValues.push_back(attr.getValue());
86+
isOpConstant = false;
87+
else {
88+
auto value = isInverted(i) ? ~attr.getValue() : attr.getValue();
89+
inputValues.push_back(value);
90+
}
91+
i++;
92+
}
93+
if (!isOpConstant) {
94+
// x 0 0
95+
// x 1 1
96+
if (inputValues.size() == 2) {
97+
if (inputValues[0] != inputValues[1])
98+
return {};
99+
else
100+
return IntegerAttr::get(
101+
IntegerType::get(getContext(), inputValues[0].getBitWidth()),
102+
inputValues[0]);
103+
}
104+
auto getConstant = [&](unsigned index) -> std::optional<llvm::APInt> {
105+
APInt value;
106+
if (mlir::matchPattern(getInputs()[index], mlir::m_ConstantInt(&value)))
107+
return isInverted(index) ? ~value : value;
108+
return std::nullopt;
109+
};
110+
// Pattern match following cases:
111+
// maj_inv(x, x, y) -> x
112+
// maj_inv(x, y, not y) -> x
113+
for (int i = 0; i < 2; ++i) {
114+
for (int j = i + 1; j < 3; ++j) {
115+
int k = 3 - (i + j);
116+
assert(k >= 0 && k < 3);
117+
// If we have two identical operands, we can fold.
118+
if (getOperand(i) == getOperand(j)) {
119+
// If they are inverted differently, we can fold to the third.
120+
if (isInverted(i) != isInverted(j)) {
121+
return getOperand(k);
122+
}
123+
return getOperand(i);
124+
}
125+
126+
// If i and j are constant.
127+
if (auto c1 = getConstant(i)) {
128+
if (auto c2 = getConstant(j)) {
129+
// If both constants are equal, we can fold.
130+
if (*c1 == *c2) {
131+
// auto value = cast<IntegerAttr>(getInputs()[i]).getValue();
132+
// return IntegerAttr::get(IntegerType::get(getContext(),
133+
// value.getBitWidth()),value);
134+
return IntegerAttr::get(getType(), *c1);
135+
}
136+
// If constants are complementary, we can fold.
137+
if (*c1 == ~*c2)
138+
return getOperand(k);
139+
}
140+
}
141+
}
142+
}
76143
}
77144

78145
auto result = evaluate(inputValues);
@@ -112,39 +179,36 @@ LogicalResult MajorityInverterOp::canonicalize(MajorityInverterOp op,
112179
return success();
113180
};
114181

115-
// Pattern match following cases:
116-
// maj_inv(x, x, y) -> x
117-
// maj_inv(x, y, not y) -> x
118-
for (int i = 0; i < 2; ++i) {
119-
for (int j = i + 1; j < 3; ++j) {
120-
int k = 3 - (i + j);
121-
assert(k >= 0 && k < 3);
122-
// If we have two identical operands, we can fold.
123-
if (op.getOperand(i) == op.getOperand(j)) {
124-
// If they are inverted differently, we can fold to the third.
125-
if (op.isInverted(i) != op.isInverted(j)) {
126-
return replaceWithIndex(k);
127-
}
128-
rewriter.replaceOp(op, op.getOperand(i));
129-
return success();
130-
}
131-
132-
// If i and j are constant.
133-
if (auto c1 = getConstant(i)) {
134-
if (auto c2 = getConstant(j)) {
135-
// If both constants are equal, we can fold.
136-
if (*c1 == *c2) {
137-
rewriter.replaceOpWithNewOp<hw::ConstantOp>(
138-
op, op.getType(), mlir::IntegerAttr::get(op.getType(), *c1));
139-
return success();
140-
}
141-
// If constants are complementary, we can fold.
142-
if (*c1 == ~*c2)
143-
return replaceWithIndex(k);
144-
}
145-
}
146-
}
147-
}
182+
// for (int i = 0; i < 2; ++i) {
183+
// for (int j = i + 1; j < 3; ++j) {
184+
// int k = 3 - (i + j);
185+
// assert(k >= 0 && k < 3);
186+
// // If we have two identical operands, we can fold.
187+
// if (op.getOperand(i) == op.getOperand(j)) {
188+
// // If they are inverted differently, we can fold to the third.
189+
// if (op.isInverted(i) != op.isInverted(j)) {
190+
// return replaceWithIndex(k);
191+
// }
192+
// rewriter.replaceOp(op, op.getOperand(i));
193+
// return success();
194+
// }
195+
196+
// // If i and j are constant.
197+
// if (auto c1 = getConstant(i)) {
198+
// if (auto c2 = getConstant(j)) {
199+
// // If both constants are equal, we can fold.
200+
// if (*c1 == *c2) {
201+
// rewriter.replaceOpWithNewOp<hw::ConstantOp>(
202+
// op, op.getType(), mlir::IntegerAttr::get(op.getType(), *c1));
203+
// return success();
204+
// }
205+
// // If constants are complementary, we can fold.
206+
// if (*c1 == ~*c2)
207+
// return replaceWithIndex(k);
208+
// }
209+
// }
210+
// }
211+
//}
148212
return failure();
149213
}
150214

0 commit comments

Comments
 (0)