From fd60da4a22f02f6f5b82cef3585240c1b33595e9 Mon Sep 17 00:00:00 2001 From: Andres Noetzli Date: Sat, 6 Jun 2020 01:24:17 -0700 Subject: [PATCH] Keep definitions when global-declarations enabled (#4572) Fixes #4552. Fixes #4555. The SMT-LIB standard mandates that definitions are kept when `:global-declarations` are enabled. Until now, CVC4 was keeping track of the symbols of a definition correctly but lost the body of the definition when the user context was popped. This commit fixes the issue by adding a `global` parameter to `SmtEngine::defineFunction()` and `SmtEngine::defineFunctionRec()`. If that parameter is set, the definitions of functions are added at level 0 to `d_definedFunctions` and the lemmas for recursive function definitions are kept in an additional list and asserted during each `checkSat` call. The commit also updates new API, the commands, and the parsers to reflect this change. --- NEWS | 4 ++ src/api/cvc4cpp.cpp | 26 +++++--- src/api/cvc4cpp.h | 25 +++++-- src/api/python/cvc4.pxd | 10 +-- src/api/python/cvc4.pxi | 24 ++++--- src/parser/cvc/Cvc.g | 4 +- src/parser/parser.h | 2 + src/parser/smt2/Smt2.g | 54 ++++++++++----- src/smt/command.cpp | 47 +++++++------ src/smt/command.h | 40 ++++++++--- src/smt/smt_engine.cpp | 56 +++++++++++++--- src/smt/smt_engine.h | 36 ++++++---- test/regress/CMakeLists.txt | 1 + test/regress/regress0/smtlib/issue4552.smt2 | 27 ++++++++ test/unit/api/solver_black.h | 73 +++++++++++++++++++++ 15 files changed, 330 insertions(+), 99 deletions(-) create mode 100644 test/regress/regress0/smtlib/issue4552.smt2 diff --git a/NEWS b/NEWS index a7d6d3f40..ac9f0747e 100644 --- a/NEWS +++ b/NEWS @@ -3,6 +3,10 @@ This file contains a summary of important user-visible changes. Changes since 1.7 ================= +Improvements: +* API: Function definitions can now be requested to be global. If the `global` + parameter is set to true, they persist after popping the user context. + Changes: * API change: `SmtEngine::query()` has been renamed to `SmtEngine::checkEntailed()` and `Result::Validity` has been renamed to diff --git a/src/api/cvc4cpp.cpp b/src/api/cvc4cpp.cpp index 2c65f1ca6..88974dc69 100644 --- a/src/api/cvc4cpp.cpp +++ b/src/api/cvc4cpp.cpp @@ -4219,7 +4219,8 @@ Sort Solver::declareSort(const std::string& symbol, uint32_t arity) const Term Solver::defineFun(const std::string& symbol, const std::vector& bound_vars, Sort sort, - Term term) const + Term term, + bool global) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4_API_ARG_CHECK_EXPECTED(sort.isFirstClass(), sort) @@ -4253,14 +4254,15 @@ Term Solver::defineFun(const std::string& symbol, } Expr fun = d_exprMgr->mkVar(symbol, type); std::vector ebound_vars = termVectorToExprs(bound_vars); - d_smtEngine->defineFunction(fun, ebound_vars, *term.d_expr); + d_smtEngine->defineFunction(fun, ebound_vars, *term.d_expr, global); return Term(this, fun); CVC4_API_SOLVER_TRY_CATCH_END; } Term Solver::defineFun(Term fun, const std::vector& bound_vars, - Term term) const + Term term, + bool global) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4_API_ARG_CHECK_EXPECTED(fun.getSort().isFunction(), fun) << "function"; @@ -4293,7 +4295,7 @@ Term Solver::defineFun(Term fun, << codomain << "'"; std::vector ebound_vars = termVectorToExprs(bound_vars); - d_smtEngine->defineFunction(*fun.d_expr, ebound_vars, *term.d_expr); + d_smtEngine->defineFunction(*fun.d_expr, ebound_vars, *term.d_expr, global); return fun; CVC4_API_SOLVER_TRY_CATCH_END; } @@ -4304,7 +4306,8 @@ Term Solver::defineFun(Term fun, Term Solver::defineFunRec(const std::string& symbol, const std::vector& bound_vars, Sort sort, - Term term) const + Term term, + bool global) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4_API_ARG_CHECK_EXPECTED(sort.isFirstClass(), sort) @@ -4340,14 +4343,15 @@ Term Solver::defineFunRec(const std::string& symbol, } Expr fun = d_exprMgr->mkVar(symbol, type); std::vector ebound_vars = termVectorToExprs(bound_vars); - d_smtEngine->defineFunctionRec(fun, ebound_vars, *term.d_expr); + d_smtEngine->defineFunctionRec(fun, ebound_vars, *term.d_expr, global); return Term(this, fun); CVC4_API_SOLVER_TRY_CATCH_END; } Term Solver::defineFunRec(Term fun, const std::vector& bound_vars, - Term term) const + Term term, + bool global) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4_API_ARG_CHECK_EXPECTED(fun.getSort().isFunction(), fun) << "function"; @@ -4379,7 +4383,8 @@ Term Solver::defineFunRec(Term fun, << "Invalid sort of function body '" << term << "', expected '" << codomain << "'"; std::vector ebound_vars = termVectorToExprs(bound_vars); - d_smtEngine->defineFunctionRec(*fun.d_expr, ebound_vars, *term.d_expr); + d_smtEngine->defineFunctionRec( + *fun.d_expr, ebound_vars, *term.d_expr, global); return fun; CVC4_API_SOLVER_TRY_CATCH_END; } @@ -4389,7 +4394,8 @@ Term Solver::defineFunRec(Term fun, */ void Solver::defineFunsRec(const std::vector& funs, const std::vector>& bound_vars, - const std::vector& terms) const + const std::vector& terms, + bool global) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; size_t funs_size = funs.size(); @@ -4444,7 +4450,7 @@ void Solver::defineFunsRec(const std::vector& funs, ebound_vars.push_back(termVectorToExprs(v)); } std::vector exprs = termVectorToExprs(terms); - d_smtEngine->defineFunctionsRec(efuns, ebound_vars, exprs); + d_smtEngine->defineFunctionsRec(efuns, ebound_vars, exprs, global); CVC4_API_SOLVER_TRY_CATCH_END; } diff --git a/src/api/cvc4cpp.h b/src/api/cvc4cpp.h index adf3691ab..aa51a4134 100644 --- a/src/api/cvc4cpp.h +++ b/src/api/cvc4cpp.h @@ -2860,12 +2860,15 @@ class CVC4_PUBLIC Solver * @param bound_vars the parameters to this function * @param sort the sort of the return value of this function * @param term the function body + * @param global determines whether this definition is global (i.e. persists + * when popping the context) * @return the function */ Term defineFun(const std::string& symbol, const std::vector& bound_vars, Sort sort, - Term term) const; + Term term, + bool global = false) const; /** * Define n-ary function. * SMT-LIB: ( define-fun ) @@ -2873,11 +2876,14 @@ class CVC4_PUBLIC Solver * @param fun the sorted function * @param bound_vars the parameters to this function * @param term the function body + * @param global determines whether this definition is global (i.e. persists + * when popping the context) * @return the function */ Term defineFun(Term fun, const std::vector& bound_vars, - Term term) const; + Term term, + bool global = false) const; /** * Define recursive function. @@ -2886,12 +2892,15 @@ class CVC4_PUBLIC Solver * @param bound_vars the parameters to this function * @param sort the sort of the return value of this function * @param term the function body + * @param global determines whether this definition is global (i.e. persists + * when popping the context) * @return the function */ Term defineFunRec(const std::string& symbol, const std::vector& bound_vars, Sort sort, - Term term) const; + Term term, + bool global = false) const; /** * Define recursive function. @@ -2900,11 +2909,14 @@ class CVC4_PUBLIC Solver * @param fun the sorted function * @param bound_vars the parameters to this function * @param term the function body + * @param global determines whether this definition is global (i.e. persists + * when popping the context) * @return the function */ Term defineFunRec(Term fun, const std::vector& bound_vars, - Term term) const; + Term term, + bool global = false) const; /** * Define recursive functions. @@ -2913,11 +2925,14 @@ class CVC4_PUBLIC Solver * @param funs the sorted functions * @param bound_vars the list of parameters to the functions * @param term the list of function bodies of the functions + * @param global determines whether this definition is global (i.e. persists + * when popping the context) * @return the function */ void defineFunsRec(const std::vector& funs, const std::vector>& bound_vars, - const std::vector& terms) const; + const std::vector& terms, + bool global = false) const; /** * Echo a given string to the given output stream. diff --git a/src/api/python/cvc4.pxd b/src/api/python/cvc4.pxd index cc998306d..624b3c365 100644 --- a/src/api/python/cvc4.pxd +++ b/src/api/python/cvc4.pxd @@ -181,14 +181,14 @@ cdef extern from "api/cvc4cpp.h" namespace "CVC4::api": Term declareFun(const string& symbol, const vector[Sort]& sorts, Sort sort) except + Sort declareSort(const string& symbol, uint32_t arity) except + Term defineFun(const string& symbol, const vector[Term]& bound_vars, - Sort sort, Term term) except + - Term defineFun(Term fun, const vector[Term]& bound_vars, Term term) except + + Sort sort, Term term, bint glbl) except + + Term defineFun(Term fun, const vector[Term]& bound_vars, Term term, bint glbl) except + Term defineFunRec(const string& symbol, const vector[Term]& bound_vars, - Sort sort, Term term) except + + Sort sort, Term term, bint glbl) except + Term defineFunRec(Term fun, const vector[Term]& bound_vars, - Term term) except + + Term term, bint glbl) except + Term defineFunsRec(vector[Term]& funs, vector[vector[Term]]& bound_vars, - vector[Term]& terms) except + + vector[Term]& terms, bint glbl) except + vector[Term] getAssertions() except + vector[pair[Term, Term]] getAssignment() except + string getInfo(const string& flag) except + diff --git a/src/api/python/cvc4.pxi b/src/api/python/cvc4.pxi index 9dd9c1cde..b7593f6f1 100644 --- a/src/api/python/cvc4.pxi +++ b/src/api/python/cvc4.pxi @@ -797,13 +797,13 @@ cdef class Solver: sort.csort = self.csolver.declareSort(symbol.encode(), arity) return sort - def defineFun(self, sym_or_fun, bound_vars, sort_or_term, t=None): + def defineFun(self, sym_or_fun, bound_vars, sort_or_term, t=None, glbl=False): ''' Supports two uses: Term defineFun(str symbol, List[Term] bound_vars, - Sort sort, Term term) + Sort sort, Term term, bool glbl) Term defineFun(Term fun, List[Term] bound_vars, - Term term) + Term term, bool glbl) ''' cdef Term term = Term() cdef vector[c_Term] v @@ -814,21 +814,23 @@ cdef class Solver: term.cterm = self.csolver.defineFun(( sym_or_fun).encode(), v, ( sort_or_term).csort, - ( t).cterm) + ( t).cterm, + glbl) else: term.cterm = self.csolver.defineFun(( sym_or_fun).cterm, v, - ( sort_or_term).cterm) + ( sort_or_term).cterm, + glbl) return term - def defineFunRec(self, sym_or_fun, bound_vars, sort_or_term, t=None): + def defineFunRec(self, sym_or_fun, bound_vars, sort_or_term, t=None, glbl=False): ''' Supports two uses: Term defineFunRec(str symbol, List[Term] bound_vars, - Sort sort, Term term) + Sort sort, Term term, bool glbl) Term defineFunRec(Term fun, List[Term] bound_vars, - Term term) + Term term, bool glbl) ''' cdef Term term = Term() cdef vector[c_Term] v @@ -839,11 +841,13 @@ cdef class Solver: term.cterm = self.csolver.defineFunRec(( sym_or_fun).encode(), v, ( sort_or_term).csort, - ( t).cterm) + ( t).cterm, + glbl) else: term.cterm = self.csolver.defineFunRec(( sym_or_fun).cterm, v, - ( sort_or_term).cterm) + ( sort_or_term).cterm, + glbl) return term diff --git a/src/parser/cvc/Cvc.g b/src/parser/cvc/Cvc.g index e604c7769..5d04a8cc0 100644 --- a/src/parser/cvc/Cvc.g +++ b/src/parser/cvc/Cvc.g @@ -943,7 +943,7 @@ mainCommand[std::unique_ptr* cmd] cmd->reset( new DefineFunctionRecCommand(api::termVectorToExprs(funcs), eformals, - api::termVectorToExprs(formulas))); + api::termVectorToExprs(formulas), true)); } | toplevelDeclaration[cmd] ; @@ -1163,7 +1163,7 @@ declareVariables[std::unique_ptr* cmd, CVC4::api::Sort& t, ExprManager::VAR_FLAG_GLOBAL | ExprManager::VAR_FLAG_DEFINED); PARSER_STATE->defineVar(*i, f); Command* decl = - new DefineFunctionCommand(*i, func.getExpr(), f.getExpr()); + new DefineFunctionCommand(*i, func.getExpr(), f.getExpr(), true); seq->addCommand(decl); } } diff --git a/src/parser/parser.h b/src/parser/parser.h index 681404efa..0bdf23dcd 100644 --- a/src/parser/parser.h +++ b/src/parser/parser.h @@ -812,6 +812,8 @@ public: d_globalDeclarations = flag; } + bool getGlobalDeclarations() { return d_globalDeclarations; } + inline SymbolTable* getSymbolTable() const { return d_symtab; } diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g index 436700826..dd261dcb6 100644 --- a/src/parser/smt2/Smt2.g +++ b/src/parser/smt2/Smt2.g @@ -361,8 +361,12 @@ command [std::unique_ptr* cmd] // we allow overloading for function definitions api::Term func = PARSER_STATE->bindVar(name, t, ExprManager::VAR_FLAG_DEFINED, true); - cmd->reset(new DefineFunctionCommand( - name, func.getExpr(), api::termVectorToExprs(terms), expr.getExpr())); + cmd->reset( + new DefineFunctionCommand(name, + func.getExpr(), + api::termVectorToExprs(terms), + expr.getExpr(), + PARSER_STATE->getGlobalDeclarations())); } | DECLARE_DATATYPE_TOK datatypeDefCommand[false, cmd] | DECLARE_DATATYPES_TOK datatypesDefCommand[false, cmd] @@ -1204,7 +1208,7 @@ smt25Command[std::unique_ptr* cmd] expr = PARSER_STATE->mkHoApply( expr, flattenVars ); } cmd->reset(new DefineFunctionRecCommand( - func.getExpr(), api::termVectorToExprs(bvs), expr.getExpr())); + func.getExpr(), api::termVectorToExprs(bvs), expr.getExpr(), PARSER_STATE->getGlobalDeclarations())); } | DEFINE_FUNS_REC_TOK { PARSER_STATE->checkThatLogicIsSet();} @@ -1275,7 +1279,7 @@ smt25Command[std::unique_ptr* cmd] cmd->reset( new DefineFunctionRecCommand(api::termVectorToExprs(funcs), eformals, - api::termVectorToExprs(func_defs))); + api::termVectorToExprs(func_defs), PARSER_STATE->getGlobalDeclarations())); } ; @@ -1365,14 +1369,21 @@ extendedCommand[std::unique_ptr* cmd] { cmd->reset(seq.release()); } | DEFINE_TOK { PARSER_STATE->checkThatLogicIsSet(); } - ( symbol[name,CHECK_UNDECLARED,SYM_VARIABLE] + ( // (define f t) + symbol[name,CHECK_UNDECLARED,SYM_VARIABLE] { PARSER_STATE->checkUserSymbol(name); } term[e,e2] - { api::Term func = PARSER_STATE->bindVar(name, e.getSort(), + { + api::Term func = PARSER_STATE->bindVar(name, e.getSort(), ExprManager::VAR_FLAG_DEFINED); - cmd->reset(new DefineFunctionCommand(name, func.getExpr(), e.getExpr())); + cmd->reset( + new DefineFunctionCommand(name, + func.getExpr(), + e.getExpr(), + PARSER_STATE->getGlobalDeclarations())); } - | LPAREN_TOK + | // (define (f (v U) ...) t) + LPAREN_TOK symbol[name,CHECK_UNDECLARED,SYM_VARIABLE] { PARSER_STATE->checkUserSymbol(name); } sortedVarList[sortedVarNames] RPAREN_TOK @@ -1382,7 +1393,8 @@ extendedCommand[std::unique_ptr* cmd] terms = PARSER_STATE->bindBoundVars(sortedVarNames); } term[e,e2] - { PARSER_STATE->popScope(); + { + PARSER_STATE->popScope(); // declare the name down here (while parsing term, signature // must not be extended with the name itself; no recursion // permitted) @@ -1398,11 +1410,16 @@ extendedCommand[std::unique_ptr* cmd] } api::Term func = PARSER_STATE->bindVar(name, tt, ExprManager::VAR_FLAG_DEFINED); - cmd->reset(new DefineFunctionCommand( - name, func.getExpr(), api::termVectorToExprs(terms), e.getExpr())); + cmd->reset( + new DefineFunctionCommand(name, + func.getExpr(), + api::termVectorToExprs(terms), + e.getExpr(), + PARSER_STATE->getGlobalDeclarations())); } ) - | DEFINE_CONST_TOK { PARSER_STATE->checkThatLogicIsSet(); } + | // (define-const x U t) + DEFINE_CONST_TOK { PARSER_STATE->checkThatLogicIsSet(); } symbol[name,CHECK_UNDECLARED,SYM_VARIABLE] { PARSER_STATE->checkUserSymbol(name); } sortSymbol[t,CHECK_DECLARED] @@ -1412,14 +1429,19 @@ extendedCommand[std::unique_ptr* cmd] terms = PARSER_STATE->bindBoundVars(sortedVarNames); } term[e, e2] - { PARSER_STATE->popScope(); + { + PARSER_STATE->popScope(); // declare the name down here (while parsing term, signature // must not be extended with the name itself; no recursion // permitted) api::Term func = PARSER_STATE->bindVar(name, t, ExprManager::VAR_FLAG_DEFINED); - cmd->reset(new DefineFunctionCommand( - name, func.getExpr(), api::termVectorToExprs(terms), e.getExpr())); + cmd->reset( + new DefineFunctionCommand(name, + func.getExpr(), + api::termVectorToExprs(terms), + e.getExpr(), + PARSER_STATE->getGlobalDeclarations())); } | SIMPLIFY_TOK { PARSER_STATE->checkThatLogicIsSet(); } @@ -2217,7 +2239,7 @@ attribute[CVC4::api::Term& expr, CVC4::api::Term& retExpr, std::string& attr] std::string name = sexpr.getValue(); // bind name to expr with define-fun Command* c = new DefineNamedFunctionCommand( - name, func.getExpr(), std::vector(), expr.getExpr()); + name, func.getExpr(), std::vector(), expr.getExpr(), PARSER_STATE->getGlobalDeclarations()); c->setMuted(true); PARSER_STATE->preemptCommand(c); } diff --git a/src/smt/command.cpp b/src/smt/command.cpp index 20f2dcff9..9fd0122fc 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -1266,22 +1266,27 @@ std::string DefineTypeCommand::getCommandName() const { return "define-sort"; } DefineFunctionCommand::DefineFunctionCommand(const std::string& id, Expr func, - Expr formula) + Expr formula, + bool global) : DeclarationDefinitionCommand(id), d_func(func), d_formals(), - d_formula(formula) + d_formula(formula), + d_global(global) { } DefineFunctionCommand::DefineFunctionCommand(const std::string& id, Expr func, const std::vector& formals, - Expr formula) + Expr formula, + bool global) : DeclarationDefinitionCommand(id), d_func(func), d_formals(formals), - d_formula(formula) + d_formula(formula), + d_global(global) + { } @@ -1298,7 +1303,7 @@ void DefineFunctionCommand::invoke(SmtEngine* smtEngine) { if (!d_func.isNull()) { - smtEngine->defineFunction(d_func, d_formals, d_formula); + smtEngine->defineFunction(d_func, d_formals, d_formula, d_global); } d_commandStatus = CommandSuccess::instance(); } @@ -1319,12 +1324,13 @@ Command* DefineFunctionCommand::exportTo(ExprManager* exprManager, back_inserter(formals), ExportTransformer(exprManager, variableMap)); Expr formula = d_formula.exportTo(exprManager, variableMap); - return new DefineFunctionCommand(d_symbol, func, formals, formula); + return new DefineFunctionCommand(d_symbol, func, formals, formula, d_global); } Command* DefineFunctionCommand::clone() const { - return new DefineFunctionCommand(d_symbol, d_func, d_formals, d_formula); + return new DefineFunctionCommand( + d_symbol, d_func, d_formals, d_formula, d_global); } std::string DefineFunctionCommand::getCommandName() const @@ -1340,8 +1346,9 @@ DefineNamedFunctionCommand::DefineNamedFunctionCommand( const std::string& id, Expr func, const std::vector& formals, - Expr formula) - : DefineFunctionCommand(id, func, formals, formula) + Expr formula, + bool global) + : DefineFunctionCommand(id, func, formals, formula, global) { } @@ -1365,12 +1372,14 @@ Command* DefineNamedFunctionCommand::exportTo( back_inserter(formals), ExportTransformer(exprManager, variableMap)); Expr formula = d_formula.exportTo(exprManager, variableMap); - return new DefineNamedFunctionCommand(d_symbol, func, formals, formula); + return new DefineNamedFunctionCommand( + d_symbol, func, formals, formula, d_global); } Command* DefineNamedFunctionCommand::clone() const { - return new DefineNamedFunctionCommand(d_symbol, d_func, d_formals, d_formula); + return new DefineNamedFunctionCommand( + d_symbol, d_func, d_formals, d_formula, d_global); } /* -------------------------------------------------------------------------- */ @@ -1378,7 +1387,8 @@ Command* DefineNamedFunctionCommand::clone() const /* -------------------------------------------------------------------------- */ DefineFunctionRecCommand::DefineFunctionRecCommand( - Expr func, const std::vector& formals, Expr formula) + Expr func, const std::vector& formals, Expr formula, bool global) + : d_global(global) { d_funcs.push_back(func); d_formals.push_back(formals); @@ -1388,11 +1398,10 @@ DefineFunctionRecCommand::DefineFunctionRecCommand( DefineFunctionRecCommand::DefineFunctionRecCommand( const std::vector& funcs, const std::vector>& formals, - const std::vector& formulas) + const std::vector& formulas, + bool global) + : d_funcs(funcs), d_formals(formals), d_formulas(formulas), d_global(global) { - d_funcs.insert(d_funcs.end(), funcs.begin(), funcs.end()); - d_formals.insert(d_formals.end(), formals.begin(), formals.end()); - d_formulas.insert(d_formulas.end(), formulas.begin(), formulas.end()); } const std::vector& DefineFunctionRecCommand::getFunctions() const @@ -1415,7 +1424,7 @@ void DefineFunctionRecCommand::invoke(SmtEngine* smtEngine) { try { - smtEngine->defineFunctionsRec(d_funcs, d_formals, d_formulas); + smtEngine->defineFunctionsRec(d_funcs, d_formals, d_formulas, d_global); d_commandStatus = CommandSuccess::instance(); } catch (exception& e) @@ -1450,12 +1459,12 @@ Command* DefineFunctionRecCommand::exportTo( Expr formula = d_formulas[i].exportTo(exprManager, variableMap); formulas.push_back(formula); } - return new DefineFunctionRecCommand(funcs, formals, formulas); + return new DefineFunctionRecCommand(funcs, formals, formulas, d_global); } Command* DefineFunctionRecCommand::clone() const { - return new DefineFunctionRecCommand(d_funcs, d_formals, d_formulas); + return new DefineFunctionRecCommand(d_funcs, d_formals, d_formulas, d_global); } std::string DefineFunctionRecCommand::getCommandName() const diff --git a/src/smt/command.h b/src/smt/command.h index 63f1f0f33..0582cee34 100644 --- a/src/smt/command.h +++ b/src/smt/command.h @@ -436,17 +436,16 @@ class CVC4_PUBLIC DefineTypeCommand : public DeclarationDefinitionCommand class CVC4_PUBLIC DefineFunctionCommand : public DeclarationDefinitionCommand { - protected: - Expr d_func; - std::vector d_formals; - Expr d_formula; - public: - DefineFunctionCommand(const std::string& id, Expr func, Expr formula); + DefineFunctionCommand(const std::string& id, + Expr func, + Expr formula, + bool global); DefineFunctionCommand(const std::string& id, Expr func, const std::vector& formals, - Expr formula); + Expr formula, + bool global); Expr getFunction() const; const std::vector& getFormals() const; @@ -457,6 +456,19 @@ class CVC4_PUBLIC DefineFunctionCommand : public DeclarationDefinitionCommand ExprManagerMapCollection& variableMap) override; Command* clone() const override; std::string getCommandName() const override; + + protected: + /** The function we are defining */ + Expr d_func; + /** The formal arguments for the function we are defining */ + std::vector d_formals; + /** The formula corresponding to the body of the function we are defining */ + Expr d_formula; + /** + * Stores whether this definition is global (i.e. should persist when + * popping the user context. + */ + bool d_global; }; /* class DefineFunctionCommand */ /** @@ -470,7 +482,8 @@ class CVC4_PUBLIC DefineNamedFunctionCommand : public DefineFunctionCommand DefineNamedFunctionCommand(const std::string& id, Expr func, const std::vector& formals, - Expr formula); + Expr formula, + bool global); void invoke(SmtEngine* smtEngine) override; Command* exportTo(ExprManager* exprManager, ExprManagerMapCollection& variableMap) override; @@ -487,10 +500,12 @@ class CVC4_PUBLIC DefineFunctionRecCommand : public Command public: DefineFunctionRecCommand(Expr func, const std::vector& formals, - Expr formula); + Expr formula, + bool global); DefineFunctionRecCommand(const std::vector& funcs, const std::vector >& formals, - const std::vector& formula); + const std::vector& formula, + bool global); const std::vector& getFunctions() const; const std::vector >& getFormals() const; @@ -509,6 +524,11 @@ class CVC4_PUBLIC DefineFunctionRecCommand : public Command std::vector > d_formals; /** formulas corresponding to the bodies of the functions we are defining */ std::vector d_formulas; + /** + * Stores whether this definition is global (i.e. should persist when + * popping the user context. + */ + bool d_global; }; /* class DefineFunctionRecCommand */ /** diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 9e382cdcf..e7ef23c16 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -759,6 +759,7 @@ void SmtEngine::finishInit() // In the case of incremental solving, we appear to need these to // ensure the relevant Nodes remain live. d_assertionList = new (true) AssertionList(getUserContext()); + d_globalDefineFunRecLemmas.reset(new std::vector()); } // dump out a set-logic command only when raw-benchmark is disabled to avoid @@ -847,6 +848,8 @@ SmtEngine::~SmtEngine() d_assignments->deleteSelf(); } + d_globalDefineFunRecLemmas.reset(); + if(d_assertionList != NULL) { d_assertionList->deleteSelf(); } @@ -1179,7 +1182,8 @@ void SmtEngine::debugCheckFunctionBody(Expr formula, void SmtEngine::defineFunction(Expr func, const std::vector& formals, - Expr formula) + Expr formula, + bool global) { SmtScope smts(this); finalOptionsAreSet(); @@ -1191,7 +1195,7 @@ void SmtEngine::defineFunction(Expr func, ss << language::SetLanguage( language::SetLanguage::getLanguage(Dump.getStream())) << func; - DefineFunctionCommand c(ss.str(), func, formals, formula); + DefineFunctionCommand c(ss.str(), func, formals, formula, global); addToModelCommandAndDump( c, ExprManager::VAR_FLAG_DEFINED, true, "declarations"); @@ -1220,13 +1224,22 @@ void SmtEngine::defineFunction(Expr func, // Otherwise, (check-sat) (get-value ((! foo :named bar))) breaks // d_haveAdditions = true; Debug("smt") << "definedFunctions insert " << funcNode << " " << formNode << endl; - d_definedFunctions->insert(funcNode, def); + + if (global) + { + d_definedFunctions->insertAtContextLevelZero(funcNode, def); + } + else + { + d_definedFunctions->insert(funcNode, def); + } } void SmtEngine::defineFunctionsRec( const std::vector& funcs, - const std::vector >& formals, - const std::vector& formulas) + const std::vector>& formals, + const std::vector& formulas, + bool global) { SmtScope smts(this); finalOptionsAreSet(); @@ -1254,7 +1267,8 @@ void SmtEngine::defineFunctionsRec( if (Dump.isOn("raw-benchmark")) { - Dump("raw-benchmark") << DefineFunctionRecCommand(funcs, formals, formulas); + Dump("raw-benchmark") << DefineFunctionRecCommand( + funcs, formals, formulas, global); } ExprManager* em = getExprManager(); @@ -1294,17 +1308,28 @@ void SmtEngine::defineFunctionsRec( // notice we don't call assertFormula directly, since this would // duplicate the output on raw-benchmark. Expr e = d_private->substituteAbstractValues(Node::fromExpr(lem)).toExpr(); - if (d_assertionList != NULL) + if (d_assertionList != nullptr) { d_assertionList->push_back(e); } - d_private->addFormula(e.getNode(), false, true, false, maybeHasFv); + if (global && d_globalDefineFunRecLemmas != nullptr) + { + // Global definitions are asserted at check-sat-time because we have to + // make sure that they are always present + Assert(!language::isInputLangSygus(options::inputLanguage())); + d_globalDefineFunRecLemmas->emplace_back(Node::fromExpr(e)); + } + else + { + d_private->addFormula(e.getNode(), false, true, false, maybeHasFv); + } } } void SmtEngine::defineFunctionRec(Expr func, const std::vector& formals, - Expr formula) + Expr formula, + bool global) { std::vector funcs; funcs.push_back(func); @@ -1312,7 +1337,7 @@ void SmtEngine::defineFunctionRec(Expr func, formals_multi.push_back(formals); std::vector formulas; formulas.push_back(formula); - defineFunctionsRec(funcs, formals_multi, formulas); + defineFunctionsRec(funcs, formals_multi, formulas, global); } bool SmtEngine::isDefinedFunction( Expr func ){ @@ -1652,6 +1677,17 @@ Result SmtEngine::checkSatisfiability(const vector& assumptions, d_private->addFormula(e.getNode(), inUnsatCore, true, true); } + if (d_globalDefineFunRecLemmas != nullptr) + { + // Global definitions are asserted at check-sat-time because we have to + // make sure that they are always present (they are essentially level + // zero assertions) + for (const Node& lemma : *d_globalDefineFunRecLemmas) + { + d_private->addFormula(lemma, false, true, false, false); + } + } + r = check(); if ((options::solveRealAsInt() || options::solveIntAsBV() > 0) diff --git a/src/smt/smt_engine.h b/src/smt/smt_engine.h index 75737b603..29d25c103 100644 --- a/src/smt/smt_engine.h +++ b/src/smt/smt_engine.h @@ -290,15 +290,18 @@ class CVC4_PUBLIC SmtEngine * (lambda (formals) formula) * This adds func to the list of defined functions, which indicates that * all occurrences of func should be expanded during expandDefinitions. - * This method expects input such that: - * - func : a variable of function type that expects the arguments in - * formals, - * - formals : a list of BOUND_VARIABLE expressions, - * - formula does not contain func. + * + * @param func a variable of function type that expects the arguments in + * formal + * @param formals a list of BOUND_VARIABLE expressions + * @param formula The body of the function, must not contain func + * @param global True if this definition is global (i.e. should persist when + * popping the user context) */ void defineFunction(Expr func, const std::vector& formals, - Expr formula); + Expr formula, + bool global = false); /** Return true if given expression is a defined function. */ bool isDefinedFunction(Expr func); @@ -317,17 +320,22 @@ class CVC4_PUBLIC SmtEngine * - func[i] : a variable of function type that expects the arguments in * formals[i], and * - formals[i] : a list of BOUND_VARIABLE expressions. + * + * @param global True if this definition is global (i.e. should persist when + * popping the user context) */ void defineFunctionsRec(const std::vector& funcs, - const std::vector >& formals, - const std::vector& formulas); + const std::vector>& formals, + const std::vector& formulas, + bool global = false); /** * Define function recursive * Same as above, but for a single function. */ void defineFunctionRec(Expr func, const std::vector& formals, - Expr formula); + Expr formula, + bool global = false); /** * Add a formula to the current context: preprocess, do per-theory * setup, use processAssertionList(), asserting to T-solver for @@ -862,8 +870,6 @@ class CVC4_PUBLIC SmtEngine typedef context::CDList AssertionList; /** The type of our internal assignment set */ typedef context::CDHashSet AssignmentSet; - /** The types for the recursive function definitions */ - typedef context::CDList NodeList; // disallow copy/assignment SmtEngine(const SmtEngine&) = delete; @@ -1139,10 +1145,16 @@ class CVC4_PUBLIC SmtEngine /** * The assertion list (before any conversion) for supporting - * getAssertions(). Only maintained if in interactive mode. + * getAssertions(). Only maintained if in incremental mode. */ AssertionList* d_assertionList; + /** + * List of lemmas generated for global recursive function definitions. We + * assert this list of definitions in each check-sat call. + */ + std::unique_ptr> d_globalDefineFunRecLemmas; + /** * The list of assumptions from the previous call to checkSatisfiability. * Note that if the last call to checkSatisfiability was an entailment check, diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index 4bc9d2705..e0ce456bc 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -921,6 +921,7 @@ set(regress_0_tests regress0/smtlib/issue4028.smt2 regress0/smtlib/issue4077.smt2 regress0/smtlib/issue4151.smt2 + regress0/smtlib/issue4552.smt2 regress0/smtlib/reason-unknown.smt2 regress0/smtlib/reset.smt2 regress0/smtlib/reset-assertions1.smt2 diff --git a/test/regress/regress0/smtlib/issue4552.smt2 b/test/regress/regress0/smtlib/issue4552.smt2 new file mode 100644 index 000000000..af8e0b948 --- /dev/null +++ b/test/regress/regress0/smtlib/issue4552.smt2 @@ -0,0 +1,27 @@ +; COMMAND-LINE: --incremental +; EXPECT: unsat +; EXPECT: unsat +; EXPECT: unsat +(set-logic UF) +(set-option :global-declarations true) + +(push) +(define a true) +(define (f (b Bool)) b) +(define-const a2 Bool true) + +(define-fun a3 () Bool true) + +(define-fun-rec b () Bool true) +(define-funs-rec ((g ((b Bool)) Bool)) (b)) +(assert (or (not a) (not a2) (not a3) (not b) (g false))) +(check-sat) +(pop) + +(assert (or (not a) (not a2) (not a3) (not b) (g false))) +(check-sat) + +(reset-assertions) + +(assert (or (not a) (not a2) (not a3) (not b) (g false))) +(check-sat) diff --git a/test/unit/api/solver_black.h b/test/unit/api/solver_black.h index 3dcf18f78..257c28669 100644 --- a/test/unit/api/solver_black.h +++ b/test/unit/api/solver_black.h @@ -84,8 +84,11 @@ class SolverBlack : public CxxTest::TestSuite void testDeclareSort(); void testDefineFun(); + void testDefineFunGlobal(); void testDefineFunRec(); + void testDefineFunRecGlobal(); void testDefineFunsRec(); + void testDefineFunsRecGlobal(); void testUFIteration(); @@ -1036,6 +1039,30 @@ void SolverBlack::testDefineFun() CVC4ApiException&); } +void SolverBlack::testDefineFunGlobal() +{ + Sort bSort = d_solver->getBooleanSort(); + Sort fSort = d_solver->mkFunctionSort({bSort}, bSort); + + Term bTrue = d_solver->mkBoolean(true); + // (define-fun f () Bool true) + Term f = d_solver->defineFun("f", {}, bSort, bTrue, true); + Term b = d_solver->mkVar(bSort, "b"); + Term gSym = d_solver->mkConst(fSort, "g"); + // (define-fun g (b Bool) Bool b) + Term g = d_solver->defineFun(gSym, {b}, b, true); + + // (assert (or (not f) (not (g true)))) + d_solver->assertFormula(d_solver->mkTerm( + OR, f.notTerm(), d_solver->mkTerm(APPLY_UF, g, bTrue).notTerm())); + TS_ASSERT(d_solver->checkSat().isUnsat()); + d_solver->resetAssertions(); + // (assert (or (not f) (not (g true)))) + d_solver->assertFormula(d_solver->mkTerm( + OR, f.notTerm(), d_solver->mkTerm(APPLY_UF, g, bTrue).notTerm())); + TS_ASSERT(d_solver->checkSat().isUnsat()); +} + void SolverBlack::testDefineFunRec() { Sort bvSort = d_solver->mkBitVectorSort(32); @@ -1090,6 +1117,31 @@ void SolverBlack::testDefineFunRec() CVC4ApiException&); } +void SolverBlack::testDefineFunRecGlobal() +{ + Sort bSort = d_solver->getBooleanSort(); + Sort fSort = d_solver->mkFunctionSort({bSort}, bSort); + + d_solver->push(); + Term bTrue = d_solver->mkBoolean(true); + // (define-fun f () Bool true) + Term f = d_solver->defineFunRec("f", {}, bSort, bTrue, true); + Term b = d_solver->mkVar(bSort, "b"); + Term gSym = d_solver->mkConst(fSort, "g"); + // (define-fun g (b Bool) Bool b) + Term g = d_solver->defineFunRec(gSym, {b}, b, true); + + // (assert (or (not f) (not (g true)))) + d_solver->assertFormula(d_solver->mkTerm( + OR, f.notTerm(), d_solver->mkTerm(APPLY_UF, g, bTrue).notTerm())); + TS_ASSERT(d_solver->checkSat().isUnsat()); + d_solver->pop(); + // (assert (or (not f) (not (g true)))) + d_solver->assertFormula(d_solver->mkTerm( + OR, f.notTerm(), d_solver->mkTerm(APPLY_UF, g, bTrue).notTerm())); + TS_ASSERT(d_solver->checkSat().isUnsat()); +} + void SolverBlack::testDefineFunsRec() { Sort uSort = d_solver->mkUninterpretedSort("u"); @@ -1162,6 +1214,27 @@ void SolverBlack::testDefineFunsRec() CVC4ApiException&); } +void SolverBlack::testDefineFunsRecGlobal() +{ + Sort bSort = d_solver->getBooleanSort(); + Sort fSort = d_solver->mkFunctionSort({bSort}, bSort); + + d_solver->push(); + Term bTrue = d_solver->mkBoolean(true); + Term b = d_solver->mkVar(bSort, "b"); + Term gSym = d_solver->mkConst(fSort, "g"); + // (define-funs-rec ((g ((b Bool)) Bool)) (b)) + d_solver->defineFunsRec({gSym}, {{b}}, {b}, true); + + // (assert (not (g true))) + d_solver->assertFormula(d_solver->mkTerm(APPLY_UF, gSym, bTrue).notTerm()); + TS_ASSERT(d_solver->checkSat().isUnsat()); + d_solver->pop(); + // (assert (not (g true))) + d_solver->assertFormula(d_solver->mkTerm(APPLY_UF, gSym, bTrue).notTerm()); + TS_ASSERT(d_solver->checkSat().isUnsat()); +} + void SolverBlack::testUFIteration() { Sort intSort = d_solver->getIntegerSort(); -- 2.30.2