Add declare model symbol methods to SymbolManager and Model (#5480)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 23 Nov 2020 12:40:47 +0000 (06:40 -0600)
committerGitHub <noreply@github.com>
Mon, 23 Nov 2020 12:40:47 +0000 (06:40 -0600)
This is in preparation for the symbol manager determining which symbols are printed in the model.

src/expr/symbol_manager.cpp
src/expr/symbol_manager.h
src/smt/model.cpp
src/smt/model.h

index a163e503d240518852fedb42ab5bb9ef170776cf..f82845fe349c3b8fdcee2acfb90cdd5f9bc2feaf 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "context/cdhashmap.h"
 #include "context/cdhashset.h"
+#include "context/cdlist.h"
 #include "context/cdo.h"
 
 using namespace CVC4::context;
@@ -28,12 +29,16 @@ class SymbolManager::Implementation
 {
   using TermStringMap = CDHashMap<api::Term, std::string, api::TermHashFunction>;
   using TermSet = CDHashSet<api::Term, api::TermHashFunction>;
+  using SortList = CDList<api::Sort>;
+  using TermList = CDList<api::Term>;
 
  public:
   Implementation()
       : d_context(),
         d_names(&d_context),
         d_namedAsserts(&d_context),
+        d_declareSorts(&d_context),
+        d_declareTerms(&d_context),
         d_hasPushedScope(&d_context, false)
   {
   }
@@ -53,6 +58,14 @@ class SymbolManager::Implementation
                           bool areAssertions = false) const;
   /** get expression names */
   std::map<api::Term, std::string> getExpressionNames(bool areAssertions) const;
+  /** get model declare sorts */
+  std::vector<api::Sort> getModelDeclareSorts() const;
+  /** get model declare terms */
+  std::vector<api::Term> getModelDeclareTerms() const;
+  /** Add declared sort to the list of model declarations. */
+  void addModelDeclarationSort(api::Sort s);
+  /** Add declared term to the list of model declarations. */
+  void addModelDeclarationTerm(api::Term t);
   /** reset */
   void reset();
   /** Push a scope in the expression names. */
@@ -67,6 +80,10 @@ class SymbolManager::Implementation
   TermStringMap d_names;
   /** The set of terms with assertion names */
   TermSet d_namedAsserts;
+  /** Declared sorts (for model printing) */
+  SortList d_declareSorts;
+  /** Declared terms (for model printing) */
+  TermList d_declareTerms;
   /**
    * Have we pushed a scope (e.g. a let or quantifier) in the current context?
    */
@@ -150,6 +167,34 @@ SymbolManager::Implementation::getExpressionNames(bool areAssertions) const
   return emap;
 }
 
+std::vector<api::Sort> SymbolManager::Implementation::getModelDeclareSorts()
+    const
+{
+  std::vector<api::Sort> declareSorts(d_declareSorts.begin(),
+                                      d_declareSorts.end());
+  return declareSorts;
+}
+
+std::vector<api::Term> SymbolManager::Implementation::getModelDeclareTerms()
+    const
+{
+  std::vector<api::Term> declareTerms(d_declareTerms.begin(),
+                                      d_declareTerms.end());
+  return declareTerms;
+}
+
+void SymbolManager::Implementation::addModelDeclarationSort(api::Sort s)
+{
+  Trace("sym-manager") << "addModelDeclarationSort " << s << std::endl;
+  d_declareSorts.push_back(s);
+}
+
+void SymbolManager::Implementation::addModelDeclarationTerm(api::Term t)
+{
+  Trace("sym-manager") << "addModelDeclarationTerm " << t << std::endl;
+  d_declareTerms.push_back(t);
+}
+
 void SymbolManager::Implementation::pushScope(bool isUserContext)
 {
   Trace("sym-manager") << "pushScope, isUserContext = " << isUserContext
@@ -219,6 +264,24 @@ std::map<api::Term, std::string> SymbolManager::getExpressionNames(
 {
   return d_implementation->getExpressionNames(areAssertions);
 }
+std::vector<api::Sort> SymbolManager::getModelDeclareSorts() const
+{
+  return d_implementation->getModelDeclareSorts();
+}
+std::vector<api::Term> SymbolManager::getModelDeclareTerms() const
+{
+  return d_implementation->getModelDeclareTerms();
+}
+
+void SymbolManager::addModelDeclarationSort(api::Sort s)
+{
+  d_implementation->addModelDeclarationSort(s);
+}
+
+void SymbolManager::addModelDeclarationTerm(api::Term t)
+{
+  d_implementation->addModelDeclarationTerm(t);
+}
 
 size_t SymbolManager::scopeLevel() const
 {
index a3ca8e780d2eda13dad0bb8c666365a575e299ce..06b01da8b81e8855d0d1422fab5da656443839fb 100644 (file)
@@ -92,6 +92,23 @@ class CVC4_PUBLIC SymbolManager
    */
   std::map<api::Term, std::string> getExpressionNames(
       bool areAssertions = false) const;
+  /**
+   * @return The sorts we have declared that should be printed in the model.
+   */
+  std::vector<api::Sort> getModelDeclareSorts() const;
+  /**
+   * @return The terms we have declared that should be printed in the model.
+   */
+  std::vector<api::Term> getModelDeclareTerms() const;
+  /**
+   * Add declared sort to the list of model declarations.
+   */
+  void addModelDeclarationSort(api::Sort s);
+  /**
+   * Add declared term to the list of model declarations.
+   */
+  void addModelDeclarationTerm(api::Term t);
+
   //---------------------------- end named expressions
   /**
    * Get the scope level of the symbol table.
index fc9ea8fbb930b2053b3feddc50cab6258f349d92..b734ad9e911e69aaa24e152807e7f7d59ea10f7f 100644 (file)
@@ -61,5 +61,19 @@ Node Model::getValue(TNode n) const { return d_tmodel->getValue(n); }
 
 bool Model::hasApproximations() const { return d_tmodel->hasApproximations(); }
 
+void Model::clearModelDeclarations() { d_declareSorts.clear(); }
+
+void Model::addDeclarationSort(TypeNode tn) { d_declareSorts.push_back(tn); }
+
+void Model::addDeclarationTerm(Node n) { d_declareTerms.push_back(n); }
+const std::vector<TypeNode>& Model::getDeclaredSorts() const
+{
+  return d_declareSorts;
+}
+const std::vector<Node>& Model::getDeclaredTerms() const
+{
+  return d_declareTerms;
+}
+
 }  // namespace smt
 }/* CVC4 namespace */
index dc36b5d291535bc6a3a14ecb1e8f0f8ac5532bb4..0913922d17c890774ff2136e45282197fae38aef 100644 (file)
@@ -39,6 +39,9 @@ std::ostream& operator<<(std::ostream&, const Model&);
  * This is the SMT-level model object, that is responsible for maintaining
  * the necessary information for how to print the model, as well as
  * holding a pointer to the underlying implementation of the theory model.
+ *
+ * The model declarations maintained by this class are context-independent
+ * and should be updated when this model is printed.
  */
 class Model {
   friend std::ostream& operator<<(std::ostream&, const Model&);
@@ -49,10 +52,6 @@ class Model {
   Model(SmtEngine& smt, theory::TheoryModel* tm);
   /** virtual destructor */
   ~Model() {}
-  /** get number of commands to report */
-  size_t getNumCommands() const;
-  /** get command */
-  const NodeCommand* getCommand(size_t i) const;
   /** get the smt engine that this model is hooked up to */
   SmtEngine* getSmtEngine() { return &d_smt; }
   /** get the smt engine (as a pointer-to-const) that this model is hooked up to */
@@ -78,6 +77,28 @@ class Model {
   /** Does this model have approximations? */
   bool hasApproximations() const;
   //----------------------- end helper methods
+  /** get number of commands to report */
+  size_t getNumCommands() const;
+  /** get command */
+  const NodeCommand* getCommand(size_t i) const;
+  //----------------------- model declarations
+  /** Clear the current model declarations. */
+  void clearModelDeclarations();
+  /**
+   * Set that tn is a sort that should be printed in the model, when applicable,
+   * based on the output language.
+   */
+  void addDeclarationSort(TypeNode tn);
+  /**
+   * Set that n is a variable that should be printed in the model, when
+   * applicable, based on the output language.
+   */
+  void addDeclarationTerm(Node n);
+  /** get declared sorts */
+  const std::vector<TypeNode>& getDeclaredSorts() const;
+  /** get declared terms */
+  const std::vector<Node>& getDeclaredTerms() const;
+  //----------------------- end model declarations
  protected:
   /** The SmtEngine we're associated with */
   SmtEngine& d_smt;
@@ -93,6 +114,16 @@ class Model {
    * the values of sorts and terms.
    */
   theory::TheoryModel* d_tmodel;
+  /**
+   * The list of types to print, generally corresponding to declare-sort
+   * commands.
+   */
+  std::vector<TypeNode> d_declareSorts;
+  /**
+   * The list of terms to print, is typically one-to-one with declare-fun
+   * commands.
+   */
+  std::vector<Node> d_declareTerms;
 };
 
 }  // namespace smt