Print response to get-model using the API (#7084)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 1 Sep 2021 18:05:48 +0000 (13:05 -0500)
committerGitHub <noreply@github.com>
Wed, 1 Sep 2021 18:05:48 +0000 (18:05 +0000)
This changes our implementation of GetModelCommand so that we use the API to print the model.

It simplifies smt::Model so that this is a pretty printing utility, and not a layer on top of TheoryModel.

It adds getModel as an API method for returning the string representation of the model, analogous to our current support for getProof.

This eliminates the last call to getSmtEngine() from the command layer.

23 files changed:
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5.h
src/main/command_executor.h
src/printer/ast/ast_printer.cpp
src/printer/ast/ast_printer.h
src/printer/cvc/cvc_printer.cpp
src/printer/cvc/cvc_printer.h
src/printer/printer.cpp
src/printer/printer.h
src/printer/smt2/smt2_printer.cpp
src/printer/smt2/smt2_printer.h
src/printer/tptp/tptp_printer.cpp
src/printer/tptp/tptp_printer.h
src/smt/check_models.cpp
src/smt/check_models.h
src/smt/command.cpp
src/smt/command.h
src/smt/model.cpp
src/smt/model.h
src/smt/smt_engine.cpp
src/smt/smt_engine.h
test/regress/regress0/cvc-rerror-print.cvc
test/unit/api/solver_black.cpp

index d03b8975e39501ffab13169e39fcacbdee42bf1e..d6c0a58ee26bfe7b3f67666ad82143e9a1c56067 100644 (file)
@@ -7381,6 +7381,36 @@ bool Solver::isModelCoreSymbol(const Term& v) const
   CVC5_API_TRY_CATCH_END;
 }
 
+std::string Solver::getModel(const std::vector<Sort>& sorts,
+                             const std::vector<Term>& vars) const
+{
+  CVC5_API_TRY_CATCH_BEGIN;
+  NodeManagerScope scope(getNodeManager());
+  CVC5_API_RECOVERABLE_CHECK(d_smtEngine->getOptions().smt.produceModels)
+      << "Cannot get model unless model generation is enabled "
+         "(try --produce-models)";
+  CVC5_API_RECOVERABLE_CHECK(d_smtEngine->isSmtModeSat())
+      << "Cannot get model unless after a SAT or unknown response.";
+  CVC5_API_SOLVER_CHECK_SORTS(sorts);
+  for (const Sort& s : sorts)
+  {
+    CVC5_API_RECOVERABLE_CHECK(s.isUninterpretedSort())
+        << "Expecting an uninterpreted sort as argument to "
+           "getModel.";
+  }
+  CVC5_API_SOLVER_CHECK_TERMS(vars);
+  for (const Term& v : vars)
+  {
+    CVC5_API_RECOVERABLE_CHECK(v.getKind() == CONSTANT)
+        << "Expecting a free constant as argument to getModel.";
+  }
+  //////// all checks before this line
+  return d_smtEngine->getModel(Sort::sortVectorToTypeNodes(sorts),
+                               Term::termVectorToNodes(vars));
+  ////////
+  CVC5_API_TRY_CATCH_END;
+}
+
 Term Solver::getQuantifierElimination(const Term& q) const
 {
   NodeManagerScope scope(getNodeManager());
@@ -7900,12 +7930,6 @@ std::vector<Term> Solver::getSynthSolutions(
   CVC5_API_TRY_CATCH_END;
 }
 
-/*
- * !!! This is only temporarily available until the parser is fully migrated to
- * the new API. !!!
- */
-SmtEngine* Solver::getSmtEngine(void) const { return d_smtEngine.get(); }
-
 Statistics Solver::getStatistics() const
 {
   return Statistics(d_smtEngine->getStatisticsRegistry());
index 11b138a50d21ede03f596d5d04436ff868ed9e67..a221f37116763b565e8a4575b5de471d178f5d78 100644 (file)
@@ -3950,6 +3950,22 @@ class CVC5_EXPORT Solver
    */
   bool isModelCoreSymbol(const Term& v) const;
 
+  /**
+   * Get the model
+   * SMT-LIB:
+   * \verbatim
+   * ( get-model )
+   * \endverbatim
+   * Requires to enable option 'produce-models'.
+   * @param sorts The list of uninterpreted sorts that should be printed in the
+   * model.
+   * @param vars The list of free constants that should be printed in the
+   * model. A subset of these may be printed based on isModelCoreSymbol.
+   * @return a string representing the model.
+   */
+  std::string getModel(const std::vector<Sort>& sorts,
+                       const std::vector<Term>& vars) const;
+
   /**
    * Do quantifier elimination.
    * SMT-LIB:
@@ -4329,10 +4345,6 @@ class CVC5_EXPORT Solver
    */
   std::vector<Term> getSynthSolutions(const std::vector<Term>& terms) const;
 
-  // !!! This is only temporarily available until the parser is fully migrated
-  // to the new API. !!!
-  SmtEngine* getSmtEngine(void) const;
-
   /**
    * Returns a snapshot of the current state of the statistic values of this
    * solver. The returned object is completely decoupled from the solver and
index 0a7a56e5bb08d63d850d2f11bfcea76c7726f2ec..1e8d848a453fab3662c31b60db9a0be4aa3d42c4 100644 (file)
@@ -27,10 +27,6 @@ namespace cvc5 {
 
 class Command;
 
-namespace smt {
-class SmtEngine;
-}
-
 namespace main {
 
 class CommandExecutor
@@ -81,8 +77,6 @@ class CommandExecutor
   api::Result getResult() const { return d_result; }
   void reset();
 
-  SmtEngine* getSmtEngine() const { return d_solver->getSmtEngine(); }
-
   /** Store the current options as the original options */
   void storeOptionsAsOriginal();
 
index 7c1a0e8874f38a4c314fe79903dc1461cffff2b1..75219840ad8e386d6fc54cd12db9c296df071f81 100644 (file)
@@ -141,16 +141,16 @@ void AstPrinter::toStream(std::ostream& out, const smt::Model& m) const
 }
 
 void AstPrinter::toStreamModelSort(std::ostream& out,
-                                   const smt::Model& m,
-                                   TypeNode tn) const
+                                   TypeNode tn,
+                                   const std::vector<Node>& elements) const
 {
   // shouldn't be called; only the non-Command* version above should be
   Unreachable();
 }
 
 void AstPrinter::toStreamModelTerm(std::ostream& out,
-                                   const smt::Model& m,
-                                   Node n) const
+                                   const Node& n,
+                                   const Node& value) const
 {
   // shouldn't be called; only the non-Command* version above should be
   Unreachable();
index da5785f9f00380be437388235b8c3aaf447b9af8..fd4775da44b1314104e01b8c6911d85565d20c28 100644 (file)
@@ -173,16 +173,16 @@ class AstPrinter : public cvc5::Printer
    * tn declared via declare-sort or declare-datatype.
    */
   void toStreamModelSort(std::ostream& out,
-                         const smt::Model& m,
-                         TypeNode tn) const override;
+                         TypeNode tn,
+                         const std::vector<Node>& elements) const override;
 
   /**
    * To stream model term. This prints the appropriate output for term
    * n declared via declare-fun.
    */
   void toStreamModelTerm(std::ostream& out,
-                         const smt::Model& m,
-                         Node n) const override;
+                         const Node& n,
+                         const Node& value) const override;
   /**
    * To stream with let binding. This prints n, possibly in the scope
    * of letification generated by this method based on lbind.
index 1f1296b7f0c8ece7ce953fb9ca3d9bb9b4a2b060..04274ddc3a8791193165f1811c31c72d40e59d8a 100644 (file)
@@ -1029,8 +1029,8 @@ void CvcPrinter::toStream(std::ostream& out, const CommandStatus* s) const
 }/* CvcPrinter::toStream(CommandStatus*) */
 
 void CvcPrinter::toStreamModelSort(std::ostream& out,
-                                   const smt::Model& m,
-                                   TypeNode tn) const
+                                   TypeNode tn,
+                                   const std::vector<Node>& elements) const
 {
   if (!tn.isSort())
   {
@@ -1038,36 +1038,25 @@ void CvcPrinter::toStreamModelSort(std::ostream& out,
         << tn << std::endl;
     return;
   }
-  const theory::TheoryModel* tm = m.getTheoryModel();
-  const std::vector<Node>* type_reps = tm->getRepSet()->getTypeRepsOrNull(tn);
-  if (type_reps != nullptr)
+  out << "% cardinality of " << tn << " is " << elements.size() << std::endl;
+  toStreamCmdDeclareType(out, tn);
+  for (const Node& type_rep : elements)
   {
-    out << "% cardinality of " << tn << " is " << type_reps->size()
-        << std::endl;
-    toStreamCmdDeclareType(out, tn);
-    for (Node type_rep : *type_reps)
+    if (type_rep.isVar())
     {
-      if (type_rep.isVar())
-      {
-        out << type_rep << " : " << tn << ";" << std::endl;
-      }
-      else
-      {
-        out << "% rep: " << type_rep << std::endl;
-      }
+      out << type_rep << " : " << tn << ";" << std::endl;
+    }
+    else
+    {
+      out << "% rep: " << type_rep << std::endl;
     }
-  }
-  else
-  {
-    toStreamCmdDeclareType(out, tn);
   }
 }
 
 void CvcPrinter::toStreamModelTerm(std::ostream& out,
-                                   const smt::Model& m,
-                                   Node n) const
+                                   const Node& n,
+                                   const Node& value) const
 {
-  const theory::TheoryModel* tm = m.getTheoryModel();
   TypeNode tn = n.getType();
   out << n << " : ";
   if (tn.isFunction() || tn.isPredicate())
@@ -1087,10 +1076,7 @@ void CvcPrinter::toStreamModelTerm(std::ostream& out,
   {
     out << tn;
   }
-  // We get the value from the theory model directly, which notice
-  // does not have to go through the standard SmtEngine::getValue interface.
-  Node val = tm->getValue(n);
-  out << " = " << val << ";" << std::endl;
+  out << " = " << value << ";" << std::endl;
 }
 
 void CvcPrinter::toStream(std::ostream& out, const smt::Model& m) const
index 5552371775516d3332c06ca620eb654ff2615f2b..4851868d37d5c656b50d9dcd623a6e780e256504 100644 (file)
@@ -175,19 +175,19 @@ class CvcPrinter : public cvc5::Printer
                     LetBinding* lbind) const;
   /**
    * To stream model sort. This prints the appropriate output for type
-   * tn declared via declare-sort or declare-datatype.
+   * tn declared via declare-sort.
    */
   void toStreamModelSort(std::ostream& out,
-                         const smt::Model& m,
-                         TypeNode tn) const override;
+                         TypeNode tn,
+                         const std::vector<Node>& elements) const override;
 
   /**
    * To stream model term. This prints the appropriate output for term
    * n declared via declare-fun.
    */
   void toStreamModelTerm(std::ostream& out,
-                         const smt::Model& m,
-                         Node n) const override;
+                         const Node& n,
+                         const Node& value) const override;
   /**
    * To stream with let binding. This prints n, possibly in the scope
    * of letification generated by this method based on lbind.
index 59122cf3d8f3680a6a3e3f32e8b23589c2c76c31..e038952c48d52e6a3d6a04ba332549e9a3f50a7d 100644 (file)
@@ -66,22 +66,15 @@ void Printer::toStream(std::ostream& out, const smt::Model& m) const
   const std::vector<TypeNode>& dsorts = m.getDeclaredSorts();
   for (const TypeNode& tn : dsorts)
   {
-    toStreamModelSort(out, m, tn);
+    toStreamModelSort(out, tn, m.getDomainElements(tn));
   }
-
   // print the declared terms
   const std::vector<Node>& dterms = m.getDeclaredTerms();
   for (const Node& n : dterms)
   {
-    // take into account model core, independently of the format
-    if (!m.isModelCoreSymbol(n))
-    {
-      continue;
-    }
-    toStreamModelTerm(out, m, n);
+    toStreamModelTerm(out, n, m.getValue(n));
   }
-
-}/* Printer::toStream(Model) */
+}
 
 void Printer::toStreamUsing(Language lang,
                             std::ostream& out,
index 05ac8879b6b2e6641b37458cb95c0956453d6fce..5e141fe8fdfd0d228b410fdff4499911fb1c1061 100644 (file)
@@ -271,19 +271,19 @@ class Printer
 
   /**
    * To stream model sort. This prints the appropriate output for type
-   * tn declared via declare-sort or declare-datatype.
+   * tn declared via declare-sort.
    */
   virtual void toStreamModelSort(std::ostream& out,
-                                 const smt::Model& m,
-                                 TypeNode tn) const = 0;
+                                 TypeNode tn,
+                                 const std::vector<Node>& elements) const = 0;
 
   /**
    * To stream model term. This prints the appropriate output for term
    * n declared via declare-fun.
    */
   virtual void toStreamModelTerm(std::ostream& out,
-                                 const smt::Model& m,
-                                 Node n) const = 0;
+                                 const Node& n,
+                                 const Node& value) const = 0;
 
   /** write model response to command using another language printer */
   void toStreamUsing(Language lang,
index 8a23a59ea4e3adfb23d216f12a40534a08bbe6f6..0d556c1dc46e345248869eac7e239eacfafc9424 100644 (file)
@@ -1262,7 +1262,6 @@ void Smt2Printer::toStream(std::ostream& out, const UnsatCore& core) const
 
 void Smt2Printer::toStream(std::ostream& out, const smt::Model& m) const
 {
-  const theory::TheoryModel* tm = m.getTheoryModel();
   //print the model
   out << "(" << endl;
   // don't need to print approximations since they are built into choice
@@ -1271,7 +1270,7 @@ void Smt2Printer::toStream(std::ostream& out, const smt::Model& m) const
   out << ")" << endl;
   //print the heap model, if it exists
   Node h, neq;
-  if (tm->getHeapModel(h, neq))
+  if (m.getHeapModel(h, neq))
   {
     // description of the heap+what nil is equal to fully describes model
     out << "(heap" << endl;
@@ -1282,8 +1281,8 @@ void Smt2Printer::toStream(std::ostream& out, const smt::Model& m) const
 }
 
 void Smt2Printer::toStreamModelSort(std::ostream& out,
-                                    const smt::Model& m,
-                                    TypeNode tn) const
+                                    TypeNode tn,
+                                    const std::vector<Node>& elements) const
 {
   if (!tn.isSort())
   {
@@ -1291,8 +1290,6 @@ void Smt2Printer::toStreamModelSort(std::ostream& out,
         << tn << std::endl;
     return;
   }
-  const theory::TheoryModel* tm = m.getTheoryModel();
-  std::vector<Node> elements = tm->getDomainElements(tn);
   // print the cardinality
   out << "; cardinality of " << tn << " is " << elements.size() << endl;
   if (options::modelUninterpPrint()
@@ -1322,26 +1319,22 @@ void Smt2Printer::toStreamModelSort(std::ostream& out,
 }
 
 void Smt2Printer::toStreamModelTerm(std::ostream& out,
-                                    const smt::Model& m,
-                                    Node n) const
+                                    const Node& n,
+                                    const Node& value) const
 {
-  const theory::TheoryModel* tm = m.getTheoryModel();
-  // We get the value from the theory model directly, which notice
-  // does not have to go through the standard SmtEngine::getValue interface.
-  Node val = tm->getValue(n);
-  if (val.getKind() == kind::LAMBDA)
+  if (value.getKind() == kind::LAMBDA)
   {
     TypeNode rangeType = n.getType().getRangeType();
-    out << "(define-fun " << n << " " << val[0] << " " << rangeType << " ";
+    out << "(define-fun " << n << " " << value[0] << " " << rangeType << " ";
     // call toStream and force its type to be proper
-    toStreamCastToType(out, val[1], -1, rangeType);
+    toStreamCastToType(out, value[1], -1, rangeType);
     out << ")" << endl;
   }
   else
   {
     out << "(define-fun " << n << " () " << n.getType() << " ";
     // call toStream and force its type to be proper
-    toStreamCastToType(out, val, -1, n.getType());
+    toStreamCastToType(out, value, -1, n.getType());
     out << ")" << endl;
   }
 }
index 15f45a10e199fbf6a3999f1fec50fad8a7f6eb47..729caebf4502d8fae5b50d86d5f6d47aece8f7c2 100644 (file)
@@ -252,16 +252,16 @@ class Smt2Printer : public cvc5::Printer
    * tn declared via declare-sort or declare-datatype.
    */
   void toStreamModelSort(std::ostream& out,
-                         const smt::Model& m,
-                         TypeNode tn) const override;
+                         TypeNode tn,
+                         const std::vector<Node>& elements) const override;
 
   /**
    * To stream model term. This prints the appropriate output for term
    * n declared via declare-fun.
    */
   void toStreamModelTerm(std::ostream& out,
-                         const smt::Model& m,
-                         Node n) const override;
+                         const Node& n,
+                         const Node& value) const override;
 
   /**
    * To stream with let binding. This prints n, possibly in the scope
index bb8df120e89520e8e233fc53af04a14af1c9c551..6c874670615e1862cc5e2fd7a146f0db11e2c3cb 100644 (file)
@@ -58,16 +58,16 @@ void TptpPrinter::toStream(std::ostream& out, const smt::Model& m) const
 }
 
 void TptpPrinter::toStreamModelSort(std::ostream& out,
-                                    const smt::Model& m,
-                                    TypeNode tn) const
+                                    TypeNode tn,
+                                    const std::vector<Node>& elements) const
 {
   // shouldn't be called; only the non-Command* version above should be
   Unreachable();
 }
 
 void TptpPrinter::toStreamModelTerm(std::ostream& out,
-                                    const smt::Model& m,
-                                    Node n) const
+                                    const Node& n,
+                                    const Node& value) const
 {
   // shouldn't be called; only the non-Command* version above should be
   Unreachable();
index db86de8cfca6e7ac308688c7cd9a7323901a8aca..9d288b4ed296eb65d10c8ddcf3eb0e266efb1ef1 100644 (file)
@@ -48,16 +48,16 @@ class TptpPrinter : public cvc5::Printer
    * tn declared via declare-sort or declare-datatype.
    */
   void toStreamModelSort(std::ostream& out,
-                         const smt::Model& m,
-                         TypeNode tn) const override;
+                         TypeNode tn,
+                         const std::vector<Node>& elements) const override;
 
   /**
    * To stream model term. This prints the appropriate output for term
    * n declared via declare-fun.
    */
   void toStreamModelTerm(std::ostream& out,
-                         const smt::Model& m,
-                         Node n) const override;
+                         const Node& n,
+                         const Node& value) const override;
 
 }; /* class TptpPrinter */
 
index 0bc7ce99bdc9b755e571ccbe401ec1ac1796986e..d3a2dfefa46d386e2f94a04cb5159e04aec355c4 100644 (file)
 #include "base/modal_exception.h"
 #include "options/smt_options.h"
 #include "smt/env.h"
-#include "smt/model.h"
 #include "smt/node_command.h"
 #include "smt/preprocessor.h"
 #include "smt/smt_solver.h"
 #include "theory/rewriter.h"
 #include "theory/substitutions.h"
 #include "theory/theory_engine.h"
+#include "theory/theory_model.h"
 
 using namespace cvc5::theory;
 
@@ -34,7 +34,7 @@ namespace smt {
 CheckModels::CheckModels(Env& e) : d_env(e) {}
 CheckModels::~CheckModels() {}
 
-void CheckModels::checkModel(Model* m,
+void CheckModels::checkModel(TheoryModel* m,
                              context::CDList<Node>* al,
                              bool hardFailure)
 {
index ce06bae075215d6958d04b7a984915424d99f394..fbfb1c2f5313c90b9a1f56e699c00102f839cecf 100644 (file)
@@ -25,9 +25,11 @@ namespace cvc5 {
 
 class Env;
 
-namespace smt {
+namespace theory {
+class TheoryModel;
+}
 
-class Model;
+namespace smt {
 
 /**
  * This utility is responsible for checking the current model.
@@ -43,7 +45,9 @@ class CheckModels
    * This throws an exception if we fail to verify that m is a proper model
    * given assertion list al based on the model checking policy.
    */
-  void checkModel(Model* m, context::CDList<Node>* al, bool hardFailure);
+  void checkModel(theory::TheoryModel* m,
+                  context::CDList<Node>* al,
+                  bool hardFailure);
 
  private:
   /** Reference to the environment */
index e6be0a646aed450650f72a79478f0205a4d68a01..008d7a6d87efdd40ca555d0107f2e74b86c4cb8b 100644 (file)
@@ -38,8 +38,6 @@
 #include "proof/unsat_core.h"
 #include "smt/dump.h"
 #include "smt/model.h"
-#include "smt/smt_engine.h"
-#include "smt/smt_engine_scope.h"
 #include "util/unsafe_interrupt_exception.h"
 #include "util/utility.h"
 
@@ -1748,27 +1746,17 @@ void GetAssignmentCommand::toStream(std::ostream& out,
 /* class GetModelCommand                                                      */
 /* -------------------------------------------------------------------------- */
 
-GetModelCommand::GetModelCommand() : d_result(nullptr) {}
+GetModelCommand::GetModelCommand() {}
 void GetModelCommand::invoke(api::Solver* solver, SymbolManager* sm)
 {
   try
   {
-    d_result = solver->getSmtEngine()->getModel();
-    // set the model declarations, which determines what is printed in the model
-    d_result->clearModelDeclarations();
     std::vector<api::Sort> declareSorts = sm->getModelDeclareSorts();
-    for (const api::Sort& s : declareSorts)
-    {
-      d_result->addDeclarationSort(sortToTypeNode(s));
-    }
     std::vector<api::Term> declareTerms = sm->getModelDeclareTerms();
-    for (const api::Term& t : declareTerms)
-    {
-      d_result->addDeclarationTerm(termToNode(t));
-    }
+    d_result = solver->getModel(declareSorts, declareTerms);
     d_commandStatus = CommandSuccess::instance();
   }
-  catch (RecoverableModalException& e)
+  catch (api::CVC5ApiRecoverableException& e)
   {
     d_commandStatus = new CommandRecoverableFailure(e.what());
   }
@@ -1782,12 +1770,6 @@ void GetModelCommand::invoke(api::Solver* solver, SymbolManager* sm)
   }
 }
 
-/* Model is private to the library -- for now
-Model* GetModelCommand::getResult() const  {
-  return d_result;
-}
-*/
-
 void GetModelCommand::printResult(std::ostream& out, uint32_t verbosity) const
 {
   if (!ok())
@@ -1796,13 +1778,13 @@ void GetModelCommand::printResult(std::ostream& out, uint32_t verbosity) const
   }
   else
   {
-    out << *d_result;
+    out << d_result;
   }
 }
 
 Command* GetModelCommand::clone() const
 {
-  GetModelCommand* c = new GetModelCommand();
+  GetModelCommand* c = new GetModelCommand;
   c->d_result = d_result;
   return c;
 }
index d3e3679d2919b879cb5a6074f5c393f2887e0105..627cb13c953029097d393038e302c82ce39542fb 100644 (file)
@@ -950,7 +950,8 @@ class CVC5_EXPORT GetModelCommand : public Command
                 Language language = Language::LANG_AUTO) const override;
 
  protected:
-  smt::Model* d_result;
+  /** Result of printing the model */
+  std::string d_result;
 }; /* class GetModelCommand */
 
 /** The command to block models. */
index cf6a90f12ddefadde69aa0d35389c27dde13f689..9a195cb511a8db26970904833da0a6e8f5a5c678 100644 (file)
 #include "expr/expr_iomanip.h"
 #include "options/base_options.h"
 #include "printer/printer.h"
-#include "smt/dump_manager.h"
-#include "smt/node_command.h"
-#include "smt/smt_engine.h"
-#include "smt/smt_engine_scope.h"
-#include "theory/theory_model.h"
 
 namespace cvc5 {
 namespace smt {
 
-Model::Model(theory::TheoryModel* tm) : d_isKnownSat(false), d_tmodel(tm)
+Model::Model(bool isKnownSat, const std::string& inputName)
+    : d_inputName(inputName), d_isKnownSat(isKnownSat)
 {
-  Assert(d_tmodel != nullptr);
 }
 
 std::ostream& operator<<(std::ostream& out, const Model& m) {
@@ -38,31 +33,55 @@ std::ostream& operator<<(std::ostream& out, const Model& m) {
   return out;
 }
 
-theory::TheoryModel* Model::getTheoryModel() { return d_tmodel; }
+const std::vector<Node>& Model::getDomainElements(TypeNode tn) const
+{
+  std::map<TypeNode, std::vector<Node>>::const_iterator it =
+      d_domainElements.find(tn);
+  Assert(it != d_domainElements.end());
+  return it->second;
+}
 
-const theory::TheoryModel* Model::getTheoryModel() const { return d_tmodel; }
+Node Model::getValue(TNode n) const
+{
+  std::map<Node, Node>::const_iterator it = d_declareTermValues.find(n);
+  Assert(it != d_declareTermValues.end());
+  return it->second;
+}
 
-bool Model::isModelCoreSymbol(TNode sym) const
+bool Model::getHeapModel(Node& h, Node& nilEq) const
 {
-  return d_tmodel->isModelCoreSymbol(sym);
+  if (d_sepHeap.isNull() || d_sepNilEq.isNull())
+  {
+    return false;
+  }
+  h = d_sepHeap;
+  nilEq = d_sepNilEq;
+  return true;
 }
-Node Model::getValue(TNode n) const { return d_tmodel->getValue(n); }
 
-bool Model::hasApproximations() const { return d_tmodel->hasApproximations(); }
+void Model::addDeclarationSort(TypeNode tn, const std::vector<Node>& elements)
+{
+  d_declareSorts.push_back(tn);
+  d_domainElements[tn] = elements;
+}
 
-void Model::clearModelDeclarations()
+void Model::addDeclarationTerm(Node n, Node value)
 {
-  d_declareTerms.clear();
-  d_declareSorts.clear();
+  d_declareTerms.push_back(n);
+  d_declareTermValues[n] = value;
 }
 
-void Model::addDeclarationSort(TypeNode tn) { d_declareSorts.push_back(tn); }
+void Model::setHeapModel(Node h, Node nilEq)
+{
+  d_sepHeap = h;
+  d_sepNilEq = nilEq;
+}
 
-void Model::addDeclarationTerm(Node n) { d_declareTerms.push_back(n); }
 const std::vector<TypeNode>& Model::getDeclaredSorts() const
 {
   return d_declareSorts;
 }
+
 const std::vector<Node>& Model::getDeclaredTerms() const
 {
   return d_declareTerms;
index 342a9f3b08b686e33f4021ace094d08d8f8c2c7d..5275ea6807b62a62069bcd29677ec309d7c3afad 100644 (file)
@@ -15,8 +15,8 @@
 
 #include "cvc5_private.h"
 
-#ifndef CVC5__MODEL_H
-#define CVC5__MODEL_H
+#ifndef CVC5__SMT__MODEL_H
+#define CVC5__SMT__MODEL_H
 
 #include <iosfwd>
 #include <vector>
 #include "expr/node.h"
 
 namespace cvc5 {
-
-class SmtEngine;
-
-namespace theory {
-class TheoryModel;
-}
-
 namespace smt {
 
 class Model;
@@ -38,22 +31,15 @@ class Model;
 std::ostream& operator<<(std::ostream&, const Model&);
 
 /**
- * This is the SMT-level model object, that is responsible for maintaining
- * the necessary information for how to print the model, as well as
- * holding a pointer to the underlying implementation of the theory model.
- *
- * The model declarations maintained by this class are context-independent
- * and should be updated when this model is printed.
+ * A utility for representing a model for pretty printing.
  */
 class Model {
-  friend std::ostream& operator<<(std::ostream&, const Model&);
-  friend class ::cvc5::SmtEngine;
-
  public:
-  /** construct */
-  Model(theory::TheoryModel* tm);
-  /** virtual destructor */
-  ~Model() {}
+  /** Constructor
+   * @param isKnownSat True if this model is associated with a "sat" response,
+   * or false if it is associated with an "unknown" response.
+   */
+  Model(bool isKnownSat, const std::string& inputName);
   /** get the input name (file name, etc.) this model is associated to */
   std::string getInputName() const { return d_inputName; }
   /**
@@ -63,31 +49,37 @@ class Model {
    * only a candidate solution.
    */
   bool isKnownSat() const { return d_isKnownSat; }
-  /** Get the underlying theory model */
-  theory::TheoryModel* getTheoryModel();
-  /** Get the underlying theory model (const version) */
-  const theory::TheoryModel* getTheoryModel() const;
-  //----------------------- helper methods in the underlying theory model
-  /** Is the node n a model core symbol? */
-  bool isModelCoreSymbol(TNode sym) const;
+  /** Get domain elements */
+  const std::vector<Node>& getDomainElements(TypeNode tn) const;
   /** Get value */
   Node getValue(TNode n) const;
-  /** Does this model have approximations? */
-  bool hasApproximations() const;
-  //----------------------- end helper methods
+  /** Get separation logic heap and nil, return true if they have been set */
+  bool getHeapModel(Node& h, Node& nilEq) const;
   //----------------------- model declarations
-  /** Clear the current model declarations. */
-  void clearModelDeclarations();
   /**
    * Set that tn is a sort that should be printed in the model, when applicable,
    * based on the output language.
+   *
+   * @param tn The uninterpreted sort
+   * @param elements The domain elements of tn in the model
    */
-  void addDeclarationSort(TypeNode tn);
+  void addDeclarationSort(TypeNode tn, const std::vector<Node>& elements);
   /**
    * Set that n is a variable that should be printed in the model, when
    * applicable, based on the output language.
+   *
+   * @param n The variable
+   * @param value The value of the variable in the model
+   */
+  void addDeclarationTerm(Node n, Node value);
+  /**
+   * Set the separation logic model information where h is the heap and nilEq
+   * is the value of sep.nil.
+   *
+   * @param h The value of heap in the heap model
+   * @param nilEq The value of sep.nil in the heap model
    */
-  void addDeclarationTerm(Node n);
+  void setHeapModel(Node h, Node nilEq);
   /** get declared sorts */
   const std::vector<TypeNode>& getDeclaredSorts() const;
   /** get declared terms */
@@ -101,24 +93,26 @@ class Model {
    * from the solver.
    */
   bool d_isKnownSat;
-  /**
-   * Pointer to the underlying theory model, which maintains all data regarding
-   * the values of sorts and terms.
-   */
-  theory::TheoryModel* d_tmodel;
   /**
    * The list of types to print, generally corresponding to declare-sort
    * commands.
    */
   std::vector<TypeNode> d_declareSorts;
+  /** The interpretation of the above sorts, as a list of domain elements. */
+  std::map<TypeNode, std::vector<Node>> d_domainElements;
   /**
    * The list of terms to print, is typically one-to-one with declare-fun
    * commands.
    */
   std::vector<Node> d_declareTerms;
+  /** Mapping terms to values */
+  std::map<Node, Node> d_declareTermValues;
+  /** Separation logic heap and nil */
+  Node d_sepHeap;
+  Node d_sepNilEq;
 };
 
 }  // namespace smt
 }  // namespace cvc5
 
-#endif /* CVC5__MODEL_H */
+#endif /* CVC5__SMT__MODEL_H */
index 2276956b59fcbf7532c2904b2d346f0815f729e2..27e7b85300fb7174fd16eb85eb7daf1fb9abbfc2 100644 (file)
@@ -92,7 +92,6 @@ SmtEngine::SmtEngine(NodeManager* nm, const Options* optr)
       d_routListener(new ResourceOutListener(*this)),
       d_snmListener(new SmtNodeManagerListener(*getDumpManager(), d_outMgr)),
       d_smtSolver(nullptr),
-      d_model(nullptr),
       d_checkModels(nullptr),
       d_pfManager(nullptr),
       d_ucManager(nullptr),
@@ -224,7 +223,6 @@ void SmtEngine::finishInit()
   TheoryModel* tm = te->getModel();
   if (tm != nullptr)
   {
-    d_model.reset(new Model(tm));
     // make the check models utility
     d_checkModels.reset(new CheckModels(*d_env.get()));
   }
@@ -305,7 +303,6 @@ SmtEngine::~SmtEngine()
 
     d_absValues.reset(nullptr);
     d_asserts.reset(nullptr);
-    d_model.reset(nullptr);
 
     d_abductSolver.reset(nullptr);
     d_interpolSolver.reset(nullptr);
@@ -712,7 +709,7 @@ Result SmtEngine::quickCheck() {
       Result::ENTAILMENT_UNKNOWN, Result::REQUIRES_FULL_CHECK, filename);
 }
 
-Model* SmtEngine::getAvailableModel(const char* c) const
+TheoryModel* SmtEngine::getAvailableModel(const char* c) const
 {
   if (!d_env->getOptions().theory.assignFunctionValues)
   {
@@ -751,7 +748,7 @@ Model* SmtEngine::getAvailableModel(const char* c) const
     throw RecoverableModalException(ss.str().c_str());
   }
 
-  return d_model.get();
+  return m;
 }
 
 QuantifiersEngine* SmtEngine::getAvailableQuantifiersEngine(const char* c) const
@@ -1121,7 +1118,7 @@ Node SmtEngine::getValue(const Node& ex) const
   }
 
   Trace("smt") << "--- getting value of " << n << endl;
-  Model* m = getAvailableModel("get-value");
+  TheoryModel* m = getAvailableModel("get-value");
   Assert(m != nullptr);
   Node resultNode = m->getValue(n);
   Trace("smt") << "--- got value " << n << " = " << resultNode << endl;
@@ -1163,8 +1160,8 @@ std::vector<Node> SmtEngine::getValues(const std::vector<Node>& exprs) const
 std::vector<Node> SmtEngine::getModelDomainElements(TypeNode tn) const
 {
   Assert(tn.isSort());
-  Model* m = getAvailableModel("getModelDomainElements");
-  return m->getTheoryModel()->getDomainElements(tn);
+  TheoryModel* m = getAvailableModel("getModelDomainElements");
+  return m->getDomainElements(tn);
 }
 
 bool SmtEngine::isModelCoreSymbol(Node n)
@@ -1177,8 +1174,7 @@ bool SmtEngine::isModelCoreSymbol(Node n)
     // if the model core mode is none, we are always a model core symbol
     return true;
   }
-  Model* m = getAvailableModel("isModelCoreSymbol");
-  TheoryModel* tm = m->getTheoryModel();
+  TheoryModel* tm = getAvailableModel("isModelCoreSymbol");
   // compute the model core if not done so already
   if (!tm->isUsingModelCore())
   {
@@ -1193,41 +1189,54 @@ bool SmtEngine::isModelCoreSymbol(Node n)
   return tm->isModelCoreSymbol(n);
 }
 
-// TODO(#1108): Simplify the error reporting of this method.
-Model* SmtEngine::getModel() {
-  Trace("smt") << "SMT getModel()" << endl;
+std::string SmtEngine::getModel(const std::vector<TypeNode>& declaredSorts,
+                                const std::vector<Node>& declaredFuns)
+{
   SmtScope smts(this);
-
-  finishInit();
-
-  if (Dump.isOn("benchmark"))
+  // !!! Note that all methods called here should have a version at the API
+  // level. This is to ensure that the information associated with a model is
+  // completely accessible by the user. This is currently not rigorously
+  // enforced. An alternative design would be to have this method implemented
+  // at the API level, but this makes exceptions in the text interface less
+  // intuitive and makes it impossible to implement raw-benchmark at the
+  // SmtEngine level.
+  if (Dump.isOn("raw-benchmark"))
   {
     getPrinter().toStreamCmdGetModel(d_env->getDumpOut());
   }
-
-  Model* m = getAvailableModel("get model");
-
-  // Notice that the returned model is (currently) accessed by the
-  // GetModelCommand only, and is not returned to the user. The information
-  // in that model may become stale after it is returned. This is safe
-  // since GetModelCommand always calls this command again when it prints
-  // a model.
-
-  if (d_env->getOptions().smt.modelCoresMode
-      != options::ModelCoresMode::NONE)
+  TheoryModel* tm = getAvailableModel("get model");
+  // use the smt::Model model utility for printing
+  const Options& opts = d_env->getOptions();
+  bool isKnownSat = (d_state->getMode() == SmtMode::SAT);
+  Model m(isKnownSat, opts.driver.filename);
+  // set the model declarations, which determines what is printed in the model
+  for (const TypeNode& tn : declaredSorts)
   {
-    // If we enabled model cores, we compute a model core for m based on our
-    // (expanded) assertions using the model core builder utility
-    std::vector<Node> asserts = getAssertionsInternal();
-    d_pp->expandDefinitions(asserts);
-    ModelCoreBuilder::setModelCore(
-        asserts, m->getTheoryModel(), d_env->getOptions().smt.modelCoresMode);
+    m.addDeclarationSort(tn, getModelDomainElements(tn));
   }
-  // set the information on the SMT-level model
-  Assert(m != nullptr);
-  m->d_inputName = d_env->getOptions().driver.filename;
-  m->d_isKnownSat = (d_state->getMode() == SmtMode::SAT);
-  return m;
+  bool usingModelCores =
+      (opts.smt.modelCoresMode != options::ModelCoresMode::NONE);
+  for (const Node& n : declaredFuns)
+  {
+    if (usingModelCores && !tm->isModelCoreSymbol(n))
+    {
+      // skip if not in model core
+      continue;
+    }
+    Node value = tm->getValue(n);
+    m.addDeclarationTerm(n, value);
+  }
+  // for separation logic
+  TypeNode locT, dataT;
+  if (getSepHeapTypes(locT, dataT))
+  {
+    std::pair<Node, Node> sh = getSepHeapAndNilExpr();
+    m.setHeapModel(sh.first, sh.second);
+  }
+  // print the model
+  std::stringstream ssm;
+  ssm << m;
+  return ssm.str();
 }
 
 Result SmtEngine::blockModel()
@@ -1242,7 +1251,7 @@ Result SmtEngine::blockModel()
     getPrinter().toStreamCmdBlockModel(d_env->getDumpOut());
   }
 
-  Model* m = getAvailableModel("block model");
+  TheoryModel* m = getAvailableModel("block model");
 
   if (d_env->getOptions().smt.blockModelsMode
       == options::BlockModelsMode::NONE)
@@ -1254,10 +1263,8 @@ Result SmtEngine::blockModel()
 
   // get expanded assertions
   std::vector<Node> eassertsProc = getExpandedAssertions();
-  Node eblocker =
-      ModelBlocker::getModelBlocker(eassertsProc,
-                                    m->getTheoryModel(),
-                                    d_env->getOptions().smt.blockModelsMode);
+  Node eblocker = ModelBlocker::getModelBlocker(
+      eassertsProc, m, d_env->getOptions().smt.blockModelsMode);
   Trace("smt") << "Block formula: " << eblocker << std::endl;
   return assertFormula(eblocker);
 }
@@ -1274,16 +1281,13 @@ Result SmtEngine::blockModelValues(const std::vector<Node>& exprs)
     getPrinter().toStreamCmdBlockModelValues(d_env->getDumpOut(), exprs);
   }
 
-  Model* m = getAvailableModel("block model values");
+  TheoryModel* m = getAvailableModel("block model values");
 
   // get expanded assertions
   std::vector<Node> eassertsProc = getExpandedAssertions();
   // we always do block model values mode here
-  Node eblocker =
-      ModelBlocker::getModelBlocker(eassertsProc,
-                                    m->getTheoryModel(),
-                                    options::BlockModelsMode::VALUES,
-                                    exprs);
+  Node eblocker = ModelBlocker::getModelBlocker(
+      eassertsProc, m, options::BlockModelsMode::VALUES, exprs);
   return assertFormula(eblocker);
 }
 
@@ -1299,8 +1303,7 @@ std::pair<Node, Node> SmtEngine::getSepHeapAndNilExpr(void)
   NodeManagerScope nms(getNodeManager());
   Node heap;
   Node nil;
-  Model* m = getAvailableModel("get separation logic heap and nil");
-  TheoryModel* tm = m->getTheoryModel();
+  TheoryModel* tm = getAvailableModel("get separation logic heap and nil");
   if (!tm->getHeapModel(heap, nil))
   {
     const char* msg =
@@ -1548,7 +1551,7 @@ void SmtEngine::checkModel(bool hardFailure) {
   TimerStat::CodeTimer checkModelTimer(d_stats->d_checkModelTime);
 
   Notice() << "SmtEngine::checkModel(): generating model" << endl;
-  Model* m = getAvailableModel("check model");
+  TheoryModel* m = getAvailableModel("check model");
   Assert(m != nullptr);
 
   // check the model with the theory engine for debugging
index 84501d35e660d985b165ee1bf9d954a1f286a501..06a1c9ae4785cca242a6e7981f0a9bd260415eeb 100644 (file)
@@ -42,7 +42,6 @@ class Env;
 class NodeManager;
 class TheoryEngine;
 class UnsatCore;
-class LogicRequest;
 class StatisticsRegistry;
 class Printer;
 class ResourceManager;
@@ -77,7 +76,6 @@ namespace prop {
 
 namespace smt {
 /** Utilities */
-class Model;
 class SmtEngineState;
 class AbstractValues;
 class Assertions;
@@ -104,9 +102,10 @@ class UnsatCoreManager;
 /* -------------------------------------------------------------------------- */
 
 namespace theory {
-  class Rewriter;
-  class QuantifiersEngine;
-  }  // namespace theory
+class TheoryModel;
+class Rewriter;
+class QuantifiersEngine;
+}  // namespace theory
 
 /* -------------------------------------------------------------------------- */
 
@@ -115,7 +114,6 @@ class CVC5_EXPORT SmtEngine
   friend class ::cvc5::api::Solver;
   friend class ::cvc5::smt::SmtEngineState;
   friend class ::cvc5::smt::SmtScope;
-  friend class ::cvc5::LogicRequest;
 
   /* .......................................................................  */
  public:
@@ -226,14 +224,6 @@ class CVC5_EXPORT SmtEngine
   /** Is this an internal subsolver? */
   bool isInternalSubsolver() const;
 
-  /**
-   * Get the model (only if immediately preceded by a SAT or NOT_ENTAILED
-   * query).  Only permitted if produce-models is on.
-   *
-   * TODO (issues#287): eliminate this method.
-   */
-  smt::Model* getModel();
-
   /**
    * Block the current model. Can be called only if immediately preceded by
    * a SAT or INVALID query. Only permitted if produce-models is on, and the
@@ -523,6 +513,19 @@ class CVC5_EXPORT SmtEngine
    */
   bool isModelCoreSymbol(Node v);
 
+  /**
+   * Get a model (only if immediately preceded by an SAT or unknown query).
+   * Only permitted if the model option is on.
+   *
+   * @param declaredSorts The sorts to print in the model
+   * @param declaredFuns The free constants to print in the model. A subset
+   * of these may be printed based on isModelCoreSymbol.
+   * @return the string corresponding to the model. If the output language is
+   * smt2, then this corresponds to a response to the get-model command.
+   */
+  std::string getModel(const std::vector<TypeNode>& declaredSorts,
+                       const std::vector<Node>& declaredFuns);
+
   /** print instantiations
    *
    * Print all instantiations for all quantified formulas on out,
@@ -936,7 +939,7 @@ class CVC5_EXPORT SmtEngine
    * @param c used for giving an error message to indicate the context
    * this method was called.
    */
-  smt::Model* getAvailableModel(const char* c) const;
+  theory::TheoryModel* getAvailableModel(const char* c) const;
   /**
    * Get available quantifiers engine, which throws a modal exception if it
    * does not exist. This can happen if a quantifiers-specific call (e.g.
@@ -1046,13 +1049,6 @@ class CVC5_EXPORT SmtEngine
   /** The SMT solver */
   std::unique_ptr<smt::SmtSolver> d_smtSolver;
 
-  /**
-   * The SMT-level model object, which contains information about how to
-   * print the model, as well as a pointer to the underlying TheoryModel
-   * implementation maintained by the SmtSolver.
-   */
-  std::unique_ptr<smt::Model> d_model;
-
   /**
    * The utility used for checking models
    */
index e134b56667cac782036eb6920a880f26a9e8f488..728db28d87cdebdbf1fe93228466e3d6cf104bfa 100644 (file)
@@ -1,5 +1,5 @@
 % EXPECT: entailed
-% EXPECT: Cannot get model unless immediately preceded by SAT/NOT_ENTAILED or UNKNOWN response.
+% EXPECT: Cannot get model unless after a SAT or unknown response.
 OPTION "logic" "ALL";
 OPTION "produce-models" true;
 x : INT;
index 1daa3fba4e481689a4f0351eeb72d938891e3de0..5ca96f0352ce87424df309cf15baa27f8d1355df 100644 (file)
@@ -1571,6 +1571,47 @@ TEST_F(TestApiBlackSolver, isModelCoreSymbol)
   ASSERT_THROW(d_solver.isModelCoreSymbol(zero), CVC5ApiException);
 }
 
+TEST_F(TestApiBlackSolver, getModel)
+{
+  d_solver.setOption("produce-models", "true");
+  Sort uSort = d_solver.mkUninterpretedSort("u");
+  Term x = d_solver.mkConst(uSort, "x");
+  Term y = d_solver.mkConst(uSort, "y");
+  Term z = d_solver.mkConst(uSort, "z");
+  Term f = d_solver.mkTerm(NOT, d_solver.mkTerm(EQUAL, x, y));
+  d_solver.assertFormula(f);
+  d_solver.checkSat();
+  std::vector<Sort> sorts;
+  sorts.push_back(uSort);
+  std::vector<Term> terms;
+  terms.push_back(x);
+  terms.push_back(y);
+  ASSERT_NO_THROW(d_solver.getModel(sorts, terms));
+  Term null;
+  terms.push_back(null);
+  ASSERT_THROW(d_solver.getModel(sorts, terms), CVC5ApiException);
+}
+
+TEST_F(TestApiBlackSolver, getModel2)
+{
+  d_solver.setOption("produce-models", "true");
+  std::vector<Sort> sorts;
+  std::vector<Term> terms;
+  ASSERT_THROW(d_solver.getModel(sorts, terms), CVC5ApiException);
+}
+
+TEST_F(TestApiBlackSolver, getModel3)
+{
+  d_solver.setOption("produce-models", "true");
+  std::vector<Sort> sorts;
+  std::vector<Term> terms;
+  d_solver.checkSat();
+  ASSERT_NO_THROW(d_solver.getModel(sorts, terms));
+  Sort integer = d_solver.getIntegerSort();
+  sorts.push_back(integer);
+  ASSERT_THROW(d_solver.getModel(sorts, terms), CVC5ApiException);
+}
+
 TEST_F(TestApiBlackSolver, getQuantifierElimination)
 {
   Term x = d_solver.mkVar(d_solver.getBooleanSort(), "x");