Make extended rewriter use standard Subs utility (#8682)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 29 Apr 2022 15:22:20 +0000 (10:22 -0500)
committerGitHub <noreply@github.com>
Fri, 29 Apr 2022 15:22:20 +0000 (15:22 +0000)
This is work towards ensuring all substitutions are strictly typed.

src/theory/quantifiers/extended_rewrite.cpp
src/theory/quantifiers/extended_rewrite.h

index 975fe6e7e6040ff3c4aed0854bd3b528b889edb1..d23cf209fc9dfadbc2fcda4cafe047960fd8e22d 100644 (file)
@@ -485,19 +485,17 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const
   {
     // If x is less than t based on an ordering, then we use { x -> t } as a
     // substitution to the children of ite( x = t ^ C, s, t ) below.
-    std::vector<Node> vars;
-    std::vector<Node> subs;
-    inferSubstitution(n[0], vars, subs, true);
+    Subs subs;
+    inferSubstitution(n[0], subs, true);
 
-    if (!vars.empty())
+    if (!subs.empty())
     {
       // reverse substitution to opposite child
       // r{ x -> t } = s  implies  ite( x=t ^ C, s, r ) ---> r
       // We can use ordinary substitute since the result of the substitution
       // is not being returned. In other words, nn is only being used to query
       // whether the second branch is a generalization of the first.
-      Node nn =
-          t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+      Node nn = subs.apply(t2);
       if (nn != t2)
       {
         nn = d_rew.rewrite(nn);
@@ -511,7 +509,7 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const
       // ite( x=t ^ C, s, r ) ---> ite( x=t ^ C, s{ x -> t }, r )
       // must use partial substitute here, to avoid substitution into witness
       std::map<Kind, bool> rkinds;
-      nn = partialSubstitute(t1, vars, subs, rkinds);
+      nn = partialSubstitute(t1, subs, rkinds);
       nn = d_rew.rewrite(nn);
       if (nn != t1)
       {
@@ -1064,11 +1062,10 @@ Node ExtendedRewriter::extendedRewriteEqRes(Kind andk,
       if (!eq.isNull())
       {
         // see if it corresponds to a substitution
-        std::vector<Node> vars;
-        std::vector<Node> subs;
-        if (inferSubstitution(eq, vars, subs))
+        Subs subs;
+        if (inferSubstitution(eq, subs))
         {
-          Assert(vars.size() == 1);
+          Assert(subs.size() == 1);
           std::vector<Node> children;
           bool childrenChanged = false;
           // apply to all other children
@@ -1080,7 +1077,7 @@ Node ExtendedRewriter::extendedRewriteEqRes(Kind andk,
               // Substitution is only applicable to compatible kinds. We always
               // use the partialSubstitute method to avoid substitution into
               // witness terms.
-              ccs = partialSubstitute(ccs, vars, subs, bcp_kinds);
+              ccs = partialSubstitute(ccs, subs, bcp_kinds);
               childrenChanged = childrenChanged || n[j] != ccs;
             }
             children.push_back(ccs);
@@ -1610,16 +1607,12 @@ Node ExtendedRewriter::partialSubstitute(
 }
 
 Node ExtendedRewriter::partialSubstitute(
-    Node n,
-    const std::vector<Node>& vars,
-    const std::vector<Node>& subs,
-    const std::map<Kind, bool>& rkinds) const
+    Node n, const Subs& subs, const std::map<Kind, bool>& rkinds) const
 {
-  Assert(vars.size() == subs.size());
   std::map<Node, Node> assign;
-  for (size_t i = 0, nvars = vars.size(); i < nvars; i++)
+  for (size_t i = 0, nvars = subs.size(); i < nvars; i++)
   {
-    assign[vars[i]] = subs[i];
+    assign[subs.d_vars[i]] = subs.d_subs[i];
   }
   return partialSubstitute(n, assign, rkinds);
 }
@@ -1632,17 +1625,14 @@ Node ExtendedRewriter::solveEquality(Node n) const
   return Node::null();
 }
 
-bool ExtendedRewriter::inferSubstitution(Node n,
-                                         std::vector<Node>& vars,
-                                         std::vector<Node>& subs,
-                                         bool usePred) const
+bool ExtendedRewriter::inferSubstitution(Node n, Subs& subs, bool usePred) const
 {
   if (n.getKind() == AND)
   {
     bool ret = false;
     for (const Node& nc : n)
     {
-      bool cret = inferSubstitution(nc, vars, subs, usePred);
+      bool cret = inferSubstitution(nc, subs, usePred);
       ret = ret || cret;
     }
     return ret;
@@ -1660,8 +1650,7 @@ bool ExtendedRewriter::inferSubstitution(Node n,
     {
       if (n[i].isConst())
       {
-        vars.push_back(n[1 - i]);
-        subs.push_back(n[i]);
+        subs.add(n[1 - i], n[i]);
         return true;
       }
       if (n[i].isVar())
@@ -1685,11 +1674,9 @@ bool ExtendedRewriter::inferSubstitution(Node n,
           Assert(TermUtil::isNegate(n[i].getKind()));
           r2 = TermUtil::mkNegate(n[i].getKind(), r2);
         }
-        // TODO (#1706) : union find
-        if (std::find(vars.begin(), vars.end(), r1) == vars.end())
+        if (!subs.contains(r1))
         {
-          vars.push_back(r1);
-          subs.push_back(r2);
+          subs.add(r1, r2);
           return true;
         }
       }
@@ -1698,8 +1685,9 @@ bool ExtendedRewriter::inferSubstitution(Node n,
   if (usePred)
   {
     bool negated = n.getKind() == NOT;
-    vars.push_back(negated ? n[0] : n);
-    subs.push_back(negated ? d_false : d_true);
+    Node var = negated ? n[0] : n;
+    Node s = NodeManager::currentNM()->mkConst(!negated);
+    subs.add(var, s);
     return true;
   }
   return false;
index bc96cbeed6796a75c4c0dce9706a1b533c44c73b..002c00b65c27f0ad83efc371006ae3e5c4f56285 100644 (file)
@@ -21,6 +21,7 @@
 #include <unordered_map>
 
 #include "expr/node.h"
+#include "expr/subs.h"
 
 namespace cvc5::internal {
 namespace theory {
@@ -207,10 +208,9 @@ class ExtendedRewriter
   Node partialSubstitute(Node n,
                          const std::map<Node, Node>& assign,
                          const std::map<Kind, bool>& rkinds) const;
-  /** same as above, with vectors */
+  /** same as above, with the subs utility */
   Node partialSubstitute(Node n,
-                         const std::vector<Node>& vars,
-                         const std::vector<Node>& subs,
+                         const Subs& subs,
                          const std::map<Kind, bool>& rkinds) const;
   /** solve equality
    *
@@ -223,15 +223,12 @@ class ExtendedRewriter
    * If n is an equality of the form x = t, where t is either:
    * (1) a constant, or
    * (2) a variable y such that x < y based on an ordering,
-   * then this method adds x to vars and y to subs and return true, otherwise
+   * then this method adds {x -> y} to subs and return true, otherwise
    * it returns false.
    * If usePred is true, we may additionally add n -> true, or n[0] -> false
    * is n is a negation.
    */
-  bool inferSubstitution(Node n,
-                         std::vector<Node>& vars,
-                         std::vector<Node>& subs,
-                         bool usePred = false) const;
+  bool inferSubstitution(Node n, Subs& subs, bool usePred = false) const;
   /** extended rewrite
    *
    * Prints debug information, indicating the rewrite n ---> ret was found.