From: Andrew Reynolds Date: Tue, 6 Mar 2018 20:22:43 +0000 (-0600) Subject: Refactor symmetry breaking in datatypes sygus (#1640) X-Git-Tag: cvc5-1.0.0~5248 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=e0909efd64c96311c69dec223411ab6b7988d01d;p=cvc5.git Refactor symmetry breaking in datatypes sygus (#1640) --- diff --git a/src/theory/datatypes/datatypes_sygus.cpp b/src/theory/datatypes/datatypes_sygus.cpp index b8185e9c8..1779ab27b 100644 --- a/src/theory/datatypes/datatypes_sygus.cpp +++ b/src/theory/datatypes/datatypes_sygus.cpp @@ -669,14 +669,7 @@ Node SygusSymBreakNew::getSimpleSymBreakPred( TypeNode tn, int tindex, unsigned } TNode SygusSymBreakNew::getFreeVar( TypeNode tn ) { - std::map< TypeNode, Node >::iterator it = d_free_var.find( tn ); - if( it==d_free_var.end() ){ - Node x = NodeManager::currentNM()->mkSkolem( "x", tn ); - d_free_var[tn] = x; - return x; - }else{ - return it->second; - } + return d_tds->getFreeVar(tn, 0); } unsigned SygusSymBreakNew::processSelectorChain( Node n, std::map< TypeNode, Node >& top_level, std::map< Node, unsigned >& tdepth, std::vector< Node >& lemmas ) { @@ -741,15 +734,11 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d, Trace("sygus-sb-debug") << " ......rewrites to " << bvr << std::endl; Trace("dt-sygus") << " * DT builtin : " << n << " -> " << bvr << std::endl; unsigned sz = d_tds->getSygusTermSize( nv ); - std::vector< Node > exp; - bool do_exclude = false; if( d_tds->involvesDivByZero( bvr ) ){ - Node x = getFreeVar( tn ); quantifiers::DivByZeroSygusInvarianceTest dbzet; Trace("sygus-sb-mexp-debug") << "Minimize explanation for div-by-zero in " << d_tds->sygusToBuiltin( nv ) << std::endl; - d_tds->getExplain()->getExplanationFor( - x, nv, exp, dbzet, Node::null(), sz); - do_exclude = true; + registerSymBreakLemmaForValue(a, nv, dbzet, Node::null(), lemmas); + return false; }else{ std::map< Node, Node >::iterator itsv = d_cache[a].d_search_val[tn].find( bvr ); Node bad_val_bvr; @@ -880,43 +869,43 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d, // do analysis of the evaluation FIXME: does not work (evaluation is non-constant) quantifiers::EquivSygusInvarianceTest eset; eset.init(d_tds, tn, aconj, a, bvr); + Trace("sygus-sb-mexp-debug") << "Minimize explanation for eval[" << d_tds->sygusToBuiltin( bad_val ) << "] = " << bvr << std::endl; - d_tds->getExplain()->getExplanationFor( - x, bad_val, exp, eset, bad_val_o, sz); - do_exclude = true; - } - } - if( do_exclude ){ - Node lem = exp.size()==1 ? exp[0] : NodeManager::currentNM()->mkNode( kind::AND, exp ); - lem = lem.negate(); - /* add min type depth to size : TODO? - Assert( d_term_to_anchor.find( n )!=d_term_to_anchor.end() ); - TypeNode atype = d_term_to_anchor[n].getType(); - if( atype!=tn ){ - unsigned min_type_depth = d_tds->getMinTypeDepth( atype, tn ); - if( min_type_depth>0 ){ - Trace("sygus-sb-exc") << " ........min type depth for " << ((DatatypeType)tn.toType()).getDatatype().getName() << " in "; - Trace("sygus-sb-exc") << ((DatatypeType)atype.toType()).getDatatype().getName() << " is " << min_type_depth << std::endl; - sz = sz + min_type_depth; - } + registerSymBreakLemmaForValue(a, bad_val, eset, bad_val_o, lemmas); + return false; } - */ - Trace("sygus-sb-exc") << " ........exc lemma is " << lem << ", size = " << sz << std::endl; - registerSymBreakLemma( tn, lem, sz, a, lemmas ); - Trace("dt-sygus") - << " ...excluded by dynamic symmetry breaking, based on " << n - << " == " << bvr << std::endl; - return false; } } return true; } - +void SygusSymBreakNew::registerSymBreakLemmaForValue( + Node a, + Node val, + quantifiers::SygusInvarianceTest& et, + Node valr, + std::vector& lemmas) +{ + TypeNode tn = val.getType(); + Node x = getFreeVar(tn); + unsigned sz = d_tds->getSygusTermSize(val); + std::vector exp; + d_tds->getExplain()->getExplanationFor(x, val, exp, et, valr, sz); + Node lem = + exp.size() == 1 ? exp[0] : NodeManager::currentNM()->mkNode(AND, exp); + lem = lem.negate(); + Trace("sygus-sb-exc") << " ........exc lemma is " << lem << ", size = " << sz + << std::endl; + registerSymBreakLemma(tn, lem, sz, a, lemmas); +} void SygusSymBreakNew::registerSymBreakLemma( TypeNode tn, Node lem, unsigned sz, Node a, std::vector< Node >& lemmas ) { // lem holds for all terms of type tn, and is applicable to terms of size sz - Trace("sygus-sb-debug") << " register sym break lemma : " << lem << ", size " << sz << std::endl; + Trace("sygus-sb-debug") << " register sym break lemma : " << lem + << std::endl; + Trace("sygus-sb-debug") << " anchor : " << a << std::endl; + Trace("sygus-sb-debug") << " type : " << tn << std::endl; + Trace("sygus-sb-debug") << " size : " << sz << std::endl; Assert( !a.isNull() ); d_cache[a].d_sb_lemmas[tn][sz].push_back( lem ); TNode x = getFreeVar( tn ); @@ -928,7 +917,7 @@ void SygusSymBreakNew::registerSymBreakLemma( TypeNode tn, Node lem, unsigned sz for( unsigned k=0; ksecond.size(); k++ ){ TNode t = itt->second[k]; if( !options::sygusSymBreakLazy() || d_active_terms.find( t )!=d_active_terms.end() ){ - addSymBreakLemma( tn, lem, x, t, sz, d, lemmas ); + addSymBreakLemma(lem, x, t, lemmas); } } } @@ -953,14 +942,18 @@ void SygusSymBreakNew::addSymBreakLemmasFor( TypeNode tn, Node t, unsigned d, No if( (int)it->first<=max_sz ){ for( unsigned k=0; ksecond.size(); k++ ){ Node lem = it->second[k]; - addSymBreakLemma( tn, lem, x, t, it->first, d, lemmas ); + addSymBreakLemma(lem, x, t, lemmas); } } } } } -void SygusSymBreakNew::addSymBreakLemma( TypeNode tn, Node lem, TNode x, TNode n, unsigned lem_sz, unsigned n_depth, std::vector< Node >& lemmas ) { +void SygusSymBreakNew::addSymBreakLemma(Node lem, + TNode x, + TNode n, + std::vector& lemmas) +{ Assert( !options::sygusSymBreakLazy() || d_active_terms.find( n )!=d_active_terms.end() ); // apply lemma Node slem = lem.substitute( x, n ); @@ -1124,7 +1117,7 @@ void SygusSymBreakNew::incrementCurrentSearchSize( Node m, std::vector< Node >& if( !options::sygusSymBreakLazy() || d_active_terms.find( t )!=d_active_terms.end() ){ for( unsigned j=0; jsecond.size(); j++ ){ Node lem = it->second[j]; - addSymBreakLemma( tn, lem, x, t, sz, new_depth, lemmas ); + addSymBreakLemma(lem, x, t, lemmas); } } } @@ -1137,6 +1130,30 @@ void SygusSymBreakNew::incrementCurrentSearchSize( Node m, std::vector< Node >& void SygusSymBreakNew::check( std::vector< Node >& lemmas ) { Trace("sygus-sb") << "SygusSymBreakNew::check" << std::endl; + + // check for externally registered symmetry breaking lemmas + std::vector anchors; + if (d_tds->hasSymBreakLemmas(anchors)) + { + for (const Node& a : anchors) + { + std::vector sbl; + d_tds->getSymBreakLemmas(a, sbl); + for (const Node& lem : sbl) + { + TypeNode tn = d_tds->getTypeForSymBreakLemma(lem); + unsigned sz = d_tds->getSizeForSymBreakLemma(lem); + registerSymBreakLemma(tn, lem, sz, a, lemmas); + } + } + d_tds->clearSymBreakLemmas(); + if (!lemmas.empty()) + { + return; + } + } + + // register search values, add symmetry breaking lemmas if applicable for( std::map< Node, bool >::iterator it = d_register_st.begin(); it != d_register_st.end(); ++it ){ if( it->second ){ Node prog = it->first; diff --git a/src/theory/datatypes/datatypes_sygus.h b/src/theory/datatypes/datatypes_sygus.h index fa3918270..cb7729658 100644 --- a/src/theory/datatypes/datatypes_sygus.h +++ b/src/theory/datatypes/datatypes_sygus.h @@ -30,6 +30,7 @@ #include "expr/datatype.h" #include "expr/node.h" #include "theory/quantifiers/sygus/ce_guided_conjecture.h" +#include "theory/quantifiers/sygus/sygus_explain.h" #include "theory/quantifiers/sygus_sampler.h" #include "theory/quantifiers/term_database.h" @@ -41,13 +42,14 @@ class TheoryDatatypes; class SygusSymBreakNew { -private: - TheoryDatatypes * d_td; - quantifiers::TermDbSygus * d_tds; typedef context::CDHashMap< Node, int, NodeHashFunction > IntMap; typedef context::CDHashMap< Node, Node, NodeHashFunction > NodeMap; typedef context::CDHashMap< Node, bool, NodeHashFunction > BoolMap; typedef context::CDHashSet NodeSet; + + private: + TheoryDatatypes* d_td; + quantifiers::TermDbSygus* d_tds; IntMap d_testers; IntMap d_is_const; NodeMap d_testers_exp; @@ -86,7 +88,10 @@ private: * */ std::map< Node, bool > d_is_top_level; - void registerTerm( Node n, std::vector< Node >& lemmas ); + /** + * Returns true if the selector chain n is top-level based on the above + * definition, when tn is the type of n. + */ bool computeTopLevel( TypeNode tn, Node n ); private: //list of all terms encountered in search at depth @@ -117,7 +122,7 @@ private: /** For each term, whether this cache has processed that term */ std::map< Node, bool > d_search_val_proc; }; - // anchor -> cache + /** An instance of the above cache, for each anchor */ std::map< Node, SearchCache > d_cache; /** a sygus sampler object for each (anchor, sygus type) pair * @@ -125,21 +130,147 @@ private: * the rewriter. */ std::map> d_sampler; - Node d_null; + /** Assert tester internal + * + * This function is called when the tester with index tindex is asserted for + * n, exp is the tester predicate. For example, for grammar: + * A -> A+A | x | 1 | 0 + * when is_+( d ) is asserted, + * assertTesterInternal(0, s( d ), is_+( s( d ) ),...) is called. This + * function may add lemmas to lemmas, which are sent out on the output + * channel of datatypes by the caller. + * + * These lemmas are of various forms, including: + * (1) dynamic symmetry breaking clauses for subterms of n (those added to + * lemmas on calls to addSymBreakLemmasFor, see function below), + * (2) static symmetry breaking clauses for subterms of n (those added to + * lemmas on getSimpleSymBreakPred, see function below), + * (3) conjecture-specific symmetry breaking lemmas, see + * CegConjecture::getSymmetryBreakingPredicate, + * (4) fairness conflicts if sygusFair() is SYGUS_FAIR_DIRECT, e.g.: + * size( d ) <= 1 V ~is-C1( d ) V ~is-C2( d.1 ) + * where C1 and C2 are non-nullary constructors. + */ void assertTesterInternal( int tindex, TNode n, Node exp, std::vector< Node >& lemmas ); - // register search term + /** + * This function is called when term n is registered to the theory of + * datatypes. It makes the appropriate call to registerSearchTerm below, + * if applicable. + */ + void registerTerm(Node n, std::vector& lemmas); + /** Register search term + * + * This function is called when selector chain S_1( ... S_m( n ) ... ) is + * registered to the theory of datatypes, where tn is the type of n, + * d indicates the depth of n (the sum of weights of the selectors S_1...S_m), + * and topLevel is whether n is a top-level term (see d_is_top_level). + * + * The purpose of this function is to notify this class that symmetry breaking + * lemmas should be instantiated for n. Any symmetry breaking lemmas that + * are active for n (see description of addSymBreakLemmasFor) are added to + * lemmas in this call. + */ void registerSearchTerm( TypeNode tn, unsigned d, Node n, bool topLevel, std::vector< Node >& lemmas ); + /** Register search value + * + * This function is called when a selector chain n has been assigned a model + * value nv. This function calls itself recursively so that extensions of the + * selector chain n are registered with all the subterms of nv. For example, + * if we call this function with: + * n = x, nv = +( 1(), x() ) + * we make recursive calls with: + * n = x.1, nv = 1() and n = x.2, nv = x() + * + * a : the anchor of n, + * d : the depth of n. + * + * This function determines if the value nv is equivalent via rewriting to + * any previously registered search values for anchor a. If so, we construct + * a symmetry breaking lemma template and register it in d_cache[a]. For + * example, for grammar: + * A -> A+A | x | 1 | 0 + * Registering search value d -> x followed by d -> +( x, 0 ) results in the + * construction of the symmetry breaking lemma template: + * ~is_+( z ) V ~is_x( z.1 ) V ~is_0( z.2 ) + * which is stored in d_cache[a].d_sb_lemmas. This lemma is instantiated with + * z -> t for all terms t of appropriate depth, including d. + * This function strengthens blocking clauses using generalization techniques + * described in Reynolds et al SYNT 2017. + */ bool registerSearchValue( Node a, Node n, Node nv, unsigned d, std::vector< Node >& lemmas ); - void registerSymBreakLemma( TypeNode tn, Node lem, unsigned sz, Node e, std::vector< Node >& lemmas ); - void addSymBreakLemmasFor( TypeNode tn, Node t, unsigned d, Node e, std::vector< Node >& lemmas ); + /** Register symmetry breaking lemma + * + * This function adds the symmetry breaking lemma template lem for terms of + * type tn with anchor a. This is added to d_cache[a].d_sb_lemmas. Notice that + * we use lem as a template with free variable x, e.g. our template is: + * (lambda ((x tn)) lem) + * where x = getFreeVar( tn ). For all search terms t of the appropriate + * depth, + * we add the lemma lem{ x -> t } to lemmas. + * + * The argument sz indicates the size of terms that the lemma applies to, e.g. + * ~is_+( z ) has size 1 + * ~is_+( z ) V ~is_x( z.1 ) V ~is_0( z.2 ) has size 1 + * ~is_+( z ) V ~is_+( z.1 ) has size 2 + * This is equivalent to sum of weights of constructors corresponding to each + * tester, e.g. above + has weight 1, and x and 0 have weight 0. + */ + void registerSymBreakLemma( + TypeNode tn, Node lem, unsigned sz, Node a, std::vector& lemmas); + /** Register symmetry breaking lemma for value + * + * This function adds a symmetry breaking lemma template for selector chains + * with anchor a, that effectively states that val should never be a subterm + * of any value for a. + * + * et : an "invariance test" (see sygus/sygus_invariance.h) which states a + * criterion that val meets, which is the reason for its exclusion. This is + * used for generalizing the symmetry breaking lemma template. + * valr : if non-null, this states a value that should *not* be excluded by + * the symmetry breaking lemma template, which is a restriction to the above + * generalization. + * + * This function may add instances of the symmetry breaking template for + * existing search terms, which are added to lemmas. + */ + void registerSymBreakLemmaForValue(Node a, + Node val, + quantifiers::SygusInvarianceTest& et, + Node valr, + std::vector& lemmas); + /** Add symmetry breaking lemmas for term + * + * Adds all active symmetry breaking lemmas for selector chain t to lemmas. A + * symmetry breaking lemma L is active for t based on three factors: + * (1) the current search size sz(a) for its anchor a, + * (2) the depth d of term t (see d_term_to_depth), + * (3) the size sz(L) of the symmetry breaking lemma L. + * In particular, L is active if sz(L) <= sz(a) - d. In other words, a + * symmetry breaking lemma is active if it is intended to block terms of + * size sz(L), and the maximum size that t can take in the current search, + * sz(a)-d, is greater than or equal to this value. + * + * tn : the type of term t, + * a : the anchor of term t, + * d : the depth of term t. + */ + void addSymBreakLemmasFor( + TypeNode tn, Node t, unsigned d, Node a, std::vector& lemmas); + /** calls the above function where a is the anchor t */ void addSymBreakLemmasFor( TypeNode tn, Node t, unsigned d, std::vector< Node >& lemmas ); - void addSymBreakLemma( TypeNode tn, Node lem, TNode x, TNode n, unsigned lem_sz, unsigned n_depth, std::vector< Node >& lemmas ); -private: + /** add symmetry breaking lemma + * + * This adds the lemma R => lem{ x -> n } to lemmas, where R is a "relevancy + * condition" that states which contexts n is relevant in (contexts in which + * the selector chain n is specified). + */ + void addSymBreakLemma(Node lem, TNode x, TNode n, std::vector& lemmas); + + private: std::map< Node, Node > d_rlv_cond; Node getRelevancyCondition( Node n ); private: std::map< TypeNode, std::map< int, std::map< unsigned, Node > > > d_simple_sb_pred; - std::map< TypeNode, Node > d_free_var; // user-context dependent if sygus-incremental std::map< Node, unsigned > d_simple_proc; //get simple symmetry breaking predicate diff --git a/src/theory/quantifiers/sygus/ce_guided_conjecture.cpp b/src/theory/quantifiers/sygus/ce_guided_conjecture.cpp index 1dd4dcbeb..2273db5ea 100644 --- a/src/theory/quantifiers/sygus/ce_guided_conjecture.cpp +++ b/src/theory/quantifiers/sygus/ce_guided_conjecture.cpp @@ -547,9 +547,8 @@ void CegConjecture::printAndContinueStream() { sol = d_cinfo[cprog].d_inst.back(); // add to explanation of exclusion - d_qe->getTermDatabaseSygus() - ->getExplain() - ->getExplanationForConstantEquality(cprog, sol, exp); + d_qe->getTermDatabaseSygus()->getExplain()->getExplanationForEquality( + cprog, sol, exp); } } Assert(!exp.empty()); @@ -612,6 +611,8 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation if (eq_sol != sol) { ++(cei->d_statistics.d_candidate_rewrites); + // if eq_sol is null, then we have an uninteresting candidate rewrite, + // e.g. one that is alpha-equivalent to another. if (!eq_sol.isNull()) { // The analog of terms sol and eq_sol are equivalent under sample diff --git a/src/theory/quantifiers/sygus/sygus_explain.cpp b/src/theory/quantifiers/sygus/sygus_explain.cpp index aafaa07e1..f76edb1c3 100644 --- a/src/theory/quantifiers/sygus/sygus_explain.cpp +++ b/src/theory/quantifiers/sygus/sygus_explain.cpp @@ -110,19 +110,25 @@ Node TermRecBuild::build(unsigned d) return NodeManager::currentNM()->mkNode(d_kind[d], children); } -void SygusExplain::getExplanationForConstantEquality(Node n, - Node vn, - std::vector& exp) +void SygusExplain::getExplanationForEquality(Node n, + Node vn, + std::vector& exp) { std::map cexc; - getExplanationForConstantEquality(n, vn, exp, cexc); + getExplanationForEquality(n, vn, exp, cexc); } -void SygusExplain::getExplanationForConstantEquality( - Node n, Node vn, std::vector& exp, std::map& cexc) +void SygusExplain::getExplanationForEquality(Node n, + Node vn, + std::vector& exp, + std::map& cexc) { - Assert(vn.getKind() == kind::APPLY_CONSTRUCTOR); Assert(n.getType() == vn.getType()); + if (n == vn) + { + return; + } + Assert(vn.getKind() == kind::APPLY_CONSTRUCTOR); TypeNode tn = n.getType(); Assert(tn.isDatatype()); const Datatype& dt = ((DatatypeType)tn.toType()).getDatatype(); @@ -137,22 +143,23 @@ void SygusExplain::getExplanationForConstantEquality( kind::APPLY_SELECTOR_TOTAL, Node::fromExpr(dt[i].getSelectorInternal(tn.toType(), j)), n); - getExplanationForConstantEquality(sel, vn[j], exp); + getExplanationForEquality(sel, vn[j], exp); } } } -Node SygusExplain::getExplanationForConstantEquality(Node n, Node vn) +Node SygusExplain::getExplanationForEquality(Node n, Node vn) { std::map cexc; - return getExplanationForConstantEquality(n, vn, cexc); + return getExplanationForEquality(n, vn, cexc); } -Node SygusExplain::getExplanationForConstantEquality( - Node n, Node vn, std::map& cexc) +Node SygusExplain::getExplanationForEquality(Node n, + Node vn, + std::map& cexc) { std::vector exp; - getExplanationForConstantEquality(n, vn, exp, cexc); + getExplanationForEquality(n, vn, exp, cexc); Assert(!exp.empty()); return exp.size() == 1 ? exp[0] : NodeManager::currentNM()->mkNode(kind::AND, exp); @@ -250,7 +257,7 @@ void SygusExplain::getExplanationFor(TermRecBuild& trb, // if excluded, we may need to add the explanation for this if (vnr_exp.isNull() && !vnr_c.isNull()) { - vnr_exp = getExplanationForConstantEquality(sel, vnr[i]); + vnr_exp = getExplanationForEquality(sel, vnr[i]); } } } @@ -264,7 +271,7 @@ void SygusExplain::getExplanationFor(Node n, unsigned& sz) { // naive : - // return getExplanationForConstantEquality( n, vn, exp ); + // return getExplanationForEquality( n, vn, exp ); // set up the recursion object std::map var_count; diff --git a/src/theory/quantifiers/sygus/sygus_explain.h b/src/theory/quantifiers/sygus/sygus_explain.h index ad26f29e4..818f51438 100644 --- a/src/theory/quantifiers/sygus/sygus_explain.h +++ b/src/theory/quantifiers/sygus/sygus_explain.h @@ -100,7 +100,7 @@ class TermRecBuild * (datatype) sygus term n is: * (if (gt x 0) 0 0) * where if, gt, x, 0 are datatype constructors. - * The explanation returned by getExplanationForConstantEquality + * The explanation returned by getExplanationForEquality * below for n and the above term is: * { ((_ is if) n), ((_ is geq) n.0), * ((_ is x) n.0.0), ((_ is 0) n.0.1), @@ -142,20 +142,19 @@ class SygusExplain public: SygusExplain(TermDbSygus* tdb) : d_tdb(tdb) {} ~SygusExplain() {} - /** get explanation for constant equality + /** get explanation for equality * * This function constructs an explanation, stored in exp, such that: * - All formulas in exp are of the form ((_ is C) ns), where ns * is a chain of selectors applied to n, and * - exp => ( n = vn ) */ - void getExplanationForConstantEquality(Node n, - Node vn, - std::vector& exp); + void getExplanationForEquality(Node n, Node vn, std::vector& exp); /** returns the conjunction of exp computed in the above function */ - Node getExplanationForConstantEquality(Node n, Node vn); + Node getExplanationForEquality(Node n, Node vn); - /** get explanation for constant equality + /** get explanation for equality + * * This is identical to the above function except that we * take an additional argument cexc, which says which * children of vn should be excluded from the explanation. @@ -165,14 +164,14 @@ class SygusExplain * { ((_ is plus) n), ((_ is y) n.1) } * where notice that the 0^th argument of vn is excluded. */ - void getExplanationForConstantEquality(Node n, - Node vn, - std::vector& exp, - std::map& cexc); + void getExplanationForEquality(Node n, + Node vn, + std::vector& exp, + std::map& cexc); /** returns the conjunction of exp computed in the above function */ - Node getExplanationForConstantEquality(Node n, - Node vn, - std::map& cexc); + Node getExplanationForEquality(Node n, + Node vn, + std::map& cexc); /** get explanation for * diff --git a/src/theory/quantifiers/sygus/sygus_pbe.cpp b/src/theory/quantifiers/sygus/sygus_pbe.cpp index 36e883848..1c61544e1 100644 --- a/src/theory/quantifiers/sygus/sygus_pbe.cpp +++ b/src/theory/quantifiers/sygus/sygus_pbe.cpp @@ -1303,7 +1303,7 @@ void CegConjecturePbe::addEnumeratedValue( Node x, Node v, std::vector< Node >& if (exp_exc.isNull()) { // if we did not already explain why this should be excluded, use default - exp_exc = d_tds->getExplain()->getExplanationForConstantEquality(x, v); + exp_exc = d_tds->getExplain()->getExplanationForEquality(x, v); } Node exlem = NodeManager::currentNM()->mkNode(OR, g.negate(), exp_exc.negate()); diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index b12a23c83..e8bdf2083 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -733,6 +733,59 @@ void TermDbSygus::getEnumerators(std::vector& mts) } } +void TermDbSygus::registerSymBreakLemma(Node e, + Node lem, + TypeNode tn, + unsigned sz) +{ + d_enum_to_sb_lemmas[e].push_back(lem); + d_sb_lemma_to_type[lem] = tn; + d_sb_lemma_to_size[lem] = sz; +} + +bool TermDbSygus::hasSymBreakLemmas(std::vector& enums) const +{ + if (!d_enum_to_sb_lemmas.empty()) + { + for (std::pair > sb : d_enum_to_sb_lemmas) + { + enums.push_back(sb.first); + } + return true; + } + return false; +} + +void TermDbSygus::getSymBreakLemmas(Node e, std::vector& lemmas) const +{ + std::map >::const_iterator itsb = + d_enum_to_sb_lemmas.find(e); + if (itsb != d_enum_to_sb_lemmas.end()) + { + lemmas.insert(lemmas.end(), itsb->second.begin(), itsb->second.end()); + } +} + +TypeNode TermDbSygus::getTypeForSymBreakLemma(Node lem) const +{ + std::map::const_iterator it = d_sb_lemma_to_type.find(lem); + Assert(it != d_sb_lemma_to_type.end()); + return it->second; +} +unsigned TermDbSygus::getSizeForSymBreakLemma(Node lem) const +{ + std::map::const_iterator it = d_sb_lemma_to_size.find(lem); + Assert(it != d_sb_lemma_to_size.end()); + return it->second; +} + +void TermDbSygus::clearSymBreakLemmas() +{ + d_enum_to_sb_lemmas.clear(); + d_sb_lemma_to_type.clear(); + d_sb_lemma_to_size.clear(); +} + bool TermDbSygus::isRegistered( TypeNode tn ) { return d_register.find( tn )!=d_register.end(); } @@ -1202,7 +1255,7 @@ void TermDbSygus::registerModelValue( Node a, Node v, std::vector< Node >& terms unsigned start = d_node_mv_args_proc[n][vn]; // get explanation in terms of testers std::vector< Node > antec_exp; - d_syexp->getExplanationForConstantEquality(n, vn, antec_exp); + d_syexp->getExplanationForEquality(n, vn, antec_exp); Node antec = antec_exp.size()==1 ? antec_exp[0] : NodeManager::currentNM()->mkNode( kind::AND, antec_exp ); //Node antec = n.eqNode( vn ); TypeNode tn = n.getType(); diff --git a/src/theory/quantifiers/sygus/term_database_sygus.h b/src/theory/quantifiers/sygus/term_database_sygus.h index e796a3adc..7ef9e6151 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.h +++ b/src/theory/quantifiers/sygus/term_database_sygus.h @@ -40,21 +40,31 @@ class TermDbSygus { std::string identify() const { return "TermDbSygus"; } /** register the sygus type */ void registerSygusType(TypeNode tn); - /** register a variable e that we will do enumerative search on - * conj is the conjecture that the enumeration of e is for. - * f is the synth-fun that the enumeration of e is for. - * mkActiveGuard is whether we want to make an active guard for e + + //------------------------------utilities + /** get the explanation utility */ + SygusExplain* getExplain() { return d_syexp.get(); } + /** get the extended rewrite utility */ + ExtendedRewriter* getExtRewriter() { return d_ext_rw.get(); } + //------------------------------end utilities + + //------------------------------enumerators + /** + * Register a variable e that we will do enumerative search on. + * conj : the conjecture that the enumeration of e is for. + * f : the synth-fun that the enumeration of e is for. + * mkActiveGuard : whether we want to make an active guard for e * (see d_enum_to_active_guard). * - * Notice that enumerator e may not be equivalent - * to f in synthesis-through-unification approaches - * (e.g. decision tree construction for PBE synthesis). + * Notice that enumerator e may not be one-to-one with f in + * synthesis-through-unification approaches (e.g. decision tree construction + * for PBE synthesis). */ void registerEnumerator(Node e, Node f, CegConjecture* conj, bool mkActiveGuard = false); - /** is e an enumerator? */ + /** is e an enumerator registered with this class? */ bool isEnumerator(Node e) const; /** return the conjecture e is associated with */ CegConjecture* getConjectureForEnumerator(Node e); @@ -64,10 +74,36 @@ class TermDbSygus { Node getActiveGuardForEnumerator(Node e); /** get all registered enumerators */ void getEnumerators(std::vector& mts); - /** get the explanation utility */ - SygusExplain* getExplain() { return d_syexp.get(); } - /** get the extended rewrite utility */ - ExtendedRewriter* getExtRewriter() { return d_ext_rw.get(); } + /** Register symmetry breaking lemma + * + * This function registers lem as a symmetry breaking lemma template for + * subterms of enumerator e. For more information on symmetry breaking + * lemma templates, see datatypes/datatypes_sygus.h. + * + * tn : the (sygus datatype) type that lem applies to, i.e. the + * type of terms that lem blocks models for, + * sz : the minimum size of terms that the lem blocks. + * + * Notice that the symmetry breaking lemma template should be relative to x, + * where x is returned by the call to getFreeVar( tn, 0 ) in this class. + */ + void registerSymBreakLemma(Node e, Node lem, TypeNode tn, unsigned sz); + /** Has symmetry breaking lemmas been added for any enumerator? */ + bool hasSymBreakLemmas(std::vector& enums) const; + /** Get symmetry breaking lemmas + * + * Returns the set of symmetry breaking lemmas that have been registered + * for enumerator e. It adds these to lemmas. + */ + void getSymBreakLemmas(Node e, std::vector& lemmas) const; + /** Get the type of term symmetry breaking lemma lem applies to */ + TypeNode getTypeForSymBreakLemma(Node lem) const; + /** Get the minimum size of terms symmetry breaking lemma lem applies to */ + unsigned getSizeForSymBreakLemma(Node lem) const; + /** Clear information about symmetry breaking lemmas */ + void clearSymBreakLemmas(); + //------------------------------end enumerators + //-----------------------------conversion from sygus to builtin /** get free variable * @@ -121,10 +157,15 @@ class TermDbSygus { private: /** reference to the quantifiers engine */ QuantifiersEngine* d_quantEngine; + + //------------------------------utilities /** sygus explanation */ std::unique_ptr d_syexp; /** sygus explanation */ std::unique_ptr d_ext_rw; + //------------------------------end utilities + + //------------------------------enumerators /** mapping from enumerator terms to the conjecture they are associated with */ std::map d_enum_to_conjecture; @@ -137,6 +178,13 @@ class TermDbSygus { * if G is true, then there are more values of e to enumerate". */ std::map d_enum_to_active_guard; + /** mapping from enumerators to symmetry breaking clauses for them */ + std::map > d_enum_to_sb_lemmas; + /** mapping from symmetry breaking lemmas to type */ + std::map d_sb_lemma_to_type; + /** mapping from symmetry breaking lemmas to size */ + std::map d_sb_lemma_to_size; + //------------------------------end enumerators //-----------------------------conversion from sygus to builtin /** cache for sygusToBuiltin */