From: Andrew Reynolds Date: Wed, 4 May 2022 15:03:16 +0000 (-0500) Subject: Refactor oracles using new std::function backend (#8717) X-Git-Tag: cvc5-1.0.1~174 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=67d02be826a12349eaf8420ab93495b8cfe5deca;p=cvc5.git Refactor oracles using new std::function backend (#8717) This also updates several places to be generalized to methods that return a vector of Nodes. Previously we had assumed a use case returning a single node. After this PR, #8622 will be updated to use the new std::function interface. --- diff --git a/src/expr/oracle_caller.cpp b/src/expr/oracle_caller.cpp index 5feb5068e..94d547845 100644 --- a/src/expr/oracle_caller.cpp +++ b/src/expr/oracle_caller.cpp @@ -15,21 +15,20 @@ #include "expr/oracle_caller.h" -#include - -#include "options/base_options.h" #include "theory/quantifiers/quantifiers_attributes.h" namespace cvc5::internal { -OracleCaller::OracleCaller(const Node& oracleInterfaceNode) - : d_binaryName(getBinaryNameFor(oracleInterfaceNode)) +OracleCaller::OracleCaller(const Node& n) + : d_oracleNode(getOracleFor(n)), + d_oracle(NodeManager::currentNM()->getOracleFor(d_oracleNode)) { + Assert(!d_oracleNode.isNull()); } -bool OracleCaller::callOracle(const Node& fapp, Node& res, int& runResult) +bool OracleCaller::callOracle(const Node& fapp, std::vector& res) { - std::map::iterator it = d_cachedResults.find(fapp); + std::map>::iterator it = d_cachedResults.find(fapp); if (it != d_cachedResults.end()) { Trace("oracle-calls") << "Using cached oracle result for " << fapp @@ -39,26 +38,15 @@ bool OracleCaller::callOracle(const Node& fapp, Node& res, int& runResult) return false; } Assert(fapp.getKind() == kind::APPLY_UF); - Assert(getBinaryNameFor(fapp.getOperator()) == d_binaryName); - std::vector sargs; - sargs.push_back(d_binaryName); + Assert(getOracleFor(fapp.getOperator()) == d_oracleNode); Trace("oracle-calls") << "Call oracle " << fapp << std::endl; - for (const Node& arg : fapp) - { - std::ostringstream oss; - oss << arg; - sargs.push_back(oss.str()); - } - - // Run the oracle binary for `sargs`, which indicates a list of - // smt2 terms as strings. - - // Parse response from the binary into a Node. The response from the binary - // should be a string that can be parsed as a (tuple of) terms in the smt2 - // format. - Node response = Node::null(); + // get the input arguments from the application + std::vector args(fapp.begin(), fapp.end()); + // run the oracle method + std::vector response = d_oracle.run(args); Trace("oracle-calls") << "response node " << response << std::endl; + // cache the response d_cachedResults[fapp] = response; res = response; return true; @@ -79,32 +67,32 @@ bool OracleCaller::isOracleFunctionApp(Node n) return isOracleFunction(n); } -std::string OracleCaller::getBinaryName() const { return d_binaryName; } - -std::string OracleCaller::getBinaryNameFor(const Node& n) +Node OracleCaller::getOracleFor(const Node& n) { // oracle functions have no children if (n.isVar()) { Assert(isOracleFunction(n)); - return n.getAttribute(theory::OracleInterfaceAttribute()); + Node o = n.getAttribute(theory::OracleInterfaceAttribute()); + Assert(o.getKind() == kind::ORACLE); + return o; } else if (n.getKind() == kind::FORALL) { // oracle interfaces have children, and the attribute is stored in 2nd child for (const Node& v : n[2][0]) { - if (v.getAttribute(theory::OracleInterfaceAttribute()) != "") + if (v.getKind() == kind::ORACLE) { - return v.getAttribute(theory::OracleInterfaceAttribute()); + return v; } } } - Assert(false) << "Unexpected node for binary name " << n; - return ""; + Assert(false) << "Unexpected node for oracle " << n; + return Node::null(); } -const std::map& OracleCaller::getCachedResults() const +const std::map>& OracleCaller::getCachedResults() const { return d_cachedResults; } diff --git a/src/expr/oracle_caller.h b/src/expr/oracle_caller.h index 294ec6f43..786e7b6ed 100644 --- a/src/expr/oracle_caller.h +++ b/src/expr/oracle_caller.h @@ -19,22 +19,21 @@ #define CVC5__EXPR__ORACLE_CALLER_H #include "expr/node.h" -#include "expr/node_trie.h" +#include "expr/oracle.h" namespace cvc5::internal { /** - * This class manages the calls to an (external) binary for a single oracle - * function symbol or oracle interface quantified formula. + * This class manages the calls to an (externally implemented) oracle for a + * single oracle function symbol or oracle interface quantified formula. */ class OracleCaller { public: /** - * @param oracleInterfaceNode The oracle function symbol or oracle interface - * quantified formula. + * @param n The oracle function or oracle interface quantified formula. */ - OracleCaller(const Node& oracleInterfaceNode); + OracleCaller(const Node& n); ~OracleCaller() {} @@ -43,37 +42,34 @@ class OracleCaller * fapp. Store in result res. * * Return true if the call was made, and false if it was already cached. - * - * If this method returns true, then runResult is set to the result returned - * from executing the binary. */ - bool callOracle(const Node& fapp, Node& res, int& runResult); - - /** Get the binary name for this oracle caller */ - std::string getBinaryName() const; + bool callOracle(const Node& fapp, std::vector& res); /** Get cached results for this oracle caller */ - const std::map& getCachedResults() const; - - /** - * Get binary from an oracle function, or an oracle interface quantified - * formula. - */ - static std::string getBinaryNameFor(const Node& n); + const std::map>& getCachedResults() const; /** is f an oracle function? */ static bool isOracleFunction(Node f); /** is n an oracle function application? */ static bool isOracleFunctionApp(Node n); + /** + * Get oracle from an oracle function, or an oracle interface quantified + * formula. Returns a node of kind ORACLE if the associated oracle exists, + * or null otherwise. + */ + static Node getOracleFor(const Node& n); + private: - /** name of binary */ - std::string d_binaryName; + /** The oracle node */ + Node d_oracleNode; + /** The oracle */ + const Oracle& d_oracle; /** * The cached results, mapping (APPLY_UF) applications of the oracle * function to the parsed output of the binary. */ - std::map d_cachedResults; + std::map> d_cachedResults; }; } // namespace cvc5::internal diff --git a/src/theory/quantifiers/oracle_checker.cpp b/src/theory/quantifiers/oracle_checker.cpp index bc0c1e40a..13a2e7630 100644 --- a/src/theory/quantifiers/oracle_checker.cpp +++ b/src/theory/quantifiers/oracle_checker.cpp @@ -50,9 +50,15 @@ Node OracleChecker::evaluateApp(Node app) OracleCaller& caller = d_callers.at(f); // get oracle result - Node ret; - int runResult; - caller.callOracle(app, ret, runResult); + std::vector retv; + caller.callOracle(app, retv); + if (retv.size() != 1) + { + Assert(false) << "Failed to evaluate " << app + << " to a single return value, got: " << retv << std::endl; + return app; + } + Node ret = retv[0]; Assert(!ret.isNull()); return ret; } @@ -111,7 +117,8 @@ bool OracleChecker::hasOracleCalls(Node f) const std::map::const_iterator it = d_callers.find(f); return it != d_callers.end(); } -const std::map& OracleChecker::getOracleCalls(Node f) const +const std::map>& OracleChecker::getOracleCalls( + Node f) const { Assert(hasOracleCalls(f)); std::map::const_iterator it = d_callers.find(f); diff --git a/src/theory/quantifiers/oracle_checker.h b/src/theory/quantifiers/oracle_checker.h index 7d46ffc6a..fa86f3ab2 100644 --- a/src/theory/quantifiers/oracle_checker.h +++ b/src/theory/quantifiers/oracle_checker.h @@ -78,8 +78,11 @@ class OracleChecker : protected EnvObj, public NodeConverter bool hasOracles() const; /** Has oracle calls for oracle function symbol f. */ bool hasOracleCalls(Node f) const; - /** Get the cached results for oracle function symbol f */ - const std::map& getOracleCalls(Node f) const; + /** + * Get the cached results for oracle function symbol f. Note the vectors + * in the range of this method are expected to have size 1. + */ + const std::map>& getOracleCalls(Node f) const; private: /** diff --git a/src/theory/quantifiers/oracle_engine.cpp b/src/theory/quantifiers/oracle_engine.cpp index 84856f8e7..4b0783fdb 100644 --- a/src/theory/quantifiers/oracle_engine.cpp +++ b/src/theory/quantifiers/oracle_engine.cpp @@ -208,9 +208,8 @@ void OracleEngine::checkOwnership(Node q) if (Configuration::isAssertionBuild()) { std::vector inputs, outputs; - Node assume, constraint; - std::string binName; - getOracleInterface(q, inputs, outputs, assume, constraint, binName); + Node assume, constraint, oracle; + getOracleInterface(q, inputs, outputs, assume, constraint, oracle); Assert(constraint.isConst() && constraint.getConst()) << "Unhandled oracle constraint " << q; CVC5_UNUSED bool isOracleFun = false; @@ -240,12 +239,7 @@ std::string OracleEngine::identify() const return std::string("OracleEngine"); } -void OracleEngine::declareOracleFun(Node f, const std::string& binName) -{ - OracleInterfaceAttribute oia; - f.setAttribute(oia, binName); - d_oracleFuns.push_back(f); -} +void OracleEngine::declareOracleFun(Node f) { d_oracleFuns.push_back(f); } std::vector OracleEngine::getOracleFuns() const { @@ -261,16 +255,14 @@ Node OracleEngine::mkOracleInterface(const std::vector& inputs, const std::vector& outputs, Node assume, Node constraint, - const std::string& binName) + Node oracleNode) { Assert(!assume.isNull()); Assert(!constraint.isNull()); + Assert(oracleNode.getKind() == ORACLE); NodeManager* nm = NodeManager::currentNM(); - SkolemManager* sm = nm->getSkolemManager(); - OracleInterfaceAttribute oia; - Node oiVar = sm->mkDummySkolem("oracle-interface", nm->booleanType()); - oiVar.setAttribute(oia, binName); - Node ipl = nm->mkNode(INST_PATTERN_LIST, nm->mkNode(INST_ATTRIBUTE, oiVar)); + Node ipl = + nm->mkNode(INST_PATTERN_LIST, nm->mkNode(INST_ATTRIBUTE, oracleNode)); std::vector vars; OracleInputVarAttribute oiva; for (Node v : inputs) @@ -294,7 +286,7 @@ bool OracleEngine::getOracleInterface(Node q, std::vector& outputs, Node& assume, Node& constraint, - std::string& binName) const + Node& oracleNode) const { QuantAttributes& qa = d_qreg.getQuantAttributes(); if (qa.isOracleInterface(q)) @@ -318,9 +310,8 @@ bool OracleEngine::getOracleInterface(Node q, constraint = q[1][0]; Assert(q.getNumChildren() == 3); Assert(q[2].getNumChildren() == 1); - OracleInterfaceAttribute oia; - Assert(q[2][0].hasAttribute(oia)); - binName = q[2][0].getAttribute(oia); + Assert(q[2][0].getKind() == ORACLE); + oracleNode = q[2][0]; return true; } return false; diff --git a/src/theory/quantifiers/oracle_engine.h b/src/theory/quantifiers/oracle_engine.h index a2cfc4d47..582f2465e 100644 --- a/src/theory/quantifiers/oracle_engine.h +++ b/src/theory/quantifiers/oracle_engine.h @@ -79,7 +79,7 @@ class OracleEngine : public QuantifiersModule std::string identify() const override; /** Declare oracle fun */ - void declareOracleFun(Node f, const std::string& binName); + void declareOracleFun(Node f); /** Get the list of all declared oracle functions */ std::vector getOracleFuns() const; @@ -88,7 +88,7 @@ class OracleEngine : public QuantifiersModule const std::vector& outputs, Node assume, Node constraint, - const std::string& binName); + Node oracleNode); /** * Get oracle interface, returns true if q is an oracle interface quantifier * (constructed by the above method). Obtains the arguments for which q is @@ -99,7 +99,7 @@ class OracleEngine : public QuantifiersModule std::vector& outputs, Node& assume, Node& constraint, - std::string& binName) const; + Node& oracleNode) const; private: /** The oracle functions (user-context dependent) */ diff --git a/src/theory/quantifiers/quantifiers_attributes.cpp b/src/theory/quantifiers/quantifiers_attributes.cpp index 8775c0f79..95ea7ce8b 100644 --- a/src/theory/quantifiers/quantifiers_attributes.cpp +++ b/src/theory/quantifiers/quantifiers_attributes.cpp @@ -259,13 +259,12 @@ void QuantAttributes::computeQuantAttributes( Node q, QAttributes& qa ){ Trace("quant-attr") << "Attribute : sygus : " << q << std::endl; qa.d_sygus = true; } - if (avar.hasAttribute(OracleInterfaceAttribute())) + // oracles are specified by a distinguished variable kind + if (avar.getKind() == kind::ORACLE) { - qa.d_oracleInterfaceBin = - avar.getAttribute(OracleInterfaceAttribute()); + qa.d_oracle = avar; Trace("quant-attr") - << "Attribute : oracle interface : " << qa.d_oracleInterfaceBin - << " : " << q << std::endl; + << "Attribute : oracle interface : " << q << std::endl; } if (avar.hasAttribute(SygusSideConditionAttribute())) { diff --git a/src/theory/quantifiers/quantifiers_attributes.h b/src/theory/quantifiers/quantifiers_attributes.h index 1fdbcb2fc..e44bbf2aa 100644 --- a/src/theory/quantifiers/quantifiers_attributes.h +++ b/src/theory/quantifiers/quantifiers_attributes.h @@ -58,7 +58,7 @@ typedef expr::Attribute< SygusAttributeId, bool > SygusAttribute; struct OracleInterfaceAttributeId { }; -typedef expr::Attribute +typedef expr::Attribute OracleInterfaceAttribute; /**Attribute to give names to quantified formulas */ @@ -125,7 +125,6 @@ struct QAttributes : d_hasPattern(false), d_hasPool(false), d_sygus(false), - d_oracleInterfaceBin(""), d_qinstLevel(-1), d_quant_elim(false), d_quant_elim_partial(false), @@ -142,8 +141,8 @@ struct QAttributes Node d_fundef_f; /** is this formula marked as a sygus conjecture? */ bool d_sygus; - /** the binary name, if this is an oracle interface quantifier */ - std::string d_oracleInterfaceBin; + /** the oracle, which stores an implementation */ + Node d_oracle; /** side condition for sygus conjectures */ Node d_sygusSideCondition; /** stores the maximum instantiation level allowed for this quantified formula @@ -165,7 +164,7 @@ struct QAttributes /** is this quantified formula a function definition? */ bool isFunDef() const { return !d_fundef_f.isNull(); } /** is this quantified formula an oracle interface quantifier? */ - bool isOracleInterface() const { return !d_oracleInterfaceBin.empty(); } + bool isOracleInterface() const { return !d_oracle.isNull(); } /** * Is this a standard quantifier? A standard quantifier is one that we can * perform destructive updates (variable elimination, miniscoping, etc). diff --git a/src/theory/quantifiers/sygus/synth_verify.cpp b/src/theory/quantifiers/sygus/synth_verify.cpp index 3f78a8d42..a6b051275 100644 --- a/src/theory/quantifiers/sygus/synth_verify.cpp +++ b/src/theory/quantifiers/sygus/synth_verify.cpp @@ -186,10 +186,15 @@ Node SynthVerify::preprocessQueryInternal(Node query) // to the query. if (ochecker != nullptr && ochecker->hasOracleCalls(f)) { - const std::map& ocalls = ochecker->getOracleCalls(f); - for (const std::pair& oc : ocalls) + const std::map>& ocalls = + ochecker->getOracleCalls(f); + for (const std::pair>& oc : ocalls) { - qconj.push_back(oc.first.eqNode(oc.second)); + // we ignore calls that had a size other than one + if (oc.second.size() == 1) + { + qconj.push_back(oc.first.eqNode(oc.second[0])); + } } } } diff --git a/src/theory/quantifiers_engine.cpp b/src/theory/quantifiers_engine.cpp index 08d68a95f..656bcdf4e 100644 --- a/src/theory/quantifiers_engine.cpp +++ b/src/theory/quantifiers_engine.cpp @@ -703,7 +703,7 @@ void QuantifiersEngine::declarePool(Node p, const std::vector& initValue) d_treg.declarePool(p, initValue); } -void QuantifiersEngine::declareOracleFun(Node f, const std::string& binName) +void QuantifiersEngine::declareOracleFun(Node f) { if (d_qmodules->d_oracleEngine.get() == nullptr) { @@ -711,7 +711,7 @@ void QuantifiersEngine::declareOracleFun(Node f, const std::string& binName) << std::endl; return; } - d_qmodules->d_oracleEngine->declareOracleFun(f, binName); + d_qmodules->d_oracleEngine->declareOracleFun(f); } std::vector QuantifiersEngine::getOracleFuns() const { diff --git a/src/theory/quantifiers_engine.h b/src/theory/quantifiers_engine.h index 9da371430..c2f86d8f0 100644 --- a/src/theory/quantifiers_engine.h +++ b/src/theory/quantifiers_engine.h @@ -146,7 +146,7 @@ class QuantifiersEngine : protected EnvObj /** Declare pool */ void declarePool(Node p, const std::vector& initValue); /** Declare oracle fun */ - void declareOracleFun(Node f, const std::string& binName); + void declareOracleFun(Node f); /** Get the list of all declared oracle functions */ std::vector getOracleFuns() const; //----------end user interface for instantiations