From d0df704a60696d7f824eb01781b413d91a0e4202 Mon Sep 17 00:00:00 2001 From: ajreynol Date: Thu, 10 Mar 2016 17:49:13 -0600 Subject: [PATCH] Faster conditional rewriting for and/or beneath quantifiers. Improvements to sort inference, related to constants. Add several quantifiers options, minor refactoring. --- src/options/quantifiers_options | 6 + src/smt/smt_engine.cpp | 27 +- .../quantifiers/candidate_generator.cpp | 2 +- .../quantifiers/inst_match_generator.cpp | 15 +- .../quantifiers/inst_strategy_e_matching.cpp | 3 +- .../quantifiers/quant_conflict_find.cpp | 12 +- src/theory/quantifiers/quant_conflict_find.h | 2 +- src/theory/quantifiers/quant_util.h | 2 - .../quantifiers/quantifiers_rewriter.cpp | 261 ++++---- src/theory/quantifiers/relevant_domain.cpp | 2 +- src/theory/quantifiers/term_database.cpp | 12 +- src/theory/quantifiers/term_database.h | 5 +- src/theory/quantifiers/trigger.cpp | 2 +- src/theory/quantifiers_engine.cpp | 136 ++++- src/theory/quantifiers_engine.h | 5 +- src/theory/sort_inference.cpp | 572 +++++++++--------- src/theory/sort_inference.h | 7 +- 17 files changed, 605 insertions(+), 466 deletions(-) diff --git a/src/options/quantifiers_options b/src/options/quantifiers_options index e3f4e94f2..1363626c6 100644 --- a/src/options/quantifiers_options +++ b/src/options/quantifiers_options @@ -57,6 +57,8 @@ option elimTautQuant --elim-taut-quant bool :default true eliminate tautological disjuncts of quantified formulas option purifyQuant --purify-quant bool :default false purify quantified formulas +option elimExtArithQuant --elim-ext-arith-quant bool :default true + eliminate extended arithmetic symbols in quantified formulas #### E-matching options @@ -67,6 +69,8 @@ option termDbMode --term-db-mode CVC4::theory::quantifiers::TermDbMode :default which ground terms to consider for instantiation option registerQuantBodyTerms --register-quant-body-terms bool :default false consider ground terms within bodies of quantified formulas for matching +option inferArithTriggerEq --infer-arith-trigger-eq bool :default false + infer equalities for trigger terms based on solving arithmetic equalities option smartTriggers --smart-triggers bool :default true enable smart triggers @@ -95,6 +99,8 @@ option incrementTriggers --increment-triggers bool :default true option instWhenMode --inst-when=MODE CVC4::theory::quantifiers::InstWhenMode :default CVC4::theory::quantifiers::INST_WHEN_FULL_LAST_CALL :read-write :include "options/quantifiers_modes.h" :handler stringToInstWhenMode :predicate checkInstWhenMode when to apply instantiation +option instWhenDelayIncrement --inst-when-delay-increment bool :default false + delay incrementing counters for inst-when mode to ensure theory combination and standard quantifier effort strategies take turns option instMaxLevel --inst-max-level=N int :read-write :default -1 maximum inst level of terms used to instantiate quantified formulas with (-1 == no limit, default) diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 201585070..93623408e 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -1809,23 +1809,22 @@ void SmtEngine::setDefaults() { if( !options::rewriteDivk.wasSetByUser()) { options::rewriteDivk.set( true ); } - } - if( options::cbqi() && d_logic.isPure(THEORY_ARITH) ){ - options::cbqiAll.set( false ); - if( !options::quantConflictFind.wasSetByUser() ){ - options::quantConflictFind.set( false ); - } - if( !options::instNoEntail.wasSetByUser() ){ - options::instNoEntail.set( false ); - } - if( !options::instWhenMode.wasSetByUser() && options::cbqiModel() ){ - //only instantiation should happen at last call when model is avaiable - if( !options::instWhenMode.wasSetByUser() ){ - options::instWhenMode.set( quantifiers::INST_WHEN_LAST_CALL ); + if( d_logic.isPure(THEORY_ARITH) ){ + options::cbqiAll.set( false ); + if( !options::quantConflictFind.wasSetByUser() ){ + options::quantConflictFind.set( false ); + } + if( !options::instNoEntail.wasSetByUser() ){ + options::instNoEntail.set( false ); + } + if( !options::instWhenMode.wasSetByUser() && options::cbqiModel() ){ + //only instantiation should happen at last call when model is avaiable + if( !options::instWhenMode.wasSetByUser() ){ + options::instWhenMode.set( quantifiers::INST_WHEN_LAST_CALL ); + } } } } - //implied options... if( options::qcfMode.wasSetByUser() || options::qcfTConstraint() ){ options::quantConflictFind.set( true ); diff --git a/src/theory/quantifiers/candidate_generator.cpp b/src/theory/quantifiers/candidate_generator.cpp index 0cdb22be4..680be77da 100644 --- a/src/theory/quantifiers/candidate_generator.cpp +++ b/src/theory/quantifiers/candidate_generator.cpp @@ -90,7 +90,7 @@ void CandidateGeneratorQE::reset( Node eqc ){ bool CandidateGeneratorQE::isLegalOpCandidate( Node n ) { if( n.hasOperator() ){ if( isLegalCandidate( n ) ){ - return d_qe->getTermDatabase()->getOperator( n )==d_op; + return d_qe->getTermDatabase()->getMatchOperator( n )==d_op; } } return false; diff --git a/src/theory/quantifiers/inst_match_generator.cpp b/src/theory/quantifiers/inst_match_generator.cpp index 89c2d4868..41c62192f 100644 --- a/src/theory/quantifiers/inst_match_generator.cpp +++ b/src/theory/quantifiers/inst_match_generator.cpp @@ -97,11 +97,11 @@ void InstMatchGenerator::initialize( Node q, QuantifiersEngine* qe, std::vector< } d_match_pattern_type = d_match_pattern.getType(); Trace("inst-match-gen") << "Pattern is " << d_pattern << ", match pattern is " << d_match_pattern << std::endl; - d_match_pattern_op = qe->getTermDatabase()->getOperator( d_match_pattern ); + d_match_pattern_op = qe->getTermDatabase()->getMatchOperator( d_match_pattern ); //now, collect children of d_match_pattern //int childMatchPolicy = MATCH_GEN_DEFAULT; - for( int i=0; i<(int)d_match_pattern.getNumChildren(); i++ ){ + for( unsigned i=0; igetTermDatabase()->getOperator( d_pattern ).toExpr(); + Expr selectorExpr = qe->getTermDatabase()->getMatchOperator( d_pattern ).toExpr(); size_t selectorIndex = Datatype::cindexOf(selectorExpr); const Datatype& dt = Datatype::datatypeOf(selectorExpr); const DatatypeConstructor& c = dt[selectorIndex]; @@ -197,7 +197,7 @@ bool InstMatchGenerator::getMatch( Node f, Node t, InstMatch& m, QuantifiersEngi Assert( !Trigger::isAtomicTrigger( d_match_pattern ) || t.getOperator()==d_match_pattern.getOperator() ); //first, check if ground arguments are not equal, or a match is in conflict Trace("matching-debug2") << "Setting immediate matches..." << std::endl; - for( int i=0; i<(int)d_match_pattern.getNumChildren(); i++ ){ + for( unsigned i=0; igetTermDatabase()->getOperator( t ); + Node t_op = qe->getTermDatabase()->getMatchOperator( t ); if( ((InstMatchGenerator*)d_children[i])->d_match_pattern_op==t_op ){ InstMatch m( q ); //if it produces a match, then process it with the rest @@ -709,7 +709,7 @@ InstMatchGeneratorSimple::InstMatchGeneratorSimple( Node q, Node pat ) : d_f( q } void InstMatchGeneratorSimple::resetInstantiationRound( QuantifiersEngine* qe ) { - d_op = qe->getTermDatabase()->getOperator( d_match_pattern ); + d_op = qe->getTermDatabase()->getMatchOperator( d_match_pattern ); } int InstMatchGeneratorSimple::addInstantiations( Node q, InstMatch& baseMatch, QuantifiersEngine* qe ){ @@ -763,9 +763,10 @@ void InstMatchGeneratorSimple::addInstantiations( InstMatch& m, QuantifiersEngin } int InstMatchGeneratorSimple::addTerm( Node q, Node t, QuantifiersEngine* qe ){ + //for eager instantiation only Assert( options::eagerInstQuant() ); InstMatch m( q ); - for( int i=0; i<(int)t.getNumChildren(); i++ ){ + for( unsigned i=0; igetEqualityQuery()->areEqual( d_match_pattern[i], t[i] ) ){ diff --git a/src/theory/quantifiers/inst_strategy_e_matching.cpp b/src/theory/quantifiers/inst_strategy_e_matching.cpp index 299eb51fd..621327c0b 100644 --- a/src/theory/quantifiers/inst_strategy_e_matching.cpp +++ b/src/theory/quantifiers/inst_strategy_e_matching.cpp @@ -257,9 +257,8 @@ void InstStrategyAutoGenTriggers::generateTriggers( Node f ){ Node bd = d_quantEngine->getTermDatabase()->getInstConstantBody( f ); Trigger::collectPatTerms( d_quantEngine, f, bd, patTermsF, d_tr_strategy, d_user_no_gen[f], true ); Trace("auto-gen-trigger-debug") << "Collected pat terms for " << bd << ", no-patterns : " << d_user_no_gen[f].size() << std::endl; - Trace("auto-gen-trigger-debug") << " "; for( int i=0; i<(int)patTermsF.size(); i++ ){ - Trace("auto-gen-trigger-debug") << patTermsF[i] << " "; + Trace("auto-gen-trigger-debug") << " " << patTermsF[i] << std::endl; } Trace("auto-gen-trigger-debug") << std::endl; if( ntrivTriggers ){ diff --git a/src/theory/quantifiers/quant_conflict_find.cpp b/src/theory/quantifiers/quant_conflict_find.cpp index 779c0c44e..93cd4be91 100644 --- a/src/theory/quantifiers/quant_conflict_find.cpp +++ b/src/theory/quantifiers/quant_conflict_find.cpp @@ -843,7 +843,7 @@ MatchGen::MatchGen( QuantInfo * qi, Node n, bool isVar ) d_qni_size++; d_type_not = false; d_n = n; - //Node f = getOperator( n ); + //Node f = getMatchOperator( n ); for( unsigned j=0; jgetTermDatabase()->getTermArgTrie( Node::null(), f ); if( qni!=NULL ){ @@ -1339,7 +1339,7 @@ bool MatchGen::getNextMatch( QuantConflictFind * p, QuantInfo * qi ) { /* if( d_type==typ_var && p->d_effort==QuantConflictFind::effort_mc && !d_matched_basis ){ d_matched_basis = true; - Node f = getOperator( d_n ); + Node f = getMatchOperator( d_n ); TNode mbo = p->getTermDatabase()->getModelBasisOpTerm( f ); if( qi->setMatch( p, d_qni_var_num[0], mbo ) ){ success = true; @@ -1702,9 +1702,9 @@ bool MatchGen::isHandledUfTerm( TNode n ) { return inst::Trigger::isAtomicTriggerKind( n.getKind() ); } -Node MatchGen::getOperator( QuantConflictFind * p, Node n ) { +Node MatchGen::getMatchOperator( QuantConflictFind * p, Node n ) { if( isHandledUfTerm( n ) ){ - return p->getTermDatabase()->getOperator( n ); + return p->getTermDatabase()->getMatchOperator( n ); }else{ return Node::null(); } @@ -1896,7 +1896,7 @@ void QuantConflictFind::assertNode( Node q ) { Node QuantConflictFind::evaluateTerm( Node n ) { if( MatchGen::isHandledUfTerm( n ) ){ - Node f = MatchGen::getOperator( this, n ); + Node f = MatchGen::getMatchOperator( this, n ); Node nn; if( getEqualityEngine()->hasTerm( n ) ){ nn = getTermDatabase()->existsTerm( f, n ); diff --git a/src/theory/quantifiers/quant_conflict_find.h b/src/theory/quantifiers/quant_conflict_find.h index 11299b532..4bcc59bde 100644 --- a/src/theory/quantifiers/quant_conflict_find.h +++ b/src/theory/quantifiers/quant_conflict_find.h @@ -98,7 +98,7 @@ public: // is this term treated as UF application? static bool isHandledBoolConnective( TNode n ); static bool isHandledUfTerm( TNode n ); - static Node getOperator( QuantConflictFind * p, Node n ); + static Node getMatchOperator( QuantConflictFind * p, Node n ); //can this node be handled by the algorithm static bool isHandled( TNode n ); }; diff --git a/src/theory/quantifiers/quant_util.h b/src/theory/quantifiers/quant_util.h index 566a09923..b4cf54dfd 100644 --- a/src/theory/quantifiers/quant_util.h +++ b/src/theory/quantifiers/quant_util.h @@ -99,8 +99,6 @@ class EqualityQuery { public: EqualityQuery(){} virtual ~EqualityQuery(){}; - /** reset */ - virtual void reset() = 0; /** contains term */ virtual bool hasTerm( Node a ) = 0; /** get the representative of the equivalence class of a */ diff --git a/src/theory/quantifiers/quantifiers_rewriter.cpp b/src/theory/quantifiers/quantifiers_rewriter.cpp index ff55c5c9b..c10ba944b 100644 --- a/src/theory/quantifiers/quantifiers_rewriter.cpp +++ b/src/theory/quantifiers/quantifiers_rewriter.cpp @@ -554,6 +554,16 @@ void setEntailedCond( Node n, bool pol, std::map< Node, bool >& currCond, std::v } } +void removeEntailedCond( std::map< Node, bool >& currCond, std::vector< Node >& new_cond, std::map< Node, Node >& cache ) { + if( !new_cond.empty() ){ + for( unsigned j=0; j& new_vars, std::vector< Node >& new_conds, Node q, QAttributes& qa ){ std::map< Node, bool > curr_cond; std::map< Node, Node > cache; @@ -582,96 +592,119 @@ Node QuantifiersRewriter::computeProcessTerms2( Node body, bool hasPol, bool pol ret = iti->second; Trace("quantifiers-rewrite-term-debug2") << "Return (cached) " << ret << " for " << body << std::endl; }else{ - bool firstTimeCD = true; + //only do context dependent processing up to depth 8 + bool doCD = nCurrCond<8; bool changed = false; std::vector< Node > children; - for( size_t i=0; i new_cond; + //set entailed conditions based on OR/AND + std::map< int, std::vector< Node > > new_cond_children; + if( doCD && ( body.getKind()==OR || body.getKind()==AND ) ){ + nCurrCond = nCurrCond + 1; bool conflict = false; - //only do context dependent processing up to depth 8 - if( nCurrCond<8 ){ - if( firstTimeCD ){ - firstTimeCD = false; - nCurrCond = nCurrCond + 1; - } - if( Trace.isOn("quantifiers-rewrite-term-debug") ){ - //if( ( body.getKind()==ITE && i>0 ) || ( hasPol && ( ( body.getKind()==OR && pol ) || (body.getKind()==AND && !pol ) ) ) ){ - if( ( body.getKind()==ITE && i>0 ) || body.getKind()==OR || body.getKind()==AND ){ - Trace("quantifiers-rewrite-term-debug") << "---rewrite " << body[i] << " under conditions:----" << std::endl; + bool use_pol = body.getKind()==AND; + for( unsigned j=0; jmkConst( !use_pol ); + } + } + if( ret.isNull() ){ + for( size_t i=0; i new_cond; + bool conflict = false; + if( doCD ){ + if( Trace.isOn("quantifiers-rewrite-term-debug") ){ + if( ( body.getKind()==ITE && i>0 ) || body.getKind()==OR || body.getKind()==AND ){ + Trace("quantifiers-rewrite-term-debug") << "---rewrite " << body[i] << " under conditions:----" << std::endl; + } } - } - if( body.getKind()==ITE && i>0 ){ - setEntailedCond( children[0], i==1, currCond, new_cond, conflict ); - //should not conflict (entailment check failed) - Assert( !conflict ); - } - //if( hasPol && ( ( body.getKind()==OR && pol ) || ( body.getKind()==AND && !pol ) ) ){ - // bool use_pol = !pol; - if( body.getKind()==OR || body.getKind()==AND ){ - bool use_pol = body.getKind()==AND; - for( unsigned j=0; ji ){ - setEntailedCond( body[j], use_pol, currCond, new_cond, conflict ); + if( body.getKind()==ITE && i>0 ){ + if( i==1 ){ + nCurrCond = nCurrCond + 1; + } + setEntailedCond( children[0], i==1, currCond, new_cond, conflict ); + //should not conflict (entailment check failed) + Assert( !conflict ); + } + if( body.getKind()==OR || body.getKind()==AND ){ + bool use_pol = body.getKind()==AND; + //remove the current condition + removeEntailedCond( currCond, new_cond_children[i], cache ); + if( i>0 ){ + //add the previous condition + setEntailedCond( children[i-1], use_pol, currCond, new_cond_children[i-1], conflict ); + } + if( conflict ){ + Trace("quantifiers-rewrite-term-debug") << "-------conflict, return " << !use_pol << std::endl; + ret = NodeManager::currentNM()->mkConst( !use_pol ); } } - if( conflict ){ - Trace("quantifiers-rewrite-term-debug") << "-------conflict, return " << !use_pol << std::endl; - ret = NodeManager::currentNM()->mkConst( !use_pol ); + if( !new_cond.empty() ){ + cache.clear(); } - } - if( !new_cond.empty() ){ - cache.clear(); - } - if( Trace.isOn("quantifiers-rewrite-term-debug") ){ - //if( ( body.getKind()==ITE && i>0 ) || ( hasPol && ( ( body.getKind()==OR && pol ) || (body.getKind()==AND && !pol ) ) ) ){ - if( ( body.getKind()==ITE && i>0 ) || body.getKind()==OR || body.getKind()==AND ){ - Trace("quantifiers-rewrite-term-debug") << "-------" << std::endl; + if( Trace.isOn("quantifiers-rewrite-term-debug") ){ + if( ( body.getKind()==ITE && i>0 ) || body.getKind()==OR || body.getKind()==AND ){ + Trace("quantifiers-rewrite-term-debug") << "-------" << std::endl; + } } } - } - if( !conflict ){ - bool newHasPol; - bool newPol; - QuantPhaseReq::getPolarity( body, i, hasPol, pol, newHasPol, newPol ); - Node nn = computeProcessTerms2( body[i], newHasPol, newPol, currCond, nCurrCond, cache, icache, new_vars, new_conds ); - if( body.getKind()==ITE && i==0 ){ - int res = getEntailedCond( nn, currCond ); - Trace("quantifiers-rewrite-term-debug") << "Condition for " << body << " is " << nn << ", entailment check=" << res << std::endl; - if( res==1 ){ - ret = computeProcessTerms2( body[1], hasPol, pol, currCond, nCurrCond, cache, icache, new_vars, new_conds ); - }else if( res==-1 ){ - ret = computeProcessTerms2( body[2], hasPol, pol, currCond, nCurrCond, cache, icache, new_vars, new_conds ); + + //do the recursive call on children + if( !conflict ){ + bool newHasPol; + bool newPol; + QuantPhaseReq::getPolarity( body, i, hasPol, pol, newHasPol, newPol ); + Node nn = computeProcessTerms2( body[i], newHasPol, newPol, currCond, nCurrCond, cache, icache, new_vars, new_conds ); + if( body.getKind()==ITE && i==0 ){ + int res = getEntailedCond( nn, currCond ); + Trace("quantifiers-rewrite-term-debug") << "Condition for " << body << " is " << nn << ", entailment check=" << res << std::endl; + if( res==1 ){ + ret = computeProcessTerms2( body[1], hasPol, pol, currCond, nCurrCond, cache, icache, new_vars, new_conds ); + }else if( res==-1 ){ + ret = computeProcessTerms2( body[2], hasPol, pol, currCond, nCurrCond, cache, icache, new_vars, new_conds ); + } } + children.push_back( nn ); + changed = changed || nn!=body[i]; + } + + //clean up entailed conditions + removeEntailedCond( currCond, new_cond, cache ); + + if( !ret.isNull() ){ + break; } - children.push_back( nn ); - changed = changed || nn!=body[i]; } - if( !new_cond.empty() ){ - for( unsigned j=0; jmkNode( body.getKind(), children ); + }else{ + ret = body; } - cache.clear(); - } - if( !ret.isNull() ){ - break; } } - if( ret.isNull() ){ - if( changed ){ - if( body.getMetaKind() == kind::metakind::PARAMETERIZED ){ - children.insert( children.begin(), body.getOperator() ); - } - ret = NodeManager::currentNM()->mkNode( body.getKind(), children ); - }else{ - ret = body; + + //clean up entailed conditions + if( body.getKind()==OR || body.getKind()==AND ){ + for( unsigned j=0; jsecond; @@ -701,46 +734,60 @@ Node QuantifiersRewriter::computeProcessTerms2( Node body, bool hasPol, bool pol } } } - }else if( ret.getKind()==INTS_DIVISION_TOTAL || ret.getKind()==INTS_MODULUS_TOTAL ){ - Node num = ret[0]; - Node den = ret[1]; - if(den.isConst()) { - const Rational& rat = den.getConst(); - Assert(!num.isConst()); - if(rat != 0) { - Node intVar = NodeManager::currentNM()->mkBoundVar(NodeManager::currentNM()->integerType()); - new_vars.push_back( intVar ); - Node cond; - if(rat > 0) { - cond = NodeManager::currentNM()->mkNode(kind::AND, - NodeManager::currentNM()->mkNode(kind::LEQ, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar), num), - NodeManager::currentNM()->mkNode(kind::LT, num, - NodeManager::currentNM()->mkNode(kind::MULT, den, NodeManager::currentNM()->mkNode(kind::PLUS, intVar, NodeManager::currentNM()->mkConst(Rational(1)))))); - } else { - cond = NodeManager::currentNM()->mkNode(kind::AND, - NodeManager::currentNM()->mkNode(kind::LEQ, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar), num), - NodeManager::currentNM()->mkNode(kind::LT, num, - NodeManager::currentNM()->mkNode(kind::MULT, den, NodeManager::currentNM()->mkNode(kind::PLUS, intVar, NodeManager::currentNM()->mkConst(Rational(-1)))))); - } - new_conds.push_back( cond.negate() ); - if( ret.getKind()==INTS_DIVISION_TOTAL ){ - ret = intVar; - }else{ - ret = NodeManager::currentNM()->mkNode(kind::MINUS, num, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar)); + /* ITE lifting + if( ret.getKind()==ITE ){ + TypeNode ite_t = ret[1].getType(); + if( !ite_t.isBoolean() ){ + ite_t = TypeNode::leastCommonTypeNode( ite_t, ret[2].getType() ); + Node ite_v = NodeManager::currentNM()->mkBoundVar(ite_t); + new_vars.push_back( ite_v ); + Node cond = NodeManager::currentNM()->mkNode(kind::ITE, ret[0], ite_v.eqNode( ret[1] ), ite_v.eqNode( ret[2] ) ); + new_conds.push_back( cond.negate() ); + ret = ite_v; + } + */ + }else if( options::elimExtArithQuant() ){ + if( ret.getKind()==INTS_DIVISION_TOTAL || ret.getKind()==INTS_MODULUS_TOTAL ){ + Node num = ret[0]; + Node den = ret[1]; + if(den.isConst()) { + const Rational& rat = den.getConst(); + Assert(!num.isConst()); + if(rat != 0) { + Node intVar = NodeManager::currentNM()->mkBoundVar(NodeManager::currentNM()->integerType()); + new_vars.push_back( intVar ); + Node cond; + if(rat > 0) { + cond = NodeManager::currentNM()->mkNode(kind::AND, + NodeManager::currentNM()->mkNode(kind::LEQ, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar), num), + NodeManager::currentNM()->mkNode(kind::LT, num, + NodeManager::currentNM()->mkNode(kind::MULT, den, NodeManager::currentNM()->mkNode(kind::PLUS, intVar, NodeManager::currentNM()->mkConst(Rational(1)))))); + } else { + cond = NodeManager::currentNM()->mkNode(kind::AND, + NodeManager::currentNM()->mkNode(kind::LEQ, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar), num), + NodeManager::currentNM()->mkNode(kind::LT, num, + NodeManager::currentNM()->mkNode(kind::MULT, den, NodeManager::currentNM()->mkNode(kind::PLUS, intVar, NodeManager::currentNM()->mkConst(Rational(-1)))))); + } + new_conds.push_back( cond.negate() ); + if( ret.getKind()==INTS_DIVISION_TOTAL ){ + ret = intVar; + }else{ + ret = NodeManager::currentNM()->mkNode(kind::MINUS, num, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar)); + } } } - } - }else if( ret.getKind()==TO_INTEGER || ret.getKind()==IS_INTEGER ){ - Node intVar = NodeManager::currentNM()->mkBoundVar(NodeManager::currentNM()->integerType()); - new_vars.push_back( intVar ); - new_conds.push_back(NodeManager::currentNM()->mkNode(kind::AND, - NodeManager::currentNM()->mkNode(kind::LT, - NodeManager::currentNM()->mkNode(kind::MINUS, ret[0], NodeManager::currentNM()->mkConst(Rational(1))), intVar), - NodeManager::currentNM()->mkNode(kind::LEQ, intVar, ret[0])).negate()); - if( ret.getKind()==TO_INTEGER ){ - ret = intVar; - }else{ - ret = ret[0].eqNode( intVar ); + }else if( ret.getKind()==TO_INTEGER || ret.getKind()==IS_INTEGER ){ + Node intVar = NodeManager::currentNM()->mkBoundVar(NodeManager::currentNM()->integerType()); + new_vars.push_back( intVar ); + new_conds.push_back(NodeManager::currentNM()->mkNode(kind::AND, + NodeManager::currentNM()->mkNode(kind::LT, + NodeManager::currentNM()->mkNode(kind::MINUS, ret[0], NodeManager::currentNM()->mkConst(Rational(1))), intVar), + NodeManager::currentNM()->mkNode(kind::LEQ, intVar, ret[0])).negate()); + if( ret.getKind()==TO_INTEGER ){ + ret = intVar; + }else{ + ret = ret[0].eqNode( intVar ); + } } } icache[prev] = ret; diff --git a/src/theory/quantifiers/relevant_domain.cpp b/src/theory/quantifiers/relevant_domain.cpp index 88793358e..ce7a38fa5 100644 --- a/src/theory/quantifiers/relevant_domain.cpp +++ b/src/theory/quantifiers/relevant_domain.cpp @@ -140,7 +140,7 @@ void RelevantDomain::compute(){ } void RelevantDomain::computeRelevantDomain( Node q, Node n, bool hasPol, bool pol ) { - Node op = d_qe->getTermDatabase()->getOperator( n ); + Node op = d_qe->getTermDatabase()->getMatchOperator( n ); for( unsigned i=0; i& added, bool withinQuant, bool wi //if this is an atomic trigger, consider adding it if( inst::Trigger::isAtomicTrigger( n ) ){ Trace("term-db") << "register term in db " << n << std::endl; - Node op = getOperator( n ); + Node op = getMatchOperator( n ); d_op_map[op].push_back( n ); added.insert( n ); - + if( options::eagerInstQuant() ){ for( unsigned i=0; i& subs, bool subsRe } }else{ if( n.hasOperator() ){ - TNode f = getOperator( n ); + TNode f = getMatchOperator( n ); if( !f.isNull() ){ std::vector< TNode > args; for( unsigned i=0; igetRepresentative( n ); }else if( n.getKind()!=BOUND_VARIABLE ){ if( n.hasOperator() ){ - TNode f = getOperator( n ); + TNode f = getMatchOperator( n ); if( !f.isNull() ){ std::vector< TNode > args; for( unsigned i=0; ihasTerm( n ) ){ diff --git a/src/theory/quantifiers/term_database.h b/src/theory/quantifiers/term_database.h index b1cfcf2ae..fcacbd686 100644 --- a/src/theory/quantifiers/term_database.h +++ b/src/theory/quantifiers/term_database.h @@ -188,7 +188,6 @@ public: /** map from type nodes to terms of that type */ std::map< TypeNode, std::vector< Node > > d_type_map; - /** count number of non-redundant ground terms per operator */ std::map< Node, int > d_op_nonred_count; /**mapping from UF terms to representatives of their arguments */ @@ -212,8 +211,8 @@ public: void presolve(); /** reset (calculate which terms are active) */ void reset( Theory::Effort effort ); - /** get operator*/ - Node getOperator( Node n ); + /** get match operator */ + Node getMatchOperator( Node n ); /** get term arg index */ TermArgTrie * getTermArgTrie( Node f ); TermArgTrie * getTermArgTrie( Node eqc, Node f ); diff --git a/src/theory/quantifiers/trigger.cpp b/src/theory/quantifiers/trigger.cpp index 9aee18317..0628b7fbc 100644 --- a/src/theory/quantifiers/trigger.cpp +++ b/src/theory/quantifiers/trigger.cpp @@ -73,7 +73,7 @@ d_quantEngine( qe ), d_f( f ){ //Notice() << "Trigger : " << (*this) << " for " << f << std::endl; if( options::eagerInstQuant() ){ for( int i=0; i<(int)d_nodes.size(); i++ ){ - Node op = qe->getTermDatabase()->getOperator( d_nodes[i] ); + Node op = qe->getTermDatabase()->getMatchOperator( d_nodes[i] ); qe->getTermDatabase()->registerTrigger( this, op ); } } diff --git a/src/theory/quantifiers_engine.cpp b/src/theory/quantifiers_engine.cpp index 7cda713a1..5d19d603c 100644 --- a/src/theory/quantifiers_engine.cpp +++ b/src/theory/quantifiers_engine.cpp @@ -138,8 +138,10 @@ QuantifiersEngine::QuantifiersEngine(context::Context* c, context::UserContext* d_builder = NULL; d_total_inst_count_debug = 0; - d_ierCounter = 0; - d_ierCounter_lc = 0; + //allow theory combination to go first, once initially, when instWhenDelayIncrement = true + d_ierCounter = options::instWhenDelayIncrement() ? 1 : 0; + d_ierCounter_lc = options::instWhenDelayIncrement() ? 1 : 0; + d_ierCounterLastLc = -1; //if any strategy called only on last call, use phase 3 d_inst_when_phase = options::cbqi() ? 3 : 2; } @@ -338,10 +340,12 @@ void QuantifiersEngine::check( Theory::Effort e ){ Trace("quant-engine-debug") << "Master equality engine not consistent, return." << std::endl; return; } - if( e==Theory::EFFORT_FULL ){ - d_ierCounter++; - }else if( e==Theory::EFFORT_LAST_CALL ){ - d_ierCounter_lc++; + if( !options::instWhenDelayIncrement() ){ + if( e==Theory::EFFORT_FULL ){ + d_ierCounter++; + }else if( e==Theory::EFFORT_LAST_CALL ){ + d_ierCounter_lc++; + } } bool needsCheck = !d_lemmas_waiting.empty(); unsigned needsModelE = QEFFORT_NONE; @@ -392,14 +396,6 @@ void QuantifiersEngine::check( Theory::Effort e ){ Trace("quant-engine-debug") << " Needs model effort : " << needsModelE << std::endl; Trace("quant-engine-debug") << "Resetting all modules..." << std::endl; } - if( Trace.isOn("quant-engine-ee") ){ - Trace("quant-engine-ee") << "Equality engine : " << std::endl; - debugPrintEqualityEngine( "quant-engine-ee" ); - } - if( Trace.isOn("quant-engine-assert") ){ - Trace("quant-engine-assert") << "Assertions : " << std::endl; - getTheoryEngine()->printAssertions("quant-engine-assert"); - } //reset relevant information @@ -410,12 +406,22 @@ void QuantifiersEngine::check( Theory::Effort e ){ } Trace("quant-engine-debug2") << "Reset term db..." << std::endl; + d_eq_query->reset( e ); d_term_db->reset( e ); - d_eq_query->reset(); if( d_rel_dom ){ d_rel_dom->reset(); } d_model->reset_round(); + + if( Trace.isOn("quant-engine-ee") ){ + Trace("quant-engine-ee") << "Equality engine : " << std::endl; + debugPrintEqualityEngine( "quant-engine-ee" ); + } + if( Trace.isOn("quant-engine-assert") ){ + Trace("quant-engine-assert") << "Assertions : " << std::endl; + getTheoryEngine()->printAssertions("quant-engine-assert"); + } + for( unsigned i=0; iidentify().c_str() << std::endl; d_modules[i]->reset_round( e ); @@ -462,23 +468,39 @@ void QuantifiersEngine::check( Theory::Effort e ){ //if we have added one, stop if( d_hasAddedLemma ){ break; - }else if( e==Theory::EFFORT_LAST_CALL && quant_e==QEFFORT_MODEL ){ - //if we have a chance not to set incomplete - if( !setIncomplete ){ - setIncomplete = false; - //check if we should set the incomplete flag - for( unsigned i=0; icheckComplete() ){ - Trace("quant-engine-debug") << "Set incomplete because " << qm[i]->identify().c_str() << " was incomplete." << std::endl; - setIncomplete = true; + }else{ + if( quant_e==QEFFORT_CONFLICT ){ + if( options::instWhenDelayIncrement() ){ + if( e==Theory::EFFORT_FULL ){ + //increment if a last call happened, or if we already were in phase + if( d_ierCounterLastLc!=d_ierCounter_lc || d_ierCounter%d_inst_when_phase==0 ){ + d_ierCounter++; + d_ierCounterLastLc = d_ierCounter_lc; + } + }else if( e==Theory::EFFORT_LAST_CALL ){ + d_ierCounter_lc++; + } + } + }else if( quant_e==QEFFORT_MODEL ){ + if( e==Theory::EFFORT_LAST_CALL ){ + //if we have a chance not to set incomplete + if( !setIncomplete ){ + setIncomplete = false; + //check if we should set the incomplete flag + for( unsigned i=0; icheckComplete() ){ + Trace("quant-engine-debug") << "Set incomplete because " << qm[i]->identify().c_str() << " was incomplete." << std::endl; + setIncomplete = true; + break; + } + } + } + //if setIncomplete = false, we will answer SAT, otherwise we will run at quant_e QEFFORT_LAST_CALL + if( !setIncomplete ){ break; } } } - //if setIncomplete = false, we will answer SAT, otherwise we will run at quant_e QEFFORT_LAST_CALL - if( !setIncomplete ){ - break; - } } } Trace("quant-engine-debug") << "Done check modules that needed check." << std::endl; @@ -1022,6 +1044,7 @@ bool QuantifiersEngine::addSplitEquality( Node n1, Node n2, bool reqPhase, bool } bool QuantifiersEngine::getInstWhenNeedsCheck( Theory::Effort e ) { + Trace("quant-engine-debug2") << "Get inst when needs check, counts=" << d_ierCounter << ", " << d_ierCounter_lc << std::endl; //determine if we should perform check, based on instWhenMode bool performCheck = false; if( options::instWhenMode()==quantifiers::INST_WHEN_FULL ){ @@ -1225,9 +1248,64 @@ void QuantifiersEngine::debugPrintEqualityEngine( const char * c ) { } } -void EqualityQueryQuantifiersEngine::reset(){ +void EqualityQueryQuantifiersEngine::reset( Theory::Effort e ){ d_int_rep.clear(); d_reset_count++; + processInferences( e ); +} + +void EqualityQueryQuantifiersEngine::processInferences( Theory::Effort e ) { + if( options::inferArithTriggerEq() ){ + std::vector< Node > infer; + std::vector< Node > infer_exp; + eq::EqualityEngine* ee = getEngine(); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( ee ); + while( !eqcs_i.isFinished() ){ + TNode r = (*eqcs_i); + TypeNode tr = r.getType(); + if( tr.isReal() ){ + std::vector< Node > eqc; + eq::EqClassIterator eqc_i = eq::EqClassIterator( r, ee ); + while( !eqc_i.isFinished() ){ + TNode n = (*eqc_i); + //accumulate equivalence class + eqc.push_back( n ); + ++eqc_i; + } + for( unsigned i=0; i msum; + QuantArith::getMonomialSum( n, msum ); + for( std::map< Node, Node >::iterator it = msum.begin(); it != msum.end(); ++it ){ + //if the term is a trigger + Node t = it->first; + if( inst::Trigger::isAtomicTrigger( t ) ){ + for( unsigned j=0; jhasTerm( v ) && ee->getRepresentative( v )!=r ){ + Trace("quant-engine-ee-proc") << "processInferences : Infer : " << t << " == " << v << " from " << n << " == " << eqc[j] << std::endl; + infer.push_back( t.eqNode( v ) ); + infer_exp.push_back( n.eqNode( eqc[j] ) ); + } + } + } + } + } + } + } + } + ++eqcs_i; + } + for( unsigned i=0; iassertEquality( infer[i], true, infer_exp[i] ); + } + Assert( ee->consistent() ); + } } bool EqualityQueryQuantifiersEngine::hasTerm( Node a ){ diff --git a/src/theory/quantifiers_engine.h b/src/theory/quantifiers_engine.h index 92296ebac..0c43223d8 100644 --- a/src/theory/quantifiers_engine.h +++ b/src/theory/quantifiers_engine.h @@ -206,6 +206,7 @@ private: /** inst round counters */ int d_ierCounter; int d_ierCounter_lc; + int d_ierCounterLastLc; int d_inst_when_phase; /** has presolve been called */ context::CDO< bool > d_presolve; @@ -399,6 +400,8 @@ private: /** reset count */ int d_reset_count; + /** processInferences : will merge equivalence classes in master equality engine, if possible */ + void processInferences( Theory::Effort e ); /** node contains */ Node getInstance( Node n, const std::vector< Node >& eqc, std::hash_map& cache ); /** get score */ @@ -407,7 +410,7 @@ public: EqualityQueryQuantifiersEngine( QuantifiersEngine* qe ) : d_qe( qe ), d_reset_count( 0 ){} ~EqualityQueryQuantifiersEngine(){} /** reset */ - void reset(); + void reset( Theory::Effort e ); /** general queries about equality */ bool hasTerm( Node a ); Node getRepresentative( Node a ); diff --git a/src/theory/sort_inference.cpp b/src/theory/sort_inference.cpp index 060584fcf..4b29148a7 100644 --- a/src/theory/sort_inference.cpp +++ b/src/theory/sort_inference.cpp @@ -118,10 +118,11 @@ void SortInference::simplify( std::vector< Node >& assertions, bool doSortInfere if( doSortInference ){ Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl; //process all assertions + std::map< Node, int > visited; for( unsigned i=0; i var_bound; - process( assertions[i], var_bound ); + process( assertions[i], var_bound, visited ); } Trace("sort-inference-proc") << "...done" << std::endl; for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){ @@ -155,10 +156,11 @@ void SortInference::simplify( std::vector< Node >& assertions, bool doSortInfere bool rewritten = false; //determine monotonicity of sorts Trace("sort-inference-proc") << "Calculating monotonicty for subsorts..." << std::endl; + std::map< Node, std::map< int, bool > > visited; for( unsigned i=0; i var_bound; - processMonotonic( assertions[i], true, true, var_bound ); + processMonotonic( assertions[i], true, true, var_bound, visited ); } Trace("sort-inference-proc") << "...done" << std::endl; @@ -176,13 +178,16 @@ void SortInference::simplify( std::vector< Node >& assertions, bool doSortInfere //simplify all assertions by introducing new symbols wherever necessary Trace("sort-inference-proc") << "Perform simplification..." << std::endl; + std::map< Node, std::map< TypeNode, Node > > visited2; for( unsigned i=0; i var_bound; - Trace("sort-inference-debug") << "Rewrite " << assertions[i] << std::endl; - Node curr = simplify( assertions[i], var_bound ); + Trace("sort-inference-debug") << "Simplify " << assertions[i] << std::endl; + TypeNode tnn; + Node curr = simplifyNode( assertions[i], var_bound, tnn, visited2 ); Trace("sort-inference-debug") << "Done." << std::endl; if( curr!=assertions[i] ){ + Trace("sort-inference-debug") << "Rewrite " << curr << std::endl; curr = theory::Rewriter::rewrite( curr ); rewritten = true; Trace("sort-inference-rewrite") << assertions << std::endl; @@ -196,9 +201,17 @@ void SortInference::simplify( std::vector< Node >& assertions, bool doSortInfere for( std::map< TypeNode, std::map< Node, Node > >::iterator it = d_const_map.begin(); it != d_const_map.end(); ++it ){ std::vector< Node > consts; for( std::map< Node, Node >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){ + Assert( it2->first.isConst() ); consts.push_back( it2->second ); } - //TODO: add lemma enforcing introduced constants to be distinct + //add lemma enforcing introduced constants to be distinct + if( consts.size()>1 ){ + Node distinct_const = NodeManager::currentNM()->mkNode( kind::DISTINCT, consts ); + Trace("sort-inference-rewrite") << "Add the constant distinctness lemma: " << std::endl; + Trace("sort-inference-rewrite") << " " << distinct_const << std::endl; + assertions.push_back( distinct_const ); + rewritten = true; + } } //enforce constraints based on monotonicity @@ -242,43 +255,15 @@ void SortInference::simplify( std::vector< Node >& assertions, bool doSortInfere reset(); Trace("sort-inference-debug") << "Finished sort inference, rewritten = " << rewritten << std::endl; } - /* - else if( !options::ufssSymBreak() ){ - //just add the unit lemmas between constants - std::map< TypeNode, std::map< int, Node > > constants; - for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){ - int rt = d_type_union_find.getRepresentative( it->second ); - if( d_op_arg_types[ it->first ].empty() ){ - TypeNode tn = it->first.getType(); - if( constants[ tn ].find( rt )==constants[ tn ].end() ){ - constants[ tn ][ rt ] = it->first; - } - } - } - //add unit lemmas for each constant - for( std::map< TypeNode, std::map< int, Node > >::iterator it = constants.begin(); it != constants.end(); ++it ){ - Node first_const; - for( std::map< int, Node >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){ - if( first_const.isNull() ){ - first_const = it2->second; - }else{ - Node eq = first_const.eqNode( it2->second ); - //eq = Rewriter::rewrite( eq ); - Trace("sort-inference-lemma") << "Sort inference lemma : " << eq << std::endl; - assertions.push_back( eq ); - } - } - } - } - */ initialSortCount = sortCount; } if( doMonotonicyInference ){ + std::map< Node, std::map< int, bool > > visited; Trace("sort-inference-proc") << "Calculating monotonicty for types..." << std::endl; for( unsigned i=0; i var_bound; - processMonotonic( assertions[i], true, true, var_bound, true ); + processMonotonic( assertions[i], true, true, var_bound, visited, true ); } Trace("sort-inference-proc") << "...done" << std::endl; } @@ -338,174 +323,185 @@ int SortInference::getIdForType( TypeNode tn ){ } } -int SortInference::process( Node n, std::map< Node, Node >& var_bound ){ - //add to variable bindings - if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){ - if( d_var_types.find( n )!=d_var_types.end() ){ - return getIdForType( n.getType() ); - }else{ - for( size_t i=0; i& var_bound, std::map< Node, int >& visited ){ + std::map< Node, int >::iterator itv = visited.find( n ); + if( itv!=visited.end() ){ + return itv->second; + }else{ + //add to variable bindings + bool use_new_visited = false; + std::map< Node, int > new_visited; + if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){ + if( d_var_types.find( n )!=d_var_types.end() ){ + return getIdForType( n.getType() ); + }else{ + for( size_t i=0; i children; - std::vector< int > child_types; - for( size_t i=0; i=1; - } - if( processChild ){ - children.push_back( n[i] ); - child_types.push_back( process( n[i], var_bound ) ); + //process children + std::vector< Node > children; + std::vector< int > child_types; + for( size_t i=0; i=1; + } + if( processChild ){ + children.push_back( n[i] ); + child_types.push_back( process( n[i], var_bound, use_new_visited ? new_visited : visited ) ); + } } - } - //remove from variable bindings - if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){ - //erase from variable bound - for( size_t i=0; i::iterator it = var_bound.find( n ); - if( it!=var_bound.end() ){ - Trace("sort-inference-debug") << n << " is a bound variable." << std::endl; - //the return type was specified while binding - retType = d_var_types[it->second][n]; - }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){ - Trace("sort-inference-debug") << n << " is a variable." << std::endl; - if( d_op_return_types.find( n )==d_op_return_types.end() ){ - //assign arbitrary sort - d_op_return_types[n] = sortCount; - sortCount++; - //d_type_eq_class[sortCount].push_back( n ); + for( size_t i=0; i::iterator it = var_bound.find( n ); + if( it!=var_bound.end() ){ + Trace("sort-inference-debug") << n << " is a bound variable." << std::endl; + //the return type was specified while binding + retType = d_var_types[it->second][n]; + }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){ + Trace("sort-inference-debug") << n << " is a variable." << std::endl; + if( d_op_return_types.find( n )==d_op_return_types.end() ){ + //assign arbitrary sort + d_op_return_types[n] = sortCount; + sortCount++; + //d_type_eq_class[sortCount].push_back( n ); + } + retType = d_op_return_types[n]; + }else if( n.isConst() ){ + Trace("sort-inference-debug") << n << " is a constant." << std::endl; + //can be any type we want + retType = sortCount; + sortCount++; + }else{ + Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl; + //it is an interpreted term + for( size_t i=0; i& var_bound, bool typeMode ) { - Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl; - if( n.getKind()==kind::FORALL ){ - //only consider variables universally if it is possible this quantified formula is asserted positively - if( !hasPol || pol ){ - for( unsigned i=0; i& var_bound, std::map< Node, std::map< int, bool > >& visited, bool typeMode ) { + int pindex = hasPol ? ( pol ? 1 : -1 ) : 0; + if( visited[n].find( pindex )==visited[n].end() ){ + visited[n][pindex] = true; + Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl; + if( n.getKind()==kind::FORALL ){ + //only consider variables universally if it is possible this quantified formula is asserted positively + if( !hasPol || pol ){ + for( unsigned i=0; i& var_bound ){ - Trace("sort-inference-debug2") << "Simplify " << n << std::endl; - std::vector< Node > children; - if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){ - //recreate based on types of variables - std::vector< Node > new_children; - for( size_t i=0; i& var_bound, TypeNode tnn, std::map< Node, std::map< TypeNode, Node > >& visited ){ + std::map< TypeNode, Node >::iterator itv = visited[n].find( tnn ); + if( itv!=visited[n].end() ){ + return itv->second; + }else{ + Trace("sort-inference-debug2") << "Simplify " << n << ", type context=" << tnn << std::endl; + std::vector< Node > children; + std::map< Node, std::map< TypeNode, Node > > new_visited; + bool use_new_visited = false; + if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){ + //recreate based on types of variables + std::vector< Node > new_children; + for( size_t i=0; imkNode( n[0].getKind(), new_children ) ); + use_new_visited = true; } - children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) ); - } - //process children - if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){ - children.push_back( n.getOperator() ); - } - bool childChanged = false; - for( size_t i=0; i=1; + //process children + if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){ + children.push_back( n.getOperator() ); } - if( processChild ){ - Node nc = simplify( n[i], var_bound ); - Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl; - children.push_back( nc ); - childChanged = childChanged || nc!=n[i]; + Node op; + if( n.hasOperator() ){ + op = n.getOperator(); } - } - - //remove from variable bindings - if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){ - //erase from variable bound - for( size_t i=0; i=1; + } + if( processChild ){ + if( n.getKind()==kind::APPLY_UF ){ + Assert( d_op_arg_types.find( op )!=d_op_arg_types.end() ); + tnnc = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() ); + Assert( !tnnc.isNull() ); + }else if( n.getKind()==kind::EQUAL && i==0 ){ + Assert( d_equality_types.find( n )!=d_equality_types.end() ); + tnnc = getOrCreateTypeForId( d_equality_types[n], n[0].getType() ); + Assert( !tnnc.isNull() ); + } + Node nc = simplifyNode( n[i], var_bound, tnnc, use_new_visited ? new_visited : visited ); + Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl; + children.push_back( nc ); + childChanged = childChanged || nc!=n[i]; + } } - return NodeManager::currentNM()->mkNode( n.getKind(), children ); - }else if( n.getKind()==kind::EQUAL ){ - TypeNode tn1 = children[0].getType(); - TypeNode tn2 = children[1].getType(); - if( !tn1.isSubtypeOf( tn2 ) && !tn2.isSubtypeOf( tn1 ) ){ - if( children[0].isConst() ){ - children[0] = getNewSymbol( children[0], children[1].getType() ); - }else if( children[1].isConst() ){ - children[1] = getNewSymbol( children[1], children[0].getType() ); - }else{ + + //remove from variable bindings + Node ret; + if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){ + //erase from variable bound + for( size_t i=0; imkNode( n.getKind(), children ); + }else if( n.getKind()==kind::EQUAL ){ + TypeNode tn1 = children[0].getType(); + TypeNode tn2 = children[1].getType(); + if( !tn1.isSubtypeOf( tn2 ) && !tn2.isSubtypeOf( tn1 ) ){ Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl; Trace("sort-inference-warn") << " Types : " << children[0].getType() << " " << children[1].getType() << std::endl; Assert( false ); } - } - return NodeManager::currentNM()->mkNode( kind::EQUAL, children ); - }else if( n.getKind()==kind::APPLY_UF ){ - Node op = n.getOperator(); - if( d_symbol_map.find( op )==d_symbol_map.end() ){ - //make the new operator if necessary - bool opChanged = false; - std::vector< TypeNode > argTypes; - for( size_t i=0; imkNode( kind::EQUAL, children ); + }else if( n.getKind()==kind::APPLY_UF ){ + if( d_symbol_map.find( op )==d_symbol_map.end() ){ + //make the new operator if necessary + bool opChanged = false; + std::vector< TypeNode > argTypes; + for( size_t i=0; imkFunctionType( argTypes, retType ); - d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" ); - Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl; - d_model_replace_f[op] = d_symbol_map[op]; - }else{ - d_symbol_map[op] = op; - } - } - children[0] = d_symbol_map[op]; - //make sure all children have been taken care of - for( size_t i=0; imkFunctionType( argTypes, retType ); + d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" ); + Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl; + d_model_replace_f[op] = d_symbol_map[op]; }else{ + d_symbol_map[op] = op; + } + } + children[0] = d_symbol_map[op]; + //make sure all children have been taken care of + for( size_t i=0; imkNode( kind::APPLY_UF, children ); - }else{ - std::map< Node, Node >::iterator it = var_bound.find( n ); - if( it!=var_bound.end() ){ - return it->second; - }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){ - if( d_symbol_map.find( n )==d_symbol_map.end() ){ - TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() ); - d_symbol_map[n] = getNewSymbol( n, tn ); - } - return d_symbol_map[n]; - }else if( n.isConst() ){ - //just return n, we will fix at higher scope - return n; + ret = NodeManager::currentNM()->mkNode( kind::APPLY_UF, children ); }else{ - if( childChanged ){ - return NodeManager::currentNM()->mkNode( n.getKind(), children ); + std::map< Node, Node >::iterator it = var_bound.find( n ); + if( it!=var_bound.end() ){ + ret = it->second; + }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){ + if( d_symbol_map.find( n )==d_symbol_map.end() ){ + TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() ); + d_symbol_map[n] = getNewSymbol( n, tn ); + } + ret = d_symbol_map[n]; + }else if( n.isConst() ){ + //type is determined by context + ret = getNewSymbol( n, tnn ); + }else if( childChanged ){ + ret = NodeManager::currentNM()->mkNode( n.getKind(), children ); }else{ - return n; + ret = n; } } + visited[n][tnn] = ret; + return ret; } - } Node SortInference::mkInjection( TypeNode tn1, TypeNode tn2 ) { @@ -728,7 +735,8 @@ void SortInference::setSkolemVar( Node f, Node v, Node sk ){ if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){ //calculate the sort for variables if not done so already std::map< Node, Node > var_bound; - process( f, var_bound ); + std::map< Node, int > visited; + process( f, var_bound, visited ); } d_op_return_types[sk] = getSortId( f, v ); Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl; diff --git a/src/theory/sort_inference.h b/src/theory/sort_inference.h index f926776de..163a3c53e 100644 --- a/src/theory/sort_inference.h +++ b/src/theory/sort_inference.h @@ -61,6 +61,7 @@ private: //for apply uf operators std::map< Node, int > d_op_return_types; std::map< Node, std::vector< int > > d_op_arg_types; + std::map< Node, int > d_equality_types; //for bound variables std::map< Node, std::map< Node, int > > d_var_types; //get representative @@ -68,10 +69,10 @@ private: int getIdForType( TypeNode tn ); void printSort( const char* c, int t ); //process - int process( Node n, std::map< Node, Node >& var_bound ); + int process( Node n, std::map< Node, Node >& var_bound, std::map< Node, int >& visited ); //for monotonicity inference private: - void processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, bool typeMode = false ); + void processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, std::map< Node, std::map< int, bool > >& visited, bool typeMode = false ); //for rewriting private: @@ -84,7 +85,7 @@ private: TypeNode getTypeForId( int t ); Node getNewSymbol( Node old, TypeNode tn ); //simplify - Node simplify( Node n, std::map< Node, Node >& var_bound ); + Node simplifyNode( Node n, std::map< Node, Node >& var_bound, TypeNode tnn, std::map< Node, std::map< TypeNode, Node > >& visited ); //make injection Node mkInjection( TypeNode tn1, TypeNode tn2 ); //reset -- 2.30.2