Refactor oracles using new std::function backend (#8717)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 4 May 2022 15:03:16 +0000 (10:03 -0500)
committerGitHub <noreply@github.com>
Wed, 4 May 2022 15:03:16 +0000 (15:03 +0000)
This also updates several places to be generalized to methods that return a vector of Nodes. Previously we had assumed a use case returning a single node.

After this PR, #8622 will be updated to use the new std::function interface.

src/expr/oracle_caller.cpp
src/expr/oracle_caller.h
src/theory/quantifiers/oracle_checker.cpp
src/theory/quantifiers/oracle_checker.h
src/theory/quantifiers/oracle_engine.cpp
src/theory/quantifiers/oracle_engine.h
src/theory/quantifiers/quantifiers_attributes.cpp
src/theory/quantifiers/quantifiers_attributes.h
src/theory/quantifiers/sygus/synth_verify.cpp
src/theory/quantifiers_engine.cpp
src/theory/quantifiers_engine.h

index 5feb5068e0ab7a04ddc93e0c51c422d67bc7ab7c..94d547845e5e73226204572fafed247be76c7f3d 100644 (file)
 
 #include "expr/oracle_caller.h"
 
-#include <sstream>
-
-#include "options/base_options.h"
 #include "theory/quantifiers/quantifiers_attributes.h"
 
 namespace cvc5::internal {
 
-OracleCaller::OracleCaller(const Node& oracleInterfaceNode)
-    : d_binaryName(getBinaryNameFor(oracleInterfaceNode))
+OracleCaller::OracleCaller(const Node& n)
+    : d_oracleNode(getOracleFor(n)),
+      d_oracle(NodeManager::currentNM()->getOracleFor(d_oracleNode))
 {
+  Assert(!d_oracleNode.isNull());
 }
 
-bool OracleCaller::callOracle(const Node& fapp, Node& res, int& runResult)
+bool OracleCaller::callOracle(const Node& fapp, std::vector<Node>& res)
 {
-  std::map<Node, Node>::iterator it = d_cachedResults.find(fapp);
+  std::map<Node, std::vector<Node>>::iterator it = d_cachedResults.find(fapp);
   if (it != d_cachedResults.end())
   {
     Trace("oracle-calls") << "Using cached oracle result for " << fapp
@@ -39,26 +38,15 @@ bool OracleCaller::callOracle(const Node& fapp, Node& res, int& runResult)
     return false;
   }
   Assert(fapp.getKind() == kind::APPLY_UF);
-  Assert(getBinaryNameFor(fapp.getOperator()) == d_binaryName);
-  std::vector<std::string> sargs;
-  sargs.push_back(d_binaryName);
+  Assert(getOracleFor(fapp.getOperator()) == d_oracleNode);
 
   Trace("oracle-calls") << "Call oracle " << fapp << std::endl;
-  for (const Node& arg : fapp)
-  {
-    std::ostringstream oss;
-    oss << arg;
-    sargs.push_back(oss.str());
-  }
-
-  // Run the oracle binary for `sargs`, which indicates a list of
-  // smt2 terms as strings.
-
-  // Parse response from the binary into a Node. The response from the binary
-  // should be a string that can be parsed as a (tuple of) terms in the smt2
-  // format.
-  Node response = Node::null();
+  // get the input arguments from the application
+  std::vector<Node> args(fapp.begin(), fapp.end());
+  // run the oracle method
+  std::vector<Node> response = d_oracle.run(args);
   Trace("oracle-calls") << "response node " << response << std::endl;
+  // cache the response
   d_cachedResults[fapp] = response;
   res = response;
   return true;
@@ -79,32 +67,32 @@ bool OracleCaller::isOracleFunctionApp(Node n)
   return isOracleFunction(n);
 }
 
-std::string OracleCaller::getBinaryName() const { return d_binaryName; }
-
-std::string OracleCaller::getBinaryNameFor(const Node& n)
+Node OracleCaller::getOracleFor(const Node& n)
 {
   // oracle functions have no children
   if (n.isVar())
   {
     Assert(isOracleFunction(n));
-    return n.getAttribute(theory::OracleInterfaceAttribute());
+    Node o = n.getAttribute(theory::OracleInterfaceAttribute());
+    Assert(o.getKind() == kind::ORACLE);
+    return o;
   }
   else if (n.getKind() == kind::FORALL)
   {
     // oracle interfaces have children, and the attribute is stored in 2nd child
     for (const Node& v : n[2][0])
     {
-      if (v.getAttribute(theory::OracleInterfaceAttribute()) != "")
+      if (v.getKind() == kind::ORACLE)
       {
-        return v.getAttribute(theory::OracleInterfaceAttribute());
+        return v;
       }
     }
   }
-  Assert(false) << "Unexpected node for binary name " << n;
-  return "";
+  Assert(false) << "Unexpected node for oracle " << n;
+  return Node::null();
 }
 
-const std::map<Node, Node>& OracleCaller::getCachedResults() const
+const std::map<Node, std::vector<Node>>& OracleCaller::getCachedResults() const
 {
   return d_cachedResults;
 }
index 294ec6f43c7ff3f10acca30b387e503af97e281a..786e7b6ed83e585893ccd9e813fa58688cd5a17e 100644 (file)
 #define CVC5__EXPR__ORACLE_CALLER_H
 
 #include "expr/node.h"
-#include "expr/node_trie.h"
+#include "expr/oracle.h"
 
 namespace cvc5::internal {
 
 /**
- * This class manages the calls to an (external) binary for a single oracle
- * function symbol or oracle interface quantified formula.
+ * This class manages the calls to an (externally implemented) oracle for a
+ * single oracle function symbol or oracle interface quantified formula.
  */
 class OracleCaller
 {
  public:
   /**
-   * @param oracleInterfaceNode The oracle function symbol or oracle interface
-   * quantified formula.
+   * @param n The oracle function or oracle interface quantified formula.
    */
-  OracleCaller(const Node& oracleInterfaceNode);
+  OracleCaller(const Node& n);
 
   ~OracleCaller() {}
 
@@ -43,37 +42,34 @@ class OracleCaller
    * fapp. Store in result res.
    *
    * Return true if the call was made, and false if it was already cached.
-   *
-   * If this method returns true, then runResult is set to the result returned
-   * from executing the binary.
    */
-  bool callOracle(const Node& fapp, Node& res, int& runResult);
-
-  /** Get the binary name for this oracle caller */
-  std::string getBinaryName() const;
+  bool callOracle(const Node& fapp, std::vector<Node>& res);
 
   /** Get cached results for this oracle caller */
-  const std::map<Node, Node>& getCachedResults() const;
-
-  /**
-   * Get binary from an oracle function, or an oracle interface quantified
-   * formula.
-   */
-  static std::string getBinaryNameFor(const Node& n);
+  const std::map<Node, std::vector<Node>>& getCachedResults() const;
 
   /** is f an oracle function? */
   static bool isOracleFunction(Node f);
   /** is n an oracle function application? */
   static bool isOracleFunctionApp(Node n);
 
+  /**
+   * Get oracle from an oracle function, or an oracle interface quantified
+   * formula. Returns a node of kind ORACLE if the associated oracle exists,
+   * or null otherwise.
+   */
+  static Node getOracleFor(const Node& n);
+
  private:
-  /** name of binary */
-  std::string d_binaryName;
+  /** The oracle node */
+  Node d_oracleNode;
+  /** The oracle */
+  const Oracle& d_oracle;
   /**
    * The cached results, mapping (APPLY_UF) applications of the oracle
    * function to the parsed output of the binary.
    */
-  std::map<Node, Node> d_cachedResults;
+  std::map<Node, std::vector<Node>> d_cachedResults;
 };
 
 }  // namespace cvc5::internal
index bc0c1e40afeb48b55bee1ef3a56459b7e68784f9..13a2e7630f6f7f1aae28ffac32e3a9fda63396e2 100644 (file)
@@ -50,9 +50,15 @@ Node OracleChecker::evaluateApp(Node app)
   OracleCaller& caller = d_callers.at(f);
 
   // get oracle result
-  Node ret;
-  int runResult;
-  caller.callOracle(app, ret, runResult);
+  std::vector<Node> retv;
+  caller.callOracle(app, retv);
+  if (retv.size() != 1)
+  {
+    Assert(false) << "Failed to evaluate " << app
+                  << " to a single return value, got: " << retv << std::endl;
+    return app;
+  }
+  Node ret = retv[0];
   Assert(!ret.isNull());
   return ret;
 }
@@ -111,7 +117,8 @@ bool OracleChecker::hasOracleCalls(Node f) const
   std::map<Node, OracleCaller>::const_iterator it = d_callers.find(f);
   return it != d_callers.end();
 }
-const std::map<Node, Node>& OracleChecker::getOracleCalls(Node f) const
+const std::map<Node, std::vector<Node>>& OracleChecker::getOracleCalls(
+    Node f) const
 {
   Assert(hasOracleCalls(f));
   std::map<Node, OracleCaller>::const_iterator it = d_callers.find(f);
index 7d46ffc6a26d219a5d19fe90bf14cda33c08756e..fa86f3ab2682e765b531180e2e2ef32e20349b85 100644 (file)
@@ -78,8 +78,11 @@ class OracleChecker : protected EnvObj, public NodeConverter
   bool hasOracles() const;
   /** Has oracle calls for oracle function symbol f. */
   bool hasOracleCalls(Node f) const;
-  /** Get the cached results for oracle function symbol f */
-  const std::map<Node, Node>& getOracleCalls(Node f) const;
+  /**
+   * Get the cached results for oracle function symbol f. Note the vectors
+   * in the range of this method are expected to have size 1.
+   */
+  const std::map<Node, std::vector<Node>>& getOracleCalls(Node f) const;
 
  private:
   /**
index 84856f8e72a94e14dd804e726884ad836c7ca5e5..4b0783fdbd0faa78d8a511213a2a01ed45ce3d6d 100644 (file)
@@ -208,9 +208,8 @@ void OracleEngine::checkOwnership(Node q)
   if (Configuration::isAssertionBuild())
   {
     std::vector<Node> inputs, outputs;
-    Node assume, constraint;
-    std::string binName;
-    getOracleInterface(q, inputs, outputs, assume, constraint, binName);
+    Node assume, constraint, oracle;
+    getOracleInterface(q, inputs, outputs, assume, constraint, oracle);
     Assert(constraint.isConst() && constraint.getConst<bool>())
         << "Unhandled oracle constraint " << q;
     CVC5_UNUSED bool isOracleFun = false;
@@ -240,12 +239,7 @@ std::string OracleEngine::identify() const
   return std::string("OracleEngine");
 }
 
-void OracleEngine::declareOracleFun(Node f, const std::string& binName)
-{
-  OracleInterfaceAttribute oia;
-  f.setAttribute(oia, binName);
-  d_oracleFuns.push_back(f);
-}
+void OracleEngine::declareOracleFun(Node f) { d_oracleFuns.push_back(f); }
 
 std::vector<Node> OracleEngine::getOracleFuns() const
 {
@@ -261,16 +255,14 @@ Node OracleEngine::mkOracleInterface(const std::vector<Node>& inputs,
                                      const std::vector<Node>& outputs,
                                      Node assume,
                                      Node constraint,
-                                     const std::string& binName)
+                                     Node oracleNode)
 {
   Assert(!assume.isNull());
   Assert(!constraint.isNull());
+  Assert(oracleNode.getKind() == ORACLE);
   NodeManager* nm = NodeManager::currentNM();
-  SkolemManager* sm = nm->getSkolemManager();
-  OracleInterfaceAttribute oia;
-  Node oiVar = sm->mkDummySkolem("oracle-interface", nm->booleanType());
-  oiVar.setAttribute(oia, binName);
-  Node ipl = nm->mkNode(INST_PATTERN_LIST, nm->mkNode(INST_ATTRIBUTE, oiVar));
+  Node ipl =
+      nm->mkNode(INST_PATTERN_LIST, nm->mkNode(INST_ATTRIBUTE, oracleNode));
   std::vector<Node> vars;
   OracleInputVarAttribute oiva;
   for (Node v : inputs)
@@ -294,7 +286,7 @@ bool OracleEngine::getOracleInterface(Node q,
                                       std::vector<Node>& outputs,
                                       Node& assume,
                                       Node& constraint,
-                                      std::string& binName) const
+                                      Node& oracleNode) const
 {
   QuantAttributes& qa = d_qreg.getQuantAttributes();
   if (qa.isOracleInterface(q))
@@ -318,9 +310,8 @@ bool OracleEngine::getOracleInterface(Node q,
     constraint = q[1][0];
     Assert(q.getNumChildren() == 3);
     Assert(q[2].getNumChildren() == 1);
-    OracleInterfaceAttribute oia;
-    Assert(q[2][0].hasAttribute(oia));
-    binName = q[2][0].getAttribute(oia);
+    Assert(q[2][0].getKind() == ORACLE);
+    oracleNode = q[2][0];
     return true;
   }
   return false;
index a2cfc4d47e6aa80bfada07dc0fc1bd877829e51b..582f2465e6d116ae7bffcd5d08fe30052582c714 100644 (file)
@@ -79,7 +79,7 @@ class OracleEngine : public QuantifiersModule
   std::string identify() const override;
 
   /** Declare oracle fun */
-  void declareOracleFun(Node f, const std::string& binName);
+  void declareOracleFun(Node f);
   /** Get the list of all declared oracle functions */
   std::vector<Node> getOracleFuns() const;
 
@@ -88,7 +88,7 @@ class OracleEngine : public QuantifiersModule
                                 const std::vector<Node>& outputs,
                                 Node assume,
                                 Node constraint,
-                                const std::string& binName);
+                                Node oracleNode);
   /**
    * Get oracle interface, returns true if q is an oracle interface quantifier
    * (constructed by the above method). Obtains the arguments for which q is
@@ -99,7 +99,7 @@ class OracleEngine : public QuantifiersModule
                           std::vector<Node>& outputs,
                           Node& assume,
                           Node& constraint,
-                          std::string& binName) const;
+                          Node& oracleNode) const;
 
  private:
   /** The oracle functions (user-context dependent) */
index 8775c0f79b03542258b87e9a9ad65be3f542a360..95ea7ce8b8126f243f0f417d915bc7464a5d444b 100644 (file)
@@ -259,13 +259,12 @@ void QuantAttributes::computeQuantAttributes( Node q, QAttributes& qa ){
           Trace("quant-attr") << "Attribute : sygus : " << q << std::endl;
           qa.d_sygus = true;
         }
-        if (avar.hasAttribute(OracleInterfaceAttribute()))
+        // oracles are specified by a distinguished variable kind
+        if (avar.getKind() == kind::ORACLE)
         {
-          qa.d_oracleInterfaceBin =
-              avar.getAttribute(OracleInterfaceAttribute());
+          qa.d_oracle = avar;
           Trace("quant-attr")
-              << "Attribute : oracle interface : " << qa.d_oracleInterfaceBin
-              << " : " << q << std::endl;
+              << "Attribute : oracle interface : " << q << std::endl;
         }
         if (avar.hasAttribute(SygusSideConditionAttribute()))
         {
index 1fdbcb2fc040817c8b397458fb6f630d3aa9a15f..e44bbf2aabdcb2949ad0fd05200de960e7ed7db1 100644 (file)
@@ -58,7 +58,7 @@ typedef expr::Attribute< SygusAttributeId, bool > SygusAttribute;
 struct OracleInterfaceAttributeId
 {
 };
-typedef expr::Attribute<OracleInterfaceAttributeId, std::string>
+typedef expr::Attribute<OracleInterfaceAttributeId, Node>
     OracleInterfaceAttribute;
 
 /**Attribute to give names to quantified formulas */
@@ -125,7 +125,6 @@ struct QAttributes
       : d_hasPattern(false),
         d_hasPool(false),
         d_sygus(false),
-        d_oracleInterfaceBin(""),
         d_qinstLevel(-1),
         d_quant_elim(false),
         d_quant_elim_partial(false),
@@ -142,8 +141,8 @@ struct QAttributes
   Node d_fundef_f;
   /** is this formula marked as a sygus conjecture? */
   bool d_sygus;
-  /** the binary name, if this is an oracle interface quantifier */
-  std::string d_oracleInterfaceBin;
+  /** the oracle, which stores an implementation */
+  Node d_oracle;
   /** side condition for sygus conjectures */
   Node d_sygusSideCondition;
   /** stores the maximum instantiation level allowed for this quantified formula
@@ -165,7 +164,7 @@ struct QAttributes
   /** is this quantified formula a function definition? */
   bool isFunDef() const { return !d_fundef_f.isNull(); }
   /** is this quantified formula an oracle interface quantifier? */
-  bool isOracleInterface() const { return !d_oracleInterfaceBin.empty(); }
+  bool isOracleInterface() const { return !d_oracle.isNull(); }
   /**
    * Is this a standard quantifier? A standard quantifier is one that we can
    * perform destructive updates (variable elimination, miniscoping, etc).
index 3f78a8d42b4ed637f9a7fae3fc76d3dbb522566f..a6b0512750bae8d69709a2b444867aefd60c687b 100644 (file)
@@ -186,10 +186,15 @@ Node SynthVerify::preprocessQueryInternal(Node query)
         // to the query.
         if (ochecker != nullptr && ochecker->hasOracleCalls(f))
         {
-          const std::map<Node, Node>& ocalls = ochecker->getOracleCalls(f);
-          for (const std::pair<const Node, Node>& oc : ocalls)
+          const std::map<Node, std::vector<Node>>& ocalls =
+              ochecker->getOracleCalls(f);
+          for (const std::pair<const Node, std::vector<Node>>& oc : ocalls)
           {
-            qconj.push_back(oc.first.eqNode(oc.second));
+            // we ignore calls that had a size other than one
+            if (oc.second.size() == 1)
+            {
+              qconj.push_back(oc.first.eqNode(oc.second[0]));
+            }
           }
         }
       }
index 08d68a95ff480f0011cd1f37c328b85d88e146f1..656bcdf4e9ad348e9ebc1ea3a8e1c9bca35a4be2 100644 (file)
@@ -703,7 +703,7 @@ void QuantifiersEngine::declarePool(Node p, const std::vector<Node>& initValue)
   d_treg.declarePool(p, initValue);
 }
 
-void QuantifiersEngine::declareOracleFun(Node f, const std::string& binName)
+void QuantifiersEngine::declareOracleFun(Node f)
 {
   if (d_qmodules->d_oracleEngine.get() == nullptr)
   {
@@ -711,7 +711,7 @@ void QuantifiersEngine::declareOracleFun(Node f, const std::string& binName)
               << std::endl;
     return;
   }
-  d_qmodules->d_oracleEngine->declareOracleFun(f, binName);
+  d_qmodules->d_oracleEngine->declareOracleFun(f);
 }
 std::vector<Node> QuantifiersEngine::getOracleFuns() const
 {
index 9da371430e38ddbdaf49bd2b3c5de059e941168d..c2f86d8f0d4f15a7a5fac16f01d5f118036e4547 100644 (file)
@@ -146,7 +146,7 @@ class QuantifiersEngine : protected EnvObj
   /** Declare pool */
   void declarePool(Node p, const std::vector<Node>& initValue);
   /** Declare oracle fun */
-  void declareOracleFun(Node f, const std::string& binName);
+  void declareOracleFun(Node f);
   /** Get the list of all declared oracle functions */
   std::vector<Node> getOracleFuns() const;
   //----------end user interface for instantiations