--- /dev/null
+/******************************************************************************
+ * Top contributors (to current version):
+ * Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved. See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Oracle checker
+ */
+
+#include "theory/quantifiers/oracle_checker.h"
+
+#include "expr/node_algorithm.h"
+#include "options/base_options.h"
+#include "smt/env.h"
+#include "theory/rewriter.h"
+
+namespace cvc5::internal {
+namespace theory {
+namespace quantifiers {
+
+bool OracleChecker::checkConsistent(Node app,
+ Node val,
+ std::vector<Node>& lemmas)
+{
+ Node result = evaluateApp(app);
+ if (result != val)
+ {
+ lemmas.push_back(result.eqNode(app));
+ return false;
+ }
+ return true;
+}
+
+Node OracleChecker::evaluateApp(Node app)
+{
+ Assert(app.getKind() == kind::APPLY_UF);
+ Node f = app.getOperator();
+ Assert(OracleCaller::isOracleFunction(f));
+ // get oracle caller
+ if (d_callers.find(f) == d_callers.end())
+ {
+ d_callers.insert(std::pair<Node, OracleCaller>(f, OracleCaller(f)));
+ }
+ OracleCaller& caller = d_callers.at(f);
+
+ // get oracle result
+ Node ret;
+ int runResult;
+ caller.callOracle(app, ret, runResult);
+ Assert(!ret.isNull());
+ return ret;
+}
+
+Node OracleChecker::evaluate(Node n)
+{
+ // same as convert
+ return convert(n);
+}
+
+Node OracleChecker::postConvert(Node n)
+{
+ Trace("oracle-checker-debug") << "postConvert: " << n << std::endl;
+ // if it is an oracle function applied to constant arguments
+ if (n.getKind() == kind::APPLY_UF
+ && OracleCaller::isOracleFunction(n.getOperator()))
+ {
+ bool allConst = true;
+ for (const Node& nc : n)
+ {
+ if (nc.isConst())
+ {
+ continue;
+ }
+ // special case: assume all closed lambdas are constants
+ if (nc.getKind() == kind::LAMBDA)
+ {
+ // if the lambda does not have a free variable (BOUND_VARIABLE)
+ if (!expr::hasFreeVar(nc))
+ {
+ // it also cannot have any free symbol
+ std::unordered_set<Node> syms;
+ expr::getSymbols(nc, syms);
+ if (syms.empty())
+ {
+ continue;
+ }
+ }
+ }
+ // non-constant argument, fail
+ allConst = false;
+ break;
+ }
+ if (allConst)
+ {
+ // evaluate the application
+ return evaluateApp(n);
+ }
+ }
+ // otherwise, always rewrite
+ return Rewriter::rewrite(n);
+}
+bool OracleChecker::hasOracles() const { return !d_callers.empty(); }
+bool OracleChecker::hasOracleCalls(Node f) const
+{
+ std::map<Node, OracleCaller>::const_iterator it = d_callers.find(f);
+ return it != d_callers.end();
+}
+const std::map<Node, Node>& OracleChecker::getOracleCalls(Node f) const
+{
+ Assert(hasOracleCalls(f));
+ std::map<Node, OracleCaller>::const_iterator it = d_callers.find(f);
+ return it->second.getCachedResults();
+}
+
+} // namespace quantifiers
+} // namespace theory
+} // namespace cvc5::internal
--- /dev/null
+/******************************************************************************
+ * Top contributors (to current version):
+ * Andrew Reynolds, Elizabeth Polgreen
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved. See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Oracle checker, caches oracle caller objects
+ */
+
+#include "cvc5_private.h"
+
+#ifndef CVC5__THEORY__QUANTIFIERS__ORACLE_CHECKER_H
+#define CVC5__THEORY__QUANTIFIERS__ORACLE_CHECKER_H
+
+#include <vector>
+
+#include "expr/node.h"
+#include "expr/node_converter.h"
+#include "expr/oracle_caller.h"
+#include "smt/env_obj.h"
+
+namespace cvc5::internal {
+namespace theory {
+namespace quantifiers {
+
+/**
+ * Oracle checker.
+ *
+ * This maintains callers for all oracle functions, and can be used to evaluate
+ * terms that contain oracle functions. In particular, all oracle functions
+ * that are applied to "value-like" arguments only are invoked and replaced
+ * by their return. A term is "value-like" if it is constant (Node::isConst())
+ * or a closed lambda term.
+ *
+ * For example, if f is an oracle function, where evaluating the oracle
+ * f on 4 returns 5, and evaluating on 7 returns 10, this class acts as a
+ * node converter that may transform:
+ * f(f(4)+2) ---> f(5+2) ---> f(7) ---> 10
+ */
+class OracleChecker : protected EnvObj, public NodeConverter
+{
+ public:
+ OracleChecker(Env& env) : EnvObj(env) {}
+ ~OracleChecker() {}
+
+ /**
+ * Check predicted io pair is consistent, generate a lemma if
+ * not. This is used to check whether a definition of an oracle function
+ * is consistent in the model.
+ *
+ * For example, calling this method with app = f(c) and val = d will
+ * check whether we have evalauted the oracle associated with f on input
+ * c. If not, we invoke the oracle; otherwise we retrieve its cached value.
+ * If this output d' is not d, then this method adds d' = f(c) to lemmas.
+ */
+ bool checkConsistent(Node app, Node val, std::vector<Node>& lemmas);
+ /**
+ * Evaluate an oracle application. Given input f(c), where f is an oracle
+ * function symbol, this returns the result of invoking the oracle associated
+ * with f. This may either correspond to a cached value, or otherwise will
+ * invoke the oracle.
+ */
+ Node evaluateApp(Node app);
+
+ /**
+ * Evaluate all oracle function applications (recursively) in n. This is an
+ * alias for convert.
+ */
+ Node evaluate(Node n);
+
+ /** Has oracles? Have we invoked any oracle calls */
+ bool hasOracles() const;
+ /** Has oracle calls for oracle function symbol f. */
+ bool hasOracleCalls(Node f) const;
+ /** Get the cached results for oracle function symbol f */
+ const std::map<Node, Node>& getOracleCalls(Node f) const;
+
+ private:
+ /**
+ * Call back to convert, which evaluates oracle function applications and
+ * rewrites all other nodes.
+ */
+ Node postConvert(Node n) override;
+ /** map of oracle interface nodes to oracle callers **/
+ std::map<Node, OracleCaller> d_callers;
+};
+
+} // namespace quantifiers
+} // namespace theory
+} // namespace cvc5::internal
+
+#endif