This also updates several places to be generalized to methods that return a vector of Nodes. Previously we had assumed a use case returning a single node.
After this PR, #8622 will be updated to use the new std::function interface.
#include "expr/oracle_caller.h"
-#include <sstream>
-
-#include "options/base_options.h"
#include "theory/quantifiers/quantifiers_attributes.h"
namespace cvc5::internal {
-OracleCaller::OracleCaller(const Node& oracleInterfaceNode)
- : d_binaryName(getBinaryNameFor(oracleInterfaceNode))
+OracleCaller::OracleCaller(const Node& n)
+ : d_oracleNode(getOracleFor(n)),
+ d_oracle(NodeManager::currentNM()->getOracleFor(d_oracleNode))
{
+ Assert(!d_oracleNode.isNull());
}
-bool OracleCaller::callOracle(const Node& fapp, Node& res, int& runResult)
+bool OracleCaller::callOracle(const Node& fapp, std::vector<Node>& res)
{
- std::map<Node, Node>::iterator it = d_cachedResults.find(fapp);
+ std::map<Node, std::vector<Node>>::iterator it = d_cachedResults.find(fapp);
if (it != d_cachedResults.end())
{
Trace("oracle-calls") << "Using cached oracle result for " << fapp
return false;
}
Assert(fapp.getKind() == kind::APPLY_UF);
- Assert(getBinaryNameFor(fapp.getOperator()) == d_binaryName);
- std::vector<std::string> sargs;
- sargs.push_back(d_binaryName);
+ Assert(getOracleFor(fapp.getOperator()) == d_oracleNode);
Trace("oracle-calls") << "Call oracle " << fapp << std::endl;
- for (const Node& arg : fapp)
- {
- std::ostringstream oss;
- oss << arg;
- sargs.push_back(oss.str());
- }
-
- // Run the oracle binary for `sargs`, which indicates a list of
- // smt2 terms as strings.
-
- // Parse response from the binary into a Node. The response from the binary
- // should be a string that can be parsed as a (tuple of) terms in the smt2
- // format.
- Node response = Node::null();
+ // get the input arguments from the application
+ std::vector<Node> args(fapp.begin(), fapp.end());
+ // run the oracle method
+ std::vector<Node> response = d_oracle.run(args);
Trace("oracle-calls") << "response node " << response << std::endl;
+ // cache the response
d_cachedResults[fapp] = response;
res = response;
return true;
return isOracleFunction(n);
}
-std::string OracleCaller::getBinaryName() const { return d_binaryName; }
-
-std::string OracleCaller::getBinaryNameFor(const Node& n)
+Node OracleCaller::getOracleFor(const Node& n)
{
// oracle functions have no children
if (n.isVar())
{
Assert(isOracleFunction(n));
- return n.getAttribute(theory::OracleInterfaceAttribute());
+ Node o = n.getAttribute(theory::OracleInterfaceAttribute());
+ Assert(o.getKind() == kind::ORACLE);
+ return o;
}
else if (n.getKind() == kind::FORALL)
{
// oracle interfaces have children, and the attribute is stored in 2nd child
for (const Node& v : n[2][0])
{
- if (v.getAttribute(theory::OracleInterfaceAttribute()) != "")
+ if (v.getKind() == kind::ORACLE)
{
- return v.getAttribute(theory::OracleInterfaceAttribute());
+ return v;
}
}
}
- Assert(false) << "Unexpected node for binary name " << n;
- return "";
+ Assert(false) << "Unexpected node for oracle " << n;
+ return Node::null();
}
-const std::map<Node, Node>& OracleCaller::getCachedResults() const
+const std::map<Node, std::vector<Node>>& OracleCaller::getCachedResults() const
{
return d_cachedResults;
}
#define CVC5__EXPR__ORACLE_CALLER_H
#include "expr/node.h"
-#include "expr/node_trie.h"
+#include "expr/oracle.h"
namespace cvc5::internal {
/**
- * This class manages the calls to an (external) binary for a single oracle
- * function symbol or oracle interface quantified formula.
+ * This class manages the calls to an (externally implemented) oracle for a
+ * single oracle function symbol or oracle interface quantified formula.
*/
class OracleCaller
{
public:
/**
- * @param oracleInterfaceNode The oracle function symbol or oracle interface
- * quantified formula.
+ * @param n The oracle function or oracle interface quantified formula.
*/
- OracleCaller(const Node& oracleInterfaceNode);
+ OracleCaller(const Node& n);
~OracleCaller() {}
* fapp. Store in result res.
*
* Return true if the call was made, and false if it was already cached.
- *
- * If this method returns true, then runResult is set to the result returned
- * from executing the binary.
*/
- bool callOracle(const Node& fapp, Node& res, int& runResult);
-
- /** Get the binary name for this oracle caller */
- std::string getBinaryName() const;
+ bool callOracle(const Node& fapp, std::vector<Node>& res);
/** Get cached results for this oracle caller */
- const std::map<Node, Node>& getCachedResults() const;
-
- /**
- * Get binary from an oracle function, or an oracle interface quantified
- * formula.
- */
- static std::string getBinaryNameFor(const Node& n);
+ const std::map<Node, std::vector<Node>>& getCachedResults() const;
/** is f an oracle function? */
static bool isOracleFunction(Node f);
/** is n an oracle function application? */
static bool isOracleFunctionApp(Node n);
+ /**
+ * Get oracle from an oracle function, or an oracle interface quantified
+ * formula. Returns a node of kind ORACLE if the associated oracle exists,
+ * or null otherwise.
+ */
+ static Node getOracleFor(const Node& n);
+
private:
- /** name of binary */
- std::string d_binaryName;
+ /** The oracle node */
+ Node d_oracleNode;
+ /** The oracle */
+ const Oracle& d_oracle;
/**
* The cached results, mapping (APPLY_UF) applications of the oracle
* function to the parsed output of the binary.
*/
- std::map<Node, Node> d_cachedResults;
+ std::map<Node, std::vector<Node>> d_cachedResults;
};
} // namespace cvc5::internal
OracleCaller& caller = d_callers.at(f);
// get oracle result
- Node ret;
- int runResult;
- caller.callOracle(app, ret, runResult);
+ std::vector<Node> retv;
+ caller.callOracle(app, retv);
+ if (retv.size() != 1)
+ {
+ Assert(false) << "Failed to evaluate " << app
+ << " to a single return value, got: " << retv << std::endl;
+ return app;
+ }
+ Node ret = retv[0];
Assert(!ret.isNull());
return ret;
}
std::map<Node, OracleCaller>::const_iterator it = d_callers.find(f);
return it != d_callers.end();
}
-const std::map<Node, Node>& OracleChecker::getOracleCalls(Node f) const
+const std::map<Node, std::vector<Node>>& OracleChecker::getOracleCalls(
+ Node f) const
{
Assert(hasOracleCalls(f));
std::map<Node, OracleCaller>::const_iterator it = d_callers.find(f);
bool hasOracles() const;
/** Has oracle calls for oracle function symbol f. */
bool hasOracleCalls(Node f) const;
- /** Get the cached results for oracle function symbol f */
- const std::map<Node, Node>& getOracleCalls(Node f) const;
+ /**
+ * Get the cached results for oracle function symbol f. Note the vectors
+ * in the range of this method are expected to have size 1.
+ */
+ const std::map<Node, std::vector<Node>>& getOracleCalls(Node f) const;
private:
/**
if (Configuration::isAssertionBuild())
{
std::vector<Node> inputs, outputs;
- Node assume, constraint;
- std::string binName;
- getOracleInterface(q, inputs, outputs, assume, constraint, binName);
+ Node assume, constraint, oracle;
+ getOracleInterface(q, inputs, outputs, assume, constraint, oracle);
Assert(constraint.isConst() && constraint.getConst<bool>())
<< "Unhandled oracle constraint " << q;
CVC5_UNUSED bool isOracleFun = false;
return std::string("OracleEngine");
}
-void OracleEngine::declareOracleFun(Node f, const std::string& binName)
-{
- OracleInterfaceAttribute oia;
- f.setAttribute(oia, binName);
- d_oracleFuns.push_back(f);
-}
+void OracleEngine::declareOracleFun(Node f) { d_oracleFuns.push_back(f); }
std::vector<Node> OracleEngine::getOracleFuns() const
{
const std::vector<Node>& outputs,
Node assume,
Node constraint,
- const std::string& binName)
+ Node oracleNode)
{
Assert(!assume.isNull());
Assert(!constraint.isNull());
+ Assert(oracleNode.getKind() == ORACLE);
NodeManager* nm = NodeManager::currentNM();
- SkolemManager* sm = nm->getSkolemManager();
- OracleInterfaceAttribute oia;
- Node oiVar = sm->mkDummySkolem("oracle-interface", nm->booleanType());
- oiVar.setAttribute(oia, binName);
- Node ipl = nm->mkNode(INST_PATTERN_LIST, nm->mkNode(INST_ATTRIBUTE, oiVar));
+ Node ipl =
+ nm->mkNode(INST_PATTERN_LIST, nm->mkNode(INST_ATTRIBUTE, oracleNode));
std::vector<Node> vars;
OracleInputVarAttribute oiva;
for (Node v : inputs)
std::vector<Node>& outputs,
Node& assume,
Node& constraint,
- std::string& binName) const
+ Node& oracleNode) const
{
QuantAttributes& qa = d_qreg.getQuantAttributes();
if (qa.isOracleInterface(q))
constraint = q[1][0];
Assert(q.getNumChildren() == 3);
Assert(q[2].getNumChildren() == 1);
- OracleInterfaceAttribute oia;
- Assert(q[2][0].hasAttribute(oia));
- binName = q[2][0].getAttribute(oia);
+ Assert(q[2][0].getKind() == ORACLE);
+ oracleNode = q[2][0];
return true;
}
return false;
std::string identify() const override;
/** Declare oracle fun */
- void declareOracleFun(Node f, const std::string& binName);
+ void declareOracleFun(Node f);
/** Get the list of all declared oracle functions */
std::vector<Node> getOracleFuns() const;
const std::vector<Node>& outputs,
Node assume,
Node constraint,
- const std::string& binName);
+ Node oracleNode);
/**
* Get oracle interface, returns true if q is an oracle interface quantifier
* (constructed by the above method). Obtains the arguments for which q is
std::vector<Node>& outputs,
Node& assume,
Node& constraint,
- std::string& binName) const;
+ Node& oracleNode) const;
private:
/** The oracle functions (user-context dependent) */
Trace("quant-attr") << "Attribute : sygus : " << q << std::endl;
qa.d_sygus = true;
}
- if (avar.hasAttribute(OracleInterfaceAttribute()))
+ // oracles are specified by a distinguished variable kind
+ if (avar.getKind() == kind::ORACLE)
{
- qa.d_oracleInterfaceBin =
- avar.getAttribute(OracleInterfaceAttribute());
+ qa.d_oracle = avar;
Trace("quant-attr")
- << "Attribute : oracle interface : " << qa.d_oracleInterfaceBin
- << " : " << q << std::endl;
+ << "Attribute : oracle interface : " << q << std::endl;
}
if (avar.hasAttribute(SygusSideConditionAttribute()))
{
struct OracleInterfaceAttributeId
{
};
-typedef expr::Attribute<OracleInterfaceAttributeId, std::string>
+typedef expr::Attribute<OracleInterfaceAttributeId, Node>
OracleInterfaceAttribute;
/**Attribute to give names to quantified formulas */
: d_hasPattern(false),
d_hasPool(false),
d_sygus(false),
- d_oracleInterfaceBin(""),
d_qinstLevel(-1),
d_quant_elim(false),
d_quant_elim_partial(false),
Node d_fundef_f;
/** is this formula marked as a sygus conjecture? */
bool d_sygus;
- /** the binary name, if this is an oracle interface quantifier */
- std::string d_oracleInterfaceBin;
+ /** the oracle, which stores an implementation */
+ Node d_oracle;
/** side condition for sygus conjectures */
Node d_sygusSideCondition;
/** stores the maximum instantiation level allowed for this quantified formula
/** is this quantified formula a function definition? */
bool isFunDef() const { return !d_fundef_f.isNull(); }
/** is this quantified formula an oracle interface quantifier? */
- bool isOracleInterface() const { return !d_oracleInterfaceBin.empty(); }
+ bool isOracleInterface() const { return !d_oracle.isNull(); }
/**
* Is this a standard quantifier? A standard quantifier is one that we can
* perform destructive updates (variable elimination, miniscoping, etc).
// to the query.
if (ochecker != nullptr && ochecker->hasOracleCalls(f))
{
- const std::map<Node, Node>& ocalls = ochecker->getOracleCalls(f);
- for (const std::pair<const Node, Node>& oc : ocalls)
+ const std::map<Node, std::vector<Node>>& ocalls =
+ ochecker->getOracleCalls(f);
+ for (const std::pair<const Node, std::vector<Node>>& oc : ocalls)
{
- qconj.push_back(oc.first.eqNode(oc.second));
+ // we ignore calls that had a size other than one
+ if (oc.second.size() == 1)
+ {
+ qconj.push_back(oc.first.eqNode(oc.second[0]));
+ }
}
}
}
d_treg.declarePool(p, initValue);
}
-void QuantifiersEngine::declareOracleFun(Node f, const std::string& binName)
+void QuantifiersEngine::declareOracleFun(Node f)
{
if (d_qmodules->d_oracleEngine.get() == nullptr)
{
<< std::endl;
return;
}
- d_qmodules->d_oracleEngine->declareOracleFun(f, binName);
+ d_qmodules->d_oracleEngine->declareOracleFun(f);
}
std::vector<Node> QuantifiersEngine::getOracleFuns() const
{
/** Declare pool */
void declarePool(Node p, const std::vector<Node>& initValue);
/** Declare oracle fun */
- void declareOracleFun(Node f, const std::string& binName);
+ void declareOracleFun(Node f);
/** Get the list of all declared oracle functions */
std::vector<Node> getOracleFuns() const;
//----------end user interface for instantiations