Cleaning up friend relationships for commands (#6254)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 2 Apr 2021 16:55:16 +0000 (11:55 -0500)
committerGitHub <noreply@github.com>
Fri, 2 Apr 2021 16:55:16 +0000 (16:55 +0000)
src/api/cvc4cpp.h
src/smt/command.cpp
src/smt/command.h

index c446fcaf5cbeac7bbf093d17842572c419a01785..8f4977b28469e2cc5ab9d1adb684122221ca45e8 100644 (file)
@@ -36,38 +36,14 @@ template <bool ref_count>
 class NodeTemplate;
 typedef NodeTemplate<true> Node;
 
-class AssertCommand;
-class BlockModelValuesCommand;
-class CheckSatCommand;
-class CheckSatAssumingCommand;
-class DatatypeDeclarationCommand;
-class DeclareFunctionCommand;
-class DeclareHeapCommand;
-class DeclareSortCommand;
-class DeclareSygusVarCommand;
-class DefineFunctionCommand;
-class DefineFunctionRecCommand;
-class DefineSortCommand;
+class Command;
 class DType;
 class DTypeConstructor;
 class DTypeSelector;
-class GetAbductCommand;
-class GetInterpolCommand;
-class GetModelCommand;
-class GetQuantifierEliminationCommand;
-class GetUnsatCoreCommand;
-class GetValueCommand;
 class NodeManager;
-class ResetCommand;
-class SetUserAttributeCommand;
-class SimplifyCommand;
 class SmtEngine;
-class SygusConstraintCommand;
-class SygusInvConstraintCommand;
-class SynthFunCommand;
 class TypeNode;
 class Options;
-class QueryCommand;
 class Random;
 class Result;
 
@@ -240,16 +216,7 @@ class Datatype;
  */
 class CVC4_EXPORT Sort
 {
-  friend class cvc5::DatatypeDeclarationCommand;
-  friend class cvc5::DeclareFunctionCommand;
-  friend class cvc5::DeclareHeapCommand;
-  friend class cvc5::DeclareSortCommand;
-  friend class cvc5::DeclareSygusVarCommand;
-  friend class cvc5::DefineSortCommand;
-  friend class cvc5::GetAbductCommand;
-  friend class cvc5::GetInterpolCommand;
-  friend class cvc5::GetModelCommand;
-  friend class cvc5::SynthFunCommand;
+  friend class cvc5::Command;
   friend class DatatypeConstructor;
   friend class DatatypeConstructorDecl;
   friend class DatatypeSelector;
@@ -890,25 +857,7 @@ class CVC4_EXPORT Op
  */
 class CVC4_EXPORT Term
 {
-  friend class cvc5::AssertCommand;
-  friend class cvc5::BlockModelValuesCommand;
-  friend class cvc5::CheckSatCommand;
-  friend class cvc5::CheckSatAssumingCommand;
-  friend class cvc5::DeclareSygusVarCommand;
-  friend class cvc5::DefineFunctionCommand;
-  friend class cvc5::DefineFunctionRecCommand;
-  friend class cvc5::GetAbductCommand;
-  friend class cvc5::GetInterpolCommand;
-  friend class cvc5::GetModelCommand;
-  friend class cvc5::GetQuantifierEliminationCommand;
-  friend class cvc5::GetUnsatCoreCommand;
-  friend class cvc5::GetValueCommand;
-  friend class cvc5::SetUserAttributeCommand;
-  friend class cvc5::SimplifyCommand;
-  friend class cvc5::SygusConstraintCommand;
-  friend class cvc5::SygusInvConstraintCommand;
-  friend class cvc5::SynthFunCommand;
-  friend class cvc5::QueryCommand;
+  friend class cvc5::Command;
   friend class Datatype;
   friend class DatatypeConstructor;
   friend class DatatypeSelector;
@@ -2136,9 +2085,7 @@ std::ostream& operator<<(std::ostream& out,
  */
 class CVC4_EXPORT Grammar
 {
-  friend class cvc5::GetAbductCommand;
-  friend class cvc5::GetInterpolCommand;
-  friend class cvc5::SynthFunCommand;
+  friend class cvc5::Command;
   friend class Solver;
 
  public:
@@ -2323,7 +2270,7 @@ class CVC4_EXPORT Solver
   friend class DatatypeSelector;
   friend class Grammar;
   friend class Op;
-  friend class cvc5::ResetCommand;
+  friend class cvc5::Command;
   friend class Sort;
   friend class Term;
 
index e4b179cf46345024423cecf84ffdbddb23fde5da..4a6efe7136f6002509e27e1d681ee34607ccf91e 100644 (file)
@@ -237,6 +237,36 @@ void Command::printResult(std::ostream& out, uint32_t verbosity) const
   }
 }
 
+Node Command::termToNode(const api::Term& term) { return term.getNode(); }
+
+std::vector<Node> Command::termVectorToNodes(
+    const std::vector<api::Term>& terms)
+{
+  return api::Term::termVectorToNodes(terms);
+}
+
+TypeNode Command::sortToTypeNode(const api::Sort& sort)
+{
+  return sort.getTypeNode();
+}
+
+std::vector<TypeNode> Command::sortVectorToTypeNodes(
+    const std::vector<api::Sort>& sorts)
+{
+  return api::Sort::sortVectorToTypeNodes(sorts);
+}
+
+TypeNode Command::grammarToTypeNode(api::Grammar* grammar)
+{
+  return grammar == nullptr ? TypeNode::null()
+                            : sortToTypeNode(grammar->resolve());
+}
+
+Options& Command::getOriginalOptionsFrom(api::Solver* s)
+{
+  return *s->d_originalOptions.get();
+}
+
 /* -------------------------------------------------------------------------- */
 /* class EmptyCommand                                                         */
 /* -------------------------------------------------------------------------- */
@@ -310,7 +340,7 @@ void AssertCommand::invoke(api::Solver* solver, SymbolManager* sm)
 {
   try
   {
-    solver->getSmtEngine()->assertFormula(d_term.getNode(), d_inUnsatCore);
+    solver->getSmtEngine()->assertFormula(termToNode(d_term), d_inUnsatCore);
     d_commandStatus = CommandSuccess::instance();
   }
   catch (UnsafeInterruptException& e)
@@ -335,7 +365,7 @@ void AssertCommand::toStream(std::ostream& out,
                              size_t dag,
                              OutputLanguage language) const
 {
-  Printer::getPrinter(language)->toStreamCmdAssert(out, d_term.getNode());
+  Printer::getPrinter(language)->toStreamCmdAssert(out, termToNode(d_term));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -456,7 +486,7 @@ void CheckSatCommand::toStream(std::ostream& out,
                                size_t dag,
                                OutputLanguage language) const
 {
-  Printer::getPrinter(language)->toStreamCmdCheckSat(out, d_term.getNode());
+  Printer::getPrinter(language)->toStreamCmdCheckSat(out, termToNode(d_term));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -531,7 +561,7 @@ void CheckSatAssumingCommand::toStream(std::ostream& out,
                                        OutputLanguage language) const
 {
   Printer::getPrinter(language)->toStreamCmdCheckSatAssuming(
-      out, api::Term::termVectorToNodes(d_terms));
+      out, termVectorToNodes(d_terms));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -584,7 +614,7 @@ void QueryCommand::toStream(std::ostream& out,
                             size_t dag,
                             OutputLanguage language) const
 {
-  Printer::getPrinter(language)->toStreamCmdQuery(out, d_term.getNode());
+  Printer::getPrinter(language)->toStreamCmdQuery(out, termToNode(d_term));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -622,7 +652,7 @@ void DeclareSygusVarCommand::toStream(std::ostream& out,
                                       OutputLanguage language) const
 {
   Printer::getPrinter(language)->toStreamCmdDeclareVar(
-      out, d_var.getNode(), d_sort.getTypeNode());
+      out, termToNode(d_var), sortToTypeNode(d_sort));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -677,14 +707,13 @@ void SynthFunCommand::toStream(std::ostream& out,
                                size_t dag,
                                OutputLanguage language) const
 {
-  std::vector<Node> nodeVars = api::Term::termVectorToNodes(d_vars);
+  std::vector<Node> nodeVars = termVectorToNodes(d_vars);
   Printer::getPrinter(language)->toStreamCmdSynthFun(
       out,
-      d_fun.getNode(),
+      termToNode(d_fun),
       nodeVars,
       d_isInv,
-      d_grammar == nullptr ? TypeNode::null()
-                           : d_grammar->resolve().getTypeNode());
+      d_grammar == nullptr ? TypeNode::null() : grammarToTypeNode(d_grammar));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -725,7 +754,7 @@ void SygusConstraintCommand::toStream(std::ostream& out,
                                       size_t dag,
                                       OutputLanguage language) const
 {
-  Printer::getPrinter(language)->toStreamCmdConstraint(out, d_term.getNode());
+  Printer::getPrinter(language)->toStreamCmdConstraint(out, termToNode(d_term));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -782,10 +811,10 @@ void SygusInvConstraintCommand::toStream(std::ostream& out,
 {
   Printer::getPrinter(language)->toStreamCmdInvConstraint(
       out,
-      d_predicates[0].getNode(),
-      d_predicates[1].getNode(),
-      d_predicates[2].getNode(),
-      d_predicates[3].getNode());
+      termToNode(d_predicates[0]),
+      termToNode(d_predicates[1]),
+      termToNode(d_predicates[2]),
+      termToNode(d_predicates[3]));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -866,7 +895,7 @@ void ResetCommand::invoke(api::Solver* solver, SymbolManager* sm)
   {
     sm->reset();
     Options opts;
-    opts.copyValues(*solver->d_originalOptions);
+    opts.copyValues(getOriginalOptionsFrom(solver));
     // This reconstructs a new solver object at the same memory location as the
     // current one. Note that this command does not own the solver object!
     // It may be safer to instead make the ResetCommand a special case in the
@@ -1136,7 +1165,7 @@ void DeclareFunctionCommand::toStream(std::ostream& out,
                                       OutputLanguage language) const
 {
   Printer::getPrinter(language)->toStreamCmdDeclareFunction(
-      out, d_func.toString(), d_sort.getTypeNode());
+      out, d_func.toString(), sortToTypeNode(d_sort));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -1175,7 +1204,7 @@ void DeclareSortCommand::toStream(std::ostream& out,
                                   OutputLanguage language) const
 {
   Printer::getPrinter(language)->toStreamCmdDeclareType(out,
-                                                        d_sort.getTypeNode());
+                                                        sortToTypeNode(d_sort));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -1218,10 +1247,7 @@ void DefineSortCommand::toStream(std::ostream& out,
                                  OutputLanguage language) const
 {
   Printer::getPrinter(language)->toStreamCmdDefineType(
-      out,
-      d_symbol,
-      api::Sort::sortVectorToTypeNodes(d_params),
-      d_sort.getTypeNode());
+      out, d_symbol, sortVectorToTypeNodes(d_params), sortToTypeNode(d_sort));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -1296,9 +1322,9 @@ void DefineFunctionCommand::toStream(std::ostream& out,
   Printer::getPrinter(language)->toStreamCmdDefineFunction(
       out,
       d_func.toString(),
-      api::Term::termVectorToNodes(d_formals),
-      d_func.getNode().getType().getRangeType(),
-      d_formula.getNode());
+      termVectorToNodes(d_formals),
+      termToNode(d_func).getType().getRangeType(),
+      termToNode(d_formula));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -1376,14 +1402,11 @@ void DefineFunctionRecCommand::toStream(std::ostream& out,
   formals.reserve(d_formals.size());
   for (const std::vector<api::Term>& formal : d_formals)
   {
-    formals.push_back(api::Term::termVectorToNodes(formal));
+    formals.push_back(termVectorToNodes(formal));
   }
 
   Printer::getPrinter(language)->toStreamCmdDefineFunctionRec(
-      out,
-      api::Term::termVectorToNodes(d_funcs),
-      formals,
-      api::Term::termVectorToNodes(d_formulas));
+      out, termVectorToNodes(d_funcs), formals, termVectorToNodes(d_formulas));
 }
 /* -------------------------------------------------------------------------- */
 /* class DeclareHeapCommand                                                   */
@@ -1417,7 +1440,7 @@ void DeclareHeapCommand::toStream(std::ostream& out,
                                   OutputLanguage language) const
 {
   Printer::getPrinter(language)->toStreamCmdDeclareHeap(
-      out, d_locSort.getTypeNode(), d_dataSort.getTypeNode());
+      out, sortToTypeNode(d_locSort), sortToTypeNode(d_dataSort));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -1460,11 +1483,10 @@ void SetUserAttributeCommand::invoke(api::Solver* solver, SymbolManager* sm)
   {
     if (!d_term.isNull())
     {
-      solver->getSmtEngine()->setUserAttribute(
-          d_attr,
-          d_term.getNode(),
-          api::Term::termVectorToNodes(d_termValues),
-          d_strValue);
+      solver->getSmtEngine()->setUserAttribute(d_attr,
+                                               termToNode(d_term),
+                                               termVectorToNodes(d_termValues),
+                                               d_strValue);
     }
     d_commandStatus = CommandSuccess::instance();
   }
@@ -1490,7 +1512,7 @@ void SetUserAttributeCommand::toStream(std::ostream& out,
                                        OutputLanguage language) const
 {
   Printer::getPrinter(language)->toStreamCmdSetUserAttribute(
-      out, d_attr, d_term.getNode());
+      out, d_attr, termToNode(d_term));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -1543,7 +1565,7 @@ void SimplifyCommand::toStream(std::ostream& out,
                                size_t dag,
                                OutputLanguage language) const
 {
-  Printer::getPrinter(language)->toStreamCmdSimplify(out, d_term.getNode());
+  Printer::getPrinter(language)->toStreamCmdSimplify(out, termToNode(d_term));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -1624,7 +1646,7 @@ void GetValueCommand::toStream(std::ostream& out,
                                OutputLanguage language) const
 {
   Printer::getPrinter(language)->toStreamCmdGetValue(
-      out, api::Term::termVectorToNodes(d_terms));
+      out, termVectorToNodes(d_terms));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -1722,12 +1744,12 @@ void GetModelCommand::invoke(api::Solver* solver, SymbolManager* sm)
     std::vector<api::Sort> declareSorts = sm->getModelDeclareSorts();
     for (const api::Sort& s : declareSorts)
     {
-      d_result->addDeclarationSort(s.getTypeNode());
+      d_result->addDeclarationSort(sortToTypeNode(s));
     }
     std::vector<api::Term> declareTerms = sm->getModelDeclareTerms();
     for (const api::Term& t : declareTerms)
     {
-      d_result->addDeclarationTerm(t.getNode());
+      d_result->addDeclarationTerm(termToNode(t));
     }
     d_commandStatus = CommandSuccess::instance();
   }
@@ -1877,7 +1899,7 @@ void BlockModelValuesCommand::toStream(std::ostream& out,
                                        OutputLanguage language) const
 {
   Printer::getPrinter(language)->toStreamCmdBlockModelValues(
-      out, api::Term::termVectorToNodes(d_terms));
+      out, termVectorToNodes(d_terms));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -2120,7 +2142,7 @@ void GetInterpolCommand::toStream(std::ostream& out,
                                   OutputLanguage language) const
 {
   Printer::getPrinter(language)->toStreamCmdGetInterpol(
-      out, d_name, d_conj.getNode(), d_sygus_grammar->resolve().getTypeNode());
+      out, d_name, termToNode(d_conj), grammarToTypeNode(d_sygus_grammar));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -2205,7 +2227,7 @@ void GetAbductCommand::toStream(std::ostream& out,
                                 OutputLanguage language) const
 {
   Printer::getPrinter(language)->toStreamCmdGetAbduct(
-      out, d_name, d_conj.getNode(), d_sygus_grammar->resolve().getTypeNode());
+      out, d_name, termToNode(d_conj), grammarToTypeNode(d_sygus_grammar));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -2281,7 +2303,7 @@ void GetQuantifierEliminationCommand::toStream(std::ostream& out,
                                                OutputLanguage language) const
 {
   Printer::getPrinter(language)->toStreamCmdGetQuantifierElimination(
-      out, d_term.getNode());
+      out, termToNode(d_term));
 }
 
 /* -------------------------------------------------------------------------- */
@@ -2381,7 +2403,7 @@ void GetUnsatCoreCommand::printResult(std::ostream& out,
     if (options::dumpUnsatCoresFull())
     {
       // use the assertions
-      UnsatCore ucr(api::Term::termVectorToNodes(d_result));
+      UnsatCore ucr(termVectorToNodes(d_result));
       ucr.toStream(out);
     }
     else
@@ -2816,7 +2838,7 @@ void DatatypeDeclarationCommand::toStream(std::ostream& out,
                                           OutputLanguage language) const
 {
   Printer::getPrinter(language)->toStreamCmdDatatypeDeclaration(
-      out, api::Sort::sortVectorToTypeNodes(d_datatypes));
+      out, sortVectorToTypeNodes(d_datatypes));
 }
 
 }  // namespace cvc5
index 2d13a22468260c31a685f4bc17bd8dbbb4104616..6c3b4f0e47ca5483a7698c3c0d73245d3dac9114 100644 (file)
@@ -277,6 +277,25 @@ class CVC4_EXPORT Command
    * successful execution.
    */
   bool d_muted;
+
+ protected:
+  // These methods rely on Command being a friend of classes in the API.
+  // Subclasses of command should use these methods for conversions,
+  // which is currently necessary for e.g. printing commands.
+  /** Helper to convert a Term to an internal Node */
+  static Node termToNode(const api::Term& term);
+  /** Helper to convert a vector of Terms to internal Nodes. */
+  static std::vector<Node> termVectorToNodes(
+      const std::vector<api::Term>& terms);
+  /** Helper to convert a Sort to an internal TypeNode */
+  static TypeNode sortToTypeNode(const api::Sort& sort);
+  /** Helper to convert a vector of Sorts to internal TypeNodes. */
+  static std::vector<TypeNode> sortVectorToTypeNodes(
+      const std::vector<api::Sort>& sorts);
+  /** Helper to convert a Grammar to an internal TypeNode */
+  static TypeNode grammarToTypeNode(api::Grammar* grammar);
+  /** Get original options from the solver (for ResetCommand) */
+  Options& getOriginalOptionsFrom(api::Solver* s);
 }; /* class Command */
 
 /**