55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66//
77// ===----------------------------------------------------------------------===//
8+ #include < algorithm>
9+ #include < regex>
810#include < string>
911#include < vector>
10- #include < regex>
11- #include < algorithm>
1212
1313#include " mlir/Dialect/Arith/IR/Arith.h"
1414#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1818#include " mlir/IR/OpImplementation.h"
1919#include " mlir/Interfaces/SideEffectInterfaces.h"
2020
21+ #include " polygeist/Ops.h"
22+ #include " sql/Parser.h"
2123#include " sql/SQLDialect.h"
2224#include " sql/SQLOps.h"
2325#include " sql/SQLTypes.h"
24- #include " sql/Parser.h"
25- #include " polygeist/Ops.h"
2626
2727#define GET_OP_CLASSES
2828#include " sql/SQLOps.cpp.inc"
4141#include " llvm/ADT/SetVector.h"
4242#include " llvm/Support/Debug.h"
4343
44-
45- #include " mlir/IR/Value.h"
44+ #include " mlir/IR/Attributes.h"
4645#include " mlir/IR/Builders.h"
46+ #include " mlir/IR/BuiltinTypes.h"
4747#include " mlir/IR/Location.h"
48- #include " mlir/IR/Attributes .h"
48+ #include " mlir/IR/Value .h"
4949#include " llvm/ADT/SmallVector.h"
50- #include " mlir/IR/BuiltinTypes.h"
5150
5251#define DEBUG_TYPE " sql"
5352
5453using namespace mlir ;
5554using namespace sql ;
5655using namespace mlir ::arith;
5756
58-
5957class GetValueOpTypeFix final : public OpRewritePattern<GetValueOp> {
6058public:
6159 using OpRewritePattern<GetValueOp>::OpRewritePattern;
@@ -67,38 +65,38 @@ class GetValueOpTypeFix final : public OpRewritePattern<GetValueOp> {
6765
6866 Value handle = op.getOperand (0 );
6967 if (!handle.getType ().isa <IndexType>()) {
70- handle = rewriter.create <IndexCastOp>(op.getLoc (),
71- rewriter.getIndexType (), handle);
72- changed = true ;
68+ handle = rewriter.create <IndexCastOp>(op.getLoc (),
69+ rewriter.getIndexType (), handle);
70+ changed = true ;
7371 }
7472 Value row = op.getOperand (1 );
7573 if (!row.getType ().isa <IndexType>()) {
76- row = rewriter.create <IndexCastOp>(op.getLoc (),
77- rewriter. getIndexType (), row);
78- changed = true ;
74+ row = rewriter.create <IndexCastOp>(op.getLoc (), rewriter. getIndexType (),
75+ row);
76+ changed = true ;
7977 }
8078 Value column = op.getOperand (2 );
8179 if (!column.getType ().isa <IndexType>()) {
82- column = rewriter.create <IndexCastOp>(op.getLoc (),
83- rewriter.getIndexType (), column);
84- changed = true ;
80+ column = rewriter.create <IndexCastOp>(op.getLoc (),
81+ rewriter.getIndexType (), column);
82+ changed = true ;
8583 }
8684
87- if (!changed) return failure ();
85+ if (!changed)
86+ return failure ();
8887
89- rewriter.replaceOpWithNewOp <GetValueOp>(op, op.getType (), handle, row, column);
88+ rewriter.replaceOpWithNewOp <GetValueOp>(op, op.getType (), handle, row,
89+ column);
9090
9191 return success (changed);
9292 }
9393};
9494
9595void GetValueOp::getCanonicalizationPatterns (RewritePatternSet &results,
96- MLIRContext *context) {
96+ MLIRContext *context) {
9797 results.insert <GetValueOpTypeFix>(context);
9898}
9999
100-
101-
102100class NumResultsOpTypeFix final : public OpRewritePattern<NumResultsOp> {
103101public:
104102 using OpRewritePattern<NumResultsOp>::OpRewritePattern;
@@ -108,34 +106,35 @@ class NumResultsOpTypeFix final : public OpRewritePattern<NumResultsOp> {
108106 bool changed = false ;
109107 Value handle = op->getOperand (0 );
110108
111- if (handle.getType ().isa <IndexType>() && op->getResultTypes ()[0 ].isa <IndexType>())
112- return failure ();
109+ if (handle.getType ().isa <IndexType>() &&
110+ op->getResultTypes ()[0 ].isa <IndexType>())
111+ return failure ();
113112
114113 if (!handle.getType ().isa <IndexType>()) {
115- handle = rewriter.create <IndexCastOp>(op.getLoc (),
116- rewriter.getIndexType (), handle);
117- changed = true ;
114+ handle = rewriter.create <IndexCastOp>(op.getLoc (),
115+ rewriter.getIndexType (), handle);
116+ changed = true ;
118117 }
119118
120- mlir::Value res = rewriter.create <NumResultsOp>(op.getLoc (), rewriter.getIndexType (), handle);
119+ mlir::Value res = rewriter.create <NumResultsOp>(
120+ op.getLoc (), rewriter.getIndexType (), handle);
121121
122122 if (op->getResultTypes ()[0 ].isa <IndexType>()) {
123- rewriter.replaceOp (op, res);
123+ rewriter.replaceOp (op, res);
124124 } else {
125- rewriter.replaceOpWithNewOp <IndexCastOp>(op, op->getResultTypes ()[0 ], res);
125+ rewriter.replaceOpWithNewOp <IndexCastOp>(op, op->getResultTypes ()[0 ],
126+ res);
126127 }
127128
128129 return success (changed);
129130 }
130131};
131132
132133void NumResultsOp::getCanonicalizationPatterns (RewritePatternSet &results,
133- MLIRContext *context) {
134+ MLIRContext *context) {
134135 results.insert <NumResultsOpTypeFix>(context);
135136}
136137
137-
138-
139138// class ExecuteOpTypeFix final : public OpRewritePattern<ExecuteOp> {
140139// public:
141140// using OpRewritePattern<ExecuteOp>::OpRewritePattern;
@@ -147,39 +146,44 @@ void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results,
147146// Value conn = op->getOperand(0);
148147// Value command = op->getOperand(1);
149148
150- // if (conn.getType().isa<IndexType>() && command.getType().isa<IndexType>() && op->getResultTypes()[0].isa<IndexType>())
149+ // if (conn.getType().isa<IndexType>() && command.getType().isa<IndexType>()
150+ // && op->getResultTypes()[0].isa<IndexType>())
151151// return failure();
152152
153153// if (!conn.getType().isa<IndexType>()) {
154154// conn = rewriter.create<IndexCastOp>(op.getLoc(),
155- // rewriter.getIndexType(), conn);
155+ // rewriter.getIndexType(),
156+ // conn);
156157// changed = true;
157158// }
158159// if (command.getType().isa<MemRefType>()) {
159- // command = rewriter.create<polygeist::Memref2PointerOp>(op.getLoc(),
160- // LLVM::LLVMPointerType::get(rewriter.getI8Type()), command);
160+ // command = rewriter.create<polygeist::Memref2PointerOp>(op.getLoc(),
161+ // LLVM::LLVMPointerType::get(rewriter.getI8Type()),
162+ // command);
161163// changed = true;
162164// }
163165
164-
165166// if (command.getType().isa<LLVM::LLVMPointerType>()) {
166- // command = rewriter.create<LLVM::PtrToIntOp>(op.getLoc(),
167- // rewriter.getI64Type(), command);
167+ // command = rewriter.create<LLVM::PtrToIntOp>(op.getLoc(),
168+ // rewriter.getI64Type(),
169+ // command);
168170// changed = true;
169171// }
170172// if (!command.getType().isa<IndexType>()) {
171- // command = rewriter.create<IndexCastOp>(op.getLoc(),
172- // rewriter.getIndexType(), command);
173+ // command = rewriter.create<IndexCastOp>(op.getLoc(),
174+ // rewriter.getIndexType(),
175+ // command);
173176// changed = true;
174177// }
175178
176179// if (!changed) return failure();
177- // mlir::Value res = rewriter.create<ExecuteOp>(op.getLoc(), rewriter.getIndexType(), conn, command);
178- // rewriter.replaceOp(op, res);
180+ // mlir::Value res = rewriter.create<ExecuteOp>(op.getLoc(),
181+ // rewriter.getIndexType(), conn, command); rewriter. replaceOp(op, res);
179182// // if (op->getResultTypes()[0].isa<IndexType>()) {
180183// // rewriter.replaceOp(op, res);
181184// // } else {
182- // // rewriter.replaceOpWithNewOp<IndexCastOp>(op, op->getResultTypes()[0], res);
185+ // // rewriter.replaceOpWithNewOp<IndexCastOp>(op,
186+ // op->getResultTypes()[0], res);
183187// // }
184188// return success(changed);
185189// }
@@ -190,8 +194,7 @@ void NumResultsOp::getCanonicalizationPatterns(RewritePatternSet &results,
190194// results.insert<ExecuteOpTypeFix>(context);
191195// }
192196
193-
194- template <typename T>
197+ template <typename T>
195198class UnparsedOpInnerCast final : public OpRewritePattern<UnparsedOp> {
196199public:
197200 using OpRewritePattern<UnparsedOp>::OpRewritePattern;
@@ -200,39 +203,91 @@ class UnparsedOpInnerCast final : public OpRewritePattern<UnparsedOp> {
200203 PatternRewriter &rewriter) const override {
201204
202205 Value input = op->getOperand (0 );
203-
206+
204207 auto cst = input.getDefiningOp <T>();
205- if (!cst) return failure ();
208+ if (!cst)
209+ return failure ();
206210
207211 rewriter.replaceOpWithNewOp <UnparsedOp>(op, op.getType (), cst.getOperand ());
208212 return success ();
209213 }
210214};
211215
212216void UnparsedOp::getCanonicalizationPatterns (RewritePatternSet &results,
213- MLIRContext *context) {
214- results.insert <UnparsedOpInnerCast<polygeist::Pointer2MemrefOp> >(context);
217+ MLIRContext *context) {
218+ results.insert <UnparsedOpInnerCast<polygeist::Pointer2MemrefOp>>(context);
215219}
216220
217-
218- class SQLStringConcatOpCanonicalization final : public OpRewritePattern<SQLStringConcatOp> {
221+ class SQLStringConcatOpCanonicalization final
222+ : public OpRewritePattern<SQLStringConcatOp> {
219223public:
220224 using OpRewritePattern<SQLStringConcatOp>::OpRewritePattern;
221225
222226 LogicalResult matchAndRewrite (SQLStringConcatOp op,
223227 PatternRewriter &rewriter) const override {
224-
225- auto input1 = op->getOperand (0 ).getDefiningOp <SQLConstantStringOp>();
226- auto input2 = op->getOperand (1 ).getDefiningOp <SQLConstantStringOp>();
227-
228- if (!input1 || !input2) return failure ();
229-
230- rewriter.replaceOpWithNewOp <SQLConstantStringOp>(op, op.getType (), (input1.getInput () + input2.getInput ()).str ());
231- return success ();
228+ // Whether we changed the state. If we make no simplifications we need to
229+ // return failure otherwise we will infinite loop
230+ bool changed = false ;
231+ // Operands to the simplified concat
232+ SmallVector<Value> operands;
233+ // Constants that we will merge, "current running constant"
234+ SmallVector<SQLConstantStringOp> constants;
235+ for (auto op : op->getOperands ()) {
236+ if (auto constOp = op.getDefiningOp <SQLConstantStringOp>()) {
237+ constants.push_back (constOp);
238+ continue ;
239+ }
240+ if (constants.size () != 0 ) {
241+ if (constants.size () == 1 ) {
242+ operands.push_back (constants[0 ]);
243+ } else {
244+ std::string nextStr;
245+ changed = true ;
246+ for (auto str : constants)
247+ nextStr += str.getInput ().str ();
248+
249+ operands.push_back (rewriter.create <SQLConstantStringOp>(
250+ op.getLoc (), MemRefType::get ({-1 }, rewriter.getI8Type ()), nextStr));
251+ }
252+ }
253+ constants.clear ();
254+ if (auto concat = op.getDefiningOp <SQLStringConcatOp>()) {
255+ changed = true ;
256+ for (auto op2 : concat->getOperands ())
257+ operands.push_back (op2);
258+ continue ;
259+ }
260+ operands.push_back (op);
261+ }
262+ if (constants.size () != 0 ) {
263+ if (constants.size () == 1 ) {
264+ operands.push_back (constants[0 ]);
265+ } else {
266+ std::string nextStr;
267+ changed = true ;
268+ for (auto str : constants)
269+ nextStr = nextStr + str.getInput ().str ();
270+ operands.push_back (rewriter.create <SQLConstantStringOp>(
271+ op.getLoc (), MemRefType::get ({-1 }, rewriter.getI8Type ()), nextStr));
272+ }
273+ }
274+ if (operands.size () == 0 ) {
275+ rewriter.replaceOpWithNewOp <SQLConstantStringOp>(op, MemRefType::get ({-1 }, rewriter.getI8Type ()), " " );
276+ return success ();
277+ }
278+ if (operands.size () == 1 ) {
279+ rewriter.replaceOp (op, operands[0 ]);
280+ return success ();
281+ }
282+ if (changed) {
283+ rewriter.replaceOpWithNewOp <SQLStringConcatOp>(op, MemRefType::get ({-1 }, rewriter.getI8Type ()), operands);
284+ return success ();
285+ }
286+ return failure ();
232287 }
233288};
234289
235290void SQLStringConcatOp::getCanonicalizationPatterns (RewritePatternSet &results,
236- MLIRContext *context) {
291+ MLIRContext *context) {
237292 results.insert <SQLStringConcatOpCanonicalization>(context);
238- }
293+ }
0 commit comments