@@ -49,6 +49,7 @@ llvm::APInt MajorityInverterOp::evaluate(ArrayRef<APInt> inputs) {
4949 auto width = inputs[0 ].getBitWidth ();
5050 APInt result (width, 0 );
5151
52+ // is every input of same width ?
5253 for (size_t bit = 0 ; bit < width; ++bit) {
5354 size_t count = 0 ;
5455 for (size_t i = 0 ; i < inputs.size (); ++i) {
@@ -67,12 +68,78 @@ llvm::APInt MajorityInverterOp::evaluate(ArrayRef<APInt> inputs) {
6768OpFoldResult MajorityInverterOp::fold (FoldAdaptor adaptor) {
6869 // TODO: Implement maj(x, 1, 1) = 1, maj(x, 0, 0) = 0
6970
71+ // x x 1 -> x
72+ // x 1 1 -> 1
73+ // x ~x 1 -> 1
74+ // x 0 0 -> 0
75+ // x 1 1 -> 1
76+
77+ if (getNumOperands () != 3 )
78+ return {};
79+ bool isOpConstant = true ;
80+ // for all constant inputs
7081 SmallVector<APInt, 3 > inputValues;
82+ size_t i = 0 ;
7183 for (auto input : adaptor.getInputs ()) {
7284 auto attr = llvm::dyn_cast_or_null<IntegerAttr>(input);
7385 if (!attr)
74- return {};
75- inputValues.push_back (attr.getValue ());
86+ isOpConstant = false ;
87+ else {
88+ auto value = isInverted (i) ? ~attr.getValue () : attr.getValue ();
89+ inputValues.push_back (value);
90+ }
91+ i++;
92+ }
93+ if (!isOpConstant) {
94+ // x 0 0
95+ // x 1 1
96+ if (inputValues.size () == 2 ) {
97+ if (inputValues[0 ] != inputValues[1 ])
98+ return {};
99+ else
100+ return IntegerAttr::get (
101+ IntegerType::get (getContext (), inputValues[0 ].getBitWidth ()),
102+ inputValues[0 ]);
103+ }
104+ auto getConstant = [&](unsigned index) -> std::optional<llvm::APInt> {
105+ APInt value;
106+ if (mlir::matchPattern (getInputs ()[index], mlir::m_ConstantInt (&value)))
107+ return isInverted (index) ? ~value : value;
108+ return std::nullopt ;
109+ };
110+ // Pattern match following cases:
111+ // maj_inv(x, x, y) -> x
112+ // maj_inv(x, y, not y) -> x
113+ for (int i = 0 ; i < 2 ; ++i) {
114+ for (int j = i + 1 ; j < 3 ; ++j) {
115+ int k = 3 - (i + j);
116+ assert (k >= 0 && k < 3 );
117+ // If we have two identical operands, we can fold.
118+ if (getOperand (i) == getOperand (j)) {
119+ // If they are inverted differently, we can fold to the third.
120+ if (isInverted (i) != isInverted (j)) {
121+ return getOperand (k);
122+ }
123+ return getOperand (i);
124+ }
125+
126+ // If i and j are constant.
127+ if (auto c1 = getConstant (i)) {
128+ if (auto c2 = getConstant (j)) {
129+ // If both constants are equal, we can fold.
130+ if (*c1 == *c2) {
131+ // auto value = cast<IntegerAttr>(getInputs()[i]).getValue();
132+ // return IntegerAttr::get(IntegerType::get(getContext(),
133+ // value.getBitWidth()),value);
134+ return IntegerAttr::get (getType (), *c1);
135+ }
136+ // If constants are complementary, we can fold.
137+ if (*c1 == ~*c2)
138+ return getOperand (k);
139+ }
140+ }
141+ }
142+ }
76143 }
77144
78145 auto result = evaluate (inputValues);
@@ -112,39 +179,36 @@ LogicalResult MajorityInverterOp::canonicalize(MajorityInverterOp op,
112179 return success ();
113180 };
114181
115- // Pattern match following cases:
116- // maj_inv(x, x, y) -> x
117- // maj_inv(x, y, not y) -> x
118- for (int i = 0 ; i < 2 ; ++i) {
119- for (int j = i + 1 ; j < 3 ; ++j) {
120- int k = 3 - (i + j);
121- assert (k >= 0 && k < 3 );
122- // If we have two identical operands, we can fold.
123- if (op.getOperand (i) == op.getOperand (j)) {
124- // If they are inverted differently, we can fold to the third.
125- if (op.isInverted (i) != op.isInverted (j)) {
126- return replaceWithIndex (k);
127- }
128- rewriter.replaceOp (op, op.getOperand (i));
129- return success ();
130- }
131-
132- // If i and j are constant.
133- if (auto c1 = getConstant (i)) {
134- if (auto c2 = getConstant (j)) {
135- // If both constants are equal, we can fold.
136- if (*c1 == *c2) {
137- rewriter.replaceOpWithNewOp <hw::ConstantOp>(
138- op, op.getType (), mlir::IntegerAttr::get (op.getType (), *c1));
139- return success ();
140- }
141- // If constants are complementary, we can fold.
142- if (*c1 == ~*c2)
143- return replaceWithIndex (k);
144- }
145- }
146- }
147- }
182+ // for (int i = 0; i < 2; ++i) {
183+ // for (int j = i + 1; j < 3; ++j) {
184+ // int k = 3 - (i + j);
185+ // assert(k >= 0 && k < 3);
186+ // // If we have two identical operands, we can fold.
187+ // if (op.getOperand(i) == op.getOperand(j)) {
188+ // // If they are inverted differently, we can fold to the third.
189+ // if (op.isInverted(i) != op.isInverted(j)) {
190+ // return replaceWithIndex(k);
191+ // }
192+ // rewriter.replaceOp(op, op.getOperand(i));
193+ // return success();
194+ // }
195+
196+ // // If i and j are constant.
197+ // if (auto c1 = getConstant(i)) {
198+ // if (auto c2 = getConstant(j)) {
199+ // // If both constants are equal, we can fold.
200+ // if (*c1 == *c2) {
201+ // rewriter.replaceOpWithNewOp<hw::ConstantOp>(
202+ // op, op.getType(), mlir::IntegerAttr::get(op.getType(), *c1));
203+ // return success();
204+ // }
205+ // // If constants are complementary, we can fold.
206+ // if (*c1 == ~*c2)
207+ // return replaceWithIndex(k);
208+ // }
209+ // }
210+ // }
211+ // }
148212 return failure ();
149213}
150214
0 commit comments