From 4a516e33436fb0abd9efd9b8ec92a8e65534ce3a Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Mon, 2 Apr 2018 20:03:16 -0500 Subject: [PATCH] Improvements to extended rewriter for Booleans and ITE (#1705) --- src/options/quantifiers_options.toml | 8 + src/theory/quantifiers/extended_rewrite.cpp | 1030 ++++++++++++++++--- src/theory/quantifiers/extended_rewrite.h | 143 ++- src/theory/quantifiers/term_util.cpp | 22 +- src/theory/quantifiers/term_util.h | 12 + 5 files changed, 1055 insertions(+), 160 deletions(-) diff --git a/src/options/quantifiers_options.toml b/src/options/quantifiers_options.toml index 28a9e58a7..f877143a2 100644 --- a/src/options/quantifiers_options.toml +++ b/src/options/quantifiers_options.toml @@ -1085,6 +1085,14 @@ header = "options/quantifiers_options.h" default = "false" help = "enumerate a stream of solutions instead of terminating after the first one" +[[option]] + name = "sygusExtRew" + category = "regular" + long = "sygus-ext-rew" + type = "bool" + default = "true" + help = "use extended rewriter for sygus" + [[option]] name = "cegisSample" category = "regular" diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index dd4fc86ba..756413b54 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -14,7 +14,9 @@ #include "theory/quantifiers/extended_rewrite.h" +#include "options/quantifiers_options.h" #include "theory/arith/arith_msum.h" +#include "theory/bv/theory_bv_utils.h" #include "theory/datatypes/datatypes_rewriter.h" #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" @@ -26,201 +28,176 @@ namespace CVC4 { namespace theory { namespace quantifiers { +struct ExtRewriteAttributeId +{ +}; +typedef expr::Attribute ExtRewriteAttribute; + ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr) { - d_true = NodeManager::currentNM()->mkConst(true); - d_false = NodeManager::currentNM()->mkConst(false); } - -Node ExtendedRewriter::extendedRewritePullIte(Node n) +void ExtendedRewriter::setCache(Node n, Node ret) { - // generalize this? - Assert(n.getNumChildren() == 2); - Assert(n.getType().isBoolean()); - Assert(n.getMetaKind() != kind::metakind::PARAMETERIZED); - std::vector children; - for (unsigned i = 0; i < n.getNumChildren(); i++) - { - children.push_back(n[i]); - } - for (unsigned i = 0; i < 2; i++) - { - if (n[i].getKind() == kind::ITE) - { - for (unsigned j = 0; j < 2; j++) - { - children[i] = n[i][j + 1]; - Node eqr = extendedRewrite( - NodeManager::currentNM()->mkNode(n.getKind(), children)); - children[i] = n[i]; - if (eqr.isConst()) - { - std::vector new_children; - Kind new_k; - if (eqr == d_true) - { - new_k = kind::OR; - new_children.push_back(j == 0 ? n[i][0] : n[i][0].negate()); - } - else - { - Assert(eqr == d_false); - new_k = kind::AND; - new_children.push_back(j == 0 ? n[i][0].negate() : n[i][0]); - } - children[i] = n[i][2 - j]; - Node rem_eq = NodeManager::currentNM()->mkNode(n.getKind(), children); - children[i] = n[i]; - new_children.push_back(rem_eq); - Node nc = NodeManager::currentNM()->mkNode(new_k, new_children); - Trace("q-ext-rewrite") << "sygus-extr : " << n << " rewrites to " - << nc << " by simple ITE pulling." - << std::endl; - return nc; - } - } - } - } - return Node::null(); + ExtRewriteAttribute era; + n.setAttribute(era, ret); } Node ExtendedRewriter::extendedRewrite(Node n) { n = Rewriter::rewrite(n); - std::unordered_map::iterator it = - d_ext_rewrite_cache.find(n); - if (it != d_ext_rewrite_cache.end()) + if (!options::sygusExtRew()) + { + return n; + } + + // has it already been computed? + if (n.hasAttribute(ExtRewriteAttribute())) { - return it->second; + return n.getAttribute(ExtRewriteAttribute()); } + Node ret = n; + NodeManager* nm = NodeManager::currentNM(); + + //--------------------pre-rewrite + Node pre_new_ret; + if (ret.getKind() == IMPLIES) + { + pre_new_ret = nm->mkNode(OR, ret[0].negate(), ret[1]); + debugExtendedRewrite(ret, pre_new_ret, "IMPLIES elim"); + } + else if (ret.getKind() == XOR) + { + pre_new_ret = nm->mkNode(EQUAL, ret[0].negate(), ret[1]); + debugExtendedRewrite(ret, pre_new_ret, "XOR elim"); + } + else if (ret.getKind() == NOT) + { + pre_new_ret = extendedRewriteNnf(ret); + debugExtendedRewrite(ret, pre_new_ret, "NNF"); + } + if (!pre_new_ret.isNull()) + { + ret = extendedRewrite(pre_new_ret); + Trace("q-ext-rewrite-debug") << "...ext-pre-rewrite : " << n << " -> " + << pre_new_ret << std::endl; + setCache(n, ret); + return ret; + } + //--------------------end pre-rewrite + + //--------------------rewrite children if (n.getNumChildren() > 0) { std::vector children; - if (n.getMetaKind() == kind::metakind::PARAMETERIZED) + if (n.getMetaKind() == metakind::PARAMETERIZED) { children.push_back(n.getOperator()); } + Kind k = n.getKind(); bool childChanged = false; + bool isNonAdditive = TermUtil::isNonAdditive(k); for (unsigned i = 0; i < n.getNumChildren(); i++) { Node nc = extendedRewrite(n[i]); childChanged = nc != n[i] || childChanged; - children.push_back(nc); + // If the operator is non-additive, do not consider duplicates + if (isNonAdditive + && std::find(children.begin(), children.end(), nc) != children.end()) + { + childChanged = true; + } + else + { + children.push_back(nc); + } } + Assert(!children.empty()); // Some commutative operators have rewriters that are agnostic to order, // thus, we sort here. - if (TermUtil::isComm(n.getKind()) && (d_aggr || children.size() <= 5)) + if (TermUtil::isComm(k) && (d_aggr || children.size() <= 5)) { childChanged = true; std::sort(children.begin(), children.end()); } if (childChanged) { - ret = NodeManager::currentNM()->mkNode(n.getKind(), children); + if (isNonAdditive && children.size() == 1) + { + // we may have subsumed children down to one + ret = children[0]; + } + else + { + ret = nm->mkNode(k, children); + } } } ret = Rewriter::rewrite(ret); + //--------------------end rewrite children + + // now, do extended rewrite Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret << " (from " << n << ")" << std::endl; - Node new_ret; - if (ret.getKind() == kind::EQUAL) + + //---------------------- theory-independent post-rewriting + if (ret.getKind() == ITE) { - if (new_ret.isNull()) - { - // simple ITE pulling - new_ret = extendedRewritePullIte(ret); - } + new_ret = extendedRewriteIte(ITE, ret); } - else if (ret.getKind() == kind::ITE) + else if (ret.getKind() == AND || ret.getKind() == OR) { - Assert(ret[1] != ret[2]); - if (ret[0].getKind() == NOT) - { - ret = NodeManager::currentNM()->mkNode( - kind::ITE, ret[0][0], ret[2], ret[1]); - } - if (ret[0].getKind() == kind::EQUAL) - { - // simple invariant ITE - for (unsigned i = 0; i < 2; i++) - { - if (ret[1] == ret[0][i] && ret[2] == ret[0][1 - i]) - { - Trace("q-ext-rewrite") - << "sygus-extr : " << ret << " rewrites to " << ret[2] - << " due to simple invariant ITE." << std::endl; - new_ret = ret[2]; - break; - } - } - // notice this is strictly more general than the above - if (new_ret.isNull()) - { - // simple substitution - for (unsigned i = 0; i < 2; i++) - { - TNode r1 = ret[0][i]; - TNode r2 = ret[0][1 - i]; - if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst())) - { - Node retn = ret[1].substitute(r1, r2); - if (retn != ret[1]) - { - new_ret = NodeManager::currentNM()->mkNode( - kind::ITE, ret[0], retn, ret[2]); - Trace("q-ext-rewrite") - << "sygus-extr : " << ret << " rewrites to " << new_ret - << " due to simple ITE substitution." << std::endl; - } - } - } - } - } + // all kinds are legal to substitute over : hence we give the empty map + std::map bcp_kinds; + new_ret = extendedRewriteBcp(AND, OR, NOT, bcp_kinds, ret); + debugExtendedRewrite(ret, new_ret, "Bool bcp"); } - else if (ret.getKind() == DIVISION || ret.getKind() == INTS_DIVISION - || ret.getKind() == INTS_MODULUS) + else if (ret.getKind() == EQUAL) { - // rewrite as though total - std::vector children; - bool all_const = true; - for (unsigned i = 0; i < ret.getNumChildren(); i++) + new_ret = extendedRewriteEqChain(EQUAL, AND, OR, NOT, ret); + debugExtendedRewrite(ret, new_ret, "Bool eq-chain simplify"); + } + if (new_ret.isNull() && ret.getKind() != ITE) + { + // simple ITE pulling + new_ret = extendedRewritePullIte(ITE, ret); + } + //----------------------end theory-independent post-rewriting + + //----------------------theory-specific post-rewriting + if (new_ret.isNull()) + { + Node atom = ret.getKind() == NOT ? ret[0] : ret; + bool pol = ret.getKind() != NOT; + TheoryId tid = Theory::theoryOf(atom); + if (tid == THEORY_ARITH) { - if (ret[i].isConst()) - { - children.push_back(ret[i]); - } - else - { - all_const = false; - break; - } + new_ret = extendedRewriteArith(atom, pol); } - if (all_const) + // add back negation if not processed + if (!pol && !new_ret.isNull()) { - Kind new_k = (ret.getKind() == DIVISION ? DIVISION_TOTAL - : (ret.getKind() == INTS_DIVISION - ? INTS_DIVISION_TOTAL - : INTS_MODULUS_TOTAL)); - new_ret = NodeManager::currentNM()->mkNode(new_k, children); - Trace("q-ext-rewrite") - << "sygus-extr : " << ret << " rewrites to " << new_ret - << " due to total interpretation." << std::endl; + new_ret = new_ret.negate(); } } - // more expensive rewrites + //----------------------end theory-specific post-rewriting + + //----------------------aggressive rewrites if (new_ret.isNull() && d_aggr) { new_ret = extendedRewriteAggr(ret); } + //----------------------end aggressive rewrites - d_ext_rewrite_cache[n] = ret; + setCache(n, ret); if (!new_ret.isNull()) { ret = extendedRewrite(new_ret); } - d_ext_rewrite_cache[n] = ret; + Trace("q-ext-rewrite-debug") << "...ext-rewrite : " << n << " -> " << ret + << std::endl; + setCache(n, ret); return ret; } @@ -234,6 +211,8 @@ Node ExtendedRewriter::extendedRewriteAggr(Node n) if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal()) || ret_atom.getKind() == GEQ) { + // ITE term removal in polynomials + // e.g. ite( x=0, x, y ) = x+1 ---> ( x=0 ^ y = x+1 ) Trace("q-ext-rewrite-debug2") << "Compute monomial sum " << ret_atom << std::endl; // compute monomial sum @@ -255,7 +234,7 @@ Node ExtendedRewriter::extendedRewriteAggr(Node n) Trace("q-ext-rewrite-debug") << " have ITE relation, solved form : " << veq << std::endl; // try pulling ITE - new_ret = extendedRewritePullIte(veq); + new_ret = extendedRewritePullIte(ITE, veq); if (!new_ret.isNull()) { if (!polarity) @@ -279,10 +258,781 @@ Node ExtendedRewriter::extendedRewriteAggr(Node n) << " failed to get monomial sum of " << n << std::endl; } } - // TODO (#1599) : conditional rewriting, condition merging + // TODO (#1706) : conditional rewriting, condition merging + return new_ret; +} + +Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) +{ + Assert(n.getKind() == itek); + Assert(n[1] != n[2]); + + NodeManager* nm = NodeManager::currentNM(); + + Trace("ext-rew-ite") << "Rewrite ITE : " << n << std::endl; + + Node flip_cond; + if (n[0].getKind() == NOT) + { + flip_cond = n[0][0]; + } + else if (n[0].getKind() == OR) + { + // a | b ---> ~( ~a & ~b ) + flip_cond = TermUtil::simpleNegate(n[0]); + } + if (!flip_cond.isNull()) + { + Node new_ret = nm->mkNode(ITE, flip_cond, n[2], n[1]); + // only print debug trace if full=true + if (full) + { + debugExtendedRewrite(n, new_ret, "ITE flip"); + } + return new_ret; + } + + // get entailed equalities in the condition + std::vector eq_conds; + Kind ck = n[0].getKind(); + if (ck == EQUAL) + { + eq_conds.push_back(n[0]); + } + else if (ck == AND) + { + for (const Node& cn : n[0]) + { + if (cn.getKind() == EQUAL) + { + eq_conds.push_back(cn); + } + } + } + + Node new_ret; + Node b; + Node e; + Node t1 = n[1]; + Node t2 = n[2]; + std::stringstream ss_reason; + + for (const Node& eq : eq_conds) + { + // simple invariant ITE + for (unsigned i = 0; i <= 1; i++) + { + // ite( x = y ^ C, y, x ) ---> x + // this is subsumed by the rewrites below + if (t2 == eq[i] && t1 == eq[1 - i]) + { + new_ret = t2; + ss_reason << "ITE simple rev subs"; + break; + } + } + if (!new_ret.isNull()) + { + break; + } + } + + if (new_ret.isNull() && d_aggr) + { + // If x is less than t based on an ordering, then we use { x -> t } as a + // substitution to the children of ite( x = t ^ C, s, t ) below. + std::vector vars; + std::vector subs; + for (const Node& eq : eq_conds) + { + inferSubstitution(eq, vars, subs); + } + + if (!vars.empty()) + { + // reverse substitution to opposite child + // r{ x -> t } = s implies ite( x=t ^ C, s, r ) ---> r + Node nn = + t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end()); + if (nn != t2) + { + nn = Rewriter::rewrite(nn); + if (nn == t1) + { + new_ret = t2; + ss_reason << "ITE rev subs"; + } + } + + // ite( x=t ^ C, s, r ) ---> ite( x=t ^ C, s{ x -> t }, r ) + nn = t1.substitute(vars.begin(), vars.end(), subs.begin(), subs.end()); + if (nn != t1) + { + // If full=false, then we've duplicated a term u in the children of n. + // For example, when ITE pulling, we have n is of the form: + // ite( C, f( u, t1 ), f( u, t2 ) ) + // We must show that at least one copy of u dissappears in this case. + nn = Rewriter::rewrite(nn); + if (nn == t2) + { + new_ret = nn; + ss_reason << "ITE subs invariant"; + } + else if (full || nn.isConst()) + { + new_ret = nm->mkNode(itek, n[0], nn, t2); + ss_reason << "ITE subs"; + } + } + } + } + + // only print debug trace if full=true + if (!new_ret.isNull() && full) + { + debugExtendedRewrite(n, new_ret, ss_reason.str().c_str()); + } + + return new_ret; +} + +Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n) +{ + NodeManager* nm = NodeManager::currentNM(); + TypeNode tn = n.getType(); + std::vector children; + bool hasOp = (n.getMetaKind() == metakind::PARAMETERIZED); + if (hasOp) + { + children.push_back(n.getOperator()); + } + unsigned nchildren = n.getNumChildren(); + for (unsigned i = 0; i < nchildren; i++) + { + children.push_back(n[i]); + } + std::map > ite_c; + for (unsigned i = 0; i < nchildren; i++) + { + if (n[i].getKind() == itek) + { + unsigned ii = hasOp ? i + 1 : i; + for (unsigned j = 0; j < 2; j++) + { + children[ii] = n[i][j + 1]; + Node pull = nm->mkNode(n.getKind(), children); + Node pullr = Rewriter::rewrite(pull); + children[ii] = n[i]; + ite_c[i][j] = pullr; + } + if (ite_c[i][0] == ite_c[i][1]) + { + // ITE dual invariance + // f( t1..s1..tn ) ---> t and f( t1..s2..tn ) ---> t implies + // f( t1..ite( A, s1, s2 )..tn ) ---> t + debugExtendedRewrite(n, ite_c[i][0], "ITE dual invariant"); + return ite_c[i][0]; + } + else if (d_aggr) + { + for (unsigned j = 0; j < 2; j++) + { + Node pullr = ite_c[i][j]; + if (pullr.isConst() || pullr == n[i][j + 1]) + { + // ITE single child elimination + // f( t1..s1..tn ) ---> t where t is a constant or s1 itself + // implies + // f( t1..ite( A, s1, s2 )..tn ) ---> ite( A, t, f( t1..s2..tn ) ) + Node new_ret; + if (tn.isBoolean()) + { + // remove false/true child immediately + bool pol = pullr.getConst(); + std::vector new_children; + new_children.push_back((j == 0) == pol ? n[i][0] + : n[i][0].negate()); + new_children.push_back(ite_c[i][1 - j]); + new_ret = nm->mkNode(pol ? OR : AND, new_children); + debugExtendedRewrite(n, new_ret, "ITE Bool single elim"); + } + else + { + new_ret = nm->mkNode(itek, n[i][0], ite_c[i][0], ite_c[i][1]); + debugExtendedRewrite(n, new_ret, "ITE single elim"); + } + return new_ret; + } + } + } + } + } + + for (std::pair >& ip : ite_c) + { + Node nite = n[ip.first]; + Assert(nite.getKind() == itek); + // now, simply pull the ITE and try ITE rewrites + Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]); + pull_ite = Rewriter::rewrite(pull_ite); + if (pull_ite.getKind() == ITE) + { + Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false); + if (!new_pull_ite.isNull()) + { + debugExtendedRewrite(n, new_pull_ite, "ITE pull rewrite"); + return new_pull_ite; + } + } + else + { + // A general rewrite could eliminate the ITE by pulling. + // An example is: + // ~( ite( C, ~x, ~ite( C, y, x ) ) ) ---> + // ite( C, ~~x, ite( C, y, x ) ) ---> + // x + // where ~ is bitvector negation. + debugExtendedRewrite(n, pull_ite, "ITE pull basic elim"); + return pull_ite; + } + } + + return Node::null(); +} + +Node ExtendedRewriter::extendedRewriteNnf(Node ret) +{ + Assert(ret.getKind() == NOT); + + Kind nk = ret[0].getKind(); + bool neg_ch = false; + bool neg_ch_1 = false; + if (nk == AND || nk == OR) + { + neg_ch = true; + nk = nk == AND ? OR : AND; + } + else if (nk == IMPLIES) + { + neg_ch = true; + neg_ch_1 = true; + nk = AND; + } + else if (nk == ITE) + { + neg_ch = true; + neg_ch_1 = true; + } + else if (nk == XOR) + { + nk = EQUAL; + } + else if (nk == EQUAL && ret[0][0].getType().isBoolean()) + { + neg_ch_1 = true; + } + else + { + return Node::null(); + } + + std::vector new_children; + for (unsigned i = 0, nchild = ret[0].getNumChildren(); i < nchild; i++) + { + Node c = ret[0][i]; + c = (i == 0 ? neg_ch_1 : false) != neg_ch ? c.negate() : c; + new_children.push_back(c); + } + return NodeManager::currentNM()->mkNode(nk, new_children); +} + +Node ExtendedRewriter::extendedRewriteBcp( + Kind andk, Kind ork, Kind notk, std::map& bcp_kinds, Node ret) +{ + Kind k = ret.getKind(); + Assert(k == andk || k == ork); + Trace("ext-rew-bcp") << "BCP: **** INPUT: " << ret << std::endl; + + NodeManager* nm = NodeManager::currentNM(); + + TypeNode tn = ret.getType(); + Node truen = TermUtil::mkTypeMaxValue(tn); + Node falsen = TermUtil::mkTypeValue(tn, 0); + + // terms to process + std::vector to_process; + for (const Node& cn : ret) + { + to_process.push_back(cn); + } + // the processing terms + std::vector clauses; + // the terms we have propagated information to + std::unordered_set prop_clauses; + // the assignment + std::map assign; + std::vector avars; + std::vector asubs; + + Kind ok = k == andk ? ork : andk; + // global polarity : when k=ork, everything is negated + bool gpol = k == andk; + + do + { + // process the current nodes + while (!to_process.empty()) + { + std::vector new_to_process; + for (const Node& cn : to_process) + { + Trace("ext-rew-bcp-debug") << "process " << cn << std::endl; + Kind cnk = cn.getKind(); + bool pol = cnk != notk; + Node cln = cnk == notk ? cn[0] : cn; + Assert(cln.getKind() != notk); + if ((pol && cln.getKind() == k) || (!pol && cln.getKind() == ok)) + { + // flatten + for (const Node& ccln : cln) + { + Node lccln = pol ? ccln : TermUtil::mkNegate(notk, ccln); + new_to_process.push_back(lccln); + } + } + else + { + // add it to the assignment + Node val = gpol == pol ? truen : falsen; + std::map::iterator it = assign.find(cln); + Trace("ext-rew-bcp") << "BCP: assign " << cln << " -> " << val + << std::endl; + if (it != assign.end()) + { + if (val != it->second) + { + Trace("ext-rew-bcp") << "BCP: conflict!" << std::endl; + // a conflicting assignment: we are done + return gpol ? falsen : truen; + } + } + else + { + assign[cln] = val; + avars.push_back(cln); + asubs.push_back(val); + } + + // also, treat it as clause if possible + if (cln.getNumChildren() > 0 + & (bcp_kinds.empty() + || bcp_kinds.find(cln.getKind()) != bcp_kinds.end())) + { + if (std::find(clauses.begin(), clauses.end(), cn) == clauses.end() + && prop_clauses.find(cn) == prop_clauses.end()) + { + Trace("ext-rew-bcp") << "BCP: new clause: " << cn << std::endl; + clauses.push_back(cn); + } + } + } + } + to_process.clear(); + to_process.insert( + to_process.end(), new_to_process.begin(), new_to_process.end()); + } + + // apply substitution to all subterms of clauses + std::vector new_clauses; + for (const Node& c : clauses) + { + bool cpol = c.getKind() != notk; + Node ca = c.getKind() == notk ? c[0] : c; + bool childChanged = false; + std::vector ccs_children; + for (const Node& cc : ca) + { + Node ccs = cc; + if (bcp_kinds.empty()) + { + Trace("ext-rew-bcp-debug") << "...do ordinary substitute" + << std::endl; + ccs = cc.substitute( + avars.begin(), avars.end(), asubs.begin(), asubs.end()); + } + else + { + Trace("ext-rew-bcp-debug") << "...do partial substitute" << std::endl; + // substitution is only applicable to compatible kinds + ccs = partialSubstitute(ccs, assign, bcp_kinds); + } + childChanged = childChanged || ccs != cc; + ccs_children.push_back(ccs); + } + if (childChanged) + { + if (ca.getMetaKind() == metakind::PARAMETERIZED) + { + ccs_children.insert(ccs_children.begin(), ca.getOperator()); + } + Node ccs = nm->mkNode(ca.getKind(), ccs_children); + ccs = cpol ? ccs : TermUtil::mkNegate(notk, ccs); + Trace("ext-rew-bcp") << "BCP: propagated " << c << " -> " << ccs + << std::endl; + ccs = Rewriter::rewrite(ccs); + Trace("ext-rew-bcp") << "BCP: rewritten to " << ccs << std::endl; + to_process.push_back(ccs); + // store this as a node that propagation touched. This marks c so that + // it will not be included in the final construction. + prop_clauses.insert(ca); + } + else + { + new_clauses.push_back(c); + } + } + clauses.clear(); + clauses.insert(clauses.end(), new_clauses.begin(), new_clauses.end()); + } while (!to_process.empty()); + + // remake the node + if (!prop_clauses.empty()) + { + std::vector children; + for (std::pair& l : assign) + { + Node a = l.first; + // if propagation did not touch a + if (prop_clauses.find(a) == prop_clauses.end()) + { + Assert(l.second == truen || l.second == falsen); + Node ln = (l.second == truen) == gpol ? a : TermUtil::mkNegate(notk, a); + children.push_back(ln); + } + } + Node new_ret = children.size() == 1 ? children[0] : nm->mkNode(k, children); + Trace("ext-rew-bcp") << "BCP: **** OUTPUT: " << new_ret << std::endl; + return new_ret; + } + + return Node::null(); +} + +Node ExtendedRewriter::extendedRewriteEqChain( + Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor) +{ + Assert(ret.getKind() == eqk); + + NodeManager* nm = NodeManager::currentNM(); + + TypeNode tn = ret[0].getType(); + + // sort/cancelling for Boolean EQUAL/XOR-chains + Trace("ext-rew-eqchain") << "Eq-Chain : " << ret << std::endl; + + // get the children on either side + bool gpol = true; + std::vector children; + for (unsigned r = 0, size = ret.getNumChildren(); r < size; r++) + { + Node curr = ret[r]; + // assume, if necessary, right associative + while (curr.getKind() == eqk && curr[0].getType() == tn) + { + children.push_back(curr[0]); + curr = curr[1]; + } + children.push_back(curr); + } + + std::map cstatus; + // add children to status + for (const Node& c : children) + { + Node a = c; + if (a.getKind() == notk) + { + gpol = !gpol; + a = a[0]; + } + Trace("ext-rew-eqchain") << "...child : " << a << std::endl; + std::map::iterator itc = cstatus.find(a); + if (itc == cstatus.end()) + { + cstatus[a] = true; + } + else + { + // cancels + cstatus.erase(a); + if (isXor) + { + gpol = !gpol; + } + } + } + Trace("ext-rew-eqchain") << "Global polarity : " << gpol << std::endl; + + if (cstatus.empty()) + { + return TermUtil::mkTypeConst(tn, gpol); + } + + children.clear(); + + // cancel AND/OR children if possible + for (std::pair& cp : cstatus) + { + if (cp.second) + { + Node c = cp.first; + Kind ck = c.getKind(); + if (ck == andk || ck == ork) + { + for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++) + { + Node cl = c[j]; + Node ca = cl.getKind() == notk ? cl[0] : cl; + bool capol = cl.getKind() != notk; + // if this already exists as a child of the equality chain + std::map::iterator itc = cstatus.find(ca); + if (itc != cstatus.end() && itc->second) + { + // cancel it + cstatus[ca] = false; + cstatus[c] = false; + // make new child + // x = ( y | ~x ) ---> y & x + // x = ( y | x ) ---> ~y | x + // x = ( y & x ) ---> y | ~x + // x = ( y & ~x ) ---> ~y & ~x + std::vector new_children; + for (unsigned k = 0, nchild = c.getNumChildren(); k < nchild; k++) + { + if (j != k) + { + new_children.push_back(c[k]); + } + } + Node nc[2]; + nc[0] = c[j]; + nc[1] = new_children.size() == 1 ? new_children[0] + : nm->mkNode(ck, new_children); + // negate the proper child + unsigned nindex = (ck == andk) == capol ? 0 : 1; + nc[nindex] = TermUtil::mkNegate(notk, nc[nindex]); + Kind nk = capol ? ork : andk; + // store as new child + children.push_back(nm->mkNode(nk, nc[0], nc[1])); + if (isXor) + { + gpol = !gpol; + } + break; + } + } + } + } + } + + // sorted right associative chain + bool has_const = false; + unsigned const_index = 0; + for (std::pair& cp : cstatus) + { + if (cp.second) + { + if (cp.first.isConst()) + { + has_const = true; + const_index = children.size(); + } + children.push_back(cp.first); + } + } + std::sort(children.begin(), children.end()); + + Node new_ret; + if (!gpol) + { + // negate the constant child if it exists + unsigned nindex = has_const ? const_index : 0; + children[nindex] = TermUtil::mkNegate(notk, children[nindex]); + } + new_ret = children.back(); + unsigned index = children.size() - 1; + while (index > 0) + { + index--; + new_ret = nm->mkNode(eqk, children[index], new_ret); + } + new_ret = Rewriter::rewrite(new_ret); + if (new_ret != ret) + { + return new_ret; + } + return Node::null(); +} + +Node ExtendedRewriter::partialSubstitute(Node n, + std::map& assign, + std::map& rkinds) +{ + std::unordered_map visited; + std::unordered_map::iterator it; + std::vector visit; + TNode cur; + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + it = visited.find(cur); + + if (it == visited.end()) + { + std::map::iterator it = assign.find(cur); + if (it != assign.end()) + { + visited[cur] = it->second; + } + else + { + // can only recurse on these kinds + Kind k = cur.getKind(); + if (rkinds.find(k) != rkinds.end()) + { + visited[cur] = Node::null(); + visit.push_back(cur); + for (const Node& cn : cur) + { + visit.push_back(cn); + } + } + else + { + visited[cur] = cur; + } + } + } + else if (it->second.isNull()) + { + Node ret = cur; + bool childChanged = false; + std::vector children; + if (cur.getMetaKind() == metakind::PARAMETERIZED) + { + children.push_back(cur.getOperator()); + } + for (const Node& cn : cur) + { + it = visited.find(cn); + Assert(it != visited.end()); + Assert(!it->second.isNull()); + childChanged = childChanged || cn != it->second; + children.push_back(it->second); + } + if (childChanged) + { + ret = NodeManager::currentNM()->mkNode(cur.getKind(), children); + } + visited[cur] = ret; + } + } while (!visit.empty()); + Assert(visited.find(n) != visited.end()); + Assert(!visited.find(n)->second.isNull()); + return visited[n]; +} + +Node ExtendedRewriter::solveEquality(Node n) +{ + // TODO (#1706) : implement + Assert(n.getKind() == EQUAL); + + return Node::null(); +} + +bool ExtendedRewriter::inferSubstitution(Node n, + std::vector& vars, + std::vector& subs) +{ + if (n.getKind() == EQUAL) + { + // see if it can be put into form x = y + Node slv_eq = solveEquality(n); + if (!slv_eq.isNull()) + { + n = slv_eq; + } + for (unsigned i = 0; i < 2; i++) + { + TNode r1 = n[i]; + TNode r2 = n[1 - i]; + if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst())) + { + // TODO (#1706) : union find + if (std::find(vars.begin(), vars.end(), r1) == vars.end()) + { + vars.push_back(r1); + subs.push_back(r2); + return true; + } + } + } + } + return false; +} + +Node ExtendedRewriter::extendedRewriteArith(Node ret, bool& pol) +{ + Kind k = ret.getKind(); + NodeManager* nm = NodeManager::currentNM(); + Node new_ret; + if (k == DIVISION || k == INTS_DIVISION || k == INTS_MODULUS) + { + // rewrite as though total + std::vector children; + bool all_const = true; + for (unsigned i = 0, size = ret.getNumChildren(); i < size; i++) + { + if (ret[i].isConst()) + { + children.push_back(ret[i]); + } + else + { + all_const = false; + break; + } + } + if (all_const) + { + Kind new_k = (ret.getKind() == DIVISION ? DIVISION_TOTAL + : (ret.getKind() == INTS_DIVISION + ? INTS_DIVISION_TOTAL + : INTS_MODULUS_TOTAL)); + new_ret = nm->mkNode(new_k, children); + debugExtendedRewrite(ret, new_ret, "total-interpretation"); + } + } return new_ret; } +void ExtendedRewriter::debugExtendedRewrite(Node n, + Node ret, + const char* c) const +{ + if (Trace.isOn("q-ext-rewrite")) + { + if (!ret.isNull()) + { + Trace("q-ext-rewrite-apply") << "sygus-extr : apply " << c << std::endl; + Trace("q-ext-rewrite") << "sygus-extr : " << c << " : " << n + << " rewrites to " << ret << std::endl; + } + } +} + } /* CVC4::theory::quantifiers namespace */ } /* CVC4::theory namespace */ } /* CVC4 namespace */ diff --git a/src/theory/quantifiers/extended_rewrite.h b/src/theory/quantifiers/extended_rewrite.h index 25d710a6b..2daa42b18 100644 --- a/src/theory/quantifiers/extended_rewrite.h +++ b/src/theory/quantifiers/extended_rewrite.h @@ -35,10 +35,14 @@ namespace quantifiers { * * This class extended the standard techniques for rewriting * with techniques, including but not limited to: - * - ITE branch merging, + * - Redundant child elimination, + * - Sorting children of commutative operators, + * - Boolean constraint propagation, + * - Equality chain normalization, + * - Negation normal form, + * - Simple ITE pulling, * - ITE conditional variable elimination, - * - ITE condition subsumption, and - * - Aggressive rewriting for string equalities. + * - ITE condition subsumption. */ class ExtendedRewriter { @@ -60,21 +64,128 @@ class ExtendedRewriter * may be applied as a preprocessing step. */ bool d_aggr; - /** true and false nodes */ - Node d_true; - Node d_false; - /** cache for extendedRewrite */ - std::unordered_map d_ext_rewrite_cache; - /** pull ITE - * Do simple ITE pulling, e.g.: - * C2 --->^E false - * implies: - * ite( C, C1, C2 ) --->^E C ^ C1 - * where ---->^E denotes extended rewriting. + /** cache that the extended rewritten form of n is ret */ + void setCache(Node n, Node ret); + + //--------------------------------------generic utilities + /** Rewrite ITE, for example: + * + * ite( ~C, s, t ) ---> ite( C, t, s ) + * ite( A or B, s, t ) ---> ite( ~A and ~B, t, s ) + * ite( x = c, x, t ) --> ite( x = c, c, t ) + * t * { x -> c } = s => ite( x = c, s, t ) ---> t + * + * The parameter "full" indicates an effort level that this rewrite will + * take. If full is false, then we do only perform rewrites that + * strictly decrease the term size of n. + */ + Node extendedRewriteIte(Kind itek, Node n, bool full = true); + /** Pull ITE, for example: + * + * D=C2 ---> false + * implies + * D=ite( C, C1, C2 ) ---> C ^ D=C1 + * + * f(t,t1) --> s and f(t,t2)---> s + * implies + * f(t,ite(C,t1,t2)) ---> s + * + * If this function returns a non-null node ret, then n ---> ret. + */ + Node extendedRewritePullIte(Kind itek, Node n); + /** Negation Normal Form (NNF), for example: + * + * ~( A & B ) ---> ( ~ A | ~B ) + * ~( ite( A, B, C ) ---> ite( A, ~B, ~C ) + * + * If this function returns a non-null node ret, then n ---> ret. + */ + Node extendedRewriteNnf(Node n); + /** (type-independent) Boolean constraint propagation, for example: + * + * ~A & ( B V A ) ---> ~A & B + * A & ( B = ( A V C ) ) ---> A & B + * + * This function takes as arguments the kinds that specify AND, OR, and NOT. + * It additionally takes as argument a map bcp_kinds. If this map is + * non-empty, then all terms that have a Kind that is *not* in this map should + * be treated as immutable. This is for instance to prevent propagation + * beneath illegal terms. As an example: + * (bvand A (bvor A B)) is equivalent to (bvand A (bvor 1...1 B)), but + * (bvand A (bvplus A B)) is not equivalent to (bvand A (bvplus 1..1 B)), + * hence, when using this function to do BCP for bit-vectors, we have that + * BITVECTOR_AND is a bcp_kind, but BITVECTOR_PLUS is not. + * + * If this function returns a non-null node ret, then n ---> ret. + */ + Node extendedRewriteBcp( + Kind andk, Kind ork, Kind notk, std::map& bcp_kinds, Node n); + /** (type-independent) Equality chain rewriting, for example: + * + * A = ( A = B ) ---> B + * ( A = D ) = ( C = B ) ---> A = ( B = ( C = D ) ) + * A = ( A & B ) ---> ~A | B + * + * If this function returns a non-null node ret, then n ---> ret. + * This function takes as arguments the kinds that specify EQUAL, AND, OR, + * and NOT. If the flag isXor is true, the eqk is treated as XOR. + */ + Node extendedRewriteEqChain( + Kind eqk, Kind andk, Kind ork, Kind notk, Node n, bool isXor = false); + /** extended rewrite aggressive + * + * All aggressive rewriting techniques (those that should be prioritized + * at a lower level) go in this function. */ - Node extendedRewritePullIte(Node n); - /** extended rewrite aggressive */ Node extendedRewriteAggr(Node n); + /** Decompose right associative chain + * + * For term f( ... f( f( base, tn ), t{n-1} ) ... t1 ), returns term base, and + * appends t1...tn to children. + */ + Node decomposeRightAssocChain(Kind k, Node n, std::vector& children); + /** Make right associative chain + * + * Sorts children to obtain list { tn...t1 }, and returns the term + * f( ... f( f( base, tn ), t{n-1} ) ... t1 ). + */ + Node mkRightAssocChain(Kind k, Node base, std::vector& children); + /** Partial substitute + * + * Applies the substitution specified by assign to n, recursing only beneath + * terms whose Kind appears in rec_kinds. + */ + Node partialSubstitute(Node n, + std::map& assign, + std::map& rkinds); + /** solve equality + * + * If this function returns a non-null node n', then n' is equivalent to n + * and is of the form that can be used by inferSubstitution below. + */ + Node solveEquality(Node n); + /** infer substitution + * + * If n is an equality of the form x = t, where t is either: + * (1) a constant, or + * (2) a variable y such that x < y based on an ordering, + * then this method adds x to vars and y to subs and return true, otherwise + * it returns false. + */ + bool inferSubstitution(Node n, + std::vector& vars, + std::vector& subs); + /** extended rewrite + * + * Prints debug information, indicating the rewrite n ---> ret was found. + */ + inline void debugExtendedRewrite(Node n, Node ret, const char* c) const; + //--------------------------------------end generic utilities + + //--------------------------------------theory-specific top-level calls + /** extended rewrite arith */ + Node extendedRewriteArith(Node ret, bool& pol); + //--------------------------------------end theory-specific top-level calls }; } /* CVC4::theory::quantifiers namespace */ diff --git a/src/theory/quantifiers/term_util.cpp b/src/theory/quantifiers/term_util.cpp index 3b8d03399..5965906cb 100644 --- a/src/theory/quantifiers/term_util.cpp +++ b/src/theory/quantifiers/term_util.cpp @@ -773,13 +773,22 @@ bool TermUtil::containsUninterpretedConstant( Node n ) { Node TermUtil::simpleNegate( Node n ){ if( n.getKind()==OR || n.getKind()==AND ){ std::vector< Node > children; - for( unsigned i=0; imkNode( n.getKind()==OR ? AND : OR, children ); - }else{ - return n.negate(); } + return n.negate(); +} + +Node TermUtil::mkNegate(Kind notk, Node n) +{ + if (n.getKind() == notk) + { + return n[0]; + } + return NodeManager::currentNM()->mkNode(notk, n); } bool TermUtil::isAssoc( Kind k ) { @@ -912,6 +921,11 @@ Node TermUtil::getTypeValueOffset(TypeNode tn, return it->second; } +Node TermUtil::mkTypeConst(TypeNode tn, bool pol) +{ + return pol ? mkTypeValue(tn, 0) : mkTypeMaxValue(tn); +} + bool TermUtil::isAntisymmetric(Kind k, Kind& dk) { if (k == GT) diff --git a/src/theory/quantifiers/term_util.h b/src/theory/quantifiers/term_util.h index 8ec2fc8e2..97f4edcd5 100644 --- a/src/theory/quantifiers/term_util.h +++ b/src/theory/quantifiers/term_util.h @@ -289,6 +289,11 @@ public: static int getTermDepth( Node n ); /** simple negate */ static Node simpleNegate( Node n ); + /** + * Make negated term, returns the negation of n wrt Kind notk, eliminating + * double negation if applicable, e.g. mkNegate( ~, ~x ) ---> x. + */ + static Node mkNegate(Kind notk, Node n); /** is assoc */ static bool isAssoc( Kind k ); /** is k commutative? */ @@ -364,6 +369,13 @@ public: static Node mkTypeValueOffset(TypeNode tn, Node val, int offset, int& status); /** make max value, static version of get max value */ static Node mkTypeMaxValue(TypeNode tn); + /** + * Make const, returns pol ? mkTypeValue(tn,0) : mkTypeMaxValue(tn). + * In other words, this returns either the minimum element of tn if pol is + * true, and the maximum element in pol is false. The type tn should have + * minimum and maximum elements, for example tn is Bool or BitVector. + */ + static Node mkTypeConst(TypeNode tn, bool pol); // for higher-order private: -- 2.30.2