@@ -161,16 +161,26 @@ struct ConstantStringOpLowering : public OpRewritePattern<sql::SQLConstantString
161161
162162 SymbolTableCollection symbolTable;
163163 symbolTable.getSymbolTable (module );
164-
165- auto expr = (op.getInput () + " \0 " ).str ();
164+ for (auto u: op.getResult ().getUsers ()){
165+ if (isa<sql::SQLStringConcatOp>(u)) return failure ();
166+ }
167+ auto expr = op.getInput ().str ();
166168 auto name = " str" + std::to_string ((long long int )(Operation *)op);
167- auto MT = MemRefType::get ({expr.size ()}, rewriter.getI8Type ());
169+ auto MT = MemRefType::get ({expr.size () + 1 }, rewriter.getI8Type ());
170+ // auto type = MemRefType::get(mt.getShape(), mt.getElementType(), {});
168171 auto getglob = rewriter.create <memref::GetGlobalOp>(op.getLoc (), MT, name);
172+
173+ SmallVector<char , 1 > data (expr.begin (), expr.end ());
174+ data.push_back (' \0 ' );
175+ auto attr = DenseElementsAttr::get<char >(
176+ RankedTensorType::get (MT.getShape (), MT.getElementType ()), data);
169177
170- rewriter.setInsertionPointToStart (module .getBody ());
171- auto res = rewriter.create <memref::GlobalOp>(op.getLoc (), rewriter.getStringAttr (name),
172- mlir::StringAttr (), mlir::TypeAttr::get (MT), rewriter.getStringAttr (expr), mlir::UnitAttr (), /* alignment*/ nullptr );
178+ auto loc = op.getLoc ();
173179 rewriter.replaceOpWithNewOp <memref::CastOp>(op, MemRefType::get ({-1 }, rewriter.getI8Type ()), getglob.getResult ());
180+ rewriter.setInsertionPointToStart (module .getBody ());
181+ auto res = rewriter.create <memref::GlobalOp>(loc, rewriter.getStringAttr (name),
182+ mlir::StringAttr (), mlir::TypeAttr::get (MT), attr, rewriter.getUnitAttr (), /* alignment*/ nullptr );
183+
174184 return success ();
175185 }
176186};
@@ -207,22 +217,27 @@ struct ToStringOpLowering : public OpRewritePattern<sql::SQLToStringOp> {
207217 Value current = rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
208218 MemRefType::get ({-1 }, rewriter.getI8Type ()), " SELECT " );
209219 bool prevColumn = false ;
210- for (auto v : selectOp.getColumns ()) {
211- Value columns = rewriter.create <sql::SQLToStringOp>(op.getLoc (),
212- MemRefType::get ({-1 }, rewriter.getI8Type ()), v);
213- Value args[] = { current, columns };
214- current = rewriter.create <sql::SQLStringConcatOp>(op.getLoc (),
215- MemRefType::get ({-1 }, rewriter.getI8Type ()),args);
220+ auto columns = selectOp.getColumns ();
221+ for (mlir::Value v : columns) {
216222 if (prevColumn) {
217- Value args[] = { current, rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
218- MemRefType::get ({-1 }, rewriter.getI8Type ()), " , " ) };
219- current = rewriter.create <sql::SQLStringConcatOp>(op.getLoc (),
220- MemRefType::get ({-1 }, rewriter.getI8Type ()),args);
223+ Value args[] = { current, rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
224+ MemRefType::get ({-1 }, rewriter.getI8Type ()), " , " ) };
225+ current = rewriter.create <sql::SQLStringConcatOp>(op.getLoc (),
226+ MemRefType::get ({-1 }, rewriter.getI8Type ()),args);
221227 }
228+ Value col = rewriter.create <sql::SQLToStringOp>(op.getLoc (),
229+ MemRefType::get ({-1 }, rewriter.getI8Type ()), v);
230+ Value args[] = { col, rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
231+ MemRefType::get ({-1 }, rewriter.getI8Type ()), " " )};
232+ col = rewriter.create <sql::SQLStringConcatOp>(op.getLoc (),
233+ MemRefType::get ({-1 }, rewriter.getI8Type ()), args);
234+ Value args2[] = { current, col };
235+ current = rewriter.create <sql::SQLStringConcatOp>(op.getLoc (),
236+ MemRefType::get ({-1 }, rewriter.getI8Type ()), args2);
222237 prevColumn = true ;
223238 }
224239 auto tableOp = selectOp.getTable ().getDefiningOp <sql::TableOp>();
225- if (! tableOp || !tableOp. getExpr (). empty () ) {
240+ if (tableOp) {
226241 Value args[] = { current, rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
227242 MemRefType::get ({-1 }, rewriter.getI8Type ()), " FROM " ) };
228243 current = rewriter.create <sql::SQLStringConcatOp>(op.getLoc (),
@@ -233,6 +248,16 @@ struct ToStringOpLowering : public OpRewritePattern<sql::SQLToStringOp> {
233248 MemRefType::get ({-1 }, rewriter.getI8Type ()),args2);
234249 }
235250 rewriter.replaceOp (op, current);
251+ } else if (auto selectAllOp = dyn_cast<sql::SelectAllOp>(definingOp)){
252+ auto table = rewriter.create <sql::SQLToStringOp>(op.getLoc (),
253+ MemRefType::get ({-1 }, rewriter.getI8Type ()), selectAllOp.getTable ());
254+ Value res = rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
255+ MemRefType::get ({-1 }, rewriter.getI8Type ()), " SELECT * FROM " );
256+ Value args[] = { res, table };
257+ res = rewriter.create <sql::SQLStringConcatOp>(op.getLoc (),
258+ MemRefType::get ({-1 }, rewriter.getI8Type ()),args);
259+
260+ rewriter.replaceOp (op, res);
236261 } else if (auto tabOp = dyn_cast<sql::TableOp>(definingOp)) {
237262 Value res = rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
238263 MemRefType::get ({-1 }, rewriter.getI8Type ()), tabOp.getExpr ());
@@ -244,6 +269,7 @@ struct ToStringOpLowering : public OpRewritePattern<sql::SQLToStringOp> {
244269 } else if (auto intOp = dyn_cast<sql::IntOp>(definingOp)){
245270 Value res = rewriter.create <sql::SQLConstantStringOp>(op.getLoc (),
246271 MemRefType::get ({-1 }, rewriter.getI8Type ()), intOp.getExpr ());
272+ llvm::errs () << " intOp: " << intOp.getExpr () << " \n " ;
247273 rewriter.replaceOp (op, res);
248274 } else {
249275 assert (0 && " unknown type to convert to string" );
@@ -280,6 +306,8 @@ struct ExecuteOpLowering : public OpRewritePattern<sql::ExecuteOp> {
280306 // auto name = "str" + std::to_string((long long int)(Operation *)command.getDefiningOp());
281307 command = rewriter.create <sql::SQLToStringOp>(op.getLoc (),
282308 MemRefType::get ({-1 }, rewriter.getI8Type ()), command);
309+ llvm::errs () << " command: " << command << " \n " ;
310+ llvm::errs () << " command type: " << command.getType () << " \n " ;
283311 // auto type = MemRefType::get({-1}, rewriter.getI8Type());
284312
285313
@@ -295,6 +323,12 @@ struct ExecuteOpLowering : public OpRewritePattern<sql::ExecuteOp> {
295323 Value res =
296324 rewriter.create <mlir::func::CallOp>(op.getLoc (), executefn, args)
297325 ->getResult (0 );
326+ res = rewriter.create <polygeist::Memref2PointerOp>(op.getLoc (),
327+ LLVM::LLVMPointerType::get (rewriter.getI8Type ()), res);
328+ res = rewriter.create <LLVM::PtrToIntOp>(
329+ op.getLoc (), rewriter.getI64Type (), res);
330+ res = rewriter.create <arith::IndexCastOp>(op.getLoc (),
331+ op.getType (), res);
298332
299333 rewriter.replaceOp (op, res);
300334
0 commit comments