From: Andrew Reynolds Date: Mon, 5 Nov 2018 15:25:33 +0000 (-0600) Subject: Allow partial models with optimized sygus enumeration (#2682) X-Git-Tag: cvc5-1.0.0~4380 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=700ee947a84ee8df9a7a50d44999a48ccc2626d8;p=cvc5.git Allow partial models with optimized sygus enumeration (#2682) --- diff --git a/src/theory/quantifiers/sygus/enum_stream_substitution.cpp b/src/theory/quantifiers/sygus/enum_stream_substitution.cpp index 24770ade0..e8daa4256 100644 --- a/src/theory/quantifiers/sygus/enum_stream_substitution.cpp +++ b/src/theory/quantifiers/sygus/enum_stream_substitution.cpp @@ -608,8 +608,17 @@ bool EnumStreamSubstitution::CombinationState::getNextCombination() } void EnumStreamConcrete::initialize(Node e) { d_ess.initialize(e.getType()); } -void EnumStreamConcrete::addValue(Node v) { d_ess.resetValue(v); } -Node EnumStreamConcrete::getNext() { return d_ess.getNext(); } +void EnumStreamConcrete::addValue(Node v) +{ + d_ess.resetValue(v); + d_currTerm = d_ess.getNext(); +} +bool EnumStreamConcrete::increment() +{ + d_currTerm = d_ess.getNext(); + return !d_currTerm.isNull(); +} +Node EnumStreamConcrete::getCurrent() { return d_currTerm; } } // namespace quantifiers } // namespace theory } // namespace CVC4 diff --git a/src/theory/quantifiers/sygus/enum_stream_substitution.h b/src/theory/quantifiers/sygus/enum_stream_substitution.h index 38fa0627b..476a364ea 100644 --- a/src/theory/quantifiers/sygus/enum_stream_substitution.h +++ b/src/theory/quantifiers/sygus/enum_stream_substitution.h @@ -286,12 +286,16 @@ class EnumStreamConcrete : public EnumValGenerator void initialize(Node e) override; /** get that value v was enumerated */ void addValue(Node v) override; - /** get the next value enumerated by this class */ - Node getNext() override; + /** increment */ + bool increment() override; + /** get the current value enumerated by this class */ + Node getCurrent() override; private: /** stream substitution utility */ EnumStreamSubstitution d_ess; + /** the current term generated by this class */ + Node d_currTerm; }; } // namespace quantifiers diff --git a/src/theory/quantifiers/sygus/sygus_enumerator.cpp b/src/theory/quantifiers/sygus/sygus_enumerator.cpp index a39c9e958..aab580650 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator.cpp +++ b/src/theory/quantifiers/sygus/sygus_enumerator.cpp @@ -25,11 +25,7 @@ namespace theory { namespace quantifiers { SygusEnumerator::SygusEnumerator(TermDbSygus* tds, SynthConjecture* p) - : d_tds(tds), - d_parent(p), - d_tlEnum(nullptr), - d_abortSize(-1), - d_firstTime(false) + : d_tds(tds), d_parent(p), d_tlEnum(nullptr), d_abortSize(-1) { } @@ -39,7 +35,6 @@ void SygusEnumerator::initialize(Node e) d_etype = d_enum.getType(); d_tlEnum = getMasterEnumForType(d_etype); d_abortSize = options::sygusAbortSize(); - d_firstTime = true; } void SygusEnumerator::addValue(Node v) @@ -47,17 +42,9 @@ void SygusEnumerator::addValue(Node v) // do nothing } -Node SygusEnumerator::getNext() +bool SygusEnumerator::increment() { return d_tlEnum->increment(); } +Node SygusEnumerator::getCurrent() { - if (d_firstTime) - { - d_firstTime = false; - } - else if (!d_tlEnum->increment()) - { - // no more values - return Node::null(); - } if (d_abortSize >= 0) { int cs = static_cast(d_tlEnum->getCurrentSize()); @@ -70,8 +57,12 @@ Node SygusEnumerator::getNext() } } Node ret = d_tlEnum->getCurrent(); - Trace("sygus-enum") << "Enumerate : " << d_tds->sygusToBuiltin(ret) - << std::endl; + if (Trace.isOn("sygus-enum")) + { + Trace("sygus-enum") << "Enumerate : "; + TermDbSygus::toStreamSygus("sygus-enum", ret); + Trace("sygus-enum") << std::endl; + } return ret; } @@ -241,6 +232,7 @@ bool SygusEnumerator::TermCache::addTerm(Node n) d_terms.push_back(n); return true; } + Assert(!n.isNull()); if (options::sygusSymBreakDynamic()) { Node bn = d_tds->sygusToBuiltin(n); @@ -387,7 +379,7 @@ bool SygusEnumerator::TermEnumSlave::validateIndex() Trace("sygus-enum-debug2") << "slave(" << d_tn << ") : validate index...\n"; SygusEnumerator::TermCache& tc = d_se->d_tcache[d_tn]; // ensure that index is in the range - if (d_index >= tc.getNumTerms()) + while (d_index >= tc.getNumTerms()) { Assert(d_index == tc.getNumTerms()); Trace("sygus-enum-debug2") << "slave(" << d_tn << ") : force master...\n"; @@ -497,6 +489,7 @@ SygusEnumerator::TermEnum* SygusEnumerator::getMasterEnumForType(TypeNode tn) SygusEnumerator::TermEnumMaster::TermEnumMaster() : TermEnum(), d_isIncrementing(false), + d_currTermSet(false), d_consClassNum(0), d_ccWeight(0), d_consNum(0), @@ -518,6 +511,7 @@ bool SygusEnumerator::TermEnumMaster::initialize(SygusEnumerator* se, d_currChildSize = 0; d_ccCons.clear(); d_isIncrementing = false; + d_currTermSet = false; bool ret = increment(); Trace("sygus-enum-debug") << "master(" << tn << "): finish init, ret = " << ret << "\n"; @@ -526,10 +520,11 @@ bool SygusEnumerator::TermEnumMaster::initialize(SygusEnumerator* se, Node SygusEnumerator::TermEnumMaster::getCurrent() { - if (!d_currTerm.isNull()) + if (d_currTermSet) { return d_currTerm; } + d_currTermSet = true; // construct based on the children std::vector children; const Datatype& dt = d_tn.getDatatype(); @@ -541,7 +536,13 @@ Node SygusEnumerator::TermEnumMaster::getCurrent() for (unsigned i = 0, nargs = dt[cnum].getNumArgs(); i < nargs; i++) { Assert(d_children.find(i) != d_children.end()); - children.push_back(d_children[i].getCurrent()); + Node cc = d_children[i].getCurrent(); + if (cc.isNull()) + { + d_currTerm = cc; + return cc; + } + children.push_back(cc); } d_currTerm = NodeManager::currentNM()->mkNode(APPLY_CONSTRUCTOR, children); return d_currTerm; @@ -708,21 +709,27 @@ bool SygusEnumerator::TermEnumMaster::incrementInternal() Assert(d_childrenValid == d_ccTypes.size()); // do we have more constructors for the given children? - while (d_consNum < d_ccCons.size()) + if (d_consNum < d_ccCons.size()) { Trace("sygus-enum-debug2") << "master(" << d_tn << "): try constructor " << d_consNum << std::endl; // increment constructor index // we will build for the current constructor and the given children d_consNum++; + d_currTermSet = false; d_currTerm = Node::null(); Node c = getCurrent(); - if (tc.addTerm(c)) + if (!c.isNull()) { - return true; + if (!tc.addTerm(c)) + { + // the term was not unique based on rewriting + Trace("sygus-enum-debug2") << "master(" << d_tn + << "): failed addTerm\n"; + d_currTerm = Node::null(); + } } - Trace("sygus-enum-debug2") << "master(" << d_tn << "): failed addTerm\n"; - // the term was not unique based on rewriting + return true; } // finished constructors for this set of children, must increment children diff --git a/src/theory/quantifiers/sygus/sygus_enumerator.h b/src/theory/quantifiers/sygus/sygus_enumerator.h index 10a44da03..28f8f4126 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator.h +++ b/src/theory/quantifiers/sygus/sygus_enumerator.h @@ -50,8 +50,10 @@ class SygusEnumerator : public EnumValGenerator void initialize(Node e) override; /** Inform this generator that abstract value v was enumerated. */ void addValue(Node v) override; + /** increment */ + bool increment() override; /** Get the next concrete value generated by this class. */ - Node getNext() override; + Node getCurrent() override; private: /** pointer to term database sygus */ @@ -322,6 +324,8 @@ class SygusEnumerator : public EnumValGenerator bool d_isIncrementing; /** cache for getCurrent() */ Node d_currTerm; + /** is d_currTerm set */ + bool d_currTermSet; //----------------------------- current constructor class information /** the next constructor class we are using */ unsigned d_consClassNum; @@ -429,8 +433,6 @@ class SygusEnumerator : public EnumValGenerator TermEnum* d_tlEnum; /** the abort size, caches the value of --sygus-abort-size */ int d_abortSize; - /** this flag is true for the first time to getNext() after initialize(e) */ - bool d_firstTime; /** get master enumerator for type tn */ TermEnum* getMasterEnumForType(TypeNode tn); }; diff --git a/src/theory/quantifiers/sygus/synth_conjecture.cpp b/src/theory/quantifiers/sygus/synth_conjecture.cpp index e668b2206..03344d2e7 100644 --- a/src/theory/quantifiers/sygus/synth_conjecture.cpp +++ b/src/theory/quantifiers/sygus/synth_conjecture.cpp @@ -346,14 +346,17 @@ bool SynthConjecture::doCheck(std::vector& lems) { // get the model value of the relevant terms from the master module std::vector enum_values; - bool fullModel = getEnumeratedValues(terms, enum_values); + bool activeIncomplete = false; + bool fullModel = getEnumeratedValues(terms, enum_values, activeIncomplete); // if the master requires a full model and the model is partial, we fail if (!d_master->allowPartialModel() && !fullModel) { // we retain the values in d_ev_active_gen_waiting Trace("cegqi-engine") << "...partial model, fail." << std::endl; - return true; + // if we are partial due to an active enumerator, we may still succeed + // on the next call + return !activeIncomplete; } // the waiting values are passed to the module below, clear d_ev_active_gen_waiting.clear(); @@ -396,7 +399,7 @@ bool SynthConjecture::doCheck(std::vector& lems) if (emptyModel) { Trace("cegqi-engine") << "...empty model, fail." << std::endl; - return true; + return !activeIncomplete; } Assert(candidate_values.empty()); constructed_cand = d_master->constructCandidates( @@ -650,7 +653,8 @@ void SynthConjecture::preregisterConjecture(Node q) } bool SynthConjecture::getEnumeratedValues(std::vector& n, - std::vector& v) + std::vector& v, + bool& activeIncomplete) { std::vector ncheck = n; n.clear(); @@ -670,7 +674,7 @@ bool SynthConjecture::getEnumeratedValues(std::vector& n, continue; } } - Node nv = getEnumeratedValue(e); + Node nv = getEnumeratedValue(e, activeIncomplete); n.push_back(e); v.push_back(nv); ret = ret && !nv.isNull(); @@ -692,42 +696,50 @@ class EnumValGeneratorBasic : public EnumValGenerator /** initialize (do nothing) */ void initialize(Node e) override {} /** initialize (do nothing) */ - void addValue(Node v) override {} + void addValue(Node v) override { d_currTerm = *d_te; } /** * Get next returns the next (T-rewriter-unique) value based on the type * enumerator. */ - Node getNext() override + bool increment() override { if (d_te.isFinished()) { - return Node::null(); + d_currTerm = Node::null(); + return false; } - Node next = *d_te; + d_currTerm = *d_te; ++d_te; if (options::sygusSymBreakDynamic()) { - Node nextb = d_tds->sygusToBuiltin(next); + Node nextb = d_tds->sygusToBuiltin(d_currTerm); nextb = d_tds->getExtRewriter()->extendedRewrite(nextb); - if (d_cache.find(nextb) != d_cache.end()) + if (d_cache.find(nextb) == d_cache.end()) + { + d_cache.insert(nextb); + // only return the current term if not unique + } + else { - return getNext(); + d_currTerm = Node::null(); } - d_cache.insert(nextb); } - return next; + return true; } - + /** get the current term */ + Node getCurrent() override { return d_currTerm; } private: /** pointer to term database sygus */ TermDbSygus* d_tds; /** the type enumerator */ TypeEnumerator d_te; + /** the current term */ + Node d_currTerm; /** cache of (enumerated) builtin values we have enumerated so far */ std::unordered_set d_cache; }; -Node SynthConjecture::getEnumeratedValue(Node e) +Node SynthConjecture::getEnumeratedValue(Node e, bool& activeIncomplete) { bool isEnum = d_tds->isEnumerator(e); @@ -790,6 +802,7 @@ Node SynthConjecture::getEnumeratedValue(Node e) // Check if there is an (abstract) value absE we were actively generating // values based on. Node absE = d_ev_curr_active_gen[e]; + bool firstTime = false; if (absE.isNull()) { // None currently exist. The next abstract value is the model value for e. @@ -804,9 +817,22 @@ Node SynthConjecture::getEnumeratedValue(Node e) } d_ev_curr_active_gen[e] = absE; iteg->second->addValue(absE); + firstTime = true; + } + bool inc = true; + if (!firstTime) + { + inc = iteg->second->increment(); + } + Node v; + if (inc) + { + v = iteg->second->getCurrent(); } - Node v = iteg->second->getNext(); - if (v.isNull()) + Trace("sygus-active-gen-debug") << "...generated " << v + << ", with increment success : " << inc + << std::endl; + if (!inc) { // No more concrete values generated from absE. NodeManager* nm = NodeManager::currentNM(); @@ -852,7 +878,14 @@ Node SynthConjecture::getEnumeratedValue(Node e) else { // We are waiting to send e -> v to the module that requested it. - d_ev_active_gen_waiting[e] = v; + if (v.isNull()) + { + activeIncomplete = true; + } + else + { + d_ev_active_gen_waiting[e] = v; + } if (Trace.isOn("sygus-active-gen")) { Trace("sygus-active-gen") << "Active-gen : " << e << " : "; diff --git a/src/theory/quantifiers/sygus/synth_conjecture.h b/src/theory/quantifiers/sygus/synth_conjecture.h index ef1b4459f..3a43eb83d 100644 --- a/src/theory/quantifiers/sygus/synth_conjecture.h +++ b/src/theory/quantifiers/sygus/synth_conjecture.h @@ -49,8 +49,14 @@ class EnumValGenerator virtual void initialize(Node e) = 0; /** Inform this generator that abstract value v was enumerated. */ virtual void addValue(Node v) = 0; - /** Get the next concrete value generated by this class. */ - virtual Node getNext() = 0; + /** + * Increment this value generator. If this returns false, then we are out of + * values. If this returns true, getCurrent(), if non-null, returns the + * current term. + */ + virtual bool increment() = 0; + /** Get the current concrete value generated by this class. */ + virtual Node getCurrent() = 0; }; /** a synthesis conjecture @@ -193,15 +199,25 @@ class SynthConjecture * Get model values for terms n, store in vector v. This method returns true * if and only if all values added to v are non-null. * + * The argument activeIncomplete indicates whether n contains an active + * enumerator that is currently not finished enumerating values, but returned + * null on a call to getEnumeratedValue. This value is used for determining + * whether we should call getEnumeratedValues again within a call to + * SynthConjecture::check. + * * It removes terms from n that correspond to "inactive" enumerators, that * is, enumerators whose values have been exhausted. */ - bool getEnumeratedValues(std::vector& n, std::vector& v); + bool getEnumeratedValues(std::vector& n, + std::vector& v, + bool& activeIncomplete); /** * Get model value for term n. If n has a value that was excluded by - * datatypes sygus symmetry breaking, this method returns null. + * datatypes sygus symmetry breaking, this method returns null. It sets + * activeIncomplete to true if there is an actively-generated enumerator whose + * current value is null but it has not finished generating values. */ - Node getEnumeratedValue(Node n); + Node getEnumeratedValue(Node n, bool& activeIncomplete); /** enumerator generators for each actively-generated enumerator */ std::map > d_evg; /** diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index 435e1a00f..9af990086 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -756,9 +756,16 @@ void TermDbSygus::toStreamSygus(const char* c, Node n) { if (Trace.isOn(c)) { - std::stringstream ss; - Printer::getPrinter(options::outputLanguage())->toStreamSygus(ss, n); - Trace(c) << ss.str(); + if (n.isNull()) + { + Trace(c) << n; + } + else + { + std::stringstream ss; + Printer::getPrinter(options::outputLanguage())->toStreamSygus(ss, n); + Trace(c) << ss.str(); + } } } diff --git a/test/regress/regress1/sygus/commutative-stream.sy b/test/regress/regress1/sygus/commutative-stream.sy index 7b96a2bf3..8203fd9cf 100644 --- a/test/regress/regress1/sygus/commutative-stream.sy +++ b/test/regress/regress1/sygus/commutative-stream.sy @@ -3,7 +3,7 @@ ; EXPECT: (error "Maximum term size (2) for enumerative SyGuS exceeded.") ; EXIT: 1 -; COMMAND-LINE: --sygus-stream --sygus-abort-size=2 --decision=justification +; COMMAND-LINE: --sygus-stream --sygus-abort-size=2 --sygus-active-gen=none --decision=justification (set-logic LIA)