From: Mathias Preiner Date: Thu, 1 Oct 2020 20:48:37 +0000 (-0700) Subject: Add additional ground terms to SyGuS instantiation grammar (#5167) X-Git-Tag: cvc5-1.0.0~2771 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=776ee02237b06eb3130e56af4d98d9ff36667d8b;p=cvc5.git Add additional ground terms to SyGuS instantiation grammar (#5167) This PR adds options to add additional ground terms to the SyGuS instantiation grammars. --- diff --git a/src/options/quantifiers_options.toml b/src/options/quantifiers_options.toml index 4b130158c..724a2ef2b 100644 --- a/src/options/quantifiers_options.toml +++ b/src/options/quantifiers_options.toml @@ -1995,3 +1995,39 @@ header = "options/quantifiers_options.h" type = "bool" default = "false" help = "Enable SyGuS instantiation quantifiers module" + +[[option]] + name = "sygusInstScope" + category = "regular" + long = "sygus-inst-scope=MODE" + type = "SygusInstScope" + default = "IN" + help = "select scope of ground terms" + help_mode = "scope for collecting ground terms for the grammar." +[[option.mode.IN]] + name = "in" + help = "use ground terms inside given quantified formula only." +[[option.mode.OUT]] + name = "out" + help = "use ground terms outside of quantified formulas only." +[[option.mode.BOTH]] + name = "both" + help = "combines inside and outside." + +[[option]] + name = "sygusInstTermSel" + category = "regular" + long = "sygus-inst-term-sel=MODE" + type = "SygusInstTermSelMode" + default = "MIN" + help = "granularity for ground terms" + help_mode = "Ground term selection modes." +[[option.mode.MIN]] + name = "min" + help = "collect minimal ground terms only." +[[option.mode.MAX]] + name = "max" + help = "collect maximal ground terms only." +[[option.mode.BOTH]] + name = "both" + help = "combines minimal and maximal ." diff --git a/src/theory/quantifiers/sygus_inst.cpp b/src/theory/quantifiers/sygus_inst.cpp index f9a6456e1..4192ca746 100644 --- a/src/theory/quantifiers/sygus_inst.cpp +++ b/src/theory/quantifiers/sygus_inst.cpp @@ -29,10 +29,142 @@ namespace CVC4 { namespace theory { namespace quantifiers { +namespace { + +/** + * Collect maximal ground terms with type tn in node n. + * + * @param n: Node to traverse. + * @param tn: Collects only terms with type tn. + * @param terms: Collected terms. + * @param cache: Caches visited nodes. + * @param skip_quant: Do not traverse quantified formulas (skip quantifiers). + */ +void getMaxGroundTerms(TNode n, + TypeNode tn, + std::unordered_set& terms, + std::unordered_set& cache, + bool skip_quant = false) +{ + if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MAX + && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH) + { + return; + } + + Trace("sygus-inst-term") << "Find maximal terms with type " << tn + << " in: " << n << std::endl; + + Node cur; + std::vector visit; + + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + + if (cache.find(cur) != cache.end()) + { + continue; + } + cache.insert(cur); + + if (expr::hasBoundVar(cur) || cur.getType() != tn) + { + if (!skip_quant || cur.getKind() != kind::FORALL) + { + visit.insert(visit.end(), cur.begin(), cur.end()); + } + } + else + { + terms.insert(cur); + Trace("sygus-inst-term") << " found: " << cur << std::endl; + } + } while (!visit.empty()); +} + +/* + * Collect minimal ground terms with type tn in node n. + * + * @param n: Node to traverse. + * @param tn: Collects only terms with type tn. + * @param terms: Collected terms. + * @param cache: Caches visited nodes and flags indicating whether a minimal + * term was already found in a subterm. + * @param skip_quant: Do not traverse quantified formulas (skip quantifiers). + */ +void getMinGroundTerms( + TNode n, + TypeNode tn, + std::unordered_set& terms, + std::unordered_map, TNodeHashFunction>& cache, + bool skip_quant = false) +{ + if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MIN + && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH) + { + return; + } + + Trace("sygus-inst-term") << "Find minimal terms with type " << tn + << " in: " << n << std::endl; + + Node cur; + std::vector visit; + + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + + auto it = cache.find(cur); + if (it == cache.end()) + { + cache.emplace(cur, std::make_pair(false, false)); + if (!skip_quant || cur.getKind() != kind::FORALL) + { + visit.push_back(cur); + visit.insert(visit.end(), cur.begin(), cur.end()); + } + } + /* up-traversal */ + else if (!it->second.first) + { + bool found_min_term = false; + + /* Check if we found a minimal term in one of the children. */ + for (const Node& c : cur) + { + found_min_term |= cache[c].second; + if (found_min_term) break; + } + + /* If we haven't found a minimal term yet, add this term if it has the + * right type. */ + if (cur.getType() == tn && !expr::hasBoundVar(cur) && !found_min_term) + { + terms.insert(cur); + found_min_term = true; + Trace("sygus-inst-term") << " found: " << cur << std::endl; + } + + it->second.first = true; + it->second.second = found_min_term; + } + } while (!visit.empty()); +} + +} // namespace + SygusInst::SygusInst(QuantifiersEngine* qe) : QuantifiersModule(qe), d_lemma_cache(qe->getUserContext()), - d_ce_lemma_added(qe->getUserContext()) + d_ce_lemma_added(qe->getUserContext()), + d_global_terms(qe->getUserContext()), + d_notified_assertions(qe->getUserContext()) { } @@ -149,14 +281,79 @@ void SygusInst::registerQuantifier(Node q) std::map> include_cons; std::unordered_set term_irrelevant; - /* Collect extra symbols in 'q' to be used in the grammar. */ - std::unordered_set syms; - expr::getSymbols(q, syms); - for (const TNode& var : syms) + /* Collect relevant local ground terms for each variable type. */ + if (options::sygusInstScope() == options::SygusInstScope::IN + || options::sygusInstScope() == options::SygusInstScope::BOTH) + { + std::unordered_map, + TypeNodeHashFunction> + relevant_terms; + for (const Node& var : q[0]) + { + TypeNode tn = var.getType(); + + /* Collect relevant ground terms for type tn. */ + if (relevant_terms.find(tn) == relevant_terms.end()) + { + std::unordered_set terms; + std::unordered_set cache_max; + std::unordered_map, TNodeHashFunction> + cache_min; + + getMinGroundTerms(q, tn, terms, cache_min); + getMaxGroundTerms(q, tn, terms, cache_max); + relevant_terms.emplace(tn, terms); + } + + /* Add relevant ground terms to grammar. */ + auto& terms = relevant_terms[tn]; + for (const auto& t : terms) + { + TypeNode ttn = t.getType(); + extra_cons[ttn].insert(t); + Trace("sygus-inst") << "Adding (local) extra cons: " << t << std::endl; + } + } + } + + /* Collect relevant global ground terms for each variable type. */ + if (options::sygusInstScope() == options::SygusInstScope::OUT + || options::sygusInstScope() == options::SygusInstScope::BOTH) { - TypeNode tn = var.getType(); - extra_cons[tn].insert(var); - Trace("sygus-inst") << "Found symbol: " << var << std::endl; + for (const Node& var : q[0]) + { + TypeNode tn = var.getType(); + + /* Collect relevant ground terms for type tn. */ + if (d_global_terms.find(tn) == d_global_terms.end()) + { + std::unordered_set terms; + std::unordered_set cache_max; + std::unordered_map, TNodeHashFunction> + cache_min; + + for (const Node& a : d_notified_assertions) + { + getMinGroundTerms(a, tn, terms, cache_min, true); + getMaxGroundTerms(a, tn, terms, cache_max, true); + } + d_global_terms.insert(tn, terms); + } + + /* Add relevant ground terms to grammar. */ + auto it = d_global_terms.find(tn); + if (it != d_global_terms.end()) + { + for (const auto& t : (*it).second) + { + TypeNode ttn = t.getType(); + extra_cons[ttn].insert(t); + Trace("sygus-inst") + << "Adding (global) extra cons: " << t << std::endl; + } + } + } } /* Construct grammar for each bound variable of 'q'. */ @@ -190,6 +387,14 @@ void SygusInst::preRegisterQuantifier(Node q) addCeLemma(q); } +void SygusInst::ppNotifyAssertions(const std::vector& assertions) +{ + for (const Node& a : assertions) + { + d_notified_assertions.insert(a); + } +} + /*****************************************************************************/ /* private methods */ /*****************************************************************************/ diff --git a/src/theory/quantifiers/sygus_inst.h b/src/theory/quantifiers/sygus_inst.h index 2361c4a2b..c95c6a026 100644 --- a/src/theory/quantifiers/sygus_inst.h +++ b/src/theory/quantifiers/sygus_inst.h @@ -82,6 +82,9 @@ class SygusInst : public QuantifiersModule /* Called once for every quantifier 'q' per context. */ void preRegisterQuantifier(Node q) override; + /* For collecting global terms from all available assertions. */ + void ppNotifyAssertions(const std::vector& assertions); + std::string identify() const override { return "SygusInst"; } private: @@ -124,6 +127,15 @@ class SygusInst : public QuantifiersModule /* Indicates whether a counterexample lemma was added for a quantified * formula in the current context. */ context::CDHashSet d_ce_lemma_added; + + /* Set of global ground terms in assertions (outside of quantifiers). */ + context::CDHashMap, + TypeNodeHashFunction> + d_global_terms; + + /* Assertions sent by ppNotifyAssertions. */ + context::CDHashSet d_notified_assertions; }; } // namespace quantifiers diff --git a/src/theory/quantifiers_engine.cpp b/src/theory/quantifiers_engine.cpp index 557d444d6..cceb04d9f 100644 --- a/src/theory/quantifiers_engine.cpp +++ b/src/theory/quantifiers_engine.cpp @@ -370,6 +370,14 @@ void QuantifiersEngine::ppNotifyAssertions( sye->preregisterAssertion(a); } } + /* The SyGuS instantiation module needs a global view of all available + * assertions to collect global terms that get added to each grammar. + */ + if (options::sygusInst()) + { + quantifiers::SygusInst* si = d_qmodules->d_sygus_inst.get(); + si->ppNotifyAssertions(assertions); + } } void QuantifiersEngine::check( Theory::Effort e ){ @@ -976,8 +984,11 @@ void QuantifiersEngine::flushLemmas(){ //take default output channel if none is provided d_hasAddedLemma = true; std::map::iterator itp; - for (const Node& lemw : d_lemmas_waiting) + // Note: Do not use foreach loop here and do not cache size() call. + // New lemmas can be added while iterating over d_lemmas_waiting. + for (size_t i = 0; i < d_lemmas_waiting.size(); ++i) { + const Node& lemw = d_lemmas_waiting[i]; Trace("qe-lemma") << "Lemma : " << lemw << std::endl; itp = d_lemmasWaitingPg.find(lemw); if (itp != d_lemmasWaitingPg.end())