Add a `define-fun` command for each `:named` term. (#7308)
authorAbdalrhman Mohamed <32971963+abdoo8080@users.noreply.github.com>
Thu, 28 Oct 2021 17:04:06 +0000 (12:04 -0500)
committerGitHub <noreply@github.com>
Thu, 28 Oct 2021 17:04:06 +0000 (17:04 +0000)
This PR is step towards enabling -o raw-benchmark for regressions. It creates a define-fun command for each named term. This allows us to reparse dumped benchmarks containing named terms, but we still lose track of those terms and do not print them in response to (get-assignment) and (get-unsat-core) commands. This PR also simplifies the interface for DefineFunCommand interface and removes support for (define ...) command.

18 files changed:
NEWS
examples/api/python/sygus-inv.py
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5.h
src/api/cpp/cvc5_checks.h
src/api/java/io/github/cvc5/api/Solver.java
src/api/java/jni/solver.cpp
src/api/python/cvc5.pxd
src/api/python/cvc5.pxi
src/main/command_executor.cpp
src/parser/smt2/Smt2.g
src/smt/command.cpp
src/smt/command.h
test/python/unit/api/test_solver.py
test/regress/regress0/push-pop/inc-define.smt2
test/regress/regress0/smtlib/issue4552.smt2
test/unit/api/java/SolverTest.java
test/unit/api/solver_black.cpp

diff --git a/NEWS b/NEWS
index f0164364392a1e4a80673181a931e6e5b7337bcf..c9716b33f9e67e6e6042f3d05df24223549733de 100644 (file)
--- a/NEWS
+++ b/NEWS
@@ -31,7 +31,8 @@ Improvements:
 
 Changes:
 * SyGuS: Removed support for SyGuS-IF 1.0.
-* Removed Java and Python bindings for the legacy API
+* Removed support for the (non-standard) `define` command.
+* Removed Java and Python bindings for the legacy API.
 * Interactive shell: the GPL-licensed Readline library has been replaced the
   BSD-licensed Editline. Compiling with `--best` now enables Editline, instead
   of Readline. Without selecting optional GPL components, Editline-enabled CVC4
index 8273aa298f332003a19bf5fb2387bf3d635943de..50fa3f04ff3eb3fb96ccb4dd4fa83888090e8c23 100644 (file)
@@ -48,9 +48,9 @@ if __name__ == "__main__":
                         slv.mkTerm(kinds.Equal, xp, x))
 
   # define the pre-conditions, transition relations, and post-conditions
-  pre_f = slv.defineFun("pre-f", {x}, boolean, slv.mkTerm(kinds.Equal, x, zero))
-  trans_f = slv.defineFun("trans-f", {x, xp}, boolean, ite)
-  post_f = slv.defineFun("post-f", {x}, boolean, slv.mkTerm(kinds.Leq, x, ten))
+  pre_f = slv.defineFun("pre-f", [x], boolean, slv.mkTerm(kinds.Equal, x, zero))
+  trans_f = slv.defineFun("trans-f", [x, xp], boolean, ite)
+  post_f = slv.defineFun("post-f", [x], boolean, slv.mkTerm(kinds.Leq, x, ten))
 
   # declare the invariant-to-synthesize
   inv_f = slv.synthInv("inv-f", {x})
index a342cea53ec9efdf526c56a51a54663e32e43fbf..addfeb0dae61baffb90a699a0f471e33d15b053f 100644 (file)
@@ -6716,7 +6716,7 @@ Term Solver::defineFun(const std::string& symbol,
   CVC5_API_TRY_CATCH_BEGIN;
   CVC5_API_SOLVER_CHECK_CODOMAIN_SORT(sort);
   CVC5_API_SOLVER_CHECK_TERM(term);
-  CVC5_API_CHECK(sort == term.getSort())
+  CVC5_API_CHECK(term.getSort().isSubsortOf(sort))
       << "Invalid sort of function body '" << term << "', expected '" << sort
       << "'";
 
@@ -6743,37 +6743,6 @@ Term Solver::defineFun(const std::string& symbol,
   CVC5_API_TRY_CATCH_END;
 }
 
-Term Solver::defineFun(const Term& fun,
-                       const std::vector<Term>& bound_vars,
-                       const Term& term,
-                       bool global) const
-{
-  CVC5_API_TRY_CATCH_BEGIN;
-  CVC5_API_SOLVER_CHECK_TERM(fun);
-  CVC5_API_SOLVER_CHECK_TERM(term);
-  if (fun.getSort().isFunction())
-  {
-    std::vector<Sort> domain_sorts = fun.getSort().getFunctionDomainSorts();
-    CVC5_API_SOLVER_CHECK_BOUND_VARS_DEF_FUN(fun, bound_vars, domain_sorts);
-    Sort codomain = fun.getSort().getFunctionCodomainSort();
-    CVC5_API_CHECK(codomain == term.getSort())
-        << "Invalid sort of function body '" << term << "', expected '"
-        << codomain << "'";
-  }
-  else
-  {
-    CVC5_API_SOLVER_CHECK_BOUND_VARS(bound_vars);
-    CVC5_API_ARG_CHECK_EXPECTED(bound_vars.size() == 0, fun)
-        << "function or nullary symbol";
-  }
-  //////// all checks before this line
-  std::vector<Node> ebound_vars = Term::termVectorToNodes(bound_vars);
-  d_slv->defineFunction(*fun.d_node, ebound_vars, *term.d_node, global);
-  return fun;
-  ////////
-  CVC5_API_TRY_CATCH_END;
-}
-
 Term Solver::defineFunRec(const std::string& symbol,
                           const std::vector<Term>& bound_vars,
                           const Sort& sort,
index f0a36b79298733138044abdeaa5fa3d2b771d491..63dadaef630b86e0214993972bea60ba15185549 100644 (file)
@@ -3813,24 +3813,6 @@ class CVC5_EXPORT Solver
                  const Sort& sort,
                  const Term& term,
                  bool global = false) const;
-  /**
-   * Define n-ary function.
-   * SMT-LIB:
-   * \verbatim
-   * ( define-fun <function_def> )
-   * \endverbatim
-   * Create parameter 'fun' with mkConst().
-   * @param fun the sorted function
-   * @param bound_vars the parameters to this function
-   * @param term the function body
-   * @param global determines whether this definition is global (i.e. persists
-   *               when popping the context)
-   * @return the function
-   */
-  Term defineFun(const Term& fun,
-                 const std::vector<Term>& bound_vars,
-                 const Term& term,
-                 bool global = false) const;
 
   /**
    * Define recursive function.
index c30237ecdba454fdf2564c5a8971f31ac5072434..35c21df9c9e0e8c1deae07ca8b0323b8cf469ca7 100644 (file)
@@ -438,8 +438,6 @@ namespace api {
     CVC5_API_ARG_CHECK_NOT_NULL(sort);                      \
     CVC5_API_CHECK(this == sort.d_solver)                   \
         << "Given sort is not associated with this solver"; \
-    CVC5_API_ARG_CHECK_EXPECTED(sort.isFirstClass(), sort)  \
-        << "first-class sort as codomain sort";             \
     CVC5_API_ARG_CHECK_EXPECTED(!sort.isFunction(), sort)   \
         << "function sort as codomain sort";                \
   } while (0)
index 3a9f450a438f359ee06d158c4e6847392cabedd9..c57ba48b8ce5f0fb6f0a6348a9f8d434288757d3 100644 (file)
@@ -1593,47 +1593,6 @@ public class Solver implements IPointer
       long termPointer,
       boolean global);
 
-  /**
-   * Define n-ary function in the current context.
-   * SMT-LIB:
-   * {@code
-   * ( define-fun <function_def> )
-   * }
-   * Create parameter 'fun' with mkConst().
-   * @param fun the sorted function
-   * @param boundVars the parameters to this function
-   * @param term the function body
-   * @return the function
-   */
-  public Term defineFun(Term fun, Term[] boundVars, Term term)
-  {
-    return defineFun(fun, boundVars, term, false);
-  }
-  /**
-   * Define n-ary function.
-   * SMT-LIB:
-   * {@code
-   * ( define-fun <function_def> )
-   * }
-   * Create parameter 'fun' with mkConst().
-   * @param fun the sorted function
-   * @param boundVars the parameters to this function
-   * @param term the function body
-   * @param global determines whether this definition is global (i.e. persists
-   *               when popping the context)
-   * @return the function
-   */
-  public Term defineFun(Term fun, Term[] boundVars, Term term, boolean global)
-  {
-    long[] boundVarPointers = Utils.getPointers(boundVars);
-    long termPointer =
-        defineFun(pointer, fun.getPointer(), boundVarPointers, term.getPointer(), global);
-    return new Term(this, termPointer);
-  }
-
-  private native long defineFun(
-      long pointer, long funPointer, long[] boundVarPointers, long termPointer, boolean global);
-
   /**
    * Define recursive function in the current context.
    * SMT-LIB:
index af3d7e59e8a59e1758ec6daa1e207f401e42729c..bc4d7ff43f97ffd4e66f05d1ef99def12d8b5538 100644 (file)
@@ -1703,31 +1703,6 @@ Java_io_github_cvc5_api_Solver_defineFun__JLjava_lang_String_2_3JJJZ(
   CVC5_JAVA_API_TRY_CATCH_END_RETURN(env, 0);
 }
 
-/*
- * Class:     io_github_cvc5_api_Solver
- * Method:    defineFun
- * Signature: (JJ[JJZ)J
- */
-JNIEXPORT jlong JNICALL
-Java_io_github_cvc5_api_Solver_defineFun__JJ_3JJZ(JNIEnv* env,
-                                                  jobject,
-                                                  jlong pointer,
-                                                  jlong funPointer,
-                                                  jlongArray jVars,
-                                                  jlong termPointer,
-                                                  jboolean global)
-{
-  CVC5_JAVA_API_TRY_CATCH_BEGIN;
-  Solver* solver = reinterpret_cast<Solver*>(pointer);
-  Term* fun = reinterpret_cast<Term*>(funPointer);
-  Term* term = reinterpret_cast<Term*>(termPointer);
-  std::vector<Term> vars = getObjectsFromPointers<Term>(env, jVars);
-  Term* retPointer =
-      new Term(solver->defineFun(*fun, vars, *term, (bool)global));
-  return reinterpret_cast<jlong>(retPointer);
-  CVC5_JAVA_API_TRY_CATCH_END_RETURN(env, 0);
-}
-
 /*
  * Class:     io_github_cvc5_api_Solver
  * Method:    defineFunRec
index 42aee08b0de97a6693e974d8ef284df25838539a..9504bccae0778ff47db93127f47566e47801e012 100644 (file)
@@ -266,7 +266,6 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         Sort declareSort(const string& symbol, uint32_t arity) except +
         Term defineFun(const string& symbol, const vector[Term]& bound_vars,
                        Sort sort, Term term, bint glbl) except +
-        Term defineFun(Term fun, const vector[Term]& bound_vars, Term term, bint glbl) except +
         Term defineFunRec(const string& symbol, const vector[Term]& bound_vars,
                           Sort sort, Term term, bint glbl) except +
         Term defineFunRec(Term fun, const vector[Term]& bound_vars,
index 3367bf47bda69b6b5039e40257967c2b30a418fb..6f50b840182b0905f646ec59472a6a4311e30867 100644 (file)
@@ -1811,14 +1811,8 @@ cdef class Solver:
         sort.csort = self.csolver.declareSort(symbol.encode(), arity)
         return sort
 
-    def defineFun(self, sym_or_fun, bound_vars, sort_or_term, t=None, glbl=False):
-        """
-        Define n-ary function.
-        Supports two uses:
-
-        - ``Term defineFun(str symbol, List[Term] bound_vars, Sort sort, Term term, bool glbl)``
-        - ``Term defineFun(Term fun, List[Term] bound_vars, Term term, bool glbl)``
-
+    def defineFun(self, str symbol, list bound_vars, Sort sort, Term term, glbl=False):
+        """Define n-ary function.
 
         SMT-LIB:
 
@@ -1830,27 +1824,20 @@ cdef class Solver:
         :param bound_vars: the parameters to this function
         :param sort: the sort of the return value of this function
         :param term: the function body
-        :param global: determines whether this definition is global (i.e. persists when popping the context)
+        :param glbl: determines whether this definition is global (i.e. persists when popping the context)
         :return: the function
         """
-        cdef Term term = Term(self)
+        cdef Term fun = Term(self)
         cdef vector[c_Term] v
         for bv in bound_vars:
             v.push_back((<Term?> bv).cterm)
 
-        if t is not None:
-            term.cterm = self.csolver.defineFun((<str?> sym_or_fun).encode(),
-                                                <const vector[c_Term] &> v,
-                                                (<Sort?> sort_or_term).csort,
-                                                (<Term?> t).cterm,
-                                                <bint> glbl)
-        else:
-            term.cterm = self.csolver.defineFun((<Term?> sym_or_fun).cterm,
-                                                <const vector[c_Term]&> v,
-                                                (<Term?> sort_or_term).cterm,
-                                                <bint> glbl)
-
-        return term
+        fun.cterm = self.csolver.defineFun(symbol.encode(),
+                                           <const vector[c_Term] &> v,
+                                           sort.csort,
+                                           term.cterm,
+                                           <bint> glbl)
+        return fun
 
     def defineFunRec(self, sym_or_fun, bound_vars, sort_or_term, t=None, glbl=False):
         """Define recursive functions.
index 8188f247fb8b1748cea4190d47b4cc87176f9a82..f08d9adde520312ee634c664fddd47b86ff94e92 100644 (file)
@@ -204,7 +204,11 @@ bool solverInvoke(api::Solver* solver,
     cmd->toStream(ss);
   }
 
-  if (solver->getOptionInfo("parse-only").boolValue())
+  // In parse-only mode, we do not invoke any of the commands except define-fun
+  // commands. We invoke define-fun commands because they add function names
+  // to the symbol table.
+  if (solver->getOptionInfo("parse-only").boolValue()
+      && dynamic_cast<DefineFunctionCommand*>(cmd) == nullptr)
   {
     return true;
   }
index 696eb6ad4ce41833882131fc738509d21ef1047c..bf58e786ddabc43df16c776545557fe22ac342b7 100644 (file)
@@ -314,6 +314,10 @@ command [std::unique_ptr<cvc5::Command>* cmd]
       }
 
       t = PARSER_STATE->mkFlatFunctionType(sorts, t, flattenVars);
+      if (t.isFunction())
+      {
+        t = t.getFunctionCodomainSort();
+      }
       if (sortedVarNames.size() > 0)
       {
         PARSER_STATE->pushScope();
@@ -332,13 +336,7 @@ command [std::unique_ptr<cvc5::Command>* cmd]
       {
         PARSER_STATE->popScope();
       }
-      // declare the name down here (while parsing term, signature
-      // must not be extended with the name itself; no recursion
-      // permitted)
-      // we allow overloading for function definitions
-      api::Term func = PARSER_STATE->bindVar(name, t, false, true);
-      cmd->reset(new DefineFunctionCommand(
-          name, func, terms, expr, SYM_MAN->getGlobalDeclarations()));
+      cmd->reset(new DefineFunctionCommand(name, terms, t, expr));
     }
   | DECLARE_DATATYPE_TOK datatypeDefCommand[false, cmd]
   | DECLARE_DATATYPES_TOK datatypesDefCommand[false, cmd]
@@ -834,8 +832,7 @@ smt25Command[std::unique_ptr<cvc5::Command>* cmd]
       if( !flattenVars.empty() ){
         expr = PARSER_STATE->mkHoApply( expr, flattenVars );
       }
-      cmd->reset(new DefineFunctionRecCommand(
-          func, bvs, expr, SYM_MAN->getGlobalDeclarations()));
+      cmd->reset(new DefineFunctionRecCommand(func, bvs, expr));
     }
   | DEFINE_FUNS_REC_TOK
     { PARSER_STATE->checkThatLogicIsSet();}
@@ -898,8 +895,7 @@ smt25Command[std::unique_ptr<cvc5::Command>* cmd]
             "Number of functions defined does not match number listed in "
             "define-funs-rec"));
       }
-      cmd->reset(new DefineFunctionRecCommand(
-          funcs, formals, func_defs, SYM_MAN->getGlobalDeclarations()));
+      cmd->reset(new DefineFunctionRecCommand(funcs, formals, func_defs));
     }
   ;
 
@@ -984,48 +980,6 @@ extendedCommand[std::unique_ptr<cvc5::Command>* cmd]
     )+
     RPAREN_TOK
     { cmd->reset(seq.release()); }
-
-  | DEFINE_TOK { PARSER_STATE->checkThatLogicIsSet(); }
-    ( // (define f t)
-      symbol[name,CHECK_UNDECLARED,SYM_VARIABLE]
-      { PARSER_STATE->checkUserSymbol(name); }
-      term[e,e2]
-      {
-        api::Term func = PARSER_STATE->bindVar(name, e.getSort());
-        cmd->reset(new DefineFunctionCommand(
-            name, func, e, SYM_MAN->getGlobalDeclarations()));
-      }
-    | // (define (f (v U) ...) t)
-      LPAREN_TOK
-      symbol[name,CHECK_UNDECLARED,SYM_VARIABLE]
-      { PARSER_STATE->checkUserSymbol(name); }
-      sortedVarList[sortedVarNames] RPAREN_TOK
-      { /* add variables to parser state before parsing term */
-        Debug("parser") << "define fun: '" << name << "'" << std::endl;
-        PARSER_STATE->pushScope();
-        terms = PARSER_STATE->bindBoundVars(sortedVarNames);
-      }
-      term[e,e2]
-      {
-        PARSER_STATE->popScope();
-        // declare the name down here (while parsing term, signature
-        // must not be extended with the name itself; no recursion
-        // permitted)
-        api::Sort tt = e.getSort();
-        if( sortedVarNames.size() > 0 ) {
-          sorts.reserve(sortedVarNames.size());
-          for(std::vector<std::pair<std::string, api::Sort> >::const_iterator
-                i = sortedVarNames.begin(), iend = sortedVarNames.end();
-              i != iend; ++i) {
-            sorts.push_back((*i).second);
-          }
-          tt = SOLVER->mkFunctionSort(sorts, tt);
-        }
-        api::Term func = PARSER_STATE->bindVar(name, tt);
-        cmd->reset(new DefineFunctionCommand(
-            name, func, terms, e, SYM_MAN->getGlobalDeclarations()));
-      }
-    )
   | // (define-const x U t)
     DEFINE_CONST_TOK { PARSER_STATE->checkThatLogicIsSet(); }
     symbol[name,CHECK_UNDECLARED,SYM_VARIABLE]
@@ -1036,9 +990,7 @@ extendedCommand[std::unique_ptr<cvc5::Command>* cmd]
       // declare the name down here (while parsing term, signature
       // must not be extended with the name itself; no recursion
       // permitted)
-      api::Term func = PARSER_STATE->bindVar(name, t);
-      cmd->reset(new DefineFunctionCommand(
-          name, func, terms, e, SYM_MAN->getGlobalDeclarations()));
+      cmd->reset(new DefineFunctionCommand(name, t, e));
     }
 
   | SIMPLIFY_TOK { PARSER_STATE->checkThatLogicIsSet(); }
@@ -1850,6 +1802,8 @@ attribute[cvc5::api::Term& expr, cvc5::api::Term& retExpr]
   | ATTRIBUTE_NAMED_TOK symbol[s,CHECK_UNDECLARED,SYM_VARIABLE]
     {
       // notify that expression was given a name
+      PARSER_STATE->preemptCommand(
+          new DefineFunctionCommand(s, expr.getSort(), expr));
       PARSER_STATE->notifyNamedExpression(expr, s);
     }
   ;
@@ -2277,7 +2231,6 @@ ECHO_TOK : 'echo';
 DECLARE_SORTS_TOK : 'declare-sorts';
 DECLARE_FUNS_TOK : 'declare-funs';
 DECLARE_PREDS_TOK : 'declare-preds';
-DEFINE_TOK : 'define';
 DECLARE_CONST_TOK : 'declare-const';
 DEFINE_CONST_TOK : 'define-const';
 SIMPLIFY_TOK : 'simplify';
index e8ee6d59ce60b746dcaed458e949f5a732a8ecc0..419b925c4e1a735c63bc396db51678c03afdf32b 100644 (file)
@@ -1313,46 +1313,44 @@ void DefineSortCommand::toStream(std::ostream& out,
 /* -------------------------------------------------------------------------- */
 
 DefineFunctionCommand::DefineFunctionCommand(const std::string& id,
-                                             api::Term func,
-                                             api::Term formula,
-                                             bool global)
+                                             api::Sort sort,
+                                             api::Term formula)
     : DeclarationDefinitionCommand(id),
-      d_func(func),
       d_formals(),
-      d_formula(formula),
-      d_global(global)
+      d_sort(sort),
+      d_formula(formula)
 {
 }
 
 DefineFunctionCommand::DefineFunctionCommand(
     const std::string& id,
-    api::Term func,
     const std::vector<api::Term>& formals,
-    api::Term formula,
-    bool global)
+    api::Sort sort,
+    api::Term formula)
     : DeclarationDefinitionCommand(id),
-      d_func(func),
       d_formals(formals),
-      d_formula(formula),
-      d_global(global)
+      d_sort(sort),
+      d_formula(formula)
 {
 }
 
-api::Term DefineFunctionCommand::getFunction() const { return d_func; }
 const std::vector<api::Term>& DefineFunctionCommand::getFormals() const
 {
   return d_formals;
 }
 
+api::Sort DefineFunctionCommand::getSort() const { return d_sort; }
+
 api::Term DefineFunctionCommand::getFormula() const { return d_formula; }
+
 void DefineFunctionCommand::invoke(api::Solver* solver, SymbolManager* sm)
 {
   try
   {
-    if (!d_func.isNull())
-    {
-      solver->defineFun(d_func, d_formals, d_formula, d_global);
-    }
+    bool global = sm->getGlobalDeclarations();
+    api::Term fun =
+        solver->defineFun(d_symbol, d_formals, d_sort, d_formula, global);
+    sm->getSymbolTable()->bind(fun.toString(), fun, global);
     d_commandStatus = CommandSuccess::instance();
   }
   catch (exception& e)
@@ -1363,8 +1361,7 @@ void DefineFunctionCommand::invoke(api::Solver* solver, SymbolManager* sm)
 
 Command* DefineFunctionCommand::clone() const
 {
-  return new DefineFunctionCommand(
-      d_symbol, d_func, d_formals, d_formula, d_global);
+  return new DefineFunctionCommand(d_symbol, d_formals, d_sort, d_formula);
 }
 
 std::string DefineFunctionCommand::getCommandName() const
@@ -1377,16 +1374,11 @@ void DefineFunctionCommand::toStream(std::ostream& out,
                                      size_t dag,
                                      Language language) const
 {
-  TypeNode rangeType = termToNode(d_func).getType();
-  if (rangeType.isFunction())
-  {
-    rangeType = rangeType.getRangeType();
-  }
   Printer::getPrinter(language)->toStreamCmdDefineFunction(
       out,
-      d_func.toString(),
+      d_symbol,
       termVectorToNodes(d_formals),
-      rangeType,
+      sortToTypeNode(d_sort),
       termToNode(d_formula));
 }
 
@@ -1395,12 +1387,7 @@ void DefineFunctionCommand::toStream(std::ostream& out,
 /* -------------------------------------------------------------------------- */
 
 DefineFunctionRecCommand::DefineFunctionRecCommand(
-
-    api::Term func,
-    const std::vector<api::Term>& formals,
-    api::Term formula,
-    bool global)
-    : d_global(global)
+    api::Term func, const std::vector<api::Term>& formals, api::Term formula)
 {
   d_funcs.push_back(func);
   d_formals.push_back(formals);
@@ -1408,12 +1395,10 @@ DefineFunctionRecCommand::DefineFunctionRecCommand(
 }
 
 DefineFunctionRecCommand::DefineFunctionRecCommand(
-
     const std::vector<api::Term>& funcs,
     const std::vector<std::vector<api::Term>>& formals,
-    const std::vector<api::Term>& formulas,
-    bool global)
-    : d_funcs(funcs), d_formals(formals), d_formulas(formulas), d_global(global)
+    const std::vector<api::Term>& formulas)
+    : d_funcs(funcs), d_formals(formals), d_formulas(formulas)
 {
 }
 
@@ -1437,7 +1422,8 @@ void DefineFunctionRecCommand::invoke(api::Solver* solver, SymbolManager* sm)
 {
   try
   {
-    solver->defineFunsRec(d_funcs, d_formals, d_formulas, d_global);
+    bool global = sm->getGlobalDeclarations();
+    solver->defineFunsRec(d_funcs, d_formals, d_formulas, global);
     d_commandStatus = CommandSuccess::instance();
   }
   catch (exception& e)
@@ -1448,7 +1434,7 @@ void DefineFunctionRecCommand::invoke(api::Solver* solver, SymbolManager* sm)
 
 Command* DefineFunctionRecCommand::clone() const
 {
-  return new DefineFunctionRecCommand(d_funcs, d_formals, d_formulas, d_global);
+  return new DefineFunctionRecCommand(d_funcs, d_formals, d_formulas);
 }
 
 std::string DefineFunctionRecCommand::getCommandName() const
index 400f3492e1263da15f62bf341e8581fd9880375c..5733eb3e5b6f26a9dfdcc65dd9130b84bdc17e1e 100644 (file)
@@ -494,17 +494,15 @@ class CVC5_EXPORT DefineFunctionCommand : public DeclarationDefinitionCommand
 {
  public:
   DefineFunctionCommand(const std::string& id,
-                        api::Term func,
-                        api::Term formula,
-                        bool global);
+                        api::Sort sort,
+                        api::Term formula);
   DefineFunctionCommand(const std::string& id,
-                        api::Term func,
                         const std::vector<api::Term>& formals,
-                        api::Term formula,
-                        bool global);
+                        api::Sort sort,
+                        api::Term formula);
 
-  api::Term getFunction() const;
   const std::vector<api::Term>& getFormals() const;
+  api::Sort getSort() const;
   api::Term getFormula() const;
 
   void invoke(api::Solver* solver, SymbolManager* sm) override;
@@ -516,17 +514,12 @@ class CVC5_EXPORT DefineFunctionCommand : public DeclarationDefinitionCommand
                 Language language = Language::LANG_AUTO) const override;
 
  protected:
-  /** The function we are defining */
-  api::Term d_func;
   /** The formal arguments for the function we are defining */
   std::vector<api::Term> d_formals;
+  /** The co-domain sort of the function we are defining */
+  api::Sort d_sort;
   /** The formula corresponding to the body of the function we are defining */
   api::Term d_formula;
-  /**
-   * Stores whether this definition is global (i.e. should persist when
-   * popping the user context.
-   */
-  bool d_global;
 }; /* class DefineFunctionCommand */
 
 /**
@@ -539,12 +532,10 @@ class CVC5_EXPORT DefineFunctionRecCommand : public Command
  public:
   DefineFunctionRecCommand(api::Term func,
                            const std::vector<api::Term>& formals,
-                           api::Term formula,
-                           bool global);
+                           api::Term formula);
   DefineFunctionRecCommand(const std::vector<api::Term>& funcs,
                            const std::vector<std::vector<api::Term> >& formals,
-                           const std::vector<api::Term>& formula,
-                           bool global);
+                           const std::vector<api::Term>& formula);
 
   const std::vector<api::Term>& getFunctions() const;
   const std::vector<std::vector<api::Term> >& getFormals() const;
@@ -565,11 +556,6 @@ class CVC5_EXPORT DefineFunctionRecCommand : public Command
   std::vector<std::vector<api::Term> > d_formals;
   /** formulas corresponding to the bodies of the functions we are defining */
   std::vector<api::Term> d_formulas;
-  /**
-   * Stores whether this definition is global (i.e. should persist when
-   * popping the user context.
-   */
-  bool d_global;
 }; /* class DefineFunctionRecCommand */
 
 /**
index 04a2757414824f6038f236264a15ac8edb51a577..8e8ed0d9b286cfb6d272a8dbe7ea0666ce25f107 100644 (file)
@@ -950,42 +950,23 @@ def test_declare_sort(solver):
 
 def test_define_fun(solver):
     bvSort = solver.mkBitVectorSort(32)
-    funSort1 = solver.mkFunctionSort([bvSort, bvSort], bvSort)
-    funSort2 = solver.mkFunctionSort(solver.mkUninterpretedSort("u"),\
-                                     solver.getIntegerSort())
+    funSort = solver.mkFunctionSort(solver.mkUninterpretedSort("u"),
+                                    solver.getIntegerSort())
     b1 = solver.mkVar(bvSort, "b1")
-    b11 = solver.mkVar(bvSort, "b1")
     b2 = solver.mkVar(solver.getIntegerSort(), "b2")
-    b3 = solver.mkVar(funSort2, "b3")
+    b3 = solver.mkVar(funSort, "b3")
     v1 = solver.mkConst(bvSort, "v1")
-    v2 = solver.mkConst(solver.getIntegerSort(), "v2")
-    v3 = solver.mkConst(funSort2, "v3")
-    f1 = solver.mkConst(funSort1, "f1")
-    f2 = solver.mkConst(funSort2, "f2")
-    f3 = solver.mkConst(bvSort, "f3")
+    v2 = solver.mkConst(funSort, "v2")
     solver.defineFun("f", [], bvSort, v1)
     solver.defineFun("ff", [b1, b2], bvSort, v1)
-    solver.defineFun(f1, [b1, b11], v1)
     with pytest.raises(RuntimeError):
         solver.defineFun("ff", [v1, b2], bvSort, v1)
     with pytest.raises(RuntimeError):
-        solver.defineFun("fff", [b1], bvSort, v3)
+        solver.defineFun("fff", [b1], bvSort, v2)
     with pytest.raises(RuntimeError):
-        solver.defineFun("ffff", [b1], funSort2, v3)
+        solver.defineFun("ffff", [b1], funSort, v2)
     # b3 has function sort, which is allowed as an argument
     solver.defineFun("fffff", [b1, b3], bvSort, v1)
-    with pytest.raises(RuntimeError):
-        solver.defineFun(f1, [v1, b11], v1)
-    with pytest.raises(RuntimeError):
-        solver.defineFun(f1, [b1], v1)
-    with pytest.raises(RuntimeError):
-        solver.defineFun(f1, [b1, b11], v2)
-    with pytest.raises(RuntimeError):
-        solver.defineFun(f1, [b1, b11], v3)
-    with pytest.raises(RuntimeError):
-        solver.defineFun(f2, [b1], v2)
-    with pytest.raises(RuntimeError):
-        solver.defineFun(f3, [b1], v1)
 
     slv = pycvc5.Solver()
     bvSort2 = slv.mkBitVectorSort(32)
index 27261eff676feef73665274a5355110b070d6496..57f4d711c083687041af0935eaa02dbcd2b6dfbe 100644 (file)
@@ -4,6 +4,6 @@
 (set-logic QF_LIA)
 (declare-fun x () Int)
 (check-sat)
-(define t (not (= x 0)))
+(define-const t Bool (not (= x 0)))
 (assert t)
 (check-sat)
index af8e0b948e8f036d9a476076aac2ba119dc6e016..8fcfabd5e05480811a4145d2f12bc6fb99455d04 100644 (file)
@@ -6,8 +6,8 @@
 (set-option :global-declarations true)
 
 (push)
-(define a true)
-(define (f (b Bool)) b)
+(define-const a Bool true)
+(define-fun f ((b Bool)) Bool b)
 (define-const a2 Bool true)
 
 (define-fun a3 () Bool true)
index 1f88add2d154cb3547d10f85771e31c75a3d790a..58b8d03aa0f50a469fc753da42334e845c018ee4 100644 (file)
@@ -959,38 +959,25 @@ class SolverTest
   @Test void defineFun() throws CVC5ApiException
   {
     Sort bvSort = d_solver.mkBitVectorSort(32);
-    Sort funSort1 = d_solver.mkFunctionSort(new Sort[] {bvSort, bvSort}, bvSort);
-    Sort funSort2 =
+    Sort funSort =
         d_solver.mkFunctionSort(d_solver.mkUninterpretedSort("u"), d_solver.getIntegerSort());
     Term b1 = d_solver.mkVar(bvSort, "b1");
-    Term b11 = d_solver.mkVar(bvSort, "b1");
     Term b2 = d_solver.mkVar(d_solver.getIntegerSort(), "b2");
-    Term b3 = d_solver.mkVar(funSort2, "b3");
+    Term b3 = d_solver.mkVar(funSort, "b3");
     Term v1 = d_solver.mkConst(bvSort, "v1");
-    Term v2 = d_solver.mkConst(d_solver.getIntegerSort(), "v2");
-    Term v3 = d_solver.mkConst(funSort2, "v3");
-    Term f1 = d_solver.mkConst(funSort1, "f1");
-    Term f2 = d_solver.mkConst(funSort2, "f2");
-    Term f3 = d_solver.mkConst(bvSort, "f3");
+    Term v2 = d_solver.mkConst(funSort, "v2");
     assertDoesNotThrow(() -> d_solver.defineFun("f", new Term[] {}, bvSort, v1));
     assertDoesNotThrow(() -> d_solver.defineFun("ff", new Term[] {b1, b2}, bvSort, v1));
-    assertDoesNotThrow(() -> d_solver.defineFun(f1, new Term[] {b1, b11}, v1));
     assertThrows(
         CVC5ApiException.class, () -> d_solver.defineFun("ff", new Term[] {v1, b2}, bvSort, v1));
 
     assertThrows(
-        CVC5ApiException.class, () -> d_solver.defineFun("fff", new Term[] {b1}, bvSort, v3));
+        CVC5ApiException.class, () -> d_solver.defineFun("fff", new Term[] {b1}, bvSort, v2));
     assertThrows(
-        CVC5ApiException.class, () -> d_solver.defineFun("ffff", new Term[] {b1}, funSort2, v3));
+        CVC5ApiException.class, () -> d_solver.defineFun("ffff", new Term[] {b1}, funSort2, v2));
 
     // b3 has function sort, which is allowed as an argument
     assertDoesNotThrow(() -> d_solver.defineFun("fffff", new Term[] {b1, b3}, bvSort, v1));
-    assertThrows(CVC5ApiException.class, () -> d_solver.defineFun(f1, new Term[] {v1, b11}, v1));
-    assertThrows(CVC5ApiException.class, () -> d_solver.defineFun(f1, new Term[] {b1}, v1));
-    assertThrows(CVC5ApiException.class, () -> d_solver.defineFun(f1, new Term[] {b1, b11}, v2));
-    assertThrows(CVC5ApiException.class, () -> d_solver.defineFun(f1, new Term[] {b1, b11}, v3));
-    assertThrows(CVC5ApiException.class, () -> d_solver.defineFun(f2, new Term[] {b1}, v2));
-    assertThrows(CVC5ApiException.class, () -> d_solver.defineFun(f3, new Term[] {b1}, v1));
 
     Solver slv = new Solver();
     Sort bvSort2 = slv.mkBitVectorSort(32);
@@ -1012,15 +999,13 @@ class SolverTest
   @Test void defineFunGlobal()
   {
     Sort bSort = d_solver.getBooleanSort();
-    Sort fSort = d_solver.mkFunctionSort(bSort, bSort);
 
     Term bTrue = d_solver.mkBoolean(true);
     // (define-fun f () Bool true)
     Term f = d_solver.defineFun("f", new Term[] {}, bSort, bTrue, true);
     Term b = d_solver.mkVar(bSort, "b");
-    Term gSym = d_solver.mkConst(fSort, "g");
     // (define-fun g (b Bool) Bool b)
-    Term g = d_solver.defineFun(gSym, new Term[] {b}, b, true);
+    Term g = d_solver.defineFun("g", new Term[] {b}, bSort, b, true);
 
     // (assert (or (not f) (not (g true))))
     d_solver.assertFormula(
index 8dcb0fde606743a838f8fd7fff38633633114984..8435a63be3c8a4fa091dae5d105d5234143d2943 100644 (file)
@@ -922,35 +922,21 @@ TEST_F(TestApiBlackSolver, defineSort)
 TEST_F(TestApiBlackSolver, defineFun)
 {
   Sort bvSort = d_solver.mkBitVectorSort(32);
-  Sort funSort1 = d_solver.mkFunctionSort({bvSort, bvSort}, bvSort);
-  Sort funSort2 = d_solver.mkFunctionSort(d_solver.mkUninterpretedSort("u"),
-                                          d_solver.getIntegerSort());
+  Sort funSort = d_solver.mkFunctionSort(d_solver.mkUninterpretedSort("u"),
+                                         d_solver.getIntegerSort());
   Term b1 = d_solver.mkVar(bvSort, "b1");
-  Term b11 = d_solver.mkVar(bvSort, "b1");
   Term b2 = d_solver.mkVar(d_solver.getIntegerSort(), "b2");
-  Term b3 = d_solver.mkVar(funSort2, "b3");
+  Term b3 = d_solver.mkVar(funSort, "b3");
   Term v1 = d_solver.mkConst(bvSort, "v1");
-  Term v2 = d_solver.mkConst(d_solver.getIntegerSort(), "v2");
-  Term v3 = d_solver.mkConst(funSort2, "v3");
-  Term f1 = d_solver.mkConst(funSort1, "f1");
-  Term f2 = d_solver.mkConst(funSort2, "f2");
-  Term f3 = d_solver.mkConst(bvSort, "f3");
+  Term v2 = d_solver.mkConst(funSort, "v2");
   ASSERT_NO_THROW(d_solver.defineFun("f", {}, bvSort, v1));
   ASSERT_NO_THROW(d_solver.defineFun("ff", {b1, b2}, bvSort, v1));
-  ASSERT_NO_THROW(d_solver.defineFun(f1, {b1, b11}, v1));
   ASSERT_THROW(d_solver.defineFun("ff", {v1, b2}, bvSort, v1),
                CVC5ApiException);
-  ASSERT_THROW(d_solver.defineFun("fff", {b1}, bvSort, v3), CVC5ApiException);
-  ASSERT_THROW(d_solver.defineFun("ffff", {b1}, funSort2, v3),
-               CVC5ApiException);
+  ASSERT_THROW(d_solver.defineFun("fff", {b1}, bvSort, v2), CVC5ApiException);
+  ASSERT_THROW(d_solver.defineFun("ffff", {b1}, funSort, v2), CVC5ApiException);
   // b3 has function sort, which is allowed as an argument
   ASSERT_NO_THROW(d_solver.defineFun("fffff", {b1, b3}, bvSort, v1));
-  ASSERT_THROW(d_solver.defineFun(f1, {v1, b11}, v1), CVC5ApiException);
-  ASSERT_THROW(d_solver.defineFun(f1, {b1}, v1), CVC5ApiException);
-  ASSERT_THROW(d_solver.defineFun(f1, {b1, b11}, v2), CVC5ApiException);
-  ASSERT_THROW(d_solver.defineFun(f1, {b1, b11}, v3), CVC5ApiException);
-  ASSERT_THROW(d_solver.defineFun(f2, {b1}, v2), CVC5ApiException);
-  ASSERT_THROW(d_solver.defineFun(f3, {b1}, v1), CVC5ApiException);
 
   Solver slv;
   Sort bvSort2 = slv.mkBitVectorSort(32);
@@ -968,15 +954,13 @@ TEST_F(TestApiBlackSolver, defineFun)
 TEST_F(TestApiBlackSolver, defineFunGlobal)
 {
   Sort bSort = d_solver.getBooleanSort();
-  Sort fSort = d_solver.mkFunctionSort(bSort, bSort);
 
   Term bTrue = d_solver.mkBoolean(true);
   // (define-fun f () Bool true)
   Term f = d_solver.defineFun("f", {}, bSort, bTrue, true);
   Term b = d_solver.mkVar(bSort, "b");
-  Term gSym = d_solver.mkConst(fSort, "g");
   // (define-fun g (b Bool) Bool b)
-  Term g = d_solver.defineFun(gSym, {b}, b, true);
+  Term g = d_solver.defineFun("g", {b}, bSort, b, true);
 
   // (assert (or (not f) (not (g true))))
   d_solver.assertFormula(d_solver.mkTerm(