Do not construct instantiation for checking propagating instantiations spurious ...
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 20 Oct 2021 20:34:33 +0000 (15:34 -0500)
committerGitHub <noreply@github.com>
Wed, 20 Oct 2021 20:34:33 +0000 (20:34 +0000)
This makes the check for when an instantiation is "propagating" faster by not constructing the substitution + rewriting of the entire formula, and instead threading the substitution through the entailment check utility's evaluateTerm utility.

On a handful of challenge Facebook benchmarks, we go 35 seconds -> 18 seconds with this change.

This also eliminates the argument exp to the evaluateTerm method, which is no longer used, and eliminates hasSubs from several methods, which is redundant.

src/theory/quantifiers/entailment_check.cpp
src/theory/quantifiers/entailment_check.h
src/theory/quantifiers/quant_conflict_find.cpp
src/theory/quantifiers/quant_conflict_find.h

index f27e14121b6e383636830cb6cd12050801c4476b..543414a4e8dd3ff88eae145308dcf440b20c749a 100644 (file)
@@ -33,11 +33,12 @@ EntailmentCheck::EntailmentCheck(Env& env, QuantifiersState& qs, TermDb& tdb)
 }
 
 EntailmentCheck::~EntailmentCheck() {}
+
 Node EntailmentCheck::evaluateTerm2(TNode n,
                                     std::map<TNode, Node>& visited,
-                                    std::vector<Node>& exp,
+                                    std::map<TNode, TNode>& subs,
+                                    bool subsRep,
                                     bool useEntailmentTests,
-                                    bool computeExp,
                                     bool reqHasTerm)
 {
   std::map<TNode, Node>::iterator itv = visited.find(n);
@@ -45,36 +46,43 @@ Node EntailmentCheck::evaluateTerm2(TNode n,
   {
     return itv->second;
   }
-  size_t prevSize = exp.size();
   Trace("term-db-eval") << "evaluate term : " << n << std::endl;
   Node ret = n;
-  if (n.getKind() == FORALL || n.getKind() == BOUND_VARIABLE)
+  Kind k = n.getKind();
+  if (k == FORALL)
   {
     // do nothing
   }
-  else if (d_qstate.hasTerm(n))
+  else if (k == BOUND_VARIABLE)
   {
-    Trace("term-db-eval") << "...exists in ee, return rep" << std::endl;
-    ret = d_qstate.getRepresentative(n);
-    if (computeExp)
+    std::map<TNode, TNode>::iterator it = subs.find(n);
+    if (it != subs.end())
     {
-      if (n != ret)
+      if (!subsRep)
       {
-        exp.push_back(n.eqNode(ret));
+        Assert(d_qstate.hasTerm(it->second));
+        ret = d_qstate.getRepresentative(it->second);
+      }
+      else
+      {
+        ret = it->second;
       }
     }
+  }
+  else if (d_qstate.hasTerm(n))
+  {
+    Trace("term-db-eval") << "...exists in ee, return rep" << std::endl;
+    ret = d_qstate.getRepresentative(n);
     reqHasTerm = false;
   }
   else if (n.hasOperator())
   {
     std::vector<TNode> args;
     bool ret_set = false;
-    Kind k = n.getKind();
-    std::vector<Node> tempExp;
     for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
     {
       TNode c = evaluateTerm2(
-          n[i], visited, tempExp, useEntailmentTests, computeExp, reqHasTerm);
+          n[i], visited, subs, subsRep, useEntailmentTests, reqHasTerm);
       if (c.isNull())
       {
         ret = Node::null();
@@ -95,32 +103,19 @@ Node EntailmentCheck::evaluateTerm2(TNode n,
         {
           ret = evaluateTerm2(n[c == d_true ? 1 : 2],
                               visited,
-                              tempExp,
+                              subs,
+                              subsRep,
                               useEntailmentTests,
-                              computeExp,
                               reqHasTerm);
           ret_set = true;
           reqHasTerm = false;
           break;
         }
       }
-      if (computeExp)
-      {
-        exp.insert(exp.end(), tempExp.begin(), tempExp.end());
-      }
       Trace("term-db-eval") << "  child " << i << " : " << c << std::endl;
       args.push_back(c);
     }
-    if (ret_set)
-    {
-      // if we short circuited
-      if (computeExp)
-      {
-        exp.clear();
-        exp.insert(exp.end(), tempExp.begin(), tempExp.end());
-      }
-    }
-    else
+    if (!ret_set)
     {
       // get the (indexed) operator of n, if it exists
       TNode f = d_tdb.getMatchOperator(n);
@@ -133,29 +128,11 @@ Node EntailmentCheck::evaluateTerm2(TNode n,
                               << " from DB for " << n << std::endl;
         if (!nn.isNull())
         {
-          if (computeExp)
-          {
-            Assert(nn.getNumChildren() == n.getNumChildren());
-            for (size_t i = 0, nchild = nn.getNumChildren(); i < nchild; i++)
-            {
-              if (nn[i] != n[i])
-              {
-                exp.push_back(nn[i].eqNode(n[i]));
-              }
-            }
-          }
           ret = d_qstate.getRepresentative(nn);
           Trace("term-db-eval") << "return rep" << std::endl;
           ret_set = true;
           reqHasTerm = false;
           Assert(!ret.isNull());
-          if (computeExp)
-          {
-            if (n != ret)
-            {
-              exp.push_back(nn.eqNode(ret));
-            }
-          }
         }
       }
       if (!ret_set)
@@ -188,10 +165,6 @@ Node EntailmentCheck::evaluateTerm2(TNode n,
               if (et.first)
               {
                 ret = j == 0 ? d_true : d_false;
-                if (computeExp)
-                {
-                  exp.push_back(et.second);
-                }
                 break;
               }
             }
@@ -203,9 +176,9 @@ Node EntailmentCheck::evaluateTerm2(TNode n,
   // must have the term
   if (reqHasTerm && !ret.isNull())
   {
-    Kind k = ret.getKind();
-    if (k != OR && k != AND && k != EQUAL && k != ITE && k != NOT
-        && k != FORALL)
+    Kind rk = ret.getKind();
+    if (rk != OR && rk != AND && rk != EQUAL && rk != ITE && rk != NOT
+        && rk != FORALL)
     {
       if (!d_qstate.hasTerm(ret))
       {
@@ -215,19 +188,13 @@ Node EntailmentCheck::evaluateTerm2(TNode n,
   }
   Trace("term-db-eval") << "evaluated term : " << n << ", got : " << ret
                         << ", reqHasTerm = " << reqHasTerm << std::endl;
-  // clear the explanation if failed
-  if (computeExp && ret.isNull())
-  {
-    exp.resize(prevSize);
-  }
   visited[n] = ret;
   return ret;
 }
 
 TNode EntailmentCheck::getEntailedTerm2(TNode n,
                                         std::map<TNode, TNode>& subs,
-                                        bool subsRep,
-                                        bool hasSubs)
+                                        bool subsRep)
 {
   Trace("term-db-entail") << "get entailed term : " << n << std::endl;
   if (d_qstate.hasTerm(n))
@@ -237,30 +204,27 @@ TNode EntailmentCheck::getEntailedTerm2(TNode n,
   }
   else if (n.getKind() == BOUND_VARIABLE)
   {
-    if (hasSubs)
+    std::map<TNode, TNode>::iterator it = subs.find(n);
+    if (it != subs.end())
     {
-      std::map<TNode, TNode>::iterator it = subs.find(n);
-      if (it != subs.end())
+      Trace("term-db-entail")
+          << "...substitution is : " << it->second << std::endl;
+      if (subsRep)
       {
-        Trace("term-db-entail")
-            << "...substitution is : " << it->second << std::endl;
-        if (subsRep)
-        {
-          Assert(d_qstate.hasTerm(it->second));
-          Assert(d_qstate.getRepresentative(it->second) == it->second);
-          return it->second;
-        }
-        return getEntailedTerm2(it->second, subs, subsRep, hasSubs);
+        Assert(d_qstate.hasTerm(it->second));
+        Assert(d_qstate.getRepresentative(it->second) == it->second);
+        return it->second;
       }
+      return getEntailedTerm2(it->second, subs, subsRep);
     }
   }
   else if (n.getKind() == ITE)
   {
     for (uint32_t i = 0; i < 2; i++)
     {
-      if (isEntailed2(n[0], subs, subsRep, hasSubs, i == 0))
+      if (isEntailed2(n[0], subs, subsRep, i == 0))
       {
-        return getEntailedTerm2(n[i == 0 ? 1 : 2], subs, subsRep, hasSubs);
+        return getEntailedTerm2(n[i == 0 ? 1 : 2], subs, subsRep);
       }
     }
   }
@@ -274,7 +238,7 @@ TNode EntailmentCheck::getEntailedTerm2(TNode n,
         std::vector<TNode> args;
         for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
         {
-          TNode c = getEntailedTerm2(n[i], subs, subsRep, hasSubs);
+          TNode c = getEntailedTerm2(n[i], subs, subsRep);
           if (c.isNull())
           {
             return TNode::null();
@@ -294,48 +258,52 @@ TNode EntailmentCheck::getEntailedTerm2(TNode n,
 }
 
 Node EntailmentCheck::evaluateTerm(TNode n,
+                                   std::map<TNode, TNode>& subs,
+                                   bool subsRep,
                                    bool useEntailmentTests,
                                    bool reqHasTerm)
 {
   std::map<TNode, Node> visited;
-  std::vector<Node> exp;
-  return evaluateTerm2(n, visited, exp, useEntailmentTests, false, reqHasTerm);
+  return evaluateTerm2(
+      n, visited, subs, subsRep, useEntailmentTests, reqHasTerm);
 }
 
 Node EntailmentCheck::evaluateTerm(TNode n,
-                                   std::vector<Node>& exp,
                                    bool useEntailmentTests,
                                    bool reqHasTerm)
 {
   std::map<TNode, Node> visited;
-  return evaluateTerm2(n, visited, exp, useEntailmentTests, true, reqHasTerm);
+  std::map<TNode, TNode> subs;
+  return evaluateTerm2(n, visited, subs, false, useEntailmentTests, reqHasTerm);
 }
 
 TNode EntailmentCheck::getEntailedTerm(TNode n,
                                        std::map<TNode, TNode>& subs,
                                        bool subsRep)
 {
-  return getEntailedTerm2(n, subs, subsRep, true);
+  return getEntailedTerm2(n, subs, subsRep);
 }
 
 TNode EntailmentCheck::getEntailedTerm(TNode n)
 {
   std::map<TNode, TNode> subs;
-  return getEntailedTerm2(n, subs, false, false);
+  return getEntailedTerm2(n, subs, false);
 }
 
-bool EntailmentCheck::isEntailed2(
-    TNode n, std::map<TNode, TNode>& subs, bool subsRep, bool hasSubs, bool pol)
+bool EntailmentCheck::isEntailed2(TNode n,
+                                  std::map<TNode, TNode>& subs,
+                                  bool subsRep,
+                                  bool pol)
 {
   Trace("term-db-entail") << "Check entailed : " << n << ", pol = " << pol
                           << std::endl;
   Assert(n.getType().isBoolean());
   if (n.getKind() == EQUAL && !n[0].getType().isBoolean())
   {
-    TNode n1 = getEntailedTerm2(n[0], subs, subsRep, hasSubs);
+    TNode n1 = getEntailedTerm2(n[0], subs, subsRep);
     if (!n1.isNull())
     {
-      TNode n2 = getEntailedTerm2(n[1], subs, subsRep, hasSubs);
+      TNode n2 = getEntailedTerm2(n[1], subs, subsRep);
       if (!n2.isNull())
       {
         if (n1 == n2)
@@ -360,14 +328,14 @@ bool EntailmentCheck::isEntailed2(
   }
   else if (n.getKind() == NOT)
   {
-    return isEntailed2(n[0], subs, subsRep, hasSubs, !pol);
+    return isEntailed2(n[0], subs, subsRep, !pol);
   }
   else if (n.getKind() == OR || n.getKind() == AND)
   {
     bool simPol = (pol && n.getKind() == OR) || (!pol && n.getKind() == AND);
     for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
     {
-      if (isEntailed2(n[i], subs, subsRep, hasSubs, pol))
+      if (isEntailed2(n[i], subs, subsRep, pol))
       {
         if (simPol)
         {
@@ -389,17 +357,17 @@ bool EntailmentCheck::isEntailed2(
   {
     for (size_t i = 0; i < 2; i++)
     {
-      if (isEntailed2(n[0], subs, subsRep, hasSubs, i == 0))
+      if (isEntailed2(n[0], subs, subsRep, i == 0))
       {
         size_t ch = (n.getKind() == EQUAL || i == 0) ? 1 : 2;
         bool reqPol = (n.getKind() == ITE || i == 0) ? pol : !pol;
-        return isEntailed2(n[ch], subs, subsRep, hasSubs, reqPol);
+        return isEntailed2(n[ch], subs, subsRep, reqPol);
       }
     }
   }
   else if (n.getKind() == APPLY_UF)
   {
-    TNode n1 = getEntailedTerm2(n, subs, subsRep, hasSubs);
+    TNode n1 = getEntailedTerm2(n, subs, subsRep);
     if (!n1.isNull())
     {
       Assert(d_qstate.hasTerm(n1));
@@ -419,7 +387,7 @@ bool EntailmentCheck::isEntailed2(
   }
   else if (n.getKind() == FORALL && !pol)
   {
-    return isEntailed2(n[1], subs, subsRep, hasSubs, pol);
+    return isEntailed2(n[1], subs, subsRep, pol);
   }
   return false;
 }
@@ -427,7 +395,7 @@ bool EntailmentCheck::isEntailed2(
 bool EntailmentCheck::isEntailed(TNode n, bool pol)
 {
   std::map<TNode, TNode> subs;
-  return isEntailed2(n, subs, false, false, pol);
+  return isEntailed2(n, subs, false, pol);
 }
 
 bool EntailmentCheck::isEntailed(TNode n,
@@ -435,7 +403,7 @@ bool EntailmentCheck::isEntailed(TNode n,
                                  bool subsRep,
                                  bool pol)
 {
-  return isEntailed2(n, subs, subsRep, true, pol);
+  return isEntailed2(n, subs, subsRep, pol);
 }
 
 }  // namespace quantifiers
index 5f0af60a6387107564a0e96f8c0df52a721bb15c..ecf74fe85a9ccf26e95d3607304d74ce869faf89 100644 (file)
@@ -44,19 +44,18 @@ class EntailmentCheck : protected EnvObj
   ~EntailmentCheck();
   /** evaluate term
    *
-   * Returns a term n' such that n = n' is entailed based on the equality
-   * information ee.  This function may generate new terms. In particular,
-   * we typically rewrite subterms of n of maximal size to terms that exist in
-   * the equality engine specified by ee.
+   * Returns a term n' such that n * subs = n' is entailed based on the current
+   * set of equalities, where ( n * subs ) is term n under the substitution
+   * subs.
+   *
+   * This function may generate new terms. In particular, we typically rewrite
+   * subterms of n of maximal size (in terms of the AST) to terms that exist
+   * in the equality engine.
    *
    * useEntailmentTests is whether to call the theory engine's entailmentTest
    * on literals n for which this call fails to find a term n' that is
    * equivalent to n, for increased precision. This is not frequently used.
    *
-   * The vector exp stores the explanation for why n evaluates to that term,
-   * that is, if this call returns a non-null node n', then:
-   *   exp => n = n'
-   *
    * If reqHasTerm, then we require that the returned term is a Boolean
    * combination of terms that exist in the equality engine used by this call.
    * If no such term is constructable, this call returns null. The motivation
@@ -64,12 +63,23 @@ class EntailmentCheck : protected EnvObj
    * of this function to only involve existing terms. This is used e.g. in
    * the "propagating instances" portion of conflict-based instantiation
    * (quant_conflict_find.h).
+   *
+   * @param n The term under consideration
+   * @param subs The substitution under consideration
+   * @param subsRep Whether the range of subs are representatives in the current
+   * equality engine
+   * @param useEntailmentTests Whether to use entailment tests to show
+   * n * subs is equivalent to true/false.
+   * @param reqHasTerm Whether we require the returned term to be a Boolean
+   * combination of terms known to the current equality engine
+   * @return the term n * subs evaluates to
    */
   Node evaluateTerm(TNode n,
-                    std::vector<Node>& exp,
+                    std::map<TNode, TNode>& subs,
+                    bool subsRep,
                     bool useEntailmentTests = false,
                     bool reqHasTerm = false);
-  /** same as above, without exp */
+  /** Same as above, without a substitution */
   Node evaluateTerm(TNode n,
                     bool useEntailmentTests = false,
                     bool reqHasTerm = false);
@@ -119,20 +129,16 @@ class EntailmentCheck : protected EnvObj
   /** helper for evaluate term */
   Node evaluateTerm2(TNode n,
                      std::map<TNode, Node>& visited,
-                     std::vector<Node>& exp,
+                     std::map<TNode, TNode>& subs,
+                     bool subsRep,
                      bool useEntailmentTests,
-                     bool computeExp,
                      bool reqHasTerm);
   /** helper for get entailed term */
-  TNode getEntailedTerm2(TNode n,
-                         std::map<TNode, TNode>& subs,
-                         bool subsRep,
-                         bool hasSubs);
+  TNode getEntailedTerm2(TNode n, std::map<TNode, TNode>& subs, bool subsRep);
   /** helper for is entailed */
   bool isEntailed2(TNode n,
                    std::map<TNode, TNode>& subs,
                    bool subsRep,
-                   bool hasSubs,
                    bool pol);
   /** The quantifiers state object */
   QuantifiersState& d_qstate;
index dc1043d28715ea0d369ab838ad15cfa89998a75f..3233988796ca0ea1c2ed433aa343a26db011ea43 100644 (file)
@@ -52,12 +52,6 @@ QuantInfo::~QuantInfo() {
   d_var_mg.clear();
 }
 
-QuantifiersInferenceManager& QuantInfo::getInferenceManager()
-{
-  Assert(d_parent != nullptr);
-  return d_parent->getInferenceManager();
-}
-
 void QuantInfo::initialize( QuantConflictFind * p, Node q, Node qn ) {
   d_parent = p;
   d_q = q;
@@ -578,29 +572,32 @@ bool QuantInfo::isTConstraintSpurious(QuantConflictFind* p,
   if( options::qcfEagerTest() ){
     //check whether the instantiation evaluates as expected
     EntailmentCheck* echeck = p->getTermRegistry().getEntailmentCheck();
+    std::map<TNode, TNode> subs;
+    for (size_t i = 0, tsize = terms.size(); i < tsize; i++)
+    {
+      Trace("qcf-instance-check") << "  " << terms[i] << std::endl;
+      subs[d_q[0][i]] = terms[i];
+    }
+    for (size_t i = 0, evsize = d_extra_var.size(); i < evsize; i++)
+    {
+      Node n = getCurrentExpValue(d_extra_var[i]);
+      Trace("qcf-instance-check")
+          << "  " << d_extra_var[i] << " -> " << n << std::endl;
+      subs[d_extra_var[i]] = n;
+    }
     if (p->atConflictEffort()) {
       Trace("qcf-instance-check") << "Possible conflict instance for " << d_q << " : " << std::endl;
-      std::map< TNode, TNode > subs;
-      for( unsigned i=0; i<terms.size(); i++ ){
-        Trace("qcf-instance-check") << "  " << terms[i] << std::endl;
-        subs[d_q[0][i]] = terms[i];
-      }
-      for( unsigned i=0; i<d_extra_var.size(); i++ ){
-        Node n = getCurrentExpValue( d_extra_var[i] );
-        Trace("qcf-instance-check") << "  " << d_extra_var[i] << " -> " << n << std::endl;
-        subs[d_extra_var[i]] = n;
-      }
       if (!echeck->isEntailed(d_q[1], subs, false, false))
       {
         Trace("qcf-instance-check") << "...not entailed to be false." << std::endl;
         return true;
       }
     }else{
-      Node inst =
-          getInferenceManager().getInstantiate()->getInstantiation(d_q, terms);
-      inst = Rewriter::rewrite(inst);
-      Node inst_eval =
-          echeck->evaluateTerm(inst, options::qcfTConstraint(), true);
+      // see if the body of the quantified formula evaluates to a Boolean
+      // combination of known terms under the current substitution. We use
+      // the helper method evaluateTerm from the entailment check utility.
+      Node inst_eval = echeck->evaluateTerm(
+          d_q[1], subs, false, options::qcfTConstraint(), true);
       if( Trace.isOn("qcf-instance-check") ){
         Trace("qcf-instance-check") << "Possible propagating instance for " << d_q << " : " << std::endl;
         for( unsigned i=0; i<terms.size(); i++ ){
@@ -608,6 +605,10 @@ bool QuantInfo::isTConstraintSpurious(QuantConflictFind* p,
         }
         Trace("qcf-instance-check") << "...evaluates to " << inst_eval << std::endl;
       }
+      // If it is the case that instantiation can be rewritten to a Boolean
+      // combination of terms that exist in the current context, then inst_eval
+      // is non-null. Moreover, we insist that inst_eval is not true, or else
+      // the instantiation is trivially entailed and hence is spurious.
       if (inst_eval.isNull()
           || (inst_eval.isConst() && inst_eval.getConst<bool>()))
       {
index 927a74ff25be4ca2acebd70104622054a6d6b06c..d14e281fbdb9c9c8f4e6877930e2c54a50ed18e0 100644 (file)
@@ -132,8 +132,6 @@ public:
 public:
   QuantInfo();
   ~QuantInfo();
-  /** Get quantifiers inference manager */
-  QuantifiersInferenceManager& getInferenceManager();
   std::vector< TNode > d_vars;
   std::vector< TypeNode > d_var_types;
   std::map< TNode, int > d_var_num;