From 223155cfb300458f534f4be6b88e5fdc17b0ff14 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Thu, 11 Mar 2021 02:44:30 -0600 Subject: [PATCH] Direct lemmas and inference ids for sygus extension (#5960) This adds inference ID for the datatypes sygus solver, and changes its style to send lemmas instead of passing them to the caller. --- src/theory/datatypes/inference_manager.cpp | 14 -- src/theory/datatypes/inference_manager.h | 5 - src/theory/datatypes/sygus_extension.cpp | 245 +++++++++++---------- src/theory/datatypes/sygus_extension.h | 79 ++++--- src/theory/datatypes/theory_datatypes.cpp | 16 +- src/theory/inference_id.cpp | 21 ++ src/theory/inference_id.h | 31 +++ 7 files changed, 224 insertions(+), 187 deletions(-) diff --git a/src/theory/datatypes/inference_manager.cpp b/src/theory/datatypes/inference_manager.cpp index 4d0dd4998..1a837e73b 100644 --- a/src/theory/datatypes/inference_manager.cpp +++ b/src/theory/datatypes/inference_manager.cpp @@ -96,20 +96,6 @@ void InferenceManager::sendDtConflict(const std::vector& conf, InferenceId conflictExp(id, conf, d_ipc.get()); } -bool InferenceManager::sendLemmas(const std::vector& lemmas, - InferenceId id) -{ - bool ret = false; - for (const Node& lem : lemmas) - { - if (lemma(lem, id)) - { - ret = true; - } - } - return ret; -} - bool InferenceManager::isProofEnabled() const { return d_ipc != nullptr; } TrustNode InferenceManager::processDtLemma(Node conc, Node exp, InferenceId id) diff --git a/src/theory/datatypes/inference_manager.h b/src/theory/datatypes/inference_manager.h index 110747043..83876817b 100644 --- a/src/theory/datatypes/inference_manager.h +++ b/src/theory/datatypes/inference_manager.h @@ -72,11 +72,6 @@ class InferenceManager : public InferenceManagerBuffered * Send conflict immediately on the output channel */ void sendDtConflict(const std::vector& conf, InferenceId id); - /** - * Send lemmas with property NONE on the output channel immediately. - * Returns true if any lemma was sent. - */ - bool sendLemmas(const std::vector& lemmas, InferenceId id); private: /** Are proofs enabled? */ diff --git a/src/theory/datatypes/sygus_extension.cpp b/src/theory/datatypes/sygus_extension.cpp index e75d1005f..9a952601d 100644 --- a/src/theory/datatypes/sygus_extension.cpp +++ b/src/theory/datatypes/sygus_extension.cpp @@ -23,10 +23,11 @@ #include "options/quantifiers_options.h" #include "printer/printer.h" #include "smt/logic_exception.h" +#include "theory/datatypes/inference_manager.h" #include "theory/datatypes/sygus_datatype_utils.h" #include "theory/datatypes/theory_datatypes_utils.h" -#include "theory/quantifiers/sygus/synth_conjecture.h" #include "theory/quantifiers/sygus/sygus_explain.h" +#include "theory/quantifiers/sygus/synth_conjecture.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" @@ -61,8 +62,9 @@ SygusExtension::~SygusExtension() { } /** add tester */ -void SygusExtension::assertTester( int tindex, TNode n, Node exp, std::vector< Node >& lemmas ) { - registerTerm( n, lemmas ); +void SygusExtension::assertTester(int tindex, TNode n, Node exp) +{ + registerTerm(n); // check if this is a relevant (sygus) term if( d_term_to_anchor.find( n )!=d_term_to_anchor.end() ){ Trace("sygus-sb-debug2") << "Sygus : process tester : " << exp << std::endl; @@ -95,7 +97,7 @@ void SygusExtension::assertTester( int tindex, TNode n, Node exp, std::vector< N } } if( do_add ){ - assertTesterInternal( tindex, n, exp, lemmas ); + assertTesterInternal(tindex, n, exp); }else{ Trace("sygus-sb-debug2") << "...ignore inactive tester : " << exp << std::endl; } @@ -107,7 +109,8 @@ void SygusExtension::assertTester( int tindex, TNode n, Node exp, std::vector< N } } -void SygusExtension::assertFact( Node n, bool polarity, std::vector< Node >& lemmas ) { +void SygusExtension::assertFact(Node n, bool polarity) +{ if (n.getKind() == kind::DT_SYGUS_BOUND) { Node m = n[0]; @@ -118,14 +121,14 @@ void SygusExtension::assertFact( Node n, bool polarity, std::vector< Node >& lem std::map>::iterator its = d_szinfo.find(m); Assert(its != d_szinfo.end()); - Node mt = its->second->getOrMkMeasureValue(lemmas); + Node mt = its->second->getOrMkMeasureValue(); //it relates the measure term to arithmetic Node blem = n.eqNode( NodeManager::currentNM()->mkNode( kind::LEQ, mt, n[1] ) ); - lemmas.push_back( blem ); + d_im.lemma(blem, InferenceId::DATATYPES_SYGUS_FAIR_SIZE); } if( polarity ){ - unsigned s = n[1].getConst().getNumerator().toUnsignedInt(); - notifySearchSize( m, s, n, lemmas ); + uint64_t s = n[1].getConst().getNumerator().toUnsignedInt(); + notifySearchSize(m, s, n); } }else if( n.getKind() == kind::DT_HEIGHT_BOUND || n.getKind()==DT_SIZE_BOUND ){ //reduce to arithmetic TODO ? @@ -142,7 +145,8 @@ Node SygusExtension::getTermOrderPredicate( Node n1, Node n2 ) { return szGeq; } -void SygusExtension::registerTerm( Node n, std::vector< Node >& lemmas ) { +void SygusExtension::registerTerm(Node n) +{ if( d_is_top_level.find( n )==d_is_top_level.end() ){ d_is_top_level[n] = false; TypeNode tn = n.getType(); @@ -150,7 +154,7 @@ void SygusExtension::registerTerm( Node n, std::vector< Node >& lemmas ) { bool is_top_level = false; bool success = false; if( n.getKind()==kind::APPLY_SELECTOR_TOTAL ){ - registerTerm( n[0], lemmas ); + registerTerm(n[0]); std::unordered_map::iterator it = d_term_to_anchor.find(n[0]); if( it!=d_term_to_anchor.end() ) { @@ -162,7 +166,7 @@ void SygusExtension::registerTerm( Node n, std::vector< Node >& lemmas ) { success = true; } }else if( n.isVar() ){ - registerSizeTerm( n, lemmas ); + registerSizeTerm(n); if( d_register_st[n] ){ d_term_to_anchor[n] = n; d_anchor_to_conj[n] = d_tds->getConjectureForEnumerator(n); @@ -179,7 +183,7 @@ void SygusExtension::registerTerm( Node n, std::vector< Node >& lemmas ) { << ", type = " << tn.getDType().getName() << std::endl; d_term_to_depth[n] = d; d_is_top_level[n] = is_top_level; - registerSearchTerm( tn, d, n, is_top_level, lemmas ); + registerSearchTerm(tn, d, n, is_top_level); }else{ Trace("sygus-sb-debug2") << "Term " << n << " is not part of sygus search." << std::endl; } @@ -196,7 +200,8 @@ bool SygusExtension::computeTopLevel( TypeNode tn, Node n ){ } } -void SygusExtension::assertTesterInternal( int tindex, TNode n, Node exp, std::vector< Node >& lemmas ) { +void SygusExtension::assertTesterInternal(int tindex, TNode n, Node exp) +{ TypeNode ntn = n.getType(); if (!ntn.isDatatype()) { @@ -263,7 +268,8 @@ void SygusExtension::assertTesterInternal( int tindex, TNode n, Node exp, std::v conflict.push_back( itsz->second->d_search_size_exp[ssz] ); Node conf = NodeManager::currentNM()->mkNode( kind::AND, conflict ); Trace("sygus-sb-fair") << "Conflict is : " << conf << std::endl; - lemmas.push_back( conf.negate() ); + Node confn = conf.negate(); + d_im.lemma(confn, InferenceId::DATATYPES_SYGUS_FAIR_SIZE_CONFLICT); return; } } @@ -275,7 +281,7 @@ void SygusExtension::assertTesterInternal( int tindex, TNode n, Node exp, std::v //Assert( d<=ssz ); if( options::sygusSymBreakLazy() ){ // dynamic symmetry breaking - addSymBreakLemmasFor( ntn, n, d, lemmas ); + addSymBreakLemmasFor(ntn, n, d); } Trace("sygus-sb-debug") << "Get simple symmetry breaking predicates...\n"; @@ -284,7 +290,7 @@ void SygusExtension::assertTesterInternal( int tindex, TNode n, Node exp, std::v NodeManager* nm = NodeManager::currentNM(); if( min_depth<=max_depth ){ TNode x = getFreeVar( ntn ); - std::vector sb_lemmas; + std::vector> sbLemmas; // symmetry breaking lemmas requiring predicate elimination std::map sb_elim_pred; bool usingSymCons = d_tds->usingSymbolicConsForEnumerator(m); @@ -297,7 +303,8 @@ void SygusExtension::assertTesterInternal( int tindex, TNode n, Node exp, std::v m, ntn, tindex, ds, usingSymCons, isVarAgnostic); if (!ipred.isNull()) { - sb_lemmas.push_back(ipred); + sbLemmas.emplace_back(ipred, + InferenceId::DATATYPES_SYGUS_SIMPLE_SYM_BREAK); if (ds == 0 && isVarAgnostic) { sb_elim_pred[ipred] = true; @@ -317,7 +324,8 @@ void SygusExtension::assertTesterInternal( int tindex, TNode n, Node exp, std::v conj->getSymmetryBreakingPredicate(x, a, ntn, tindex, ds); if (!dpred.isNull()) { - sb_lemmas.push_back(dpred); + sbLemmas.emplace_back(dpred, + InferenceId::DATATYPES_SYGUS_CDEP_SYM_BREAK); } } } @@ -326,8 +334,9 @@ void SygusExtension::assertTesterInternal( int tindex, TNode n, Node exp, std::v // add the above symmetry breaking predicates to lemmas std::unordered_map cache; Node rlv = getRelevancyCondition(n); - for (const Node& slem : sb_lemmas) + for (std::pair& sbl : sbLemmas) { + Node slem = sbl.first; Node sslem = slem.substitute(x, n, cache); // if we require predicate elimination if (sb_elim_pred.find(slem) != sb_elim_pred.end()) @@ -342,7 +351,7 @@ void SygusExtension::assertTesterInternal( int tindex, TNode n, Node exp, std::v { sslem = nm->mkNode(OR, rlv, sslem); } - lemmas.push_back(sslem); + d_im.lemma(sslem, sbl.second); } } d_simple_proc[exp] = max_depth + 1; @@ -359,7 +368,7 @@ void SygusExtension::assertTesterInternal( int tindex, TNode n, Node exp, std::v IntMap::const_iterator itt = d_testers.find( sel ); if( itt != d_testers.end() ){ Assert(d_testers_exp.find(sel) != d_testers_exp.end()); - assertTesterInternal( (*itt).second, sel, d_testers_exp[sel], lemmas ); + assertTesterInternal((*itt).second, sel, d_testers_exp[sel]); } } Trace("sygus-sb-debug") << "...finished" << std::endl; @@ -930,7 +939,11 @@ TNode SygusExtension::getFreeVar( TypeNode tn ) { return d_tds->getFreeVar(tn, 0); } -void SygusExtension::registerSearchTerm( TypeNode tn, unsigned d, Node n, bool topLevel, std::vector< Node >& lemmas ) { +void SygusExtension::registerSearchTerm(TypeNode tn, + unsigned d, + Node n, + bool topLevel) +{ //register this term std::unordered_map::iterator ita = d_term_to_anchor.find(n); @@ -945,7 +958,7 @@ void SygusExtension::registerSearchTerm( TypeNode tn, unsigned d, Node n, bool t Trace("sygus-sb-debug") << " register search term : " << n << " at depth " << d << ", type=" << tn << ", tl=" << topLevel << std::endl; sca.d_search_terms[tn][d].push_back(n); if( !options::sygusSymBreakLazy() ){ - addSymBreakLemmasFor( tn, n, d, lemmas ); + addSymBreakLemmasFor(tn, n, d); } } } @@ -954,7 +967,6 @@ Node SygusExtension::registerSearchValue(Node a, Node n, Node nv, unsigned d, - std::vector& lemmas, bool isVarAgnostic, bool doSym) { @@ -989,7 +1001,6 @@ Node SygusExtension::registerSearchValue(Node a, sel, nv[i], d + 1, - lemmas, isVarAgnostic, doSym && (!isVarAgnostic || i == 0)); if (nvc.isNull()) @@ -1033,8 +1044,7 @@ Node SygusExtension::registerSearchValue(Node a, quantifiers::DivByZeroSygusInvarianceTest dbzet; Trace("sygus-sb-mexp-debug") << "Minimize explanation for div-by-zero in " << bv << std::endl; - registerSymBreakLemmaForValue( - a, nv, dbzet, Node::null(), var_count, lemmas); + registerSymBreakLemmaForValue(a, nv, dbzet, Node::null(), var_count); return Node::null(); }else{ std::unordered_map& scasv = @@ -1143,8 +1153,7 @@ Node SygusExtension::registerSearchValue(Node a, eset.init(d_tds, tn, aconj, a, bvr); Trace("sygus-sb-mexp-debug") << "Minimize explanation for eval[" << d_tds->sygusToBuiltin( bad_val ) << "] = " << bvr << std::endl; - registerSymBreakLemmaForValue( - a, bad_val, eset, bad_val_o, var_count, lemmas); + registerSymBreakLemmaForValue(a, bad_val, eset, bad_val_o, var_count); // other generalization criteria go here @@ -1166,8 +1175,7 @@ void SygusExtension::registerSymBreakLemmaForValue( Node val, quantifiers::SygusInvarianceTest& et, Node valr, - std::map& var_count, - std::vector& lemmas) + std::map& var_count) { TypeNode tn = val.getType(); Node x = getFreeVar(tn); @@ -1179,10 +1187,14 @@ void SygusExtension::registerSymBreakLemmaForValue( lem = lem.negate(); Trace("sygus-sb-exc") << " ........exc lemma is " << lem << ", size = " << sz << std::endl; - registerSymBreakLemma(tn, lem, sz, a, lemmas); + registerSymBreakLemma(tn, lem, sz, a); } -void SygusExtension::registerSymBreakLemma( TypeNode tn, Node lem, unsigned sz, Node a, std::vector< Node >& lemmas ) { +void SygusExtension::registerSymBreakLemma(TypeNode tn, + Node lem, + unsigned sz, + Node a) +{ // lem holds for all terms of type tn, and is applicable to terms of size sz Trace("sygus-sb-debug") << " register sym break lemma : " << lem << std::endl; @@ -1191,7 +1203,7 @@ void SygusExtension::registerSymBreakLemma( TypeNode tn, Node lem, unsigned sz, Trace("sygus-sb-debug") << " size : " << sz << std::endl; Assert(!a.isNull()); SearchCache& sca = d_cache[a]; - sca.d_sb_lemmas[tn][sz].push_back(lem); + sca.d_sbLemmas[tn][sz].push_back(lem); TNode x = getFreeVar( tn ); unsigned csz = getSearchSizeForAnchor( a ); int max_depth = ((int)csz)-((int)sz); @@ -1212,42 +1224,49 @@ void SygusExtension::registerSymBreakLemma( TypeNode tn, Node lem, unsigned sz, { slem = nm->mkNode(OR, rlv, slem); } - lemmas.push_back(slem); + d_im.lemma(slem, InferenceId::DATATYPES_SYGUS_SYM_BREAK); } } } } } -void SygusExtension::addSymBreakLemmasFor( TypeNode tn, Node t, unsigned d, std::vector< Node >& lemmas ) { +void SygusExtension::addSymBreakLemmasFor(TypeNode tn, TNode t, unsigned d) +{ Assert(d_term_to_anchor.find(t) != d_term_to_anchor.end()); Node a = d_term_to_anchor[t]; - addSymBreakLemmasFor( tn, t, d, a, lemmas ); + addSymBreakLemmasFor(tn, t, d, a); } -void SygusExtension::addSymBreakLemmasFor( TypeNode tn, Node t, unsigned d, Node a, std::vector< Node >& lemmas ) { +void SygusExtension::addSymBreakLemmasFor(TypeNode tn, + TNode t, + unsigned d, + Node a) +{ Assert(t.getType() == tn); Assert(!a.isNull()); Trace("sygus-sb-debug2") << "add sym break lemmas for " << t << " " << d << " " << a << std::endl; SearchCache& sca = d_cache[a]; - std::map>>::iterator its = - sca.d_sb_lemmas.find(tn); + std::map>>::iterator its = + sca.d_sbLemmas.find(tn); Node rlv = getRelevancyCondition(t); NodeManager* nm = NodeManager::currentNM(); - if (its != sca.d_sb_lemmas.end()) + if (its != sca.d_sbLemmas.end()) { TNode x = getFreeVar( tn ); //get symmetry breaking lemmas for this term unsigned csz = getSearchSizeForAnchor( a ); - int max_sz = ((int)csz) - ((int)d); + uint64_t max_sz = d > csz ? 0 : (csz - d); Trace("sygus-sb-debug2") << "add lemmas up to size " << max_sz << ", which is (search_size) " << csz << " - (depth) " << d << std::endl; std::unordered_map cache; - for( std::map< unsigned, std::vector< Node > >::iterator it = its->second.begin(); it != its->second.end(); ++it ){ - if( (int)it->first<=max_sz ){ - for (const Node& lem : it->second) + for (std::pair>& sbls : its->second) + { + if (sbls.first <= max_sz) + { + for (const Node& lem : sbls.second) { Node slem = lem.substitute(x, t, cache); // add the relevancy condition for t @@ -1255,7 +1274,7 @@ void SygusExtension::addSymBreakLemmasFor( TypeNode tn, Node t, unsigned d, Node { slem = nm->mkNode(OR, rlv, slem); } - lemmas.push_back(slem); + d_im.lemma(slem, InferenceId::DATATYPES_SYGUS_SYM_BREAK); } } } @@ -1263,14 +1282,15 @@ void SygusExtension::addSymBreakLemmasFor( TypeNode tn, Node t, unsigned d, Node Trace("sygus-sb-debug2") << "...finished." << std::endl; } -void SygusExtension::preRegisterTerm( TNode n, std::vector< Node >& lemmas ) { +void SygusExtension::preRegisterTerm(TNode n) +{ if( n.isVar() ){ Trace("sygus-sb-debug") << "Pre-register variable : " << n << std::endl; - registerSizeTerm( n, lemmas ); + registerSizeTerm(n); } } -void SygusExtension::registerSizeTerm(Node e, std::vector& lemmas) +void SygusExtension::registerSizeTerm(Node e) { if (d_register_st.find(e) != d_register_st.end()) { @@ -1344,15 +1364,15 @@ void SygusExtension::registerSizeTerm(Node e, std::vector& lemmas) if (options::sygusFairMax()) { Node ds = nm->mkNode(DT_SIZE, e); - slem = nm->mkNode(LEQ, ds, d_szinfo[m]->getOrMkMeasureValue(lemmas)); + slem = nm->mkNode(LEQ, ds, d_szinfo[m]->getOrMkMeasureValue()); }else{ - Node mt = d_szinfo[m]->getOrMkActiveMeasureValue(lemmas); - Node new_mt = d_szinfo[m]->getOrMkActiveMeasureValue(lemmas, true); + Node mt = d_szinfo[m]->getOrMkActiveMeasureValue(); + Node new_mt = d_szinfo[m]->getOrMkActiveMeasureValue(true); Node ds = nm->mkNode(DT_SIZE, e); slem = mt.eqNode(nm->mkNode(PLUS, new_mt, ds)); } Trace("sygus-sb") << "...size lemma : " << slem << std::endl; - lemmas.push_back(slem); + d_im.lemma(slem, InferenceId::DATATYPES_SYGUS_MT_BOUND); } if (d_tds->isVariableAgnosticEnumerator(e)) { @@ -1380,7 +1400,7 @@ void SygusExtension::registerSizeTerm(Node e, std::vector& lemmas) Trace("sygus-sb") << "...variable order : " << preNoVarProc << std::endl; Trace("sygus-sb-tp") << "...variable order : " << preNoVarProc << std::endl; - lemmas.push_back(preNoVarProc); + d_im.lemma(preNoVarProc, InferenceId::DATATYPES_SYGUS_VAR_AGNOSTIC); } } } @@ -1390,15 +1410,15 @@ void SygusExtension::registerMeasureTerm( Node m ) { d_szinfo.find(m); if( it==d_szinfo.end() ){ Trace("sygus-sb") << "Sygus : register measure term : " << m << std::endl; - d_szinfo[m].reset(new SygusSizeDecisionStrategy( - m, d_state.getSatContext(), d_state.getValuation())); + d_szinfo[m].reset(new SygusSizeDecisionStrategy(d_im, m, d_state)); // register this as a decision strategy d_dm->registerStrategy(DecisionManager::STRAT_DT_SYGUS_ENUM_SIZE, d_szinfo[m].get()); } } -void SygusExtension::notifySearchSize( Node m, unsigned s, Node exp, std::vector< Node >& lemmas ) { +void SygusExtension::notifySearchSize(TNode m, uint64_t s, Node exp) +{ std::map>::iterator its = d_szinfo.find(m); Assert(its != d_szinfo.end()); @@ -1411,19 +1431,9 @@ void SygusExtension::notifySearchSize( Node m, unsigned s, Node exp, std::vector Trace("sygus-fair") << "SygusExtension:: now considering term measure : " << s << " for " << m << std::endl; Assert(s >= its->second->d_curr_search_size); while( s>its->second->d_curr_search_size ){ - incrementCurrentSearchSize( m, lemmas ); + incrementCurrentSearchSize(m); } Trace("sygus-fair") << "...finish increment for term measure : " << s << std::endl; - /* - //re-add all testers (some may now be relevant) TODO - for( IntMap::const_iterator it = d_testers.begin(); it != d_testers.end(); - ++it ){ Node n = (*it).first; NodeMap::const_iterator itx = - d_testers_exp.find( n ); if( itx!=d_testers_exp.end() ){ int tindex = - (*it).second; Node exp = (*itx).second; assertTester( tindex, n, exp, lemmas - ); }else{ Assert( false ); - } - } - */ } } @@ -1450,8 +1460,9 @@ unsigned SygusExtension::getSearchSizeForMeasureTerm(Node m) Assert(its != d_szinfo.end()); return its->second->d_curr_search_size; } - -void SygusExtension::incrementCurrentSearchSize( Node m, std::vector< Node >& lemmas ) { + +void SygusExtension::incrementCurrentSearchSize(TNode m) +{ std::map>::iterator itsz = d_szinfo.find(m); Assert(itsz != d_szinfo.end()); @@ -1464,12 +1475,14 @@ void SygusExtension::incrementCurrentSearchSize( Node m, std::vector< Node >& le // check whether a is bounded by m Assert(d_anchor_to_measure_term.find(a) != d_anchor_to_measure_term.end()); if( d_anchor_to_measure_term[a]==m ){ - for( std::map< TypeNode, std::map< unsigned, std::vector< Node > > >::iterator its = itc->second.d_sb_lemmas.begin(); - its != itc->second.d_sb_lemmas.end(); ++its ){ - TypeNode tn = its->first; + for (std::pair>>& + sbl : itc->second.d_sbLemmas) + { + TypeNode tn = sbl.first; TNode x = getFreeVar( tn ); - for( std::map< unsigned, std::vector< Node > >::iterator it = its->second.begin(); it != its->second.end(); ++it ){ - unsigned sz = it->first; + for (std::pair>& s : sbl.second) + { + unsigned sz = s.first; int new_depth = ((int)itsz->second->d_curr_search_size) - ((int)sz); std::map< unsigned, std::vector< Node > >::iterator itt = itc->second.d_search_terms[tn].find( new_depth ); if( itt!=itc->second.d_search_terms[tn].end() ){ @@ -1477,18 +1490,18 @@ void SygusExtension::incrementCurrentSearchSize( Node m, std::vector< Node >& le { if (!options::sygusSymBreakLazy() || (d_active_terms.find(t) != d_active_terms.end() - && !it->second.empty())) + && !s.second.empty())) { Node rlv = getRelevancyCondition(t); std::unordered_map cache; - for (const Node& lem : it->second) + for (const Node& lem : s.second) { Node slem = lem.substitute(x, t, cache); if (!rlv.isNull()) { slem = nm->mkNode(OR, rlv, slem); } - lemmas.push_back(slem); + d_im.lemma(slem, InferenceId::DATATYPES_SYGUS_SYM_BREAK); } } } @@ -1499,9 +1512,13 @@ void SygusExtension::incrementCurrentSearchSize( Node m, std::vector< Node >& le } } -void SygusExtension::check( std::vector< Node >& lemmas ) { +void SygusExtension::check() +{ Trace("sygus-sb") << "SygusExtension::check" << std::endl; + // reset the count of lemmas sent + d_im.reset(); + // check for externally registered symmetry breaking lemmas std::vector anchors; if (d_tds->hasSymBreakLemmas(anchors)) @@ -1527,21 +1544,21 @@ void SygusExtension::check( std::vector< Node >& lemmas ) { // register the lemma template TypeNode tn = d_tds->getTypeForSymBreakLemma(lem); unsigned sz = d_tds->getSizeForSymBreakLemma(lem); - registerSymBreakLemma(tn, lem, sz, a, lemmas); + registerSymBreakLemma(tn, lem, sz, a); } else { Trace("dt-sygus-debug") << "DT sym break lemma : " << lem << std::endl; // it is a normal lemma - lemmas.push_back(lem); + d_im.lemma(lem, InferenceId::DATATYPES_SYGUS_ENUM_SYM_BREAK); } } d_tds->clearSymBreakLemmas(a); } } } - if (!lemmas.empty()) + if (d_im.hasSentLemma()) { return; } @@ -1557,7 +1574,7 @@ void SygusExtension::check( std::vector< Node >& lemmas ) { if (d_register_st.find(prog) == d_register_st.end()) { // not yet registered, do so now - registerSizeTerm(prog, lemmas); + registerSizeTerm(prog); needsRecheck = true; } else @@ -1575,7 +1592,7 @@ void SygusExtension::check( std::vector< Node >& lemmas ) { } // first check that the value progv for prog is what we expected bool isExc = true; - if (checkValue(prog, progv, 0, lemmas)) + if (checkValue(prog, progv, 0)) { isExc = false; //debugging : ensure fairness was properly handled @@ -1591,7 +1608,7 @@ void SygusExtension::check( std::vector< Node >& lemmas ) { Node szlem = NodeManager::currentNM()->mkNode( kind::OR, prog.eqNode( progv ).negate(), prog_sz.eqNode( progv_sz ) ); Trace("sygus-sb-warn") << "SygusSymBreak : WARNING : adding size correction : " << szlem << std::endl; - lemmas.push_back(szlem); + d_im.lemma(szlem, InferenceId::DATATYPES_SYGUS_SIZE_CORRECTION); isExc = true; } } @@ -1603,8 +1620,8 @@ void SygusExtension::check( std::vector< Node >& lemmas ) { bool isVarAgnostic = d_tds->isVariableAgnosticEnumerator(prog); // check that it is unique up to theory-specific rewriting and // conjecture-specific symmetry breaking. - Node rsv = registerSearchValue( - prog, prog, progv, 0, lemmas, isVarAgnostic, true); + Node rsv = + registerSearchValue(prog, prog, progv, 0, isVarAgnostic, true); if (rsv.isNull()) { isExc = true; @@ -1624,12 +1641,12 @@ void SygusExtension::check( std::vector< Node >& lemmas ) { if (needsRecheck) { Trace("sygus-sb") << " SygusExtension::rechecking..." << std::endl; - return check(lemmas); + return check(); } if (Trace.isOn("sygus-engine") && !d_szinfo.empty()) { - if (lemmas.empty()) + if (d_im.hasSentLemma()) { Trace("sygus-engine") << "*** Sygus : passed datatypes check. term size(s) : "; for (std::pair>& @@ -1644,18 +1661,11 @@ void SygusExtension::check( std::vector< Node >& lemmas ) { { Trace("sygus-engine") << "*** Sygus : produced symmetry breaking lemmas" << std::endl; - for (const Node& lem : lemmas) - { - Trace("sygus-engine-debug") << " " << lem << std::endl; - } } } } -bool SygusExtension::checkValue(Node n, - Node vn, - int ind, - std::vector& lemmas) +bool SygusExtension::checkValue(Node n, TNode vn, int ind) { if (vn.getKind() != kind::APPLY_CONSTRUCTOR) { @@ -1699,14 +1709,14 @@ bool SygusExtension::checkValue(Node n, "missing split for " << n << "." << std::endl; Assert(!split.isNull()); - lemmas.push_back( split ); + d_im.lemma(split, InferenceId::DATATYPES_SYGUS_VALUE_CORRECTION); return false; } } for( unsigned i=0; imkNode( APPLY_SELECTOR_TOTAL, dt[cindex].getSelectorInternal(tn, i), n); - if (!checkValue(sel, vn[i], ind + 1, lemmas)) + if (!checkValue(sel, vn[i], ind + 1)) { return false; } @@ -1737,35 +1747,42 @@ Node SygusExtension::getCurrentTemplate( Node n, std::map< TypeNode, int >& var_ } } -Node SygusExtension::SygusSizeDecisionStrategy::getOrMkMeasureValue( - std::vector& lemmas) +SygusExtension::SygusSizeDecisionStrategy::SygusSizeDecisionStrategy( + InferenceManager& im, Node t, TheoryState& s) + : DecisionStrategyFmf(s.getSatContext(), s.getValuation()), + d_this(t), + d_curr_search_size(0), + d_im(im) +{ +} + +Node SygusExtension::SygusSizeDecisionStrategy::getOrMkMeasureValue() { if (d_measure_value.isNull()) { - d_measure_value = NodeManager::currentNM()->mkSkolem( - "mt", NodeManager::currentNM()->integerType()); - lemmas.push_back(NodeManager::currentNM()->mkNode( - kind::GEQ, - d_measure_value, - NodeManager::currentNM()->mkConst(Rational(0)))); + NodeManager* nm = NodeManager::currentNM(); + d_measure_value = nm->mkSkolem("mt", nm->integerType()); + Node mtlem = + nm->mkNode(kind::GEQ, d_measure_value, nm->mkConst(Rational(0))); + d_im.lemma(mtlem, InferenceId::DATATYPES_SYGUS_MT_POS); } return d_measure_value; } Node SygusExtension::SygusSizeDecisionStrategy::getOrMkActiveMeasureValue( - std::vector& lemmas, bool mkNew) + bool mkNew) { if (mkNew) { - Node new_mt = NodeManager::currentNM()->mkSkolem( - "mt", NodeManager::currentNM()->integerType()); - lemmas.push_back(NodeManager::currentNM()->mkNode( - kind::GEQ, new_mt, NodeManager::currentNM()->mkConst(Rational(0)))); + NodeManager* nm = NodeManager::currentNM(); + Node new_mt = nm->mkSkolem("mt", nm->integerType()); + Node mtlem = nm->mkNode(kind::GEQ, new_mt, nm->mkConst(Rational(0))); d_measure_value_active = new_mt; + d_im.lemma(mtlem, InferenceId::DATATYPES_SYGUS_MT_POS); } else if (d_measure_value_active.isNull()) { - d_measure_value_active = getOrMkMeasureValue(lemmas); + d_measure_value_active = getOrMkMeasureValue(); } return d_measure_value_active; } diff --git a/src/theory/datatypes/sygus_extension.h b/src/theory/datatypes/sygus_extension.h index c35fc86ff..6cf96eefc 100644 --- a/src/theory/datatypes/sygus_extension.h +++ b/src/theory/datatypes/sygus_extension.h @@ -76,35 +76,36 @@ class SygusExtension ~SygusExtension(); /** * Notify this class that tester for constructor tindex has been asserted for - * n. Exp is the literal corresponding to this tester. This method may add - * lemmas to the vector lemmas, for details see assertTesterInternal below. + * n. Exp is the literal corresponding to this tester. This method may send + * lemmas via inference manager, for details see assertTesterInternal below. * These lemmas are sent out on the output channel of datatypes by the caller. */ - void assertTester(int tindex, TNode n, Node exp, std::vector& lemmas); + void assertTester(int tindex, TNode n, Node exp); /** * Notify this class that literal n has been asserted with the given - * polarity. This method may add lemmas to the vector lemmas, for instance + * polarity. This method may send lemmas via inference manager, for instance * based on inferring consequences of (not) n. One example is if n is * (DT_SIZE_BOUND x n), we add the lemma: * (DT_SIZE_BOUND x n) <=> ((DT_SIZE x) <= n ) */ - void assertFact(Node n, bool polarity, std::vector& lemmas); + void assertFact(Node n, bool polarity); /** pre-register term n * * This is called when n is pre-registered with the theory of datatypes. - * If n is a sygus enumerator, then we may add lemmas to the vector lemmas + * If n is a sygus enumerator, then we may send lemmas via inference manager * that are used to enforce fairness regarding the size of n. */ - void preRegisterTerm(TNode n, std::vector& lemmas); + void preRegisterTerm(TNode n); /** check * * This is called at last call effort, when the current model assignment is * satisfiable according to the quantifier-free decision procedures and a - * model is built. This method may add lemmas to the vector lemmas based + * model is built. This method may send lemmas via inference manager based * on dynamic symmetry breaking techniques, based on the model values of * all preregistered enumerators. */ - void check(std::vector& lemmas); + void check(); + private: /** The theory state of the datatype theory */ TheoryState& d_state; @@ -191,7 +192,7 @@ private: */ std::map< TypeNode, std::map< unsigned, std::vector< Node > > > d_search_terms; /** A cache of all symmetry breaking lemma templates for (types, sizes). */ - std::map< TypeNode, std::map< unsigned, std::vector< Node > > > d_sb_lemmas; + std::map>> d_sbLemmas; /** search value * * For each sygus type, a map from a builtin term to a sygus term for that @@ -298,8 +299,7 @@ private: * A -> A+A | x | 1 | 0 * when is_+( d ) is asserted, * assertTesterInternal(0, s( d ), is_+( s( d ) ),...) is called. This - * function may add lemmas to lemmas, which are sent out on the output - * channel of datatypes by the caller. + * function may send lemmas via inference manager. * * These lemmas are of various forms, including: * (1) dynamic symmetry breaking clauses for subterms of n (those added to @@ -312,13 +312,13 @@ private: * size( d ) <= 1 V ~is-C1( d ) V ~is-C2( d.1 ) * where C1 and C2 are non-nullary constructors. */ - void assertTesterInternal( int tindex, TNode n, Node exp, std::vector< Node >& lemmas ); + void assertTesterInternal(int tindex, TNode n, Node exp); /** * This function is called when term n is registered to the theory of * datatypes. It makes the appropriate call to registerSearchTerm below, * if applicable. */ - void registerTerm(Node n, std::vector& lemmas); + void registerTerm(Node n); //------------------------dynamic symmetry breaking /** Register search term @@ -334,7 +334,7 @@ private: * are active for n (see description of addSymBreakLemmasFor) are added to * lemmas in this call. */ - void registerSearchTerm( TypeNode tn, unsigned d, Node n, bool topLevel, std::vector< Node >& lemmas ); + void registerSearchTerm(TypeNode tn, unsigned d, Node n, bool topLevel); /** Register search value * * This function is called when a selector chain n has been assigned a model @@ -356,7 +356,7 @@ private: * Registering search value d -> x followed by d -> +( x, 0 ) results in the * construction of the symmetry breaking lemma template: * ~is_+( z ) V ~is_x( z.1 ) V ~is_0( z.2 ) - * which is stored in d_cache[a].d_sb_lemmas. This lemma is instantiated with + * which is stored in d_cache[a].d_sbLemmas. This lemma is instantiated with * z -> t for all terms t of appropriate depth, including d. * This function strengthens blocking clauses using generalization techniques * described in Reynolds et al SYNT 2017. @@ -392,18 +392,16 @@ private: Node n, Node nv, unsigned d, - std::vector& lemmas, bool isVarAgnostic, bool doSym); /** Register symmetry breaking lemma * * This function adds the symmetry breaking lemma template lem for terms of - * type tn with anchor a. This is added to d_cache[a].d_sb_lemmas. Notice that + * type tn with anchor a. This is added to d_cache[a].d_sbLemmas. Notice that * we use lem as a template with free variable x, e.g. our template is: * (lambda ((x tn)) lem) * where x = getFreeVar( tn ). For all search terms t of the appropriate - * depth, - * we add the lemma lem{ x -> t } to lemmas. + * depth, we send the lemma lem{ x -> t } via the inference manager. * * The argument sz indicates the size of terms that the lemma applies to, e.g. * ~is_+( z ) has size 1 @@ -412,8 +410,7 @@ private: * This is equivalent to sum of weights of constructors corresponding to each * tester, e.g. above + has weight 1, and x and 0 have weight 0. */ - void registerSymBreakLemma( - TypeNode tn, Node lem, unsigned sz, Node a, std::vector& lemmas); + void registerSymBreakLemma(TypeNode tn, Node lem, unsigned sz, Node a); /** Register symmetry breaking lemma for value * * This function adds a symmetry breaking lemma template for selector chains @@ -428,18 +425,18 @@ private: * generalization. * * This function may add instances of the symmetry breaking template for - * existing search terms, which are added to lemmas. + * existing search terms, which are sent via the inference manager. */ void registerSymBreakLemmaForValue(Node a, Node val, quantifiers::SygusInvarianceTest& et, Node valr, - std::map& var_count, - std::vector& lemmas); + std::map& var_count); /** Add symmetry breaking lemmas for term * - * Adds all active symmetry breaking lemmas for selector chain t to lemmas. A - * symmetry breaking lemma L is active for t based on three factors: + * Sends all active symmetry breaking lemmas for selector chain t via the + * inference manager. A symmetry breaking lemma L is active for t based on + * three factors: * (1) the current search size sz(a) for its anchor a, * (2) the depth d of term t (see d_term_to_depth), * (3) the size sz(L) of the symmetry breaking lemma L. @@ -452,10 +449,9 @@ private: * a : the anchor of term t, * d : the depth of term t. */ - void addSymBreakLemmasFor( - TypeNode tn, Node t, unsigned d, Node a, std::vector& lemmas); + void addSymBreakLemmasFor(TypeNode tn, TNode t, unsigned d, Node a); /** calls the above function where a is the anchor t */ - void addSymBreakLemmasFor( TypeNode tn, Node t, unsigned d, std::vector< Node >& lemmas ); + void addSymBreakLemmasFor(TypeNode tn, TNode t, unsigned d); //------------------------end dynamic symmetry breaking /** Get relevancy condition @@ -553,17 +549,15 @@ private: * decision strategy decides on literals of the form (DT_SYGUS_BOUND m n). * * After determining the measure term m for e, if applicable, we initialize - * SygusSizeDecisionStrategy for m below. This may result in lemmas + * SygusSizeDecisionStrategy for m below. This may result in lemmas sent via + * the inference manager. */ - void registerSizeTerm(Node e, std::vector& lemmas); + void registerSizeTerm(Node e); /** A decision strategy for each measure term allocated by this class */ class SygusSizeDecisionStrategy : public DecisionStrategyFmf { public: - SygusSizeDecisionStrategy(Node t, context::Context* c, Valuation valuation) - : DecisionStrategyFmf(c, valuation), d_this(t), d_curr_search_size(0) - { - } + SygusSizeDecisionStrategy(InferenceManager& im, Node t, TheoryState& s); /** the measure term */ Node d_this; /** @@ -593,7 +587,7 @@ private: * literals. Then, if we are enforcing fairness based on the maximum size, * we assert: (DT_SIZE e) <= v for all enumerators e. */ - Node getOrMkMeasureValue(std::vector& lemmas); + Node getOrMkMeasureValue(); /** get or make the active measure value * * The active measure value av is an integer variable that corresponds to @@ -611,8 +605,7 @@ private: * If the flag mkNew is set to true, then we return a fresh variable and * update the active measure value. */ - Node getOrMkActiveMeasureValue(std::vector& lemmas, - bool mkNew = false); + Node getOrMkActiveMeasureValue(bool mkNew = false); /** Returns the s^th fairness literal for this measure term. */ Node mkLiteral(unsigned s) override; /** identify */ @@ -622,6 +615,8 @@ private: } private: + /** The inference manager we are using */ + InferenceManager& d_im; /** the measure value */ Node d_measure_value; /** the sygus measure value */ @@ -650,7 +645,7 @@ private: * of how search size affects which lemmas are relevant above * addSymBreakLemmasFor. */ - void incrementCurrentSearchSize( Node m, std::vector< Node >& lemmas ); + void incrementCurrentSearchSize(TNode m); /** * Notify this class that we are currently searching for terms of size at * most s as model values for measure term m. Literal exp corresponds to the @@ -658,7 +653,7 @@ private: * incrementSearchSize above, until the total number of times we have called * incrementSearchSize so far is at least s. */ - void notifySearchSize( Node m, unsigned s, Node exp, std::vector< Node >& lemmas ); + void notifySearchSize(TNode m, uint64_t s, Node exp); /** Allocates a SygusSizeDecisionStrategy object in d_szinfo. */ void registerMeasureTerm( Node m ); /** @@ -706,7 +701,7 @@ private: * method should not ever add anything to lemmas. However, due to its * importance, we check this regardless. */ - bool checkValue(Node n, Node vn, int ind, std::vector& lemmas); + bool checkValue(Node n, TNode vn, int ind); /** * Get the current SAT status of the guard g. * In particular, this returns 1 if g is asserted true, -1 if it is asserted diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index d0b7790b2..53680530d 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -184,9 +184,7 @@ void TheoryDatatypes::postCheck(Effort level) if (level == EFFORT_LAST_CALL) { Assert(d_sygusExtension != nullptr); - std::vector lemmas; - d_sygusExtension->check(lemmas); - d_im.sendLemmas(lemmas, InferenceId::UNKNOWN); + d_sygusExtension->check(); return; } else if (level == EFFORT_FULL && !d_state.isInConflict() @@ -397,9 +395,7 @@ void TheoryDatatypes::notifyFact(TNode atom, // could be sygus-specific if (d_sygusExtension) { - std::vector< Node > lemmas; - d_sygusExtension->assertFact(atom, polarity, lemmas); - d_im.sendLemmas(lemmas, InferenceId::UNKNOWN); + d_sygusExtension->assertFact(atom, polarity); } //add to tester if applicable Node t_arg; @@ -419,10 +415,8 @@ void TheoryDatatypes::notifyFact(TNode atom, if (d_sygusExtension) { Trace("dt-tester") << "Assert tester to sygus : " << atom << std::endl; - std::vector< Node > lemmas; - d_sygusExtension->assertTester(tindex, t_arg, atom, lemmas); + d_sygusExtension->assertTester(tindex, t_arg, atom); Trace("dt-tester") << "Done assert tester to sygus." << std::endl; - d_im.sendLemmas(lemmas, InferenceId::UNKNOWN); } } }else{ @@ -480,9 +474,7 @@ void TheoryDatatypes::preRegisterTerm(TNode n) d_equalityEngine->addTerm(n); if (d_sygusExtension) { - std::vector< Node > lemmas; - d_sygusExtension->preRegisterTerm(n, lemmas); - d_im.sendLemmas(lemmas, InferenceId::UNKNOWN); + d_sygusExtension->preRegisterTerm(n); } break; } diff --git a/src/theory/inference_id.cpp b/src/theory/inference_id.cpp index 3db147a16..7acf2e861 100644 --- a/src/theory/inference_id.cpp +++ b/src/theory/inference_id.cpp @@ -101,6 +101,27 @@ const char* toString(InferenceId i) case InferenceId::DATATYPES_CYCLE: return "DATATYPES_CYCLE"; case InferenceId::DATATYPES_SIZE_POS: return "DATATYPES_SIZE_POS"; case InferenceId::DATATYPES_HEIGHT_ZERO: return "DATATYPES_HEIGHT_ZERO"; + case InferenceId::DATATYPES_SYGUS_SYM_BREAK: + return "DATATYPES_SYGUS_SYM_BREAK"; + case InferenceId::DATATYPES_SYGUS_CDEP_SYM_BREAK: + return "DATATYPES_SYGUS_CDEP_SYM_BREAK"; + case InferenceId::DATATYPES_SYGUS_ENUM_SYM_BREAK: + return "DATATYPES_SYGUS_ENUM_SYM_BREAK"; + case InferenceId::DATATYPES_SYGUS_SIMPLE_SYM_BREAK: + return "DATATYPES_SYGUS_SIMPLE_SYM_BREAK"; + case InferenceId::DATATYPES_SYGUS_FAIR_SIZE: + return "DATATYPES_SYGUS_FAIR_SIZE"; + case InferenceId::DATATYPES_SYGUS_FAIR_SIZE_CONFLICT: + return "DATATYPES_SYGUS_FAIR_SIZE_CONFLICT"; + case InferenceId::DATATYPES_SYGUS_VAR_AGNOSTIC: + return "DATATYPES_SYGUS_VAR_AGNOSTIC"; + case InferenceId::DATATYPES_SYGUS_SIZE_CORRECTION: + return "DATATYPES_SYGUS_SIZE_CORRECTION"; + case InferenceId::DATATYPES_SYGUS_VALUE_CORRECTION: + return "DATATYPES_SYGUS_VALUE_CORRECTION"; + case InferenceId::DATATYPES_SYGUS_MT_BOUND: + return "DATATYPES_SYGUS_MT_BOUND"; + case InferenceId::DATATYPES_SYGUS_MT_POS: return "DATATYPES_SYGUS_MT_POS"; case InferenceId::SEP_PTO_NEG_PROP: return "SEP_PTO_NEG_PROP"; case InferenceId::SEP_PTO_PROP: return "SEP_PTO_PROP"; diff --git a/src/theory/inference_id.h b/src/theory/inference_id.h index 73f7a2404..8cc678162 100644 --- a/src/theory/inference_id.h +++ b/src/theory/inference_id.h @@ -167,6 +167,37 @@ enum class InferenceId DATATYPES_SIZE_POS, // (=> (= (dt.height t) 0) => (and (= (dt.height (sel_1 t)) 0) .... )) DATATYPES_HEIGHT_ZERO, + //-------------------- sygus extension + // a sygus symmetry breaking lemma (or ~is-C1( t1 ) V ... V ~is-Cn( tn ) ) + // where t1 ... tn are unique shared selector chains. For details see + // Reynolds et al CAV 2019 + DATATYPES_SYGUS_SYM_BREAK, + // a conjecture-dependent symmetry breaking lemma, which may be used to + // exclude constructors for variables that irrelevant for a synthesis + // conjecture + DATATYPES_SYGUS_CDEP_SYM_BREAK, + // an enumerator-specific symmetry breaking lemma, which are used e.g. for + // excluding certain kinds of constructors + DATATYPES_SYGUS_ENUM_SYM_BREAK, + // a simple static symmetry breaking lemma (see Reynolds et al CAV 2019) + DATATYPES_SYGUS_SIMPLE_SYM_BREAK, + // (dt.size t) <= N, to implement fair enumeration when sygus-fair=dt-size + DATATYPES_SYGUS_FAIR_SIZE, + // (dt.size t) <= N => (or ~is-C1( t1 ) V ... V ~is-Cn( tn ) ) if using + // sygus-fair=direct + DATATYPES_SYGUS_FAIR_SIZE_CONFLICT, + // used for implementing variable agnostic enumeration + DATATYPES_SYGUS_VAR_AGNOSTIC, + // handles case the model value for a sygus term violates the size bound + DATATYPES_SYGUS_SIZE_CORRECTION, + // handles case the model value for a sygus term does not exist + DATATYPES_SYGUS_VALUE_CORRECTION, + // s <= (dt.size t), where s is a term that must be less than the current + // size bound based on our fairness strategy. For instance, s may be + // (dt.size e) for (each) enumerator e when multiple enumerators are present. + DATATYPES_SYGUS_MT_BOUND, + // (dt.size t) >= 0 + DATATYPES_SYGUS_MT_POS, // ---------------------------------- end datatypes theory //-------------------------------------- quantifiers theory -- 2.30.2