From: ajreynol Date: Tue, 16 Feb 2016 20:55:28 +0000 (-0600) Subject: Public interface for quantifier elimination. Minor changes to datatypes rewriter. X-Git-Tag: cvc5-1.0.0~6049^2~123 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=c603a047ac534ed4caafb128b5d333e05e1fd191;p=cvc5.git Public interface for quantifier elimination. Minor changes to datatypes rewriter. --- diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g index fb3b5ec5e..9d2392715 100644 --- a/src/parser/smt2/Smt2.g +++ b/src/parser/smt2/Smt2.g @@ -1302,6 +1302,12 @@ extendedCommand[CVC4::Command*& cmd] | SIMPLIFY_TOK { PARSER_STATE->checkThatLogicIsSet(); } term[e,e2] { cmd = new SimplifyCommand(e); } + | GET_QE_TOK { PARSER_STATE->checkThatLogicIsSet(); } + term[e,e2] + { cmd = new GetQuantifierEliminationCommand(e, true); } + | GET_QE_DISJUNCT_TOK { PARSER_STATE->checkThatLogicIsSet(); } + term[e,e2] + { cmd = new GetQuantifierEliminationCommand(e, false); } ; @@ -2562,6 +2568,8 @@ DECLARE_CONST_TOK : 'declare-const'; DEFINE_CONST_TOK : 'define-const'; SIMPLIFY_TOK : 'simplify'; INCLUDE_TOK : 'include'; +GET_QE_TOK : 'get-qe'; +GET_QE_DISJUNCT_TOK : 'get-qe-disjunct'; // SyGuS commands SYNTH_FUN_TOK : 'synth-fun'; diff --git a/src/smt/command.cpp b/src/smt/command.cpp index d6ec0769a..89d7b5ca2 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -1233,6 +1233,60 @@ std::string GetSynthSolutionCommand::getCommandName() const throw() { return "get-instantiations"; } +/* class GetQuantifierEliminationCommand */ + +GetQuantifierEliminationCommand::GetQuantifierEliminationCommand() throw() : + d_expr() { +} + +GetQuantifierEliminationCommand::GetQuantifierEliminationCommand(const Expr& expr, bool doFull) throw() : + d_expr(expr), d_doFull(doFull) { +} + +Expr GetQuantifierEliminationCommand::getExpr() const throw() { + return d_expr; +} +bool GetQuantifierEliminationCommand::getDoFull() const throw() { + return d_doFull; +} + +void GetQuantifierEliminationCommand::invoke(SmtEngine* smtEngine) throw() { + try { + d_result = smtEngine->doQuantifierElimination(d_expr, d_doFull); + d_commandStatus = CommandSuccess::instance(); + } catch(exception& e) { + d_commandStatus = new CommandFailure(e.what()); + } +} + +Expr GetQuantifierEliminationCommand::getResult() const throw() { + return d_result; +} + +void GetQuantifierEliminationCommand::printResult(std::ostream& out, uint32_t verbosity) const throw() { + if(! ok()) { + this->Command::printResult(out, verbosity); + } else { + out << d_result << endl; + } +} + +Command* GetQuantifierEliminationCommand::exportTo(ExprManager* exprManager, ExprManagerMapCollection& variableMap) { + GetQuantifierEliminationCommand* c = new GetQuantifierEliminationCommand(d_expr.exportTo(exprManager, variableMap), d_doFull); + c->d_result = d_result; + return c; +} + +Command* GetQuantifierEliminationCommand::clone() const { + GetQuantifierEliminationCommand* c = new GetQuantifierEliminationCommand(d_expr, d_doFull); + c->d_result = d_result; + return c; +} + +std::string GetQuantifierEliminationCommand::getCommandName() const throw() { + return d_doFull ? "get-qe" : "get-qe-disjunct"; +} + /* class GetUnsatCoreCommand */ GetUnsatCoreCommand::GetUnsatCoreCommand() throw() { diff --git a/src/smt/command.h b/src/smt/command.h index 248e69b0e..512f147da 100644 --- a/src/smt/command.h +++ b/src/smt/command.h @@ -624,6 +624,25 @@ public: std::string getCommandName() const throw(); };/* class GetSynthSolutionCommand */ +class CVC4_PUBLIC GetQuantifierEliminationCommand : public Command { +protected: + Expr d_expr; + bool d_doFull; + Expr d_result; +public: + GetQuantifierEliminationCommand() throw(); + GetQuantifierEliminationCommand(const Expr& expr, bool doFull) throw(); + ~GetQuantifierEliminationCommand() throw() {} + Expr getExpr() const throw(); + bool getDoFull() const throw(); + void invoke(SmtEngine* smtEngine) throw(); + Expr getResult() const throw(); + void printResult(std::ostream& out, uint32_t verbosity = 2) const throw(); + Command* exportTo(ExprManager* exprManager, ExprManagerMapCollection& variableMap); + Command* clone() const; + std::string getCommandName() const throw(); +};/* class GetQuantifierEliminationCommand */ + class CVC4_PUBLIC GetUnsatCoreCommand : public Command { protected: UnsatCore d_result; diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 007c5e049..3dc64d61a 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -955,6 +955,45 @@ public: std::ostream* getReplayLog() const { return d_managedReplayLog.getReplayLog(); } + + Node replaceQuantifiersWithInstantiations( Node n, std::map< Node, std::vector< Node > >& insts, std::map< Node, Node >& visited ){ + std::map< Node, Node >::iterator itv = visited.find( n ); + if( itv!=visited.end() ){ + return itv->second; + }else{ + Node ret = n; + if( n.getKind()==kind::FORALL ){ + std::map< Node, std::vector< Node > >::iterator it = insts.find( n ); + if( it==insts.end() ){ + Trace("smt-qe-debug") << "* " << n << " has no instances" << std::endl; + ret = NodeManager::currentNM()->mkConst(true); + }else{ + Trace("smt-qe-debug") << "* " << n << " has " << it->second.size() << " instances" << std::endl; + Node reti = it->second.empty() ? NodeManager::currentNM()->mkConst(true) : ( it->second.size()==1 ? it->second[0] : NodeManager::currentNM()->mkNode( kind::AND, it->second ) ); + Trace("smt-qe-debug") << " return : " << ret << std::endl; + //recursive (for nested quantification) + ret = replaceQuantifiersWithInstantiations( reti, insts, visited ); + } + }else if( n.getNumChildren()>0 ){ + bool childChanged = false; + std::vector< Node > children; + if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){ + children.push_back( n.getOperator() ); + } + for( unsigned i=0; imkNode( n.getKind(), children ); + } + } + visited[n] = ret; + return ret; + } + } + };/* class SmtEnginePrivate */ }/* namespace CVC4::smt */ @@ -4318,80 +4357,80 @@ Result SmtEngine::query(const Expr& ex, bool inUnsatCore) throw(TypeCheckingExce Trace("smt") << "SMT query(" << ex << ")" << endl; try { - if(d_queryMade && !options::incrementalSolving()) { - throw ModalException("Cannot make multiple queries unless " - "incremental solving is enabled " - "(try --incremental)"); - } + if(d_queryMade && !options::incrementalSolving()) { + throw ModalException("Cannot make multiple queries unless " + "incremental solving is enabled " + "(try --incremental)"); + } - // Substitute out any abstract values in ex - Expr e = d_private->substituteAbstractValues(Node::fromExpr(ex)).toExpr(); - // Ensure that the expression is type-checked at this point, and Boolean - ensureBoolean(e); + // Substitute out any abstract values in ex + Expr e = d_private->substituteAbstractValues(Node::fromExpr(ex)).toExpr(); + // Ensure that the expression is type-checked at this point, and Boolean + ensureBoolean(e); - // check to see if a postsolve() is pending - if(d_needPostsolve) { - d_theoryEngine->postsolve(); - d_needPostsolve = false; - } + // check to see if a postsolve() is pending + if(d_needPostsolve) { + d_theoryEngine->postsolve(); + d_needPostsolve = false; + } - // Push the context - internalPush(); + // Push the context + internalPush(); - // Note that a query has been made - d_queryMade = true; + // Note that a query has been made + d_queryMade = true; - // Add the formula - d_problemExtended = true; - if(d_assertionList != NULL) { - d_assertionList->push_back(e.notExpr()); - } - d_private->addFormula(e.getNode().notNode(), inUnsatCore); + // Add the formula + d_problemExtended = true; + if(d_assertionList != NULL) { + d_assertionList->push_back(e.notExpr()); + } + d_private->addFormula(e.getNode().notNode(), inUnsatCore); - // Run the check - Result r(Result::SAT_UNKNOWN, Result::UNKNOWN_REASON); - r = check().asValidityResult(); - d_needPostsolve = true; + // Run the check + Result r(Result::SAT_UNKNOWN, Result::UNKNOWN_REASON); + r = check().asValidityResult(); + d_needPostsolve = true; - // Dump the query if requested - if(Dump.isOn("benchmark")) { - // the expr already got dumped out if assertion-dumping is on - Dump("benchmark") << QueryCommand(ex); - } + // Dump the query if requested + if(Dump.isOn("benchmark")) { + // the expr already got dumped out if assertion-dumping is on + Dump("benchmark") << QueryCommand(ex); + } - // Pop the context - internalPop(); + // Pop the context + internalPop(); - // Remember the status - d_status = r; + // Remember the status + d_status = r; - d_problemExtended = false; + d_problemExtended = false; - Trace("smt") << "SMT query(" << e << ") ==> " << r << endl; + Trace("smt") << "SMT query(" << e << ") ==> " << r << endl; - // Check that SAT results generate a model correctly. - if(options::checkModels()) { - if(r.asSatisfiabilityResult().isSat() == Result::SAT || - (r.isUnknown() && r.whyUnknown() == Result::INCOMPLETE) ){ - checkModel(/* hard failure iff */ ! r.isUnknown()); + // Check that SAT results generate a model correctly. + if(options::checkModels()) { + if(r.asSatisfiabilityResult().isSat() == Result::SAT || + (r.isUnknown() && r.whyUnknown() == Result::INCOMPLETE) ){ + checkModel(/* hard failure iff */ ! r.isUnknown()); + } } - } - // Check that UNSAT results generate a proof correctly. - if(options::checkProofs()) { - if(r.asSatisfiabilityResult().isSat() == Result::UNSAT) { - TimerStat::CodeTimer checkProofTimer(d_stats->d_checkProofTime); - checkProof(); + // Check that UNSAT results generate a proof correctly. + if(options::checkProofs()) { + if(r.asSatisfiabilityResult().isSat() == Result::UNSAT) { + TimerStat::CodeTimer checkProofTimer(d_stats->d_checkProofTime); + checkProof(); + } } - } - // Check that UNSAT results generate an unsat core correctly. - if(options::checkUnsatCores()) { - if(r.asSatisfiabilityResult().isSat() == Result::UNSAT) { - TimerStat::CodeTimer checkUnsatCoreTimer(d_stats->d_checkUnsatCoreTime); - checkUnsatCore(); + // Check that UNSAT results generate an unsat core correctly. + if(options::checkUnsatCores()) { + if(r.asSatisfiabilityResult().isSat() == Result::UNSAT) { + TimerStat::CodeTimer checkUnsatCoreTimer(d_stats->d_checkUnsatCoreTime); + checkUnsatCore(); + } } - } - return r; + return r; } catch (UnsafeInterruptException& e) { AlwaysAssert(d_private->getResourceManager()->out()); Result::UnknownExplanation why = d_private->getResourceManager()->outOfResources() ? @@ -5017,6 +5056,37 @@ void SmtEngine::printSynthSolution( std::ostream& out ) { } } +Expr SmtEngine::doQuantifierElimination(const Expr& e, bool doFull) { + SmtScope smts(this); + Trace("smt-qe") << "Do quantifier elimination " << e << std::endl; + Result r = query(e); + Trace("smt-qe") << "Query returned " << r << std::endl; + if(r.asSatisfiabilityResult().isSat() == Result::SAT) { + Node input = Node::fromExpr( e ); + input = Rewriter::rewrite( input ); + Trace("smt-qe") << "Replace instances in rewritten input: " << input << std::endl; + std::map< Node, std::vector< Node > > insts; + d_theoryEngine->getInstantiations( insts ); + std::map< Node, Node > visited; + Node en = d_private->replaceQuantifiersWithInstantiations( input, insts, visited ); + + //ensure all instantiations were accounted for + for( std::map< Node, std::vector< Node > >::iterator it = insts.begin(); it != insts.end(); ++it ){ + if( visited.find( it->first )==visited.end() ){ + stringstream ss; + ss << "While performing quantifier elimination, processed a quantified formula : " << it->first; + ss << " that was not related to the query. Try option --simplification=none." << std::endl; + InternalError(ss.str().c_str()); + } + } + Trace("smt-qe") << "Returned : " << en << std::endl; + en = Rewriter::rewrite( en ); + return en.toExpr(); + }else{ + return NodeManager::currentNM()->mkConst(false).toExpr(); + } +} + vector SmtEngine::getAssertions() throw(ModalException) { SmtScope smts(this); finalOptionsAreSet(); diff --git a/src/smt/smt_engine.h b/src/smt/smt_engine.h index 3616762bc..5aa33731e 100644 --- a/src/smt/smt_engine.h +++ b/src/smt/smt_engine.h @@ -552,6 +552,11 @@ public: */ void printSynthSolution( std::ostream& out ); + /** + * Do quantifier elimination, doFull false means just output one disjunct + */ + Expr doQuantifierElimination(const Expr& e, bool doFull); + /** * Get an unsatisfiable core (only if immediately preceded by an * UNSAT or VALID query). Only permitted if CVC4 was built with diff --git a/src/theory/datatypes/datatypes_rewriter.h b/src/theory/datatypes/datatypes_rewriter.h index 0c00ed8df..ffec86477 100644 --- a/src/theory/datatypes/datatypes_rewriter.h +++ b/src/theory/datatypes/datatypes_rewriter.h @@ -200,20 +200,24 @@ public: } } - if(in.getKind() == kind::EQUAL && in[0] == in[1]) { - return RewriteResponse(REWRITE_DONE, - NodeManager::currentNM()->mkConst(true)); - } if(in.getKind() == kind::EQUAL ) { - std::vector< Node > rew; - if( checkClash(in[0], in[1], rew) ){ - Trace("datatypes-rewrite") << "Rewrite clashing equality " << in << " to false" << std::endl; - return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(false)); - }else if( rew.size()==1 && rew[0]!=in ){ - Trace("datatypes-rewrite") << "Rewrite equality " << in << " to " << rew[0] << std::endl; - return RewriteResponse(REWRITE_AGAIN_FULL, rew[0] ); + if(in[0] == in[1]) { + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true)); }else{ - Trace("datatypes-rewrite-debug") << "Did not rewrite equality " << in << " " << in[0].getKind() << " " << in[1].getKind() << std::endl; + std::vector< Node > rew; + if( checkClash(in[0], in[1], rew) ){ + Trace("datatypes-rewrite") << "Rewrite clashing equality " << in << " to false" << std::endl; + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(false)); + }else if( rew.size()==1 && rew[0]!=in ){ + Trace("datatypes-rewrite") << "Rewrite equality " << in << " to " << rew[0] << std::endl; + return RewriteResponse(REWRITE_AGAIN_FULL, rew[0] ); + }else if( in[1]mkNode(in.getKind(), in[1], in[0]); + Trace("datatypes-rewrite") << "Swap equality " << in << " to " << ins << std::endl; + return RewriteResponse(REWRITE_DONE, ins); + }else{ + Trace("datatypes-rewrite-debug") << "Did not rewrite equality " << in << " " << in[0].getKind() << " " << in[1].getKind() << std::endl; + } } } diff --git a/src/theory/quantifiers/inst_match.cpp b/src/theory/quantifiers/inst_match.cpp index 5eca87903..f2d5c640d 100644 --- a/src/theory/quantifiers/inst_match.cpp +++ b/src/theory/quantifiers/inst_match.cpp @@ -212,6 +212,18 @@ void InstMatchTrie::print( std::ostream& out, Node q, std::vector< TNode >& term } } +void InstMatchTrie::getInstantiations( std::vector< Node >& insts, Node q, std::vector< TNode >& vars, std::vector< TNode >& terms ) const { + if( terms.size()==q[0].getNumChildren() ){ + insts.push_back( q[1].substitute( vars.begin(), vars.end(), terms.begin(), terms.end() ) ); + }else{ + for( std::map< Node, InstMatchTrie >::const_iterator it = d_data.begin(); it != d_data.end(); ++it ){ + terms.push_back( it->first ); + it->second.getInstantiations( insts, q, vars, terms ); + terms.pop_back(); + } + } +} + bool CDInstMatchTrie::addInstMatch( QuantifiersEngine* qe, Node f, std::vector< Node >& m, context::Context* c, bool modEq, bool modInst, int index, bool onlyExist ){ @@ -298,6 +310,20 @@ void CDInstMatchTrie::print( std::ostream& out, Node q, std::vector< TNode >& te } } +void CDInstMatchTrie::getInstantiations( std::vector< Node >& insts, Node q, std::vector< TNode >& vars, std::vector< TNode >& terms ) const{ + if( d_valid.get() ){ + if( terms.size()==q[0].getNumChildren() ){ + insts.push_back( q[1].substitute( vars.begin(), vars.end(), terms.begin(), terms.end() ) ); + }else{ + for( std::map< Node, CDInstMatchTrie* >::const_iterator it = d_data.begin(); it != d_data.end(); ++it ){ + terms.push_back( it->first ); + it->second->getInstantiations( insts, q, vars, terms ); + terms.pop_back(); + } + } + } +} + }/* CVC4::theory::inst namespace */ }/* CVC4::theory namespace */ }/* CVC4 namespace */ diff --git a/src/theory/quantifiers/inst_match.h b/src/theory/quantifiers/inst_match.h index f1c1c952a..abe31b48d 100644 --- a/src/theory/quantifiers/inst_match.h +++ b/src/theory/quantifiers/inst_match.h @@ -101,6 +101,7 @@ public: std::map< Node, InstMatchTrie > d_data; private: void print( std::ostream& out, Node q, std::vector< TNode >& terms ) const; + void getInstantiations( std::vector< Node >& insts, Node q, std::vector< TNode >& vars, std::vector< TNode >& terms ) const; public: InstMatchTrie(){} ~InstMatchTrie(){} @@ -131,6 +132,14 @@ public: std::vector< TNode > terms; print( out, q, terms ); } + void getInstantiations( std::vector< Node >& insts, Node q ) { + std::vector< TNode > terms; + std::vector< TNode > vars; + for( unsigned i=0; i d_valid; private: void print( std::ostream& out, Node q, std::vector< TNode >& terms ) const; + void getInstantiations( std::vector< Node >& insts, Node q, std::vector< TNode >& vars, std::vector< TNode >& terms ) const; public: CDInstMatchTrie( context::Context* c ) : d_valid( c, false ){} ~CDInstMatchTrie(){} @@ -173,6 +183,14 @@ public: std::vector< TNode > terms; print( out, q, terms ); } + void getInstantiations( std::vector< Node >& insts, Node q ) { + std::vector< TNode > terms; + std::vector< TNode > vars; + for( unsigned i=0; i >& insts ) { + if( options::incrementalSolving() ){ + for( std::map< Node, inst::CDInstMatchTrie* >::iterator it = d_c_inst_match_trie.begin(); it != d_c_inst_match_trie.end(); ++it ){ + it->second->getInstantiations( insts[it->first], it->first ); + } + }else{ + for( std::map< Node, inst::InstMatchTrie >::iterator it = d_inst_match_trie.begin(); it != d_inst_match_trie.end(); ++it ){ + it->second.getInstantiations( insts[it->first], it->first ); + } + } +} + QuantifiersEngine::Statistics::Statistics() : d_time("theory::QuantifiersEngine::time"), d_num_quant("QuantifiersEngine::Num_Quantifiers", 0), diff --git a/src/theory/quantifiers_engine.h b/src/theory/quantifiers_engine.h index aa770ad67..49c9eeff8 100644 --- a/src/theory/quantifiers_engine.h +++ b/src/theory/quantifiers_engine.h @@ -342,6 +342,8 @@ public: void printInstantiations( std::ostream& out ); /** print solution for synthesis conjectures */ void printSynthSolution( std::ostream& out ); + /** get instantiations */ + void getInstantiations( std::map< Node, std::vector< Node > >& insts ); /** statistics class */ class Statistics { public: diff --git a/src/theory/theory_engine.cpp b/src/theory/theory_engine.cpp index 45f7506de..dcb3fec0a 100644 --- a/src/theory/theory_engine.cpp +++ b/src/theory/theory_engine.cpp @@ -1212,6 +1212,8 @@ Node TheoryEngine::ensureLiteral(TNode n) { void TheoryEngine::printInstantiations( std::ostream& out ) { if( d_quantEngine ){ d_quantEngine->printInstantiations( out ); + }else{ + out << "Internal error : instantiations not available when quantifiers are not present." << std::endl; } } @@ -1223,6 +1225,15 @@ void TheoryEngine::printSynthSolution( std::ostream& out ) { } } +void TheoryEngine::getInstantiations( std::map< Node, std::vector< Node > >& insts ) { + if( d_quantEngine ){ + d_quantEngine->getInstantiations( insts ); + }else{ + Assert( false ); + } +} + + static Node mkExplanation(const std::vector& explanation) { std::set all; diff --git a/src/theory/theory_engine.h b/src/theory/theory_engine.h index 886aa6863..5061f3cb7 100644 --- a/src/theory/theory_engine.h +++ b/src/theory/theory_engine.h @@ -783,6 +783,11 @@ public: */ void printSynthSolution( std::ostream& out ); + /** + * Get instantiations + */ + void getInstantiations( std::map< Node, std::vector< Node > >& insts ); + /** * Forwards an entailment check according to the given theoryOfMode. * See theory.h for documentation on entailmentCheck().