From: Andrew Reynolds Date: Mon, 12 Aug 2019 19:23:31 +0000 (-0500) Subject: Give rewrite engine pointer to conflict-based instantiation module (#3174) X-Git-Tag: cvc5-1.0.0~4029 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=75d70649e2d72d6d6bb46f47cf96ee523b718cb9;p=cvc5.git Give rewrite engine pointer to conflict-based instantiation module (#3174) --- diff --git a/src/theory/quantifiers/rewrite_engine.cpp b/src/theory/quantifiers/rewrite_engine.cpp index 10157c3b3..07ff9ee46 100644 --- a/src/theory/quantifiers/rewrite_engine.cpp +++ b/src/theory/quantifiers/rewrite_engine.cpp @@ -17,17 +17,7 @@ #include "theory/quantifiers/rewrite_engine.h" #include "options/quantifiers_options.h" -#include "theory/quantifiers/ematching/inst_match_generator.h" -#include "theory/quantifiers/first_order_model.h" -#include "theory/quantifiers/fmf/model_engine.h" -#include "theory/quantifiers/instantiate.h" -#include "theory/quantifiers/quant_conflict_find.h" -#include "theory/quantifiers/quant_util.h" -#include "theory/quantifiers/quantifiers_attributes.h" -#include "theory/quantifiers/term_database.h" -#include "theory/quantifiers/term_util.h" #include "theory/quantifiers_engine.h" -#include "theory/theory_engine.h" using namespace CVC4; using namespace std; @@ -42,8 +32,11 @@ struct PrioritySort { } }; - -RewriteEngine::RewriteEngine( context::Context* c, QuantifiersEngine* qe ) : QuantifiersModule(qe) { +RewriteEngine::RewriteEngine(context::Context* c, + QuantifiersEngine* qe, + QuantConflictFind* qcf) + : QuantifiersModule(qe), d_qcf(qcf) +{ d_needsSort = false; } @@ -115,101 +108,98 @@ void RewriteEngine::check(Theory::Effort e, QEffort quant_e) } int RewriteEngine::checkRewriteRule( Node f, Theory::Effort e ) { - int addedLemmas = 0; Trace("rewrite-engine-inst") << "Check " << d_qinfo_n[f] << ", priority = " << getPriority( f ) << ", effort = " << e << "..." << std::endl; - QuantConflictFind * qcf = d_quantEngine->getConflictFind(); - if( qcf ){ - //reset QCF module - qcf->setEffort(QuantConflictFind::EFFORT_CONFLICT); - //get the proper quantifiers info - std::map< Node, QuantInfo >::iterator it = d_qinfo.find( f ); - if( it!=d_qinfo.end() ){ - QuantInfo * qi = &it->second; - if( qi->matchGeneratorIsValid() ){ - Node rr = QuantAttributes::getRewriteRule( f ); - Trace("rewrite-engine-inst-debug") << " Reset round..." << std::endl; - qi->reset_round( qcf ); - Trace("rewrite-engine-inst-debug") << " Get matches..." << std::endl; - while( !d_quantEngine->inConflict() && qi->getNextMatch( qcf ) && - ( addedLemmas==0 || !options::rrOneInstPerRound() ) ){ - Trace("rewrite-engine-inst-debug") << " Got match to complete..." << std::endl; - qi->debugPrintMatch( "rewrite-engine-inst-debug" ); - std::vector< int > assigned; - if( !qi->isMatchSpurious( qcf ) ){ - bool doContinue = false; - bool success = true; - int tempAddedLemmas = 0; - while( !d_quantEngine->inConflict() && tempAddedLemmas==0 && success && ( addedLemmas==0 || !options::rrOneInstPerRound() ) ){ - success = qi->completeMatch( qcf, assigned, doContinue ); - doContinue = true; - if( success ){ - Trace("rewrite-engine-inst-debug") << " Construct match..." << std::endl; - std::vector< Node > inst; - qi->getMatch( inst ); - Trace("rewrite-engine-inst-debug") << " Add instantiation..." << std::endl; - for( unsigned i=0; i "; - if( if[0].getNumChildren() ){ - inst.resize( f[0].getNumChildren() ); - } - if (d_quantEngine->getInstantiate()->addInstantiation(f, inst)) - { - addedLemmas++; - tempAddedLemmas++; - /* - //remove rewritten terms from consideration - std::vector< Node > to_remove; - switch( rr[2].getKind() ){ - case kind::RR_REWRITE: - to_remove.push_back( rr[2][0] ); - break; - case kind::RR_REDUCTION: - for( unsigned i=0; igetSubstitute( to_remove[j], inst ); - Trace("rewrite-engine-inst-debug") << "Will remove : " << ns << std::endl; - d_quantEngine->getTermDatabase()->setTermInactive( ns ); - } - */ - }else{ - Trace("rewrite-engine-inst-debug") << " - failed." << std::endl; - } - Trace("rewrite-engine-inst-debug") << " Get next completion..." << std::endl; + + // get the proper quantifiers info + std::map::iterator it = d_qinfo.find(f); + if (it == d_qinfo.end()) + { + Trace("rewrite-engine-inst-debug") << "...No qinfo." << std::endl; + return 0; + } + // reset QCF module + QuantInfo* qi = &it->second; + if (!qi->matchGeneratorIsValid()) + { + Trace("rewrite-engine-inst-debug") << "...Invalid qinfo." << std::endl; + return 0; + } + d_qcf->setEffort(QuantConflictFind::EFFORT_CONFLICT); + Node rr = QuantAttributes::getRewriteRule(f); + Trace("rewrite-engine-inst-debug") << " Reset round..." << std::endl; + qi->reset_round(d_qcf); + Trace("rewrite-engine-inst-debug") << " Get matches..." << std::endl; + int addedLemmas = 0; + while (!d_quantEngine->inConflict() && qi->getNextMatch(d_qcf) + && (addedLemmas == 0 || !options::rrOneInstPerRound())) + { + Trace("rewrite-engine-inst-debug") + << " Got match to complete..." << std::endl; + qi->debugPrintMatch("rewrite-engine-inst-debug"); + std::vector assigned; + if (!qi->isMatchSpurious(d_qcf)) + { + bool doContinue = false; + bool success = true; + int tempAddedLemmas = 0; + while (!d_quantEngine->inConflict() && tempAddedLemmas == 0 && success + && (addedLemmas == 0 || !options::rrOneInstPerRound())) + { + success = qi->completeMatch(d_qcf, assigned, doContinue); + doContinue = true; + if (success) + { + Trace("rewrite-engine-inst-debug") + << " Construct match..." << std::endl; + std::vector inst; + qi->getMatch(inst); + if (Trace.isOn("rewrite-engine-inst-debug")) + { + Trace("rewrite-engine-inst-debug") + << " Add instantiation..." << std::endl; + for (unsigned i = 0, nchild = f[0].getNumChildren(); i < nchild; + i++) + { + Trace("rewrite-engine-inst-debug") << " " << f[0][i] << " -> "; + if (i < inst.size()) + { + Trace("rewrite-engine-inst-debug") << inst[i] << std::endl; + } + else + { + Trace("rewrite-engine-inst-debug") + << "OUT_OF_RANGE" << std::endl; + Assert(false); } } - //Trace("rewrite-engine-inst-debug") << " Reverted assigned variables : "; - //for( unsigned a=0; arevertMatch( assigned ); - //Assert( assigned.empty() ); - Trace("rewrite-engine-inst-debug") << " - failed to complete." << std::endl; - }else{ - Trace("rewrite-engine-inst-debug") << " - match is spurious." << std::endl; } - Trace("rewrite-engine-inst-debug") << " Get next match..." << std::endl; + // resize to remove auxiliary variables + if (inst.size() > f[0].getNumChildren()) + { + inst.resize(f[0].getNumChildren()); + } + if (d_quantEngine->getInstantiate()->addInstantiation(f, inst)) + { + addedLemmas++; + tempAddedLemmas++; + } + else + { + Trace("rewrite-engine-inst-debug") << " - failed." << std::endl; + } + Trace("rewrite-engine-inst-debug") + << " Get next completion..." << std::endl; } - }else{ - Trace("rewrite-engine-inst-debug") << "...Invalid qinfo." << std::endl; } - }else{ - Trace("rewrite-engine-inst-debug") << "...No qinfo." << std::endl; + Trace("rewrite-engine-inst-debug") + << " - failed to complete." << std::endl; } + else + { + Trace("rewrite-engine-inst-debug") + << " - match is spurious." << std::endl; + } + Trace("rewrite-engine-inst-debug") << " Get next match..." << std::endl; } d_quantEngine->d_statistics.d_instantiations_rr += addedLemmas; Trace("rewrite-engine-inst") << "-> Generated " << addedLemmas << " lemmas." << std::endl; @@ -218,71 +208,82 @@ int RewriteEngine::checkRewriteRule( Node f, Theory::Effort e ) { void RewriteEngine::registerQuantifier( Node f ) { Node rr = QuantAttributes::getRewriteRule( f ); - if( !rr.isNull() ){ - Trace("rr-register") << "Register quantifier " << f << std::endl; - Trace("rr-register") << " rewrite rule is : " << rr << std::endl; - d_rr_quant.push_back( f ); - d_rr[f] = rr; - d_needsSort = true; - Trace("rr-register") << " guard is : " << d_rr[f][1] << std::endl; + if (rr.isNull()) + { + return; + } + Trace("rr-register") << "Register quantifier " << f << std::endl; + Trace("rr-register") << " rewrite rule is : " << rr << std::endl; + d_rr_quant.push_back(f); + d_rr[f] = rr; + d_needsSort = true; + Trace("rr-register") << " guard is : " << d_rr[f][1] << std::endl; - QuantConflictFind * qcf = d_quantEngine->getConflictFind(); - if( qcf ){ - std::vector< Node > qcfn_c; + std::vector qcfn_c; - std::vector< Node > bvl; - for( unsigned i=0; i bvl; + bvl.insert(bvl.end(), f[0].begin(), f[0].end()); - std::vector< Node > cc; - //add patterns - for( unsigned i=1; i nc; - for( unsigned j=0; jmkBoundVar( f[2][i][j].getType() ); - if( f[2][i][j].getType().isBoolean() && f[2][i][j].getKind()!=APPLY_UF ){ - nn = f[2][i][j].negate(); - }else{ - nn = f[2][i][j].eqNode( nbv ).negate(); - bvl.push_back( nbv ); - } - nc.push_back( nn ); - } - if( !nc.empty() ){ - Node n = nc.size()==1 ? nc[0] : NodeManager::currentNM()->mkNode( AND, nc ); - Trace("rr-register-debug") << " pattern is " << n << std::endl; - if( std::find( cc.begin(), cc.end(), n )==cc.end() ){ - cc.push_back( n ); - } - } + NodeManager* nm = NodeManager::currentNM(); + std::vector cc; + // add patterns + for (unsigned i = 1, nchild = f[2].getNumChildren(); i < nchild; i++) + { + std::vector nc; + for (const Node& pat : f[2][i]) + { + Node nn; + Node nbv = nm->mkBoundVar(pat.getType()); + if (pat.getType().isBoolean() && pat.getKind() != APPLY_UF) + { + nn = pat.negate(); } - qcfn_c.push_back( NodeManager::currentNM()->mkNode( BOUND_VAR_LIST, bvl ) ); + else + { + nn = pat.eqNode(nbv).negate(); + bvl.push_back(nbv); + } + nc.push_back(nn); + } + if (!nc.empty()) + { + Node n = nc.size() == 1 ? nc[0] : nm->mkNode(AND, nc); + Trace("rr-register-debug") << " pattern is " << n << std::endl; + if (std::find(cc.begin(), cc.end(), n) == cc.end()) + { + cc.push_back(n); + } + } + } + qcfn_c.push_back(nm->mkNode(BOUND_VAR_LIST, bvl)); - std::vector< Node > body_c; - //add the guards - if( d_rr[f][1].getKind()==AND ){ - for( unsigned j=0; jgetTermUtil()->d_true ){ - if( MatchGen::isHandled( d_rr[f][1] ) ){ - body_c.push_back( d_rr[f][1].negate() ); - } + std::vector body_c; + // add the guards + if (d_rr[f][1].getKind() == AND) + { + for (const Node& g : d_rr[f][1]) + { + if (MatchGen::isHandled(g)) + { + body_c.push_back(g.negate()); } - //add the patterns to the body - body_c.push_back( cc.size()==1 ? cc[0] : NodeManager::currentNM()->mkNode( AND, cc ) ); - //make the body - qcfn_c.push_back( body_c.size()==1 ? body_c[0] : NodeManager::currentNM()->mkNode( OR, body_c ) ); - //make the quantified formula - d_qinfo_n[f] = NodeManager::currentNM()->mkNode( FORALL, qcfn_c ); - Trace("rr-register") << " qcf formula is : " << d_qinfo_n[f] << std::endl; - d_qinfo[f].initialize( qcf, d_qinfo_n[f], d_qinfo_n[f][1] ); } } + else if (d_rr[f][1] != d_quantEngine->getTermUtil()->d_true) + { + if (MatchGen::isHandled(d_rr[f][1])) + { + body_c.push_back(d_rr[f][1].negate()); + } + } + // add the patterns to the body + body_c.push_back(cc.size() == 1 ? cc[0] : nm->mkNode(AND, cc)); + // make the body + qcfn_c.push_back(body_c.size() == 1 ? body_c[0] : nm->mkNode(OR, body_c)); + // make the quantified formula + d_qinfo_n[f] = nm->mkNode(FORALL, qcfn_c); + Trace("rr-register") << " qcf formula is : " << d_qinfo_n[f] << std::endl; + d_qinfo[f].initialize(d_qcf, d_qinfo_n[f], d_qinfo_n[f][1]); } void RewriteEngine::assertNode( Node n ) { diff --git a/src/theory/quantifiers/rewrite_engine.h b/src/theory/quantifiers/rewrite_engine.h index 717f4009b..5832d2817 100644 --- a/src/theory/quantifiers/rewrite_engine.h +++ b/src/theory/quantifiers/rewrite_engine.h @@ -18,8 +18,9 @@ #ifndef CVC4__REWRITE_ENGINE_H #define CVC4__REWRITE_ENGINE_H -#include "context/context.h" -#include "context/context_mm.h" +#include +#include + #include "theory/quantifiers/ematching/trigger.h" #include "theory/quantifiers/quant_conflict_find.h" #include "theory/quantifiers/quant_util.h" @@ -30,9 +31,6 @@ namespace quantifiers { class RewriteEngine : public QuantifiersModule { - typedef context::CDHashMap NodeBoolMap; - typedef context::CDHashMap NodeIntMap; - typedef context::CDHashMap NodeNodeMap; std::vector< Node > d_rr_quant; std::vector< Node > d_priority_order; std::map< Node, Node > d_rr; @@ -45,10 +43,14 @@ class RewriteEngine : public QuantifiersModule bool d_needsSort; std::map< Node, std::map< Node, Node > > d_inst_const_node; Node getInstConstNode( Node n, Node q ); -private: + + private: int checkRewriteRule( Node f, Theory::Effort e ); -public: - RewriteEngine( context::Context* c, QuantifiersEngine* qe ); + + public: + RewriteEngine(context::Context* c, + QuantifiersEngine* qe, + QuantConflictFind* qcf); bool needsCheck(Theory::Effort e) override; void check(Theory::Effort e, QEffort quant_e) override; @@ -57,6 +59,14 @@ public: bool checkCompleteFor(Node q) override; /** Identify this module */ std::string identify() const override { return "RewriteEngine"; } + + private: + /** + * A pointer to the quantifiers conflict find module of the quantifiers + * engine. This is the module that computes instantiations for rewrite rule + * quantifiers. + */ + QuantConflictFind* d_qcf; }; } diff --git a/src/theory/quantifiers_engine.cpp b/src/theory/quantifiers_engine.cpp index f0b0c31df..c17af1e1f 100644 --- a/src/theory/quantifiers_engine.cpp +++ b/src/theory/quantifiers_engine.cpp @@ -149,7 +149,7 @@ class QuantifiersEnginePrivate } if (options::quantRewriteRules()) { - d_rr_engine.reset(new quantifiers::RewriteEngine(c, qe)); + d_rr_engine.reset(new quantifiers::RewriteEngine(c, qe, d_qcf.get())); modules.push_back(d_rr_engine.get()); } if (options::ltePartialInst()) @@ -408,10 +408,6 @@ quantifiers::BoundedIntegers* QuantifiersEngine::getBoundedIntegers() const { return d_private->d_bint.get(); } -quantifiers::QuantConflictFind* QuantifiersEngine::getConflictFind() const -{ - return d_private->d_qcf.get(); -} quantifiers::SynthEngine* QuantifiersEngine::getSynthEngine() const { return d_private->d_synth_e.get(); diff --git a/src/theory/quantifiers_engine.h b/src/theory/quantifiers_engine.h index 7e5fe9102..7a9f5e7da 100644 --- a/src/theory/quantifiers_engine.h +++ b/src/theory/quantifiers_engine.h @@ -123,8 +123,6 @@ public: //---------------------- modules (TODO remove these #1163) /** get bounded integers utility */ quantifiers::BoundedIntegers* getBoundedIntegers() const; - /** Conflict find mechanism for quantifiers */ - quantifiers::QuantConflictFind* getConflictFind() const; /** ceg instantiation */ quantifiers::SynthEngine* getSynthEngine() const; /** get inst strategy cbqi */