From 0a8938d9c2774a2db4da7354d57251f262c0987f Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Tue, 17 May 2022 17:26:53 -0500 Subject: [PATCH] Refactor declare oracle command (#8742) In preparation for adding oracle functions to API. --- src/printer/printer.cpp | 3 ++- src/printer/printer.h | 3 ++- src/printer/smt2/smt2_printer.cpp | 7 +++--- src/printer/smt2/smt2_printer.h | 3 ++- src/smt/command.cpp | 39 ++++++++++++++++++++++--------- src/smt/command.h | 15 ++++++++---- 6 files changed, 48 insertions(+), 22 deletions(-) diff --git a/src/printer/printer.cpp b/src/printer/printer.cpp index eea164b20..fa46350be 100644 --- a/src/printer/printer.cpp +++ b/src/printer/printer.cpp @@ -215,7 +215,8 @@ void Printer::toStreamCmdDeclarePool(std::ostream& out, } void Printer::toStreamCmdDeclareOracleFun(std::ostream& out, - Node fun, + const std::string& id, + TypeNode type, const std::string& binName) const { printUnknownCommand(out, "declare-oracle-fun"); diff --git a/src/printer/printer.h b/src/printer/printer.h index 424726fdd..f4ad5443d 100644 --- a/src/printer/printer.h +++ b/src/printer/printer.h @@ -99,7 +99,8 @@ class Printer const std::vector& initValue) const; /** Print declare-oracle-fun command */ virtual void toStreamCmdDeclareOracleFun(std::ostream& out, - Node fun, + const std::string& id, + TypeNode type, const std::string& binName) const; /** Print declare-sort command */ diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 1276c8993..62d6f3bb4 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1589,11 +1589,12 @@ void Smt2Printer::toStreamCmdDeclareFunction(std::ostream& out, } void Smt2Printer::toStreamCmdDeclareOracleFun(std::ostream& out, - Node fun, + const std::string& id, + TypeNode type, const std::string& binName) const { - out << "(declare-oracle-fun " << fun << " "; - toStreamDeclareType(out, fun.getType()); + out << "(declare-oracle-fun " << cvc5::internal::quoteSymbol(id) << " "; + toStreamDeclareType(out, type); out << " " << binName << ")" << std::endl; } diff --git a/src/printer/smt2/smt2_printer.h b/src/printer/smt2/smt2_printer.h index 9198c4628..a65be9a32 100644 --- a/src/printer/smt2/smt2_printer.h +++ b/src/printer/smt2/smt2_printer.h @@ -76,7 +76,8 @@ class Smt2Printer : public cvc5::internal::Printer /** Print declare-oracle-fun command */ void toStreamCmdDeclareOracleFun(std::ostream& out, - Node fun, + const std::string& id, + TypeNode type, const std::string& binName) const override; /** Print declare-pool command */ diff --git a/src/smt/command.cpp b/src/smt/command.cpp index d3a5702f3..885b79182 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -1177,17 +1177,25 @@ void DeclarePoolCommand::toStream(std::ostream& out, /* class DeclareOracleFunCommand */ /* -------------------------------------------------------------------------- */ -DeclareOracleFunCommand::DeclareOracleFunCommand(Term func) - : d_func(func), d_binName("") +DeclareOracleFunCommand::DeclareOracleFunCommand(const std::string& id, + Sort sort) + : d_id(id), d_sort(sort), d_binName("") { } -DeclareOracleFunCommand::DeclareOracleFunCommand(Term func, +DeclareOracleFunCommand::DeclareOracleFunCommand(const std::string& id, + Sort sort, const std::string& binName) - : d_func(func), d_binName(binName) + : d_id(id), d_sort(sort), d_binName(binName) { } -Term DeclareOracleFunCommand::getFunction() const { return d_func; } +const std::string& DeclareOracleFunCommand::getIdentifier() const +{ + return d_id; +} + +Sort DeclareOracleFunCommand::getSort() const { return d_sort; } + const std::string& DeclareOracleFunCommand::getBinaryName() const { return d_binName; @@ -1195,16 +1203,25 @@ const std::string& DeclareOracleFunCommand::getBinaryName() const void DeclareOracleFunCommand::invoke(Solver* solver, SymbolManager* sm) { - // Notice that the oracle function is already declared by the parser so that - // the symbol is bound eagerly. - // mark that it will be printed in the model - sm->addModelDeclarationTerm(d_func); + std::vector args; + Sort ret; + if (d_sort.isFunction()) + { + args = d_sort.getFunctionDomainSorts(); + ret = d_sort.getFunctionCodomainSort(); + } + else + { + ret = d_sort; + } + // will call solver declare oracle function when available in API d_commandStatus = CommandSuccess::instance(); } Command* DeclareOracleFunCommand::clone() const { - DeclareOracleFunCommand* dfc = new DeclareOracleFunCommand(d_func, d_binName); + DeclareOracleFunCommand* dfc = + new DeclareOracleFunCommand(d_id, d_sort, d_binName); return dfc; } @@ -1219,7 +1236,7 @@ void DeclareOracleFunCommand::toStream(std::ostream& out, Language language) const { Printer::getPrinter(language)->toStreamCmdDeclareOracleFun( - out, termToNode(d_func), d_binName); + out, d_id, sortToTypeNode(d_sort), d_binName); } /* -------------------------------------------------------------------------- */ diff --git a/src/smt/command.h b/src/smt/command.h index 9faae1fca..66854a03b 100644 --- a/src/smt/command.h +++ b/src/smt/command.h @@ -467,9 +467,12 @@ class CVC5_EXPORT DeclarePoolCommand : public DeclarationDefinitionCommand class CVC5_EXPORT DeclareOracleFunCommand : public Command { public: - DeclareOracleFunCommand(Term func); - DeclareOracleFunCommand(Term func, const std::string& binName); - Term getFunction() const; + DeclareOracleFunCommand(const std::string& id, Sort sort); + DeclareOracleFunCommand(const std::string& id, + Sort sort, + const std::string& binName); + const std::string& getIdentifier() const; + Sort getSort() const; const std::string& getBinaryName() const; void invoke(Solver* solver, SymbolManager* sm) override; @@ -482,8 +485,10 @@ class CVC5_EXPORT DeclareOracleFunCommand : public Command internal::Language::LANG_AUTO) const override; protected: - /** The oracle function */ - Term d_func; + /** The identifier */ + std::string d_id; + /** The (possibly function) sort */ + Sort d_sort; /** The binary name, or "" if none is provided */ std::string d_binName; }; -- 2.30.2