diff --git a/src/ae/AeValSolver.hpp b/src/ae/AeValSolver.hpp index 9287cd61d..59832877e 100644 --- a/src/ae/AeValSolver.hpp +++ b/src/ae/AeValSolver.hpp @@ -173,13 +173,15 @@ namespace ufo { ExprSet lits; u.getTrueLiterals(pr, m, lits, true); - pr = simplifyArithm(mixQE(conjoin(lits, efac), exp, m, u, debug)); + outs() << "Lits with var : { \n"; + for (auto a: lits) + if (contains(a, exp)) + outs() << a << "\n"; + outs() << "}\n" << std::endl; + pr = mixQE(conjoin(lits, efac), exp, m, u, debug); if(m.eval(exp) != exp) modelMap[exp] = mk(exp, m.eval(exp)); - if(debug) - MBPSanityCheck(m, pr); - if(debug >= 2) { outs() << "\nmodel " << partitioning_size << ":\n"; @@ -198,8 +200,12 @@ namespace ufo } outs() << "projection:\n"; pprint(pr, 2); + outs() << std::endl; } + if(debug) + MBPSanityCheck(m, pr); + for(auto it = lits.begin(); it != lits.end();) { if(contains(*it, exp)) @@ -223,10 +229,14 @@ namespace ufo { assert(isOpX(m.eval(pr))); ExprVector args; - for(auto temp : v) + ExprVector argsPr; + for(auto temp : v) { args.push_back(temp->last()); + argsPr.push_back(temp->last()); + } args.push_back(t); - boost::tribool impl = u.implies(pr, mknary(args)); + argsPr.push_back(pr); + boost::tribool impl = u.implies(mknary(argsPr), mknary(args)); tribool_assert(impl); }; @@ -1267,8 +1277,8 @@ namespace ufo minusSets(ex_qvars, fa_qvars); } - s = convertIntsToReals
(s); - t = convertIntsToReals
(t); + s = convertIdivToDiv
(s); + t = convertIdivToDiv
(t); if(debug >= 3) { diff --git a/src/ae/ExprSimpl.hpp b/src/ae/ExprSimpl.hpp index 47a8da07c..2c0092633 100644 --- a/src/ae/ExprSimpl.hpp +++ b/src/ae/ExprSimpl.hpp @@ -532,6 +532,8 @@ namespace ufo return e; } + static Expr realSimplifyMult(Expr fla); + /** * Helper used in ineqMover */ @@ -539,7 +541,7 @@ namespace ufo Expr l = e->left(); Expr r = e->right(); ExprVector orig_lhs, orig_rhs, lhs, rhs; - + ExprVector realCoefs; // parse getAddTerm(l, orig_lhs); @@ -568,6 +570,20 @@ namespace ufo coef -= lexical_cast(subExpr->left()); found = true; } + else if (isOpX(subExpr) && isRealConst(var)) { + ExprVector tmp; + int skipped = 0; + for (int i = 0; i < (*it)->arity(); ++i) + { + if (subExpr->arg(i) != var) + tmp.push_back(subExpr->arg(i)); + else + skipped++; + } + assert(skipped == 1); + realCoefs.push_back(mk(mkmult(tmp, var->efac()))); + found = true; + } else if (isOp(subExpr) && 2 == subExpr->arity() && isOpX(subExpr->right()) && subExpr->left() == var) { coef -= rational(1, lexical_cast(subExpr->right())); found = true; @@ -577,6 +593,20 @@ namespace ufo coef += lexical_cast((*it)->left()); found = true; } + else if (isOpX(*it) && isRealConst(var)) { + ExprVector tmp; + int skipped = 0; + for (int i = 0; i < (*it)->arity(); ++i) + { + if ((*it)->arg(i) != var) + tmp.push_back((*it)->arg(i)); + else + skipped++; + } + assert(skipped == 1); + realCoefs.push_back(mkmult(tmp, var->efac())); + found = true; + } if (isOp(*it) && 2 == (*it)->arity() && isOpX((*it)->right()) && (*it)->left() == var) { coef += rational(1, lexical_cast((*it)->right())); found = true; @@ -594,7 +624,12 @@ namespace ufo } r = mkplus(rhs, e->getFactory()); - + if (!realCoefs.empty()) + { + l = mkplus(realCoefs, var->efac()); + l = mk(l, var); + return mk(l,r); + } if (coef == 0){ l = mkMPZ (0, e->getFactory()); } else if (coef == 1){ @@ -1857,6 +1892,7 @@ namespace ufo } static Expr rewriteMultAdd (Expr exp); + static Expr rewriteDivAdd (Expr exp); inline static void getAddTerm (Expr a, ExprVector &terms) // implementation (mutually recursive) { @@ -1900,10 +1936,42 @@ namespace ufo } else if (isOpX(a)) { + if (a->arity() == 2) { + Expr lhs = a->left(); + Expr rhs = a->right(); + if (isOpX
(lhs)) { + Expr tmp = mk
(mk(rhs, lhs->left() ), lhs->right()); + getAddTerm(tmp, terms); + return; + } + else if (isOpX
(rhs)) { + Expr tmp = mk
(mk(lhs, rhs->left()), rhs->right()); + getAddTerm(tmp, terms); + return; + } + else if (isOpX(lhs) && isOpX
(lhs->left())) { + rhs = additiveInverse(rhs); + Expr tmp = mk
(mk(rhs, lhs->left()->left() ), lhs->left()->right()); + getAddTerm(tmp, terms); + return; + } + else if (isOpX(rhs) && isOpX
(rhs->left())) { + lhs = additiveInverse(lhs); + Expr tmp = mk
(mk(rhs->left()->left(), lhs), rhs->left()->right()); + getAddTerm(tmp, terms); + return; + } + } Expr tmp = rewriteMultAdd(a); if (tmp == a) terms.push_back(a); else getAddTerm(tmp, terms); } + else if (isOpX
(a)) + { + Expr tmp = rewriteDivAdd(a); + if (tmp == a) terms.push_back(a); + else getAddTerm(tmp, terms); + } else if (lexical_cast(a) != "0") { bool found = false; @@ -1959,6 +2027,38 @@ namespace ufo return dagVisit (mu, exp); } + struct AddDivDistr + { + AddDivDistr () {}; + + Expr operator() (Expr exp) + { + if (isOpX
(exp) && exp->arity() == 2) + { + Expr lhs = exp->left(); + Expr rhs = exp->right(); + + ExprVector alllhs; + getAddTerm(lhs, alllhs); + + ExprVector unf; + for (auto &a : alllhs) + { + unf.push_back(mk
(a, rhs)); + } + return mkplus(unf, exp->getFactory()); + } + + return exp; + } + }; + + inline static Expr rewriteDivAdd (Expr exp) + { + RW mu(new AddDivDistr()); + return dagVisit (mu, exp); + } + struct FindNonlinAndRewrite { ExprVector& vars; @@ -2213,6 +2313,147 @@ namespace ufo return typeOf(e) == mk(e->getFactory()); } + static Expr divSimplifier(Expr fla, int& minusOps) + { + Expr tmp = fla; + ExprFactory& efac = fla->getFactory(); + ExprVector dividers; + + while (true) + { + if (isOpX
(tmp)) + { + dividers.push_back(tmp->right()); + if(lexical_cast(tmp->right())[0] == '-') + minusOps++; // TODO: complicated divs + tmp = tmp->left(); + } + else if (isOpX(tmp) && isOpX
(tmp->left())) + { + dividers.push_back(tmp->left()->right()); + if(lexical_cast(tmp->left()->right())[0] == '-') + minusOps++; // TODO: complicated divs + tmp = additiveInverse(tmp->left()->left()); + } + else + break; + } + + if (dividers.size() == 0) + return fla; + return mk
(tmp, mkmult(dividers, efac)); + } + + static Expr realRewriteDivs(Expr fla, Expr var) + { + assert (isOp(fla)); + ExprFactory& efac = var->efac(); + ExprVector plusOpsLeft; + ExprVector plusOpsRight; + + ExprVector lhss; + ExprVector rhss; + int minusOps = 0; + + getAddTerm(fla->left(), plusOpsLeft); + getAddTerm(fla->right(), plusOpsRight); + + for (auto r : plusOpsRight) + plusOpsLeft.push_back(additiveInverse(r)); + + ExprSet divs; + for (auto it = plusOpsLeft.begin(); it != plusOpsLeft.end(); it++) + { + if(!contains(*it, var)) + continue; + *it = divSimplifier(*it, minusOps); + if(isOpX
(*it)) + divs.insert((*it)->right()); + else if (isOpX(*it) && isOpX
((*it)->left())) + divs.insert((*it)->left()->right()); + } + + for(auto ite = plusOpsLeft.begin(); ite != plusOpsLeft.end(); ite++) + { + if (!contains(*ite, var)) { + Expr m = mkmult(divs, efac); + if (m != mkMPZ (1, efac)) + *ite = mk(*ite, m); + } + else if (isOpX
(*ite)) + { + Expr d = (*ite)->right(); + divs.erase(d); + Expr m = mkmult(divs, efac); + if (m != mkMPZ (1, efac)) + *ite = mk((*ite)->left(), m); + else + *ite = (*ite)->left(); + divs.insert(d); + } + else if (isOpX(*ite) && isOpX
((*ite)->left())) + { + Expr d = (*ite)->left()->right(); + divs.erase(d); + Expr m = mkmult(divs, efac); + if (m != mkMPZ (1, efac)) + *ite = mk(mk((*ite)->left()->left()), m); + else + *ite = mk((*ite)->left()->left()); + divs.insert(d); + } + else { + Expr m = mkmult(divs, efac); + if (m != mkMPZ (1, efac)) + *ite = mk(*ite, m); + } + } + if (minusOps % 2 == 0) + return (mk(fla->op(), mkplus(plusOpsLeft, efac), mkMPZ (0, efac))); + return reBuildCmpSym(fla, mkMPZ (0, efac), mkplus(plusOpsLeft, efac)); + } + + static void realMultHelper(Expr fla, ExprVector& mults) + { + if (!isOpX(fla)) + mults.push_back(fla); + else + { + for (int i = 0; i < fla->arity(); i++) + realMultHelper(fla->arg(i), mults); + } + } + + static Expr realSimplifyMult(Expr fla) + { + ExprFactory& efac = fla->getFactory(); + ExprVector plusOpsLeft; + ExprVector plusOpsRight; + int minusOps = 0; + getAddTerm(fla->left(), plusOpsLeft); + getAddTerm(fla->right(), plusOpsRight); + + for(auto ite = plusOpsRight.begin(); ite != plusOpsRight.end(); ite++) + { + *ite = divSimplifier(*ite, minusOps); + if (!isOpX(*ite)) + continue; + ExprVector mults; + realMultHelper(*ite, mults); + *ite = mkmult(mults, efac); + } + for(auto ite = plusOpsLeft.begin(); ite != plusOpsLeft.end(); ite++) + { + *ite = divSimplifier(*ite, minusOps); + if (!isOpX(*ite)) + continue; + ExprVector mults; + realMultHelper(*ite, mults); + *ite = mkmult(mults, efac); + } + return mk(fla->op(), mkplus(plusOpsLeft, efac), mkplus(plusOpsRight, efac)); + } + inline Expr rewriteDivConstraints(Expr fla) { // heuristic for the divisibility constraints @@ -2355,11 +2596,38 @@ namespace ufo return fla; } - template static Expr convertIntsToReals (Expr exp); + static Expr convertIntsToReals (Expr exp); - template struct IntToReal + struct AllIntsToReals { - IntToReal () {}; + AllIntsToReals () {}; + + Expr operator() (Expr exp) + { + ExprVector args; + for (int i = 0; i < exp->arity(); i++) + { + Expr e = exp->arg(i); + if (isOpX(e)) + e = mkTerm (mpq_class (lexical_cast(e)), exp->getFactory()); + else if (bind::isIntConst(e)) + e = realConst(fname(e)); + args.push_back(e); + } + return mknary(exp->op(), args.begin(), args.end()); + } + }; + + static Expr convertIntsToReals (Expr exp) + { + RW rw(new AllIntsToReals()); + return dagVisit (rw, exp); + } + + template static Expr convertIdivToDiv (Expr exp); + template struct IdivToDiv + { + IdivToDiv () {}; Expr operator() (Expr exp) { @@ -2372,10 +2640,10 @@ namespace ufo if (isOpX(e)) e = mkTerm (mpq_class (lexical_cast(e)), exp->getFactory()); else { - e = convertIntsToReals(e); - e = convertIntsToReals(e); - e = convertIntsToReals(e); - e = convertIntsToReals(e); + e = convertIdivToDiv(e); + e = convertIdivToDiv(e); + e = convertIdivToDiv(e); + e = convertIdivToDiv(e); } args.push_back(e); } @@ -2385,9 +2653,9 @@ namespace ufo } }; - template static Expr convertIntsToReals (Expr exp) + template static Expr convertIdivToDiv (Expr exp) { - RW> rw(new IntToReal()); + RW> rw(new IdivToDiv()); return dagVisit (rw, exp); } @@ -4558,6 +4826,117 @@ namespace ufo return false; } + + + enum laType { + REALTYPE = 0, + INTTYPE, + MIXTYPE, + NOTYPE + }; + + /** + * intOrReal - checks expression type + */ + static int intOrReal(Expr s) + { + ExprVector sVec; + bool realType = false, intType = false; + filter(s, bind::IsNumber(), back_inserter(sVec)); + filter(s, bind::IsConst(), back_inserter(sVec)); + for(auto ite : sVec) + { + if(bind::isIntConst(ite) || isOpX(ite)) + intType = true; + else if(bind::isRealConst(ite) || isOpX(ite)) + realType = true; + } + + if(realType && intType) + return MIXTYPE; // a bad case + else if(realType) + return REALTYPE; + else if(intType) + return INTTYPE; + else + return NOTYPE; // unknown + } + + static Expr tryToRemoveMixType(Expr exp) + { + if (intOrReal(exp) != MIXTYPE) + return exp; + + return convertIntsToReals(exp); + } + + // static void getLiterals (Expr exp, ExprSet& lits, bool splitEqs = true); + + // static void getLiteralsBool(Expr exp, ExprSet& lits, bool splitEqs = true) + // { + // Expr el = exp->left(); + // Expr er = exp->right(); + // if (isOp(exp) && !splitEqs && isBoolConstOrNegation(el) && isBoolConstOrNegation(er)) { + // lits.insert(exp); + // } else if (isOpX(exp) || isOpX(exp) || isOpX(exp) || isOpX(exp)) { + // getLiterals(mkNeg(el), lits, splitEqs); + // getLiterals(er, lits, splitEqs); + // getLiterals(mkNeg(er), lits, splitEqs); + // getLiterals(el, lits, splitEqs); + // } else if (isOpX(exp) || isOpX(exp)) { + // for (int i = 0; i < exp->arity(); i++) + // getLiterals(exp->arg(i), lits, splitEqs); + // } else if (isOpX(exp)) { + // if (isBoolConst(el)) + // lits.insert(exp); + // else + // getLiterals(mkNeg(el), lits, splitEqs); + // } else if (isOpX(exp)) { + // getLiterals(mkNeg(el), lits, splitEqs); + // getLiterals(er, lits, splitEqs); + // } + // } + + // static void getLiteralsNumeric(Expr exp, ExprSet& lits, bool splitEqs = true) + // { + // exp = tryToRemoveMixType(exp); + // Expr el = exp->left(); + // Expr er = exp->right(); + // if (isOp(exp) && !splitEqs) { + // lits.insert(exp); + // } else if (isOpX(exp) && !containsOp(exp)) { + // getLiterals(mk(el, er), lits, splitEqs); + // getLiterals(mk(el, er), lits, splitEqs); + // } else if (isOpX(exp) && !containsOp(exp)) { + // getLiterals(mk(el, er), lits, splitEqs); + // getLiterals(mk(el, er), lits, splitEqs); + // } else if (isOp(exp)) { + // exp = rewriteDivConstraints(exp); + // exp = rewriteModConstraints(exp); + // if (isOpX(exp) || isOpX(exp)) + // getLiterals(exp, lits, splitEqs); + // else lits.insert(exp); + // } + // } + + // static void getLiterals (Expr exp, ExprSet& lits, bool splitEqs) + // { + // ExprFactory& efac = exp->getFactory(); + // Expr el = exp->left(); + // Expr er = exp->right(); + // if (isOp(exp) || isOp(exp) && isBoolean(el)) + // getLiteralsBool(exp, lits, splitEqs); + // else if (isOp(exp) && isNumeric(el)) + // getLiteralsNumeric(exp, lits, splitEqs); + // else if (bind::typeOf(exp) == mk(efac) && + // !containsOp(exp) && !containsOp(exp)) { + // lits.insert(exp); + // } else if (!isOpX(exp) && !isOpX(exp)) { + // errs () << "unable lit: " << *exp << "\n"; + // assert(0); + // } + // } + static void getLiterals (Expr exp, ExprSet& lits, bool splitEqs = true) { ExprFactory& efac = exp->getFactory(); @@ -4616,6 +4995,7 @@ namespace ufo { if (isOp(exp)) { + exp = tryToRemoveMixType(exp); exp = rewriteDivConstraints(exp); exp = rewriteModConstraints(exp); if (isOpX(exp) || isOpX(exp)) @@ -4636,7 +5016,6 @@ namespace ufo } } - void pprint(Expr exp, int inden, bool upper); template static void pprint(Range& exprs, int inden = 0) @@ -4689,5 +5068,4 @@ namespace ufo if (upper) outs() << "\n"; } } - #endif diff --git a/src/ae/MBPUtils.cpp b/src/ae/MBPUtils.cpp index 5816dac33..d41b65836 100644 --- a/src/ae/MBPUtils.cpp +++ b/src/ae/MBPUtils.cpp @@ -3,35 +3,6 @@ using namespace ufo; -/** - * intOrReal - checks expression type - */ -int intOrReal(Expr s) -{ - ExprVector sVec; - bool realType = false, intType = false; - filter(s, bind::IsNumber(), back_inserter(sVec)); - filter(s, bind::IsConst(), back_inserter(sVec)); - for(auto ite : sVec) - { - if(bind::isIntConst(ite) || isOpX(ite)) - intType = true; - else if(bind::isRealConst(ite) || isOpX(ite)) - realType = true; - else - assert(false); // Error identifying - } - - if(realType && intType) - return MIXTYPE; // a bad case - else if(realType) - return REALTYPE; - else if(intType) - return INTTYPE; - else - return NOTYPE; // t == true -} - /** * laMergeBounds - merges lower and upper bounds * @@ -71,6 +42,20 @@ void laMergeBounds( return isOpX(m.eval(mk(ra, rb))); }); + outs() << "upVec: "; + for (auto a: upVec) + { + outs() << a << " : " << m.eval(a->right()) << ";"; + } + outs() << endl; + + outs() << "loVec: "; + for (auto a: loVec) + { + outs() << a << " : " << m.eval(a->right()) << ";"; + } + outs() << endl; + Expr loBound = loVec.back(); Expr upBound = upVec.front(); @@ -96,17 +81,28 @@ void laMergeBounds( Expr lraMultTrans(Expr t, Expr eVar) { Expr lhs = t->left(), rhs = t->right(); + unsigned minOps = 0; while(isOp(lhs)) //until lhs is no longer * { - Expr lOperand = lhs->left(), rOperand = lhs->right(); - bool yOnTheLeft = contains(lOperand, eVar); - - rhs = mk
(rhs, yOnTheLeft ? rOperand : lOperand); - lhs = yOnTheLeft ? lOperand : rOperand; + Expr varOperand; + for (int i = 0; i < lhs->arity(); i++) + { + if (contains(lhs->arg(i), eVar)) + varOperand = lhs->arg(i); + else + { + if (lexical_cast(lhs->arg(i))[0] == '-') + minOps++; + rhs = mk
(rhs, lhs->arg(i)); + } + } + lhs = varOperand; if (!contains(lhs, eVar)) unreachable(); } - return (mk(t->op(), lhs, rhs)); + if (minOps % 2 == 0) + return (mk(t->op(), lhs, rhs)); + return reBuildCmpSym(t, rhs, lhs); } /** @@ -313,20 +309,26 @@ Expr intQE(ExprSet sSet, Expr eVar, ZSolver::Model &m) */ Expr ineqPrepare(Expr t, Expr eVar) { + int intVSreal = intOrReal(t); if(isOpX(t) && isOp(t->left())) t = mkNeg(t->left()); if(isOp(t)) { + Expr zero = intVSreal == INTTYPE ? mkMPZ(0, eVar->efac()) + : mkMPQ("0", eVar->efac()); // rewrite so that y is on lhs, with positive coef t = simplifyArithm(reBuildCmp( t, mk(t->arg(0), additiveInverse(t->arg(1))), - mkMPZ(0, eVar->efac()))); + zero)); + if (isRealConst(eVar)) { + t = realRewriteDivs(t, eVar); + t = realSimplifyMult(t); + } t = ineqSimplifier(eVar, t); } else unreachable(); - int intVSreal = intOrReal(t); if(isReal(eVar) && (intVSreal == REALTYPE)) return t; @@ -334,7 +336,6 @@ Expr ineqPrepare(Expr t, Expr eVar) return t; else if(intVSreal != NOTYPE) notImplemented(); - return t; } @@ -376,5 +377,10 @@ Expr ufo::mixQE( outSet.insert(isReal(eVar) ? realQE(sameTypeSet, eVar, m) : intQE(sameTypeSet, eVar, m)); + outs() << "OutSet:\n"; + for (auto o : outSet) + outs() << o << "\n"; + outs() << std::endl; + return conjoin(outSet, eVar->getFactory()); } diff --git a/src/ae/MBPUtils.hpp b/src/ae/MBPUtils.hpp index 60ada80d0..09dcd777c 100644 --- a/src/ae/MBPUtils.hpp +++ b/src/ae/MBPUtils.hpp @@ -6,13 +6,6 @@ namespace ufo { Expr mixQE(Expr s, Expr eVar, ZSolver::Model &m, SMTUtils &u, int debug); - - enum laType { - REALTYPE = 0, - INTTYPE, - MIXTYPE, - NOTYPE - }; } -#endif \ No newline at end of file +#endif diff --git a/src/sygus/SyGuSSolver.hpp b/src/sygus/SyGuSSolver.hpp index 4bc4335ea..2d0c807e9 100644 --- a/src/sygus/SyGuSSolver.hpp +++ b/src/sygus/SyGuSSolver.hpp @@ -218,7 +218,7 @@ class SyGuSSolver faArgs.push_back(mknary(exArgs)); Expr aeProb = mknary(faArgs); aeProb = regularizeQF(aeProb); - aeProb = convertIntsToReals
(aeProb); + aeProb = convertIdivToDiv
(aeProb); if (debug > 1) { outs() << "Sending to aeval: "; u.print(aeProb); outs() << endl; } diff --git a/src/ufo/Expr.hpp b/src/ufo/Expr.hpp index e4cc6cd2d..fd7a0093e 100644 --- a/src/ufo/Expr.hpp +++ b/src/ufo/Expr.hpp @@ -2350,6 +2350,11 @@ namespace expr return mkTerm (mpz_class (a), efac); } + static Expr mkMPQ(std::string a, ExprFactory& efac) + { + return mkTerm (mpq_class (a), efac); + } + struct FAPP_PS { static inline void print (std::ostream &OS,