Refactor SMT-level model object (#5277)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 16 Oct 2020 18:32:42 +0000 (13:32 -0500)
committerGitHub <noreply@github.com>
Fri, 16 Oct 2020 18:32:42 +0000 (13:32 -0500)
This refactors the SMT-level model object so that it is a wrapper around TheoryModel instead of a base class. This inheritance was unnecessary.

Moreover, it removes the virtual base models of the SMT-level model which were based on Expr. Now the interface is more minimal and in terms of Node only.

This PR further simplifies a few places in the code that interface with the SmtEngine with things related to models.

23 files changed:
src/api/cvc4cpp.cpp
src/printer/ast/ast_printer.cpp
src/printer/ast/ast_printer.h
src/printer/cvc/cvc_printer.cpp
src/printer/cvc/cvc_printer.h
src/printer/printer.cpp
src/printer/printer.h
src/printer/smt2/smt2_printer.cpp
src/printer/smt2/smt2_printer.h
src/printer/tptp/tptp_printer.cpp
src/printer/tptp/tptp_printer.h
src/smt/command.h
src/smt/model.cpp
src/smt/model.h
src/smt/model_blocker.cpp
src/smt/model_core_builder.cpp
src/smt/model_core_builder.h
src/smt/smt_engine.cpp
src/smt/smt_engine.h
src/theory/theory_model.cpp
src/theory/theory_model.h
src/theory/theory_model_builder.cpp
src/theory/theory_model_builder.h

index 0384b573eaa4eb38739196788f5f73949e8f9553..2417936a755fcd61cb3295a59d6731b7eb7f64a0 100644 (file)
@@ -5211,13 +5211,6 @@ Term Solver::getSeparationHeap() const
          "(try --produce-models)";
   CVC4_API_CHECK(d_smtEngine->getSmtMode() != SmtMode::UNSAT)
       << "Cannot get separtion heap term when in unsat mode.";
-
-  theory::TheoryModel* m =
-      d_smtEngine->getAvailableModel("get separation logic heap and nil");
-  Expr heap, nil;
-  bool hasHeapModel = m->getHeapModel(heap, nil);
-  CVC4_API_CHECK(hasHeapModel)
-      << "Failed to obtain heap term from theory model.";
   return Term(this, d_smtEngine->getSepHeapExpr());
   CVC4_API_SOLVER_TRY_CATCH_END;
 }
@@ -5235,14 +5228,7 @@ Term Solver::getSeparationNilTerm() const
          "(try --produce-models)";
   CVC4_API_CHECK(d_smtEngine->getSmtMode() != SmtMode::UNSAT)
       << "Cannot get separtion nil term when in unsat mode.";
-
-  theory::TheoryModel* m =
-      d_smtEngine->getAvailableModel("get separation logic heap and nil");
-  Expr heap, nil;
-  bool hasHeapModel = m->getHeapModel(heap, nil);
-  CVC4_API_CHECK(hasHeapModel)
-      << "Failed to obtain nil term from theory model.";
-  return Term(this, nil);
+  return Term(this, d_smtEngine->getSepNilExpr());
   CVC4_API_SOLVER_TRY_CATCH_END;
 }
 
index 062ebf0377f1e599c9b0378339a0411a02cfded3..1ed9d146cc1f92fc9b8fccdd6bdb3d4d89bc7d5e 100644 (file)
@@ -150,13 +150,13 @@ void AstPrinter::toStream(std::ostream& out, const CommandStatus* s) const
 
 }/* AstPrinter::toStream(CommandStatus*) */
 
-void AstPrinter::toStream(std::ostream& out, const Model& m) const
+void AstPrinter::toStream(std::ostream& out, const smt::Model& m) const
 {
   out << "Model()";
 }
 
 void AstPrinter::toStream(std::ostream& out,
-                          const Model& m,
+                          const smt::Model& m,
                           const NodeCommand* c) const
 {
   // shouldn't be called; only the non-Command* version above should be
index b5feccdfa62feac82346c3b9bf95313c5eced0cd..f01436b8a221365c559df96370d1a481324c1e6e 100644 (file)
@@ -37,7 +37,7 @@ class AstPrinter : public CVC4::Printer
                 bool types,
                 size_t dag) const override;
   void toStream(std::ostream& out, const CommandStatus* s) const override;
-  void toStream(std::ostream& out, const Model& m) const override;
+  void toStream(std::ostream& out, const smt::Model& m) const override;
 
   /** Print empty command */
   void toStreamCmdEmpty(std::ostream& out,
@@ -174,7 +174,7 @@ class AstPrinter : public CVC4::Printer
  private:
   void toStream(std::ostream& out, TNode n, int toDepth, bool types) const;
   void toStream(std::ostream& out,
-                const Model& m,
+                const smt::Model& m,
                 const NodeCommand* c) const override;
 }; /* class AstPrinter */
 
index 7fd26e1a02f1f3ed8ed744713932647be73833a0..bab619dcead0708d2212c6e7ce26bed7200f78e1 100644 (file)
@@ -1142,7 +1142,9 @@ void DeclareFunctionNodeCommandToStream(
   {
     out << tn;
   }
-  Node val = model.getSmtEngine()->getValue(n);
+  // We get the value from the theory model directly, which notice
+  // does not have to go through the standard SmtEngine::getValue interface.
+  Node val = model.getValue(n);
   if (options::modelUninterpDtEnum() && val.getKind() == kind::STORE)
   {
     TypeNode type_node = val[1].getType();
@@ -1162,11 +1164,12 @@ void DeclareFunctionNodeCommandToStream(
 
 }  // namespace
 
-void CvcPrinter::toStream(std::ostream& out, const Model& m) const
+void CvcPrinter::toStream(std::ostream& out, const smt::Model& m) const
 {
+  const theory::TheoryModel* tm = m.getTheoryModel();
   // print the model comments
   std::stringstream c;
-  m.getComments(c);
+  tm->getComments(c);
   std::string ln;
   while (std::getline(c, ln))
   {
@@ -1180,10 +1183,10 @@ void CvcPrinter::toStream(std::ostream& out, const Model& m) const
 }
 
 void CvcPrinter::toStream(std::ostream& out,
-                          const Model& model,
+                          const smt::Model& model,
                           const NodeCommand* command) const
 {
-  const auto* theory_model = dynamic_cast<const theory::TheoryModel*>(&model);
+  const auto* theory_model = model.getTheoryModel();
   AlwaysAssert(theory_model != nullptr);
   if (const auto* declare_type_command =
           dynamic_cast<const DeclareTypeNodeCommand*>(command))
index a7bacb8031b29f4fb8ca203e871c181ac745159d..4047f0d8b93d61973c44feafdf4038d4e90bc352 100644 (file)
@@ -38,7 +38,7 @@ class CvcPrinter : public CVC4::Printer
                 bool types,
                 size_t dag) const override;
   void toStream(std::ostream& out, const CommandStatus* s) const override;
-  void toStream(std::ostream& out, const Model& m) const override;
+  void toStream(std::ostream& out, const smt::Model& m) const override;
 
   /** Print empty command */
   void toStreamCmdEmpty(std::ostream& out,
@@ -176,7 +176,7 @@ class CvcPrinter : public CVC4::Printer
   void toStream(
       std::ostream& out, TNode n, int toDepth, bool types, bool bracket) const;
   void toStream(std::ostream& out,
-                const Model& m,
+                const smt::Model& m,
                 const NodeCommand* c) const override;
 
   bool d_cvc3Mode;
index 952caf89e26e2001536ab6688ba92ca77dc97993..ba062c20fcb0bd0e42de13b62cac4b597b8ae35b 100644 (file)
@@ -71,13 +71,13 @@ unique_ptr<Printer> Printer::makePrinter(OutputLanguage lang)
   }
 }
 
-void Printer::toStream(std::ostream& out, const Model& m) const
+void Printer::toStream(std::ostream& out, const smt::Model& m) const
 {
   for(size_t i = 0; i < m.getNumCommands(); ++i) {
     const NodeCommand* cmd = m.getCommand(i);
     const DeclareFunctionNodeCommand* dfc =
         dynamic_cast<const DeclareFunctionNodeCommand*>(cmd);
-    if (dfc != NULL && !m.isModelCoreSymbol(dfc->getFunction().toExpr()))
+    if (dfc != NULL && !m.isModelCoreSymbol(dfc->getFunction()))
     {
       continue;
     }
index c10e1db049ea6858c45624af331b398e634b7883..b95b02ca836c8adaca10a0366d17c416972fd019 100644 (file)
@@ -58,7 +58,7 @@ class Printer
   virtual void toStream(std::ostream& out, const CommandStatus* s) const = 0;
 
   /** Write a Model out to a stream with this Printer. */
-  virtual void toStream(std::ostream& out, const Model& m) const;
+  virtual void toStream(std::ostream& out, const smt::Model& m) const;
 
   /** Write an UnsatCore out to a stream with this Printer. */
   virtual void toStream(std::ostream& out, const UnsatCore& core) const;
@@ -275,13 +275,13 @@ class Printer
 
   /** write model response to command */
   virtual void toStream(std::ostream& out,
-                        const Model& m,
+                        const smt::Model& m,
                         const NodeCommand* c) const = 0;
 
   /** write model response to command using another language printer */
   void toStreamUsing(OutputLanguage lang,
                      std::ostream& out,
-                     const Model& m,
+                     const smt::Model& m,
                      const NodeCommand* c) const
   {
     getPrinter(lang)->toStream(out, m, c);
index 6d75279e5cc64518d543c050a6aee35add2d82cf..2024c87b6a141f917c274147d6bd6bb376f96ff2 100644 (file)
@@ -1323,11 +1323,12 @@ void Smt2Printer::toStream(std::ostream& out, const UnsatCore& core) const
   out << ")" << endl;
 }/* Smt2Printer::toStream(UnsatCore, map<Expr, string>) */
 
-void Smt2Printer::toStream(std::ostream& out, const Model& m) const
+void Smt2Printer::toStream(std::ostream& out, const smt::Model& m) const
 {
+  const theory::TheoryModel* tm = m.getTheoryModel();
   //print the model comments
   std::stringstream c;
-  m.getComments( c );
+  tm->getComments(c);
   std::string ln;
   while( std::getline( c, ln ) ){
     out << "; " << ln << std::endl;
@@ -1339,8 +1340,9 @@ void Smt2Printer::toStream(std::ostream& out, const Model& m) const
   this->Printer::toStream(out, m);
   out << ")" << endl;
   //print the heap model, if it exists
-  Expr h, neq;
-  if( m.getHeapModel( h, neq ) ){
+  Node h, neq;
+  if (tm->getHeapModel(h, neq))
+  {
     // description of the heap+what nil is equal to fully describes model
     out << "(heap" << endl;
     out << h << endl;
@@ -1350,11 +1352,10 @@ void Smt2Printer::toStream(std::ostream& out, const Model& m) const
 }
 
 void Smt2Printer::toStream(std::ostream& out,
-                           const Model& model,
+                           const smt::Model& model,
                            const NodeCommand* command) const
 {
-  const theory::TheoryModel* theory_model =
-      dynamic_cast<const theory::TheoryModel*>(&model);
+  const theory::TheoryModel* theory_model = model.getTheoryModel();
   AlwaysAssert(theory_model != nullptr);
   if (const DeclareTypeNodeCommand* dtc =
           dynamic_cast<const DeclareTypeNodeCommand*>(command))
@@ -1367,7 +1368,7 @@ void Smt2Printer::toStream(std::ostream& out,
     }
     else
     {
-      std::vector<Expr> elements = theory_model->getDomainElements(tn.toType());
+      std::vector<Node> elements = theory_model->getDomainElements(tn);
       if (options::modelUninterpDtEnum())
       {
         if (isVariant_2_6(d_variant))
@@ -1378,7 +1379,7 @@ void Smt2Printer::toStream(std::ostream& out,
         {
           out << "(declare-datatypes () ((" << (*dtc).getSymbol() << " ";
         }
-        for (const Expr& type_ref : elements)
+        for (const Node& type_ref : elements)
         {
           out << "(" << type_ref << ")";
         }
@@ -1390,9 +1391,8 @@ void Smt2Printer::toStream(std::ostream& out,
         out << "; cardinality of " << tn << " is " << elements.size() << endl;
         out << (*dtc) << endl;
         // print the representatives
-        for (const Expr& type_ref : elements)
+        for (const Node& trn : elements)
         {
-          Node trn = Node::fromExpr(type_ref);
           if (trn.isVar())
           {
             out << "(declare-fun " << quoteSymbol(trn) << " () " << tn << ")"
@@ -1423,7 +1423,9 @@ void Smt2Printer::toStream(std::ostream& out,
       // don't print out internal stuff
       return;
     }
-    Node val = theory_model->getSmtEngine()->getValue(n);
+    // We get the value from the theory model directly, which notice
+    // does not have to go through the standard SmtEngine::getValue interface.
+    Node val = theory_model->getValue(n);
     if (val.getKind() == kind::LAMBDA)
     {
       out << "(define-fun " << n << " " << val[0] << " "
index 3160771da53fdd4ce2a3a27c75afaed2fd1c35f7..ed04da983ae73f3c9863d65aaca9379f80b23a1b 100644 (file)
@@ -45,7 +45,7 @@ class Smt2Printer : public CVC4::Printer
                 bool types,
                 size_t dag) const override;
   void toStream(std::ostream& out, const CommandStatus* s) const override;
-  void toStream(std::ostream& out, const Model& m) const override;
+  void toStream(std::ostream& out, const smt::Model& m) const override;
   /**
    * Writes the unsat core to the stream out.
    * We use the expression names that are stored in the SMT engine associated
@@ -231,7 +231,7 @@ class Smt2Printer : public CVC4::Printer
   void toStream(
       std::ostream& out, TNode n, int toDepth, bool types, TypeNode nt) const;
   void toStream(std::ostream& out,
-                const Model& m,
+                const smt::Model& m,
                 const NodeCommand* c) const override;
   void toStream(std::ostream& out, const SExpr& sexpr) const;
   void toStream(std::ostream& out, const DType& dt) const;
index fa0fc3c46b5caa2e41fb74c36092c47e512428b9..009f78a1da646a3686eddd800dbc7efed387463e 100644 (file)
@@ -45,7 +45,7 @@ void TptpPrinter::toStream(std::ostream& out, const CommandStatus* s) const
   s->toStream(out, language::output::LANG_SMTLIB_V2_5);
 }/* TptpPrinter::toStream() */
 
-void TptpPrinter::toStream(std::ostream& out, const Model& m) const
+void TptpPrinter::toStream(std::ostream& out, const smt::Model& m) const
 {
   std::string statusName(m.isKnownSat() ? "FiniteModel"
                                         : "CandidateFiniteModel");
@@ -59,7 +59,7 @@ void TptpPrinter::toStream(std::ostream& out, const Model& m) const
 }
 
 void TptpPrinter::toStream(std::ostream& out,
-                           const Model& m,
+                           const smt::Model& m,
                            const NodeCommand* c) const
 {
   // shouldn't be called; only the non-Command* version above should be
index 0c961d39bf967a86c82d893cee56c887a9a9653f..84bb3e576cd44796bdd635652765cf5957660332 100644 (file)
@@ -37,7 +37,7 @@ class TptpPrinter : public CVC4::Printer
                 bool types,
                 size_t dag) const override;
   void toStream(std::ostream& out, const CommandStatus* s) const override;
-  void toStream(std::ostream& out, const Model& m) const override;
+  void toStream(std::ostream& out, const smt::Model& m) const override;
   /** print unsat core to stream
    * We use the expression names stored in the SMT engine associated with the
    * unsat core with UnsatCore::getSmtEngine.
@@ -46,7 +46,7 @@ class TptpPrinter : public CVC4::Printer
 
  private:
   void toStream(std::ostream& out,
-                const Model& m,
+                const smt::Model& m,
                 const NodeCommand* c) const override;
 
 }; /* class TptpPrinter */
index b823b57302da0e51853003bdff45f192b75d74a4..41776cee50db3b1687424834d8795993eaf4f1f3 100644 (file)
@@ -46,7 +46,10 @@ class Term;
 class SmtEngine;
 class Command;
 class CommandStatus;
+
+namespace smt {
 class Model;
+}
 
 std::ostream& operator<<(std::ostream&, const Command&) CVC4_PUBLIC;
 std::ostream& operator<<(std::ostream&, const Command*) CVC4_PUBLIC;
@@ -995,7 +998,7 @@ class CVC4_PUBLIC GetModelCommand : public Command
       OutputLanguage language = language::output::LANG_AUTO) const override;
 
  protected:
-  Model* d_result;
+  smt::Model* d_result;
 }; /* class GetModelCommand */
 
 /** The command to block models. */
index 60640def19d5e0bc3b0b1afc3645958ff21fb497..fc9ea8fbb930b2053b3feddc50cab6258f349d92 100644 (file)
@@ -14,8 +14,6 @@
 
 #include "smt/model.h"
 
-#include <vector>
-
 #include "expr/expr_iomanip.h"
 #include "options/base_options.h"
 #include "printer/printer.h"
 #include "smt/node_command.h"
 #include "smt/smt_engine.h"
 #include "smt/smt_engine_scope.h"
-
-using namespace std;
+#include "theory/theory_model.h"
 
 namespace CVC4 {
+namespace smt {
+
+Model::Model(SmtEngine& smt, theory::TheoryModel* tm)
+    : d_smt(smt), d_isKnownSat(false), d_tmodel(tm)
+{
+  Assert(d_tmodel != nullptr);
+}
 
 std::ostream& operator<<(std::ostream& out, const Model& m) {
   smt::SmtScope smts(&m.d_smt);
@@ -35,8 +39,6 @@ std::ostream& operator<<(std::ostream& out, const Model& m) {
   return out;
 }
 
-Model::Model() : d_smt(*smt::currentSmtEngine()), d_isKnownSat(false) {}
-
 size_t Model::getNumCommands() const
 {
   return d_smt.getDumpManager()->getNumModelCommands();
@@ -47,4 +49,17 @@ const NodeCommand* Model::getCommand(size_t i) const
   return d_smt.getDumpManager()->getModelCommand(i);
 }
 
+theory::TheoryModel* Model::getTheoryModel() { return d_tmodel; }
+
+const theory::TheoryModel* Model::getTheoryModel() const { return d_tmodel; }
+
+bool Model::isModelCoreSymbol(TNode sym) const
+{
+  return d_tmodel->isModelCoreSymbol(sym);
+}
+Node Model::getValue(TNode n) const { return d_tmodel->getValue(n); }
+
+bool Model::hasApproximations() const { return d_tmodel->hasApproximations(); }
+
+}  // namespace smt
 }/* CVC4 namespace */
index eb959ba7ef9ec8737901caa2bf63aecb94909230..dc36b5d291535bc6a3a14ecb1e8f0f8ac5532bb4 100644 (file)
 #include <vector>
 
 #include "expr/expr.h"
+#include "theory/theory_model.h"
 #include "util/cardinality.h"
 
 namespace CVC4 {
 
-class NodeCommand;
 class SmtEngine;
+class NodeCommand;
+
+namespace smt {
+
 class Model;
 
 std::ostream& operator<<(std::ostream&, const Model&);
 
+/**
+ * This is the SMT-level model object, that is responsible for maintaining
+ * the necessary information for how to print the model, as well as
+ * holding a pointer to the underlying implementation of the theory model.
+ */
 class Model {
   friend std::ostream& operator<<(std::ostream&, const Model&);
-  friend class SmtEngine;
-
- protected:
-  /** The SmtEngine we're associated with */
-  SmtEngine& d_smt;
-
-  /** construct the base class; users cannot do this, only CVC4 internals */
-  Model();
+  friend class ::CVC4::SmtEngine;
 
  public:
+  /** construct */
+  Model(SmtEngine& smt, theory::TheoryModel* tm);
   /** virtual destructor */
-  virtual ~Model() { }
+  ~Model() {}
   /** get number of commands to report */
   size_t getNumCommands() const;
   /** get command */
@@ -62,54 +66,21 @@ class Model {
    * only a candidate solution.
    */
   bool isKnownSat() const { return d_isKnownSat; }
-  //--------------------------- model cores
-  /** set using model core
-   *
-   * This sets that this model is minimized to be a "model core" for some
-   * formula (typically the input formula).
-   *
-   * For example, given formula ( a>5 OR b>5 ) AND f( c ) = 0,
-   * a model for this formula is: a -> 6, b -> 0, c -> 0, f -> lambda x. 0.
-   * A "model core" is a subset of this model that suffices to show the
-   * above formula is true, for example { a -> 6, f -> lambda x. 0 } is a
-   * model core for this formula.
-   */
-  virtual void setUsingModelCore() = 0;
-  /** record model core symbol
-   *
-   * This marks that sym is a "model core symbol". In other words, its value is
-   * critical to the satisfiability of the formula this model is for.
-   */
-  virtual void recordModelCoreSymbol(Expr sym) = 0;
-  /** Check whether this expr is in the model core */
-  virtual bool isModelCoreSymbol(Expr expr) const = 0;
-  //--------------------------- end model cores
-  /** get value for expression */
-  virtual Expr getValue(Expr expr) const = 0;
-  /** get cardinality for sort */
-  virtual Cardinality getCardinality(Type t) const = 0;
-  /** print comments */
-  virtual void getComments(std::ostream& out) const {}
-  /** get heap model (for separation logic) */
-  virtual bool getHeapModel( Expr& h, Expr& ne ) const { return false; }
-  /** are there any approximations in this model? */
-  virtual bool hasApproximations() const { return false; }
-  /** get the list of approximations
-   *
-   * This is a list of pairs of the form (t,p), where t is a term and p
-   * is a predicate over t that indicates a property that t satisfies.
-   */
-  virtual std::vector<std::pair<Expr, Expr> > getApproximations() const = 0;
-  /** get the domain elements for uninterpreted sort t
-   *
-   * This method gets the interpretation of an uninterpreted sort t.
-   * All models interpret uninterpreted sorts t as finite sets
-   * of domain elements v_1, ..., v_n. This method returns this list for t in
-   * this model.
-   */
-  virtual std::vector<Expr> getDomainElements(Type t) const = 0;
-
+  /** Get the underlying theory model */
+  theory::TheoryModel* getTheoryModel();
+  /** Get the underlying theory model (const version) */
+  const theory::TheoryModel* getTheoryModel() const;
+  //----------------------- helper methods in the underlying theory model
+  /** Is the node n a model core symbol? */
+  bool isModelCoreSymbol(TNode sym) const;
+  /** Get value */
+  Node getValue(TNode n) const;
+  /** Does this model have approximations? */
+  bool hasApproximations() const;
+  //----------------------- end helper methods
  protected:
+  /** The SmtEngine we're associated with */
+  SmtEngine& d_smt;
   /** the input name (file name, etc.) this model is associated to */
   std::string d_inputName;
   /**
@@ -117,8 +88,14 @@ class Model {
    * from the solver.
    */
   bool d_isKnownSat;
-};/* class Model */
+  /**
+   * Pointer to the underlying theory model, which maintains all data regarding
+   * the values of sorts and terms.
+   */
+  theory::TheoryModel* d_tmodel;
+};
 
+}  // namespace smt
 }/* CVC4 namespace */
 
 #endif  /* CVC4__MODEL_H */
index 9d15b5690c81987fd5af6e1c3b99221b92651cee..cabd7bd203481219f70a165ab44aac0ee661a072 100644 (file)
@@ -66,7 +66,7 @@ Node ModelBlocker::getModelBlocker(const std::vector<Node>& assertions,
       Node blockTriv = nm->mkConst(false);
       Trace("model-blocker")
           << "...model blocker is (trivially) " << blockTriv << std::endl;
-      return blockTriv.toExpr();
+      return blockTriv;
     }
 
     Node formula = asserts.size() > 1 ? nm->mkNode(AND, asserts) : asserts[0];
@@ -152,7 +152,7 @@ Node ModelBlocker::getModelBlocker(const std::vector<Node>& assertions,
           std::vector<Node> children;
           for (const Node& cn : catom)
           {
-            Node vn = Node::fromExpr(m->getValue(cn.toExpr()));
+            Node vn = m->getValue(cn);
             Assert(vn.isConst());
             children.push_back(vn.getConst<bool>() ? cn : cn.negate());
           }
index 59dac63e89f4ef4b7767626ec61b91d5b2bf1f52..cb8494e850b561dc3c76f8926e82287cf472fb9e 100644 (file)
@@ -21,7 +21,7 @@ using namespace CVC4::kind;
 namespace CVC4 {
 
 bool ModelCoreBuilder::setModelCore(const std::vector<Node>& assertions,
-                                    Model* m,
+                                    theory::TheoryModel* m,
                                     options::ModelCoresMode mode)
 {
   if (Trace.isOn("model-core"))
@@ -53,7 +53,7 @@ bool ModelCoreBuilder::setModelCore(const std::vector<Node>& assertions,
       visited.insert(cur);
       if (cur.isVar())
       {
-        Node vcur = Node::fromExpr(m->getValue(cur.toExpr()));
+        Node vcur = m->getValue(cur);
         Trace("model-core") << "  " << cur << " -> " << vcur << std::endl;
         vars.push_back(cur);
         subs.push_back(vcur);
@@ -95,7 +95,7 @@ bool ModelCoreBuilder::setModelCore(const std::vector<Node>& assertions,
 
     for (const Node& cv : coreVars)
     {
-      m->recordModelCoreSymbol(cv.toExpr());
+      m->recordModelCoreSymbol(cv);
     }
     return true;
   }
index 984c61d043278806f07caf79c1cb85f0ba68173d..7a28c47f202b5fdef64162380c6e637a342513f4 100644 (file)
@@ -21,7 +21,7 @@
 
 #include "expr/expr.h"
 #include "options/smt_options.h"
-#include "smt/model.h"
+#include "theory/theory_model.h"
 
 namespace CVC4 {
 
@@ -55,7 +55,7 @@ class ModelCoreBuilder
    * left unchanged.
    */
   static bool setModelCore(const std::vector<Node>& assertions,
-                           Model* m,
+                           theory::TheoryModel* m,
                            options::ModelCoresMode mode);
 }; /* class TheoryModelCoreBuilder */
 
index 205865e168781afed96dd0b95b16f8c33b5f8851..2a771ce768d58ae23362e45c21445c560c673cd4 100644 (file)
@@ -127,6 +127,7 @@ SmtEngine::SmtEngine(ExprManager* em, Options* optr)
       d_snmListener(new SmtNodeManagerListener(*d_dumpm.get(), d_outMgr)),
       d_smtSolver(nullptr),
       d_proofManager(nullptr),
+      d_model(nullptr),
       d_pfManager(nullptr),
       d_rewriter(new theory::Rewriter()),
       d_definedFunctions(nullptr),
@@ -271,6 +272,15 @@ void SmtEngine::finishInit()
   Trace("smt-debug") << "SmtEngine::finishInit" << std::endl;
   d_smtSolver->finishInit(const_cast<const LogicInfo&>(d_logic));
 
+  // now can construct the SMT-level model object
+  TheoryEngine* te = d_smtSolver->getTheoryEngine();
+  Assert(te != nullptr);
+  TheoryModel* tm = te->getModel();
+  if (tm != nullptr)
+  {
+    d_model.reset(new Model(*this, tm));
+  }
+
   // global push/pop around everything, to ensure proper destruction
   // of context-dependent data structures
   d_state->setup();
@@ -839,7 +849,7 @@ Result SmtEngine::quickCheck() {
       Result::ENTAILMENT_UNKNOWN, Result::REQUIRES_FULL_CHECK, filename);
 }
 
-theory::TheoryModel* SmtEngine::getAvailableModel(const char* c) const
+Model* SmtEngine::getAvailableModel(const char* c) const
 {
   if (!options::assignFunctionValues())
   {
@@ -878,7 +888,7 @@ theory::TheoryModel* SmtEngine::getAvailableModel(const char* c) const
     throw RecoverableModalException(ss.str().c_str());
   }
 
-  return m;
+  return d_model.get();
 }
 
 void SmtEngine::notifyPushPre() { d_smtSolver->processAssertions(*d_asserts); }
@@ -1210,11 +1220,9 @@ Node SmtEngine::getValue(const Node& ex) const
   }
 
   Trace("smt") << "--- getting value of " << n << endl;
-  TheoryModel* m = getAvailableModel("get-value");
-  Node resultNode;
-  if(m != NULL) {
-    resultNode = m->getValue(n);
-  }
+  Model* m = getAvailableModel("get-value");
+  Assert(m != nullptr);
+  Node resultNode = m->getValue(n);
   Trace("smt") << "--- got value " << n << " = " << resultNode << endl;
   Trace("smt") << "--- type " << resultNode.getType() << endl;
   Trace("smt") << "--- expected type " << expectedType << endl;
@@ -1301,7 +1309,7 @@ vector<pair<Expr, Expr>> SmtEngine::getAssignment()
   // Get the model here, regardless of whether d_assignments is null, since
   // we should throw errors related to model availability whether or not
   // assignments is null.
-  TheoryModel* m = getAvailableModel("get assignment");
+  Model* m = getAvailableModel("get assignment");
 
   vector<pair<Expr,Expr>> res;
   if (d_assignments != nullptr)
@@ -1354,7 +1362,7 @@ Model* SmtEngine::getModel() {
         getOutputManager().getDumpOut());
   }
 
-  TheoryModel* m = getAvailableModel("get model");
+  Model* m = getAvailableModel("get model");
 
   // Since model m is being returned to the user, we must ensure that this
   // model object remains valid with future check-sat calls. Hence, we set
@@ -1368,8 +1376,11 @@ Model* SmtEngine::getModel() {
     // If we enabled model cores, we compute a model core for m based on our
     // (expanded) assertions using the model core builder utility
     std::vector<Node> eassertsProc = getExpandedAssertions();
-    ModelCoreBuilder::setModelCore(eassertsProc, m, options::modelCoresMode());
+    ModelCoreBuilder::setModelCore(
+        eassertsProc, m->getTheoryModel(), options::modelCoresMode());
   }
+  // set the information on the SMT-level model
+  Assert(m != nullptr);
   m->d_inputName = d_state->getFilename();
   m->d_isKnownSat = (d_state->getMode() == SmtMode::SAT);
   return m;
@@ -1388,19 +1399,19 @@ Result SmtEngine::blockModel()
         getOutputManager().getDumpOut());
   }
 
-  TheoryModel* m = getAvailableModel("block model");
+  Model* m = getAvailableModel("block model");
 
   if (options::blockModelsMode() == options::BlockModelsMode::NONE)
   {
     std::stringstream ss;
     ss << "Cannot block model when block-models is set to none.";
-    throw ModalException(ss.str().c_str());
+    throw RecoverableModalException(ss.str().c_str());
   }
 
   // get expanded assertions
   std::vector<Node> eassertsProc = getExpandedAssertions();
   Node eblocker = ModelBlocker::getModelBlocker(
-      eassertsProc, m, options::blockModelsMode());
+      eassertsProc, m->getTheoryModel(), options::blockModelsMode());
   return assertFormula(eblocker);
 }
 
@@ -1417,13 +1428,16 @@ Result SmtEngine::blockModelValues(const std::vector<Node>& exprs)
         getOutputManager().getDumpOut(), exprs);
   }
 
-  TheoryModel* m = getAvailableModel("block model values");
+  Model* m = getAvailableModel("block model values");
 
   // get expanded assertions
   std::vector<Node> eassertsProc = getExpandedAssertions();
   // we always do block model values mode here
-  Node eblocker = ModelBlocker::getModelBlocker(
-      eassertsProc, m, options::BlockModelsMode::VALUES, exprs);
+  Node eblocker =
+      ModelBlocker::getModelBlocker(eassertsProc,
+                                    m->getTheoryModel(),
+                                    options::BlockModelsMode::VALUES,
+                                    exprs);
   return assertFormula(eblocker);
 }
 
@@ -1437,16 +1451,18 @@ std::pair<Node, Node> SmtEngine::getSepHeapAndNilExpr(void)
     throw RecoverableModalException(msg);
   }
   NodeManagerScope nms(d_nodeManager);
-  Expr heap;
-  Expr nil;
+  Node heap;
+  Node nil;
   Model* m = getAvailableModel("get separation logic heap and nil");
-  if (!m->getHeapModel(heap, nil))
+  TheoryModel* tm = m->getTheoryModel();
+  if (!tm->getHeapModel(heap, nil))
   {
-    InternalError()
-        << "SmtEngine::getSepHeapAndNilExpr(): failed to obtain heap/nil "
-           "expressions from theory model.";
+    const char* msg =
+        "Failed to obtain heap/nil "
+        "expressions from theory model.";
+    throw RecoverableModalException(msg);
   }
-  return std::make_pair(Node::fromExpr(heap), Node::fromExpr(nil));
+  return std::make_pair(heap, nil);
 }
 
 std::vector<Node> SmtEngine::getExpandedAssertions()
@@ -1544,7 +1560,8 @@ void SmtEngine::checkModel(bool hardFailure) {
   // and if Notice() is on, the user gave --verbose (or equivalent).
 
   Notice() << "SmtEngine::checkModel(): generating model" << endl;
-  TheoryModel* m = getAvailableModel("check model");
+  Model* m = getAvailableModel("check model");
+  Assert(m != nullptr);
 
   // check-model is not guaranteed to succeed if approximate values were used.
   // Thus, we intentionally abort here.
index 62e54a0c11f0dc4a220f6adbcee1d932f09d3d7c..da12d336b913ffd20dc24c019e3e91cd29f98bc1 100644 (file)
@@ -60,7 +60,6 @@ class TheoryEngine;
 
 class ProofManager;
 
-class Model;
 class LogicRequest;
 class StatisticsRegistry;
 
@@ -95,6 +94,7 @@ namespace prop {
 
 namespace smt {
 /** Utilities */
+class Model;
 class SmtEngineState;
 class AbstractValues;
 class Assertions;
@@ -280,7 +280,7 @@ class CVC4_PUBLIC SmtEngine
    * Get the model (only if immediately preceded by a SAT or NOT_ENTAILED
    * query).  Only permitted if produce-models is on.
    */
-  Model* getModel();
+  smt::Model* getModel();
 
   /**
    * Block the current model. Can be called only if immediately preceded by
@@ -969,16 +969,17 @@ class CVC4_PUBLIC SmtEngine
   Result quickCheck();
 
   /**
-   * Get the model, if it is available and return a pointer to it
+   * Get the (SMT-level) model pointer, if we are in SAT mode. Otherwise,
+   * return nullptr.
    *
-   * This ensures that the model is currently available, which means that
-   * CVC4 is producing models, and is in "SAT mode", otherwise an exception
-   * is thrown.
+   * This ensures that the underlying theory model of the SmtSolver maintained
+   * by this class is currently available, which means that CVC4 is producing
+   * models, and is in "SAT mode", otherwise a recoverable exception is thrown.
    *
    * The flag c is used for giving an error message to indicate the context
    * this method was called.
    */
-  theory::TheoryModel* getAvailableModel(const char* c) const;
+  smt::Model* getAvailableModel(const char* c) const;
 
   // --------------------------------------- callbacks from the state
   /**
@@ -1088,6 +1089,12 @@ class CVC4_PUBLIC SmtEngine
 
   /** The (old) proof manager TODO (project #37): delete this */
   std::unique_ptr<ProofManager> d_proofManager;
+  /**
+   * The SMT-level model object, which contains information about how to
+   * print the model, as well as a pointer to the underlying TheoryModel
+   * implementation maintained by the SmtSolver.
+   */
+  std::unique_ptr<smt::Model> d_model;
 
   /**
    * The proof manager, which manages all things related to checking,
index b8cbdf6b38272ec0d2e46bfddfa0756a3d6c87e5..f240d5113890648022cf8902493f87ad368373a0 100644 (file)
@@ -109,48 +109,27 @@ bool TheoryModel::getHeapModel(Node& h, Node& neq) const
   return true;
 }
 
-bool TheoryModel::getHeapModel( Expr& h, Expr& neq ) const {
-  if( d_sep_heap.isNull() || d_sep_nil_eq.isNull() ){
-    return false;
-  }else{
-    h = d_sep_heap.toExpr();
-    neq = d_sep_nil_eq.toExpr();
-    return true;
-  }
-}
-
 bool TheoryModel::hasApproximations() const { return !d_approx_list.empty(); }
 
-std::vector<std::pair<Expr, Expr> > TheoryModel::getApproximations() const
+std::vector<std::pair<Node, Node> > TheoryModel::getApproximations() const
 {
-  std::vector<std::pair<Expr, Expr> > approx;
-  for (const std::pair<Node, Node>& ap : d_approx_list)
-  {
-    approx.push_back(
-        std::pair<Expr, Expr>(ap.first.toExpr(), ap.second.toExpr()));
-  }
-  return approx;
+  return d_approx_list;
 }
 
-std::vector<Expr> TheoryModel::getDomainElements(Type t) const
+std::vector<Node> TheoryModel::getDomainElements(TypeNode tn) const
 {
   // must be an uninterpreted sort
-  Assert(t.isSort());
-  std::vector<Expr> elements;
-  TypeNode tn = TypeNode::fromType(t);
+  Assert(tn.isSort());
+  std::vector<Node> elements;
   const std::vector<Node>* type_refs = d_rep_set.getTypeRepsOrNull(tn);
   if (type_refs == nullptr || type_refs->empty())
   {
     // This is called when t is a sort that does not occur in this model.
     // Sorts are always interpreted as non-empty, thus we add a single element.
-    elements.push_back(t.mkGroundTerm());
+    elements.push_back(tn.mkGroundTerm());
     return elements;
   }
-  for (const Node& n : *type_refs)
-  {
-    elements.push_back(n.toExpr());
-  }
-  return elements;
+  return *type_refs;
 }
 
 Node TheoryModel::getValue(TNode n) const
@@ -170,39 +149,35 @@ Node TheoryModel::getValue(TNode n) const
   return nn;
 }
 
-bool TheoryModel::isModelCoreSymbol(Expr sym) const
+bool TheoryModel::isModelCoreSymbol(Node s) const
 {
   if (!d_using_model_core)
   {
     return true;
   }
-  Node s = Node::fromExpr(sym);
   Assert(s.isVar() && s.getKind() != BOUND_VARIABLE);
   return d_model_core.find(s) != d_model_core.end();
 }
 
-Expr TheoryModel::getValue( Expr expr ) const{
-  Node n = Node::fromExpr( expr );
-  Node ret = getValue( n );
-  return ret.toExpr();
-}
-
-/** get cardinality for sort */
-Cardinality TheoryModel::getCardinality( Type t ) const{
-  TypeNode tn = TypeNode::fromType( t );
+Cardinality TheoryModel::getCardinality(TypeNode tn) const
+{
   //for now, we only handle cardinalities for uninterpreted sorts
-  if( tn.isSort() ){
-    if( d_rep_set.hasType( tn ) ){
-      Debug("model-getvalue-debug") << "Get cardinality sort, #rep : " << d_rep_set.getNumRepresentatives( tn ) << std::endl;
-      return Cardinality( d_rep_set.getNumRepresentatives( tn ) );
-    }else{
-      Debug("model-getvalue-debug") << "Get cardinality sort, unconstrained, return 1." << std::endl;
-      return Cardinality( 1 );
-    }
-  }else{
-      Debug("model-getvalue-debug") << "Get cardinality other sort, unknown." << std::endl;
+  if (!tn.isSort())
+  {
+    Debug("model-getvalue-debug")
+        << "Get cardinality other sort, unknown." << std::endl;
     return Cardinality( CardinalityUnknown() );
   }
+  if (d_rep_set.hasType(tn))
+  {
+    Debug("model-getvalue-debug")
+        << "Get cardinality sort, #rep : "
+        << d_rep_set.getNumRepresentatives(tn) << std::endl;
+    return Cardinality(d_rep_set.getNumRepresentatives(tn));
+  }
+  Debug("model-getvalue-debug")
+      << "Get cardinality sort, unconstrained, return 1." << std::endl;
+  return Cardinality(1);
 }
 
 Node TheoryModel::getModelValue(TNode n) const
@@ -258,16 +233,15 @@ Node TheoryModel::getModelValue(TNode n) const
     {
       Debug("model-getvalue-debug")
           << "get cardinality constraint " << ret[0].getType() << std::endl;
-      ret = nm->mkConst(
-          getCardinality(ret[0].getType().toType()).getFiniteCardinality()
-          <= ret[1].getConst<Rational>().getNumerator());
+      ret = nm->mkConst(getCardinality(ret[0].getType()).getFiniteCardinality()
+                        <= ret[1].getConst<Rational>().getNumerator());
     }
     else if (ret.getKind() == kind::CARDINALITY_VALUE)
     {
       Debug("model-getvalue-debug")
           << "get cardinality value " << ret[0].getType() << std::endl;
-      ret = nm->mkConst(Rational(
-          getCardinality(ret[0].getType().toType()).getFiniteCardinality()));
+      ret = nm->mkConst(
+          Rational(getCardinality(ret[0].getType()).getFiniteCardinality()));
     }
     d_modelCache[n] = ret;
     return ret;
@@ -621,10 +595,7 @@ void TheoryModel::setUsingModelCore()
   d_model_core.clear();
 }
 
-void TheoryModel::recordModelCoreSymbol(Expr sym)
-{
-  d_model_core.insert(Node::fromExpr(sym));
-}
+void TheoryModel::recordModelCoreSymbol(Node sym) { d_model_core.insert(sym); }
 
 void TheoryModel::setUnevaluatedKind(Kind k) { d_unevaluated_kinds.insert(k); }
 
index 9f330ff6ca984d53abfc806bb1b36d5ed10dd07e..e8665bb838fba0aa85fac0f2beec061266a0c964 100644 (file)
@@ -20,7 +20,6 @@
 #include <unordered_map>
 #include <unordered_set>
 
-#include "smt/model.h"
 #include "theory/ee_setup_info.h"
 #include "theory/rep_set.h"
 #include "theory/substitutions.h"
@@ -76,12 +75,12 @@ namespace theory {
  * above functions such as getRepresentative() when assigning total
  * interpretations for uninterpreted functions.
  */
-class TheoryModel : public Model
+class TheoryModel
 {
   friend class TheoryEngineModelBuilder;
 public:
   TheoryModel(context::Context* c, std::string name, bool enableFuncModels);
-  ~TheoryModel() override;
+  virtual ~TheoryModel();
   /**
    * Finish init, where ee is the equality engine the model should use.
    */
@@ -295,23 +294,21 @@ public:
    */
   Node getValue(TNode n) const;
   /** get comments */
-  void getComments(std::ostream& out) const override;
+  void getComments(std::ostream& out) const;
 
   //---------------------------- separation logic
   /** set the heap and value sep.nil is equal to */
   void setHeapModel(Node h, Node neq);
   /** get the heap and value sep.nil is equal to */
   bool getHeapModel(Node& h, Node& neq) const;
-  /** get the heap and value sep.nil is equal to */
-  bool getHeapModel(Expr& h, Expr& neq) const override;
   //---------------------------- end separation logic
 
   /** is the list of approximations non-empty? */
-  bool hasApproximations() const override;
+  bool hasApproximations() const;
   /** get approximations */
-  std::vector<std::pair<Expr, Expr> > getApproximations() const override;
+  std::vector<std::pair<Node, Node> > getApproximations() const;
   /** get domain elements for uninterpreted sort t */
-  std::vector<Expr> getDomainElements(Type t) const override;
+  std::vector<Node> getDomainElements(TypeNode t) const;
   /** get the representative set object */
   const RepSet* getRepSet() const { return &d_rep_set; }
   /** get the representative set object (FIXME: remove this, see #1199) */
@@ -319,17 +316,15 @@ public:
 
   //---------------------------- model cores
   /** set using model core */
-  void setUsingModelCore() override;
+  void setUsingModelCore();
   /** record model core symbol */
-  void recordModelCoreSymbol(Expr sym) override;
+  void recordModelCoreSymbol(Node sym);
   /** Return whether symbol expr is in the model core. */
-  bool isModelCoreSymbol(Expr sym) const override;
+  bool isModelCoreSymbol(Node sym) const;
   //---------------------------- end model cores
 
-  /** get value function for Exprs. */
-  Expr getValue(Expr expr) const override;
   /** get cardinality for sort */
-  Cardinality getCardinality(Type t) const override;
+  Cardinality getCardinality(TypeNode t) const;
 
   //---------------------------- function values
   /** a map from functions f to a list of all APPLY_UF terms with operator f */
index 0f69566d6b0e10be4e453d8b8c7aeeb3208d9f53..2f9e168c90bb3137ad767a1ba830fe792eb64138 100644 (file)
@@ -1082,7 +1082,7 @@ void TheoryEngineModelBuilder::computeAssignableInfo(
   }
 }
 
-void TheoryEngineModelBuilder::postProcessModel(bool incomplete, Model* m)
+void TheoryEngineModelBuilder::postProcessModel(bool incomplete, TheoryModel* m)
 {
   // if we are incomplete, there is no guarantee on the model.
   // thus, we do not check the model here.
@@ -1090,12 +1090,11 @@ void TheoryEngineModelBuilder::postProcessModel(bool incomplete, Model* m)
   {
     return;
   }
-  TheoryModel* tm = static_cast<TheoryModel*>(m);
-  Assert(tm != nullptr);
+  Assert(m != nullptr);
   // debug-check the model if the checkModels() is enabled.
   if (options::debugCheckModels())
   {
-    debugCheckModel(tm);
+    debugCheckModel(m);
   }
 }
 
index 996609dd3b152461a06727c29cc4ebaefa7a37fc..4ffcbeee774e8ed585015f4cae45bfa6125fdad3 100644 (file)
@@ -81,7 +81,7 @@ class TheoryEngineModelBuilder
    * method checks the internal consistency of the model if we are in a debug
    * build.
    */
-  void postProcessModel(bool incomplete, Model* m);
+  void postProcessModel(bool incomplete, TheoryModel* m);
 
  protected:
   /** pointer to theory engine */