From 5bfb18e0b31584a50ea0386ce3d43855ee2061d4 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 29 Apr 2022 10:22:20 -0500 Subject: [PATCH] Make extended rewriter use standard Subs utility (#8682) This is work towards ensuring all substitutions are strictly typed. --- src/theory/quantifiers/extended_rewrite.cpp | 52 ++++++++------------- src/theory/quantifiers/extended_rewrite.h | 13 ++---- 2 files changed, 25 insertions(+), 40 deletions(-) diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index 975fe6e7e..d23cf209f 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -485,19 +485,17 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const { // If x is less than t based on an ordering, then we use { x -> t } as a // substitution to the children of ite( x = t ^ C, s, t ) below. - std::vector vars; - std::vector subs; - inferSubstitution(n[0], vars, subs, true); + Subs subs; + inferSubstitution(n[0], subs, true); - if (!vars.empty()) + if (!subs.empty()) { // reverse substitution to opposite child // r{ x -> t } = s implies ite( x=t ^ C, s, r ) ---> r // We can use ordinary substitute since the result of the substitution // is not being returned. In other words, nn is only being used to query // whether the second branch is a generalization of the first. - Node nn = - t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end()); + Node nn = subs.apply(t2); if (nn != t2) { nn = d_rew.rewrite(nn); @@ -511,7 +509,7 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const // ite( x=t ^ C, s, r ) ---> ite( x=t ^ C, s{ x -> t }, r ) // must use partial substitute here, to avoid substitution into witness std::map rkinds; - nn = partialSubstitute(t1, vars, subs, rkinds); + nn = partialSubstitute(t1, subs, rkinds); nn = d_rew.rewrite(nn); if (nn != t1) { @@ -1064,11 +1062,10 @@ Node ExtendedRewriter::extendedRewriteEqRes(Kind andk, if (!eq.isNull()) { // see if it corresponds to a substitution - std::vector vars; - std::vector subs; - if (inferSubstitution(eq, vars, subs)) + Subs subs; + if (inferSubstitution(eq, subs)) { - Assert(vars.size() == 1); + Assert(subs.size() == 1); std::vector children; bool childrenChanged = false; // apply to all other children @@ -1080,7 +1077,7 @@ Node ExtendedRewriter::extendedRewriteEqRes(Kind andk, // Substitution is only applicable to compatible kinds. We always // use the partialSubstitute method to avoid substitution into // witness terms. - ccs = partialSubstitute(ccs, vars, subs, bcp_kinds); + ccs = partialSubstitute(ccs, subs, bcp_kinds); childrenChanged = childrenChanged || n[j] != ccs; } children.push_back(ccs); @@ -1610,16 +1607,12 @@ Node ExtendedRewriter::partialSubstitute( } Node ExtendedRewriter::partialSubstitute( - Node n, - const std::vector& vars, - const std::vector& subs, - const std::map& rkinds) const + Node n, const Subs& subs, const std::map& rkinds) const { - Assert(vars.size() == subs.size()); std::map assign; - for (size_t i = 0, nvars = vars.size(); i < nvars; i++) + for (size_t i = 0, nvars = subs.size(); i < nvars; i++) { - assign[vars[i]] = subs[i]; + assign[subs.d_vars[i]] = subs.d_subs[i]; } return partialSubstitute(n, assign, rkinds); } @@ -1632,17 +1625,14 @@ Node ExtendedRewriter::solveEquality(Node n) const return Node::null(); } -bool ExtendedRewriter::inferSubstitution(Node n, - std::vector& vars, - std::vector& subs, - bool usePred) const +bool ExtendedRewriter::inferSubstitution(Node n, Subs& subs, bool usePred) const { if (n.getKind() == AND) { bool ret = false; for (const Node& nc : n) { - bool cret = inferSubstitution(nc, vars, subs, usePred); + bool cret = inferSubstitution(nc, subs, usePred); ret = ret || cret; } return ret; @@ -1660,8 +1650,7 @@ bool ExtendedRewriter::inferSubstitution(Node n, { if (n[i].isConst()) { - vars.push_back(n[1 - i]); - subs.push_back(n[i]); + subs.add(n[1 - i], n[i]); return true; } if (n[i].isVar()) @@ -1685,11 +1674,9 @@ bool ExtendedRewriter::inferSubstitution(Node n, Assert(TermUtil::isNegate(n[i].getKind())); r2 = TermUtil::mkNegate(n[i].getKind(), r2); } - // TODO (#1706) : union find - if (std::find(vars.begin(), vars.end(), r1) == vars.end()) + if (!subs.contains(r1)) { - vars.push_back(r1); - subs.push_back(r2); + subs.add(r1, r2); return true; } } @@ -1698,8 +1685,9 @@ bool ExtendedRewriter::inferSubstitution(Node n, if (usePred) { bool negated = n.getKind() == NOT; - vars.push_back(negated ? n[0] : n); - subs.push_back(negated ? d_false : d_true); + Node var = negated ? n[0] : n; + Node s = NodeManager::currentNM()->mkConst(!negated); + subs.add(var, s); return true; } return false; diff --git a/src/theory/quantifiers/extended_rewrite.h b/src/theory/quantifiers/extended_rewrite.h index bc96cbeed..002c00b65 100644 --- a/src/theory/quantifiers/extended_rewrite.h +++ b/src/theory/quantifiers/extended_rewrite.h @@ -21,6 +21,7 @@ #include #include "expr/node.h" +#include "expr/subs.h" namespace cvc5::internal { namespace theory { @@ -207,10 +208,9 @@ class ExtendedRewriter Node partialSubstitute(Node n, const std::map& assign, const std::map& rkinds) const; - /** same as above, with vectors */ + /** same as above, with the subs utility */ Node partialSubstitute(Node n, - const std::vector& vars, - const std::vector& subs, + const Subs& subs, const std::map& rkinds) const; /** solve equality * @@ -223,15 +223,12 @@ class ExtendedRewriter * If n is an equality of the form x = t, where t is either: * (1) a constant, or * (2) a variable y such that x < y based on an ordering, - * then this method adds x to vars and y to subs and return true, otherwise + * then this method adds {x -> y} to subs and return true, otherwise * it returns false. * If usePred is true, we may additionally add n -> true, or n[0] -> false * is n is a negation. */ - bool inferSubstitution(Node n, - std::vector& vars, - std::vector& subs, - bool usePred = false) const; + bool inferSubstitution(Node n, Subs& subs, bool usePred = false) const; /** extended rewrite * * Prints debug information, indicating the rewrite n ---> ret was found. -- 2.30.2