Improvements and refactoring for enumeratative strategy (#6030)
authorMikolasJanota <MikolasJanota@users.noreply.github.com>
Thu, 11 Mar 2021 16:28:51 +0000 (17:28 +0100)
committerGitHub <noreply@github.com>
Thu, 11 Mar 2021 16:28:51 +0000 (10:28 -0600)
Refactoring out the code from  `inst_strategy_enumerative` into a separate
class. Some additional tricks to avoid duplicate instantiations, most
notably, before instantiation, a tuple is checked if it's not a
super-tuple of some tuple that had earlier resulted in a useless
instantiation.

Signed-off-by: mikolas <mikolas.janota@gmail.com>
src/CMakeLists.txt
src/options/quantifiers_options.toml
src/theory/quantifiers/index_trie.cpp [new file with mode: 0644]
src/theory/quantifiers/index_trie.h [new file with mode: 0644]
src/theory/quantifiers/inst_strategy_enumerative.cpp
src/theory/quantifiers/term_tuple_enumerator.cpp [new file with mode: 0644]
src/theory/quantifiers/term_tuple_enumerator.h [new file with mode: 0644]

index a086d4224c8d8f661fc9892e4c7208def026f2ae..ad06eb568253ed6795391bc72df8060655f4d2c1 100644 (file)
@@ -714,6 +714,8 @@ libcvc4_add_sources(
   theory/quantifiers/fmf/model_engine.h
   theory/quantifiers/fun_def_evaluator.cpp
   theory/quantifiers/fun_def_evaluator.h
+  theory/quantifiers/index_trie.cpp
+  theory/quantifiers/index_trie.h
   theory/quantifiers/inst_match.cpp
   theory/quantifiers/inst_match.h
   theory/quantifiers/inst_match_trie.cpp
@@ -762,6 +764,8 @@ libcvc4_add_sources(
   theory/quantifiers/skolemize.h
   theory/quantifiers/solution_filter.cpp
   theory/quantifiers/solution_filter.h
+  theory/quantifiers/term_tuple_enumerator.cpp
+  theory/quantifiers/term_tuple_enumerator.h
   theory/quantifiers/sygus/ce_guided_single_inv.cpp
   theory/quantifiers/sygus/ce_guided_single_inv.h
   theory/quantifiers/sygus/ce_guided_single_inv_sol.cpp
index d03e9715a98c5e83dc28035691abe513ef2ae5e9..db7100e9c62ca521a3b25db11095652f5807e9be 100644 (file)
@@ -555,6 +555,15 @@ header = "options/quantifiers_options.h"
   read_only  = true
   help       = "stratify effort levels in enumerative instantiation, which favors speed over fairness"
 
+[[option]]
+  name       = "fullSaturateSum"
+  category   = "regular"
+  long       = "fs-sum"
+  type       = "bool"
+  default    = "false"
+  read_only  = true
+  help       = "enumerating tuples of quantifiers by increasing the sum of indices, rather than the maximum"
+
 [[option]]
   name       = "literalMatchMode"
   category   = "regular"
diff --git a/src/theory/quantifiers/index_trie.cpp b/src/theory/quantifiers/index_trie.cpp
new file mode 100644 (file)
index 0000000..728d51f
--- /dev/null
@@ -0,0 +1,117 @@
+/*********************                                                        */
+/*! \file index_trie.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mikolas Janota
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Implementation of a trie that store subsets of tuples of term indices
+ ** that are not yielding  useful instantiations. of quantifier instantiation.
+ ** This is used in the term_tuple_enumerator.
+ **/
+#include "theory/quantifiers/index_trie.h"
+
+namespace CVC4 {
+namespace theory {
+namespace quantifiers {
+
+void IndexTrie::add(const std::vector<bool>& mask,
+                    const std::vector<size_t>& values)
+{
+  const size_t cardinality = std::count(mask.begin(), mask.end(), true);
+  if (d_ignoreFullySpecified && cardinality == mask.size())
+  {
+    return;
+  }
+
+  d_root = addRec(d_root, 0, cardinality, mask, values);
+}
+
+void IndexTrie::freeRec(IndexTrieNode* n)
+{
+  if (!n)
+  {
+    return;
+  }
+  for (auto c : n->d_children)
+  {
+    freeRec(c.second);
+  }
+  freeRec(n->d_blank);
+  delete n;
+}
+
+bool IndexTrie::findRec(const IndexTrieNode* n,
+                        size_t index,
+                        const std::vector<size_t>& members,
+                        size_t& nonBlankLength) const
+{
+  if (!n || index >= members.size())
+  {
+    return true;  // all elements of members matched
+  }
+  if (n->d_blank && findRec(n->d_blank, index + 1, members, nonBlankLength))
+  {
+    return true;  // found in the blank branch
+  }
+  nonBlankLength = index + 1;
+  for (const auto& c : n->d_children)
+  {
+    if (c.first == members[index]
+        && findRec(c.second, index + 1, members, nonBlankLength))
+    {
+      return true;  // found in the matching subtree
+    }
+  }
+  return false;
+}
+
+IndexTrieNode* IndexTrie::addRec(IndexTrieNode* n,
+                                 size_t index,
+                                 size_t cardinality,
+                                 const std::vector<bool>& mask,
+                                 const std::vector<size_t>& values)
+{
+  if (!n)
+  {
+    return nullptr;  // this tree matches everything, no point to add
+  }
+  if (cardinality == 0)  // all blanks, all strings match
+  {
+    freeRec(n);
+    return nullptr;
+  }
+
+  Assert(index < mask.size());
+
+  if (!mask[index])  // blank position in the added vector
+  {
+    auto blank = n->d_blank ? n->d_blank : new IndexTrieNode();
+    n->d_blank = addRec(blank, index + 1, cardinality, mask, values);
+    return n;
+  }
+  Assert(cardinality);
+
+  for (auto& edge : n->d_children)
+  {
+    if (edge.first == values[index])
+    {
+      // value already amongst the children
+      edge.second =
+          addRec(edge.second, index + 1, cardinality - 1, mask, values);
+      return n;
+    }
+  }
+  // new child needs to be added
+  auto child =
+      addRec(new IndexTrieNode(), index + 1, cardinality - 1, mask, values);
+  n->d_children.push_back(std::make_pair(values[index], child));
+  return n;
+}
+}  // namespace quantifiers
+}  // namespace theory
+}  // namespace CVC4
diff --git a/src/theory/quantifiers/index_trie.h b/src/theory/quantifiers/index_trie.h
new file mode 100644 (file)
index 0000000..b770951
--- /dev/null
@@ -0,0 +1,110 @@
+/*********************                                                        */
+/*! \file index_trie.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mikolas Janota
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Implementation of a trie that store subsets of tuples of term indices
+ ** that are not yielding  useful instantiations. of quantifier instantiation.
+ ** This is used in the term_tuple_enumerator.
+ **/
+#ifndef CVC4__THEORY__QUANTIFIERS__INDEX_TRIE_H
+#define CVC4__THEORY__QUANTIFIERS__INDEX_TRIE_H
+#include <algorithm>
+#include <utility>
+#include <vector>
+
+#include "base/check.h"
+
+namespace CVC4 {
+namespace theory {
+namespace quantifiers {
+
+/** A single node of the IndexTrie. */
+struct IndexTrieNode
+{
+  std::vector<std::pair<size_t, IndexTrieNode*>> d_children;
+  IndexTrieNode* d_blank = nullptr;
+};
+
+/** Trie of  sequences indices, used to check for subsequence membership.
+ *
+ * The  data structure stores tuples of indices where some elements may be
+ * left blank. The objective is to enable checking whether a given, completely
+ * filled in, tuple has a  sub-tuple  present in the data structure.  This is
+ * used in the term tuple enumeration (term_tuple_enumerator.cpp) to store
+ * combinations of terms that had yielded a useless instantiation  and therefore
+ * should not be repeated.  Hence, we are always assuming that all tuples have
+ * the same number of elements.
+ *
+ * So for instance, if the data structure contains (_, 1, _, 3),  any  given
+ * tuple that contains 1 and 3 on second and forth position, respectively, would
+ * match.
+ *
+ *  The data structure behaves essentially as a traditional trie. Each tuple
+ * is treated as a sequence of integers with a special symbol for blank, which
+ * is in fact stored  in a special  child (member d_blank).  As a small
+ * optimization, a suffix containing only blanks is represented by  the empty
+ * subtree, i.e., a null pointer.
+ *
+ */
+class IndexTrie
+{
+ public:
+  /*  Construct the trie,  if the argument ignoreFullySpecified is true,
+   *  the data structure will  store only data structure containing at least
+   *  one blank. */
+  IndexTrie(bool ignoreFullySpecified)
+      : d_ignoreFullySpecified(ignoreFullySpecified),
+        d_root(new IndexTrieNode())
+  {
+  }
+
+  virtual ~IndexTrie() { freeRec(d_root); }
+
+  /**  Add a tuple of values into the trie  masked by a bitmask, i.e.\ position
+   * i is considered blank iff mask[i] is false. */
+  void add(const std::vector<bool>& mask, const std::vector<size_t>& values);
+
+  /** Check if the given set of indices is subsumed by something present in the
+   * trie. If it is subsumed, give the maximum non-blank index. */
+  bool find(const std::vector<size_t>& members,
+            /*out*/ size_t& nonBlankLength) const
+  {
+    nonBlankLength = 0;
+    return findRec(d_root, 0, members, nonBlankLength);
+  }
+
+ private:
+  /**  ignore tuples with no blanks in the add method */
+  const bool d_ignoreFullySpecified;
+  /**  the root of the trie, becomes null, if all tuples should match */
+  IndexTrieNode* d_root;
+
+  /** Auxiliary recursive function for cleanup. */
+  void freeRec(IndexTrieNode* n);
+
+  /** Auxiliary recursive function for finding  subsuming tuple. */
+  bool findRec(const IndexTrieNode* n,
+               size_t index,
+               const std::vector<size_t>& members,
+               size_t& nonBlankLength) const;
+
+  /** Add master values  starting from index  to a given subtree. The
+   * cardinality represents the number of non-blank elements left. */
+  IndexTrieNode* addRec(IndexTrieNode* n,
+                        size_t index,
+                        size_t cardinality,
+                        const std::vector<bool>& mask,
+                        const std::vector<size_t>& values);
+};
+
+}  // namespace quantifiers
+}  // namespace theory
+}  // namespace CVC4
+#endif /* THEORY__QUANTIFIERS__INDEX_TRIE_H */
index 0f2fc0ba026c61e282969275317104587e7805cb..0595484fabae45fcce9daf7962c566a735c91a8f 100644 (file)
@@ -18,6 +18,7 @@
 #include "theory/quantifiers/instantiate.h"
 #include "theory/quantifiers/relevant_domain.h"
 #include "theory/quantifiers/term_database.h"
+#include "theory/quantifiers/term_tuple_enumerator.h"
 #include "theory/quantifiers/term_util.h"
 #include "theory/quantifiers_engine.h"
 
@@ -175,189 +176,52 @@ void InstStrategyEnum::check(Theory::Effort e, QEffort quant_e)
   }
 }
 
-bool InstStrategyEnum::process(Node f, bool fullEffort, bool isRd)
+bool InstStrategyEnum::process(Node quantifier, bool fullEffort, bool isRd)
 {
-  // ignore if constant true (rare case of non-standard quantifier whose body is
-  // rewritten to true)
-  if (f[1].isConst() && f[1].getConst<bool>())
+  // ignore if constant true (rare case of non-standard quantifier whose body
+  // is rewritten to true)
+  if (quantifier[1].isConst() && quantifier[1].getConst<bool>())
   {
     return false;
   }
-  unsigned final_max_i = 0;
-  std::vector<unsigned> maxs;
-  std::vector<bool> max_zero;
-  bool has_zero = false;
-  std::map<TypeNode, std::vector<Node> > term_db_list;
-  std::vector<TypeNode> ftypes;
-  TermDb* tdb = d_quantEngine->getTermDatabase();
-  QuantifiersState& qs = d_quantEngine->getState();
-  // iterate over substitutions for variables
-  for (unsigned i = 0; i < f[0].getNumChildren(); i++)
+
+  TermTupleEnumeratorContext ttec;
+  ttec.d_quantEngine = d_quantEngine;
+  ttec.d_rd = d_rd;
+  ttec.d_fullEffort = fullEffort;
+  ttec.d_increaseSum = options::fullSaturateSum();
+  ttec.d_isRd = isRd;
+  std::unique_ptr<TermTupleEnumeratorInterface> enumerator(
+      mkTermTupleEnumerator(quantifier, &ttec));
+  std::vector<Node> terms;
+  std::vector<bool> failMask;
+  Instantiate* ie = d_quantEngine->getInstantiate();
+  for (enumerator->init(); enumerator->hasNext();)
   {
-    TypeNode tn = f[0][i].getType();
-    ftypes.push_back(tn);
-    unsigned ts;
-    if (isRd)
-    {
-      ts = d_rd->getRDomain(f, i)->d_terms.size();
-    }
-    else
+    if (d_qstate.isInConflict())
     {
-      ts = tdb->getNumTypeGroundTerms(tn);
-      std::map<TypeNode, std::vector<Node> >::iterator ittd =
-          term_db_list.find(tn);
-      if (ittd == term_db_list.end())
-      {
-        std::map<Node, Node> reps_found;
-        for (unsigned j = 0; j < ts; j++)
-        {
-          Node gt = tdb->getTypeGroundTerm(ftypes[i], j);
-          if (!options::cegqi() || !quantifiers::TermUtil::hasInstConstAttr(gt))
-          {
-            Node rep = qs.getRepresentative(gt);
-            if (reps_found.find(rep) == reps_found.end())
-            {
-              reps_found[rep] = gt;
-              term_db_list[tn].push_back(gt);
-            }
-          }
-        }
-        ts = term_db_list[tn].size();
-      }
-      else
-      {
-        ts = ittd->second.size();
-      }
+      // could be conflicting for an internal reason
+      return false;
     }
-    // consider a default value if at full effort
-    max_zero.push_back(fullEffort && ts == 0);
-    ts = (fullEffort && ts == 0) ? 1 : ts;
-    Trace("inst-alg-rd") << "Variable " << i << " has " << ts
-                         << " in relevant domain." << std::endl;
-    if (ts == 0)
+    enumerator->next(terms);
+    // try instantiation
+    failMask.clear();
+    /* if (ie->addInstantiation(quantifier, terms)) */
+    if (ie->addInstantiationExpFail(quantifier, terms, failMask, false))
     {
-      has_zero = true;
-      break;
+      Trace("inst-alg-rd") << "Success!" << std::endl;
+      ++(d_quantEngine->d_statistics.d_instantiations_guess);
+      return true;
     }
-    maxs.push_back(ts);
-    if (ts > final_max_i)
+    else
     {
-      final_max_i = ts;
+      enumerator->failureReason(failMask);
     }
   }
-  if (!has_zero)
-  {
-    Trace("inst-alg-rd") << "Will do " << final_max_i
-                         << " stages of instantiation." << std::endl;
-    unsigned max_i = 0;
-    bool success;
-    Instantiate* ie = d_quantEngine->getInstantiate();
-    while (max_i <= final_max_i)
-    {
-      Trace("inst-alg-rd") << "Try stage " << max_i << "..." << std::endl;
-      std::vector<unsigned> childIndex;
-      int index = 0;
-      do
-      {
-        while (index >= 0 && index < (int)f[0].getNumChildren())
-        {
-          if (index == static_cast<int>(childIndex.size()))
-          {
-            childIndex.push_back(-1);
-          }
-          else
-          {
-            Assert(index == static_cast<int>(childIndex.size()) - 1);
-            unsigned nv = childIndex[index] + 1;
-            if (nv < maxs[index] && nv <= max_i)
-            {
-              childIndex[index] = nv;
-              index++;
-            }
-            else
-            {
-              childIndex.pop_back();
-              index--;
-            }
-          }
-        }
-        success = index >= 0;
-        if (success)
-        {
-          if (Trace.isOn("inst-alg-rd"))
-          {
-            Trace("inst-alg-rd") << "Try instantiation { ";
-            for (unsigned i : childIndex)
-            {
-              Trace("inst-alg-rd") << i << " ";
-            }
-            Trace("inst-alg-rd") << "}" << std::endl;
-          }
-          // try instantiation
-          std::vector<Node> terms;
-          for (unsigned i = 0, nchild = f[0].getNumChildren(); i < nchild; i++)
-          {
-            if (max_zero[i])
-            {
-              // no terms available, will report incomplete instantiation
-              terms.push_back(Node::null());
-              Trace("inst-alg-rd") << "  null" << std::endl;
-            }
-            else if (isRd)
-            {
-              terms.push_back(d_rd->getRDomain(f, i)->d_terms[childIndex[i]]);
-              Trace("inst-alg-rd")
-                  << "  (rd) " << d_rd->getRDomain(f, i)->d_terms[childIndex[i]]
-                  << std::endl;
-            }
-            else
-            {
-              Assert(childIndex[i] < term_db_list[ftypes[i]].size());
-              terms.push_back(term_db_list[ftypes[i]][childIndex[i]]);
-              Trace("inst-alg-rd")
-                  << "  " << term_db_list[ftypes[i]][childIndex[i]]
-                  << std::endl;
-            }
-            Assert(terms[i].isNull()
-                   || terms[i].getType().isComparableTo(ftypes[i]))
-                << "Incompatible type " << f << ", " << terms[i].getType()
-                << ", " << ftypes[i] << std::endl;
-          }
-          std::vector<bool> failMask;
-          if (ie->addInstantiationExpFail(f, terms, failMask, false))
-          {
-            Trace("inst-alg-rd") << "Success!" << std::endl;
-            ++(d_quantEngine->d_statistics.d_instantiations_guess);
-            return true;
-          }
-          else
-          {
-            index--;
-            // currently, we use the failmask only for backtracking, although
-            // more could be learned here (wishue #81).
-            Assert(failMask.size() == terms.size());
-            while (!failMask.empty() && !failMask.back())
-            {
-              failMask.pop_back();
-              childIndex.pop_back();
-              index--;
-            }
-          }
-          if (d_qstate.isInConflict())
-          {
-            // could be conflicting for an internal reason (such as term
-            // indices computed in above calls)
-            return false;
-          }
-        }
-      } while (success);
-      max_i++;
-    }
-  }
-  // TODO : term enumerator instantiation?
   return false;
+  // TODO : term enumerator instantiation?
 }
 
-} /* CVC4::theory::quantifiers namespace */
-} /* CVC4::theory namespace */
-} /* CVC4 namespace */
+}  // namespace quantifiers
+}  // namespace theory
+}  // namespace CVC4
diff --git a/src/theory/quantifiers/term_tuple_enumerator.cpp b/src/theory/quantifiers/term_tuple_enumerator.cpp
new file mode 100644 (file)
index 0000000..1466e10
--- /dev/null
@@ -0,0 +1,501 @@
+/*********************                                                        */
+/*! \file  term_tuple_enumerator.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mikolas Janota
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Implementation of an enumeration of tuples of terms for the purpose
+ *of quantifier instantiation.
+ **/
+#include "theory/quantifiers/term_tuple_enumerator.h"
+
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <map>
+#include <vector>
+
+#include "base/map_util.h"
+#include "base/output.h"
+#include "options/quantifiers_options.h"
+#include "smt/smt_statistics_registry.h"
+#include "theory/quantifiers/index_trie.h"
+#include "theory/quantifiers/quant_module.h"
+#include "theory/quantifiers/relevant_domain.h"
+#include "theory/quantifiers/term_util.h"
+#include "theory/quantifiers_engine.h"
+#include "util/statistics_registry.h"
+
+namespace CVC4 {
+
+template <typename T>
+static CVC4ostream& operator<<(CVC4ostream& out, const std::vector<T>& v)
+{
+  out << "[ ";
+  std::copy(v.begin(), v.end(), std::ostream_iterator<T>(out, " "));
+  return out << "]";
+}
+
+/** Tracing purposes, printing a masked vector of indices. */
+static void traceMaskedVector(const char* trace,
+                              const char* name,
+                              const std::vector<bool>& mask,
+                              const std::vector<size_t>& values)
+{
+  Assert(mask.size() == values.size());
+  Trace(trace) << name << " [ ";
+  for (size_t variableIx = 0; variableIx < mask.size(); variableIx++)
+  {
+    if (mask[variableIx])
+    {
+      Trace(trace) << values[variableIx] << " ";
+    }
+    else
+    {
+      Trace(trace) << "_ ";
+    }
+  }
+  Trace(trace) << "]" << std::endl;
+}
+
+namespace theory {
+namespace quantifiers {
+/**
+ * Base class for enumerators of tuples of terms for the purpose of
+ * quantification instantiation. The tuples are represented as tuples of
+ * indices of  terms, where the tuple has as many elements as there are
+ * quantified variables in the considered quantifier.
+ *
+ * Like so, we see a tuple as a number, where the digits may have different
+ * ranges. The most significant digits are stored first.
+ *
+ * Tuples are enumerated  in a lexicographic order in stages. There are 2
+ * possible strategies, either  all tuples in a given stage have the same sum of
+ * digits, or, the maximum  over these digits is the same.
+ * */
+class TermTupleEnumeratorBase : public TermTupleEnumeratorInterface
+{
+ public:
+  /** Initialize the class with the quantifier to be instantiated. */
+  TermTupleEnumeratorBase(Node quantifier,
+                          const TermTupleEnumeratorContext* context)
+      : d_quantifier(quantifier),
+        d_variableCount(d_quantifier[0].getNumChildren()),
+        d_context(context),
+        d_stepCounter(0),
+        d_disabledCombinations(
+            true)  // do not record combinations with no blanks
+
+  {
+    d_changePrefix = d_variableCount;
+  }
+
+  virtual ~TermTupleEnumeratorBase() = default;
+
+  // implementation of the TermTupleEnumeratorInterface
+  virtual void init() override;
+  virtual bool hasNext() override;
+  virtual void next(/*out*/ std::vector<Node>& terms) override;
+  virtual void failureReason(const std::vector<bool>& mask) override;
+  // end of implementation of the TermTupleEnumeratorInterface
+
+ protected:
+  /** the quantifier whose variables are being instantiated */
+  const Node d_quantifier;
+  /** number of variables in the quantifier */
+  const size_t d_variableCount;
+  /** context of structures with a longer lifespan */
+  const TermTupleEnumeratorContext* const d_context;
+  /** type for each variable */
+  std::vector<TypeNode> d_typeCache;
+  /** number of candidate terms for each variable */
+  std::vector<size_t> d_termsSizes;
+  /** tuple of indices of the current terms */
+  std::vector<size_t> d_termIndex;
+  /** total number of steps of the enumerator */
+  uint32_t d_stepCounter;
+
+  /** a data structure storing disabled combinations of terms */
+  IndexTrie d_disabledCombinations;
+
+  /** current sum/max  of digits, depending on the strategy */
+  size_t d_currentStage;
+  /**total number of stages*/
+  size_t d_stageCount;
+  /**becomes false once the enumerator runs out of options*/
+  bool d_hasNext;
+  /** the length of the prefix that has to be changed in the next
+  combination, i.e.  the number of the most significant digits that need to be
+  changed in order to escape a  useless instantiation */
+  size_t d_changePrefix;
+  /** Move onto the next stage */
+  bool increaseStage();
+  /** Move onto the next stage, sum strategy. */
+  bool increaseStageSum();
+  /** Move onto the next stage, max strategy. */
+  bool increaseStageMax();
+  /** Move on in the current stage */
+  bool nextCombination();
+  /** Move onto the next combination. */
+  bool nextCombinationInternal();
+  /** Find the next lexicographically smallest combination of terms that change
+   * on the change prefix, each digit is within the current state,  and there is
+   * at least one digit not in the previous stage. */
+  bool nextCombinationSum();
+  /** Find the next lexicographically smallest combination of terms that change
+   * on the change prefix and their sum is equal to d_currentStage. */
+  bool nextCombinationMax();
+  /** Set up terms for given variable.  */
+  virtual size_t prepareTerms(size_t variableIx) = 0;
+  /** Get a given term for a given variable.  */
+  virtual Node getTerm(size_t variableIx,
+                       size_t term_index) CVC4_WARN_UNUSED_RESULT = 0;
+};
+
+/**
+ * Enumerate ground terms as they come from the term database.
+ */
+class TermTupleEnumeratorBasic : public TermTupleEnumeratorBase
+{
+ public:
+  TermTupleEnumeratorBasic(Node quantifier,
+                           const TermTupleEnumeratorContext* context)
+      : TermTupleEnumeratorBase(quantifier, context)
+  {
+  }
+
+  virtual ~TermTupleEnumeratorBasic() = default;
+
+ protected:
+  /**  a list of terms for each type */
+  std::map<TypeNode, std::vector<Node> > d_termDbList;
+  virtual size_t prepareTerms(size_t variableIx) override;
+  virtual Node getTerm(size_t variableIx, size_t term_index) override;
+};
+
+/**
+ * Enumerate ground terms according to the relevant domain class.
+ */
+class TermTupleEnumeratorRD : public TermTupleEnumeratorBase
+{
+ public:
+  TermTupleEnumeratorRD(Node quantifier,
+                        const TermTupleEnumeratorContext* context)
+      : TermTupleEnumeratorBase(quantifier, context)
+  {
+  }
+  virtual ~TermTupleEnumeratorRD() = default;
+
+ protected:
+  virtual size_t prepareTerms(size_t variableIx) override
+  {
+    return d_context->d_rd->getRDomain(d_quantifier, variableIx)
+        ->d_terms.size();
+  }
+  virtual Node getTerm(size_t variableIx, size_t term_index) override
+  {
+    return d_context->d_rd->getRDomain(d_quantifier, variableIx)
+        ->d_terms[term_index];
+  }
+};
+
+TermTupleEnumeratorInterface* mkTermTupleEnumerator(
+    Node quantifier, const TermTupleEnumeratorContext* context)
+{
+  return context->d_isRd ? static_cast<TermTupleEnumeratorInterface*>(
+             new TermTupleEnumeratorRD(quantifier, context))
+                         : static_cast<TermTupleEnumeratorInterface*>(
+                             new TermTupleEnumeratorBasic(quantifier, context));
+}
+
+void TermTupleEnumeratorBase::init()
+{
+  Trace("inst-alg-rd") << "Initializing enumeration " << d_quantifier
+                       << std::endl;
+  d_currentStage = 0;
+  d_hasNext = true;
+  d_stageCount = 1;  // in the case of full effort we do at least one stage
+
+  if (d_variableCount == 0)
+  {
+    d_hasNext = false;
+    return;
+  }
+
+  // prepare a sequence of terms for each quantified variable
+  // additionally initialize the cache for variable types
+  for (size_t variableIx = 0; variableIx < d_variableCount; variableIx++)
+  {
+    d_typeCache.push_back(d_quantifier[0][variableIx].getType());
+    const size_t termsSize = prepareTerms(variableIx);
+    Trace("inst-alg-rd") << "Variable " << variableIx << " has " << termsSize
+                         << " in relevant domain." << std::endl;
+    if (termsSize == 0 && !d_context->d_fullEffort)
+    {
+      d_hasNext = false;
+      return;  // give up on an empty dommain
+    }
+    d_termsSizes.push_back(termsSize);
+    d_stageCount = std::max(d_stageCount, termsSize);
+  }
+
+  Trace("inst-alg-rd") << "Will do " << d_stageCount
+                       << " stages of instantiation." << std::endl;
+  d_termIndex.resize(d_variableCount, 0);
+}
+
+bool TermTupleEnumeratorBase::hasNext()
+{
+  if (!d_hasNext)
+  {
+    return false;
+  }
+
+  if (d_stepCounter++ == 0)
+  {  // TODO:any (nice)  way of avoiding this special if?
+    Assert(d_currentStage == 0);
+    Trace("inst-alg-rd") << "Try stage " << d_currentStage << "..."
+                         << std::endl;
+    return true;
+  }
+
+  // try to find the next combination
+  return d_hasNext = nextCombination();
+}
+
+void TermTupleEnumeratorBase::failureReason(const std::vector<bool>& mask)
+{
+  if (Trace.isOn("inst-alg"))
+  {
+    traceMaskedVector("inst-alg", "failureReason", mask, d_termIndex);
+  }
+  d_disabledCombinations.add(mask, d_termIndex);  // record failure
+  // update change prefix accordingly
+  for (d_changePrefix = mask.size();
+       d_changePrefix && !mask[d_changePrefix - 1];
+       d_changePrefix--)
+    ;
+}
+
+void TermTupleEnumeratorBase::next(/*out*/ std::vector<Node>& terms)
+{
+  Trace("inst-alg-rd") << "Try instantiation: " << d_termIndex << std::endl;
+  terms.resize(d_variableCount);
+  for (size_t variableIx = 0; variableIx < d_variableCount; variableIx++)
+  {
+    const Node t = d_termsSizes[variableIx] == 0
+                       ? Node::null()
+                       : getTerm(variableIx, d_termIndex[variableIx]);
+    terms[variableIx] = t;
+    Trace("inst-alg-rd") << t << "  ";
+    Assert(terms[variableIx].isNull()
+           || terms[variableIx].getType().isComparableTo(
+               d_quantifier[0][variableIx].getType()));
+  }
+  Trace("inst-alg-rd") << std::endl;
+}
+
+bool TermTupleEnumeratorBase::increaseStageSum()
+{
+  const size_t lowerBound = d_currentStage + 1;
+  Trace("inst-alg-rd") << "Try sum " << lowerBound << "..." << std::endl;
+  d_currentStage = 0;
+  for (size_t digit = d_termIndex.size();
+       d_currentStage < lowerBound && digit--;)
+  {
+    const size_t missing = lowerBound - d_currentStage;
+    const size_t maxValue = d_termsSizes[digit] ? d_termsSizes[digit] - 1 : 0;
+    d_termIndex[digit] = std::min(missing, maxValue);
+    d_currentStage += d_termIndex[digit];
+  }
+  return d_currentStage >= lowerBound;
+}
+
+bool TermTupleEnumeratorBase::increaseStage()
+{
+  d_changePrefix = d_variableCount;  // simply reset upon increase stage
+  return d_context->d_increaseSum ? increaseStageSum() : increaseStageMax();
+}
+
+bool TermTupleEnumeratorBase::increaseStageMax()
+{
+  d_currentStage++;
+  if (d_currentStage >= d_stageCount)
+  {
+    return false;
+  }
+  Trace("inst-alg-rd") << "Try stage " << d_currentStage << "..." << std::endl;
+  // skipping some elements that have already been definitely seen
+  // find the least significant digit that can be set to the current stage
+  // TODO: should we skip all?
+  std::fill(d_termIndex.begin(), d_termIndex.end(), 0);
+  bool found = false;
+  for (size_t digit = d_termIndex.size(); !found && digit--;)
+  {
+    if (d_termsSizes[digit] > d_currentStage)
+    {
+      found = true;
+      d_termIndex[digit] = d_currentStage;
+    }
+  }
+  Assert(found);
+  return found;
+}
+
+bool TermTupleEnumeratorBase::nextCombination()
+{
+  while (true)
+  {
+    Trace("inst-alg-rd") << "changePrefix " << d_changePrefix << std::endl;
+    if (!nextCombinationInternal() && !increaseStage())
+    {
+      return false;  // ran out of combinations
+    }
+    if (!d_disabledCombinations.find(d_termIndex, d_changePrefix))
+    {
+      return true;  // current combination vetted by disabled combinations
+    }
+  }
+}
+
+/** Move onto the next combination, depending on the strategy. */
+bool TermTupleEnumeratorBase::nextCombinationInternal()
+{
+  return d_context->d_increaseSum ? nextCombinationSum() : nextCombinationMax();
+}
+
+/** Find the next lexicographically smallest combination of terms that change
+ * on the change prefix and their sum is equal to d_currentStage. */
+bool TermTupleEnumeratorBase::nextCombinationMax()
+{
+  // look for the least significant digit, within change prefix,
+  // that can be increased
+  bool found = false;
+  size_t increaseDigit = d_changePrefix;
+  while (!found && increaseDigit--)
+  {
+    const size_t new_value = d_termIndex[increaseDigit] + 1;
+    if (new_value < d_termsSizes[increaseDigit] && new_value <= d_currentStage)
+    {
+      d_termIndex[increaseDigit] = new_value;
+      // send everything after the increased digit to 0
+      std::fill(d_termIndex.begin() + increaseDigit + 1, d_termIndex.end(), 0);
+      found = true;
+    }
+  }
+  if (!found)
+  {
+    return false;  // nothing to increase
+  }
+  // check if the combination has at least one digit in the current stage,
+  // since at least one digit was increased, no need for this in stage 1
+  bool inStage = d_currentStage <= 1;
+  for (size_t i = increaseDigit + 1; !inStage && i--;)
+  {
+    inStage = d_termIndex[i] >= d_currentStage;
+  }
+  if (!inStage)  // look for a digit that can increase to current stage
+  {
+    for (increaseDigit = d_variableCount, found = false;
+         !found && increaseDigit--;)
+    {
+      found = d_termsSizes[increaseDigit] > d_currentStage;
+    }
+    if (!found)
+    {
+      return false;  // nothing to increase to the current stage
+    }
+    Assert(d_termsSizes[increaseDigit] > d_currentStage
+           && d_termIndex[increaseDigit] < d_currentStage);
+    d_termIndex[increaseDigit] = d_currentStage;
+    // send everything after the increased digit to 0
+    std::fill(d_termIndex.begin() + increaseDigit + 1, d_termIndex.end(), 0);
+  }
+  return true;
+}
+
+/** Find the next lexicographically smallest combination of terms that change
+ * on the change prefix, each digit is within the current state,  and there is
+ * at least one digit not in the previous stage. */
+bool TermTupleEnumeratorBase::nextCombinationSum()
+{
+  size_t suffixSum = 0;
+  bool found = false;
+  size_t increaseDigit = d_termIndex.size();
+  while (increaseDigit--)
+  {
+    const size_t newValue = d_termIndex[increaseDigit] + 1;
+    found = suffixSum > 0 && newValue < d_termsSizes[increaseDigit]
+            && increaseDigit < d_changePrefix;
+    if (found)
+    {
+      // digit can be increased and suffix can be decreased
+      d_termIndex[increaseDigit] = newValue;
+      break;
+    }
+    suffixSum += d_termIndex[increaseDigit];
+    d_termIndex[increaseDigit] = 0;
+  }
+  if (!found)
+  {
+    return false;
+  }
+  Assert(suffixSum > 0);
+  // increaseDigit went up by one, hence, distribute (suffixSum - 1) in the
+  // least significant digits
+  suffixSum--;
+  for (size_t digit = d_termIndex.size(); suffixSum > 0 && digit--;)
+  {
+    const size_t maxValue = d_termsSizes[digit] ? d_termsSizes[digit] - 1 : 0;
+    d_termIndex[digit] = std::min(suffixSum, maxValue);
+    suffixSum -= d_termIndex[digit];
+  }
+  Assert(suffixSum == 0);  // everything should have been distributed
+  return true;
+}
+
+size_t TermTupleEnumeratorBasic::prepareTerms(size_t variableIx)
+{
+  TermDb* const tdb = d_context->d_quantEngine->getTermDatabase();
+  QuantifiersState& qs = d_context->d_quantEngine->getState();
+  const TypeNode type_node = d_typeCache[variableIx];
+
+  if (!ContainsKey(d_termDbList, type_node))
+  {
+    const size_t ground_terms_count = tdb->getNumTypeGroundTerms(type_node);
+    std::map<Node, Node> repsFound;
+    for (size_t j = 0; j < ground_terms_count; j++)
+    {
+      Node gt = tdb->getTypeGroundTerm(type_node, j);
+      if (!options::cegqi() || !quantifiers::TermUtil::hasInstConstAttr(gt))
+      {
+        Node rep = qs.getRepresentative(gt);
+        if (repsFound.find(rep) == repsFound.end())
+        {
+          repsFound[rep] = gt;
+          d_termDbList[type_node].push_back(gt);
+        }
+      }
+    }
+  }
+
+  Trace("inst-alg-rd") << "Instantiation Terms for child " << variableIx << ": "
+                       << d_termDbList[type_node] << std::endl;
+  return d_termDbList[type_node].size();
+}
+
+Node TermTupleEnumeratorBasic::getTerm(size_t variableIx, size_t term_index)
+{
+  const TypeNode type_node = d_typeCache[variableIx];
+  Assert(term_index < d_termDbList[type_node].size());
+  return d_termDbList[type_node][term_index];
+}
+
+}  // namespace quantifiers
+}  // namespace theory
+}  // namespace CVC4
diff --git a/src/theory/quantifiers/term_tuple_enumerator.h b/src/theory/quantifiers/term_tuple_enumerator.h
new file mode 100644 (file)
index 0000000..bd32971
--- /dev/null
@@ -0,0 +1,89 @@
+/*********************                                                        */
+/*! \file  term_tuple_enumerator.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mikolas Janota
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Implementation of an enumeration of tuples of terms for the purpose
+ *of quantifier instantiation.
+ **/
+#ifndef CVC4__THEORY__QUANTIFIERS__TERM_TUPLE_ENUMERATOR_H
+#define CVC4__THEORY__QUANTIFIERS__TERM_TUPLE_ENUMERATOR_H
+#include <vector>
+
+#include "expr/node.h"
+
+namespace CVC4 {
+namespace theory {
+
+class QuantifiersEngine;
+
+namespace quantifiers {
+
+class RelevantDomain;
+
+/**  Interface for enumeration of tuples of terms.
+ *
+ * The interface should be used as follows. Firstly, init is called, then,
+ * repeatedly,  verify if there are any combinations left by calling hasNext
+ * and obtaining the next combination by calling next.
+ *
+ *  Optionally, if the  most recent combination is determined to be undesirable
+ * (for whatever reason), the method failureReason is used to indicate which
+ *  positions of the tuple are responsible for the said failure.
+ */
+class TermTupleEnumeratorInterface
+{
+ public:
+  /** Initialize the enumerator. */
+  virtual void init() = 0;
+  /** Test if there are any more combinations. */
+  virtual bool hasNext() = 0;
+  /** Obtain the next combination, meaningful only if hasNext Returns true. */
+  virtual void next(/*out*/ std::vector<Node>& terms) = 0;
+  /** Record which of the terms obtained by the last call of next should not be
+   * explored again. */
+  virtual void failureReason(const std::vector<bool>& mask) = 0;
+  virtual ~TermTupleEnumeratorInterface() = default;
+};
+
+/** A struct bundling up parameters for term tuple enumerator.*/
+struct TermTupleEnumeratorContext
+{
+  QuantifiersEngine* d_quantEngine;
+  RelevantDomain* d_rd;
+  bool d_fullEffort;
+  bool d_increaseSum;
+  bool d_isRd;
+};
+
+/**  A function to construct a tuple enumerator.
+ *
+ * Currently we support the enumerators based on the following idea.
+ * The tuples are represented as tuples of
+ * indices of  terms, where the tuple has as many elements as there are
+ * quantified variables in the considered quantifier.
+ *
+ * Like so, we see a tuple as a number, where the digits may have different
+ * ranges. The most significant digits are stored first.
+ *
+ * Tuples are enumerated  in a lexicographic order in stages. There are 2
+ * possible strategies, either  all tuples in a given stage have the same sum of
+ * digits, or, the maximum  over these digits is the same (controlled by
+ * d_increaseSum).
+ *
+ * Further, an enumerator  either draws ground terms from the term database or
+ * using the relevant domain class  (controlled by d_isRd).
+ */
+TermTupleEnumeratorInterface* mkTermTupleEnumerator(
+    Node quantifier, const TermTupleEnumeratorContext* context);
+
+}  // namespace quantifiers
+}  // namespace theory
+}  // namespace CVC4
+#endif /* TERM_TUPLE_ENUMERATOR_H_7640 */