From: Andrew Reynolds Date: Sat, 7 Jul 2018 06:49:11 +0000 (+0100) Subject: sygusComp2018: improve extended rewriter for Bool (#2107) X-Git-Tag: cvc5-1.0.0~4904 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=ad454857a1f57386f7b132c01ad460750ca8d3aa;p=cvc5.git sygusComp2018: improve extended rewriter for Bool (#2107) --- diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index 4195a8a16..cdd597a5c 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -93,6 +93,7 @@ Node ExtendedRewriter::extendedRewrite(Node n) if (!pre_new_ret.isNull()) { ret = extendedRewrite(pre_new_ret); + Trace("q-ext-rewrite-debug") << "...ext-pre-rewrite : " << n << " -> " << pre_new_ret << std::endl; setCache(n, ret); @@ -167,17 +168,7 @@ Node ExtendedRewriter::extendedRewrite(Node n) } else if (ret.getKind() == AND || ret.getKind() == OR) { - // all kinds are legal to substitute over : hence we give the empty map - std::map bcp_kinds; - new_ret = extendedRewriteBcp(AND, OR, NOT, bcp_kinds, ret); - debugExtendedRewrite(ret, new_ret, "Bool bcp"); - if (new_ret.isNull()) - { - // equality resolution - new_ret = - extendedRewriteEqRes(AND, OR, EQUAL, NOT, bcp_kinds, ret, false); - debugExtendedRewrite(ret, new_ret, "Bool eq res"); - } + new_ret = extendedRewriteAndOr(ret); } else if (ret.getKind() == EQUAL) { @@ -194,17 +185,20 @@ Node ExtendedRewriter::extendedRewrite(Node n) //----------------------theory-specific post-rewriting if (new_ret.isNull()) { - Node atom = ret.getKind() == NOT ? ret[0] : ret; - bool pol = ret.getKind() != NOT; - TheoryId tid = Theory::theoryOf(atom); - if (tid == THEORY_ARITH) + TheoryId tid; + if (ret.getKind() == ITE) { - new_ret = extendedRewriteArith(atom, pol); + tid = Theory::theoryOf(ret.getType()); } - // add back negation if not processed - if (!pol && !new_ret.isNull()) + else { - new_ret = new_ret.negate(); + tid = Theory::theoryOf(ret); + } + Trace("q-ext-rewrite-debug") << "theoryOf( " << ret << " )= " << tid + << std::endl; + if (tid == THEORY_ARITH) + { + new_ret = extendedRewriteArith(ret); } } //----------------------end theory-specific post-rewriting @@ -317,6 +311,38 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) } return new_ret; } + // Boolean true/false return + TypeNode tn = n.getType(); + if (tn.isBoolean()) + { + for (unsigned i = 1; i <= 2; i++) + { + if (n[i].isConst()) + { + Node cond = i == 1 ? n[0] : n[0].negate(); + Node other = n[i == 1 ? 2 : 1]; + Kind retk = AND; + if (n[i].getConst()) + { + retk = OR; + } + else + { + cond = cond.negate(); + } + Node new_ret = nm->mkNode(retk, cond, other); + if (full) + { + // ite( A, true, B ) ---> A V B + // ite( A, false, B ) ---> ~A /\ B + // ite( A, B, true ) ---> ~A V B + // ite( A, B, false ) ---> A /\ B + debugExtendedRewrite(n, new_ret, "ITE const return"); + } + return new_ret; + } + } + } // get entailed equalities in the condition std::vector eq_conds; @@ -422,6 +448,31 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) return new_ret; } +Node ExtendedRewriter::extendedRewriteAndOr(Node n) +{ + Node new_ret; + // all kinds are legal to substitute over : hence we give the empty map + std::map bcp_kinds; + new_ret = extendedRewriteBcp(AND, OR, NOT, bcp_kinds, n); + if (!new_ret.isNull()) + { + debugExtendedRewrite(n, new_ret, "Bool bcp"); + return new_ret; + } + // factoring + new_ret = extendedRewriteFactoring(AND, OR, NOT, n); + if (!new_ret.isNull()) + { + debugExtendedRewrite(n, new_ret, "Bool factoring"); + return new_ret; + } + + // equality resolution + new_ret = extendedRewriteEqRes(AND, OR, EQUAL, NOT, bcp_kinds, n, false); + debugExtendedRewrite(n, new_ret, "Bool eq res"); + return new_ret; +} + Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n) { NodeManager* nm = NodeManager::currentNM(); @@ -744,6 +795,96 @@ Node ExtendedRewriter::extendedRewriteBcp( return Node::null(); } +Node ExtendedRewriter::extendedRewriteFactoring(Kind andk, + Kind ork, + Kind notk, + Node n) +{ + Trace("ext-rew-factoring") << "Factoring: *** INPUT: " << n << std::endl; + NodeManager* nm = NodeManager::currentNM(); + + Kind nk = n.getKind(); + Assert(nk == andk || nk == ork); + Kind onk = nk == andk ? ork : andk; + // count the number of times atoms occur + std::map > lit_to_cl; + std::map > cl_to_lits; + for (const Node& nc : n) + { + Kind nck = nc.getKind(); + if (nck == onk) + { + for (const Node& ncl : nc) + { + if (std::find(lit_to_cl[ncl].begin(), lit_to_cl[ncl].end(), nc) + == lit_to_cl[ncl].end()) + { + lit_to_cl[ncl].push_back(nc); + cl_to_lits[nc].push_back(ncl); + } + } + } + else + { + lit_to_cl[nc].push_back(nc); + cl_to_lits[nc].push_back(nc); + } + } + // get the maximum shared literal to factor + unsigned max_size = 0; + Node flit; + for (const std::pair >& ltc : lit_to_cl) + { + if (ltc.second.size() > max_size) + { + max_size = ltc.second.size(); + flit = ltc.first; + } + } + if (max_size > 1) + { + // do the factoring + std::vector children; + std::vector fchildren; + std::map >::iterator itl = lit_to_cl.find(flit); + std::vector& cls = itl->second; + for (const Node& nc : n) + { + if (std::find(cls.begin(), cls.end(), nc) == cls.end()) + { + children.push_back(nc); + } + else + { + // rebuild + std::vector& lits = cl_to_lits[nc]; + std::vector::iterator itlfl = + std::find(lits.begin(), lits.end(), flit); + Assert(itlfl != lits.end()); + lits.erase(itlfl); + // rebuild + if (!lits.empty()) + { + Node new_cl = lits.size() == 1 ? lits[0] : nm->mkNode(onk, lits); + fchildren.push_back(new_cl); + } + } + } + // rebuild the factored children + Assert(!fchildren.empty()); + Node fcn = fchildren.size() == 1 ? fchildren[0] : nm->mkNode(nk, fchildren); + children.push_back(nm->mkNode(onk, flit, fcn)); + Node ret = children.size() == 1 ? children[0] : nm->mkNode(nk, children); + Trace("ext-rew-factoring") << "Factoring: *** OUTPUT: " << ret << std::endl; + return ret; + } + else + { + Trace("ext-rew-factoring") << "Factoring: no change" << std::endl; + } + return Node::null(); +} + Node ExtendedRewriter::extendedRewriteEqRes(Kind andk, Kind ork, Kind eqk, @@ -830,6 +971,81 @@ Node ExtendedRewriter::extendedRewriteEqRes(Kind andk, return Node::null(); } +/** sort pairs by their second (unsigned) argument */ +static bool sortPairSecond(const std::pair& a, + const std::pair& b) +{ + return (a.second < b.second); +} + +/** A simple subsumption trie used to compute pairwise list subsets */ +class SimpSubsumeTrie +{ + public: + /** the children of this node */ + std::map d_children; + /** the term at this node */ + Node d_data; + /** add term to the trie + * + * This adds term c to this trie, whose atom list is alist. This adds terms + * s to subsumes such that the atom list of s is a subset of the atom list + * of c. For example, say: + * c1.alist = { A } + * c2.alist = { C } + * c3.alist = { B, C } + * c4.alist = { A, B, D } + * c5.alist = { A, B, C } + * If these terms are added in the order c1, c2, c3, c4, c5, then: + * addTerm c1 results in subsumes = {} + * addTerm c2 results in subsumes = {} + * addTerm c3 results in subsumes = { c2 } + * addTerm c4 results in subsumes = { c1 } + * addTerm c5 results in subsumes = { c1, c2, c3 } + * Notice that the intended use case of this trie is to add term t before t' + * only when size( t.alist ) <= size( t'.alist ). + * + * The last two arguments describe the state of the path [t0...tn] we + * have followed in the trie during the recursive call. + * If doAdd = true, + * then n+1 = index and alist[1]...alist[n] = t1...tn. If index=alist.size() + * we add c as the current node of this trie. + * If doAdd = false, + * then t1...tn occur in alist. + */ + void addTerm(Node c, + std::vector& alist, + std::vector& subsumes, + unsigned index = 0, + bool doAdd = true) + { + if (!d_data.isNull()) + { + subsumes.push_back(d_data); + } + if (doAdd) + { + if (index == alist.size()) + { + d_data = c; + return; + } + } + // try all children where we have this atom + for (std::pair& cp : d_children) + { + if (std::find(alist.begin(), alist.end(), cp.first) != alist.end()) + { + cp.second.addTerm(c, alist, subsumes, 0, false); + } + } + if (doAdd) + { + d_children[alist[index]].addTerm(c, alist, subsumes, index + 1, doAdd); + } + } +}; + Node ExtendedRewriter::extendedRewriteEqChain( Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor) { @@ -892,60 +1108,253 @@ Node ExtendedRewriter::extendedRewriteEqChain( children.clear(); - // cancel AND/OR children if possible + // compute the atoms of each child + Trace("ext-rew-eqchain") << "eqchain-simplify: begin\n"; + Trace("ext-rew-eqchain") << " eqchain-simplify: get atoms...\n"; + std::map > atoms; + std::map > alist; + std::vector > atom_count; for (std::pair& cp : cstatus) { - if (cp.second) + if (!cp.second) { - Node c = cp.first; - Kind ck = c.getKind(); - if (ck == andk || ck == ork) + // already eliminated + continue; + } + Node c = cp.first; + Kind ck = c.getKind(); + if (ck == andk || ck == ork) + { + for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++) { - for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++) + Node cl = c[j]; + bool pol = cl.getKind() != notk; + Node ca = pol ? cl : cl[0]; + Assert(atoms[c].find(ca) == atoms[c].end()); + // polarity is flipped when we are AND + atoms[c][ca] = (ck == andk ? !pol : pol); + alist[c].push_back(ca); + + // if this already exists as a child of the equality chain, eliminate. + // this catches cases like ( x & y ) = ( ( x & y ) | z ), where we + // consider ( x & y ) a unit, whereas below it is expanded to + // ~( ~x | ~y ). + std::map::iterator itc = cstatus.find(ca); + if (itc != cstatus.end() && itc->second) { - Node cl = c[j]; - Node ca = cl.getKind() == notk ? cl[0] : cl; - bool capol = cl.getKind() != notk; - // if this already exists as a child of the equality chain - std::map::iterator itc = cstatus.find(ca); - if (itc != cstatus.end() && itc->second) + // cancel it + cstatus[ca] = false; + cstatus[c] = false; + // make new child + // x = ( y | ~x ) ---> y & x + // x = ( y | x ) ---> ~y | x + // x = ( y & x ) ---> y | ~x + // x = ( y & ~x ) ---> ~y & ~x + std::vector new_children; + for (unsigned k = 0, nchild = c.getNumChildren(); k < nchild; k++) { - // cancel it - cstatus[ca] = false; - cstatus[c] = false; - // make new child - // x = ( y | ~x ) ---> y & x - // x = ( y | x ) ---> ~y | x - // x = ( y & x ) ---> y | ~x - // x = ( y & ~x ) ---> ~y & ~x - std::vector new_children; - for (unsigned k = 0, nchild = c.getNumChildren(); k < nchild; k++) - { - if (j != k) - { - new_children.push_back(c[k]); - } - } - Node nc[2]; - nc[0] = c[j]; - nc[1] = new_children.size() == 1 ? new_children[0] - : nm->mkNode(ck, new_children); - // negate the proper child - unsigned nindex = (ck == andk) == capol ? 0 : 1; - nc[nindex] = TermUtil::mkNegate(notk, nc[nindex]); - Kind nk = capol ? ork : andk; - // store as new child - children.push_back(nm->mkNode(nk, nc[0], nc[1])); - if (isXor) + if (j != k) { - gpol = !gpol; + new_children.push_back(c[k]); } - break; } + Node nc[2]; + nc[0] = c[j]; + nc[1] = new_children.size() == 1 ? new_children[0] + : nm->mkNode(ck, new_children); + // negate the proper child + unsigned nindex = (ck == andk) == pol ? 0 : 1; + nc[nindex] = TermUtil::mkNegate(notk, nc[nindex]); + Kind nk = pol ? ork : andk; + // store as new child + children.push_back(nm->mkNode(nk, nc[0], nc[1])); + if (isXor) + { + gpol = !gpol; + } + break; } } } + else + { + bool pol = ck != notk; + Node ca = pol ? c : c[0]; + atoms[c][ca] = pol; + alist[c].push_back(ca); + } + atom_count.push_back(std::pair(c, alist[c].size())); + } + // sort the atoms in each atom list + for (std::map >::iterator it = alist.begin(); + it != alist.end(); + ++it) + { + std::sort(it->second.begin(), it->second.end()); } + // check subsumptions + // sort by #atoms + std::sort(atom_count.begin(), atom_count.end(), sortPairSecond); + if (Trace.isOn("ext-rew-eqchain")) + { + for (const std::pair& ac : atom_count) + { + Trace("ext-rew-eqchain") << " eqchain-simplify: " << ac.first << " has " + << ac.second << " atoms." << std::endl; + } + Trace("ext-rew-eqchain") << " eqchain-simplify: compute subsumptions...\n"; + } + SimpSubsumeTrie sst; + for (std::pair& cp : cstatus) + { + if (!cp.second) + { + // already eliminated + continue; + } + Node c = cp.first; + std::map >::iterator itc = atoms.find(c); + Assert(itc != atoms.end()); + Trace("ext-rew-eqchain") << " - add term " << c << " with atom list " + << alist[c] << "...\n"; + std::vector subsumes; + sst.addTerm(c, alist[c], subsumes); + for (const Node& cc : subsumes) + { + if (!cstatus[cc]) + { + // subsumes a child that was already eliminated + continue; + } + Trace("ext-rew-eqchain") << " eqchain-simplify: " << c << " subsumes " + << cc << std::endl; + // for each of the atoms in cc + std::map >::iterator itcc = atoms.find(cc); + Assert(itcc != atoms.end()); + std::vector common_children; + std::vector diff_children; + for (const std::pair& ap : itcc->second) + { + // compare the polarity + Node a = ap.first; + bool polcc = ap.second; + Assert(itc->second.find(a) != itc->second.end()); + bool polc = itc->second[a]; + Trace("ext-rew-eqchain") << " eqchain-simplify: atom " << a + << " has polarities : " << polc << " " << polcc + << "\n"; + Node lit = polc ? a : TermUtil::mkNegate(notk, a); + if (polc != polcc) + { + diff_children.push_back(lit); + } + else + { + common_children.push_back(lit); + } + } + std::vector rem_children; + for (const std::pair& ap : itc->second) + { + Node a = ap.first; + if (atoms[cc].find(a) == atoms[cc].end()) + { + bool polc = ap.second; + rem_children.push_back(polc ? a : TermUtil::mkNegate(notk, a)); + } + } + Trace("ext-rew-eqchain") + << " #common/diff/rem: " << common_children.size() << "/" + << diff_children.size() << "/" << rem_children.size() << "\n"; + bool do_rewrite = false; + if (common_children.empty() && itc->second.size() == itcc->second.size() + && itcc->second.size() == 2) + { + // x | y = ~x | ~y ---> ~( x = y ) + do_rewrite = true; + children.push_back(diff_children[0]); + children.push_back(diff_children[1]); + gpol = !gpol; + Trace("ext-rew-eqchain") << " apply 2-child all-diff\n"; + } + else if (common_children.empty() && diff_children.size() == 1) + { + do_rewrite = true; + // x = ( ~x | y ) ---> ~( ~x | ~y ) + Node remn = rem_children.size() == 1 ? rem_children[0] + : nm->mkNode(ork, rem_children); + remn = TermUtil::mkNegate(notk, remn); + children.push_back(nm->mkNode(ork, diff_children[0], remn)); + if (!isXor) + { + gpol = !gpol; + } + Trace("ext-rew-eqchain") << " apply unit resolution\n"; + } + else if (diff_children.size() == 1 + && itc->second.size() == itcc->second.size()) + { + // ( x | y | z ) = ( x | ~y | z ) ---> ( x | z ) + do_rewrite = true; + Assert(!common_children.empty()); + Node comn = common_children.size() == 1 + ? common_children[0] + : nm->mkNode(ork, common_children); + children.push_back(comn); + if (isXor) + { + gpol = !gpol; + } + Trace("ext-rew-eqchain") << " apply resolution\n"; + } + else if (diff_children.empty()) + { + do_rewrite = true; + if (rem_children.empty()) + { + // x | y = x | y ---> true + // this can happen if we have ( ~x & ~y ) = ( x | y ) + children.push_back(TermUtil::mkTypeMaxValue(tn)); + if (isXor) + { + gpol = !gpol; + } + Trace("ext-rew-eqchain") << " apply cancel\n"; + } + else + { + // x | y = ( x | y | z ) ---> ( x | y | ~z ) + Node remn = rem_children.size() == 1 ? rem_children[0] + : nm->mkNode(ork, rem_children); + remn = TermUtil::mkNegate(notk, remn); + Node comn = common_children.size() == 1 + ? common_children[0] + : nm->mkNode(ork, common_children); + children.push_back(nm->mkNode(ork, comn, remn)); + if (isXor) + { + gpol = !gpol; + } + Trace("ext-rew-eqchain") << " apply subsume\n"; + } + } + if (do_rewrite) + { + // eliminate the children, reverse polarity as needed + for (unsigned r = 0; r < 2; r++) + { + Node c_rem = r == 0 ? c : cc; + cstatus[c_rem] = false; + if (c_rem.getKind() == andk) + { + gpol = !gpol; + } + } + break; + } + } + } + Trace("ext-rew-eqchain") << "eqchain-simplify: finish" << std::endl; // sorted right associative chain bool has_nvar = false; @@ -1079,7 +1488,13 @@ bool ExtendedRewriter::inferSubstitution(Node n, Node v[2]; for (unsigned i = 0; i < 2; i++) { - if (n[i].isVar() || n[i].isConst()) + if (n[i].isConst()) + { + vars.push_back(n[1 - i]); + subs.push_back(n[i]); + return true; + } + if (n[i].isVar()) { v[i] = n[i]; } @@ -1097,7 +1512,7 @@ bool ExtendedRewriter::inferSubstitution(Node n, r2 = n[1 - i]; if (v[i] != n[i]) { - Assert( TermUtil::isNegate( n[i].getKind() ) ); + Assert(TermUtil::isNegate(n[i].getKind())); r2 = TermUtil::mkNegate(n[i].getKind(), r2); } // TODO (#1706) : union find @@ -1113,7 +1528,7 @@ bool ExtendedRewriter::inferSubstitution(Node n, return false; } -Node ExtendedRewriter::extendedRewriteArith(Node ret, bool& pol) +Node ExtendedRewriter::extendedRewriteArith(Node ret) { Kind k = ret.getKind(); NodeManager* nm = NodeManager::currentNM(); diff --git a/src/theory/quantifiers/extended_rewrite.h b/src/theory/quantifiers/extended_rewrite.h index 37c179f94..4d3f08b1d 100644 --- a/src/theory/quantifiers/extended_rewrite.h +++ b/src/theory/quantifiers/extended_rewrite.h @@ -87,6 +87,12 @@ class ExtendedRewriter * strictly decrease the term size of n. */ Node extendedRewriteIte(Kind itek, Node n, bool full = true); + /** Rewrite AND/OR + * + * This implements BCP, factoring, and equality resolution for the Boolean + * term n whose top symbolic is AND/OR. + */ + Node extendedRewriteAndOr(Node n); /** Pull ITE, for example: * * D=C2 ---> false @@ -127,6 +133,15 @@ class ExtendedRewriter */ Node extendedRewriteBcp( Kind andk, Kind ork, Kind notk, std::map& bcp_kinds, Node n); + /** (type-independent) factoring, for example: + * + * ( A V B ) ^ ( A V C ) ----> A V ( B ^ C ) + * ( A ^ B ) V ( A ^ C ) ----> A ^ ( B V C ) + * + * This function takes as arguments the kinds that specify AND, OR, NOT. + * We assume that the children of n do not contain duplicates. + */ + Node extendedRewriteFactoring(Kind andk, Kind ork, Kind notk, Node n); /** (type-independent) equality resolution, for example: * * ( A V C ) & ( A = B ) ---> ( B V C ) & ( A = B ) @@ -211,7 +226,7 @@ class ExtendedRewriter //--------------------------------------theory-specific top-level calls /** extended rewrite arith */ - Node extendedRewriteArith(Node ret, bool& pol); + Node extendedRewriteArith(Node ret); //--------------------------------------end theory-specific top-level calls };