From 7c249b3efdeeb51fd3dfc2571bc529c55880cf5c Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 16 Oct 2020 13:32:42 -0500 Subject: [PATCH] Refactor SMT-level model object (#5277) This refactors the SMT-level model object so that it is a wrapper around TheoryModel instead of a base class. This inheritance was unnecessary. Moreover, it removes the virtual base models of the SMT-level model which were based on Expr. Now the interface is more minimal and in terms of Node only. This PR further simplifies a few places in the code that interface with the SmtEngine with things related to models. --- src/api/cvc4cpp.cpp | 16 +---- src/printer/ast/ast_printer.cpp | 4 +- src/printer/ast/ast_printer.h | 4 +- src/printer/cvc/cvc_printer.cpp | 13 ++-- src/printer/cvc/cvc_printer.h | 4 +- src/printer/printer.cpp | 4 +- src/printer/printer.h | 6 +- src/printer/smt2/smt2_printer.cpp | 26 ++++---- src/printer/smt2/smt2_printer.h | 4 +- src/printer/tptp/tptp_printer.cpp | 4 +- src/printer/tptp/tptp_printer.h | 4 +- src/smt/command.h | 5 +- src/smt/model.cpp | 27 +++++++-- src/smt/model.h | 93 +++++++++++------------------ src/smt/model_blocker.cpp | 4 +- src/smt/model_core_builder.cpp | 6 +- src/smt/model_core_builder.h | 4 +- src/smt/smt_engine.cpp | 65 ++++++++++++-------- src/smt/smt_engine.h | 21 ++++--- src/theory/theory_model.cpp | 87 +++++++++------------------ src/theory/theory_model.h | 25 ++++---- src/theory/theory_model_builder.cpp | 7 +-- src/theory/theory_model_builder.h | 2 +- 23 files changed, 205 insertions(+), 230 deletions(-) diff --git a/src/api/cvc4cpp.cpp b/src/api/cvc4cpp.cpp index 0384b573e..2417936a7 100644 --- a/src/api/cvc4cpp.cpp +++ b/src/api/cvc4cpp.cpp @@ -5211,13 +5211,6 @@ Term Solver::getSeparationHeap() const "(try --produce-models)"; CVC4_API_CHECK(d_smtEngine->getSmtMode() != SmtMode::UNSAT) << "Cannot get separtion heap term when in unsat mode."; - - theory::TheoryModel* m = - d_smtEngine->getAvailableModel("get separation logic heap and nil"); - Expr heap, nil; - bool hasHeapModel = m->getHeapModel(heap, nil); - CVC4_API_CHECK(hasHeapModel) - << "Failed to obtain heap term from theory model."; return Term(this, d_smtEngine->getSepHeapExpr()); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -5235,14 +5228,7 @@ Term Solver::getSeparationNilTerm() const "(try --produce-models)"; CVC4_API_CHECK(d_smtEngine->getSmtMode() != SmtMode::UNSAT) << "Cannot get separtion nil term when in unsat mode."; - - theory::TheoryModel* m = - d_smtEngine->getAvailableModel("get separation logic heap and nil"); - Expr heap, nil; - bool hasHeapModel = m->getHeapModel(heap, nil); - CVC4_API_CHECK(hasHeapModel) - << "Failed to obtain nil term from theory model."; - return Term(this, nil); + return Term(this, d_smtEngine->getSepNilExpr()); CVC4_API_SOLVER_TRY_CATCH_END; } diff --git a/src/printer/ast/ast_printer.cpp b/src/printer/ast/ast_printer.cpp index 062ebf037..1ed9d146c 100644 --- a/src/printer/ast/ast_printer.cpp +++ b/src/printer/ast/ast_printer.cpp @@ -150,13 +150,13 @@ void AstPrinter::toStream(std::ostream& out, const CommandStatus* s) const }/* AstPrinter::toStream(CommandStatus*) */ -void AstPrinter::toStream(std::ostream& out, const Model& m) const +void AstPrinter::toStream(std::ostream& out, const smt::Model& m) const { out << "Model()"; } void AstPrinter::toStream(std::ostream& out, - const Model& m, + const smt::Model& m, const NodeCommand* c) const { // shouldn't be called; only the non-Command* version above should be diff --git a/src/printer/ast/ast_printer.h b/src/printer/ast/ast_printer.h index b5feccdfa..f01436b8a 100644 --- a/src/printer/ast/ast_printer.h +++ b/src/printer/ast/ast_printer.h @@ -37,7 +37,7 @@ class AstPrinter : public CVC4::Printer bool types, size_t dag) const override; void toStream(std::ostream& out, const CommandStatus* s) const override; - void toStream(std::ostream& out, const Model& m) const override; + void toStream(std::ostream& out, const smt::Model& m) const override; /** Print empty command */ void toStreamCmdEmpty(std::ostream& out, @@ -174,7 +174,7 @@ class AstPrinter : public CVC4::Printer private: void toStream(std::ostream& out, TNode n, int toDepth, bool types) const; void toStream(std::ostream& out, - const Model& m, + const smt::Model& m, const NodeCommand* c) const override; }; /* class AstPrinter */ diff --git a/src/printer/cvc/cvc_printer.cpp b/src/printer/cvc/cvc_printer.cpp index 7fd26e1a0..bab619dce 100644 --- a/src/printer/cvc/cvc_printer.cpp +++ b/src/printer/cvc/cvc_printer.cpp @@ -1142,7 +1142,9 @@ void DeclareFunctionNodeCommandToStream( { out << tn; } - Node val = model.getSmtEngine()->getValue(n); + // We get the value from the theory model directly, which notice + // does not have to go through the standard SmtEngine::getValue interface. + Node val = model.getValue(n); if (options::modelUninterpDtEnum() && val.getKind() == kind::STORE) { TypeNode type_node = val[1].getType(); @@ -1162,11 +1164,12 @@ void DeclareFunctionNodeCommandToStream( } // namespace -void CvcPrinter::toStream(std::ostream& out, const Model& m) const +void CvcPrinter::toStream(std::ostream& out, const smt::Model& m) const { + const theory::TheoryModel* tm = m.getTheoryModel(); // print the model comments std::stringstream c; - m.getComments(c); + tm->getComments(c); std::string ln; while (std::getline(c, ln)) { @@ -1180,10 +1183,10 @@ void CvcPrinter::toStream(std::ostream& out, const Model& m) const } void CvcPrinter::toStream(std::ostream& out, - const Model& model, + const smt::Model& model, const NodeCommand* command) const { - const auto* theory_model = dynamic_cast(&model); + const auto* theory_model = model.getTheoryModel(); AlwaysAssert(theory_model != nullptr); if (const auto* declare_type_command = dynamic_cast(command)) diff --git a/src/printer/cvc/cvc_printer.h b/src/printer/cvc/cvc_printer.h index a7bacb803..4047f0d8b 100644 --- a/src/printer/cvc/cvc_printer.h +++ b/src/printer/cvc/cvc_printer.h @@ -38,7 +38,7 @@ class CvcPrinter : public CVC4::Printer bool types, size_t dag) const override; void toStream(std::ostream& out, const CommandStatus* s) const override; - void toStream(std::ostream& out, const Model& m) const override; + void toStream(std::ostream& out, const smt::Model& m) const override; /** Print empty command */ void toStreamCmdEmpty(std::ostream& out, @@ -176,7 +176,7 @@ class CvcPrinter : public CVC4::Printer void toStream( std::ostream& out, TNode n, int toDepth, bool types, bool bracket) const; void toStream(std::ostream& out, - const Model& m, + const smt::Model& m, const NodeCommand* c) const override; bool d_cvc3Mode; diff --git a/src/printer/printer.cpp b/src/printer/printer.cpp index 952caf89e..ba062c20f 100644 --- a/src/printer/printer.cpp +++ b/src/printer/printer.cpp @@ -71,13 +71,13 @@ unique_ptr Printer::makePrinter(OutputLanguage lang) } } -void Printer::toStream(std::ostream& out, const Model& m) const +void Printer::toStream(std::ostream& out, const smt::Model& m) const { for(size_t i = 0; i < m.getNumCommands(); ++i) { const NodeCommand* cmd = m.getCommand(i); const DeclareFunctionNodeCommand* dfc = dynamic_cast(cmd); - if (dfc != NULL && !m.isModelCoreSymbol(dfc->getFunction().toExpr())) + if (dfc != NULL && !m.isModelCoreSymbol(dfc->getFunction())) { continue; } diff --git a/src/printer/printer.h b/src/printer/printer.h index c10e1db04..b95b02ca8 100644 --- a/src/printer/printer.h +++ b/src/printer/printer.h @@ -58,7 +58,7 @@ class Printer virtual void toStream(std::ostream& out, const CommandStatus* s) const = 0; /** Write a Model out to a stream with this Printer. */ - virtual void toStream(std::ostream& out, const Model& m) const; + virtual void toStream(std::ostream& out, const smt::Model& m) const; /** Write an UnsatCore out to a stream with this Printer. */ virtual void toStream(std::ostream& out, const UnsatCore& core) const; @@ -275,13 +275,13 @@ class Printer /** write model response to command */ virtual void toStream(std::ostream& out, - const Model& m, + const smt::Model& m, const NodeCommand* c) const = 0; /** write model response to command using another language printer */ void toStreamUsing(OutputLanguage lang, std::ostream& out, - const Model& m, + const smt::Model& m, const NodeCommand* c) const { getPrinter(lang)->toStream(out, m, c); diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 6d75279e5..2024c87b6 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1323,11 +1323,12 @@ void Smt2Printer::toStream(std::ostream& out, const UnsatCore& core) const out << ")" << endl; }/* Smt2Printer::toStream(UnsatCore, map) */ -void Smt2Printer::toStream(std::ostream& out, const Model& m) const +void Smt2Printer::toStream(std::ostream& out, const smt::Model& m) const { + const theory::TheoryModel* tm = m.getTheoryModel(); //print the model comments std::stringstream c; - m.getComments( c ); + tm->getComments(c); std::string ln; while( std::getline( c, ln ) ){ out << "; " << ln << std::endl; @@ -1339,8 +1340,9 @@ void Smt2Printer::toStream(std::ostream& out, const Model& m) const this->Printer::toStream(out, m); out << ")" << endl; //print the heap model, if it exists - Expr h, neq; - if( m.getHeapModel( h, neq ) ){ + Node h, neq; + if (tm->getHeapModel(h, neq)) + { // description of the heap+what nil is equal to fully describes model out << "(heap" << endl; out << h << endl; @@ -1350,11 +1352,10 @@ void Smt2Printer::toStream(std::ostream& out, const Model& m) const } void Smt2Printer::toStream(std::ostream& out, - const Model& model, + const smt::Model& model, const NodeCommand* command) const { - const theory::TheoryModel* theory_model = - dynamic_cast(&model); + const theory::TheoryModel* theory_model = model.getTheoryModel(); AlwaysAssert(theory_model != nullptr); if (const DeclareTypeNodeCommand* dtc = dynamic_cast(command)) @@ -1367,7 +1368,7 @@ void Smt2Printer::toStream(std::ostream& out, } else { - std::vector elements = theory_model->getDomainElements(tn.toType()); + std::vector elements = theory_model->getDomainElements(tn); if (options::modelUninterpDtEnum()) { if (isVariant_2_6(d_variant)) @@ -1378,7 +1379,7 @@ void Smt2Printer::toStream(std::ostream& out, { out << "(declare-datatypes () ((" << (*dtc).getSymbol() << " "; } - for (const Expr& type_ref : elements) + for (const Node& type_ref : elements) { out << "(" << type_ref << ")"; } @@ -1390,9 +1391,8 @@ void Smt2Printer::toStream(std::ostream& out, out << "; cardinality of " << tn << " is " << elements.size() << endl; out << (*dtc) << endl; // print the representatives - for (const Expr& type_ref : elements) + for (const Node& trn : elements) { - Node trn = Node::fromExpr(type_ref); if (trn.isVar()) { out << "(declare-fun " << quoteSymbol(trn) << " () " << tn << ")" @@ -1423,7 +1423,9 @@ void Smt2Printer::toStream(std::ostream& out, // don't print out internal stuff return; } - Node val = theory_model->getSmtEngine()->getValue(n); + // We get the value from the theory model directly, which notice + // does not have to go through the standard SmtEngine::getValue interface. + Node val = theory_model->getValue(n); if (val.getKind() == kind::LAMBDA) { out << "(define-fun " << n << " " << val[0] << " " diff --git a/src/printer/smt2/smt2_printer.h b/src/printer/smt2/smt2_printer.h index 3160771da..ed04da983 100644 --- a/src/printer/smt2/smt2_printer.h +++ b/src/printer/smt2/smt2_printer.h @@ -45,7 +45,7 @@ class Smt2Printer : public CVC4::Printer bool types, size_t dag) const override; void toStream(std::ostream& out, const CommandStatus* s) const override; - void toStream(std::ostream& out, const Model& m) const override; + void toStream(std::ostream& out, const smt::Model& m) const override; /** * Writes the unsat core to the stream out. * We use the expression names that are stored in the SMT engine associated @@ -231,7 +231,7 @@ class Smt2Printer : public CVC4::Printer void toStream( std::ostream& out, TNode n, int toDepth, bool types, TypeNode nt) const; void toStream(std::ostream& out, - const Model& m, + const smt::Model& m, const NodeCommand* c) const override; void toStream(std::ostream& out, const SExpr& sexpr) const; void toStream(std::ostream& out, const DType& dt) const; diff --git a/src/printer/tptp/tptp_printer.cpp b/src/printer/tptp/tptp_printer.cpp index fa0fc3c46..009f78a1d 100644 --- a/src/printer/tptp/tptp_printer.cpp +++ b/src/printer/tptp/tptp_printer.cpp @@ -45,7 +45,7 @@ void TptpPrinter::toStream(std::ostream& out, const CommandStatus* s) const s->toStream(out, language::output::LANG_SMTLIB_V2_5); }/* TptpPrinter::toStream() */ -void TptpPrinter::toStream(std::ostream& out, const Model& m) const +void TptpPrinter::toStream(std::ostream& out, const smt::Model& m) const { std::string statusName(m.isKnownSat() ? "FiniteModel" : "CandidateFiniteModel"); @@ -59,7 +59,7 @@ void TptpPrinter::toStream(std::ostream& out, const Model& m) const } void TptpPrinter::toStream(std::ostream& out, - const Model& m, + const smt::Model& m, const NodeCommand* c) const { // shouldn't be called; only the non-Command* version above should be diff --git a/src/printer/tptp/tptp_printer.h b/src/printer/tptp/tptp_printer.h index 0c961d39b..84bb3e576 100644 --- a/src/printer/tptp/tptp_printer.h +++ b/src/printer/tptp/tptp_printer.h @@ -37,7 +37,7 @@ class TptpPrinter : public CVC4::Printer bool types, size_t dag) const override; void toStream(std::ostream& out, const CommandStatus* s) const override; - void toStream(std::ostream& out, const Model& m) const override; + void toStream(std::ostream& out, const smt::Model& m) const override; /** print unsat core to stream * We use the expression names stored in the SMT engine associated with the * unsat core with UnsatCore::getSmtEngine. @@ -46,7 +46,7 @@ class TptpPrinter : public CVC4::Printer private: void toStream(std::ostream& out, - const Model& m, + const smt::Model& m, const NodeCommand* c) const override; }; /* class TptpPrinter */ diff --git a/src/smt/command.h b/src/smt/command.h index b823b5730..41776cee5 100644 --- a/src/smt/command.h +++ b/src/smt/command.h @@ -46,7 +46,10 @@ class Term; class SmtEngine; class Command; class CommandStatus; + +namespace smt { class Model; +} std::ostream& operator<<(std::ostream&, const Command&) CVC4_PUBLIC; std::ostream& operator<<(std::ostream&, const Command*) CVC4_PUBLIC; @@ -995,7 +998,7 @@ class CVC4_PUBLIC GetModelCommand : public Command OutputLanguage language = language::output::LANG_AUTO) const override; protected: - Model* d_result; + smt::Model* d_result; }; /* class GetModelCommand */ /** The command to block models. */ diff --git a/src/smt/model.cpp b/src/smt/model.cpp index 60640def1..fc9ea8fbb 100644 --- a/src/smt/model.cpp +++ b/src/smt/model.cpp @@ -14,8 +14,6 @@ #include "smt/model.h" -#include - #include "expr/expr_iomanip.h" #include "options/base_options.h" #include "printer/printer.h" @@ -23,10 +21,16 @@ #include "smt/node_command.h" #include "smt/smt_engine.h" #include "smt/smt_engine_scope.h" - -using namespace std; +#include "theory/theory_model.h" namespace CVC4 { +namespace smt { + +Model::Model(SmtEngine& smt, theory::TheoryModel* tm) + : d_smt(smt), d_isKnownSat(false), d_tmodel(tm) +{ + Assert(d_tmodel != nullptr); +} std::ostream& operator<<(std::ostream& out, const Model& m) { smt::SmtScope smts(&m.d_smt); @@ -35,8 +39,6 @@ std::ostream& operator<<(std::ostream& out, const Model& m) { return out; } -Model::Model() : d_smt(*smt::currentSmtEngine()), d_isKnownSat(false) {} - size_t Model::getNumCommands() const { return d_smt.getDumpManager()->getNumModelCommands(); @@ -47,4 +49,17 @@ const NodeCommand* Model::getCommand(size_t i) const return d_smt.getDumpManager()->getModelCommand(i); } +theory::TheoryModel* Model::getTheoryModel() { return d_tmodel; } + +const theory::TheoryModel* Model::getTheoryModel() const { return d_tmodel; } + +bool Model::isModelCoreSymbol(TNode sym) const +{ + return d_tmodel->isModelCoreSymbol(sym); +} +Node Model::getValue(TNode n) const { return d_tmodel->getValue(n); } + +bool Model::hasApproximations() const { return d_tmodel->hasApproximations(); } + +} // namespace smt }/* CVC4 namespace */ diff --git a/src/smt/model.h b/src/smt/model.h index eb959ba7e..dc36b5d29 100644 --- a/src/smt/model.h +++ b/src/smt/model.h @@ -21,30 +21,34 @@ #include #include "expr/expr.h" +#include "theory/theory_model.h" #include "util/cardinality.h" namespace CVC4 { -class NodeCommand; class SmtEngine; +class NodeCommand; + +namespace smt { + class Model; std::ostream& operator<<(std::ostream&, const Model&); +/** + * This is the SMT-level model object, that is responsible for maintaining + * the necessary information for how to print the model, as well as + * holding a pointer to the underlying implementation of the theory model. + */ class Model { friend std::ostream& operator<<(std::ostream&, const Model&); - friend class SmtEngine; - - protected: - /** The SmtEngine we're associated with */ - SmtEngine& d_smt; - - /** construct the base class; users cannot do this, only CVC4 internals */ - Model(); + friend class ::CVC4::SmtEngine; public: + /** construct */ + Model(SmtEngine& smt, theory::TheoryModel* tm); /** virtual destructor */ - virtual ~Model() { } + ~Model() {} /** get number of commands to report */ size_t getNumCommands() const; /** get command */ @@ -62,54 +66,21 @@ class Model { * only a candidate solution. */ bool isKnownSat() const { return d_isKnownSat; } - //--------------------------- model cores - /** set using model core - * - * This sets that this model is minimized to be a "model core" for some - * formula (typically the input formula). - * - * For example, given formula ( a>5 OR b>5 ) AND f( c ) = 0, - * a model for this formula is: a -> 6, b -> 0, c -> 0, f -> lambda x. 0. - * A "model core" is a subset of this model that suffices to show the - * above formula is true, for example { a -> 6, f -> lambda x. 0 } is a - * model core for this formula. - */ - virtual void setUsingModelCore() = 0; - /** record model core symbol - * - * This marks that sym is a "model core symbol". In other words, its value is - * critical to the satisfiability of the formula this model is for. - */ - virtual void recordModelCoreSymbol(Expr sym) = 0; - /** Check whether this expr is in the model core */ - virtual bool isModelCoreSymbol(Expr expr) const = 0; - //--------------------------- end model cores - /** get value for expression */ - virtual Expr getValue(Expr expr) const = 0; - /** get cardinality for sort */ - virtual Cardinality getCardinality(Type t) const = 0; - /** print comments */ - virtual void getComments(std::ostream& out) const {} - /** get heap model (for separation logic) */ - virtual bool getHeapModel( Expr& h, Expr& ne ) const { return false; } - /** are there any approximations in this model? */ - virtual bool hasApproximations() const { return false; } - /** get the list of approximations - * - * This is a list of pairs of the form (t,p), where t is a term and p - * is a predicate over t that indicates a property that t satisfies. - */ - virtual std::vector > getApproximations() const = 0; - /** get the domain elements for uninterpreted sort t - * - * This method gets the interpretation of an uninterpreted sort t. - * All models interpret uninterpreted sorts t as finite sets - * of domain elements v_1, ..., v_n. This method returns this list for t in - * this model. - */ - virtual std::vector getDomainElements(Type t) const = 0; - + /** Get the underlying theory model */ + theory::TheoryModel* getTheoryModel(); + /** Get the underlying theory model (const version) */ + const theory::TheoryModel* getTheoryModel() const; + //----------------------- helper methods in the underlying theory model + /** Is the node n a model core symbol? */ + bool isModelCoreSymbol(TNode sym) const; + /** Get value */ + Node getValue(TNode n) const; + /** Does this model have approximations? */ + bool hasApproximations() const; + //----------------------- end helper methods protected: + /** The SmtEngine we're associated with */ + SmtEngine& d_smt; /** the input name (file name, etc.) this model is associated to */ std::string d_inputName; /** @@ -117,8 +88,14 @@ class Model { * from the solver. */ bool d_isKnownSat; -};/* class Model */ + /** + * Pointer to the underlying theory model, which maintains all data regarding + * the values of sorts and terms. + */ + theory::TheoryModel* d_tmodel; +}; +} // namespace smt }/* CVC4 namespace */ #endif /* CVC4__MODEL_H */ diff --git a/src/smt/model_blocker.cpp b/src/smt/model_blocker.cpp index 9d15b5690..cabd7bd20 100644 --- a/src/smt/model_blocker.cpp +++ b/src/smt/model_blocker.cpp @@ -66,7 +66,7 @@ Node ModelBlocker::getModelBlocker(const std::vector& assertions, Node blockTriv = nm->mkConst(false); Trace("model-blocker") << "...model blocker is (trivially) " << blockTriv << std::endl; - return blockTriv.toExpr(); + return blockTriv; } Node formula = asserts.size() > 1 ? nm->mkNode(AND, asserts) : asserts[0]; @@ -152,7 +152,7 @@ Node ModelBlocker::getModelBlocker(const std::vector& assertions, std::vector children; for (const Node& cn : catom) { - Node vn = Node::fromExpr(m->getValue(cn.toExpr())); + Node vn = m->getValue(cn); Assert(vn.isConst()); children.push_back(vn.getConst() ? cn : cn.negate()); } diff --git a/src/smt/model_core_builder.cpp b/src/smt/model_core_builder.cpp index 59dac63e8..cb8494e85 100644 --- a/src/smt/model_core_builder.cpp +++ b/src/smt/model_core_builder.cpp @@ -21,7 +21,7 @@ using namespace CVC4::kind; namespace CVC4 { bool ModelCoreBuilder::setModelCore(const std::vector& assertions, - Model* m, + theory::TheoryModel* m, options::ModelCoresMode mode) { if (Trace.isOn("model-core")) @@ -53,7 +53,7 @@ bool ModelCoreBuilder::setModelCore(const std::vector& assertions, visited.insert(cur); if (cur.isVar()) { - Node vcur = Node::fromExpr(m->getValue(cur.toExpr())); + Node vcur = m->getValue(cur); Trace("model-core") << " " << cur << " -> " << vcur << std::endl; vars.push_back(cur); subs.push_back(vcur); @@ -95,7 +95,7 @@ bool ModelCoreBuilder::setModelCore(const std::vector& assertions, for (const Node& cv : coreVars) { - m->recordModelCoreSymbol(cv.toExpr()); + m->recordModelCoreSymbol(cv); } return true; } diff --git a/src/smt/model_core_builder.h b/src/smt/model_core_builder.h index 984c61d04..7a28c47f2 100644 --- a/src/smt/model_core_builder.h +++ b/src/smt/model_core_builder.h @@ -21,7 +21,7 @@ #include "expr/expr.h" #include "options/smt_options.h" -#include "smt/model.h" +#include "theory/theory_model.h" namespace CVC4 { @@ -55,7 +55,7 @@ class ModelCoreBuilder * left unchanged. */ static bool setModelCore(const std::vector& assertions, - Model* m, + theory::TheoryModel* m, options::ModelCoresMode mode); }; /* class TheoryModelCoreBuilder */ diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 205865e16..2a771ce76 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -127,6 +127,7 @@ SmtEngine::SmtEngine(ExprManager* em, Options* optr) d_snmListener(new SmtNodeManagerListener(*d_dumpm.get(), d_outMgr)), d_smtSolver(nullptr), d_proofManager(nullptr), + d_model(nullptr), d_pfManager(nullptr), d_rewriter(new theory::Rewriter()), d_definedFunctions(nullptr), @@ -271,6 +272,15 @@ void SmtEngine::finishInit() Trace("smt-debug") << "SmtEngine::finishInit" << std::endl; d_smtSolver->finishInit(const_cast(d_logic)); + // now can construct the SMT-level model object + TheoryEngine* te = d_smtSolver->getTheoryEngine(); + Assert(te != nullptr); + TheoryModel* tm = te->getModel(); + if (tm != nullptr) + { + d_model.reset(new Model(*this, tm)); + } + // global push/pop around everything, to ensure proper destruction // of context-dependent data structures d_state->setup(); @@ -839,7 +849,7 @@ Result SmtEngine::quickCheck() { Result::ENTAILMENT_UNKNOWN, Result::REQUIRES_FULL_CHECK, filename); } -theory::TheoryModel* SmtEngine::getAvailableModel(const char* c) const +Model* SmtEngine::getAvailableModel(const char* c) const { if (!options::assignFunctionValues()) { @@ -878,7 +888,7 @@ theory::TheoryModel* SmtEngine::getAvailableModel(const char* c) const throw RecoverableModalException(ss.str().c_str()); } - return m; + return d_model.get(); } void SmtEngine::notifyPushPre() { d_smtSolver->processAssertions(*d_asserts); } @@ -1210,11 +1220,9 @@ Node SmtEngine::getValue(const Node& ex) const } Trace("smt") << "--- getting value of " << n << endl; - TheoryModel* m = getAvailableModel("get-value"); - Node resultNode; - if(m != NULL) { - resultNode = m->getValue(n); - } + Model* m = getAvailableModel("get-value"); + Assert(m != nullptr); + Node resultNode = m->getValue(n); Trace("smt") << "--- got value " << n << " = " << resultNode << endl; Trace("smt") << "--- type " << resultNode.getType() << endl; Trace("smt") << "--- expected type " << expectedType << endl; @@ -1301,7 +1309,7 @@ vector> SmtEngine::getAssignment() // Get the model here, regardless of whether d_assignments is null, since // we should throw errors related to model availability whether or not // assignments is null. - TheoryModel* m = getAvailableModel("get assignment"); + Model* m = getAvailableModel("get assignment"); vector> res; if (d_assignments != nullptr) @@ -1354,7 +1362,7 @@ Model* SmtEngine::getModel() { getOutputManager().getDumpOut()); } - TheoryModel* m = getAvailableModel("get model"); + Model* m = getAvailableModel("get model"); // Since model m is being returned to the user, we must ensure that this // model object remains valid with future check-sat calls. Hence, we set @@ -1368,8 +1376,11 @@ Model* SmtEngine::getModel() { // If we enabled model cores, we compute a model core for m based on our // (expanded) assertions using the model core builder utility std::vector eassertsProc = getExpandedAssertions(); - ModelCoreBuilder::setModelCore(eassertsProc, m, options::modelCoresMode()); + ModelCoreBuilder::setModelCore( + eassertsProc, m->getTheoryModel(), options::modelCoresMode()); } + // set the information on the SMT-level model + Assert(m != nullptr); m->d_inputName = d_state->getFilename(); m->d_isKnownSat = (d_state->getMode() == SmtMode::SAT); return m; @@ -1388,19 +1399,19 @@ Result SmtEngine::blockModel() getOutputManager().getDumpOut()); } - TheoryModel* m = getAvailableModel("block model"); + Model* m = getAvailableModel("block model"); if (options::blockModelsMode() == options::BlockModelsMode::NONE) { std::stringstream ss; ss << "Cannot block model when block-models is set to none."; - throw ModalException(ss.str().c_str()); + throw RecoverableModalException(ss.str().c_str()); } // get expanded assertions std::vector eassertsProc = getExpandedAssertions(); Node eblocker = ModelBlocker::getModelBlocker( - eassertsProc, m, options::blockModelsMode()); + eassertsProc, m->getTheoryModel(), options::blockModelsMode()); return assertFormula(eblocker); } @@ -1417,13 +1428,16 @@ Result SmtEngine::blockModelValues(const std::vector& exprs) getOutputManager().getDumpOut(), exprs); } - TheoryModel* m = getAvailableModel("block model values"); + Model* m = getAvailableModel("block model values"); // get expanded assertions std::vector eassertsProc = getExpandedAssertions(); // we always do block model values mode here - Node eblocker = ModelBlocker::getModelBlocker( - eassertsProc, m, options::BlockModelsMode::VALUES, exprs); + Node eblocker = + ModelBlocker::getModelBlocker(eassertsProc, + m->getTheoryModel(), + options::BlockModelsMode::VALUES, + exprs); return assertFormula(eblocker); } @@ -1437,16 +1451,18 @@ std::pair SmtEngine::getSepHeapAndNilExpr(void) throw RecoverableModalException(msg); } NodeManagerScope nms(d_nodeManager); - Expr heap; - Expr nil; + Node heap; + Node nil; Model* m = getAvailableModel("get separation logic heap and nil"); - if (!m->getHeapModel(heap, nil)) + TheoryModel* tm = m->getTheoryModel(); + if (!tm->getHeapModel(heap, nil)) { - InternalError() - << "SmtEngine::getSepHeapAndNilExpr(): failed to obtain heap/nil " - "expressions from theory model."; + const char* msg = + "Failed to obtain heap/nil " + "expressions from theory model."; + throw RecoverableModalException(msg); } - return std::make_pair(Node::fromExpr(heap), Node::fromExpr(nil)); + return std::make_pair(heap, nil); } std::vector SmtEngine::getExpandedAssertions() @@ -1544,7 +1560,8 @@ void SmtEngine::checkModel(bool hardFailure) { // and if Notice() is on, the user gave --verbose (or equivalent). Notice() << "SmtEngine::checkModel(): generating model" << endl; - TheoryModel* m = getAvailableModel("check model"); + Model* m = getAvailableModel("check model"); + Assert(m != nullptr); // check-model is not guaranteed to succeed if approximate values were used. // Thus, we intentionally abort here. diff --git a/src/smt/smt_engine.h b/src/smt/smt_engine.h index 62e54a0c1..da12d336b 100644 --- a/src/smt/smt_engine.h +++ b/src/smt/smt_engine.h @@ -60,7 +60,6 @@ class TheoryEngine; class ProofManager; -class Model; class LogicRequest; class StatisticsRegistry; @@ -95,6 +94,7 @@ namespace prop { namespace smt { /** Utilities */ +class Model; class SmtEngineState; class AbstractValues; class Assertions; @@ -280,7 +280,7 @@ class CVC4_PUBLIC SmtEngine * Get the model (only if immediately preceded by a SAT or NOT_ENTAILED * query). Only permitted if produce-models is on. */ - Model* getModel(); + smt::Model* getModel(); /** * Block the current model. Can be called only if immediately preceded by @@ -969,16 +969,17 @@ class CVC4_PUBLIC SmtEngine Result quickCheck(); /** - * Get the model, if it is available and return a pointer to it + * Get the (SMT-level) model pointer, if we are in SAT mode. Otherwise, + * return nullptr. * - * This ensures that the model is currently available, which means that - * CVC4 is producing models, and is in "SAT mode", otherwise an exception - * is thrown. + * This ensures that the underlying theory model of the SmtSolver maintained + * by this class is currently available, which means that CVC4 is producing + * models, and is in "SAT mode", otherwise a recoverable exception is thrown. * * The flag c is used for giving an error message to indicate the context * this method was called. */ - theory::TheoryModel* getAvailableModel(const char* c) const; + smt::Model* getAvailableModel(const char* c) const; // --------------------------------------- callbacks from the state /** @@ -1088,6 +1089,12 @@ class CVC4_PUBLIC SmtEngine /** The (old) proof manager TODO (project #37): delete this */ std::unique_ptr d_proofManager; + /** + * The SMT-level model object, which contains information about how to + * print the model, as well as a pointer to the underlying TheoryModel + * implementation maintained by the SmtSolver. + */ + std::unique_ptr d_model; /** * The proof manager, which manages all things related to checking, diff --git a/src/theory/theory_model.cpp b/src/theory/theory_model.cpp index b8cbdf6b3..f240d5113 100644 --- a/src/theory/theory_model.cpp +++ b/src/theory/theory_model.cpp @@ -109,48 +109,27 @@ bool TheoryModel::getHeapModel(Node& h, Node& neq) const return true; } -bool TheoryModel::getHeapModel( Expr& h, Expr& neq ) const { - if( d_sep_heap.isNull() || d_sep_nil_eq.isNull() ){ - return false; - }else{ - h = d_sep_heap.toExpr(); - neq = d_sep_nil_eq.toExpr(); - return true; - } -} - bool TheoryModel::hasApproximations() const { return !d_approx_list.empty(); } -std::vector > TheoryModel::getApproximations() const +std::vector > TheoryModel::getApproximations() const { - std::vector > approx; - for (const std::pair& ap : d_approx_list) - { - approx.push_back( - std::pair(ap.first.toExpr(), ap.second.toExpr())); - } - return approx; + return d_approx_list; } -std::vector TheoryModel::getDomainElements(Type t) const +std::vector TheoryModel::getDomainElements(TypeNode tn) const { // must be an uninterpreted sort - Assert(t.isSort()); - std::vector elements; - TypeNode tn = TypeNode::fromType(t); + Assert(tn.isSort()); + std::vector elements; const std::vector* type_refs = d_rep_set.getTypeRepsOrNull(tn); if (type_refs == nullptr || type_refs->empty()) { // This is called when t is a sort that does not occur in this model. // Sorts are always interpreted as non-empty, thus we add a single element. - elements.push_back(t.mkGroundTerm()); + elements.push_back(tn.mkGroundTerm()); return elements; } - for (const Node& n : *type_refs) - { - elements.push_back(n.toExpr()); - } - return elements; + return *type_refs; } Node TheoryModel::getValue(TNode n) const @@ -170,39 +149,35 @@ Node TheoryModel::getValue(TNode n) const return nn; } -bool TheoryModel::isModelCoreSymbol(Expr sym) const +bool TheoryModel::isModelCoreSymbol(Node s) const { if (!d_using_model_core) { return true; } - Node s = Node::fromExpr(sym); Assert(s.isVar() && s.getKind() != BOUND_VARIABLE); return d_model_core.find(s) != d_model_core.end(); } -Expr TheoryModel::getValue( Expr expr ) const{ - Node n = Node::fromExpr( expr ); - Node ret = getValue( n ); - return ret.toExpr(); -} - -/** get cardinality for sort */ -Cardinality TheoryModel::getCardinality( Type t ) const{ - TypeNode tn = TypeNode::fromType( t ); +Cardinality TheoryModel::getCardinality(TypeNode tn) const +{ //for now, we only handle cardinalities for uninterpreted sorts - if( tn.isSort() ){ - if( d_rep_set.hasType( tn ) ){ - Debug("model-getvalue-debug") << "Get cardinality sort, #rep : " << d_rep_set.getNumRepresentatives( tn ) << std::endl; - return Cardinality( d_rep_set.getNumRepresentatives( tn ) ); - }else{ - Debug("model-getvalue-debug") << "Get cardinality sort, unconstrained, return 1." << std::endl; - return Cardinality( 1 ); - } - }else{ - Debug("model-getvalue-debug") << "Get cardinality other sort, unknown." << std::endl; + if (!tn.isSort()) + { + Debug("model-getvalue-debug") + << "Get cardinality other sort, unknown." << std::endl; return Cardinality( CardinalityUnknown() ); } + if (d_rep_set.hasType(tn)) + { + Debug("model-getvalue-debug") + << "Get cardinality sort, #rep : " + << d_rep_set.getNumRepresentatives(tn) << std::endl; + return Cardinality(d_rep_set.getNumRepresentatives(tn)); + } + Debug("model-getvalue-debug") + << "Get cardinality sort, unconstrained, return 1." << std::endl; + return Cardinality(1); } Node TheoryModel::getModelValue(TNode n) const @@ -258,16 +233,15 @@ Node TheoryModel::getModelValue(TNode n) const { Debug("model-getvalue-debug") << "get cardinality constraint " << ret[0].getType() << std::endl; - ret = nm->mkConst( - getCardinality(ret[0].getType().toType()).getFiniteCardinality() - <= ret[1].getConst().getNumerator()); + ret = nm->mkConst(getCardinality(ret[0].getType()).getFiniteCardinality() + <= ret[1].getConst().getNumerator()); } else if (ret.getKind() == kind::CARDINALITY_VALUE) { Debug("model-getvalue-debug") << "get cardinality value " << ret[0].getType() << std::endl; - ret = nm->mkConst(Rational( - getCardinality(ret[0].getType().toType()).getFiniteCardinality())); + ret = nm->mkConst( + Rational(getCardinality(ret[0].getType()).getFiniteCardinality())); } d_modelCache[n] = ret; return ret; @@ -621,10 +595,7 @@ void TheoryModel::setUsingModelCore() d_model_core.clear(); } -void TheoryModel::recordModelCoreSymbol(Expr sym) -{ - d_model_core.insert(Node::fromExpr(sym)); -} +void TheoryModel::recordModelCoreSymbol(Node sym) { d_model_core.insert(sym); } void TheoryModel::setUnevaluatedKind(Kind k) { d_unevaluated_kinds.insert(k); } diff --git a/src/theory/theory_model.h b/src/theory/theory_model.h index 9f330ff6c..e8665bb83 100644 --- a/src/theory/theory_model.h +++ b/src/theory/theory_model.h @@ -20,7 +20,6 @@ #include #include -#include "smt/model.h" #include "theory/ee_setup_info.h" #include "theory/rep_set.h" #include "theory/substitutions.h" @@ -76,12 +75,12 @@ namespace theory { * above functions such as getRepresentative() when assigning total * interpretations for uninterpreted functions. */ -class TheoryModel : public Model +class TheoryModel { friend class TheoryEngineModelBuilder; public: TheoryModel(context::Context* c, std::string name, bool enableFuncModels); - ~TheoryModel() override; + virtual ~TheoryModel(); /** * Finish init, where ee is the equality engine the model should use. */ @@ -295,23 +294,21 @@ public: */ Node getValue(TNode n) const; /** get comments */ - void getComments(std::ostream& out) const override; + void getComments(std::ostream& out) const; //---------------------------- separation logic /** set the heap and value sep.nil is equal to */ void setHeapModel(Node h, Node neq); /** get the heap and value sep.nil is equal to */ bool getHeapModel(Node& h, Node& neq) const; - /** get the heap and value sep.nil is equal to */ - bool getHeapModel(Expr& h, Expr& neq) const override; //---------------------------- end separation logic /** is the list of approximations non-empty? */ - bool hasApproximations() const override; + bool hasApproximations() const; /** get approximations */ - std::vector > getApproximations() const override; + std::vector > getApproximations() const; /** get domain elements for uninterpreted sort t */ - std::vector getDomainElements(Type t) const override; + std::vector getDomainElements(TypeNode t) const; /** get the representative set object */ const RepSet* getRepSet() const { return &d_rep_set; } /** get the representative set object (FIXME: remove this, see #1199) */ @@ -319,17 +316,15 @@ public: //---------------------------- model cores /** set using model core */ - void setUsingModelCore() override; + void setUsingModelCore(); /** record model core symbol */ - void recordModelCoreSymbol(Expr sym) override; + void recordModelCoreSymbol(Node sym); /** Return whether symbol expr is in the model core. */ - bool isModelCoreSymbol(Expr sym) const override; + bool isModelCoreSymbol(Node sym) const; //---------------------------- end model cores - /** get value function for Exprs. */ - Expr getValue(Expr expr) const override; /** get cardinality for sort */ - Cardinality getCardinality(Type t) const override; + Cardinality getCardinality(TypeNode t) const; //---------------------------- function values /** a map from functions f to a list of all APPLY_UF terms with operator f */ diff --git a/src/theory/theory_model_builder.cpp b/src/theory/theory_model_builder.cpp index 0f69566d6..2f9e168c9 100644 --- a/src/theory/theory_model_builder.cpp +++ b/src/theory/theory_model_builder.cpp @@ -1082,7 +1082,7 @@ void TheoryEngineModelBuilder::computeAssignableInfo( } } -void TheoryEngineModelBuilder::postProcessModel(bool incomplete, Model* m) +void TheoryEngineModelBuilder::postProcessModel(bool incomplete, TheoryModel* m) { // if we are incomplete, there is no guarantee on the model. // thus, we do not check the model here. @@ -1090,12 +1090,11 @@ void TheoryEngineModelBuilder::postProcessModel(bool incomplete, Model* m) { return; } - TheoryModel* tm = static_cast(m); - Assert(tm != nullptr); + Assert(m != nullptr); // debug-check the model if the checkModels() is enabled. if (options::debugCheckModels()) { - debugCheckModel(tm); + debugCheckModel(m); } } diff --git a/src/theory/theory_model_builder.h b/src/theory/theory_model_builder.h index 996609dd3..4ffcbeee7 100644 --- a/src/theory/theory_model_builder.h +++ b/src/theory/theory_model_builder.h @@ -81,7 +81,7 @@ class TheoryEngineModelBuilder * method checks the internal consistency of the model if we are in a debug * build. */ - void postProcessModel(bool incomplete, Model* m); + void postProcessModel(bool incomplete, TheoryModel* m); protected: /** pointer to theory engine */ -- 2.30.2