From c4a2d444a601ab8131d2088065bbc8bd24ed7696 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Mon, 16 Apr 2018 08:14:53 -0500 Subject: [PATCH] Skolemize candidate rewrite rule checks (#1777) --- .../quantifiers/cegqi/inst_strategy_cbqi.cpp | 2 +- .../ematching/inst_match_generator.cpp | 12 ++- src/theory/quantifiers/ematching/trigger.cpp | 15 ++- src/theory/quantifiers/macros.cpp | 2 +- .../sygus/ce_guided_conjecture.cpp | 55 +++++++++-- .../quantifiers/sygus/ce_guided_conjecture.h | 5 + src/theory/quantifiers/term_util.cpp | 91 ++++++++++++------- src/theory/quantifiers/term_util.h | 37 ++++---- 8 files changed, 153 insertions(+), 66 deletions(-) diff --git a/src/theory/quantifiers/cegqi/inst_strategy_cbqi.cpp b/src/theory/quantifiers/cegqi/inst_strategy_cbqi.cpp index df04a743b..d2aa75288 100644 --- a/src/theory/quantifiers/cegqi/inst_strategy_cbqi.cpp +++ b/src/theory/quantifiers/cegqi/inst_strategy_cbqi.cpp @@ -113,7 +113,7 @@ bool InstStrategyCbqi::registerCbqiLemma( Node q ) { //compute dependencies between quantified formulas if( options::cbqiLitDepend() || options::cbqiInnermost() ){ std::vector< Node > ics; - TermUtil::computeVarContains( q, ics ); + TermUtil::computeInstConstContains(q, ics); d_parent_quant[q].clear(); d_children_quant[q].clear(); std::vector< Node > dep; diff --git a/src/theory/quantifiers/ematching/inst_match_generator.cpp b/src/theory/quantifiers/ematching/inst_match_generator.cpp index 0252def60..9c3095e59 100644 --- a/src/theory/quantifiers/ematching/inst_match_generator.cpp +++ b/src/theory/quantifiers/ematching/inst_match_generator.cpp @@ -597,7 +597,11 @@ int VarMatchGeneratorTermSubs::getNextMatch(Node q, InstMatchGeneratorMultiLinear::InstMatchGeneratorMultiLinear( Node q, std::vector< Node >& pats, QuantifiersEngine* qe ) { //order patterns to maximize eager matching failures std::map< Node, std::vector< Node > > var_contains; - qe->getTermUtil()->getVarContains( q, pats, var_contains ); + for (const Node& pat : pats) + { + quantifiers::TermUtil::computeInstConstContainsForQuant( + q, pat, var_contains[pat]); + } std::map< Node, std::vector< Node > > var_to_node; for( std::map< Node, std::vector< Node > >::iterator it = var_contains.begin(); it != var_contains.end(); ++it ){ for( unsigned i=0; isecond.size(); i++ ){ @@ -710,7 +714,11 @@ InstMatchGeneratorMulti::InstMatchGeneratorMulti(Node q, { Trace("multi-trigger-cache") << "Making smart multi-trigger for " << q << std::endl; std::map< Node, std::vector< Node > > var_contains; - qe->getTermUtil()->getVarContains( q, pats, var_contains ); + for (const Node& pat : pats) + { + quantifiers::TermUtil::computeInstConstContainsForQuant( + q, pat, var_contains[pat]); + } //convert to indicies for( std::map< Node, std::vector< Node > >::iterator it = var_contains.begin(); it != var_contains.end(); ++it ){ Trace("multi-trigger-cache") << "Pattern " << it->first << " contains: "; diff --git a/src/theory/quantifiers/ematching/trigger.cpp b/src/theory/quantifiers/ematching/trigger.cpp index cb5afbfab..3928cf485 100644 --- a/src/theory/quantifiers/ematching/trigger.cpp +++ b/src/theory/quantifiers/ematching/trigger.cpp @@ -36,7 +36,7 @@ namespace inst { void TriggerTermInfo::init( Node q, Node n, int reqPol, Node reqPolEq ){ if( d_fv.empty() ){ - quantifiers::TermUtil::getVarContainsNode( q, n, d_fv ); + quantifiers::TermUtil::computeInstConstContainsForQuant(q, n, d_fv); } if( d_reqPol==0 ){ d_reqPol = reqPol; @@ -134,7 +134,11 @@ bool Trigger::mkTriggerTerms( Node q, std::vector< Node >& nodes, unsigned n_var std::map< Node, std::vector< Node > > patterns; size_t varCount = 0; std::map< Node, std::vector< Node > > varContains; - quantifiers::TermUtil::getVarContains( q, temp, varContains ); + for (const Node& pat : temp) + { + quantifiers::TermUtil::computeInstConstContainsForQuant( + q, pat, varContains[pat]); + } for( unsigned i=0; i& nodes) std::map > fvs; for (unsigned i = 0, size = nodes.size(); i < size; i++) { - quantifiers::TermUtil::computeVarContains(nodes[i], fvs[i]); + quantifiers::TermUtil::computeInstConstContains(nodes[i], fvs[i]); } std::vector active; active.resize(nodes.size(), true); @@ -870,8 +874,9 @@ void Trigger::getTriggerVariables(Node n, Node q, std::vector& t_vars) std::vector< Node > exclude; collectPatTerms(q, n, patTerms, quantifiers::TRIGGER_SEL_ALL, exclude, tinfo); //collect all variables from all patterns in patTerms, add to t_vars - for( unsigned i=0; igetTermUtil()->substituteBoundVariablesToInstConstants(n, f); Trace("macros-debug2") << "Get free variables in " << icn << std::endl; std::vector< Node > var; - d_qe->getTermUtil()->getVarContainsNode( f, icn, var ); + quantifiers::TermUtil::computeInstConstContainsForQuant(f, icn, var); Trace("macros-debug2") << "Get trigger variables for " << icn << std::endl; std::vector< Node > trigger_var; inst::Trigger::getTriggerVariables( icn, f, trigger_var ); diff --git a/src/theory/quantifiers/sygus/ce_guided_conjecture.cpp b/src/theory/quantifiers/sygus/ce_guided_conjecture.cpp index d160581bf..1e0f72817 100644 --- a/src/theory/quantifiers/sygus/ce_guided_conjecture.cpp +++ b/src/theory/quantifiers/sygus/ce_guided_conjecture.cpp @@ -620,11 +620,39 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation // Notice we don't set produce-models. rrChecker takes the same // options as the SmtEngine we belong to, where we ensure that // produce-models is set. - SmtEngine rrChecker(NodeManager::currentNM()->toExprManager()); + NodeManager* nm = NodeManager::currentNM(); + SmtEngine rrChecker(nm->toExprManager()); rrChecker.setLogic(smt::currentSmtEngine()->getLogicInfo()); Node crr = solbr.eqNode(eq_solr).negate(); - Trace("rr-check") - << "Check candidate rewrite : " << crr << std::endl; + Trace("rr-check") << "Check candidate rewrite : " << crr + << std::endl; + // quantify over the free variables in crr + std::vector fvs; + TermUtil::computeVarContains(crr, fvs); + std::map fv_index; + std::vector sks; + if (!fvs.empty()) + { + // map to skolems + for (unsigned i = 0, size = fvs.size(); i < size; i++) + { + Node v = fvs[i]; + fv_index[v] = i; + std::map::iterator itf = d_fv_to_skolem.find(v); + if (itf == d_fv_to_skolem.end()) + { + Node sk = nm->mkSkolem("rrck", v.getType()); + d_fv_to_skolem[v] = sk; + sks.push_back(sk); + } + else + { + sks.push_back(itf->second); + } + } + crr = crr.substitute( + fvs.begin(), fvs.end(), sks.begin(), sks.end()); + } rrChecker.assertFormula(crr.toExpr()); Result r = rrChecker.checkSat(); Trace("rr-check") << "...result : " << r << std::endl; @@ -639,15 +667,28 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation std::vector pt; for (const Node& v : vars) { - Node val = Node::fromExpr(rrChecker.getValue(v.toExpr())); - Trace("rr-check") << " " << v << " -> " << val << std::endl; + std::map::iterator itf = fv_index.find(v); + Node val; + if (itf == fv_index.end()) + { + // not in conjecture, can use arbitrary value + val = v.getType().mkGroundTerm(); + } + else + { + // get the model value of its skolem + Node sk = sks[itf->second]; + val = Node::fromExpr(rrChecker.getValue(sk.toExpr())); + Trace("rr-check") << " " << v << " -> " << val + << std::endl; + } pt.push_back(val); } d_sampler[prog].addSamplePoint(pt); // add the solution again + // by construction of the above point, we should be unique now Node eq_sol_new = its->second.registerTerm(sol); - Assert(!r.asSatisfiabilityResult().isSat() - || eq_sol_new == sol); + Assert(eq_sol_new == sol); } else { diff --git a/src/theory/quantifiers/sygus/ce_guided_conjecture.h b/src/theory/quantifiers/sygus/ce_guided_conjecture.h index 215a4d161..b6812a18a 100644 --- a/src/theory/quantifiers/sygus/ce_guided_conjecture.h +++ b/src/theory/quantifiers/sygus/ce_guided_conjecture.h @@ -247,6 +247,11 @@ private: * rewrite rules. */ std::map d_sampler; + /** + * Cache of skolems for each free variable that appears in a synthesis check + * (for --sygus-rr-synth-check). + */ + std::map d_fv_to_skolem; }; } /* namespace CVC4::theory::quantifiers */ diff --git a/src/theory/quantifiers/term_util.cpp b/src/theory/quantifiers/term_util.cpp index 7cebf0e1e..b3915bd5d 100644 --- a/src/theory/quantifiers/term_util.cpp +++ b/src/theory/quantifiers/term_util.cpp @@ -267,51 +267,74 @@ Node TermUtil::substituteInstConstants(Node n, Node q, std::vector& terms) terms.end()); } -void TermUtil::computeVarContains( Node n, std::vector< Node >& varContains ) { - std::map< Node, bool > visited; - computeVarContains2( n, INST_CONSTANT, varContains, visited ); +void TermUtil::computeInstConstContains(Node n, std::vector& ics) +{ + computeVarContainsInternal(n, INST_CONSTANT, ics); } -void TermUtil::computeQuantContains( Node n, std::vector< Node >& quantContains ) { - std::map< Node, bool > visited; - computeVarContains2( n, FORALL, quantContains, visited ); +void TermUtil::computeVarContains(Node n, std::vector& vars) +{ + computeVarContainsInternal(n, BOUND_VARIABLE, vars); } +void TermUtil::computeQuantContains(Node n, std::vector& quants) +{ + computeVarContainsInternal(n, FORALL, quants); +} -void TermUtil::computeVarContains2( Node n, Kind k, std::vector< Node >& varContains, std::map< Node, bool >& visited ){ - if( visited.find( n )==visited.end() ){ - visited[n] = true; - if( n.getKind()==k ){ - if( std::find( varContains.begin(), varContains.end(), n )==varContains.end() ){ - varContains.push_back( n ); - } - }else{ - if (n.hasOperator()) +void TermUtil::computeVarContainsInternal(Node n, + Kind k, + std::vector& vars) +{ + std::unordered_set visited; + std::unordered_set::iterator it; + std::vector visit; + TNode cur; + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + it = visited.find(cur); + + if (it == visited.end()) + { + visited.insert(cur); + if (cur.getKind() == k) { - computeVarContains2(n.getOperator(), k, varContains, visited); + if (std::find(vars.begin(), vars.end(), cur) == vars.end()) + { + vars.push_back(cur); + } } - for( unsigned i=0; i& pats, std::map< Node, std::vector< Node > >& varContains ){ - for( unsigned i=0; i& varContains ){ - std::vector< Node > vars; - computeVarContains( n, vars ); - for( unsigned j=0; j& vars) +{ + std::vector ics; + computeInstConstContains(n, ics); + for (const Node& v : ics) + { + if (v.getAttribute(InstConstantAttribute()) == q) + { + if (std::find(vars.begin(), vars.end(), v) == vars.end()) + { + vars.push_back(v); } } } diff --git a/src/theory/quantifiers/term_util.h b/src/theory/quantifiers/term_util.h index 6b83ad639..df88c1b30 100644 --- a/src/theory/quantifiers/term_util.h +++ b/src/theory/quantifiers/term_util.h @@ -180,23 +180,28 @@ public: static Node getQuantSimplify( Node n ); private: - /** helper function for compute var contains */ - static void computeVarContains2( Node n, Kind k, std::vector< Node >& varContains, std::map< Node, bool >& visited ); + /** adds the set of nodes of kind k in n to vars */ + static void computeVarContainsInternal(Node n, + Kind k, + std::vector& vars); + public: - /** compute var contains */ - static void computeVarContains( Node n, std::vector< Node >& varContains ); - /** get var contains for each of the patterns in pats */ - static void getVarContains( Node f, std::vector< Node >& pats, std::map< Node, std::vector< Node > >& varContains ); - /** get var contains for node n */ - static void getVarContainsNode( Node f, Node n, std::vector< Node >& varContains ); - /** compute quant contains */ - static void computeQuantContains( Node n, std::vector< Node >& quantContains ); - // TODO (#1216) : this should be in trigger.h - /** filter all nodes that have instances */ - static void filterInstances( std::vector< Node >& nodes ); - -//for term ordering -private: + /** adds the set of nodes of kind INST_CONSTANT in n to ics */ + static void computeInstConstContains(Node n, std::vector& ics); + /** adds the set of nodes of kind BOUND_VARIABLE in n to vars */ + static void computeVarContains(Node n, std::vector& vars); + /** adds the set of (top-level) nodes of kind FORALL in n to quants */ + static void computeQuantContains(Node n, std::vector& quants); + /** + * Adds the set of nodes of kind INST_CONSTANT in n that belong to quantified + * formula q to vars. + */ + static void computeInstConstContainsForQuant(Node q, + Node n, + std::vector& vars); + + // for term ordering + private: /** operator id count */ int d_op_id_count; /** map from operators to id */ -- 2.30.2