From: Andrew Reynolds Date: Fri, 20 Jul 2018 23:42:25 +0000 (+0200) Subject: Cleanup and additions for candidate generator (#2173) X-Git-Tag: cvc5-1.0.0~4883 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=2b4fa75f5b7cb6657b3e1ebc35534ca4fd0ac422;p=cvc5.git Cleanup and additions for candidate generator (#2173) --- diff --git a/src/theory/quantifiers/ematching/candidate_generator.cpp b/src/theory/quantifiers/ematching/candidate_generator.cpp index 96719cc0f..4208b11ae 100644 --- a/src/theory/quantifiers/ematching/candidate_generator.cpp +++ b/src/theory/quantifiers/ematching/candidate_generator.cpp @@ -33,38 +33,10 @@ bool CandidateGenerator::isLegalCandidate( Node n ){ return d_qe->getTermDatabase()->isTermActive( n ) && ( !options::cbqi() || !quantifiers::TermUtil::hasInstConstAttr(n) ); } -void CandidateGeneratorQueue::addCandidate( Node n ) { - if( isLegalCandidate( n ) ){ - d_candidates.push_back( n ); - } -} - -void CandidateGeneratorQueue::reset( Node eqc ){ - if( d_candidate_index>0 ){ - d_candidates.erase( d_candidates.begin(), d_candidates.begin() + d_candidate_index ); - d_candidate_index = 0; - } - if( !eqc.isNull() ){ - d_candidates.push_back( eqc ); - } -} -Node CandidateGeneratorQueue::getNextCandidate(){ - if( d_candidate_index<(int)d_candidates.size() ){ - Node n = d_candidates[d_candidate_index]; - d_candidate_index++; - return n; - }else{ - d_candidate_index = 0; - d_candidates.clear(); - return Node::null(); - } -} - CandidateGeneratorQE::CandidateGeneratorQE( QuantifiersEngine* qe, Node pat ) : CandidateGenerator( qe ), d_term_iter( -1 ){ d_op = qe->getTermDatabase()->getMatchOperator( pat ); Assert( !d_op.isNull() ); - d_op_arity = pat.getNumChildren(); } void CandidateGeneratorQE::resetInstantiationRound(){ @@ -83,22 +55,16 @@ void CandidateGeneratorQE::reset( Node eqc ){ if( ee->hasTerm( eqc ) ){ quantifiers::TermArgTrie * tat = d_qe->getTermDatabase()->getTermArgTrie( eqc, d_op ); if( tat ){ -#if 1 //create an equivalence class iterator in eq class eqc Node rep = ee->getRepresentative( eqc ); d_eqc_iter = eq::EqClassIterator( rep, ee ); d_mode = cand_term_eqc; -#else - d_tindex.push_back( tat ); - d_tindex_iter.push_back( tat->d_data.begin() ); - d_mode = cand_term_tindex; -#endif }else{ d_mode = cand_term_none; } }else{ //the only match is this term itself - d_n = eqc; + d_eqc = eqc; d_mode = cand_term_ident; } } @@ -144,41 +110,12 @@ Node CandidateGeneratorQE::getNextCandidate(){ return n; } } - }else if( d_mode==cand_term_tindex ){ - Debug("cand-gen-qe") << "...get next candidate in tindex " << d_op << " " << d_op_arity << std::endl; - //increment the term index iterator - if( !d_tindex.empty() ){ - //populate the vector - while( d_tindex_iter.size()<=d_op_arity ){ - Assert( !d_tindex_iter.empty() ); - Assert( !d_tindex_iter.back()->second.d_data.empty() ); - d_tindex.push_back( &(d_tindex_iter.back()->second) ); - d_tindex_iter.push_back( d_tindex_iter.back()->second.d_data.begin() ); - } - //get the current node - Assert( d_tindex_iter.back()->second.hasNodeData() ); - Node n = d_tindex_iter.back()->second.getNodeData(); - Debug("cand-gen-qe") << "...returning " << n << std::endl; - Assert( !n.isNull() ); - Assert( isLegalOpCandidate( n ) ); - //increment - bool success = false; - do{ - ++d_tindex_iter.back(); - if( d_tindex_iter.back()==d_tindex.back()->d_data.end() ){ - d_tindex.pop_back(); - d_tindex_iter.pop_back(); - }else{ - success = true; - } - }while( !success && !d_tindex.empty() ); - return n; - } }else if( d_mode==cand_term_ident ){ Debug("cand-gen-qe") << "...get next candidate identity" << std::endl; - if( !d_n.isNull() ){ - Node n = d_n; - d_n = Node::null(); + if (!d_eqc.isNull()) + { + Node n = d_eqc; + d_eqc = Node::null(); if( isLegalOpCandidate( n ) ){ return n; } @@ -187,45 +124,6 @@ Node CandidateGeneratorQE::getNextCandidate(){ return Node::null(); } -CandidateGeneratorQELitEq::CandidateGeneratorQELitEq( QuantifiersEngine* qe, Node mpat ) : - CandidateGenerator( qe ), d_match_pattern( mpat ){ - Assert( mpat.getKind()==EQUAL ); - for( unsigned i=0; i<2; i++ ){ - if( !quantifiers::TermUtil::hasInstConstAttr(mpat[i]) ){ - d_match_gterm = mpat[i]; - } - } -} -void CandidateGeneratorQELitEq::resetInstantiationRound(){ - -} -void CandidateGeneratorQELitEq::reset( Node eqc ){ - if( d_match_gterm.isNull() ){ - d_eq = eq::EqClassesIterator( d_qe->getEqualityQuery()->getEngine() ); - }else{ - d_do_mgt = true; - } -} -Node CandidateGeneratorQELitEq::getNextCandidate(){ - if( d_match_gterm.isNull() ){ - while( !d_eq.isFinished() ){ - Node n = (*d_eq); - ++d_eq; - if( n.getType().isComparableTo( d_match_pattern[0].getType() ) ){ - //an equivalence class with the same type as the pattern, return reflexive equality - return NodeManager::currentNM()->mkNode( d_match_pattern.getKind(), n, n ); - } - } - }else{ - if( d_do_mgt ){ - d_do_mgt = false; - return NodeManager::currentNM()->mkNode( d_match_pattern.getKind(), d_match_gterm, d_match_gterm ); - } - } - return Node::null(); -} - - CandidateGeneratorQELitDeq::CandidateGeneratorQELitDeq( QuantifiersEngine* qe, Node mpat ) : CandidateGenerator( qe ), d_match_pattern( mpat ){ @@ -233,10 +131,6 @@ CandidateGenerator( qe ), d_match_pattern( mpat ){ d_match_pattern_type = d_match_pattern[0].getType(); } -void CandidateGeneratorQELitDeq::resetInstantiationRound(){ - -} - void CandidateGeneratorQELitDeq::reset( Node eqc ){ Node false_term = d_qe->getEqualityQuery()->getEngine()->getRepresentative( NodeManager::currentNM()->mkConst(false) ); d_eqc_false = eq::EqClassIterator( false_term, d_qe->getEqualityQuery()->getEngine() ); @@ -269,10 +163,6 @@ CandidateGeneratorQEAll::CandidateGeneratorQEAll( QuantifiersEngine* qe, Node mp d_firstTime = false; } -void CandidateGeneratorQEAll::resetInstantiationRound() { - -} - void CandidateGeneratorQEAll::reset( Node eqc ) { d_eq = eq::EqClassesIterator( d_qe->getEqualityQuery()->getEngine() ); d_firstTime = true; @@ -307,3 +197,56 @@ Node CandidateGeneratorQEAll::getNextCandidate() { } return Node::null(); } + +CandidateGeneratorConsExpand::CandidateGeneratorConsExpand( + QuantifiersEngine* qe, Node mpat) + : CandidateGeneratorQE(qe, mpat) +{ + Assert(mpat.getKind() == APPLY_CONSTRUCTOR); + d_mpat_type = static_cast(mpat.getType().toType()); +} + +void CandidateGeneratorConsExpand::reset(Node eqc) +{ + d_term_iter = 0; + if (eqc.isNull()) + { + d_mode = cand_term_db; + } + else + { + d_eqc = eqc; + d_mode = cand_term_ident; + Assert(d_eqc.getType().toType() == d_mpat_type); + } +} + +Node CandidateGeneratorConsExpand::getNextCandidate() +{ + // get the next term from the base class + Node curr = CandidateGeneratorQE::getNextCandidate(); + if (curr.isNull() || (curr.hasOperator() && curr.getOperator() == d_op)) + { + return curr; + } + // expand it + NodeManager* nm = NodeManager::currentNM(); + std::vector children; + const Datatype& dt = d_mpat_type.getDatatype(); + Assert(dt.getNumConstructors() == 1); + children.push_back(d_op); + for (unsigned i = 0, nargs = dt[0].getNumArgs(); i < nargs; i++) + { + Node sel = + nm->mkNode(APPLY_SELECTOR_TOTAL, + Node::fromExpr(dt[0].getSelectorInternal(d_mpat_type, i)), + curr); + children.push_back(sel); + } + return nm->mkNode(APPLY_CONSTRUCTOR, children); +} + +bool CandidateGeneratorConsExpand::isLegalOpCandidate(Node n) +{ + return isLegalCandidate(n); +} diff --git a/src/theory/quantifiers/ematching/candidate_generator.h b/src/theory/quantifiers/ematching/candidate_generator.h index dc188062f..da4ec2d83 100644 --- a/src/theory/quantifiers/ematching/candidate_generator.h +++ b/src/theory/quantifiers/ematching/candidate_generator.h @@ -23,133 +23,145 @@ namespace CVC4 { namespace theory { -namespace quantifiers { - class TermArgTrie; -} - class QuantifiersEngine; namespace inst { -/** base class for generating candidates for matching */ +/** Candidate generator + * + * This is the base class for generating a stream of candidate terms for + * E-matching. Depending on the kind of trigger we are processing and its + * overall context, we are interested in several different criteria for + * terms. This includes: + * - Generating a stream of all ground terms with a given operator, + * - Generating a stream of all ground terms with a given operator in a + * particular equivalence class, + * - Generating a stream of all terms of a particular type, + * - Generating all terms that are disequal from a fixed ground term, + * and so on. + * + * A typical use case of an instance cg of this class is the following. Given + * an equivalence class representative eqc: + * + * cg->reset( eqc ); + * do{ + * Node cand = cg->getNextCandidate(); + * ; ...if non-null, cand is a candidate... + * }while( !cand.isNull() ); + * + */ class CandidateGenerator { protected: QuantifiersEngine* d_qe; public: CandidateGenerator( QuantifiersEngine* qe ) : d_qe( qe ){} virtual ~CandidateGenerator(){} - - /** Get candidates functions. These set up a context to get all match candidates. - cg->reset( eqc ); - do{ - Node cand = cg->getNextCandidate(); - //....... - }while( !cand.isNull() ); - - eqc is the equivalence class you are searching in - */ + /** reset instantiation round + * + * This is called at the beginning of each instantiation round. + */ + virtual void resetInstantiationRound() {} + /** reset for equivalence class eqc + * + * This indicates that this class should generate a stream of candidate terms + * based on its criteria that occur in the equivalence class of eqc, or + * any equivalence class if eqc is null. + */ virtual void reset( Node eqc ) = 0; + /** get the next candidate */ virtual Node getNextCandidate() = 0; - /** add candidate to list of nodes returned by this generator */ - virtual void addCandidate( Node n ) {} - /** call this at the beginning of each instantiation round */ - virtual void resetInstantiationRound() = 0; public: - /** legal candidate */ - bool isLegalCandidate( Node n ); + /** is n a legal candidate? */ + bool isLegalCandidate(Node n); };/* class CandidateGenerator */ -/** candidate generator queue (for manual candidate generation) */ -class CandidateGeneratorQueue : public CandidateGenerator { - private: - std::vector< Node > d_candidates; - int d_candidate_index; - - public: - CandidateGeneratorQueue( QuantifiersEngine* qe ) : CandidateGenerator( qe ), d_candidate_index( 0 ){} - - void addCandidate(Node n) override; - - void resetInstantiationRound() override {} - void reset(Node eqc) override; - Node getNextCandidate() override; -};/* class CandidateGeneratorQueue */ - -//the default generator +/* the default candidate generator class + * + * This class may generate candidates for E-matching based on several modes: + * (1) cand_term_db: iterate over all ground terms for the given operator, + * (2) cand_term_ident: generate the given input term as a candidate, + * (3) cand_term_eqc: iterate over all terms in an equivalence class, returning + * those with the proper operator as candidates. + */ class CandidateGeneratorQE : public CandidateGenerator { friend class CandidateGeneratorQEDisequal; - private: - //operator you are looking for + public: + CandidateGeneratorQE(QuantifiersEngine* qe, Node pat); + /** reset instantiation round */ + void resetInstantiationRound() override; + /** reset */ + void reset(Node eqc) override; + /** get next candidate */ + Node getNextCandidate() override; + /** tell this class to exclude candidates from equivalence class r */ + void excludeEqc(Node r) { d_exclude_eqc[r] = true; } + /** is r an excluded equivalence class? */ + bool isExcludedEqc(Node r) + { + return d_exclude_eqc.find(r) != d_exclude_eqc.end(); + } + + protected: + /** operator you are looking for */ Node d_op; - //the equality class iterator - unsigned d_op_arity; - std::vector< quantifiers::TermArgTrie* > d_tindex; - std::vector< std::map< TNode, quantifiers::TermArgTrie >::iterator > d_tindex_iter; + /** the equality class iterator (for cand_term_eqc) */ eq::EqClassIterator d_eqc_iter; - //std::vector< Node > d_eqc; + /** the TermDb index of the current ground term (for cand_term_db) */ int d_term_iter; + /** the TermDb index of the current ground term (for cand_term_db) */ int d_term_iter_limit; - bool d_using_term_db; + /** the term we are matching (for cand_term_ident) */ + Node d_eqc; + /** candidate generation modes */ enum { cand_term_db, cand_term_ident, cand_term_eqc, - cand_term_tindex, cand_term_none, }; + /** the current mode of this candidate generator */ short d_mode; - bool isLegalOpCandidate( Node n ); - Node d_n; + /** is n a legal candidate of the required operator? */ + virtual bool isLegalOpCandidate(Node n); + /** the equivalence classes that we have excluded from candidate generation */ std::map< Node, bool > d_exclude_eqc; - public: - CandidateGeneratorQE( QuantifiersEngine* qe, Node pat ); - - void resetInstantiationRound() override; - void reset(Node eqc) override; - Node getNextCandidate() override; - void excludeEqc( Node r ) { d_exclude_eqc[r] = true; } - bool isExcludedEqc( Node r ) { return d_exclude_eqc.find( r )!=d_exclude_eqc.end(); } }; -class CandidateGeneratorQELitEq : public CandidateGenerator +/** + * Generate terms based on a disequality, that is, we match (= t[x] s[x]) + * with equalities (= g1 g2) in the equivalence class of false. + */ +class CandidateGeneratorQELitDeq : public CandidateGenerator { - private: - //the equality classes iterator - eq::EqClassesIterator d_eq; - //equality you are trying to match equalities for - Node d_match_pattern; - Node d_match_gterm; - bool d_do_mgt; - public: - CandidateGeneratorQELitEq( QuantifiersEngine* qe, Node mpat ); - - void resetInstantiationRound() override; + /** + * mpat is an equality that we are matching to equalities in the equivalence + * class of false + */ + CandidateGeneratorQELitDeq(QuantifiersEngine* qe, Node mpat); + /** reset */ void reset(Node eqc) override; + /** get next candidate */ Node getNextCandidate() override; -}; -class CandidateGeneratorQELitDeq : public CandidateGenerator -{ private: - //the equality class iterator for false + /** the equality class iterator for false */ eq::EqClassIterator d_eqc_false; - //equality you are trying to match disequalities for + /** + * equality you are trying to match against ground equalities that are + * assigned to false + */ Node d_match_pattern; - //type of disequality + /** type of the terms we are generating */ TypeNode d_match_pattern_type; - - public: - CandidateGeneratorQELitDeq( QuantifiersEngine* qe, Node mpat ); - - void resetInstantiationRound() override; - void reset(Node eqc) override; - Node getNextCandidate() override; }; +/** + * Generate all terms of the proper sort that occur in the current context. + */ class CandidateGeneratorQEAll : public CandidateGenerator { private: @@ -166,10 +178,34 @@ class CandidateGeneratorQEAll : public CandidateGenerator public: CandidateGeneratorQEAll( QuantifiersEngine* qe, Node mpat ); + /** reset */ + void reset(Node eqc) override; + /** get next candidate */ + Node getNextCandidate() override; +}; - void resetInstantiationRound() override; +/** candidate generation constructor expand + * + * This modifies the candidates t1, ..., tn generated by CandidateGeneratorQE + * so that they are "expansions" of a fixed datatype constructor C. Assuming + * C has arity m, we instead return the stream: + * C(sel_1( t1 ), ..., sel_m( tn )) ... C(sel_1( t1 ), ..., C( sel_m( tn )) + * where sel_1 ... sel_m are the selectors of C. + */ +class CandidateGeneratorConsExpand : public CandidateGeneratorQE +{ + public: + CandidateGeneratorConsExpand(QuantifiersEngine* qe, Node mpat); + /** reset */ void reset(Node eqc) override; + /** get next candidate */ Node getNextCandidate() override; + + protected: + /** the (datatype) type of the input match pattern */ + DatatypeType d_mpat_type; + /** we don't care about the operator of n */ + bool isLegalOpCandidate(Node n) override; }; }/* CVC4::theory::inst namespace */ diff --git a/src/theory/quantifiers/ematching/inst_match_generator.cpp b/src/theory/quantifiers/ematching/inst_match_generator.cpp index 90d1815a4..192a6b433 100644 --- a/src/theory/quantifiers/ematching/inst_match_generator.cpp +++ b/src/theory/quantifiers/ematching/inst_match_generator.cpp @@ -174,8 +174,24 @@ void InstMatchGenerator::initialize( Node q, QuantifiersEngine* qe, std::vector< //create candidate generator if( Trigger::isAtomicTrigger( d_match_pattern ) ){ - //we will be scanning lists trying to find d_match_pattern.getOperator() - d_cg = new inst::CandidateGeneratorQE( qe, d_match_pattern ); + if (d_match_pattern.getKind() == APPLY_CONSTRUCTOR) + { + // 1-constructors have a trivial way of generating candidates in a + // given equivalence class + const Datatype& dt = + static_cast(d_match_pattern.getType().toType()) + .getDatatype(); + if (dt.getNumConstructors() == 1) + { + d_cg = new inst::CandidateGeneratorConsExpand(qe, d_match_pattern); + } + } + if (d_cg == nullptr) + { + // we will be scanning lists trying to find + // d_match_pattern.getOperator() + d_cg = new inst::CandidateGeneratorQE(qe, d_match_pattern); + } //if matching on disequality, inform the candidate generator not to match on eqc if( d_pattern.getKind()==NOT && d_pattern[0].getKind()==EQUAL ){ ((inst::CandidateGeneratorQE*)d_cg)->excludeEqc( d_eq_class_rel ); @@ -196,13 +212,9 @@ void InstMatchGenerator::initialize( Node q, QuantifiersEngine* qe, std::vector< }else if( d_match_pattern.getKind()==EQUAL && d_match_pattern[0].getKind()==INST_CONSTANT && d_match_pattern[1].getKind()==INST_CONSTANT ){ //we will be producing candidates via literal matching heuristics - if( d_pattern.getKind()!=NOT ){ - //candidates will be all equalities - d_cg = new inst::CandidateGeneratorQELitEq( qe, d_match_pattern ); - }else{ - //candidates will be all disequalities - d_cg = new inst::CandidateGeneratorQELitDeq( qe, d_match_pattern ); - } + Assert(d_pattern.getKind() == NOT); + // candidates will be all disequalities + d_cg = new inst::CandidateGeneratorQELitDeq(qe, d_match_pattern); }else{ Trace("inst-match-gen-warn") << "(?) Unknown matching pattern is " << d_match_pattern << std::endl; }