Refactor InstMatch (#8646)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 26 Apr 2022 21:43:27 +0000 (16:43 -0500)
committerGitHub <noreply@github.com>
Tue, 26 Apr 2022 21:43:27 +0000 (16:43 -0500)
Also simplifies the (old) version of multi trigger matching, which copied InstMatch objects unnecessarily, and also used EqualityEngine iteration instead of iterating on a trie.

This is in preparation for making InstMatch objects optionally track fast (and generalized) failures based on configurable criteria.

src/theory/quantifiers/ematching/inst_match_generator.cpp
src/theory/quantifiers/ematching/inst_match_generator_multi.cpp
src/theory/quantifiers/ematching/relational_match_generator.cpp
src/theory/quantifiers/ematching/trigger.cpp
src/theory/quantifiers/ematching/var_match_generator.cpp
src/theory/quantifiers/inst_match.cpp
src/theory/quantifiers/inst_match.h

index b4aaee6ac70cf3eb4a73be786674d23936990c30..e31153bbf3d893753628e94a5e949ffbac945d25 100644 (file)
@@ -435,7 +435,7 @@ int InstMatchGenerator::getMatch(Node f, Node t, InstMatch& m)
   {
     for (int& pv : prev)
     {
-      m.d_vals[pv] = Node::null();
+      m.reset(pv);
     }
   }
   return ret_val;
index de5ae94759e5adca33eb8ed8efdc607a6ada234c..a1102feec2db2bf34127e8329bf22aab97a0b9e9 100644 (file)
@@ -163,23 +163,15 @@ uint64_t InstMatchGeneratorMulti::addInstantiations(Node q)
     std::vector<InstMatch> newMatches;
     InstMatch m(q);
     while (d_children[i]->getNextMatch(q, m) > 0)
-    {
-      // m.makeRepresentative( qe );
-      newMatches.push_back(InstMatch(&m));
-      m.clear();
-    }
-    Trace("multi-trigger-cache") << "Made " << newMatches.size()
-                                 << " new matches for index " << i << std::endl;
-    for (size_t j = 0, nmatches = newMatches.size(); j < nmatches; j++)
     {
       Trace("multi-trigger-cache2")
-          << "...processing " << j << " / " << newMatches.size()
-          << ", #lemmas = " << addedLemmas << std::endl;
-      processNewMatch(newMatches[j], i, addedLemmas);
+          << "...processing new match, #lemmas = " << addedLemmas << std::endl;
+      processNewMatch(m, i, addedLemmas);
       if (d_qstate.isInConflict())
       {
         return addedLemmas;
       }
+      m.clear();
     }
   }
   return addedLemmas;
@@ -190,7 +182,7 @@ void InstMatchGeneratorMulti::processNewMatch(InstMatch& m,
                                               uint64_t& addedLemmas)
 {
   // see if these produce new matches
-  d_children_trie[fromChildIndex].addInstMatch(d_qstate, d_quant, m.d_vals);
+  d_children_trie[fromChildIndex].addInstMatch(d_qstate, d_quant, m.get());
   // possibly only do the following if we know that new matches will be
   // produced? the issue is that instantiations are filtered in quantifiers
   // engine, and so there is no guarentee that
@@ -234,20 +226,25 @@ void InstMatchGeneratorMulti::processNewInstantiations(InstMatch& m,
   {
     size_t curr_index = iio->d_order[trieIndex];
     Node n = m.get(curr_index);
+    QuantifiersState& qs = d_qstate;
     if (n.isNull())
     {
       // add to InstMatch
       for (std::pair<const Node, InstMatchTrie>& d : tr->d_data)
       {
-        InstMatch mn(&m);
-        mn.setValue(curr_index, d.first);
-        processNewInstantiations(mn,
+        // try to set
+        if (!m.set(qs, curr_index, d.first))
+        {
+          continue;
+        }
+        processNewInstantiations(m,
                                  addedLemmas,
                                  &(d.second),
                                  trieIndex + 1,
                                  childIndex,
                                  endChildIndex,
                                  modEq);
+        m.reset(curr_index);
         if (d_qstate.isInConflict())
         {
           break;
@@ -270,35 +267,27 @@ void InstMatchGeneratorMulti::processNewInstantiations(InstMatch& m,
     {
       return;
     }
-    QuantifiersState& qs = d_qstate;
     // check modulo equality for other possible instantiations
     if (!qs.hasTerm(n))
     {
       return;
     }
-    eq::EqClassIterator eqc(qs.getRepresentative(n), qs.getEqualityEngine());
-    while (!eqc.isFinished())
+    for (std::pair<const Node, InstMatchTrie>& d : tr->d_data)
     {
-      Node en = (*eqc);
-      if (en != n)
+      if (d.first != n && qs.areEqual(d.first, n))
       {
-        std::map<Node, InstMatchTrie>::iterator itc = tr->d_data.find(en);
-        if (itc != tr->d_data.end())
+        processNewInstantiations(m,
+                                 addedLemmas,
+                                 &(d.second),
+                                 trieIndex + 1,
+                                 childIndex,
+                                 endChildIndex,
+                                 modEq);
+        if (d_qstate.isInConflict())
         {
-          processNewInstantiations(m,
-                                   addedLemmas,
-                                   &(itc->second),
-                                   trieIndex + 1,
-                                   childIndex,
-                                   endChildIndex,
-                                   modEq);
-          if (d_qstate.isInConflict())
-          {
-            break;
-          }
+          break;
         }
       }
-      ++eqc;
     }
   }
   else
index 4dee86a09fc84308639396dc74ea4d10168e67b6..c3b9a346b7cf8f325c147c6973fe0b5f9efbf370 100644 (file)
@@ -114,7 +114,7 @@ int RelationalMatchGenerator::getNextMatch(Node q, InstMatch& m)
       // failed
       if (rmPrev)
       {
-        m.d_vals[d_vindex] = Node::null();
+        m.reset(d_vindex);
       }
     }
   }
index c22629cc479fdf9b71cc8d325df02e6cec1ce165..c0cd6dc53245220205074855c99d752f7ced1a27 100644 (file)
@@ -178,7 +178,7 @@ bool Trigger::sendInstantiation(std::vector<Node>& m, InferenceId id)
 
 bool Trigger::sendInstantiation(InstMatch& m, InferenceId id)
 {
-  return sendInstantiation(m.d_vals, id);
+  return sendInstantiation(m.get(), id);
 }
 
 int Trigger::getActiveScore() { return d_mg->getActiveScore(); }
index 80838a824e2b01e9c1983faf1687db6947fd9951..03f394ce68d5abe515c85bef725bb4ac1a91b299 100644 (file)
@@ -45,6 +45,7 @@ bool VarMatchGeneratorTermSubs::reset(Node eqc)
 
 int VarMatchGeneratorTermSubs::getNextMatch(Node q, InstMatch& m)
 {
+  size_t index = d_children_types[0];
   int ret_val = -1;
   if (!d_eq_class.isNull())
   {
@@ -57,8 +58,8 @@ int VarMatchGeneratorTermSubs::getNextMatch(Node q, InstMatch& m)
         << "...got " << s << ", " << s.getKind() << std::endl;
     d_eq_class = Node::null();
     // if( s.getType().isSubtypeOf( d_var_type ) ){
-    d_rm_prev = m.get(d_children_types[0]).isNull();
-    if (!m.set(d_qstate, d_children_types[0], s))
+    d_rm_prev = m.get(index).isNull();
+    if (!m.set(d_qstate, index, s))
     {
       return -1;
     }
@@ -74,7 +75,7 @@ int VarMatchGeneratorTermSubs::getNextMatch(Node q, InstMatch& m)
   }
   if (d_rm_prev)
   {
-    m.d_vals[d_children_types[0]] = Node::null();
+    m.reset(index);
     d_rm_prev = false;
   }
   return -1;
index fdba4afb1b69e08fdd4bb7bf82d9a40d05298dcc..5445f955f23b95f5b4cdba9ff183b24779a4d08c 100644 (file)
@@ -21,7 +21,7 @@ namespace cvc5::internal {
 namespace theory {
 namespace quantifiers {
 
-InstMatch::InstMatch(TNode q)
+InstMatch::InstMatch(TNode q) : d_quant(q)
 {
   d_vals.resize(q[0].getNumChildren());
   Assert(!d_vals.empty());
@@ -29,10 +29,6 @@ InstMatch::InstMatch(TNode q)
   Assert(d_vals[0].isNull());
 }
 
-InstMatch::InstMatch( InstMatch* m ) {
-  d_vals.insert( d_vals.end(), m->d_vals.begin(), m->d_vals.end() );
-}
-
 void InstMatch::add(InstMatch& m)
 {
   for (unsigned i = 0, size = d_vals.size(); i < size; i++)
@@ -52,8 +48,28 @@ void InstMatch::debugPrint( const char* c ){
   }
 }
 
-bool InstMatch::isComplete() {
-  for (Node& v : d_vals)
+void InstMatch::toStream(std::ostream& out) const
+{
+  out << "INST_MATCH( ";
+  bool printed = false;
+  for (size_t i = 0, size = d_vals.size(); i < size; i++)
+  {
+    if (!d_vals[i].isNull())
+    {
+      if (printed)
+      {
+        out << ", ";
+      }
+      out << i << " -> " << d_vals[i];
+      printed = true;
+    }
+  }
+  out << " )";
+}
+
+bool InstMatch::isComplete() const
+{
+  for (const Node& v : d_vals)
   {
     if (v.isNull())
     {
@@ -63,8 +79,9 @@ bool InstMatch::isComplete() {
   return true;
 }
 
-bool InstMatch::empty() {
-  for (Node& v : d_vals)
+bool InstMatch::empty() const
+{
+  for (const Node& v : d_vals)
   {
     if (!v.isNull())
     {
@@ -91,16 +108,28 @@ void InstMatch::setValue(size_t i, TNode n)
   Assert(i < d_vals.size());
   d_vals[i] = n;
 }
+
 bool InstMatch::set(QuantifiersState& qs, size_t i, TNode n)
 {
   Assert(i < d_vals.size());
-  if( !d_vals[i].isNull() ){
+  if (!d_vals[i].isNull())
+  {
+    // if they are equal, we do nothing
     return qs.areEqual(d_vals[i], n);
   }
+  // otherwise, we update the value
   d_vals[i] = n;
   return true;
 }
 
+void InstMatch::reset(size_t i)
+{
+  Assert(!d_vals[i].isNull());
+  d_vals[i] = Node::null();
+}
+
+std::vector<Node>& InstMatch::get() { return d_vals; }
+
 }  // namespace quantifiers
 }  // namespace theory
 }  // namespace cvc5::internal
index fc4f27171eaa92dec0cd0546400ed4e699b645a1..d31d638ee98616b549a05823fffd2aace7beddf4 100644 (file)
@@ -38,12 +38,8 @@ class QuantifiersState;
  * yet to be initialized.
  */
 class InstMatch {
-public:
-  InstMatch(){}
-  explicit InstMatch(TNode q);
-  InstMatch( InstMatch* m );
-  /* map from variable to ground terms */
-  std::vector<Node> d_vals;
+ public:
+  InstMatch(TNode q);
   /** add match m
    *
    * This adds the initialized fields of m to this match for each field that is
@@ -51,26 +47,15 @@ public:
    */
   void add(InstMatch& m);
   /** is this complete, i.e. are all fields non-null? */
-  bool isComplete();
+  bool isComplete() const;
   /** is this empty, i.e. are all fields the null node? */
-  bool empty();
+  bool empty() const;
   /** clear the instantiation, i.e. set all fields to the null node */
   void clear();
   /** debug print method */
   void debugPrint(const char* c);
   /** to stream */
-  inline void toStream(std::ostream& out) const {
-    out << "INST_MATCH( ";
-    bool printed = false;
-    for( unsigned i=0; i<d_vals.size(); i++ ){
-      if( !d_vals[i].isNull() ){
-        if( printed ){ out << ", "; }
-        out << i << " -> " << d_vals[i];
-        printed = true;
-      }
-    }
-    out << " )";
-  }
+  void toStream(std::ostream& out) const;
   /** get the i^th term in the instantiation */
   Node get(size_t i) const;
   /** set/overwrites the i^th field in the instantiation with n */
@@ -81,6 +66,19 @@ public:
    * or is equivalent to n modulo the equalities given by q.
    */
   bool set(QuantifiersState& qs, size_t i, TNode n);
+  /** Resets index i */
+  void reset(size_t i);
+  /** Get the values */
+  std::vector<Node>& get();
+
+ private:
+  /**
+   * Ground terms for each variable of the quantified formula, in order.
+   * Null nodes indicate the variable has not been set.
+   */
+  std::vector<Node> d_vals;
+  /** The quantified formula */
+  Node d_quant;
 };
 
 inline std::ostream& operator<<(std::ostream& out, const InstMatch& m) {