Make IndexTrie take nodes (#8649)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 26 Apr 2022 22:37:37 +0000 (17:37 -0500)
committerGitHub <noreply@github.com>
Tue, 26 Apr 2022 22:37:37 +0000 (22:37 +0000)
This makes the class easier to use and allows for a usage where the null node is interpreted as specifying all nodes.

This is in preparation for using this class for testing whether an instantiation from any instantiation strategy is currently feasible based on learning in the style of fail masks from Janota et al FMCAD 2021.

Also, this class should be renamed to something more appropriate, since it no longer takes indices.

FYI @MikolasJanota

src/theory/quantifiers/index_trie.cpp
src/theory/quantifiers/index_trie.h
src/theory/quantifiers/inst_strategy_enumerative.cpp
src/theory/quantifiers/inst_strategy_pool.cpp
src/theory/quantifiers/term_tuple_enumerator.cpp
src/theory/quantifiers/term_tuple_enumerator.h

index d9046f79345497e67fbd91b6892ed5aebfb313f9..6c28396d6b594432a0912f70e05438d127ee61e7 100644 (file)
@@ -21,7 +21,7 @@ namespace theory {
 namespace quantifiers {
 
 void IndexTrie::add(const std::vector<bool>& mask,
-                    const std::vector<size_t>& values)
+                    const std::vector<Node>& values)
 {
   const size_t cardinality = std::count(mask.begin(), mask.end(), true);
   if (d_ignoreFullySpecified && cardinality == mask.size())
@@ -48,7 +48,7 @@ void IndexTrie::freeRec(IndexTrieNode* n)
 
 bool IndexTrie::findRec(const IndexTrieNode* n,
                         size_t index,
-                        const std::vector<size_t>& members,
+                        const std::vector<Node>& members,
                         size_t& nonBlankLength) const
 {
   if (!n || index >= members.size())
@@ -59,6 +59,11 @@ bool IndexTrie::findRec(const IndexTrieNode* n,
   {
     return true;  // found in the blank branch
   }
+  if (members[index].isNull())
+  {
+    // null is interpreted as "any", must have found in the blank branch
+    return false;
+  }
   nonBlankLength = index + 1;
   for (const auto& c : n->d_children)
   {
@@ -75,7 +80,7 @@ IndexTrieNode* IndexTrie::addRec(IndexTrieNode* n,
                                  size_t index,
                                  size_t cardinality,
                                  const std::vector<bool>& mask,
-                                 const std::vector<size_t>& values)
+                                 const std::vector<Node>& values)
 {
   if (!n)
   {
@@ -96,7 +101,7 @@ IndexTrieNode* IndexTrie::addRec(IndexTrieNode* n,
     return n;
   }
   Assert(cardinality);
-
+  Assert(!values[index].isNull());
   for (auto& edge : n->d_children)
   {
     if (edge.first == values[index])
index fa58a3e4837cd56424cbc09c7f25732b25393d65..c6d38f533bd56471265cb80a9286569f62d695b4 100644 (file)
@@ -22,6 +22,7 @@
 #include <vector>
 
 #include "base/check.h"
+#include "expr/node.h"
 
 namespace cvc5::internal {
 namespace theory {
@@ -30,13 +31,13 @@ namespace quantifiers {
 /** A single node of the IndexTrie. */
 struct IndexTrieNode
 {
-  std::vector<std::pair<size_t, IndexTrieNode*>> d_children;
+  std::vector<std::pair<Node, IndexTrieNode*>> d_children;
   IndexTrieNode* d_blank = nullptr;
 };
 
-/** Trie of  sequences indices, used to check for subsequence membership.
+/** Trie of Nodes, used to check for subsequence membership.
  *
- * The  data structure stores tuples of indices where some elements may be
+ * The data structure stores tuples of indices where some elements may be
  * left blank. The objective is to enable checking whether a given, completely
  * filled in, tuple has a  sub-tuple  present in the data structure.  This is
  * used in the term tuple enumeration (term_tuple_enumerator.cpp) to store
@@ -48,12 +49,19 @@ struct IndexTrieNode
  * tuple that contains 1 and 3 on second and forth position, respectively, would
  * match.
  *
- *  The data structure behaves essentially as a traditional trie. Each tuple
+ * The data structure behaves essentially as a traditional trie. Each tuple
  * is treated as a sequence of integers with a special symbol for blank, which
  * is in fact stored  in a special  child (member d_blank).  As a small
  * optimization, a suffix containing only blanks is represented by  the empty
  * subtree, i.e., a null pointer.
  *
+ * Additionally, this class accepts membership queries involving null nodes,
+ * which are interpreted as requiring that all possible values of the node at
+ * that position are contained. For example, writing `_` for null:
+ * (_, 1, 2, 3) is contained in (_, 1, _, 3)
+ * (1, 1, _, 3) is contained in (_, 1, _, 3)
+ * (_, 2, _, _) is not contained in (_, 1, _, 3)
+ * (_, 1, 2, 3) is not contained in (0, 1, _, 3)
  */
 class IndexTrie
 {
@@ -61,7 +69,7 @@ class IndexTrie
   /*  Construct the trie,  if the argument ignoreFullySpecified is true,
    *  the data structure will  store only data structure containing at least
    *  one blank. */
-  IndexTrie(bool ignoreFullySpecified)
+  IndexTrie(bool ignoreFullySpecified = true)
       : d_ignoreFullySpecified(ignoreFullySpecified),
         d_root(new IndexTrieNode())
   {
@@ -71,11 +79,11 @@ class IndexTrie
 
   /**  Add a tuple of values into the trie  masked by a bitmask, i.e.\ position
    * i is considered blank iff mask[i] is false. */
-  void add(const std::vector<bool>& mask, const std::vector<size_t>& values);
+  void add(const std::vector<bool>& mask, const std::vector<Node>& values);
 
   /** Check if the given set of indices is subsumed by something present in the
    * trie. If it is subsumed, give the maximum non-blank index. */
-  bool find(const std::vector<size_t>& members,
+  bool find(const std::vector<Node>& members,
             /*out*/ size_t& nonBlankLength) const
   {
     nonBlankLength = 0;
@@ -94,7 +102,7 @@ class IndexTrie
   /** Auxiliary recursive function for finding  subsuming tuple. */
   bool findRec(const IndexTrieNode* n,
                size_t index,
-               const std::vector<size_t>& members,
+               const std::vector<Node>& members,
                size_t& nonBlankLength) const;
 
   /** Add master values  starting from index  to a given subtree. The
@@ -103,7 +111,7 @@ class IndexTrie
                         size_t index,
                         size_t cardinality,
                         const std::vector<bool>& mask,
-                        const std::vector<size_t>& values);
+                        const std::vector<Node>& values);
 };
 
 }  // namespace quantifiers
index ab52d71bf9adc10c0f65be706d4ab0ca2ee35464..7e9493ccdd27186b47c54a178c06277749bf60b7 100644 (file)
@@ -182,18 +182,18 @@ bool InstStrategyEnum::process(Node quantifier, bool fullEffort, bool isRd)
     return false;
   }
 
+  Instantiate* ie = d_qim.getInstantiate();
   TermTupleEnumeratorEnv ttec;
   ttec.d_fullEffort = fullEffort;
   ttec.d_increaseSum = options().quantifiers.enumInstSum;
+  ttec.d_tr = &d_treg;
   // make the enumerator, which is either relevant domain or term database
   // based on the flag isRd.
   std::unique_ptr<TermTupleEnumeratorInterface> enumerator(
       isRd ? mkTermTupleEnumeratorRd(quantifier, &ttec, d_rd)
-           : mkTermTupleEnumerator(
-                 quantifier, &ttec, d_qstate, d_treg.getTermDatabase()));
+           : mkTermTupleEnumerator(quantifier, &ttec, d_qstate));
   std::vector<Node> terms;
   std::vector<bool> failMask;
-  Instantiate* ie = d_qim.getInstantiate();
   for (enumerator->init(); enumerator->hasNext();)
   {
     if (d_qstate.isInConflict())
index fa48a6a89589700f592b82776123033f8bb4ae9e..2cf081852609c80b15057b83fd8ff554bcd396e9 100644 (file)
@@ -127,13 +127,13 @@ std::string InstStrategyPool::identify() const
 
 bool InstStrategyPool::process(Node q, Node p, uint64_t& addedLemmas)
 {
+  Instantiate* ie = d_qim.getInstantiate();
   TermTupleEnumeratorEnv ttec;
   ttec.d_fullEffort = true;
   ttec.d_increaseSum = options().quantifiers.enumInstSum;
-  TermPools* tp = d_treg.getTermPools();
+  ttec.d_tr = &d_treg;
   std::shared_ptr<TermTupleEnumeratorInterface> enumerator(
-      mkTermTupleEnumeratorPool(q, &ttec, tp, p));
-  Instantiate* ie = d_qim.getInstantiate();
+      mkTermTupleEnumeratorPool(q, &ttec, p));
   std::vector<Node> terms;
   std::vector<bool> failMask;
   // we instantiate exhaustively
index 6e189e902105f59bf55479901fa321a23c48675f..cf5235f9cd199b9611fa0620c481ac3f81354bd8 100644 (file)
@@ -26,6 +26,7 @@
 #include "options/quantifiers_options.h"
 #include "smt/smt_statistics_registry.h"
 #include "theory/quantifiers/index_trie.h"
+#include "theory/quantifiers/instantiate.h"
 #include "theory/quantifiers/quant_module.h"
 #include "theory/quantifiers/relevant_domain.h"
 #include "theory/quantifiers/term_pools.h"
@@ -166,9 +167,10 @@ class TermTupleEnumeratorBasic : public TermTupleEnumeratorBase
  public:
   TermTupleEnumeratorBasic(Node quantifier,
                            const TermTupleEnumeratorEnv* env,
-                           QuantifiersState& qs,
-                           TermDb* td)
-      : TermTupleEnumeratorBase(quantifier, env), d_qs(qs), d_tdb(td)
+                           QuantifiersState& qs)
+      : TermTupleEnumeratorBase(quantifier, env),
+        d_qs(qs),
+        d_tdb(env->d_tr->getTermDatabase())
   {
   }
 
@@ -273,7 +275,9 @@ void TermTupleEnumeratorBase::failureReason(const std::vector<bool>& mask)
   {
     traceMaskedVector("inst-alg", "failureReason", mask, d_termIndex);
   }
-  d_disabledCombinations.add(mask, d_termIndex);  // record failure
+  std::vector<Node> tti;
+  next(tti);
+  d_disabledCombinations.add(mask, tti);  // record failure
   // update change prefix accordingly
   for (d_changePrefix = mask.size();
        d_changePrefix && !mask[d_changePrefix - 1];
@@ -287,13 +291,14 @@ void TermTupleEnumeratorBase::next(/*out*/ std::vector<Node>& terms)
   terms.resize(d_variableCount);
   for (size_t variableIx = 0; variableIx < d_variableCount; variableIx++)
   {
-    const Node t = d_termsSizes[variableIx] == 0
-                       ? Node::null()
-                       : getTerm(variableIx, d_termIndex[variableIx]);
+    const Node t =
+        d_termsSizes[variableIx] == 0
+            ? d_env->d_tr->getTermForType(d_quantifier[0][variableIx].getType())
+            : getTerm(variableIx, d_termIndex[variableIx]);
     terms[variableIx] = t;
     Trace("inst-alg-rd") << t << "  ";
-    Assert(t.isNull()
-           || t.getType().isComparableTo(d_quantifier[0][variableIx].getType()))
+    Assert(!t.isNull());
+    Assert(t.getType().isComparableTo(d_quantifier[0][variableIx].getType()))
         << "Bad type: " << t << " " << t.getType() << " "
         << d_quantifier[0][variableIx].getType();
   }
@@ -356,7 +361,9 @@ bool TermTupleEnumeratorBase::nextCombination()
     {
       return false;  // ran out of combinations
     }
-    if (!d_disabledCombinations.find(d_termIndex, d_changePrefix))
+    std::vector<Node> tti;
+    next(tti);
+    if (!d_disabledCombinations.find(tti, d_changePrefix))
     {
       return true;  // current combination vetted by disabled combinations
     }
@@ -501,9 +508,10 @@ class TermTupleEnumeratorPool : public TermTupleEnumeratorBase
  public:
   TermTupleEnumeratorPool(Node quantifier,
                           const TermTupleEnumeratorEnv* env,
-                          TermPools* tp,
                           Node pool)
-      : TermTupleEnumeratorBase(quantifier, env), d_tp(tp), d_pool(pool)
+      : TermTupleEnumeratorBase(quantifier, env),
+        d_tp(env->d_tr->getTermPools()),
+        d_pool(pool)
   {
     Assert(d_pool.getKind() == kind::INST_POOL);
   }
@@ -536,10 +544,10 @@ class TermTupleEnumeratorPool : public TermTupleEnumeratorBase
 };
 
 TermTupleEnumeratorInterface* mkTermTupleEnumerator(
-    Node q, const TermTupleEnumeratorEnv* env, QuantifiersState& qs, TermDb* td)
+    Node q, const TermTupleEnumeratorEnv* env, QuantifiersState& qs)
 {
   return static_cast<TermTupleEnumeratorInterface*>(
-      new TermTupleEnumeratorBasic(q, env, qs, td));
+      new TermTupleEnumeratorBasic(q, env, qs));
 }
 TermTupleEnumeratorInterface* mkTermTupleEnumeratorRd(
     Node q, const TermTupleEnumeratorEnv* env, RelevantDomain* rd)
@@ -549,10 +557,10 @@ TermTupleEnumeratorInterface* mkTermTupleEnumeratorRd(
 }
 
 TermTupleEnumeratorInterface* mkTermTupleEnumeratorPool(
-    Node q, const TermTupleEnumeratorEnv* env, TermPools* tp, Node pool)
+    Node q, const TermTupleEnumeratorEnv* env, Node pool)
 {
   return static_cast<TermTupleEnumeratorInterface*>(
-      new TermTupleEnumeratorPool(q, env, tp, pool));
+      new TermTupleEnumeratorPool(q, env, pool));
 }
 
 }  // namespace quantifiers
index 05cd1da5c1b993b48f1703c50c17d9b10b719186..2b3edfdecad0f95f74c9ba05e5865310ecea79d6 100644 (file)
@@ -24,9 +24,10 @@ namespace cvc5::internal {
 namespace theory {
 namespace quantifiers {
 
+class Instantiate;
 class TermPools;
 class QuantifiersState;
-class TermDb;
+class TermRegistry;
 class RelevantDomain;
 
 /**  Interface for enumeration of tuples of terms.
@@ -65,6 +66,8 @@ struct TermTupleEnumeratorEnv
   bool d_fullEffort;
   /** Whether we increase tuples based on sum instead of max (see below) */
   bool d_increaseSum;
+  /** Term registry */
+  TermRegistry* d_tr;
 };
 
 /**  A function to construct a tuple enumerator.
@@ -87,17 +90,14 @@ struct TermTupleEnumeratorEnv
  * duplicates modulo equality.
  */
 TermTupleEnumeratorInterface* mkTermTupleEnumerator(
-    Node q,
-    const TermTupleEnumeratorEnv* env,
-    QuantifiersState& qs,
-    TermDb* td);
+    Node q, const TermTupleEnumeratorEnv* env, QuantifiersState& qs);
 /** Same as above, but draws terms from the relevant domain utility (rd). */
 TermTupleEnumeratorInterface* mkTermTupleEnumeratorRd(
     Node q, const TermTupleEnumeratorEnv* env, RelevantDomain* rd);
 
 /** Make term pool enumerator */
 TermTupleEnumeratorInterface* mkTermTupleEnumeratorPool(
-    Node q, const TermTupleEnumeratorEnv* env, TermPools* tp, Node p);
+    Node q, const TermTupleEnumeratorEnv* env, Node p);
 
 }  // namespace quantifiers
 }  // namespace theory