Refactoring theory-specific variable elimination in quantifiers rewriter (#7026)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 17 Aug 2021 17:31:35 +0000 (12:31 -0500)
committerGitHub <noreply@github.com>
Tue, 17 Aug 2021 17:31:35 +0000 (17:31 +0000)
No behavior changes in this PR, just code reorganization.

src/theory/quantifiers/quantifiers_rewriter.cpp
src/theory/quantifiers/quantifiers_rewriter.h

index aba2d79bf5387c26af9cffb3498672ec14084ef6..02af92887c35daf601e27885ba4e363b253c45f6 100644 (file)
@@ -701,9 +701,65 @@ bool QuantifiersRewriter::isVarElim(Node v, Node s)
   return !expr::hasSubterm(s, v) && s.getType().isSubtypeOf(v.getType());
 }
 
-Node QuantifiersRewriter::getVarElimLitBv(Node lit,
-                                          const std::vector<Node>& args,
-                                          Node& var)
+Node QuantifiersRewriter::getVarElimEq(Node lit,
+                                       const std::vector<Node>& args,
+                                       Node& var)
+{
+  Assert(lit.getKind() == EQUAL);
+  Node slv;
+  TypeNode tt = lit[0].getType();
+  if (tt.isReal())
+  {
+    slv = getVarElimEqReal(lit, args, var);
+  }
+  else if (tt.isBitVector())
+  {
+    slv = getVarElimEqBv(lit, args, var);
+  }
+  else if (tt.isStringLike())
+  {
+    slv = getVarElimEqString(lit, args, var);
+  }
+  return slv;
+}
+
+Node QuantifiersRewriter::getVarElimEqReal(Node lit,
+                                           const std::vector<Node>& args,
+                                           Node& var)
+{
+  // for arithmetic, solve the equality
+  std::map<Node, Node> msum;
+  if (!ArithMSum::getMonomialSumLit(lit, msum))
+  {
+    return Node::null();
+  }
+  std::vector<Node>::const_iterator ita;
+  for (std::map<Node, Node>::iterator itm = msum.begin(); itm != msum.end();
+       ++itm)
+  {
+    if (itm->first.isNull())
+    {
+      continue;
+    }
+    ita = std::find(args.begin(), args.end(), itm->first);
+    if (ita != args.end())
+    {
+      Node veq_c;
+      Node val;
+      int ires = ArithMSum::isolate(itm->first, msum, veq_c, val, EQUAL);
+      if (ires != 0 && veq_c.isNull() && isVarElim(itm->first, val))
+      {
+        var = itm->first;
+        return val;
+      }
+    }
+  }
+  return Node::null();
+}
+
+Node QuantifiersRewriter::getVarElimEqBv(Node lit,
+                                         const std::vector<Node>& args,
+                                         Node& var)
 {
   if (Trace.isOn("quant-velim-bv"))
   {
@@ -752,9 +808,9 @@ Node QuantifiersRewriter::getVarElimLitBv(Node lit,
   return Node::null();
 }
 
-Node QuantifiersRewriter::getVarElimLitString(Node lit,
-                                              const std::vector<Node>& args,
-                                              Node& var)
+Node QuantifiersRewriter::getVarElimEqString(Node lit,
+                                             const std::vector<Node>& args,
+                                             Node& var)
 {
   Assert(lit.getKind() == EQUAL);
   NodeManager* nm = NodeManager::currentNM();
@@ -900,48 +956,10 @@ bool QuantifiersRewriter::getVarElimLit(Node lit,
       return true;
     }
   }
-  if (lit.getKind() == EQUAL && lit[0].getType().isReal() && pol)
-  {
-    // for arithmetic, solve the equality
-    std::map< Node, Node > msum;
-    if (ArithMSum::getMonomialSumLit(lit, msum))
-    {
-      for( std::map< Node, Node >::iterator itm = msum.begin(); itm != msum.end(); ++itm ){
-        if( !itm->first.isNull() ){
-          std::vector< Node >::iterator ita = std::find( args.begin(), args.end(), itm->first );
-          if( ita!=args.end() ){
-            Assert(pol);
-            Node veq_c;
-            Node val;
-            int ires = ArithMSum::isolate(itm->first, msum, veq_c, val, EQUAL);
-            if (ires != 0 && veq_c.isNull() && isVarElim(itm->first, val))
-            {
-              Trace("var-elim-quant")
-                  << "Variable eliminate based on solved equality : "
-                  << itm->first << " -> " << val << std::endl;
-              vars.push_back(itm->first);
-              subs.push_back(val);
-              args.erase(ita);
-              return true;
-            }
-          }
-        }
-      }
-    }
-  }
   if (lit.getKind() == EQUAL && pol)
   {
     Node var;
-    Node slv;
-    TypeNode tt = lit[0].getType();
-    if (tt.isBitVector())
-    {
-      slv = getVarElimLitBv(lit, args, var);
-    }
-    else if (tt.isStringLike())
-    {
-      slv = getVarElimLitString(lit, args, var);
-    }
+    Node slv = getVarElimEq(lit, args, var);
     if (!slv.isNull())
     {
       Assert(!var.isNull());
index ae7f75f34a3e153c4f85b7b1b491e90d00445c4f..f0c3b041426c251b88d41e59dc70e4554e2b999d 100644 (file)
@@ -80,22 +80,34 @@ class QuantifiersRewriter : public TheoryRewriter
                             std::vector<Node>& args,
                             std::vector<Node>& vars,
                             std::vector<Node>& subs);
+  /**
+   * Get variable eliminate for an equality based on theory-specific reasoning.
+   */
+  static Node getVarElimEq(Node lit, const std::vector<Node>& args, Node& var);
+  /** variable eliminate for real equalities
+   *
+   * If this returns a non-null value ret, then var is updated to a member of
+   * args, lit is equivalent to ( var = ret ).
+   */
+  static Node getVarElimEqReal(Node lit,
+                               const std::vector<Node>& args,
+                               Node& var);
   /** variable eliminate for bit-vector equalities
    *
    * If this returns a non-null value ret, then var is updated to a member of
    * args, lit is equivalent to ( var = ret ).
    */
-  static Node getVarElimLitBv(Node lit,
-                              const std::vector<Node>& args,
-                              Node& var);
+  static Node getVarElimEqBv(Node lit,
+                             const std::vector<Node>& args,
+                             Node& var);
   /** variable eliminate for string equalities
    *
    * If this returns a non-null value ret, then var is updated to a member of
    * args, lit is equivalent to ( var = ret ).
    */
-  static Node getVarElimLitString(Node lit,
-                                  const std::vector<Node>& args,
-                                  Node& var);
+  static Node getVarElimEqString(Node lit,
+                                 const std::vector<Node>& args,
+                                 Node& var);
   /** get variable elimination
    *
    * If n asserted with polarity pol entails a literal lit that corresponds