From: Andrew Reynolds Date: Tue, 20 Mar 2018 22:32:43 +0000 (-0500) Subject: Minor refactor datatypes sygus (#1673) X-Git-Tag: cvc5-1.0.0~5229 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=62f58d62c6c597eeb9cae5e08d74f21c4a5c5c40;p=cvc5.git Minor refactor datatypes sygus (#1673) --- diff --git a/src/theory/datatypes/datatypes_sygus.cpp b/src/theory/datatypes/datatypes_sygus.cpp index 1779ab27b..3c90bc448 100644 --- a/src/theory/datatypes/datatypes_sygus.cpp +++ b/src/theory/datatypes/datatypes_sygus.cpp @@ -222,10 +222,10 @@ void SygusSymBreakNew::registerTerm( Node n, std::vector< Node >& lemmas ) { bool success = false; if( n.getKind()==kind::APPLY_SELECTOR_TOTAL ){ registerTerm( n[0], lemmas ); - std::map< Node, Node >::iterator it = d_term_to_anchor.find( n[0] ); + std::unordered_map::iterator it = + d_term_to_anchor.find(n[0]); if( it!=d_term_to_anchor.end() ) { d_term_to_anchor[n] = it->second; - d_term_to_anchor_conj[n] = d_term_to_anchor_conj[n[0]]; unsigned sel_weight = d_tds->getSelectorWeight(n[0].getType(), n.getOperator()); d = d_term_to_depth[n[0]] + sel_weight; @@ -236,9 +236,9 @@ void SygusSymBreakNew::registerTerm( Node n, std::vector< Node >& lemmas ) { registerSizeTerm( n, lemmas ); if( d_register_st[n] ){ d_term_to_anchor[n] = n; - d_term_to_anchor_conj[n] = d_tds->getConjectureForEnumerator(n); + d_anchor_to_conj[n] = d_tds->getConjectureForEnumerator(n); // this assertion fails if we have a sygus term in the search that is unmeasured - Assert(d_term_to_anchor_conj[n] != NULL); + Assert(d_anchor_to_conj[n] != NULL); d = 0; is_top_level = true; success = true; @@ -354,8 +354,8 @@ void SygusSymBreakNew::assertTesterInternal( int tindex, TNode n, Node exp, std: } // static conjecture-dependent symmetry breaking std::map::iterator itc = - d_term_to_anchor_conj.find(n); - if (itc != d_term_to_anchor_conj.end()) + d_anchor_to_conj.find(a); + if (itc != d_anchor_to_conj.end()) { quantifiers::CegConjecture* conj = itc->second; Assert(conj != NULL); @@ -691,7 +691,8 @@ unsigned SygusSymBreakNew::processSelectorChain( Node n, std::map< TypeNode, Nod void SygusSymBreakNew::registerSearchTerm( TypeNode tn, unsigned d, Node n, bool topLevel, std::vector< Node >& lemmas ) { //register this term - std::map< Node, Node >::iterator ita = d_term_to_anchor.find( n ); + std::unordered_map::iterator ita = + d_term_to_anchor.find(n); Assert( ita != d_term_to_anchor.end() ); Node a = ita->second; Assert( !a.isNull() ); @@ -722,10 +723,10 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d, Trace("sygus-sb-debug2") << "Registering search value " << n << " -> " << nv << std::endl; // must do this for all nodes, regardless of top-level if( d_cache[a].d_search_val_proc.find( nv )==d_cache[a].d_search_val_proc.end() ){ - d_cache[a].d_search_val_proc[nv] = true; + d_cache[a].d_search_val_proc.insert(nv); // get the root (for PBE symmetry breaking) - Assert(d_term_to_anchor_conj.find(a) != d_term_to_anchor_conj.end()); - quantifiers::CegConjecture* aconj = d_term_to_anchor_conj[a]; + Assert(d_anchor_to_conj.find(a) != d_anchor_to_conj.end()); + quantifiers::CegConjecture* aconj = d_anchor_to_conj[a]; Assert(aconj != NULL); Trace("sygus-sb-debug") << " ...register search value " << nv << ", type=" << tn << std::endl; Node bv = d_tds->sygusToBuiltin( nv, tn ); @@ -740,7 +741,8 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d, registerSymBreakLemmaForValue(a, nv, dbzet, Node::null(), lemmas); return false; }else{ - std::map< Node, Node >::iterator itsv = d_cache[a].d_search_val[tn].find( bvr ); + std::unordered_map::iterator itsv = + d_cache[a].d_search_val[tn].find(bvr); Node bad_val_bvr; bool by_examples = false; if( itsv==d_cache[a].d_search_val[tn].end() ){ @@ -787,20 +789,9 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d, d_tds, nv, options::sygusSamples(), false); its = d_sampler[a].find(tn); } - Node bvr_sample_ret; - std::map::iterator itsv = - d_cache[a].d_search_val_sample[tn].find(bvr); - if (itsv == d_cache[a].d_search_val_sample[tn].end()) - { - // initialize the sampler for the rewritten form of this node - bvr_sample_ret = its->second.registerTerm(bvr); - d_cache[a].d_search_val_sample[tn][bvr] = bvr_sample_ret; - } - else - { - bvr_sample_ret = itsv->second; - } + // register the rewritten node with the sampler + Node bvr_sample_ret = its->second.registerTerm(bvr); // register the current node with the sampler Node sample_ret = its->second.registerTerm(bv); @@ -1072,7 +1063,8 @@ void SygusSymBreakNew::notifySearchSize( Node m, unsigned s, Node exp, std::vect unsigned SygusSymBreakNew::getSearchSizeFor( Node n ) { Trace("sygus-sb-debug2") << "get search size for term : " << n << std::endl; - std::map< Node, Node >::iterator ita = d_term_to_anchor.find( n ); + std::unordered_map::iterator ita = + d_term_to_anchor.find(n); Assert( ita != d_term_to_anchor.end() ); return getSearchSizeForAnchor( ita->second ); } diff --git a/src/theory/datatypes/datatypes_sygus.h b/src/theory/datatypes/datatypes_sygus.h index cb7729658..2936c1561 100644 --- a/src/theory/datatypes/datatypes_sygus.h +++ b/src/theory/datatypes/datatypes_sygus.h @@ -47,8 +47,22 @@ class SygusSymBreakNew typedef context::CDHashMap< Node, bool, NodeHashFunction > BoolMap; typedef context::CDHashSet NodeSet; + public: + SygusSymBreakNew(TheoryDatatypes* td, + quantifiers::TermDbSygus* tds, + context::Context* c); + ~SygusSymBreakNew(); + /** add tester */ + void assertTester(int tindex, TNode n, Node exp, std::vector& lemmas); + void assertFact(Node n, bool polarity, std::vector& lemmas); + void preRegisterTerm(TNode n, std::vector& lemmas); + void check(std::vector& lemmas); + Node getNextDecisionRequest(unsigned& priority, std::vector& lemmas); + private: + /** Pointer to the datatype theory that owns this class. */ TheoryDatatypes* d_td; + /** Pointer to the sygus term database */ quantifiers::TermDbSygus* d_tds; IntMap d_testers; IntMap d_is_const; @@ -56,18 +70,15 @@ class SygusSymBreakNew NodeSet d_active_terms; IntMap d_currTermSize; Node d_zero; - - private: /** * Map from terms (selector chains) to their anchors. The anchor of a * selector chain S1( ... Sn( x ) ... ) is x. */ - std::map< Node, Node > d_term_to_anchor; + std::unordered_map d_term_to_anchor; /** - * Map from terms (selector chains) to the conjecture that their anchor is - * associated with. + * Map from anchors to the conjecture they are associated with. */ - std::map d_term_to_anchor_conj; + std::map d_anchor_to_conj; /** * Map from terms (selector chains) to their depth. The depth of a selector * chain S1( ... Sn( x ) ... ) is: @@ -75,7 +86,7 @@ class SygusSymBreakNew * where weight is the selector weight of Si * (see SygusTermDatabase::getSelectorWeight). */ - std::map< Node, unsigned > d_term_to_depth; + std::unordered_map d_term_to_depth; /** * Map from terms (selector chains) to whether they are the topmost term * of their type. For example, if: @@ -87,18 +98,24 @@ class SygusSymBreakNew * whereas S2( S1( x ) ) and S3( S2( S1( x ) ) ) are not. * */ - std::map< Node, bool > d_is_top_level; + std::unordered_map d_is_top_level; /** * Returns true if the selector chain n is top-level based on the above * definition, when tn is the type of n. */ bool computeTopLevel( TypeNode tn, Node n ); private: - //list of all terms encountered in search at depth - class SearchCache { + /** This caches all information regarding symmetry breaking for an anchor. */ + class SearchCache + { public: SearchCache(){} + /** + * A cache of all search terms for (types, sizes). See registerSearchTerm + * for definition of search terms. + */ std::map< TypeNode, std::map< unsigned, std::vector< Node > > > d_search_terms; + /** A cache of all symmetry breaking lemma templates for (types, sizes). */ std::map< TypeNode, std::map< unsigned, std::vector< Node > > > d_sb_lemmas; /** search value * @@ -107,20 +124,13 @@ private: * term. The range of this map can be updated if we later encounter a sygus * term that also rewrites to the builtin value but has a smaller term size. */ - std::map< TypeNode, std::map< Node, Node > > d_search_val; + std::map> + d_search_val; /** the size of terms in the range of d_search val. */ - std::map< TypeNode, std::map< Node, unsigned > > d_search_val_sz; - /** search value sample - * - * This is used for the sygusRewVerify() option. For each sygus term t - * of type tn with anchor a that we register with this cache, we set: - * d_search_val_sample[tn][r] = r' - * where r is the rewritten form of the builtin equivalent of t, and r' - * is the term returned by d_sampler[a][tn].registerTerm( r ). - */ - std::map> d_search_val_sample; + std::map> + d_search_val_sz; /** For each term, whether this cache has processed that term */ - std::map< Node, bool > d_search_val_proc; + std::unordered_set d_search_val_proc; }; /** An instance of the above cache, for each anchor */ std::map< Node, SearchCache > d_cache; @@ -158,12 +168,15 @@ private: * if applicable. */ void registerTerm(Node n, std::vector& lemmas); + + //------------------------dynamic symmetry breaking /** Register search term * - * This function is called when selector chain S_1( ... S_m( n ) ... ) is - * registered to the theory of datatypes, where tn is the type of n, - * d indicates the depth of n (the sum of weights of the selectors S_1...S_m), - * and topLevel is whether n is a top-level term (see d_is_top_level). + * This function is called when selector chain n of the form + * S_1( ... S_m( x ) ... ) is registered to the theory of datatypes, where + * tn is the type of n, d indicates the depth of n (the sum of weights of the + * selectors S_1...S_m), and topLevel is whether n is a top-level term + * (see d_is_top_level). We refer to n as a "search term". * * The purpose of this function is to notify this class that symmetry breaking * lemmas should be instantiated for n. Any symmetry breaking lemmas that @@ -261,27 +274,74 @@ private: /** add symmetry breaking lemma * * This adds the lemma R => lem{ x -> n } to lemmas, where R is a "relevancy - * condition" that states which contexts n is relevant in (contexts in which - * the selector chain n is specified). + * condition" that states which contexts n is relevant in (see + * getRelevancyCondition). */ void addSymBreakLemma(Node lem, TNode x, TNode n, std::vector& lemmas); + //------------------------end dynamic symmetry breaking - private: - std::map< Node, Node > d_rlv_cond; + /** Get relevancy condition + * + * This returns a predicate that holds in the contexts in which the selector + * chain n is specified. For example, the relevancy condition for + * sel_{C2,1}( sel_{C1,1}( d ) ) is is-C1( d ) ^ is-C2( sel_{C1,1}( d ) ). + * If shared selectors are enabled, this is a conjunction of disjunctions, + * since shared selectors may apply to multiple constructors. + */ Node getRelevancyCondition( Node n ); -private: - std::map< TypeNode, std::map< int, std::map< unsigned, Node > > > d_simple_sb_pred; - // user-context dependent if sygus-incremental - std::map< Node, unsigned > d_simple_proc; - //get simple symmetry breaking predicate + /** Cache of the above function */ + std::map d_rlv_cond; + + //------------------------static symmetry breaking + /** Get simple symmetry breakind predicate + * + * This function returns the "static" symmetry breaking lemma template for + * terms with type tn and constructor index tindex, for the given depth. This + * includes inferences about size with depth=0. Given grammar: + * A -> ite( B, A, A ) | A+A | x | 1 | 0 + * B -> A = A + * Examples of static symmetry breaking lemma templates are: + * for +, depth 0: size(z)=size(z.1)+size(z.2)+1 + * for +, depth 1: ~is-0( z.1 ) ^ ~is-0( z.2 ) ^ F + * where F ensures the constructor of z.1 is less than that of z.2 based + * on some ordering. + * for ite, depth 1: z.2 != z.3 + * These templates can be thought of as "hard-coded" cases of dynamic symmetry + * breaking lemma templates. Notice that the above lemma templates are in + * terms of getFreeVar( tn ), hence only one is created per + * (constructor, depth). A static symmetry break lemma template F[z] for + * constructor C are included in lemmas of the form: + * is-C( t ) => F[t] + * where t is a search term, see registerSearchTerm for definition of search + * term. + */ Node getSimpleSymBreakPred( TypeNode tn, int tindex, unsigned depth ); + /** Cache of the above function */ + std::map>> d_simple_sb_pred; + /** + * For each search term, this stores the maximum depth for which we have added + * a static symmetry breaking lemma. + * + * This should be user context-dependent if sygus is updated to work in + * incremental mode. + */ + std::unordered_map d_simple_proc; + //------------------------end static symmetry breaking + + /** Get the canonical free variable for type tn */ TNode getFreeVar( TypeNode tn ); Node getTermOrderPredicate( Node n1, Node n2 ); private: - //should be user-context dependent if sygus in incremental mode - std::map< Node, bool > d_register_st; - void registerSizeTerm( Node e, std::vector< Node >& lemmas ); - class SearchSizeInfo { + /** + * Map from registered variables to whether they are a sygus enumerator. + * + * This should be user context-dependent if sygus is updated to work in + * incremental mode. + */ + std::map d_register_st; + void registerSizeTerm(Node e, std::vector& lemmas); + class SearchSizeInfo + { public: SearchSizeInfo( Node t, context::Context* c ) : d_this( t ), d_curr_search_size(0), d_curr_lit( c, 0 ) {} Node d_this; @@ -323,16 +383,6 @@ private: int getGuardStatus( Node g ); private: void assertIsConst( Node n, bool polarity, std::vector< Node >& lemmas ); -public: - SygusSymBreakNew( TheoryDatatypes * td, quantifiers::TermDbSygus * tds, context::Context* c ); - ~SygusSymBreakNew(); - /** add tester */ - void assertTester( int tindex, TNode n, Node exp, std::vector< Node >& lemmas ); - void assertFact( Node n, bool polarity, std::vector< Node >& lemmas ); - void preRegisterTerm( TNode n, std::vector< Node >& lemmas ); - void check( std::vector< Node >& lemmas ); -public: - Node getNextDecisionRequest( unsigned& priority, std::vector< Node >& lemmas ); }; } diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index e8bdf2083..40183fe9c 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -161,42 +161,53 @@ Node TermDbSygus::mkGeneric(const Datatype& dt, int c, std::map& pre) return mkGeneric(dt, c, var_count, pre); } +struct SygusToBuiltinAttributeId +{ +}; +typedef expr::Attribute + SygusToBuiltinAttribute; + Node TermDbSygus::sygusToBuiltin( Node n, TypeNode tn ) { Assert( n.getType()==tn ); Assert( tn.isDatatype() ); - std::map< Node, Node >::iterator it = d_sygus_to_builtin[tn].find( n ); - if( it==d_sygus_to_builtin[tn].end() ){ - Trace("sygus-db-debug") << "SygusToBuiltin : compute for " << n << ", type = " << tn << std::endl; - const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); - if( n.getKind()==APPLY_CONSTRUCTOR ){ - unsigned i = Datatype::indexOf( n.getOperator().toExpr() ); - Assert( n.getNumChildren()==dt[i].getNumArgs() ); - std::map< TypeNode, int > var_count; - std::map< int, Node > pre; - for (unsigned j = 0, size = n.getNumChildren(); j < size; j++) - { - pre[j] = sygusToBuiltin( n[j], getArgType( dt[i], j ) ); - } - Node ret = mkGeneric(dt, i, var_count, pre); - Trace("sygus-db-debug") << "SygusToBuiltin : Generic is " << ret << std::endl; - d_sygus_to_builtin[tn][n] = ret; - return ret; - } - if (n.hasAttribute(SygusPrintProxyAttribute())) + + // has it already been computed? + if (n.hasAttribute(SygusToBuiltinAttribute())) + { + return n.getAttribute(SygusToBuiltinAttribute()); + } + + Trace("sygus-db-debug") << "SygusToBuiltin : compute for " << n + << ", type = " << tn << std::endl; + const Datatype& dt = static_cast(tn.toType()).getDatatype(); + if (n.getKind() == APPLY_CONSTRUCTOR) + { + unsigned i = Datatype::indexOf(n.getOperator().toExpr()); + Assert(n.getNumChildren() == dt[i].getNumArgs()); + std::map var_count; + std::map pre; + for (unsigned j = 0, size = n.getNumChildren(); j < size; j++) { - // this variable was associated by an attribute to a builtin node - return n.getAttribute(SygusPrintProxyAttribute()); + pre[j] = sygusToBuiltin(n[j], getArgType(dt[i], j)); } - Assert(isFreeVar(n)); - // map to builtin variable type - int fv_num = getVarNum(n); - Assert(!dt.getSygusType().isNull()); - TypeNode vtn = TypeNode::fromType(dt.getSygusType()); - Node ret = getFreeVar(vtn, fv_num); + Node ret = mkGeneric(dt, i, var_count, pre); + Trace("sygus-db-debug") + << "SygusToBuiltin : Generic is " << ret << std::endl; + n.setAttribute(SygusToBuiltinAttribute(), ret); return ret; - }else{ - return it->second; } + if (n.hasAttribute(SygusPrintProxyAttribute())) + { + // this variable was associated by an attribute to a builtin node + return n.getAttribute(SygusPrintProxyAttribute()); + } + Assert(isFreeVar(n)); + // map to builtin variable type + int fv_num = getVarNum(n); + Assert(!dt.getSygusType().isNull()); + TypeNode vtn = TypeNode::fromType(dt.getSygusType()); + Node ret = getFreeVar(vtn, fv_num); + return ret; } Node TermDbSygus::sygusSubstituted( TypeNode tn, Node n, std::vector< Node >& args ) { diff --git a/src/theory/quantifiers/sygus/term_database_sygus.h b/src/theory/quantifiers/sygus/term_database_sygus.h index 7ef9e6151..57a127d8d 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.h +++ b/src/theory/quantifiers/sygus/term_database_sygus.h @@ -187,8 +187,6 @@ class TermDbSygus { //------------------------------end enumerators //-----------------------------conversion from sygus to builtin - /** cache for sygusToBuiltin */ - std::map > d_sygus_to_builtin; /** a cache of fresh variables for each type * * We store two versions of this list: