Add declareOracleFun to API (#8794)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 24 May 2022 14:07:42 +0000 (09:07 -0500)
committerGitHub <noreply@github.com>
Tue, 24 May 2022 14:07:42 +0000 (14:07 +0000)
Java and Python will be added in followup PRs.

NEWS.md
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5.h
src/smt/set_defaults.cpp
src/theory/quantifiers/oracle_engine.cpp
test/unit/api/cpp/solver_black.cpp

diff --git a/NEWS.md b/NEWS.md
index fd4ae1b7fd361097e466735c709a95c90d480c1c..1e6e08e29fcb68c7d8eb46c65a2a7faf08022359 100644 (file)
--- a/NEWS.md
+++ b/NEWS.md
@@ -20,6 +20,11 @@ cvc5 1.0.1
 - The API method `mkTuple` no longer supports casting integers to reals when
   constructing tuples.
 
+**New Features**
+
+- Support for declaring oracle functions in the API via the method
+  `declareOracleFun`. This allows users to declare functions whose semantics
+  are associated with a provided executable implementation.
 
 cvc5 1.0
 =========
index 35be11a649c8fc76d4768a6386af59692882a178..980e6468b4131b4c3fb9b7657e2c10781e2079e0 100644 (file)
@@ -62,8 +62,8 @@
 #include "options/option_exception.h"
 #include "options/options.h"
 #include "options/options_public.h"
-#include "options/smt_options.h"
 #include "options/quantifiers_options.h"
+#include "options/smt_options.h"
 #include "proof/unsat_core.h"
 #include "smt/env.h"
 #include "smt/model.h"
@@ -7027,6 +7027,40 @@ Term Solver::declarePool(const std::string& symbol,
   CVC5_API_TRY_CATCH_END;
 }
 
+Term Solver::declareOracleFun(
+    const std::string& symbol,
+    const std::vector<Sort>& sorts,
+    const Sort& sort,
+    std::function<Term(const std::vector<Term>&)> fn) const
+{
+  CVC5_API_TRY_CATCH_BEGIN;
+  CVC5_API_SOLVER_CHECK_DOMAIN_SORTS(sorts);
+  CVC5_API_SOLVER_CHECK_CODOMAIN_SORT(sort);
+  CVC5_API_CHECK(d_slv->getOptions().quantifiers.oracles)
+      << "Cannot call declareOracleFun unless oracles is enabled (use "
+         "--oracles)";
+  //////// all checks before this line
+  internal::TypeNode type = *sort.d_type;
+  if (!sorts.empty())
+  {
+    std::vector<internal::TypeNode> types = Sort::sortVectorToTypeNodes(sorts);
+    type = d_nodeMgr->mkFunctionType(types, type);
+  }
+  internal::Node fun = d_nodeMgr->mkVar(symbol, type);
+  // Wrap the terms-to-term function so that it is nodes-to-nodes. Note we
+  // make the method return a vector of size one to conform to the interface
+  // at the SolverEngine level.
+  d_slv->declareOracleFun(
+      fun, [&, fn](const std::vector<internal::Node> nodes) {
+        std::vector<Term> terms = Term::nodeVectorToTerms(this, nodes);
+        Term output = fn(terms);
+        return Term::termVectorToNodes({output});
+      });
+  return Term(this, fun);
+  ////////
+  CVC5_API_TRY_CATCH_END;
+}
+
 void Solver::pop(uint32_t nscopes) const
 {
   CVC5_API_TRY_CATCH_BEGIN;
index 6ad7b6a392443fad6b911a4308654a084775dfcd..5ba737d7bbe45a65f345c1987a83ef2462513f1e 100644 (file)
@@ -18,6 +18,7 @@
 #ifndef CVC5__API__CVC5_H
 #define CVC5__API__CVC5_H
 
+#include <functional>
 #include <map>
 #include <memory>
 #include <optional>
@@ -4435,6 +4436,36 @@ class CVC5_EXPORT Solver
   Term declarePool(const std::string& symbol,
                    const Sort& sort,
                    const std::vector<Term>& initValue) const;
+  /**
+   * Declare an oracle function with reference to an implementation.
+   *
+   * Oracle functions have a different semantics with respect to ordinary
+   * declared functions. In particular, for an input to be satisfiable,
+   * its oracle functions are implicitly universally quantified.
+   *
+   * This method is used in part for implementing this command:
+   *
+   * \verbatim embed:rst:leading-asterisk
+   * .. code:: smtlib
+   *
+   * (declare-oracle-fun <sym> (<sort>*) <sort> <sym>)
+   * \endverbatim
+   *
+   * In particular, the above command is implemented by constructing a
+   * function over terms that wraps a call to binary sym via a text interface.
+   *
+   * @warning This method is experimental and may change in future versions.
+   *
+   * @param symbol The name of the oracle
+   * @param sorts The sorts of the parameters to this function
+   * @param sort The sort of the return value of this function
+   * @param fn The function that implements the oracle function.
+   * @return The oracle function
+   */
+  Term declareOracleFun(const std::string& symbol,
+                        const std::vector<Sort>& sorts,
+                        const Sort& sort,
+                        std::function<Term(const std::vector<Term>&)> fn) const;
   /**
    * Pop (a) level(s) from the assertion stack.
    *
index dc26303adfe5d9642baff82b227dd58689cb7f34..d9a6f5fbcd21295d7ea9d029e1bf979f173a6204 100644 (file)
@@ -63,11 +63,6 @@ void SetDefaults::setDefaults(LogicInfo& logic, Options& opts)
 
 void SetDefaults::setDefaultsPre(Options& opts)
 {
-
-  if (opts.quantifiers.oracles)
-  {
-    throw OptionException(std::string("Oracles not yet supported"));
-  }
   // implied options
   if (opts.smt.debugCheckModels)
   {
index 4b0783fdbd0faa78d8a511213a2a01ed45ce3d6d..9ad68ed885654583957a43916f05134ff91f7c48 100644 (file)
@@ -209,21 +209,23 @@ void OracleEngine::checkOwnership(Node q)
   {
     std::vector<Node> inputs, outputs;
     Node assume, constraint, oracle;
-    getOracleInterface(q, inputs, outputs, assume, constraint, oracle);
-    Assert(constraint.isConst() && constraint.getConst<bool>())
-        << "Unhandled oracle constraint " << q;
-    CVC5_UNUSED bool isOracleFun = false;
-    if (OracleCaller::isOracleFunctionApp(assume))
+    if (!getOracleInterface(q, inputs, outputs, assume, constraint, oracle))
+    {
+      Assert(false) << "Not an oracle interface " << q;
+    }
+    else
     {
-      // predicate case
-      isOracleFun = true;
+      Assert(outputs.size() == 1) << "Unhandled oracle constraint " << q;
+      Assert(constraint.isConst() && constraint.getConst<bool>())
+          << "Unhandled oracle constraint " << q;
     }
-    else if (assume.getKind() == EQUAL)
+    CVC5_UNUSED bool isOracleFun = false;
+    if (assume.getKind() == EQUAL)
     {
       for (size_t i = 0; i < 2; i++)
       {
         if (OracleCaller::isOracleFunctionApp(assume[i])
-            && assume[1 - i].isConst())
+            && assume[1 - i] == outputs[0])
         {
           isOracleFun = true;
         }
@@ -295,23 +297,24 @@ bool OracleEngine::getOracleInterface(Node q,
     OracleInputVarAttribute oiva;
     for (const Node& v : q[0])
     {
-      if (v.hasAttribute(oiva))
+      if (v.getAttribute(oiva))
       {
         inputs.push_back(v);
       }
       else
       {
-        Assert(v.hasAttribute(OracleOutputVarAttribute()));
+        Assert(v.getAttribute(OracleOutputVarAttribute()));
         outputs.push_back(v);
       }
     }
     Assert(q[1].getKind() == ORACLE_FORMULA_GEN);
     assume = q[1][0];
-    constraint = q[1][0];
+    constraint = q[1][1];
     Assert(q.getNumChildren() == 3);
     Assert(q[2].getNumChildren() == 1);
-    Assert(q[2][0].getKind() == ORACLE);
-    oracleNode = q[2][0];
+    Assert(q[2][0].getNumChildren() == 1);
+    Assert(q[2][0][0].getKind() == ORACLE);
+    oracleNode = q[2][0][0];
     return true;
   }
   return false;
index b9627073c5bd7caf1bd5aefb66aa29adff9ac67a..e3bf8c074129341397d6cd351a17389276ecd1d1 100644 (file)
@@ -3439,5 +3439,99 @@ TEST_F(TestApiBlackSolver, projIssue337)
   ASSERT_EQ(t.getSort(), tt.getSort());
 }
 
+TEST_F(TestApiBlackSolver, declareOracleFunError)
+{
+  Sort iSort = d_solver.getIntegerSort();
+  // cannot declare without option
+  ASSERT_THROW(d_solver.declareOracleFun(
+      "f",
+      {iSort},
+      iSort,
+      [&](const std::vector<Term>& input) { return d_solver.mkInteger(0); });
+               , CVC5ApiException);
+  d_solver.setOption("oracles", "true");
+  Sort nullSort;
+  // bad sort
+  ASSERT_THROW(d_solver.declareOracleFun(
+      "f",
+      {nullSort},
+      iSort,
+      [&](const std::vector<Term>& input) { return d_solver.mkInteger(0); });
+               , CVC5ApiException);
+}
+
+TEST_F(TestApiBlackSolver, declareOracleFunUnsat)
+{
+  d_solver.setOption("oracles", "true");
+  Sort iSort = d_solver.getIntegerSort();
+  // f is the function implementing (lambda ((x Int)) (+ x 1))
+  Term f = d_solver.declareOracleFun(
+      "f", {iSort}, iSort, [&](const std::vector<Term>& input) {
+        if (input[0].isUInt32Value())
+        {
+          return d_solver.mkInteger(input[0].getUInt32Value() + 1);
+        }
+        return d_solver.mkInteger(0);
+      });
+  Term three = d_solver.mkInteger(3);
+  Term five = d_solver.mkInteger(5);
+  Term eq =
+      d_solver.mkTerm(EQUAL, {d_solver.mkTerm(APPLY_UF, {f, three}), five});
+  d_solver.assertFormula(eq);
+  // (f 3) = 5
+  ASSERT_TRUE(d_solver.checkSat().isUnsat());
+}
+
+TEST_F(TestApiBlackSolver, declareOracleFunSat)
+{
+  d_solver.setOption("oracles", "true");
+  d_solver.setOption("produce-models", "true");
+  Sort iSort = d_solver.getIntegerSort();
+  // f is the function implementing (lambda ((x Int)) (% x 10))
+  Term f = d_solver.declareOracleFun(
+      "f", {iSort}, iSort, [&](const std::vector<Term>& input) {
+        if (input[0].isUInt32Value())
+        {
+          return d_solver.mkInteger(input[0].getUInt32Value() % 10);
+        }
+        return d_solver.mkInteger(0);
+      });
+  Term seven = d_solver.mkInteger(7);
+  Term x = d_solver.mkConst(iSort, "x");
+  Term lb = d_solver.mkTerm(GEQ, {x, d_solver.mkInteger(0)});
+  d_solver.assertFormula(lb);
+  Term ub = d_solver.mkTerm(LEQ, {x, d_solver.mkInteger(100)});
+  d_solver.assertFormula(ub);
+  Term eq = d_solver.mkTerm(EQUAL, {d_solver.mkTerm(APPLY_UF, {f, x}), seven});
+  d_solver.assertFormula(eq);
+  // x >= 0 ^ x <= 100 ^ (f x) = 7
+  ASSERT_TRUE(d_solver.checkSat().isSat());
+  Term xval = d_solver.getValue(x);
+  ASSERT_TRUE(xval.isUInt32Value());
+  ASSERT_TRUE(xval.getUInt32Value() % 10 == 7);
+}
+
+TEST_F(TestApiBlackSolver, declareOracleFunSat2)
+{
+  d_solver.setOption("oracles", "true");
+  d_solver.setOption("produce-models", "true");
+  Sort iSort = d_solver.getIntegerSort();
+  Sort bSort = d_solver.getBooleanSort();
+  // f is the function implementing (lambda ((x Int) (y Int)) (= x y))
+  Term eq = d_solver.declareOracleFun(
+      "eq", {iSort, iSort}, bSort, [&](const std::vector<Term>& input) {
+        return d_solver.mkBoolean(input[0] == input[1]);
+      });
+  Term x = d_solver.mkConst(iSort, "x");
+  Term y = d_solver.mkConst(iSort, "y");
+  Term neq = d_solver.mkTerm(NOT, {d_solver.mkTerm(APPLY_UF, {eq, x, y})});
+  d_solver.assertFormula(neq);
+  // (not (eq x y))
+  ASSERT_TRUE(d_solver.checkSat().isSat());
+  Term xval = d_solver.getValue(x);
+  Term yval = d_solver.getValue(y);
+  ASSERT_TRUE(xval != yval);
+}
+
 }  // namespace test
 }  // namespace cvc5::internal