From: Andrew Reynolds Date: Tue, 26 Apr 2022 22:37:37 +0000 (-0500) Subject: Make IndexTrie take nodes (#8649) X-Git-Tag: cvc5-1.0.1~218 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=bd818d0ab653eea2e7816b2824824b409c723911;p=cvc5.git Make IndexTrie take nodes (#8649) This makes the class easier to use and allows for a usage where the null node is interpreted as specifying all nodes. This is in preparation for using this class for testing whether an instantiation from any instantiation strategy is currently feasible based on learning in the style of fail masks from Janota et al FMCAD 2021. Also, this class should be renamed to something more appropriate, since it no longer takes indices. FYI @MikolasJanota --- diff --git a/src/theory/quantifiers/index_trie.cpp b/src/theory/quantifiers/index_trie.cpp index d9046f793..6c28396d6 100644 --- a/src/theory/quantifiers/index_trie.cpp +++ b/src/theory/quantifiers/index_trie.cpp @@ -21,7 +21,7 @@ namespace theory { namespace quantifiers { void IndexTrie::add(const std::vector& mask, - const std::vector& values) + const std::vector& values) { const size_t cardinality = std::count(mask.begin(), mask.end(), true); if (d_ignoreFullySpecified && cardinality == mask.size()) @@ -48,7 +48,7 @@ void IndexTrie::freeRec(IndexTrieNode* n) bool IndexTrie::findRec(const IndexTrieNode* n, size_t index, - const std::vector& members, + const std::vector& members, size_t& nonBlankLength) const { if (!n || index >= members.size()) @@ -59,6 +59,11 @@ bool IndexTrie::findRec(const IndexTrieNode* n, { return true; // found in the blank branch } + if (members[index].isNull()) + { + // null is interpreted as "any", must have found in the blank branch + return false; + } nonBlankLength = index + 1; for (const auto& c : n->d_children) { @@ -75,7 +80,7 @@ IndexTrieNode* IndexTrie::addRec(IndexTrieNode* n, size_t index, size_t cardinality, const std::vector& mask, - const std::vector& values) + const std::vector& values) { if (!n) { @@ -96,7 +101,7 @@ IndexTrieNode* IndexTrie::addRec(IndexTrieNode* n, return n; } Assert(cardinality); - + Assert(!values[index].isNull()); for (auto& edge : n->d_children) { if (edge.first == values[index]) diff --git a/src/theory/quantifiers/index_trie.h b/src/theory/quantifiers/index_trie.h index fa58a3e48..c6d38f533 100644 --- a/src/theory/quantifiers/index_trie.h +++ b/src/theory/quantifiers/index_trie.h @@ -22,6 +22,7 @@ #include #include "base/check.h" +#include "expr/node.h" namespace cvc5::internal { namespace theory { @@ -30,13 +31,13 @@ namespace quantifiers { /** A single node of the IndexTrie. */ struct IndexTrieNode { - std::vector> d_children; + std::vector> d_children; IndexTrieNode* d_blank = nullptr; }; -/** Trie of sequences indices, used to check for subsequence membership. +/** Trie of Nodes, used to check for subsequence membership. * - * The data structure stores tuples of indices where some elements may be + * The data structure stores tuples of indices where some elements may be * left blank. The objective is to enable checking whether a given, completely * filled in, tuple has a sub-tuple present in the data structure. This is * used in the term tuple enumeration (term_tuple_enumerator.cpp) to store @@ -48,12 +49,19 @@ struct IndexTrieNode * tuple that contains 1 and 3 on second and forth position, respectively, would * match. * - * The data structure behaves essentially as a traditional trie. Each tuple + * The data structure behaves essentially as a traditional trie. Each tuple * is treated as a sequence of integers with a special symbol for blank, which * is in fact stored in a special child (member d_blank). As a small * optimization, a suffix containing only blanks is represented by the empty * subtree, i.e., a null pointer. * + * Additionally, this class accepts membership queries involving null nodes, + * which are interpreted as requiring that all possible values of the node at + * that position are contained. For example, writing `_` for null: + * (_, 1, 2, 3) is contained in (_, 1, _, 3) + * (1, 1, _, 3) is contained in (_, 1, _, 3) + * (_, 2, _, _) is not contained in (_, 1, _, 3) + * (_, 1, 2, 3) is not contained in (0, 1, _, 3) */ class IndexTrie { @@ -61,7 +69,7 @@ class IndexTrie /* Construct the trie, if the argument ignoreFullySpecified is true, * the data structure will store only data structure containing at least * one blank. */ - IndexTrie(bool ignoreFullySpecified) + IndexTrie(bool ignoreFullySpecified = true) : d_ignoreFullySpecified(ignoreFullySpecified), d_root(new IndexTrieNode()) { @@ -71,11 +79,11 @@ class IndexTrie /** Add a tuple of values into the trie masked by a bitmask, i.e.\ position * i is considered blank iff mask[i] is false. */ - void add(const std::vector& mask, const std::vector& values); + void add(const std::vector& mask, const std::vector& values); /** Check if the given set of indices is subsumed by something present in the * trie. If it is subsumed, give the maximum non-blank index. */ - bool find(const std::vector& members, + bool find(const std::vector& members, /*out*/ size_t& nonBlankLength) const { nonBlankLength = 0; @@ -94,7 +102,7 @@ class IndexTrie /** Auxiliary recursive function for finding subsuming tuple. */ bool findRec(const IndexTrieNode* n, size_t index, - const std::vector& members, + const std::vector& members, size_t& nonBlankLength) const; /** Add master values starting from index to a given subtree. The @@ -103,7 +111,7 @@ class IndexTrie size_t index, size_t cardinality, const std::vector& mask, - const std::vector& values); + const std::vector& values); }; } // namespace quantifiers diff --git a/src/theory/quantifiers/inst_strategy_enumerative.cpp b/src/theory/quantifiers/inst_strategy_enumerative.cpp index ab52d71bf..7e9493ccd 100644 --- a/src/theory/quantifiers/inst_strategy_enumerative.cpp +++ b/src/theory/quantifiers/inst_strategy_enumerative.cpp @@ -182,18 +182,18 @@ bool InstStrategyEnum::process(Node quantifier, bool fullEffort, bool isRd) return false; } + Instantiate* ie = d_qim.getInstantiate(); TermTupleEnumeratorEnv ttec; ttec.d_fullEffort = fullEffort; ttec.d_increaseSum = options().quantifiers.enumInstSum; + ttec.d_tr = &d_treg; // make the enumerator, which is either relevant domain or term database // based on the flag isRd. std::unique_ptr enumerator( isRd ? mkTermTupleEnumeratorRd(quantifier, &ttec, d_rd) - : mkTermTupleEnumerator( - quantifier, &ttec, d_qstate, d_treg.getTermDatabase())); + : mkTermTupleEnumerator(quantifier, &ttec, d_qstate)); std::vector terms; std::vector failMask; - Instantiate* ie = d_qim.getInstantiate(); for (enumerator->init(); enumerator->hasNext();) { if (d_qstate.isInConflict()) diff --git a/src/theory/quantifiers/inst_strategy_pool.cpp b/src/theory/quantifiers/inst_strategy_pool.cpp index fa48a6a89..2cf081852 100644 --- a/src/theory/quantifiers/inst_strategy_pool.cpp +++ b/src/theory/quantifiers/inst_strategy_pool.cpp @@ -127,13 +127,13 @@ std::string InstStrategyPool::identify() const bool InstStrategyPool::process(Node q, Node p, uint64_t& addedLemmas) { + Instantiate* ie = d_qim.getInstantiate(); TermTupleEnumeratorEnv ttec; ttec.d_fullEffort = true; ttec.d_increaseSum = options().quantifiers.enumInstSum; - TermPools* tp = d_treg.getTermPools(); + ttec.d_tr = &d_treg; std::shared_ptr enumerator( - mkTermTupleEnumeratorPool(q, &ttec, tp, p)); - Instantiate* ie = d_qim.getInstantiate(); + mkTermTupleEnumeratorPool(q, &ttec, p)); std::vector terms; std::vector failMask; // we instantiate exhaustively diff --git a/src/theory/quantifiers/term_tuple_enumerator.cpp b/src/theory/quantifiers/term_tuple_enumerator.cpp index 6e189e902..cf5235f9c 100644 --- a/src/theory/quantifiers/term_tuple_enumerator.cpp +++ b/src/theory/quantifiers/term_tuple_enumerator.cpp @@ -26,6 +26,7 @@ #include "options/quantifiers_options.h" #include "smt/smt_statistics_registry.h" #include "theory/quantifiers/index_trie.h" +#include "theory/quantifiers/instantiate.h" #include "theory/quantifiers/quant_module.h" #include "theory/quantifiers/relevant_domain.h" #include "theory/quantifiers/term_pools.h" @@ -166,9 +167,10 @@ class TermTupleEnumeratorBasic : public TermTupleEnumeratorBase public: TermTupleEnumeratorBasic(Node quantifier, const TermTupleEnumeratorEnv* env, - QuantifiersState& qs, - TermDb* td) - : TermTupleEnumeratorBase(quantifier, env), d_qs(qs), d_tdb(td) + QuantifiersState& qs) + : TermTupleEnumeratorBase(quantifier, env), + d_qs(qs), + d_tdb(env->d_tr->getTermDatabase()) { } @@ -273,7 +275,9 @@ void TermTupleEnumeratorBase::failureReason(const std::vector& mask) { traceMaskedVector("inst-alg", "failureReason", mask, d_termIndex); } - d_disabledCombinations.add(mask, d_termIndex); // record failure + std::vector tti; + next(tti); + d_disabledCombinations.add(mask, tti); // record failure // update change prefix accordingly for (d_changePrefix = mask.size(); d_changePrefix && !mask[d_changePrefix - 1]; @@ -287,13 +291,14 @@ void TermTupleEnumeratorBase::next(/*out*/ std::vector& terms) terms.resize(d_variableCount); for (size_t variableIx = 0; variableIx < d_variableCount; variableIx++) { - const Node t = d_termsSizes[variableIx] == 0 - ? Node::null() - : getTerm(variableIx, d_termIndex[variableIx]); + const Node t = + d_termsSizes[variableIx] == 0 + ? d_env->d_tr->getTermForType(d_quantifier[0][variableIx].getType()) + : getTerm(variableIx, d_termIndex[variableIx]); terms[variableIx] = t; Trace("inst-alg-rd") << t << " "; - Assert(t.isNull() - || t.getType().isComparableTo(d_quantifier[0][variableIx].getType())) + Assert(!t.isNull()); + Assert(t.getType().isComparableTo(d_quantifier[0][variableIx].getType())) << "Bad type: " << t << " " << t.getType() << " " << d_quantifier[0][variableIx].getType(); } @@ -356,7 +361,9 @@ bool TermTupleEnumeratorBase::nextCombination() { return false; // ran out of combinations } - if (!d_disabledCombinations.find(d_termIndex, d_changePrefix)) + std::vector tti; + next(tti); + if (!d_disabledCombinations.find(tti, d_changePrefix)) { return true; // current combination vetted by disabled combinations } @@ -501,9 +508,10 @@ class TermTupleEnumeratorPool : public TermTupleEnumeratorBase public: TermTupleEnumeratorPool(Node quantifier, const TermTupleEnumeratorEnv* env, - TermPools* tp, Node pool) - : TermTupleEnumeratorBase(quantifier, env), d_tp(tp), d_pool(pool) + : TermTupleEnumeratorBase(quantifier, env), + d_tp(env->d_tr->getTermPools()), + d_pool(pool) { Assert(d_pool.getKind() == kind::INST_POOL); } @@ -536,10 +544,10 @@ class TermTupleEnumeratorPool : public TermTupleEnumeratorBase }; TermTupleEnumeratorInterface* mkTermTupleEnumerator( - Node q, const TermTupleEnumeratorEnv* env, QuantifiersState& qs, TermDb* td) + Node q, const TermTupleEnumeratorEnv* env, QuantifiersState& qs) { return static_cast( - new TermTupleEnumeratorBasic(q, env, qs, td)); + new TermTupleEnumeratorBasic(q, env, qs)); } TermTupleEnumeratorInterface* mkTermTupleEnumeratorRd( Node q, const TermTupleEnumeratorEnv* env, RelevantDomain* rd) @@ -549,10 +557,10 @@ TermTupleEnumeratorInterface* mkTermTupleEnumeratorRd( } TermTupleEnumeratorInterface* mkTermTupleEnumeratorPool( - Node q, const TermTupleEnumeratorEnv* env, TermPools* tp, Node pool) + Node q, const TermTupleEnumeratorEnv* env, Node pool) { return static_cast( - new TermTupleEnumeratorPool(q, env, tp, pool)); + new TermTupleEnumeratorPool(q, env, pool)); } } // namespace quantifiers diff --git a/src/theory/quantifiers/term_tuple_enumerator.h b/src/theory/quantifiers/term_tuple_enumerator.h index 05cd1da5c..2b3edfdec 100644 --- a/src/theory/quantifiers/term_tuple_enumerator.h +++ b/src/theory/quantifiers/term_tuple_enumerator.h @@ -24,9 +24,10 @@ namespace cvc5::internal { namespace theory { namespace quantifiers { +class Instantiate; class TermPools; class QuantifiersState; -class TermDb; +class TermRegistry; class RelevantDomain; /** Interface for enumeration of tuples of terms. @@ -65,6 +66,8 @@ struct TermTupleEnumeratorEnv bool d_fullEffort; /** Whether we increase tuples based on sum instead of max (see below) */ bool d_increaseSum; + /** Term registry */ + TermRegistry* d_tr; }; /** A function to construct a tuple enumerator. @@ -87,17 +90,14 @@ struct TermTupleEnumeratorEnv * duplicates modulo equality. */ TermTupleEnumeratorInterface* mkTermTupleEnumerator( - Node q, - const TermTupleEnumeratorEnv* env, - QuantifiersState& qs, - TermDb* td); + Node q, const TermTupleEnumeratorEnv* env, QuantifiersState& qs); /** Same as above, but draws terms from the relevant domain utility (rd). */ TermTupleEnumeratorInterface* mkTermTupleEnumeratorRd( Node q, const TermTupleEnumeratorEnv* env, RelevantDomain* rd); /** Make term pool enumerator */ TermTupleEnumeratorInterface* mkTermTupleEnumeratorPool( - Node q, const TermTupleEnumeratorEnv* env, TermPools* tp, Node p); + Node q, const TermTupleEnumeratorEnv* env, Node p); } // namespace quantifiers } // namespace theory