Refactor declare oracle command (#8742)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 17 May 2022 22:26:53 +0000 (17:26 -0500)
committerGitHub <noreply@github.com>
Tue, 17 May 2022 22:26:53 +0000 (22:26 +0000)
In preparation for adding oracle functions to API.

src/printer/printer.cpp
src/printer/printer.h
src/printer/smt2/smt2_printer.cpp
src/printer/smt2/smt2_printer.h
src/smt/command.cpp
src/smt/command.h

index eea164b20ad81550d61812183c62d71459142321..fa46350becdb12beb3b68dddbb6a77f97f5a0cfe 100644 (file)
@@ -215,7 +215,8 @@ void Printer::toStreamCmdDeclarePool(std::ostream& out,
 }
 
 void Printer::toStreamCmdDeclareOracleFun(std::ostream& out,
-                                          Node fun,
+                                          const std::string& id,
+                                          TypeNode type,
                                           const std::string& binName) const
 {
   printUnknownCommand(out, "declare-oracle-fun");
index 424726fddf8e53b0495937716831fafff161801e..f4ad5443d9be640c31bef39b8ee414678b689b49 100644 (file)
@@ -99,7 +99,8 @@ class Printer
                                       const std::vector<Node>& initValue) const;
   /** Print declare-oracle-fun command */
   virtual void toStreamCmdDeclareOracleFun(std::ostream& out,
-                                           Node fun,
+                                           const std::string& id,
+                                           TypeNode type,
                                            const std::string& binName) const;
 
   /** Print declare-sort command */
index 1276c89937df59a8130f912831770141dc01fbdf..62d6f3bb4c798806c798186f4a9d5c594034e246 100644 (file)
@@ -1589,11 +1589,12 @@ void Smt2Printer::toStreamCmdDeclareFunction(std::ostream& out,
 }
 
 void Smt2Printer::toStreamCmdDeclareOracleFun(std::ostream& out,
-                                              Node fun,
+                                              const std::string& id,
+                                              TypeNode type,
                                               const std::string& binName) const
 {
-  out << "(declare-oracle-fun " << fun << " ";
-  toStreamDeclareType(out, fun.getType());
+  out << "(declare-oracle-fun " << cvc5::internal::quoteSymbol(id) << " ";
+  toStreamDeclareType(out, type);
   out << " " << binName << ")" << std::endl;
 }
 
index 9198c4628ddb9f97a4ee19b70d8dec48ffe513fc..a65be9a32f13c4ab4d9872bdceaaa387f43ba677 100644 (file)
@@ -76,7 +76,8 @@ class Smt2Printer : public cvc5::internal::Printer
 
   /** Print declare-oracle-fun command */
   void toStreamCmdDeclareOracleFun(std::ostream& out,
-                                   Node fun,
+                                   const std::string& id,
+                                   TypeNode type,
                                    const std::string& binName) const override;
 
   /** Print declare-pool command */
index d3a5702f3c30e32acf3c76dc3e9e13a02f3c53df..885b79182a25bb5db39dd950a56cbd2b4bd348f3 100644 (file)
@@ -1177,17 +1177,25 @@ void DeclarePoolCommand::toStream(std::ostream& out,
 /* class DeclareOracleFunCommand */
 /* -------------------------------------------------------------------------- */
 
-DeclareOracleFunCommand::DeclareOracleFunCommand(Term func)
-    : d_func(func), d_binName("")
+DeclareOracleFunCommand::DeclareOracleFunCommand(const std::string& id,
+                                                 Sort sort)
+    : d_id(id), d_sort(sort), d_binName("")
 {
 }
-DeclareOracleFunCommand::DeclareOracleFunCommand(Term func,
+DeclareOracleFunCommand::DeclareOracleFunCommand(const std::string& id,
+                                                 Sort sort,
                                                  const std::string& binName)
-    : d_func(func), d_binName(binName)
+    : d_id(id), d_sort(sort), d_binName(binName)
 {
 }
 
-Term DeclareOracleFunCommand::getFunction() const { return d_func; }
+const std::string& DeclareOracleFunCommand::getIdentifier() const
+{
+  return d_id;
+}
+
+Sort DeclareOracleFunCommand::getSort() const { return d_sort; }
+
 const std::string& DeclareOracleFunCommand::getBinaryName() const
 {
   return d_binName;
@@ -1195,16 +1203,25 @@ const std::string& DeclareOracleFunCommand::getBinaryName() const
 
 void DeclareOracleFunCommand::invoke(Solver* solver, SymbolManager* sm)
 {
-  // Notice that the oracle function is already declared by the parser so that
-  // the symbol is bound eagerly.
-  // mark that it will be printed in the model
-  sm->addModelDeclarationTerm(d_func);
+  std::vector<Sort> args;
+  Sort ret;
+  if (d_sort.isFunction())
+  {
+    args = d_sort.getFunctionDomainSorts();
+    ret = d_sort.getFunctionCodomainSort();
+  }
+  else
+  {
+    ret = d_sort;
+  }
+  // will call solver declare oracle function when available in API
   d_commandStatus = CommandSuccess::instance();
 }
 
 Command* DeclareOracleFunCommand::clone() const
 {
-  DeclareOracleFunCommand* dfc = new DeclareOracleFunCommand(d_func, d_binName);
+  DeclareOracleFunCommand* dfc =
+      new DeclareOracleFunCommand(d_id, d_sort, d_binName);
   return dfc;
 }
 
@@ -1219,7 +1236,7 @@ void DeclareOracleFunCommand::toStream(std::ostream& out,
                                        Language language) const
 {
   Printer::getPrinter(language)->toStreamCmdDeclareOracleFun(
-      out, termToNode(d_func), d_binName);
+      out, d_id, sortToTypeNode(d_sort), d_binName);
 }
 
 /* -------------------------------------------------------------------------- */
index 9faae1fcaa59cc1c03cce9076b90ac8a4f08f75b..66854a03be9e6afd6535aca83e7be2636c394533 100644 (file)
@@ -467,9 +467,12 @@ class CVC5_EXPORT DeclarePoolCommand : public DeclarationDefinitionCommand
 class CVC5_EXPORT DeclareOracleFunCommand : public Command
 {
  public:
-  DeclareOracleFunCommand(Term func);
-  DeclareOracleFunCommand(Term func, const std::string& binName);
-  Term getFunction() const;
+  DeclareOracleFunCommand(const std::string& id, Sort sort);
+  DeclareOracleFunCommand(const std::string& id,
+                          Sort sort,
+                          const std::string& binName);
+  const std::string& getIdentifier() const;
+  Sort getSort() const;
   const std::string& getBinaryName() const;
 
   void invoke(Solver* solver, SymbolManager* sm) override;
@@ -482,8 +485,10 @@ class CVC5_EXPORT DeclareOracleFunCommand : public Command
                     internal::Language::LANG_AUTO) const override;
 
  protected:
-  /** The oracle function */
-  Term d_func;
+  /** The identifier */
+  std::string d_id;
+  /** The (possibly function) sort */
+  Sort d_sort;
   /** The binary name, or "" if none is provided */
   std::string d_binName;
 };