Improvements to quant+BV/Bool variable elimination (#1495)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 8 Jan 2018 19:21:29 +0000 (13:21 -0600)
committerGitHub <noreply@github.com>
Mon, 8 Jan 2018 19:21:29 +0000 (13:21 -0600)
src/theory/quantifiers/bv_inverter.cpp
src/theory/quantifiers/bv_inverter.h
src/theory/quantifiers/quantifiers_rewriter.cpp
src/theory/quantifiers/quantifiers_rewriter.h
test/regress/regress0/quantifiers/Makefile.am

index 3ad99999cce685911299458197540d6b7cf996a1..f0ad4f797a8373c58e7aceb3291e31e8372405cd 100644 (file)
@@ -83,11 +83,19 @@ Node BvInverter::getInversionNode(Node cond, TypeNode tn, BvInverterQuery* m)
   if (c.isNull())
   {
     NodeManager* nm = NodeManager::currentNM();
-    Node x = m->getBoundVariable(tn);
-    Node ccond = new_cond.substitute(solve_var, x);
-    c = nm->mkNode(kind::CHOICE, nm->mkNode(BOUND_VAR_LIST, x), ccond);
-    Trace("cegqi-bv-skvinv") << "SKVINV : Make " << c << " for " << new_cond
-                             << std::endl;
+    if (m)
+    {
+      Node x = m->getBoundVariable(tn);
+      Node ccond = new_cond.substitute(solve_var, x);
+      c = nm->mkNode(kind::CHOICE, nm->mkNode(BOUND_VAR_LIST, x), ccond);
+      Trace("cegqi-bv-skvinv")
+          << "SKVINV : Make " << c << " for " << new_cond << std::endl;
+    }
+    else
+    {
+      Trace("bv-invert") << "...fail for " << cond << " : no inverter query!"
+                         << std::endl;
+    }
   }
   // currently shouldn't cache since
   // the return value depends on the
@@ -175,7 +183,7 @@ Node BvInverter::getPathToPv(
   std::unordered_set<TNode, TNodeHashFunction> visited;
   Node slit = getPathToPv(lit, pv, sv, path, visited);
   // if we are able to find a (invertible) path to pv
-  if (!slit.isNull())
+  if (!slit.isNull() && !pvs.isNull())
   {
     // substitute pvs for the other occurrences of pv
     TNode tpv = pv;
@@ -2295,15 +2303,39 @@ Node BvInverter::solveBvLit(Node sv,
     {
       Assert(nchildren >= 2);
       Node s = nchildren == 2 ? sv_t[1 - index] : dropChild(sv_t, index);
+      Node t_new;
       /* Note: All n-ary kinds except for CONCAT (i.e., AND, OR, MULT, PLUS)
        *       are commutative (no case split based on index). */
+
+      // handle cases where the inversion has a unique solution
       if (k == BITVECTOR_PLUS)
       {
-        t = nm->mkNode(BITVECTOR_SUB, t, s);
+        t_new = nm->mkNode(BITVECTOR_SUB, t, s);
       }
       else if (k == BITVECTOR_XOR)
       {
-        t = nm->mkNode(BITVECTOR_XOR, t, s);
+        t_new = nm->mkNode(BITVECTOR_XOR, t, s);
+      }
+      else if (k == BITVECTOR_MULT)
+      {
+        if (s.isConst() && bv::utils::getBit(s, 0))
+        {
+          unsigned ssize = bv::utils::getSize(s);
+          Integer a = s.getConst<BitVector>().toInteger();
+          Integer w = Integer(1).multiplyByPow2(ssize);
+          Trace("bv-invert-debug")
+              << "Compute inverse : " << a << " " << w << std::endl;
+          Integer inv = a.modInverse(w);
+          Trace("bv-invert-debug") << "Inverse : " << inv << std::endl;
+          Node inv_val = nm->mkConst(BitVector(ssize, inv));
+          t_new = nm->mkNode(BITVECTOR_MULT, inv_val, t);
+        }
+      }
+
+      if (!t_new.isNull())
+      {
+        // In this case, s op x = t is equivalent to x = t_new
+        t = t_new;
       }
       else
       {
@@ -2362,6 +2394,10 @@ Node BvInverter::solveBvLit(Node sv,
         pol = true;
         /* t = fresh skolem constant */
         t = getInversionNode(sc, solve_tn, m);
+        if (t.isNull())
+        {
+          return t;
+        }
       }
     }
     sv_t = sv_t[index];
index ce2a58695f05f5bad75331732ba44d3df4a464c2..470c3a71ff2a8240289f631692a9e901dd5a30cc 100644 (file)
@@ -54,16 +54,33 @@ class BvInverter
   /** get dummy fresh variable of type tn, used as argument for sv */
   Node getSolveVariable(TypeNode tn);
 
-  /** Get path to pv in lit, replace that occurrence by sv and all others by
-   * pvs. If return value R is non-null, then : lit.path = pv R.path = sv
+  /**
+   * Get path to pv in lit, replace that occurrence by sv and all others by
+   * pvs (if pvs is non-null). If return value R is non-null, then :
+   *   lit.path = pv R.path = sv
    *   R.path' = pvs for all lit.path' = pv, where path' != path
    */
   Node getPathToPv(
       Node lit, Node pv, Node sv, Node pvs, std::vector<unsigned>& path);
 
+  /**
+   * Same as above, but does not linearize lit for pv.
+   * Use this version if we know lit is linear wrt pv.
+   */
+  Node getPathToPv(Node lit, Node pv, std::vector<unsigned>& path)
+  {
+    return getPathToPv(lit, pv, pv, Node::null(), path);
+  }
+
   /** solveBvLit
-   * solve for sv in lit, where lit.path = sv
-   * status accumulates side conditions
+   *
+   * Solve for sv in lit, where lit.path = sv. If this function returns a
+   * non-null node t, then sv = t is the solved form of lit.
+   *
+   * If the BvInverterQuery provided to this function call is null, then
+   * the solution returned by this call will not contain CHOICE expressions.
+   * If the solved form for lit requires introducing a CHOICE expression,
+   * then this call will return null.
    */
   Node solveBvLit(Node sv,
                   Node lit,
@@ -95,6 +112,9 @@ class BvInverter
    * the solve variable. For example, if cond is x = t where x is
    * getSolveVariable(tn), then we return t instead of introducing the choice
    * function.
+   *
+   * This function will return the null node if the BvInverterQuery m provided
+   * to this call is null.
    */
   Node getInversionNode(Node cond, TypeNode tn, BvInverterQuery* m);
 };
index 511e8f051f474c16846d539634e8486de81b5dfb..17214112b28c91927b9fa85918276ae16a634ae7 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "options/quantifiers_options.h"
 #include "theory/arith/arith_msum.h"
+#include "theory/quantifiers/bv_inverter.h"
 #include "theory/quantifiers/quantifiers_attributes.h"
 #include "theory/quantifiers/skolemize.h"
 #include "theory/quantifiers/term_database.h"
@@ -537,7 +538,7 @@ Node QuantifiersRewriter::computeProcessTerms2( Node body, bool hasPol, bool pol
     Trace("quantifiers-rewrite-term-debug2") << "Return (cached) " << ret << " for " << body << std::endl;
   }else{
     //only do context dependent processing up to depth 8
-    bool doCD = nCurrCond<8;
+    bool doCD = options::condRewriteQuant() && nCurrCond < 8;
     bool changed = false;
     std::vector< Node > children;
     //set entailed conditions based on OR/AND
@@ -890,18 +891,120 @@ void QuantifiersRewriter::isVariableBoundElig( Node n, std::map< Node, int >& ex
   }
 }
 
-bool QuantifiersRewriter::computeVariableElimLit( Node lit, bool pol, std::vector< Node >& args, std::vector< Node >& vars, std::vector< Node >& subs,
-                                                  std::map< Node, std::map< bool, std::map< Node, bool > > >& num_bounds ) {
-  if( lit.getKind()==EQUAL && pol && options::varElimQuant() ){
-    for( unsigned i=0; i<2; i++ ){
-      std::vector< Node >::iterator ita = std::find( args.begin(), args.end(), lit[i] );
-      if( ita!=args.end() ){
-        if( isVariableElim( lit[i], lit[1-i] ) ){
-          Trace("var-elim-quant") << "Variable eliminate based on equality : " << lit[i] << " -> " << lit[1-i] << std::endl;
-          vars.push_back( lit[i] );
-          subs.push_back( lit[1-i] );
-          args.erase( ita );
-          return true;
+Node QuantifiersRewriter::computeVariableElimLitBv(Node lit,
+                                                   std::vector<Node>& args,
+                                                   Node& var)
+{
+  if (Trace.isOn("quant-velim-bv"))
+  {
+    Trace("quant-velim-bv") << "Bv-Elim : " << lit << " varList = { ";
+    for (const Node& v : args)
+    {
+      Trace("quant-velim-bv") << v << " ";
+    }
+    Trace("quant-velim-bv") << "} ?" << std::endl;
+  }
+  Assert(lit.getKind() == EQUAL);
+  // TODO (#1494) : linearize the literal using utility
+
+  // figure out if this literal is linear and invertible on path with args
+  std::map<TNode, bool> linear;
+  std::unordered_set<TNode, TNodeHashFunction> visited;
+  std::unordered_set<TNode, TNodeHashFunction>::iterator it;
+  std::vector<TNode> visit;
+  TNode cur;
+  visit.push_back(lit);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+    if (std::find(args.begin(), args.end(), cur) != args.end())
+    {
+      linear[cur] = linear.find(cur) == linear.end();
+    }
+    if (visited.find(cur) == visited.end())
+    {
+      visited.insert(cur);
+
+      for (const Node& cn : cur)
+      {
+        visit.push_back(cn);
+      }
+    }
+  } while (!visit.empty());
+
+  BvInverter binv;
+  for (std::pair<const TNode, bool>& lp : linear)
+  {
+    if (lp.second)
+    {
+      TNode cvar = lp.first;
+      Trace("quant-velim-bv") << "...linear wrt " << cvar << std::endl;
+      std::vector<unsigned> path;
+      Node slit = binv.getPathToPv(lit, cvar, path);
+      if (!slit.isNull())
+      {
+        Node slv = binv.solveBvLit(cvar, lit, path, nullptr);
+        Trace("quant-velim-bv") << "...solution : " << slv << std::endl;
+        if (!slv.isNull())
+        {
+          var = cvar;
+          return slv;
+        }
+      }
+      else
+      {
+        Trace("quant-velim-bv") << "...non-invertible path." << std::endl;
+      }
+    }
+  }
+
+  return Node::null();
+}
+
+bool QuantifiersRewriter::computeVariableElimLit(
+    Node lit,
+    bool pol,
+    std::vector<Node>& args,
+    std::vector<Node>& vars,
+    std::vector<Node>& subs,
+    std::map<Node, std::map<bool, std::map<Node, bool> > >& num_bounds)
+{
+  Trace("var-elim-quant-debug")
+      << "Eliminate : " << lit << ", pol = " << pol << "?" << std::endl;
+  if (lit.getKind() == EQUAL && options::varElimQuant())
+  {
+    if (pol || lit[0].getType().isBoolean())
+    {
+      for (unsigned i = 0; i < 2; i++)
+      {
+        bool tpol = pol;
+        Node v_slv = lit[i];
+        if (v_slv.getKind() == NOT)
+        {
+          v_slv = v_slv[0];
+          tpol = !tpol;
+        }
+        std::vector<Node>::iterator ita =
+            std::find(args.begin(), args.end(), v_slv);
+        if (ita != args.end())
+        {
+          if (isVariableElim(v_slv, lit[1 - i]))
+          {
+            Node slv = lit[1 - i];
+            if (!tpol)
+            {
+              Assert(slv.getType().isBoolean());
+              slv = slv.negate();
+            }
+            Trace("var-elim-quant")
+                << "Variable eliminate based on equality : " << v_slv << " -> "
+                << slv << std::endl;
+            vars.push_back(v_slv);
+            subs.push_back(slv);
+            args.erase(ita);
+            return true;
+          }
         }
       }
     }
@@ -996,7 +1099,28 @@ bool QuantifiersRewriter::computeVariableElimLit( Node lit, bool pol, std::vecto
       }
     }
   }
-  
+  else if (lit.getKind() == EQUAL && lit[0].getType().isBitVector() && pol
+           && options::varElimQuant())
+  {
+    Node var;
+    Node slv = computeVariableElimLitBv(lit, args, var);
+    if (!slv.isNull())
+    {
+      Assert(!var.isNull());
+      std::vector<Node>::iterator ita =
+          std::find(args.begin(), args.end(), var);
+      Assert(ita != args.end());
+      Assert(isVariableElim(var, slv));
+      Trace("var-elim-quant")
+          << "Variable eliminate based on bit-vector inversion : " << var
+          << " -> " << slv << std::endl;
+      vars.push_back(var);
+      subs.push_back(slv);
+      args.erase(ita);
+      return true;
+    }
+  }
+
   return false;
 }
 
@@ -1559,21 +1683,39 @@ Node QuantifiersRewriter::computeAggressiveMiniscoping( std::vector< Node >& arg
 bool QuantifiersRewriter::doOperation( Node q, int computeOption, QAttributes& qa ){
   bool is_strict_trigger = qa.d_hasPattern && options::userPatternsQuant()==USER_PAT_MODE_TRUST;
   bool is_std = !qa.d_sygus && !qa.d_quant_elim && !qa.isFunDef() && !is_strict_trigger;
-  if( computeOption==COMPUTE_ELIM_SYMBOLS ){
+  if (computeOption == COMPUTE_ELIM_SYMBOLS)
+  {
     return true;
-  }else if( computeOption==COMPUTE_MINISCOPING ){
+  }
+  else if (computeOption == COMPUTE_MINISCOPING)
+  {
     return is_std;
-  }else if( computeOption==COMPUTE_AGGRESSIVE_MINISCOPING ){
+  }
+  else if (computeOption == COMPUTE_AGGRESSIVE_MINISCOPING)
+  {
     return options::aggressiveMiniscopeQuant() && is_std;
-  }else if( computeOption==COMPUTE_PROCESS_TERMS ){
-    return options::condRewriteQuant();
-  }else if( computeOption==COMPUTE_COND_SPLIT ){
-    return ( options::iteDtTesterSplitQuant() || options::condVarSplitQuant() ) && !is_strict_trigger;
-  }else if( computeOption==COMPUTE_PRENEX ){
-    return options::prenexQuant()!=PRENEX_QUANT_NONE && !options::aggressiveMiniscopeQuant() && is_std;
-  }else if( computeOption==COMPUTE_VAR_ELIMINATION ){
-    return ( options::varElimQuant() || options::dtVarExpandQuant() ) && is_std;
-  }else{
+  }
+  else if (computeOption == COMPUTE_PROCESS_TERMS)
+  {
+    return options::condRewriteQuant() || options::elimExtArithQuant()
+           || options::iteLiftQuant() != ITE_LIFT_QUANT_MODE_NONE;
+  }
+  else if (computeOption == COMPUTE_COND_SPLIT)
+  {
+    return (options::iteDtTesterSplitQuant() || options::condVarSplitQuant())
+           && !is_strict_trigger;
+  }
+  else if (computeOption == COMPUTE_PRENEX)
+  {
+    return options::prenexQuant() != PRENEX_QUANT_NONE
+           && !options::aggressiveMiniscopeQuant() && is_std;
+  }
+  else if (computeOption == COMPUTE_VAR_ELIMINATION)
+  {
+    return (options::varElimQuant() || options::dtVarExpandQuant()) && is_std;
+  }
+  else
+  {
     return false;
   }
 }
index b179110e7b6f86f9f4b431b840b8cd14cd85a099..149380b841370009979e4cb2c8acaeaba5e3ea08 100644 (file)
@@ -52,7 +52,16 @@ private:
   static bool computeVariableElimLit( Node n, bool pol, std::vector< Node >& args, std::vector< Node >& var, std::vector< Node >& subs,
                                       std::map< Node, std::map< bool, std::map< Node, bool > > >& num_bounds );
   static Node computeVarElimination2( Node body, std::vector< Node >& args, QAttributes& qa );
-public:
+  /** variable eliminate for bit-vector literals
+   *
+   * If this returns a non-null value ret, then var is updated to a member of
+   * args, and lit is equivalent to ( var = ret ).
+   */
+  static Node computeVariableElimLitBv(Node lit,
+                                       std::vector<Node>& args,
+                                       Node& var);
+
+ public:
   static Node computeElimSymbols( Node body );
   static Node computeMiniscoping( std::vector< Node >& args, Node body, QAttributes& qa );
   static Node computeAggressiveMiniscoping( std::vector< Node >& args, Node body );
index 133c2018dc167b7d8632cc11bd70e68542c28fb1..64f8b6f16582df50725b385d246a19e1d2e08318 100644 (file)
@@ -79,7 +79,6 @@ TESTS =       \
        parametric-lists.smt2 \
        partial-trigger.smt2 \
        inst-max-level-segf.smt2 \
-       small-bug1-fixpoint-3.smt2 \
        z3.620661-no-fv-trigger.smt2 \
        bug_743.smt2 \
        quaternion_ds1_symm_0428.fof.smt2 \
@@ -145,6 +144,9 @@ TESTS =     \
 # disabled since bvcomp handling is currently disabled
 # qbv-test-invert-bvcomp.smt2
 
+# disabled, broken by variable elimination (was solved heuristically previously)
+# small-bug1-fixpoint-3.smt2 
+
 # removed because they take more than 20s
 #              javafe.ast.ArrayInit.35.smt2