Keep definitions when global-declarations enabled (#4572)
authorAndres Noetzli <andres.noetzli@gmail.com>
Sat, 6 Jun 2020 08:24:17 +0000 (01:24 -0700)
committerGitHub <noreply@github.com>
Sat, 6 Jun 2020 08:24:17 +0000 (01:24 -0700)
Fixes #4552. Fixes #4555. The SMT-LIB standard mandates that definitions
are kept when `:global-declarations` are enabled. Until now, CVC4 was
keeping track of the symbols of a definition correctly but lost the body
of the definition when the user context was popped. This commit fixes
the issue by adding a `global` parameter to
`SmtEngine::defineFunction()` and `SmtEngine::defineFunctionRec()`. If
that parameter is set, the definitions of functions are added at level 0
to `d_definedFunctions` and the lemmas for recursive function
definitions are kept in an additional list and asserted during each
`checkSat` call. The commit also updates new API, the commands, and the
parsers to reflect this change.

15 files changed:
NEWS
src/api/cvc4cpp.cpp
src/api/cvc4cpp.h
src/api/python/cvc4.pxd
src/api/python/cvc4.pxi
src/parser/cvc/Cvc.g
src/parser/parser.h
src/parser/smt2/Smt2.g
src/smt/command.cpp
src/smt/command.h
src/smt/smt_engine.cpp
src/smt/smt_engine.h
test/regress/CMakeLists.txt
test/regress/regress0/smtlib/issue4552.smt2 [new file with mode: 0644]
test/unit/api/solver_black.h

diff --git a/NEWS b/NEWS
index a7d6d3f40195a3e6f1dca447b95bdfd0e4866449..ac9f0747eb4225bc8d9fe3d9756a5f4df1193b18 100644 (file)
--- a/NEWS
+++ b/NEWS
@@ -3,6 +3,10 @@ This file contains a summary of important user-visible changes.
 Changes since 1.7
 =================
 
+Improvements:
+* API: Function definitions can now be requested to be global. If the `global`
+  parameter is set to true, they persist after popping the user context.
+
 Changes:
 * API change: `SmtEngine::query()` has been renamed to
   `SmtEngine::checkEntailed()` and `Result::Validity` has been renamed to
index 2c65f1ca682f62ca3b0f3da3cc24c9650b845934..88974dc69965dbf717fce1b874ea7b244690eaec 100644 (file)
@@ -4219,7 +4219,8 @@ Sort Solver::declareSort(const std::string& symbol, uint32_t arity) const
 Term Solver::defineFun(const std::string& symbol,
                        const std::vector<Term>& bound_vars,
                        Sort sort,
-                       Term term) const
+                       Term term,
+                       bool global) const
 {
   CVC4_API_SOLVER_TRY_CATCH_BEGIN;
   CVC4_API_ARG_CHECK_EXPECTED(sort.isFirstClass(), sort)
@@ -4253,14 +4254,15 @@ Term Solver::defineFun(const std::string& symbol,
   }
   Expr fun = d_exprMgr->mkVar(symbol, type);
   std::vector<Expr> ebound_vars = termVectorToExprs(bound_vars);
-  d_smtEngine->defineFunction(fun, ebound_vars, *term.d_expr);
+  d_smtEngine->defineFunction(fun, ebound_vars, *term.d_expr, global);
   return Term(this, fun);
   CVC4_API_SOLVER_TRY_CATCH_END;
 }
 
 Term Solver::defineFun(Term fun,
                        const std::vector<Term>& bound_vars,
-                       Term term) const
+                       Term term,
+                       bool global) const
 {
   CVC4_API_SOLVER_TRY_CATCH_BEGIN;
   CVC4_API_ARG_CHECK_EXPECTED(fun.getSort().isFunction(), fun) << "function";
@@ -4293,7 +4295,7 @@ Term Solver::defineFun(Term fun,
       << codomain << "'";
 
   std::vector<Expr> ebound_vars = termVectorToExprs(bound_vars);
-  d_smtEngine->defineFunction(*fun.d_expr, ebound_vars, *term.d_expr);
+  d_smtEngine->defineFunction(*fun.d_expr, ebound_vars, *term.d_expr, global);
   return fun;
   CVC4_API_SOLVER_TRY_CATCH_END;
 }
@@ -4304,7 +4306,8 @@ Term Solver::defineFun(Term fun,
 Term Solver::defineFunRec(const std::string& symbol,
                           const std::vector<Term>& bound_vars,
                           Sort sort,
-                          Term term) const
+                          Term term,
+                          bool global) const
 {
   CVC4_API_SOLVER_TRY_CATCH_BEGIN;
   CVC4_API_ARG_CHECK_EXPECTED(sort.isFirstClass(), sort)
@@ -4340,14 +4343,15 @@ Term Solver::defineFunRec(const std::string& symbol,
   }
   Expr fun = d_exprMgr->mkVar(symbol, type);
   std::vector<Expr> ebound_vars = termVectorToExprs(bound_vars);
-  d_smtEngine->defineFunctionRec(fun, ebound_vars, *term.d_expr);
+  d_smtEngine->defineFunctionRec(fun, ebound_vars, *term.d_expr, global);
   return Term(this, fun);
   CVC4_API_SOLVER_TRY_CATCH_END;
 }
 
 Term Solver::defineFunRec(Term fun,
                           const std::vector<Term>& bound_vars,
-                          Term term) const
+                          Term term,
+                          bool global) const
 {
   CVC4_API_SOLVER_TRY_CATCH_BEGIN;
   CVC4_API_ARG_CHECK_EXPECTED(fun.getSort().isFunction(), fun) << "function";
@@ -4379,7 +4383,8 @@ Term Solver::defineFunRec(Term fun,
       << "Invalid sort of function body '" << term << "', expected '"
       << codomain << "'";
   std::vector<Expr> ebound_vars = termVectorToExprs(bound_vars);
-  d_smtEngine->defineFunctionRec(*fun.d_expr, ebound_vars, *term.d_expr);
+  d_smtEngine->defineFunctionRec(
+      *fun.d_expr, ebound_vars, *term.d_expr, global);
   return fun;
   CVC4_API_SOLVER_TRY_CATCH_END;
 }
@@ -4389,7 +4394,8 @@ Term Solver::defineFunRec(Term fun,
  */
 void Solver::defineFunsRec(const std::vector<Term>& funs,
                            const std::vector<std::vector<Term>>& bound_vars,
-                           const std::vector<Term>& terms) const
+                           const std::vector<Term>& terms,
+                           bool global) const
 {
   CVC4_API_SOLVER_TRY_CATCH_BEGIN;
   size_t funs_size = funs.size();
@@ -4444,7 +4450,7 @@ void Solver::defineFunsRec(const std::vector<Term>& funs,
     ebound_vars.push_back(termVectorToExprs(v));
   }
   std::vector<Expr> exprs = termVectorToExprs(terms);
-  d_smtEngine->defineFunctionsRec(efuns, ebound_vars, exprs);
+  d_smtEngine->defineFunctionsRec(efuns, ebound_vars, exprs, global);
   CVC4_API_SOLVER_TRY_CATCH_END;
 }
 
index adf3691ab748acb49af5477071b6402a2e8201bb..aa51a4134260f7a0c6e65ff1ec69628929f3740e 100644 (file)
@@ -2860,12 +2860,15 @@ class CVC4_PUBLIC Solver
    * @param bound_vars the parameters to this function
    * @param sort the sort of the return value of this function
    * @param term the function body
+   * @param global determines whether this definition is global (i.e. persists
+   *               when popping the context)
    * @return the function
    */
   Term defineFun(const std::string& symbol,
                  const std::vector<Term>& bound_vars,
                  Sort sort,
-                 Term term) const;
+                 Term term,
+                 bool global = false) const;
   /**
    * Define n-ary function.
    * SMT-LIB: ( define-fun <function_def> )
@@ -2873,11 +2876,14 @@ class CVC4_PUBLIC Solver
    * @param fun the sorted function
    * @param bound_vars the parameters to this function
    * @param term the function body
+   * @param global determines whether this definition is global (i.e. persists
+   *               when popping the context)
    * @return the function
    */
   Term defineFun(Term fun,
                  const std::vector<Term>& bound_vars,
-                 Term term) const;
+                 Term term,
+                 bool global = false) const;
 
   /**
    * Define recursive function.
@@ -2886,12 +2892,15 @@ class CVC4_PUBLIC Solver
    * @param bound_vars the parameters to this function
    * @param sort the sort of the return value of this function
    * @param term the function body
+   * @param global determines whether this definition is global (i.e. persists
+   *               when popping the context)
    * @return the function
    */
   Term defineFunRec(const std::string& symbol,
                     const std::vector<Term>& bound_vars,
                     Sort sort,
-                    Term term) const;
+                    Term term,
+                    bool global = false) const;
 
   /**
    * Define recursive function.
@@ -2900,11 +2909,14 @@ class CVC4_PUBLIC Solver
    * @param fun the sorted function
    * @param bound_vars the parameters to this function
    * @param term the function body
+   * @param global determines whether this definition is global (i.e. persists
+   *               when popping the context)
    * @return the function
    */
   Term defineFunRec(Term fun,
                     const std::vector<Term>& bound_vars,
-                    Term term) const;
+                    Term term,
+                    bool global = false) const;
 
   /**
    * Define recursive functions.
@@ -2913,11 +2925,14 @@ class CVC4_PUBLIC Solver
    * @param funs the sorted functions
    * @param bound_vars the list of parameters to the functions
    * @param term the list of function bodies of the functions
+   * @param global determines whether this definition is global (i.e. persists
+   *               when popping the context)
    * @return the function
    */
   void defineFunsRec(const std::vector<Term>& funs,
                      const std::vector<std::vector<Term>>& bound_vars,
-                     const std::vector<Term>& terms) const;
+                     const std::vector<Term>& terms,
+                     bool global = false) const;
 
   /**
    * Echo a given string to the given output stream.
index cc998306d7fbfdfb782972e3be95d6d9fc45355b..624b3c365f01abf92d1543af062c04418a0d7346 100644 (file)
@@ -181,14 +181,14 @@ cdef extern from "api/cvc4cpp.h" namespace "CVC4::api":
         Term declareFun(const string& symbol, const vector[Sort]& sorts, Sort sort) except +
         Sort declareSort(const string& symbol, uint32_t arity) except +
         Term defineFun(const string& symbol, const vector[Term]& bound_vars,
-                       Sort sort, Term term) except +
-        Term defineFun(Term fun, const vector[Term]& bound_vars, Term term) except +
+                       Sort sort, Term term, bint glbl) except +
+        Term defineFun(Term fun, const vector[Term]& bound_vars, Term term, bint glbl) except +
         Term defineFunRec(const string& symbol, const vector[Term]& bound_vars,
-                          Sort sort, Term term) except +
+                          Sort sort, Term term, bint glbl) except +
         Term defineFunRec(Term fun, const vector[Term]& bound_vars,
-                          Term term) except +
+                          Term term, bint glbl) except +
         Term defineFunsRec(vector[Term]& funs, vector[vector[Term]]& bound_vars,
-                           vector[Term]& terms) except +
+                           vector[Term]& terms, bint glbl) except +
         vector[Term] getAssertions() except +
         vector[pair[Term, Term]] getAssignment() except +
         string getInfo(const string& flag) except +
index 9dd9c1cde26cff4a22aed63abb1ff1fce91ded42..b7593f6f1366b168b91fbe37c7dd10214b9479b9 100644 (file)
@@ -797,13 +797,13 @@ cdef class Solver:
         sort.csort = self.csolver.declareSort(symbol.encode(), arity)
         return sort
 
-    def defineFun(self, sym_or_fun, bound_vars, sort_or_term, t=None):
+    def defineFun(self, sym_or_fun, bound_vars, sort_or_term, t=None, glbl=False):
         '''
         Supports two uses:
                 Term defineFun(str symbol, List[Term] bound_vars,
-                               Sort sort, Term term)
+                               Sort sort, Term term, bool glbl)
                 Term defineFun(Term fun, List[Term] bound_vars,
-                               Term term)
+                               Term term, bool glbl)
         '''
         cdef Term term = Term()
         cdef vector[c_Term] v
@@ -814,21 +814,23 @@ cdef class Solver:
             term.cterm = self.csolver.defineFun((<str?> sym_or_fun).encode(),
                                                 <const vector[c_Term] &> v,
                                                 (<Sort?> sort_or_term).csort,
-                                                (<Term?> t).cterm)
+                                                (<Term?> t).cterm,
+                                                <bint> glbl)
         else:
             term.cterm = self.csolver.defineFun((<Term?> sym_or_fun).cterm,
                                                 <const vector[c_Term]&> v,
-                                                (<Term?> sort_or_term).cterm)
+                                                (<Term?> sort_or_term).cterm,
+                                                <bint> glbl)
 
         return term
 
-    def defineFunRec(self, sym_or_fun, bound_vars, sort_or_term, t=None):
+    def defineFunRec(self, sym_or_fun, bound_vars, sort_or_term, t=None, glbl=False):
         '''
         Supports two uses:
                 Term defineFunRec(str symbol, List[Term] bound_vars,
-                               Sort sort, Term term)
+                               Sort sort, Term term, bool glbl)
                 Term defineFunRec(Term fun, List[Term] bound_vars,
-                               Term term)
+                               Term term, bool glbl)
         '''
         cdef Term term = Term()
         cdef vector[c_Term] v
@@ -839,11 +841,13 @@ cdef class Solver:
             term.cterm = self.csolver.defineFunRec((<str?> sym_or_fun).encode(),
                                                 <const vector[c_Term] &> v,
                                                 (<Sort?> sort_or_term).csort,
-                                                (<Term?> t).cterm)
+                                                (<Term?> t).cterm,
+                                                <bint> glbl)
         else:
             term.cterm = self.csolver.defineFunRec((<Term?> sym_or_fun).cterm,
                                                    <const vector[c_Term]&> v,
-                                                   (<Term?> sort_or_term).cterm)
+                                                   (<Term?> sort_or_term).cterm,
+                                                   <bint> glbl)
 
         return term
 
index e604c77693a1f9bc634fbf13230d6dde7ef3a897..5d04a8cc0a569882e2c4c5babf5574ae6c49c249 100644 (file)
@@ -943,7 +943,7 @@ mainCommand[std::unique_ptr<CVC4::Command>* cmd]
       cmd->reset(
           new DefineFunctionRecCommand(api::termVectorToExprs(funcs),
                                        eformals,
-                                       api::termVectorToExprs(formulas)));
+                                       api::termVectorToExprs(formulas), true));
     }
   | toplevelDeclaration[cmd]
   ;
@@ -1163,7 +1163,7 @@ declareVariables[std::unique_ptr<CVC4::Command>* cmd, CVC4::api::Sort& t,
               ExprManager::VAR_FLAG_GLOBAL | ExprManager::VAR_FLAG_DEFINED);
           PARSER_STATE->defineVar(*i, f);
           Command* decl =
-              new DefineFunctionCommand(*i, func.getExpr(), f.getExpr());
+              new DefineFunctionCommand(*i, func.getExpr(), f.getExpr(), true);
           seq->addCommand(decl);
         }
       }
index 681404efa57248086d7ca9562dc7e197b6348f19..0bdf23dcdb886d7f8d6d82f7c69d26cecf04986d 100644 (file)
@@ -812,6 +812,8 @@ public:
     d_globalDeclarations = flag;
   }
 
+  bool getGlobalDeclarations() { return d_globalDeclarations; }
+
   inline SymbolTable* getSymbolTable() const {
     return d_symtab;
   }
index 436700826722607feeb86366d4c49046dd23de85..dd261dcb6b7cf8b6b3993727707532fda1c765fc 100644 (file)
@@ -361,8 +361,12 @@ command [std::unique_ptr<CVC4::Command>* cmd]
       // we allow overloading for function definitions
       api::Term func = PARSER_STATE->bindVar(name, t,
                                       ExprManager::VAR_FLAG_DEFINED, true);
-      cmd->reset(new DefineFunctionCommand(
-          name, func.getExpr(), api::termVectorToExprs(terms), expr.getExpr()));
+      cmd->reset(
+          new DefineFunctionCommand(name,
+                                    func.getExpr(),
+                                    api::termVectorToExprs(terms),
+                                    expr.getExpr(),
+                                    PARSER_STATE->getGlobalDeclarations()));
     }
   | DECLARE_DATATYPE_TOK datatypeDefCommand[false, cmd]
   | DECLARE_DATATYPES_TOK datatypesDefCommand[false, cmd]
@@ -1204,7 +1208,7 @@ smt25Command[std::unique_ptr<CVC4::Command>* cmd]
         expr = PARSER_STATE->mkHoApply( expr, flattenVars );
       }
       cmd->reset(new DefineFunctionRecCommand(
-          func.getExpr(), api::termVectorToExprs(bvs), expr.getExpr()));
+          func.getExpr(), api::termVectorToExprs(bvs), expr.getExpr(), PARSER_STATE->getGlobalDeclarations()));
     }
   | DEFINE_FUNS_REC_TOK
     { PARSER_STATE->checkThatLogicIsSet();}
@@ -1275,7 +1279,7 @@ smt25Command[std::unique_ptr<CVC4::Command>* cmd]
       cmd->reset(
           new DefineFunctionRecCommand(api::termVectorToExprs(funcs),
                                        eformals,
-                                       api::termVectorToExprs(func_defs)));
+                                       api::termVectorToExprs(func_defs), PARSER_STATE->getGlobalDeclarations()));
     }
   ;
 
@@ -1365,14 +1369,21 @@ extendedCommand[std::unique_ptr<CVC4::Command>* cmd]
     { cmd->reset(seq.release()); }
 
   | DEFINE_TOK { PARSER_STATE->checkThatLogicIsSet(); }
-    ( symbol[name,CHECK_UNDECLARED,SYM_VARIABLE]
+    ( // (define f t)
+      symbol[name,CHECK_UNDECLARED,SYM_VARIABLE]
       { PARSER_STATE->checkUserSymbol(name); }
       term[e,e2]
-      { api::Term func = PARSER_STATE->bindVar(name, e.getSort(),
+      {
+        api::Term func = PARSER_STATE->bindVar(name, e.getSort(),
                                         ExprManager::VAR_FLAG_DEFINED);
-        cmd->reset(new DefineFunctionCommand(name, func.getExpr(), e.getExpr()));
+        cmd->reset(
+            new DefineFunctionCommand(name,
+                                      func.getExpr(),
+                                      e.getExpr(),
+                                      PARSER_STATE->getGlobalDeclarations()));
       }
-    | LPAREN_TOK
+    | // (define (f (v U) ...) t)
+      LPAREN_TOK
       symbol[name,CHECK_UNDECLARED,SYM_VARIABLE]
       { PARSER_STATE->checkUserSymbol(name); }
       sortedVarList[sortedVarNames] RPAREN_TOK
@@ -1382,7 +1393,8 @@ extendedCommand[std::unique_ptr<CVC4::Command>* cmd]
         terms = PARSER_STATE->bindBoundVars(sortedVarNames);
       }
       term[e,e2]
-      { PARSER_STATE->popScope();
+      {
+        PARSER_STATE->popScope();
         // declare the name down here (while parsing term, signature
         // must not be extended with the name itself; no recursion
         // permitted)
@@ -1398,11 +1410,16 @@ extendedCommand[std::unique_ptr<CVC4::Command>* cmd]
         }
         api::Term func = PARSER_STATE->bindVar(name, tt,
                                         ExprManager::VAR_FLAG_DEFINED);
-        cmd->reset(new DefineFunctionCommand(
-            name, func.getExpr(), api::termVectorToExprs(terms), e.getExpr()));
+        cmd->reset(
+            new DefineFunctionCommand(name,
+                                      func.getExpr(),
+                                      api::termVectorToExprs(terms),
+                                      e.getExpr(),
+                                      PARSER_STATE->getGlobalDeclarations()));
       }
     )
-  | DEFINE_CONST_TOK { PARSER_STATE->checkThatLogicIsSet(); }
+  | // (define-const x U t)
+    DEFINE_CONST_TOK { PARSER_STATE->checkThatLogicIsSet(); }
     symbol[name,CHECK_UNDECLARED,SYM_VARIABLE]
     { PARSER_STATE->checkUserSymbol(name); }
     sortSymbol[t,CHECK_DECLARED]
@@ -1412,14 +1429,19 @@ extendedCommand[std::unique_ptr<CVC4::Command>* cmd]
       terms = PARSER_STATE->bindBoundVars(sortedVarNames);
     }
     term[e, e2]
-    { PARSER_STATE->popScope();
+    {
+      PARSER_STATE->popScope();
       // declare the name down here (while parsing term, signature
       // must not be extended with the name itself; no recursion
       // permitted)
       api::Term func = PARSER_STATE->bindVar(name, t,
                                       ExprManager::VAR_FLAG_DEFINED);
-      cmd->reset(new DefineFunctionCommand(
-          name, func.getExpr(), api::termVectorToExprs(terms), e.getExpr()));
+      cmd->reset(
+          new DefineFunctionCommand(name,
+                                    func.getExpr(),
+                                    api::termVectorToExprs(terms),
+                                    e.getExpr(),
+                                    PARSER_STATE->getGlobalDeclarations()));
     }
 
   | SIMPLIFY_TOK { PARSER_STATE->checkThatLogicIsSet(); }
@@ -2217,7 +2239,7 @@ attribute[CVC4::api::Term& expr, CVC4::api::Term& retExpr, std::string& attr]
       std::string name = sexpr.getValue();
       // bind name to expr with define-fun
       Command* c = new DefineNamedFunctionCommand(
-          name, func.getExpr(), std::vector<Expr>(), expr.getExpr());
+          name, func.getExpr(), std::vector<Expr>(), expr.getExpr(), PARSER_STATE->getGlobalDeclarations());
       c->setMuted(true);
       PARSER_STATE->preemptCommand(c);
     }
index 20f2dcff9fade82eba2ffebbbbabf601cd35601f..9fd0122fce0c4dabfd86d4f274d8efe75339f272 100644 (file)
@@ -1266,22 +1266,27 @@ std::string DefineTypeCommand::getCommandName() const { return "define-sort"; }
 
 DefineFunctionCommand::DefineFunctionCommand(const std::string& id,
                                              Expr func,
-                                             Expr formula)
+                                             Expr formula,
+                                             bool global)
     : DeclarationDefinitionCommand(id),
       d_func(func),
       d_formals(),
-      d_formula(formula)
+      d_formula(formula),
+      d_global(global)
 {
 }
 
 DefineFunctionCommand::DefineFunctionCommand(const std::string& id,
                                              Expr func,
                                              const std::vector<Expr>& formals,
-                                             Expr formula)
+                                             Expr formula,
+                                             bool global)
     : DeclarationDefinitionCommand(id),
       d_func(func),
       d_formals(formals),
-      d_formula(formula)
+      d_formula(formula),
+      d_global(global)
+
 {
 }
 
@@ -1298,7 +1303,7 @@ void DefineFunctionCommand::invoke(SmtEngine* smtEngine)
   {
     if (!d_func.isNull())
     {
-      smtEngine->defineFunction(d_func, d_formals, d_formula);
+      smtEngine->defineFunction(d_func, d_formals, d_formula, d_global);
     }
     d_commandStatus = CommandSuccess::instance();
   }
@@ -1319,12 +1324,13 @@ Command* DefineFunctionCommand::exportTo(ExprManager* exprManager,
             back_inserter(formals),
             ExportTransformer(exprManager, variableMap));
   Expr formula = d_formula.exportTo(exprManager, variableMap);
-  return new DefineFunctionCommand(d_symbol, func, formals, formula);
+  return new DefineFunctionCommand(d_symbol, func, formals, formula, d_global);
 }
 
 Command* DefineFunctionCommand::clone() const
 {
-  return new DefineFunctionCommand(d_symbol, d_func, d_formals, d_formula);
+  return new DefineFunctionCommand(
+      d_symbol, d_func, d_formals, d_formula, d_global);
 }
 
 std::string DefineFunctionCommand::getCommandName() const
@@ -1340,8 +1346,9 @@ DefineNamedFunctionCommand::DefineNamedFunctionCommand(
     const std::string& id,
     Expr func,
     const std::vector<Expr>& formals,
-    Expr formula)
-    : DefineFunctionCommand(id, func, formals, formula)
+    Expr formula,
+    bool global)
+    : DefineFunctionCommand(id, func, formals, formula, global)
 {
 }
 
@@ -1365,12 +1372,14 @@ Command* DefineNamedFunctionCommand::exportTo(
             back_inserter(formals),
             ExportTransformer(exprManager, variableMap));
   Expr formula = d_formula.exportTo(exprManager, variableMap);
-  return new DefineNamedFunctionCommand(d_symbol, func, formals, formula);
+  return new DefineNamedFunctionCommand(
+      d_symbol, func, formals, formula, d_global);
 }
 
 Command* DefineNamedFunctionCommand::clone() const
 {
-  return new DefineNamedFunctionCommand(d_symbol, d_func, d_formals, d_formula);
+  return new DefineNamedFunctionCommand(
+      d_symbol, d_func, d_formals, d_formula, d_global);
 }
 
 /* -------------------------------------------------------------------------- */
@@ -1378,7 +1387,8 @@ Command* DefineNamedFunctionCommand::clone() const
 /* -------------------------------------------------------------------------- */
 
 DefineFunctionRecCommand::DefineFunctionRecCommand(
-    Expr func, const std::vector<Expr>& formals, Expr formula)
+    Expr func, const std::vector<Expr>& formals, Expr formula, bool global)
+    : d_global(global)
 {
   d_funcs.push_back(func);
   d_formals.push_back(formals);
@@ -1388,11 +1398,10 @@ DefineFunctionRecCommand::DefineFunctionRecCommand(
 DefineFunctionRecCommand::DefineFunctionRecCommand(
     const std::vector<Expr>& funcs,
     const std::vector<std::vector<Expr>>& formals,
-    const std::vector<Expr>& formulas)
+    const std::vector<Expr>& formulas,
+    bool global)
+    : d_funcs(funcs), d_formals(formals), d_formulas(formulas), d_global(global)
 {
-  d_funcs.insert(d_funcs.end(), funcs.begin(), funcs.end());
-  d_formals.insert(d_formals.end(), formals.begin(), formals.end());
-  d_formulas.insert(d_formulas.end(), formulas.begin(), formulas.end());
 }
 
 const std::vector<Expr>& DefineFunctionRecCommand::getFunctions() const
@@ -1415,7 +1424,7 @@ void DefineFunctionRecCommand::invoke(SmtEngine* smtEngine)
 {
   try
   {
-    smtEngine->defineFunctionsRec(d_funcs, d_formals, d_formulas);
+    smtEngine->defineFunctionsRec(d_funcs, d_formals, d_formulas, d_global);
     d_commandStatus = CommandSuccess::instance();
   }
   catch (exception& e)
@@ -1450,12 +1459,12 @@ Command* DefineFunctionRecCommand::exportTo(
     Expr formula = d_formulas[i].exportTo(exprManager, variableMap);
     formulas.push_back(formula);
   }
-  return new DefineFunctionRecCommand(funcs, formals, formulas);
+  return new DefineFunctionRecCommand(funcs, formals, formulas, d_global);
 }
 
 Command* DefineFunctionRecCommand::clone() const
 {
-  return new DefineFunctionRecCommand(d_funcs, d_formals, d_formulas);
+  return new DefineFunctionRecCommand(d_funcs, d_formals, d_formulas, d_global);
 }
 
 std::string DefineFunctionRecCommand::getCommandName() const
index 63f1f0f33c083620498530c056b8730e5ddece5d..0582cee34095a4cd23982866e3c4d75910c5727d 100644 (file)
@@ -436,17 +436,16 @@ class CVC4_PUBLIC DefineTypeCommand : public DeclarationDefinitionCommand
 
 class CVC4_PUBLIC DefineFunctionCommand : public DeclarationDefinitionCommand
 {
- protected:
-  Expr d_func;
-  std::vector<Expr> d_formals;
-  Expr d_formula;
-
  public:
-  DefineFunctionCommand(const std::string& id, Expr func, Expr formula);
+  DefineFunctionCommand(const std::string& id,
+                        Expr func,
+                        Expr formula,
+                        bool global);
   DefineFunctionCommand(const std::string& id,
                         Expr func,
                         const std::vector<Expr>& formals,
-                        Expr formula);
+                        Expr formula,
+                        bool global);
 
   Expr getFunction() const;
   const std::vector<Expr>& getFormals() const;
@@ -457,6 +456,19 @@ class CVC4_PUBLIC DefineFunctionCommand : public DeclarationDefinitionCommand
                     ExprManagerMapCollection& variableMap) override;
   Command* clone() const override;
   std::string getCommandName() const override;
+
+ protected:
+  /** The function we are defining */
+  Expr d_func;
+  /** The formal arguments for the function we are defining */
+  std::vector<Expr> d_formals;
+  /** The formula corresponding to the body of the function we are defining */
+  Expr d_formula;
+  /**
+   * Stores whether this definition is global (i.e. should persist when
+   * popping the user context.
+   */
+  bool d_global;
 }; /* class DefineFunctionCommand */
 
 /**
@@ -470,7 +482,8 @@ class CVC4_PUBLIC DefineNamedFunctionCommand : public DefineFunctionCommand
   DefineNamedFunctionCommand(const std::string& id,
                              Expr func,
                              const std::vector<Expr>& formals,
-                             Expr formula);
+                             Expr formula,
+                             bool global);
   void invoke(SmtEngine* smtEngine) override;
   Command* exportTo(ExprManager* exprManager,
                     ExprManagerMapCollection& variableMap) override;
@@ -487,10 +500,12 @@ class CVC4_PUBLIC DefineFunctionRecCommand : public Command
  public:
   DefineFunctionRecCommand(Expr func,
                            const std::vector<Expr>& formals,
-                           Expr formula);
+                           Expr formula,
+                           bool global);
   DefineFunctionRecCommand(const std::vector<Expr>& funcs,
                            const std::vector<std::vector<Expr> >& formals,
-                           const std::vector<Expr>& formula);
+                           const std::vector<Expr>& formula,
+                           bool global);
 
   const std::vector<Expr>& getFunctions() const;
   const std::vector<std::vector<Expr> >& getFormals() const;
@@ -509,6 +524,11 @@ class CVC4_PUBLIC DefineFunctionRecCommand : public Command
   std::vector<std::vector<Expr> > d_formals;
   /** formulas corresponding to the bodies of the functions we are defining */
   std::vector<Expr> d_formulas;
+  /**
+   * Stores whether this definition is global (i.e. should persist when
+   * popping the user context.
+   */
+  bool d_global;
 }; /* class DefineFunctionRecCommand */
 
 /**
index 9e382cdcf70f75c9c0bd154dc30100b1548c6b2c..e7ef23c160155de9bc0b74c718333ff70e380be3 100644 (file)
@@ -759,6 +759,7 @@ void SmtEngine::finishInit()
     // In the case of incremental solving, we appear to need these to
     // ensure the relevant Nodes remain live.
     d_assertionList = new (true) AssertionList(getUserContext());
+    d_globalDefineFunRecLemmas.reset(new std::vector<Node>());
   }
 
   // dump out a set-logic command only when raw-benchmark is disabled to avoid
@@ -847,6 +848,8 @@ SmtEngine::~SmtEngine()
       d_assignments->deleteSelf();
     }
 
+    d_globalDefineFunRecLemmas.reset();
+
     if(d_assertionList != NULL) {
       d_assertionList->deleteSelf();
     }
@@ -1179,7 +1182,8 @@ void SmtEngine::debugCheckFunctionBody(Expr formula,
 
 void SmtEngine::defineFunction(Expr func,
                                const std::vector<Expr>& formals,
-                               Expr formula)
+                               Expr formula,
+                               bool global)
 {
   SmtScope smts(this);
   finalOptionsAreSet();
@@ -1191,7 +1195,7 @@ void SmtEngine::defineFunction(Expr func,
   ss << language::SetLanguage(
             language::SetLanguage::getLanguage(Dump.getStream()))
      << func;
-  DefineFunctionCommand c(ss.str(), func, formals, formula);
+  DefineFunctionCommand c(ss.str(), func, formals, formula, global);
   addToModelCommandAndDump(
       c, ExprManager::VAR_FLAG_DEFINED, true, "declarations");
 
@@ -1220,13 +1224,22 @@ void SmtEngine::defineFunction(Expr func,
   // Otherwise, (check-sat) (get-value ((! foo :named bar))) breaks
   // d_haveAdditions = true;
   Debug("smt") << "definedFunctions insert " << funcNode << " " << formNode << endl;
-  d_definedFunctions->insert(funcNode, def);
+
+  if (global)
+  {
+    d_definedFunctions->insertAtContextLevelZero(funcNode, def);
+  }
+  else
+  {
+    d_definedFunctions->insert(funcNode, def);
+  }
 }
 
 void SmtEngine::defineFunctionsRec(
     const std::vector<Expr>& funcs,
-    const std::vector<std::vector<Expr> >& formals,
-    const std::vector<Expr>& formulas)
+    const std::vector<std::vector<Expr>>& formals,
+    const std::vector<Expr>& formulas,
+    bool global)
 {
   SmtScope smts(this);
   finalOptionsAreSet();
@@ -1254,7 +1267,8 @@ void SmtEngine::defineFunctionsRec(
 
   if (Dump.isOn("raw-benchmark"))
   {
-    Dump("raw-benchmark") << DefineFunctionRecCommand(funcs, formals, formulas);
+    Dump("raw-benchmark") << DefineFunctionRecCommand(
+        funcs, formals, formulas, global);
   }
 
   ExprManager* em = getExprManager();
@@ -1294,17 +1308,28 @@ void SmtEngine::defineFunctionsRec(
     //   notice we don't call assertFormula directly, since this would
     //   duplicate the output on raw-benchmark.
     Expr e = d_private->substituteAbstractValues(Node::fromExpr(lem)).toExpr();
-    if (d_assertionList != NULL)
+    if (d_assertionList != nullptr)
     {
       d_assertionList->push_back(e);
     }
-    d_private->addFormula(e.getNode(), false, true, false, maybeHasFv);
+    if (global && d_globalDefineFunRecLemmas != nullptr)
+    {
+      // Global definitions are asserted at check-sat-time because we have to
+      // make sure that they are always present
+      Assert(!language::isInputLangSygus(options::inputLanguage()));
+      d_globalDefineFunRecLemmas->emplace_back(Node::fromExpr(e));
+    }
+    else
+    {
+      d_private->addFormula(e.getNode(), false, true, false, maybeHasFv);
+    }
   }
 }
 
 void SmtEngine::defineFunctionRec(Expr func,
                                   const std::vector<Expr>& formals,
-                                  Expr formula)
+                                  Expr formula,
+                                  bool global)
 {
   std::vector<Expr> funcs;
   funcs.push_back(func);
@@ -1312,7 +1337,7 @@ void SmtEngine::defineFunctionRec(Expr func,
   formals_multi.push_back(formals);
   std::vector<Expr> formulas;
   formulas.push_back(formula);
-  defineFunctionsRec(funcs, formals_multi, formulas);
+  defineFunctionsRec(funcs, formals_multi, formulas, global);
 }
 
 bool SmtEngine::isDefinedFunction( Expr func ){
@@ -1652,6 +1677,17 @@ Result SmtEngine::checkSatisfiability(const vector<Expr>& assumptions,
       d_private->addFormula(e.getNode(), inUnsatCore, true, true);
     }
 
+    if (d_globalDefineFunRecLemmas != nullptr)
+    {
+      // Global definitions are asserted at check-sat-time because we have to
+      // make sure that they are always present (they are essentially level
+      // zero assertions)
+      for (const Node& lemma : *d_globalDefineFunRecLemmas)
+      {
+        d_private->addFormula(lemma, false, true, false, false);
+      }
+    }
+
     r = check();
 
     if ((options::solveRealAsInt() || options::solveIntAsBV() > 0)
index 75737b603f12b7fefabc1fdb44ccb943737f6045..29d25c10330a8e5f04869b7f2a147ad6c13351b4 100644 (file)
@@ -290,15 +290,18 @@ class CVC4_PUBLIC SmtEngine
    *   (lambda (formals) formula)
    * This adds func to the list of defined functions, which indicates that
    * all occurrences of func should be expanded during expandDefinitions.
-   * This method expects input such that:
-   * - func : a variable of function type that expects the arguments in
-   *          formals,
-   * - formals : a list of BOUND_VARIABLE expressions,
-   * - formula does not contain func.
+   *
+   * @param func a variable of function type that expects the arguments in
+   *             formal
+   * @param formals a list of BOUND_VARIABLE expressions
+   * @param formula The body of the function, must not contain func
+   * @param global True if this definition is global (i.e. should persist when
+   *               popping the user context)
    */
   void defineFunction(Expr func,
                       const std::vector<Expr>& formals,
-                      Expr formula);
+                      Expr formula,
+                      bool global = false);
 
   /** Return true if given expression is a defined function. */
   bool isDefinedFunction(Expr func);
@@ -317,17 +320,22 @@ class CVC4_PUBLIC SmtEngine
    * - func[i] : a variable of function type that expects the arguments in
    *             formals[i], and
    * - formals[i] : a list of BOUND_VARIABLE expressions.
+   *
+   * @param global True if this definition is global (i.e. should persist when
+   *               popping the user context)
    */
   void defineFunctionsRec(const std::vector<Expr>& funcs,
-                          const std::vector<std::vector<Expr> >& formals,
-                          const std::vector<Expr>& formulas);
+                          const std::vector<std::vector<Expr>>& formals,
+                          const std::vector<Expr>& formulas,
+                          bool global = false);
   /**
    * Define function recursive
    * Same as above, but for a single function.
    */
   void defineFunctionRec(Expr func,
                          const std::vector<Expr>& formals,
-                         Expr formula);
+                         Expr formula,
+                         bool global = false);
   /**
    * Add a formula to the current context: preprocess, do per-theory
    * setup, use processAssertionList(), asserting to T-solver for
@@ -862,8 +870,6 @@ class CVC4_PUBLIC SmtEngine
   typedef context::CDList<Expr> AssertionList;
   /** The type of our internal assignment set */
   typedef context::CDHashSet<Node, NodeHashFunction> AssignmentSet;
-  /** The types for the recursive function definitions */
-  typedef context::CDList<Node> NodeList;
 
   // disallow copy/assignment
   SmtEngine(const SmtEngine&) = delete;
@@ -1139,10 +1145,16 @@ class CVC4_PUBLIC SmtEngine
 
   /**
    * The assertion list (before any conversion) for supporting
-   * getAssertions().  Only maintained if in interactive mode.
+   * getAssertions().  Only maintained if in incremental mode.
    */
   AssertionList* d_assertionList;
 
+  /**
+   * List of lemmas generated for global recursive function definitions. We
+   * assert this list of definitions in each check-sat call.
+   */
+  std::unique_ptr<std::vector<Node>> d_globalDefineFunRecLemmas;
+
   /**
    * The list of assumptions from the previous call to checkSatisfiability.
    * Note that if the last call to checkSatisfiability was an entailment check,
index 4bc9d2705ee70fdb83ad4760c80d3ba28b47319d..e0ce456bc6cd19bec87ee84d3fece01bb37ae343 100644 (file)
@@ -921,6 +921,7 @@ set(regress_0_tests
   regress0/smtlib/issue4028.smt2
   regress0/smtlib/issue4077.smt2
   regress0/smtlib/issue4151.smt2
+  regress0/smtlib/issue4552.smt2
   regress0/smtlib/reason-unknown.smt2
   regress0/smtlib/reset.smt2
   regress0/smtlib/reset-assertions1.smt2
diff --git a/test/regress/regress0/smtlib/issue4552.smt2 b/test/regress/regress0/smtlib/issue4552.smt2
new file mode 100644 (file)
index 0000000..af8e0b9
--- /dev/null
@@ -0,0 +1,27 @@
+; COMMAND-LINE: --incremental
+; EXPECT: unsat
+; EXPECT: unsat
+; EXPECT: unsat
+(set-logic UF)
+(set-option :global-declarations true)
+
+(push)
+(define a true)
+(define (f (b Bool)) b)
+(define-const a2 Bool true)
+
+(define-fun a3 () Bool true)
+
+(define-fun-rec b () Bool true)
+(define-funs-rec ((g ((b Bool)) Bool)) (b))
+(assert (or (not a) (not a2) (not a3) (not b) (g false)))
+(check-sat)
+(pop)
+
+(assert (or (not a) (not a2) (not a3) (not b) (g false)))
+(check-sat)
+
+(reset-assertions)
+
+(assert (or (not a) (not a2) (not a3) (not b) (g false)))
+(check-sat)
index 3dcf18f788177246d12247b94149ffeb75cfc6d5..257c286694c94e3ffebb306996679a4080c11200 100644 (file)
@@ -84,8 +84,11 @@ class SolverBlack : public CxxTest::TestSuite
   void testDeclareSort();
 
   void testDefineFun();
+  void testDefineFunGlobal();
   void testDefineFunRec();
+  void testDefineFunRecGlobal();
   void testDefineFunsRec();
+  void testDefineFunsRecGlobal();
 
   void testUFIteration();
 
@@ -1036,6 +1039,30 @@ void SolverBlack::testDefineFun()
                    CVC4ApiException&);
 }
 
+void SolverBlack::testDefineFunGlobal()
+{
+  Sort bSort = d_solver->getBooleanSort();
+  Sort fSort = d_solver->mkFunctionSort({bSort}, bSort);
+
+  Term bTrue = d_solver->mkBoolean(true);
+  // (define-fun f () Bool true)
+  Term f = d_solver->defineFun("f", {}, bSort, bTrue, true);
+  Term b = d_solver->mkVar(bSort, "b");
+  Term gSym = d_solver->mkConst(fSort, "g");
+  // (define-fun g (b Bool) Bool b)
+  Term g = d_solver->defineFun(gSym, {b}, b, true);
+
+  // (assert (or (not f) (not (g true))))
+  d_solver->assertFormula(d_solver->mkTerm(
+      OR, f.notTerm(), d_solver->mkTerm(APPLY_UF, g, bTrue).notTerm()));
+  TS_ASSERT(d_solver->checkSat().isUnsat());
+  d_solver->resetAssertions();
+  // (assert (or (not f) (not (g true))))
+  d_solver->assertFormula(d_solver->mkTerm(
+      OR, f.notTerm(), d_solver->mkTerm(APPLY_UF, g, bTrue).notTerm()));
+  TS_ASSERT(d_solver->checkSat().isUnsat());
+}
+
 void SolverBlack::testDefineFunRec()
 {
   Sort bvSort = d_solver->mkBitVectorSort(32);
@@ -1090,6 +1117,31 @@ void SolverBlack::testDefineFunRec()
                    CVC4ApiException&);
 }
 
+void SolverBlack::testDefineFunRecGlobal()
+{
+  Sort bSort = d_solver->getBooleanSort();
+  Sort fSort = d_solver->mkFunctionSort({bSort}, bSort);
+
+  d_solver->push();
+  Term bTrue = d_solver->mkBoolean(true);
+  // (define-fun f () Bool true)
+  Term f = d_solver->defineFunRec("f", {}, bSort, bTrue, true);
+  Term b = d_solver->mkVar(bSort, "b");
+  Term gSym = d_solver->mkConst(fSort, "g");
+  // (define-fun g (b Bool) Bool b)
+  Term g = d_solver->defineFunRec(gSym, {b}, b, true);
+
+  // (assert (or (not f) (not (g true))))
+  d_solver->assertFormula(d_solver->mkTerm(
+      OR, f.notTerm(), d_solver->mkTerm(APPLY_UF, g, bTrue).notTerm()));
+  TS_ASSERT(d_solver->checkSat().isUnsat());
+  d_solver->pop();
+  // (assert (or (not f) (not (g true))))
+  d_solver->assertFormula(d_solver->mkTerm(
+      OR, f.notTerm(), d_solver->mkTerm(APPLY_UF, g, bTrue).notTerm()));
+  TS_ASSERT(d_solver->checkSat().isUnsat());
+}
+
 void SolverBlack::testDefineFunsRec()
 {
   Sort uSort = d_solver->mkUninterpretedSort("u");
@@ -1162,6 +1214,27 @@ void SolverBlack::testDefineFunsRec()
       CVC4ApiException&);
 }
 
+void SolverBlack::testDefineFunsRecGlobal()
+{
+  Sort bSort = d_solver->getBooleanSort();
+  Sort fSort = d_solver->mkFunctionSort({bSort}, bSort);
+
+  d_solver->push();
+  Term bTrue = d_solver->mkBoolean(true);
+  Term b = d_solver->mkVar(bSort, "b");
+  Term gSym = d_solver->mkConst(fSort, "g");
+  // (define-funs-rec ((g ((b Bool)) Bool)) (b))
+  d_solver->defineFunsRec({gSym}, {{b}}, {b}, true);
+
+  // (assert (not (g true)))
+  d_solver->assertFormula(d_solver->mkTerm(APPLY_UF, gSym, bTrue).notTerm());
+  TS_ASSERT(d_solver->checkSat().isUnsat());
+  d_solver->pop();
+  // (assert (not (g true)))
+  d_solver->assertFormula(d_solver->mkTerm(APPLY_UF, gSym, bTrue).notTerm());
+  TS_ASSERT(d_solver->checkSat().isUnsat());
+}
+
 void SolverBlack::testUFIteration()
 {
   Sort intSort = d_solver->getIntegerSort();