From 2c8f7a72fbd01c06e717a02a34d610e39cc20042 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Thu, 14 Apr 2022 17:53:28 -0500 Subject: [PATCH] Implement internal support for (definitional) satisfiability modulo oracles (#8618) Adds implementation of OracleEngine, which adds lemmas of the form (= (f c) d) on demand for I/O pairs (c,d) from oracle calls. --- src/expr/oracle_caller.cpp | 10 ++ src/expr/oracle_caller.h | 2 + src/theory/inference_id.cpp | 2 + src/theory/inference_id.h | 5 + src/theory/quantifiers/oracle_engine.cpp | 205 ++++++++++++++++++++++- src/theory/quantifiers/oracle_engine.h | 19 +++ 6 files changed, 238 insertions(+), 5 deletions(-) diff --git a/src/expr/oracle_caller.cpp b/src/expr/oracle_caller.cpp index dccddf7f0..5feb5068e 100644 --- a/src/expr/oracle_caller.cpp +++ b/src/expr/oracle_caller.cpp @@ -69,6 +69,16 @@ bool OracleCaller::isOracleFunction(Node f) return f.hasAttribute(theory::OracleInterfaceAttribute()); } +bool OracleCaller::isOracleFunctionApp(Node n) +{ + if (n.getKind() == kind::APPLY_UF) + { + return isOracleFunction(n.getOperator()); + } + // possibly 0-ary + return isOracleFunction(n); +} + std::string OracleCaller::getBinaryName() const { return d_binaryName; } std::string OracleCaller::getBinaryNameFor(const Node& n) diff --git a/src/expr/oracle_caller.h b/src/expr/oracle_caller.h index 449bc0a08..294ec6f43 100644 --- a/src/expr/oracle_caller.h +++ b/src/expr/oracle_caller.h @@ -63,6 +63,8 @@ class OracleCaller /** is f an oracle function? */ static bool isOracleFunction(Node f); + /** is n an oracle function application? */ + static bool isOracleFunctionApp(Node n); private: /** name of binary */ diff --git a/src/theory/inference_id.cpp b/src/theory/inference_id.cpp index 22e693f9f..d3a79b8e3 100644 --- a/src/theory/inference_id.cpp +++ b/src/theory/inference_id.cpp @@ -241,6 +241,8 @@ const char* toString(InferenceId i) return "QUANTIFIERS_CEGQI_VTS_UB_DELTA"; case InferenceId::QUANTIFIERS_CEGQI_VTS_LB_INF: return "QUANTIFIERS_CEGQI_VTS_LB_INF"; + case InferenceId::QUANTIFIERS_ORACLE_INTERFACE: + return "QUANTIFIERS_ORACLE_INTERFACE"; case InferenceId::QUANTIFIERS_SYQI_CEX: return "QUANTIFIERS_SYQI_CEX"; case InferenceId::QUANTIFIERS_SYQI_EVAL_UNFOLD: return "QUANTIFIERS_SYQI_EVAL_UNFOLD"; diff --git a/src/theory/inference_id.h b/src/theory/inference_id.h index 9f0692a6f..51ce478c4 100644 --- a/src/theory/inference_id.h +++ b/src/theory/inference_id.h @@ -354,6 +354,11 @@ enum class InferenceId QUANTIFIERS_CEGQI_VTS_UB_DELTA, // infinity > c QUANTIFIERS_CEGQI_VTS_LB_INF, + //-------------------- oracles + // A lemma generated by an oracle interface quantified formula. + // For example, (= (f c) d) where (c, d) is an I/O pair obtained from calling + // the oracle associated with oracle function f. + QUANTIFIERS_ORACLE_INTERFACE, //-------------------- syntax-guided instantiation // a counterexample lemma QUANTIFIERS_SYQI_CEX, diff --git a/src/theory/quantifiers/oracle_engine.cpp b/src/theory/quantifiers/oracle_engine.cpp index 19cec1de5..84856f8e7 100644 --- a/src/theory/quantifiers/oracle_engine.cpp +++ b/src/theory/quantifiers/oracle_engine.cpp @@ -50,8 +50,12 @@ OracleEngine::OracleEngine(Env& env, QuantifiersInferenceManager& qim, QuantifiersRegistry& qr, TermRegistry& tr) - : QuantifiersModule(env, qs, qim, qr, tr), d_oracleFuns(userContext()) + : QuantifiersModule(env, qs, qim, qr, tr), + d_oracleFuns(userContext()), + d_ochecker(tr.getOracleChecker()), + d_consistencyCheckPassed(false) { + Assert(d_ochecker != nullptr); } void OracleEngine::presolve() {} @@ -67,15 +71,169 @@ OracleEngine::QEffort OracleEngine::needsModel(Theory::Effort e) return QEFFORT_MODEL; } -void OracleEngine::reset_round(Theory::Effort e) {} +void OracleEngine::reset_round(Theory::Effort e) +{ + d_consistencyCheckPassed = false; +} void OracleEngine::registerQuantifier(Node q) {} -void OracleEngine::check(Theory::Effort e, QEffort quant_e) {} +void OracleEngine::check(Theory::Effort e, QEffort quant_e) +{ + if (quant_e != QEFFORT_MODEL) + { + return; + } -bool OracleEngine::checkCompleteFor(Node q) { return false; } + double clSet = 0; + if (TraceIsOn("oracle-engine")) + { + clSet = double(clock()) / double(CLOCKS_PER_SEC); + Trace("oracle-engine") << "---Oracle Engine Round, effort = " << e << "---" + << std::endl; + } + FirstOrderModel* fm = d_treg.getModel(); + TermDb* termDatabase = d_treg.getTermDatabase(); + eq::EqualityEngine* eq = getEqualityEngine(); + NodeManager* nm = NodeManager::currentNM(); + unsigned nquant = fm->getNumAssertedQuantifiers(); + std::vector currInterfaces; + for (unsigned i = 0; i < nquant; i++) + { + Node q = fm->getAssertedQuantifier(i); + if (d_qreg.getOwner(q) != this) + { + continue; + } + currInterfaces.push_back(q); + } + // Note that we currently ignore oracle interface quantified formulas, and + // look directly at the oracle functions. Note that: + // (1) The lemmas with InferenceId QUANTIFIERS_ORACLE_INTERFACE are not + // guarded by a quantified formula. This means that we are assuming that all + // oracle interface quantified formulas are top-level assertions. This is + // correct because we do not expose a way of embedding oracle interfaces into + // formulas at the user level. + // (2) We assume that oracle functions have associated oracle interface + // quantified formulas that are in currInterfaces. + // (3) We currently ignore oracle interface quantified formulas that are + // not associated with oracle functions. + // + // The current design choices above are due to the fact that our support is + // limited to "definitional SMTO" (see Polgreen et al 2022). In particular, + // we only support oracles that define I/O equalities for oracle functions + // only. The net effect of this class hence is to check the consistency of + // oracle functions, and allow "sat" or otherwise add a lemma with id + // QUANTIFIERS_ORACLE_INTERFACE. + std::vector learnedLemmas; + bool allFappsConsistent = true; + // iterate over oracle functions + for (const Node& f : d_oracleFuns) + { + TNodeTrie* tat = termDatabase->getTermArgTrie(f); + if (!tat) + { + continue; + } + std::vector apps = tat->getLeaves(f.getType().getArgTypes().size()); + Trace("oracle-calls") << "Oracle fun " << f << " with " << apps.size() + << " applications." << std::endl; + for (const auto& fapp : apps) + { + std::vector arguments; + arguments.push_back(f); + // evaluate arguments + for (const auto& arg : fapp) + { + arguments.push_back(fm->getValue(arg)); + } + // call oracle + Node fappWithValues = nm->mkNode(APPLY_UF, arguments); + Node predictedResponse = eq->getRepresentative(fapp); + if (!d_ochecker->checkConsistent( + fappWithValues, predictedResponse, learnedLemmas)) + { + allFappsConsistent = false; + } + } + } + // if all were consistent, we can terminate + if (allFappsConsistent) + { + Trace("oracle-engine-state") + << "All responses consistent, no lemmas added" << std::endl; + d_consistencyCheckPassed = true; + } + else + { + for (const Node& l : learnedLemmas) + { + Trace("oracle-engine-state") << "adding lemma " << l << std::endl; + d_qim.lemma(l, InferenceId::QUANTIFIERS_ORACLE_INTERFACE); + } + } + // general SMTO: call constraint generators and assumption generators here -void OracleEngine::checkOwnership(Node q) {} + if (TraceIsOn("oracle-engine")) + { + double clSet2 = double(clock()) / double(CLOCKS_PER_SEC); + Trace("oracle-engine") << "Finished oracle engine, time = " + << (clSet2 - clSet) << std::endl; + } +} + +bool OracleEngine::checkCompleteFor(Node q) +{ + if (d_qreg.getOwner(q) != this) + { + return false; + } + // Only true if oracle consistency check was successful. Notice that + // we can say true for *all* oracle interface quantified formulas in the + // case that the consistency check passed. In particular, the invocation + // of oracle interfaces does not need to be complete. + return d_consistencyCheckPassed; +} + +void OracleEngine::checkOwnership(Node q) +{ + // take ownership of quantified formulas that are oracle interfaces + QuantAttributes& qa = d_qreg.getQuantAttributes(); + if (!qa.isOracleInterface(q)) + { + return; + } + d_qreg.setOwner(q, this); + // We expect oracle interfaces to be limited to definitional SMTO currently. + if (Configuration::isAssertionBuild()) + { + std::vector inputs, outputs; + Node assume, constraint; + std::string binName; + getOracleInterface(q, inputs, outputs, assume, constraint, binName); + Assert(constraint.isConst() && constraint.getConst()) + << "Unhandled oracle constraint " << q; + CVC5_UNUSED bool isOracleFun = false; + if (OracleCaller::isOracleFunctionApp(assume)) + { + // predicate case + isOracleFun = true; + } + else if (assume.getKind() == EQUAL) + { + for (size_t i = 0; i < 2; i++) + { + if (OracleCaller::isOracleFunctionApp(assume[i]) + && assume[1 - i].isConst()) + { + isOracleFun = true; + } + } + } + Assert(isOracleFun) + << "Non-definitional oracle interface quantified formula " << q; + } +} std::string OracleEngine::identify() const { @@ -131,6 +289,43 @@ Node OracleEngine::mkOracleInterface(const std::vector& inputs, return nm->mkNode(FORALL, bvl, body, ipl); } +bool OracleEngine::getOracleInterface(Node q, + std::vector& inputs, + std::vector& outputs, + Node& assume, + Node& constraint, + std::string& binName) const +{ + QuantAttributes& qa = d_qreg.getQuantAttributes(); + if (qa.isOracleInterface(q)) + { + // fill in data + OracleInputVarAttribute oiva; + for (const Node& v : q[0]) + { + if (v.hasAttribute(oiva)) + { + inputs.push_back(v); + } + else + { + Assert(v.hasAttribute(OracleOutputVarAttribute())); + outputs.push_back(v); + } + } + Assert(q[1].getKind() == ORACLE_FORMULA_GEN); + assume = q[1][0]; + constraint = q[1][0]; + Assert(q.getNumChildren() == 3); + Assert(q[2].getNumChildren() == 1); + OracleInterfaceAttribute oia; + Assert(q[2][0].hasAttribute(oia)); + binName = q[2][0].getAttribute(oia); + return true; + } + return false; +} + } // namespace quantifiers } // namespace theory } // namespace cvc5::internal diff --git a/src/theory/quantifiers/oracle_engine.h b/src/theory/quantifiers/oracle_engine.h index 7b12e29e1..a2cfc4d47 100644 --- a/src/theory/quantifiers/oracle_engine.h +++ b/src/theory/quantifiers/oracle_engine.h @@ -18,6 +18,7 @@ #ifndef CVC5__THEORY__QUANTIFIERS__ORACLE_ENGINE_H #define CVC5__THEORY__QUANTIFIERS__ORACLE_ENGINE_H +#include "theory/quantifiers/oracle_checker.h" #include "theory/quantifiers/quant_module.h" namespace cvc5::internal { @@ -88,10 +89,28 @@ class OracleEngine : public QuantifiersModule Node assume, Node constraint, const std::string& binName); + /** + * Get oracle interface, returns true if q is an oracle interface quantifier + * (constructed by the above method). Obtains the arguments for which q is + * constructed. + */ + bool getOracleInterface(Node q, + std::vector& inputs, + std::vector& outputs, + Node& assume, + Node& constraint, + std::string& binName) const; private: /** The oracle functions (user-context dependent) */ context::CDList d_oracleFuns; + /** Pointer to the oracle checker */ + OracleChecker* d_ochecker; + /** + * In a given instantiation round, did consistency checks pass for all + * oracle interface quantified formulas? + */ + bool d_consistencyCheckPassed; }; } // namespace quantifiers -- 2.30.2