Implement internal support for (definitional) satisfiability modulo oracles (#8618)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 14 Apr 2022 22:53:28 +0000 (17:53 -0500)
committerGitHub <noreply@github.com>
Thu, 14 Apr 2022 22:53:28 +0000 (17:53 -0500)
Adds implementation of OracleEngine, which adds lemmas of the form (= (f c) d) on demand for I/O pairs (c,d) from oracle calls.

src/expr/oracle_caller.cpp
src/expr/oracle_caller.h
src/theory/inference_id.cpp
src/theory/inference_id.h
src/theory/quantifiers/oracle_engine.cpp
src/theory/quantifiers/oracle_engine.h

index dccddf7f0e8bf4ab2a68de334f07e23b100ac11c..5feb5068e0ab7a04ddc93e0c51c422d67bc7ab7c 100644 (file)
@@ -69,6 +69,16 @@ bool OracleCaller::isOracleFunction(Node f)
   return f.hasAttribute(theory::OracleInterfaceAttribute());
 }
 
+bool OracleCaller::isOracleFunctionApp(Node n)
+{
+  if (n.getKind() == kind::APPLY_UF)
+  {
+    return isOracleFunction(n.getOperator());
+  }
+  // possibly 0-ary
+  return isOracleFunction(n);
+}
+
 std::string OracleCaller::getBinaryName() const { return d_binaryName; }
 
 std::string OracleCaller::getBinaryNameFor(const Node& n)
index 449bc0a08503803004a491f1eb85d01e76f25792..294ec6f43c7ff3f10acca30b387e503af97e281a 100644 (file)
@@ -63,6 +63,8 @@ class OracleCaller
 
   /** is f an oracle function? */
   static bool isOracleFunction(Node f);
+  /** is n an oracle function application? */
+  static bool isOracleFunctionApp(Node n);
 
  private:
   /** name of binary */
index 22e693f9ff2f1608c7c608c2348c7d89102971ed..d3a79b8e3e066154f41e61f7c66d1f80ea0dcec3 100644 (file)
@@ -241,6 +241,8 @@ const char* toString(InferenceId i)
       return "QUANTIFIERS_CEGQI_VTS_UB_DELTA";
     case InferenceId::QUANTIFIERS_CEGQI_VTS_LB_INF:
       return "QUANTIFIERS_CEGQI_VTS_LB_INF";
+    case InferenceId::QUANTIFIERS_ORACLE_INTERFACE:
+      return "QUANTIFIERS_ORACLE_INTERFACE";
     case InferenceId::QUANTIFIERS_SYQI_CEX: return "QUANTIFIERS_SYQI_CEX";
     case InferenceId::QUANTIFIERS_SYQI_EVAL_UNFOLD:
       return "QUANTIFIERS_SYQI_EVAL_UNFOLD";
index 9f0692a6fb39ee6f59dafd45c1f836f94f2e7145..51ce478c40fd29c756f0d2b944e0abf14a642958 100644 (file)
@@ -354,6 +354,11 @@ enum class InferenceId
   QUANTIFIERS_CEGQI_VTS_UB_DELTA,
   // infinity > c
   QUANTIFIERS_CEGQI_VTS_LB_INF,
+  //-------------------- oracles
+  // A lemma generated by an oracle interface quantified formula.
+  // For example, (= (f c) d) where (c, d) is an I/O pair obtained from calling
+  // the oracle associated with oracle function f.
+  QUANTIFIERS_ORACLE_INTERFACE,
   //-------------------- syntax-guided instantiation
   // a counterexample lemma
   QUANTIFIERS_SYQI_CEX,
index 19cec1de51dcf2dc165b4905f1976b1d778e4e43..84856f8e72a94e14dd804e726884ad836c7ca5e5 100644 (file)
@@ -50,8 +50,12 @@ OracleEngine::OracleEngine(Env& env,
                            QuantifiersInferenceManager& qim,
                            QuantifiersRegistry& qr,
                            TermRegistry& tr)
-    : QuantifiersModule(env, qs, qim, qr, tr), d_oracleFuns(userContext())
+    : QuantifiersModule(env, qs, qim, qr, tr),
+      d_oracleFuns(userContext()),
+      d_ochecker(tr.getOracleChecker()),
+      d_consistencyCheckPassed(false)
 {
+  Assert(d_ochecker != nullptr);
 }
 
 void OracleEngine::presolve() {}
@@ -67,15 +71,169 @@ OracleEngine::QEffort OracleEngine::needsModel(Theory::Effort e)
   return QEFFORT_MODEL;
 }
 
-void OracleEngine::reset_round(Theory::Effort e) {}
+void OracleEngine::reset_round(Theory::Effort e)
+{
+  d_consistencyCheckPassed = false;
+}
 
 void OracleEngine::registerQuantifier(Node q) {}
 
-void OracleEngine::check(Theory::Effort e, QEffort quant_e) {}
+void OracleEngine::check(Theory::Effort e, QEffort quant_e)
+{
+  if (quant_e != QEFFORT_MODEL)
+  {
+    return;
+  }
 
-bool OracleEngine::checkCompleteFor(Node q) { return false; }
+  double clSet = 0;
+  if (TraceIsOn("oracle-engine"))
+  {
+    clSet = double(clock()) / double(CLOCKS_PER_SEC);
+    Trace("oracle-engine") << "---Oracle Engine Round, effort = " << e << "---"
+                           << std::endl;
+  }
+  FirstOrderModel* fm = d_treg.getModel();
+  TermDb* termDatabase = d_treg.getTermDatabase();
+  eq::EqualityEngine* eq = getEqualityEngine();
+  NodeManager* nm = NodeManager::currentNM();
+  unsigned nquant = fm->getNumAssertedQuantifiers();
+  std::vector<Node> currInterfaces;
+  for (unsigned i = 0; i < nquant; i++)
+  {
+    Node q = fm->getAssertedQuantifier(i);
+    if (d_qreg.getOwner(q) != this)
+    {
+      continue;
+    }
+    currInterfaces.push_back(q);
+  }
+  // Note that we currently ignore oracle interface quantified formulas, and
+  // look directly at the oracle functions. Note that:
+  // (1) The lemmas with InferenceId QUANTIFIERS_ORACLE_INTERFACE are not
+  // guarded by a quantified formula. This means that we are assuming that all
+  // oracle interface quantified formulas are top-level assertions. This is
+  // correct because we do not expose a way of embedding oracle interfaces into
+  // formulas at the user level.
+  // (2) We assume that oracle functions have associated oracle interface
+  // quantified formulas that are in currInterfaces.
+  // (3) We currently ignore oracle interface quantified formulas that are
+  // not associated with oracle functions.
+  //
+  // The current design choices above are due to the fact that our support is
+  // limited to "definitional SMTO" (see Polgreen et al 2022). In particular,
+  // we only support oracles that define I/O equalities for oracle functions
+  // only. The net effect of this class hence is to check the consistency of
+  // oracle functions, and allow "sat" or otherwise add a lemma with id
+  // QUANTIFIERS_ORACLE_INTERFACE.
+  std::vector<Node> learnedLemmas;
+  bool allFappsConsistent = true;
+  // iterate over oracle functions
+  for (const Node& f : d_oracleFuns)
+  {
+    TNodeTrie* tat = termDatabase->getTermArgTrie(f);
+    if (!tat)
+    {
+      continue;
+    }
+    std::vector<Node> apps = tat->getLeaves(f.getType().getArgTypes().size());
+    Trace("oracle-calls") << "Oracle fun " << f << " with " << apps.size()
+                          << " applications." << std::endl;
+    for (const auto& fapp : apps)
+    {
+      std::vector<Node> arguments;
+      arguments.push_back(f);
+      // evaluate arguments
+      for (const auto& arg : fapp)
+      {
+        arguments.push_back(fm->getValue(arg));
+      }
+      // call oracle
+      Node fappWithValues = nm->mkNode(APPLY_UF, arguments);
+      Node predictedResponse = eq->getRepresentative(fapp);
+      if (!d_ochecker->checkConsistent(
+              fappWithValues, predictedResponse, learnedLemmas))
+      {
+        allFappsConsistent = false;
+      }
+    }
+  }
+  // if all were consistent, we can terminate
+  if (allFappsConsistent)
+  {
+    Trace("oracle-engine-state")
+        << "All responses consistent, no lemmas added" << std::endl;
+    d_consistencyCheckPassed = true;
+  }
+  else
+  {
+    for (const Node& l : learnedLemmas)
+    {
+      Trace("oracle-engine-state") << "adding lemma " << l << std::endl;
+      d_qim.lemma(l, InferenceId::QUANTIFIERS_ORACLE_INTERFACE);
+    }
+  }
+  // general SMTO: call constraint generators and assumption generators here
 
-void OracleEngine::checkOwnership(Node q) {}
+  if (TraceIsOn("oracle-engine"))
+  {
+    double clSet2 = double(clock()) / double(CLOCKS_PER_SEC);
+    Trace("oracle-engine") << "Finished oracle engine, time = "
+                           << (clSet2 - clSet) << std::endl;
+  }
+}
+
+bool OracleEngine::checkCompleteFor(Node q)
+{
+  if (d_qreg.getOwner(q) != this)
+  {
+    return false;
+  }
+  // Only true if oracle consistency check was successful. Notice that
+  // we can say true for *all* oracle interface quantified formulas in the
+  // case that the consistency check passed. In particular, the invocation
+  // of oracle interfaces does not need to be complete.
+  return d_consistencyCheckPassed;
+}
+
+void OracleEngine::checkOwnership(Node q)
+{
+  // take ownership of quantified formulas that are oracle interfaces
+  QuantAttributes& qa = d_qreg.getQuantAttributes();
+  if (!qa.isOracleInterface(q))
+  {
+    return;
+  }
+  d_qreg.setOwner(q, this);
+  // We expect oracle interfaces to be limited to definitional SMTO currently.
+  if (Configuration::isAssertionBuild())
+  {
+    std::vector<Node> inputs, outputs;
+    Node assume, constraint;
+    std::string binName;
+    getOracleInterface(q, inputs, outputs, assume, constraint, binName);
+    Assert(constraint.isConst() && constraint.getConst<bool>())
+        << "Unhandled oracle constraint " << q;
+    CVC5_UNUSED bool isOracleFun = false;
+    if (OracleCaller::isOracleFunctionApp(assume))
+    {
+      // predicate case
+      isOracleFun = true;
+    }
+    else if (assume.getKind() == EQUAL)
+    {
+      for (size_t i = 0; i < 2; i++)
+      {
+        if (OracleCaller::isOracleFunctionApp(assume[i])
+            && assume[1 - i].isConst())
+        {
+          isOracleFun = true;
+        }
+      }
+    }
+    Assert(isOracleFun)
+        << "Non-definitional oracle interface quantified formula " << q;
+  }
+}
 
 std::string OracleEngine::identify() const
 {
@@ -131,6 +289,43 @@ Node OracleEngine::mkOracleInterface(const std::vector<Node>& inputs,
   return nm->mkNode(FORALL, bvl, body, ipl);
 }
 
+bool OracleEngine::getOracleInterface(Node q,
+                                      std::vector<Node>& inputs,
+                                      std::vector<Node>& outputs,
+                                      Node& assume,
+                                      Node& constraint,
+                                      std::string& binName) const
+{
+  QuantAttributes& qa = d_qreg.getQuantAttributes();
+  if (qa.isOracleInterface(q))
+  {
+    // fill in data
+    OracleInputVarAttribute oiva;
+    for (const Node& v : q[0])
+    {
+      if (v.hasAttribute(oiva))
+      {
+        inputs.push_back(v);
+      }
+      else
+      {
+        Assert(v.hasAttribute(OracleOutputVarAttribute()));
+        outputs.push_back(v);
+      }
+    }
+    Assert(q[1].getKind() == ORACLE_FORMULA_GEN);
+    assume = q[1][0];
+    constraint = q[1][0];
+    Assert(q.getNumChildren() == 3);
+    Assert(q[2].getNumChildren() == 1);
+    OracleInterfaceAttribute oia;
+    Assert(q[2][0].hasAttribute(oia));
+    binName = q[2][0].getAttribute(oia);
+    return true;
+  }
+  return false;
+}
+
 }  // namespace quantifiers
 }  // namespace theory
 }  // namespace cvc5::internal
index 7b12e29e122c12afe733b71eaf13ecb1bd3d0217..a2cfc4d47e6aa80bfada07dc0fc1bd877829e51b 100644 (file)
@@ -18,6 +18,7 @@
 #ifndef CVC5__THEORY__QUANTIFIERS__ORACLE_ENGINE_H
 #define CVC5__THEORY__QUANTIFIERS__ORACLE_ENGINE_H
 
+#include "theory/quantifiers/oracle_checker.h"
 #include "theory/quantifiers/quant_module.h"
 
 namespace cvc5::internal {
@@ -88,10 +89,28 @@ class OracleEngine : public QuantifiersModule
                                 Node assume,
                                 Node constraint,
                                 const std::string& binName);
+  /**
+   * Get oracle interface, returns true if q is an oracle interface quantifier
+   * (constructed by the above method). Obtains the arguments for which q is
+   * constructed.
+   */
+  bool getOracleInterface(Node q,
+                          std::vector<Node>& inputs,
+                          std::vector<Node>& outputs,
+                          Node& assume,
+                          Node& constraint,
+                          std::string& binName) const;
 
  private:
   /** The oracle functions (user-context dependent) */
   context::CDList<Node> d_oracleFuns;
+  /** Pointer to the oracle checker */
+  OracleChecker* d_ochecker;
+  /**
+   * In a given instantiation round, did consistency checks pass for all
+   * oracle interface quantified formulas?
+   */
+  bool d_consistencyCheckPassed;
 };
 
 }  // namespace quantifiers