From: Andres Noetzli Date: Fri, 19 Jun 2020 05:51:50 +0000 (-0700) Subject: Add logic check for define-fun(s)-rec (#4577) X-Git-Tag: cvc5-1.0.0~3202 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=e8884b9b8ba86ce71807887cab87a5188cce4003;p=cvc5.git Add logic check for define-fun(s)-rec (#4577) This commit adds a logic check for `define-fun-rec`/`define-funs-rec` at the level of the new API that checks whether the logic is quantified and includes UF. To make sure that the parser actually executes that check, this commit converts the `DefineFunctionRecCommand` command to use the new API instead of the old one. This temporarily breaks the `exportTo` functionality for `DefineFunctionRecCommand` but this is not currently used within the CVC4 code base (and it seems unlikely that users use commands). --- diff --git a/src/api/cvc4cpp.cpp b/src/api/cvc4cpp.cpp index 321e284f9..2fd5cb444 100644 --- a/src/api/cvc4cpp.cpp +++ b/src/api/cvc4cpp.cpp @@ -959,20 +959,20 @@ CVC4::Type Sort::getType(void) const { return *d_type; } size_t Sort::getConstructorArity() const { - CVC4_API_CHECK(isConstructor()) << "Not a function sort."; + CVC4_API_CHECK(isConstructor()) << "Not a function sort: " << (*this); return ConstructorType(*d_type).getArity(); } std::vector Sort::getConstructorDomainSorts() const { - CVC4_API_CHECK(isConstructor()) << "Not a function sort."; + CVC4_API_CHECK(isConstructor()) << "Not a function sort: " << (*this); std::vector types = ConstructorType(*d_type).getArgTypes(); return typeVectorToSorts(d_solver, types); } Sort Sort::getConstructorCodomainSort() const { - CVC4_API_CHECK(isConstructor()) << "Not a function sort."; + CVC4_API_CHECK(isConstructor()) << "Not a function sort: " << (*this); return Sort(d_solver, ConstructorType(*d_type).getRangeType()); } @@ -980,20 +980,20 @@ Sort Sort::getConstructorCodomainSort() const size_t Sort::getFunctionArity() const { - CVC4_API_CHECK(isFunction()) << "Not a function sort."; + CVC4_API_CHECK(isFunction()) << "Not a function sort: " << (*this); return FunctionType(*d_type).getArity(); } std::vector Sort::getFunctionDomainSorts() const { - CVC4_API_CHECK(isFunction()) << "Not a function sort."; + CVC4_API_CHECK(isFunction()) << "Not a function sort: " << (*this); std::vector types = FunctionType(*d_type).getArgTypes(); return typeVectorToSorts(d_solver, types); } Sort Sort::getFunctionCodomainSort() const { - CVC4_API_CHECK(isFunction()) << "Not a function sort."; + CVC4_API_CHECK(isFunction()) << "Not a function sort" << (*this); return Sort(d_solver, FunctionType(*d_type).getRangeType()); } @@ -2599,6 +2599,7 @@ Solver::Solver(Options* opts) Options* o = opts == nullptr ? new Options() : opts; d_exprMgr.reset(new ExprManager(*o)); d_smtEngine.reset(new SmtEngine(d_exprMgr.get())); + d_smtEngine->setSolver(this); d_rng.reset(new Random((*o)[options::seed])); if (opts == nullptr) delete o; } @@ -4269,34 +4270,43 @@ Term Solver::defineFun(Term fun, bool global) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - CVC4_API_ARG_CHECK_EXPECTED(fun.getSort().isFunction(), fun) << "function"; - std::vector domain_sorts = fun.getSort().getFunctionDomainSorts(); - size_t size = bound_vars.size(); - CVC4_API_ARG_SIZE_CHECK_EXPECTED(size == domain_sorts.size(), bound_vars) - << "'" << domain_sorts.size() << "'"; - for (size_t i = 0; i < size; ++i) + + if (fun.getSort().isFunction()) { - CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( - this == bound_vars[i].d_solver, "bound variable", bound_vars[i], i) - << "bound variable associated to this solver object"; - CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( - bound_vars[i].d_expr->getKind() == CVC4::Kind::BOUND_VARIABLE, - "bound variable", - bound_vars[i], - i) - << "a bound variable"; - CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( - domain_sorts[i] == bound_vars[i].getSort(), - "sort of parameter", - bound_vars[i], - i) - << "'" << domain_sorts[i] << "'"; + std::vector domain_sorts = fun.getSort().getFunctionDomainSorts(); + size_t size = bound_vars.size(); + CVC4_API_ARG_SIZE_CHECK_EXPECTED(size == domain_sorts.size(), bound_vars) + << "'" << domain_sorts.size() << "'"; + for (size_t i = 0; i < size; ++i) + { + CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( + this == bound_vars[i].d_solver, "bound variable", bound_vars[i], i) + << "bound variable associated to this solver object"; + CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( + bound_vars[i].d_expr->getKind() == CVC4::Kind::BOUND_VARIABLE, + "bound variable", + bound_vars[i], + i) + << "a bound variable"; + CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( + domain_sorts[i] == bound_vars[i].getSort(), + "sort of parameter", + bound_vars[i], + i) + << "'" << domain_sorts[i] << "'"; + } + Sort codomain = fun.getSort().getFunctionCodomainSort(); + CVC4_API_CHECK(codomain == term.getSort()) + << "Invalid sort of function body '" << term << "', expected '" + << codomain << "'"; + } + else + { + CVC4_API_ARG_CHECK_EXPECTED(bound_vars.size() == 0, fun) + << "function or nullary symbol"; } - Sort codomain = fun.getSort().getFunctionCodomainSort(); + CVC4_API_SOLVER_CHECK_TERM(term); - CVC4_API_CHECK(codomain == term.getSort()) - << "Invalid sort of function body '" << term << "', expected '" - << codomain << "'"; std::vector ebound_vars = termVectorToExprs(bound_vars); d_smtEngine->defineFunction(*fun.d_expr, ebound_vars, *term.d_expr, global); @@ -4314,6 +4324,14 @@ Term Solver::defineFunRec(const std::string& symbol, bool global) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; + + CVC4_API_CHECK(d_smtEngine->getUserLogicInfo().isQuantified()) + << "recursive function definitions require a logic with quantifiers"; + CVC4_API_CHECK( + d_smtEngine->getUserLogicInfo().isTheoryEnabled(theory::THEORY_UF)) + << "recursive function definitions require a logic with uninterpreted " + "functions"; + CVC4_API_ARG_CHECK_EXPECTED(sort.isFirstClass(), sort) << "first-class sort as function codomain sort"; Assert(!sort.isFunction()); /* A function sort is not first-class. */ @@ -4358,34 +4376,50 @@ Term Solver::defineFunRec(Term fun, bool global) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - CVC4_API_ARG_CHECK_EXPECTED(fun.getSort().isFunction(), fun) << "function"; - std::vector domain_sorts = fun.getSort().getFunctionDomainSorts(); - size_t size = bound_vars.size(); - CVC4_API_ARG_SIZE_CHECK_EXPECTED(size == domain_sorts.size(), bound_vars) - << "'" << domain_sorts.size() << "'"; - for (size_t i = 0; i < size; ++i) + + CVC4_API_CHECK(d_smtEngine->getUserLogicInfo().isQuantified()) + << "recursive function definitions require a logic with quantifiers"; + CVC4_API_CHECK( + d_smtEngine->getUserLogicInfo().isTheoryEnabled(theory::THEORY_UF)) + << "recursive function definitions require a logic with uninterpreted " + "functions"; + + if (fun.getSort().isFunction()) { - CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( - this == bound_vars[i].d_solver, "bound variable", bound_vars[i], i) - << "bound variable associated to this solver object"; - CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( - bound_vars[i].d_expr->getKind() == CVC4::Kind::BOUND_VARIABLE, - "bound variable", - bound_vars[i], - i) - << "a bound variable"; - CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( - domain_sorts[i] == bound_vars[i].getSort(), - "sort of parameter", - bound_vars[i], - i) - << "'" << domain_sorts[i] << "'"; + std::vector domain_sorts = fun.getSort().getFunctionDomainSorts(); + size_t size = bound_vars.size(); + CVC4_API_ARG_SIZE_CHECK_EXPECTED(size == domain_sorts.size(), bound_vars) + << "'" << domain_sorts.size() << "'"; + for (size_t i = 0; i < size; ++i) + { + CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( + this == bound_vars[i].d_solver, "bound variable", bound_vars[i], i) + << "bound variable associated to this solver object"; + CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( + bound_vars[i].d_expr->getKind() == CVC4::Kind::BOUND_VARIABLE, + "bound variable", + bound_vars[i], + i) + << "a bound variable"; + CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( + domain_sorts[i] == bound_vars[i].getSort(), + "sort of parameter", + bound_vars[i], + i) + << "'" << domain_sorts[i] << "'"; + } + Sort codomain = fun.getSort().getFunctionCodomainSort(); + CVC4_API_CHECK(codomain == term.getSort()) + << "Invalid sort of function body '" << term << "', expected '" + << codomain << "'"; } + else + { + CVC4_API_ARG_CHECK_EXPECTED(bound_vars.size() == 0, fun) + << "function or nullary symbol"; + } + CVC4_API_SOLVER_CHECK_TERM(term); - Sort codomain = fun.getSort().getFunctionCodomainSort(); - CVC4_API_CHECK(codomain == term.getSort()) - << "Invalid sort of function body '" << term << "', expected '" - << codomain << "'"; std::vector ebound_vars = termVectorToExprs(bound_vars); d_smtEngine->defineFunctionRec( *fun.d_expr, ebound_vars, *term.d_expr, global); @@ -4402,6 +4436,14 @@ void Solver::defineFunsRec(const std::vector& funs, bool global) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; + + CVC4_API_CHECK(d_smtEngine->getUserLogicInfo().isQuantified()) + << "recursive function definitions require a logic with quantifiers"; + CVC4_API_CHECK( + d_smtEngine->getUserLogicInfo().isTheoryEnabled(theory::THEORY_UF)) + << "recursive function definitions require a logic with uninterpreted " + "functions"; + size_t funs_size = funs.size(); CVC4_API_ARG_SIZE_CHECK_EXPECTED(funs_size == bound_vars.size(), bound_vars) << "'" << funs_size << "'"; @@ -4414,38 +4456,46 @@ void Solver::defineFunsRec(const std::vector& funs, CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( this == fun.d_solver, "function", fun, j) << "function associated to this solver object"; - CVC4_API_ARG_CHECK_EXPECTED(fun.getSort().isFunction(), fun) << "function"; CVC4_API_SOLVER_CHECK_TERM(term); - std::vector domain_sorts = fun.getSort().getFunctionDomainSorts(); - size_t size = bvars.size(); - CVC4_API_ARG_SIZE_CHECK_EXPECTED(size == domain_sorts.size(), bvars) - << "'" << domain_sorts.size() << "'"; - for (size_t i = 0; i < size; ++i) + if (fun.getSort().isFunction()) { - for (size_t k = 0, nbvars = bvars.size(); k < nbvars; ++k) + std::vector domain_sorts = fun.getSort().getFunctionDomainSorts(); + size_t size = bvars.size(); + CVC4_API_ARG_SIZE_CHECK_EXPECTED(size == domain_sorts.size(), bvars) + << "'" << domain_sorts.size() << "'"; + for (size_t i = 0; i < size; ++i) { + for (size_t k = 0, nbvars = bvars.size(); k < nbvars; ++k) + { + CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( + this == bvars[k].d_solver, "bound variable", bvars[k], k) + << "bound variable associated to this solver object"; + CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( + bvars[k].d_expr->getKind() == CVC4::Kind::BOUND_VARIABLE, + "bound variable", + bvars[k], + k) + << "a bound variable"; + } CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( - this == bvars[k].d_solver, "bound variable", bvars[k], k) - << "bound variable associated to this solver object"; - CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( - bvars[k].d_expr->getKind() == CVC4::Kind::BOUND_VARIABLE, - "bound variable", - bvars[k], - k) - << "a bound variable"; + domain_sorts[i] == bvars[i].getSort(), + "sort of parameter", + bvars[i], + i) + << "'" << domain_sorts[i] << "' in parameter bound_vars[" << j + << "]"; } + Sort codomain = fun.getSort().getFunctionCodomainSort(); CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( - domain_sorts[i] == bvars[i].getSort(), - "sort of parameter", - bvars[i], - i) - << "'" << domain_sorts[i] << "' in parameter bound_vars[" << j << "]"; + codomain == term.getSort(), "sort of function body", term, j) + << "'" << codomain << "'"; + } + else + { + CVC4_API_ARG_CHECK_EXPECTED(bvars.size() == 0, fun) + << "function or nullary symbol"; } - Sort codomain = fun.getSort().getFunctionCodomainSort(); - CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( - codomain == term.getSort(), "sort of function body", term, j) - << "'" << codomain << "'"; } std::vector efuns = termVectorToExprs(funs); std::vector> ebound_vars; diff --git a/src/parser/cvc/Cvc.g b/src/parser/cvc/Cvc.g index d1fcc4e06..b504d290b 100644 --- a/src/parser/cvc/Cvc.g +++ b/src/parser/cvc/Cvc.g @@ -935,15 +935,8 @@ mainCommand[std::unique_ptr* cmd] PARSER_STATE->parseError("Type mismatch in definition"); } } - std::vector> eformals; - for (unsigned i=0, fsize = formals.size(); ireset( - new DefineFunctionRecCommand(api::termVectorToExprs(funcs), - eformals, - api::termVectorToExprs(formulas), true)); + new DefineFunctionRecCommand(SOLVER, funcs, formals, formulas, true)); } | toplevelDeclaration[cmd] ; diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g index 98e416969..a11b9670b 100644 --- a/src/parser/smt2/Smt2.g +++ b/src/parser/smt2/Smt2.g @@ -1208,7 +1208,7 @@ smt25Command[std::unique_ptr* cmd] expr = PARSER_STATE->mkHoApply( expr, flattenVars ); } cmd->reset(new DefineFunctionRecCommand( - func.getExpr(), api::termVectorToExprs(bvs), expr.getExpr(), PARSER_STATE->getGlobalDeclarations())); + SOLVER, func, bvs, expr, PARSER_STATE->getGlobalDeclarations())); } | DEFINE_FUNS_REC_TOK { PARSER_STATE->checkThatLogicIsSet();} @@ -1271,15 +1271,12 @@ smt25Command[std::unique_ptr* cmd] "Number of functions defined does not match number listed in " "define-funs-rec")); } - std::vector> eformals; - for (unsigned i=0, fsize = formals.size(); ireset( - new DefineFunctionRecCommand(api::termVectorToExprs(funcs), - eformals, - api::termVectorToExprs(func_defs), PARSER_STATE->getGlobalDeclarations())); + new DefineFunctionRecCommand(SOLVER, + funcs, + formals, + func_defs, + PARSER_STATE->getGlobalDeclarations())); } ; diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 6a2f220ec..b88e53788 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -21,6 +21,7 @@ #include #include +#include "api/cvc4cpp.h" #include "expr/dtype.h" #include "expr/node_manager_attributes.h" #include "expr/node_visitor.h" @@ -1666,8 +1667,8 @@ static void toStream(std::ostream& out, const DefineFunctionCommand* c) static void toStream(std::ostream& out, const DefineFunctionRecCommand* c) { - const vector& funcs = c->getFunctions(); - const vector >& formals = c->getFormals(); + const vector& funcs = c->getFunctions(); + const vector >& formals = c->getFormals(); out << "(define-fun"; if (funcs.size() > 1) { @@ -1690,10 +1691,10 @@ static void toStream(std::ostream& out, const DefineFunctionRecCommand* c) } out << funcs[i] << " ("; // print its type signature - vector::const_iterator itf = formals[i].begin(); + vector::const_iterator itf = formals[i].begin(); for (;;) { - out << "(" << (*itf) << " " << (*itf).getType() << ")"; + out << "(" << (*itf) << " " << (*itf).getSort() << ")"; ++itf; if (itf != formals[i].end()) { @@ -1704,8 +1705,8 @@ static void toStream(std::ostream& out, const DefineFunctionRecCommand* c) break; } } - Type type = funcs[i].getType(); - type = static_cast(type).getRangeType(); + api::Sort type = funcs[i].getSort(); + type = type.getFunctionCodomainSort(); out << ") " << type; if (funcs.size() > 1) { @@ -1716,7 +1717,7 @@ static void toStream(std::ostream& out, const DefineFunctionRecCommand* c) { out << ") ("; } - const vector& formulas = c->getFormulas(); + const vector& formulas = c->getFormulas(); for (unsigned i = 0, size = formulas.size(); i < size; i++) { if (i > 0) diff --git a/src/smt/command.cpp b/src/smt/command.cpp index 566772508..962882309 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -23,6 +23,7 @@ #include #include +#include "api/cvc4cpp.h" #include "base/check.h" #include "base/output.h" #include "expr/expr_iomanip.h" @@ -134,7 +135,13 @@ std::ostream& operator<<(std::ostream& out, CommandPrintSuccess cps) /* class Command */ /* -------------------------------------------------------------------------- */ -Command::Command() : d_commandStatus(NULL), d_muted(false) {} +Command::Command() : d_commandStatus(nullptr), d_muted(false) {} + +Command::Command(api::Solver* solver) + : d_solver(solver), d_commandStatus(nullptr), d_muted(false) +{ +} + Command::Command(const Command& cmd) { d_commandStatus = @@ -1387,8 +1394,12 @@ Command* DefineNamedFunctionCommand::clone() const /* -------------------------------------------------------------------------- */ DefineFunctionRecCommand::DefineFunctionRecCommand( - Expr func, const std::vector& formals, Expr formula, bool global) - : d_global(global) + api::Solver* solver, + api::Term func, + const std::vector& formals, + api::Term formula, + bool global) + : Command(solver), d_global(global) { d_funcs.push_back(func); d_formals.push_back(formals); @@ -1396,26 +1407,31 @@ DefineFunctionRecCommand::DefineFunctionRecCommand( } DefineFunctionRecCommand::DefineFunctionRecCommand( - const std::vector& funcs, - const std::vector>& formals, - const std::vector& formulas, + api::Solver* solver, + const std::vector& funcs, + const std::vector>& formals, + const std::vector& formulas, bool global) - : d_funcs(funcs), d_formals(formals), d_formulas(formulas), d_global(global) + : Command(solver), + d_funcs(funcs), + d_formals(formals), + d_formulas(formulas), + d_global(global) { } -const std::vector& DefineFunctionRecCommand::getFunctions() const +const std::vector& DefineFunctionRecCommand::getFunctions() const { return d_funcs; } -const std::vector>& DefineFunctionRecCommand::getFormals() - const +const std::vector>& +DefineFunctionRecCommand::getFormals() const { return d_formals; } -const std::vector& DefineFunctionRecCommand::getFormulas() const +const std::vector& DefineFunctionRecCommand::getFormulas() const { return d_formulas; } @@ -1424,7 +1440,7 @@ void DefineFunctionRecCommand::invoke(SmtEngine* smtEngine) { try { - smtEngine->defineFunctionsRec(d_funcs, d_formals, d_formulas, d_global); + d_solver->defineFunsRec(d_funcs, d_formals, d_formulas, d_global); d_commandStatus = CommandSuccess::instance(); } catch (exception& e) @@ -1436,35 +1452,13 @@ void DefineFunctionRecCommand::invoke(SmtEngine* smtEngine) Command* DefineFunctionRecCommand::exportTo( ExprManager* exprManager, ExprManagerMapCollection& variableMap) { - std::vector funcs; - for (unsigned i = 0, size = d_funcs.size(); i < size; i++) - { - Expr func = d_funcs[i].exportTo( - exprManager, variableMap, /* flags = */ ExprManager::VAR_FLAG_DEFINED); - funcs.push_back(func); - } - std::vector> formals; - for (unsigned i = 0, size = d_formals.size(); i < size; i++) - { - std::vector formals_c; - transform(d_formals[i].begin(), - d_formals[i].end(), - back_inserter(formals_c), - ExportTransformer(exprManager, variableMap)); - formals.push_back(formals_c); - } - std::vector formulas; - for (unsigned i = 0, size = d_formulas.size(); i < size; i++) - { - Expr formula = d_formulas[i].exportTo(exprManager, variableMap); - formulas.push_back(formula); - } - return new DefineFunctionRecCommand(funcs, formals, formulas, d_global); + Unimplemented(); } Command* DefineFunctionRecCommand::clone() const { - return new DefineFunctionRecCommand(d_funcs, d_formals, d_formulas, d_global); + return new DefineFunctionRecCommand( + d_solver, d_funcs, d_formals, d_formulas, d_global); } std::string DefineFunctionRecCommand::getCommandName() const diff --git a/src/smt/command.h b/src/smt/command.h index e09dfe490..f7c780dee 100644 --- a/src/smt/command.h +++ b/src/smt/command.h @@ -39,6 +39,11 @@ namespace CVC4 { +namespace api { +class Solver; +class Term; +} // namespace api + class SmtEngine; class Command; class CommandStatus; @@ -186,27 +191,11 @@ class CVC4_PUBLIC CommandRecoverableFailure : public CommandStatus class CVC4_PUBLIC Command { - protected: - /** - * This field contains a command status if the command has been - * invoked, or NULL if it has not. This field is either a - * dynamically-allocated pointer, or it's a pointer to the singleton - * CommandSuccess instance. Doing so is somewhat asymmetric, but - * it avoids the need to dynamically allocate memory in the common - * case of a successful command. - */ - const CommandStatus* d_commandStatus; - - /** - * True if this command is "muted"---i.e., don't print "success" on - * successful execution. - */ - bool d_muted; - public: typedef CommandPrintSuccess printsuccess; Command(); + Command(api::Solver* solver); Command(const Command& cmd); virtual ~Command(); @@ -281,6 +270,25 @@ class CVC4_PUBLIC Command Expr operator()(Expr e) { return e.exportTo(d_exprManager, d_variableMap); } Type operator()(Type t) { return t.exportTo(d_exprManager, d_variableMap); } }; /* class Command::ExportTransformer */ + + /** The solver instance that this command is associated with. */ + api::Solver* d_solver; + + /** + * This field contains a command status if the command has been + * invoked, or NULL if it has not. This field is either a + * dynamically-allocated pointer, or it's a pointer to the singleton + * CommandSuccess instance. Doing so is somewhat asymmetric, but + * it avoids the need to dynamically allocate memory in the common + * case of a successful command. + */ + const CommandStatus* d_commandStatus; + + /** + * True if this command is "muted"---i.e., don't print "success" on + * successful execution. + */ + bool d_muted; }; /* class Command */ /** @@ -498,18 +506,20 @@ class CVC4_PUBLIC DefineNamedFunctionCommand : public DefineFunctionCommand class CVC4_PUBLIC DefineFunctionRecCommand : public Command { public: - DefineFunctionRecCommand(Expr func, - const std::vector& formals, - Expr formula, + DefineFunctionRecCommand(api::Solver* solver, + api::Term func, + const std::vector& formals, + api::Term formula, bool global); - DefineFunctionRecCommand(const std::vector& funcs, - const std::vector >& formals, - const std::vector& formula, + DefineFunctionRecCommand(api::Solver* solver, + const std::vector& funcs, + const std::vector >& formals, + const std::vector& formula, bool global); - const std::vector& getFunctions() const; - const std::vector >& getFormals() const; - const std::vector& getFormulas() const; + const std::vector& getFunctions() const; + const std::vector >& getFormals() const; + const std::vector& getFormulas() const; void invoke(SmtEngine* smtEngine) override; Command* exportTo(ExprManager* exprManager, @@ -519,11 +529,11 @@ class CVC4_PUBLIC DefineFunctionRecCommand : public Command protected: /** functions we are defining */ - std::vector d_funcs; + std::vector d_funcs; /** formal arguments for each of the functions we are defining */ - std::vector > d_formals; + std::vector > d_formals; /** formulas corresponding to the bodies of the functions we are defining */ - std::vector d_formulas; + std::vector d_formulas; /** * Stores whether this definition is global (i.e. should persist when * popping the user context. diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index ccdf40393..b826ef23d 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -29,6 +29,7 @@ #include #include +#include "api/cvc4cpp.h" #include "base/check.h" #include "base/configuration.h" #include "base/configuration_private.h" @@ -906,6 +907,7 @@ void SmtEngine::setLogic(const LogicInfo& logic) "finished initializing."); } d_logic = logic; + d_userLogic = logic; setLogicInternal(); } @@ -932,6 +934,14 @@ void SmtEngine::setLogic(const char* logic) { setLogic(string(logic)); } LogicInfo SmtEngine::getLogicInfo() const { return d_logic; } +LogicInfo SmtEngine::getUserLogicInfo() const +{ + // Lock the logic to make sure that this logic can be queried. We create a + // copy of the user logic here to keep this method const. + LogicInfo res = d_userLogic; + res.lock(); + return res; +} void SmtEngine::setFilename(std::string filename) { d_filename = filename; } std::string SmtEngine::getFilename() const { return d_filename; } void SmtEngine::setLogicInternal() @@ -940,6 +950,7 @@ void SmtEngine::setLogicInternal() << "setting logic in SmtEngine but the engine has already" " finished initializing for this run"; d_logic.lock(); + d_userLogic.lock(); } void SmtEngine::setProblemExtended() @@ -1267,8 +1278,16 @@ void SmtEngine::defineFunctionsRec( if (Dump.isOn("raw-benchmark")) { + std::vector tFuncs = api::exprVectorToTerms(d_solver, funcs); + std::vector> tFormals; + for (const std::vector& formal : formals) + { + tFormals.emplace_back(api::exprVectorToTerms(d_solver, formal)); + } + std::vector tFormulas = + api::exprVectorToTerms(d_solver, formulas); Dump("raw-benchmark") << DefineFunctionRecCommand( - funcs, formals, formulas, global); + d_solver, tFuncs, tFormals, tFormulas, global); } ExprManager* em = getExprManager(); diff --git a/src/smt/smt_engine.h b/src/smt/smt_engine.h index c7a37c007..b1e3313d8 100644 --- a/src/smt/smt_engine.h +++ b/src/smt/smt_engine.h @@ -65,6 +65,12 @@ class StatisticsRegistry; /* -------------------------------------------------------------------------- */ +namespace api { +class Solver; +} // namespace api + +/* -------------------------------------------------------------------------- */ + namespace context { class Context; class UserContext; @@ -126,6 +132,7 @@ namespace theory { class CVC4_PUBLIC SmtEngine { + friend class ::CVC4::api::Solver; // TODO (Issue #1096): Remove this friend relationship. friend class ::CVC4::preprocessing::PreprocessingPassContext; friend class ::CVC4::smt::SmtEnginePrivate; @@ -207,6 +214,9 @@ class CVC4_PUBLIC SmtEngine /** Get the logic information currently set. */ LogicInfo getLogicInfo() const; + /** Get the logic information set by the user. */ + LogicInfo getUserLogicInfo() const; + /** * Set information about the script executing. * @throw OptionException, ModalException @@ -875,6 +885,9 @@ class CVC4_PUBLIC SmtEngine SmtEngine(const SmtEngine&) = delete; SmtEngine& operator=(const SmtEngine&) = delete; + /** Set solver instance that owns this SmtEngine. */ + void setSolver(api::Solver* solver) { d_solver = solver; } + /** Get a pointer to the TheoryEngine owned by this SmtEngine. */ TheoryEngine* getTheoryEngine() { return d_theoryEngine.get(); } @@ -1082,6 +1095,9 @@ class CVC4_PUBLIC SmtEngine /* Members -------------------------------------------------------------- */ + /** Solver instance that owns this SmtEngine instance. */ + api::Solver* d_solver = nullptr; + /** Expr manager context */ std::unique_ptr d_context; /** User level context */ @@ -1197,10 +1213,14 @@ class CVC4_PUBLIC SmtEngine std::vector d_defineCommands; /** - * The logic we're in. + * The logic we're in. This logic may be an extension of the logic set by the + * user. */ LogicInfo d_logic; + /** The logic set by the user. */ + LogicInfo d_userLogic; + /** * Keep a copy of the original option settings (for reset()). */ diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index cdac5fc6c..5ae66f203 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -920,6 +920,7 @@ set(regress_0_tests regress0/simplification_bug2.smtv1.smt2 regress0/smallcnf.cvc regress0/smt2output.smt2 + regress0/smtlib/define-fun-rec-logic.smt2 regress0/smtlib/get-unsat-assumptions.smt2 regress0/smtlib/global-decls.smt2 regress0/smtlib/issue4028.smt2 diff --git a/test/regress/regress0/smtlib/define-fun-rec-logic.smt2 b/test/regress/regress0/smtlib/define-fun-rec-logic.smt2 new file mode 100644 index 000000000..1af16210d --- /dev/null +++ b/test/regress/regress0/smtlib/define-fun-rec-logic.smt2 @@ -0,0 +1,12 @@ +; EXPECT: unsat +; EXPECT: (error "recursive function definitions require a logic with quantifiers") +; EXIT: 1 +(set-logic UFBV) +(define-funs-rec ((f ((b Bool)) Bool) (g ((b Bool)) Bool)) (b b)) +(assert (g false)) +(check-sat) +(reset) +(set-logic QF_BV) +(define-funs-rec ((f ((b Bool)) Bool) (g ((b Bool)) Bool)) (b b)) +(assert (g false)) +(check-sat) diff --git a/test/unit/api/solver_black.h b/test/unit/api/solver_black.h index ff0040024..6faab6075 100644 --- a/test/unit/api/solver_black.h +++ b/test/unit/api/solver_black.h @@ -86,8 +86,10 @@ class SolverBlack : public CxxTest::TestSuite void testDefineFun(); void testDefineFunGlobal(); void testDefineFunRec(); + void testDefineFunRecWrongLogic(); void testDefineFunRecGlobal(); void testDefineFunsRec(); + void testDefineFunsRecWrongLogic(); void testDefineFunsRecGlobal(); void testUFIteration(); @@ -1117,6 +1119,19 @@ void SolverBlack::testDefineFunRec() CVC4ApiException&); } +void SolverBlack::testDefineFunRecWrongLogic() +{ + d_solver->setLogic("QF_BV"); + Sort bvSort = d_solver->mkBitVectorSort(32); + Sort funSort = d_solver->mkFunctionSort({bvSort, bvSort}, bvSort); + Term b = d_solver->mkVar(bvSort, "b"); + Term v = d_solver->mkConst(bvSort, "v"); + Term f = d_solver->mkConst(funSort, "f"); + TS_ASSERT_THROWS(d_solver->defineFunRec("f", {}, bvSort, v), + CVC4ApiException&); + TS_ASSERT_THROWS(d_solver->defineFunRec(f, {b, b}, v), CVC4ApiException&); +} + void SolverBlack::testDefineFunRecGlobal() { Sort bSort = d_solver->getBooleanSort(); @@ -1214,6 +1229,23 @@ void SolverBlack::testDefineFunsRec() CVC4ApiException&); } +void SolverBlack::testDefineFunsRecWrongLogic() +{ + d_solver->setLogic("QF_BV"); + Sort uSort = d_solver->mkUninterpretedSort("u"); + Sort bvSort = d_solver->mkBitVectorSort(32); + Sort funSort1 = d_solver->mkFunctionSort({bvSort, bvSort}, bvSort); + Sort funSort2 = d_solver->mkFunctionSort(uSort, d_solver->getIntegerSort()); + Term b = d_solver->mkVar(bvSort, "b"); + Term u = d_solver->mkVar(uSort, "u"); + Term v1 = d_solver->mkConst(bvSort, "v1"); + Term v2 = d_solver->mkConst(d_solver->getIntegerSort(), "v2"); + Term f1 = d_solver->mkConst(funSort1, "f1"); + Term f2 = d_solver->mkConst(funSort2, "f2"); + TS_ASSERT_THROWS(d_solver->defineFunsRec({f1, f2}, {{b, b}, {u}}, {v1, v2}), + CVC4ApiException&); +} + void SolverBlack::testDefineFunsRecGlobal() { Sort bSort = d_solver->getBooleanSort();