From 4da459b5bed88e7898ab030b8c2e8b2386e30b4e Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Tue, 24 May 2022 09:07:42 -0500 Subject: [PATCH] Add declareOracleFun to API (#8794) Java and Python will be added in followup PRs. --- NEWS.md | 5 ++ src/api/cpp/cvc5.cpp | 36 ++++++++- src/api/cpp/cvc5.h | 31 ++++++++ src/smt/set_defaults.cpp | 5 -- src/theory/quantifiers/oracle_engine.cpp | 31 ++++---- test/unit/api/cpp/solver_black.cpp | 94 ++++++++++++++++++++++++ 6 files changed, 182 insertions(+), 20 deletions(-) diff --git a/NEWS.md b/NEWS.md index fd4ae1b7f..1e6e08e29 100644 --- a/NEWS.md +++ b/NEWS.md @@ -20,6 +20,11 @@ cvc5 1.0.1 - The API method `mkTuple` no longer supports casting integers to reals when constructing tuples. +**New Features** + +- Support for declaring oracle functions in the API via the method + `declareOracleFun`. This allows users to declare functions whose semantics + are associated with a provided executable implementation. cvc5 1.0 ========= diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index 35be11a64..980e6468b 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -62,8 +62,8 @@ #include "options/option_exception.h" #include "options/options.h" #include "options/options_public.h" -#include "options/smt_options.h" #include "options/quantifiers_options.h" +#include "options/smt_options.h" #include "proof/unsat_core.h" #include "smt/env.h" #include "smt/model.h" @@ -7027,6 +7027,40 @@ Term Solver::declarePool(const std::string& symbol, CVC5_API_TRY_CATCH_END; } +Term Solver::declareOracleFun( + const std::string& symbol, + const std::vector& sorts, + const Sort& sort, + std::function&)> fn) const +{ + CVC5_API_TRY_CATCH_BEGIN; + CVC5_API_SOLVER_CHECK_DOMAIN_SORTS(sorts); + CVC5_API_SOLVER_CHECK_CODOMAIN_SORT(sort); + CVC5_API_CHECK(d_slv->getOptions().quantifiers.oracles) + << "Cannot call declareOracleFun unless oracles is enabled (use " + "--oracles)"; + //////// all checks before this line + internal::TypeNode type = *sort.d_type; + if (!sorts.empty()) + { + std::vector types = Sort::sortVectorToTypeNodes(sorts); + type = d_nodeMgr->mkFunctionType(types, type); + } + internal::Node fun = d_nodeMgr->mkVar(symbol, type); + // Wrap the terms-to-term function so that it is nodes-to-nodes. Note we + // make the method return a vector of size one to conform to the interface + // at the SolverEngine level. + d_slv->declareOracleFun( + fun, [&, fn](const std::vector nodes) { + std::vector terms = Term::nodeVectorToTerms(this, nodes); + Term output = fn(terms); + return Term::termVectorToNodes({output}); + }); + return Term(this, fun); + //////// + CVC5_API_TRY_CATCH_END; +} + void Solver::pop(uint32_t nscopes) const { CVC5_API_TRY_CATCH_BEGIN; diff --git a/src/api/cpp/cvc5.h b/src/api/cpp/cvc5.h index 6ad7b6a39..5ba737d7b 100644 --- a/src/api/cpp/cvc5.h +++ b/src/api/cpp/cvc5.h @@ -18,6 +18,7 @@ #ifndef CVC5__API__CVC5_H #define CVC5__API__CVC5_H +#include #include #include #include @@ -4435,6 +4436,36 @@ class CVC5_EXPORT Solver Term declarePool(const std::string& symbol, const Sort& sort, const std::vector& initValue) const; + /** + * Declare an oracle function with reference to an implementation. + * + * Oracle functions have a different semantics with respect to ordinary + * declared functions. In particular, for an input to be satisfiable, + * its oracle functions are implicitly universally quantified. + * + * This method is used in part for implementing this command: + * + * \verbatim embed:rst:leading-asterisk + * .. code:: smtlib + * + * (declare-oracle-fun (*) ) + * \endverbatim + * + * In particular, the above command is implemented by constructing a + * function over terms that wraps a call to binary sym via a text interface. + * + * @warning This method is experimental and may change in future versions. + * + * @param symbol The name of the oracle + * @param sorts The sorts of the parameters to this function + * @param sort The sort of the return value of this function + * @param fn The function that implements the oracle function. + * @return The oracle function + */ + Term declareOracleFun(const std::string& symbol, + const std::vector& sorts, + const Sort& sort, + std::function&)> fn) const; /** * Pop (a) level(s) from the assertion stack. * diff --git a/src/smt/set_defaults.cpp b/src/smt/set_defaults.cpp index dc26303ad..d9a6f5fbc 100644 --- a/src/smt/set_defaults.cpp +++ b/src/smt/set_defaults.cpp @@ -63,11 +63,6 @@ void SetDefaults::setDefaults(LogicInfo& logic, Options& opts) void SetDefaults::setDefaultsPre(Options& opts) { - - if (opts.quantifiers.oracles) - { - throw OptionException(std::string("Oracles not yet supported")); - } // implied options if (opts.smt.debugCheckModels) { diff --git a/src/theory/quantifiers/oracle_engine.cpp b/src/theory/quantifiers/oracle_engine.cpp index 4b0783fdb..9ad68ed88 100644 --- a/src/theory/quantifiers/oracle_engine.cpp +++ b/src/theory/quantifiers/oracle_engine.cpp @@ -209,21 +209,23 @@ void OracleEngine::checkOwnership(Node q) { std::vector inputs, outputs; Node assume, constraint, oracle; - getOracleInterface(q, inputs, outputs, assume, constraint, oracle); - Assert(constraint.isConst() && constraint.getConst()) - << "Unhandled oracle constraint " << q; - CVC5_UNUSED bool isOracleFun = false; - if (OracleCaller::isOracleFunctionApp(assume)) + if (!getOracleInterface(q, inputs, outputs, assume, constraint, oracle)) + { + Assert(false) << "Not an oracle interface " << q; + } + else { - // predicate case - isOracleFun = true; + Assert(outputs.size() == 1) << "Unhandled oracle constraint " << q; + Assert(constraint.isConst() && constraint.getConst()) + << "Unhandled oracle constraint " << q; } - else if (assume.getKind() == EQUAL) + CVC5_UNUSED bool isOracleFun = false; + if (assume.getKind() == EQUAL) { for (size_t i = 0; i < 2; i++) { if (OracleCaller::isOracleFunctionApp(assume[i]) - && assume[1 - i].isConst()) + && assume[1 - i] == outputs[0]) { isOracleFun = true; } @@ -295,23 +297,24 @@ bool OracleEngine::getOracleInterface(Node q, OracleInputVarAttribute oiva; for (const Node& v : q[0]) { - if (v.hasAttribute(oiva)) + if (v.getAttribute(oiva)) { inputs.push_back(v); } else { - Assert(v.hasAttribute(OracleOutputVarAttribute())); + Assert(v.getAttribute(OracleOutputVarAttribute())); outputs.push_back(v); } } Assert(q[1].getKind() == ORACLE_FORMULA_GEN); assume = q[1][0]; - constraint = q[1][0]; + constraint = q[1][1]; Assert(q.getNumChildren() == 3); Assert(q[2].getNumChildren() == 1); - Assert(q[2][0].getKind() == ORACLE); - oracleNode = q[2][0]; + Assert(q[2][0].getNumChildren() == 1); + Assert(q[2][0][0].getKind() == ORACLE); + oracleNode = q[2][0][0]; return true; } return false; diff --git a/test/unit/api/cpp/solver_black.cpp b/test/unit/api/cpp/solver_black.cpp index b9627073c..e3bf8c074 100644 --- a/test/unit/api/cpp/solver_black.cpp +++ b/test/unit/api/cpp/solver_black.cpp @@ -3439,5 +3439,99 @@ TEST_F(TestApiBlackSolver, projIssue337) ASSERT_EQ(t.getSort(), tt.getSort()); } +TEST_F(TestApiBlackSolver, declareOracleFunError) +{ + Sort iSort = d_solver.getIntegerSort(); + // cannot declare without option + ASSERT_THROW(d_solver.declareOracleFun( + "f", + {iSort}, + iSort, + [&](const std::vector& input) { return d_solver.mkInteger(0); }); + , CVC5ApiException); + d_solver.setOption("oracles", "true"); + Sort nullSort; + // bad sort + ASSERT_THROW(d_solver.declareOracleFun( + "f", + {nullSort}, + iSort, + [&](const std::vector& input) { return d_solver.mkInteger(0); }); + , CVC5ApiException); +} + +TEST_F(TestApiBlackSolver, declareOracleFunUnsat) +{ + d_solver.setOption("oracles", "true"); + Sort iSort = d_solver.getIntegerSort(); + // f is the function implementing (lambda ((x Int)) (+ x 1)) + Term f = d_solver.declareOracleFun( + "f", {iSort}, iSort, [&](const std::vector& input) { + if (input[0].isUInt32Value()) + { + return d_solver.mkInteger(input[0].getUInt32Value() + 1); + } + return d_solver.mkInteger(0); + }); + Term three = d_solver.mkInteger(3); + Term five = d_solver.mkInteger(5); + Term eq = + d_solver.mkTerm(EQUAL, {d_solver.mkTerm(APPLY_UF, {f, three}), five}); + d_solver.assertFormula(eq); + // (f 3) = 5 + ASSERT_TRUE(d_solver.checkSat().isUnsat()); +} + +TEST_F(TestApiBlackSolver, declareOracleFunSat) +{ + d_solver.setOption("oracles", "true"); + d_solver.setOption("produce-models", "true"); + Sort iSort = d_solver.getIntegerSort(); + // f is the function implementing (lambda ((x Int)) (% x 10)) + Term f = d_solver.declareOracleFun( + "f", {iSort}, iSort, [&](const std::vector& input) { + if (input[0].isUInt32Value()) + { + return d_solver.mkInteger(input[0].getUInt32Value() % 10); + } + return d_solver.mkInteger(0); + }); + Term seven = d_solver.mkInteger(7); + Term x = d_solver.mkConst(iSort, "x"); + Term lb = d_solver.mkTerm(GEQ, {x, d_solver.mkInteger(0)}); + d_solver.assertFormula(lb); + Term ub = d_solver.mkTerm(LEQ, {x, d_solver.mkInteger(100)}); + d_solver.assertFormula(ub); + Term eq = d_solver.mkTerm(EQUAL, {d_solver.mkTerm(APPLY_UF, {f, x}), seven}); + d_solver.assertFormula(eq); + // x >= 0 ^ x <= 100 ^ (f x) = 7 + ASSERT_TRUE(d_solver.checkSat().isSat()); + Term xval = d_solver.getValue(x); + ASSERT_TRUE(xval.isUInt32Value()); + ASSERT_TRUE(xval.getUInt32Value() % 10 == 7); +} + +TEST_F(TestApiBlackSolver, declareOracleFunSat2) +{ + d_solver.setOption("oracles", "true"); + d_solver.setOption("produce-models", "true"); + Sort iSort = d_solver.getIntegerSort(); + Sort bSort = d_solver.getBooleanSort(); + // f is the function implementing (lambda ((x Int) (y Int)) (= x y)) + Term eq = d_solver.declareOracleFun( + "eq", {iSort, iSort}, bSort, [&](const std::vector& input) { + return d_solver.mkBoolean(input[0] == input[1]); + }); + Term x = d_solver.mkConst(iSort, "x"); + Term y = d_solver.mkConst(iSort, "y"); + Term neq = d_solver.mkTerm(NOT, {d_solver.mkTerm(APPLY_UF, {eq, x, y})}); + d_solver.assertFormula(neq); + // (not (eq x y)) + ASSERT_TRUE(d_solver.checkSat().isSat()); + Term xval = d_solver.getValue(x); + Term yval = d_solver.getValue(y); + ASSERT_TRUE(xval != yval); +} + } // namespace test } // namespace cvc5::internal -- 2.30.2