From 7a3aa7033719b14b34c0334d6956834b850fa9eb Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 1 Sep 2021 13:05:48 -0500 Subject: [PATCH] Print response to get-model using the API (#7084) This changes our implementation of GetModelCommand so that we use the API to print the model. It simplifies smt::Model so that this is a pretty printing utility, and not a layer on top of TheoryModel. It adds getModel as an API method for returning the string representation of the model, analogous to our current support for getProof. This eliminates the last call to getSmtEngine() from the command layer. --- src/api/cpp/cvc5.cpp | 36 +++++-- src/api/cpp/cvc5.h | 20 +++- src/main/command_executor.h | 6 -- src/printer/ast/ast_printer.cpp | 8 +- src/printer/ast/ast_printer.h | 8 +- src/printer/cvc/cvc_printer.cpp | 42 +++----- src/printer/cvc/cvc_printer.h | 10 +- src/printer/printer.cpp | 13 +-- src/printer/printer.h | 10 +- src/printer/smt2/smt2_printer.cpp | 25 ++--- src/printer/smt2/smt2_printer.h | 8 +- src/printer/tptp/tptp_printer.cpp | 8 +- src/printer/tptp/tptp_printer.h | 8 +- src/smt/check_models.cpp | 4 +- src/smt/check_models.h | 10 +- src/smt/command.cpp | 28 +----- src/smt/command.h | 3 +- src/smt/model.cpp | 55 +++++++---- src/smt/model.h | 78 +++++++-------- src/smt/smt_engine.cpp | 109 +++++++++++---------- src/smt/smt_engine.h | 40 ++++---- test/regress/regress0/cvc-rerror-print.cvc | 2 +- test/unit/api/solver_black.cpp | 41 ++++++++ 23 files changed, 307 insertions(+), 265 deletions(-) diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index d03b8975e..d6c0a58ee 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -7381,6 +7381,36 @@ bool Solver::isModelCoreSymbol(const Term& v) const CVC5_API_TRY_CATCH_END; } +std::string Solver::getModel(const std::vector& sorts, + const std::vector& vars) const +{ + CVC5_API_TRY_CATCH_BEGIN; + NodeManagerScope scope(getNodeManager()); + CVC5_API_RECOVERABLE_CHECK(d_smtEngine->getOptions().smt.produceModels) + << "Cannot get model unless model generation is enabled " + "(try --produce-models)"; + CVC5_API_RECOVERABLE_CHECK(d_smtEngine->isSmtModeSat()) + << "Cannot get model unless after a SAT or unknown response."; + CVC5_API_SOLVER_CHECK_SORTS(sorts); + for (const Sort& s : sorts) + { + CVC5_API_RECOVERABLE_CHECK(s.isUninterpretedSort()) + << "Expecting an uninterpreted sort as argument to " + "getModel."; + } + CVC5_API_SOLVER_CHECK_TERMS(vars); + for (const Term& v : vars) + { + CVC5_API_RECOVERABLE_CHECK(v.getKind() == CONSTANT) + << "Expecting a free constant as argument to getModel."; + } + //////// all checks before this line + return d_smtEngine->getModel(Sort::sortVectorToTypeNodes(sorts), + Term::termVectorToNodes(vars)); + //////// + CVC5_API_TRY_CATCH_END; +} + Term Solver::getQuantifierElimination(const Term& q) const { NodeManagerScope scope(getNodeManager()); @@ -7900,12 +7930,6 @@ std::vector Solver::getSynthSolutions( CVC5_API_TRY_CATCH_END; } -/* - * !!! This is only temporarily available until the parser is fully migrated to - * the new API. !!! - */ -SmtEngine* Solver::getSmtEngine(void) const { return d_smtEngine.get(); } - Statistics Solver::getStatistics() const { return Statistics(d_smtEngine->getStatisticsRegistry()); diff --git a/src/api/cpp/cvc5.h b/src/api/cpp/cvc5.h index 11b138a50..a221f3711 100644 --- a/src/api/cpp/cvc5.h +++ b/src/api/cpp/cvc5.h @@ -3950,6 +3950,22 @@ class CVC5_EXPORT Solver */ bool isModelCoreSymbol(const Term& v) const; + /** + * Get the model + * SMT-LIB: + * \verbatim + * ( get-model ) + * \endverbatim + * Requires to enable option 'produce-models'. + * @param sorts The list of uninterpreted sorts that should be printed in the + * model. + * @param vars The list of free constants that should be printed in the + * model. A subset of these may be printed based on isModelCoreSymbol. + * @return a string representing the model. + */ + std::string getModel(const std::vector& sorts, + const std::vector& vars) const; + /** * Do quantifier elimination. * SMT-LIB: @@ -4329,10 +4345,6 @@ class CVC5_EXPORT Solver */ std::vector getSynthSolutions(const std::vector& terms) const; - // !!! This is only temporarily available until the parser is fully migrated - // to the new API. !!! - SmtEngine* getSmtEngine(void) const; - /** * Returns a snapshot of the current state of the statistic values of this * solver. The returned object is completely decoupled from the solver and diff --git a/src/main/command_executor.h b/src/main/command_executor.h index 0a7a56e5b..1e8d848a4 100644 --- a/src/main/command_executor.h +++ b/src/main/command_executor.h @@ -27,10 +27,6 @@ namespace cvc5 { class Command; -namespace smt { -class SmtEngine; -} - namespace main { class CommandExecutor @@ -81,8 +77,6 @@ class CommandExecutor api::Result getResult() const { return d_result; } void reset(); - SmtEngine* getSmtEngine() const { return d_solver->getSmtEngine(); } - /** Store the current options as the original options */ void storeOptionsAsOriginal(); diff --git a/src/printer/ast/ast_printer.cpp b/src/printer/ast/ast_printer.cpp index 7c1a0e887..75219840a 100644 --- a/src/printer/ast/ast_printer.cpp +++ b/src/printer/ast/ast_printer.cpp @@ -141,16 +141,16 @@ void AstPrinter::toStream(std::ostream& out, const smt::Model& m) const } void AstPrinter::toStreamModelSort(std::ostream& out, - const smt::Model& m, - TypeNode tn) const + TypeNode tn, + const std::vector& elements) const { // shouldn't be called; only the non-Command* version above should be Unreachable(); } void AstPrinter::toStreamModelTerm(std::ostream& out, - const smt::Model& m, - Node n) const + const Node& n, + const Node& value) const { // shouldn't be called; only the non-Command* version above should be Unreachable(); diff --git a/src/printer/ast/ast_printer.h b/src/printer/ast/ast_printer.h index da5785f9f..fd4775da4 100644 --- a/src/printer/ast/ast_printer.h +++ b/src/printer/ast/ast_printer.h @@ -173,16 +173,16 @@ class AstPrinter : public cvc5::Printer * tn declared via declare-sort or declare-datatype. */ void toStreamModelSort(std::ostream& out, - const smt::Model& m, - TypeNode tn) const override; + TypeNode tn, + const std::vector& elements) const override; /** * To stream model term. This prints the appropriate output for term * n declared via declare-fun. */ void toStreamModelTerm(std::ostream& out, - const smt::Model& m, - Node n) const override; + const Node& n, + const Node& value) const override; /** * To stream with let binding. This prints n, possibly in the scope * of letification generated by this method based on lbind. diff --git a/src/printer/cvc/cvc_printer.cpp b/src/printer/cvc/cvc_printer.cpp index 1f1296b7f..04274ddc3 100644 --- a/src/printer/cvc/cvc_printer.cpp +++ b/src/printer/cvc/cvc_printer.cpp @@ -1029,8 +1029,8 @@ void CvcPrinter::toStream(std::ostream& out, const CommandStatus* s) const }/* CvcPrinter::toStream(CommandStatus*) */ void CvcPrinter::toStreamModelSort(std::ostream& out, - const smt::Model& m, - TypeNode tn) const + TypeNode tn, + const std::vector& elements) const { if (!tn.isSort()) { @@ -1038,36 +1038,25 @@ void CvcPrinter::toStreamModelSort(std::ostream& out, << tn << std::endl; return; } - const theory::TheoryModel* tm = m.getTheoryModel(); - const std::vector* type_reps = tm->getRepSet()->getTypeRepsOrNull(tn); - if (type_reps != nullptr) + out << "% cardinality of " << tn << " is " << elements.size() << std::endl; + toStreamCmdDeclareType(out, tn); + for (const Node& type_rep : elements) { - out << "% cardinality of " << tn << " is " << type_reps->size() - << std::endl; - toStreamCmdDeclareType(out, tn); - for (Node type_rep : *type_reps) + if (type_rep.isVar()) { - if (type_rep.isVar()) - { - out << type_rep << " : " << tn << ";" << std::endl; - } - else - { - out << "% rep: " << type_rep << std::endl; - } + out << type_rep << " : " << tn << ";" << std::endl; + } + else + { + out << "% rep: " << type_rep << std::endl; } - } - else - { - toStreamCmdDeclareType(out, tn); } } void CvcPrinter::toStreamModelTerm(std::ostream& out, - const smt::Model& m, - Node n) const + const Node& n, + const Node& value) const { - const theory::TheoryModel* tm = m.getTheoryModel(); TypeNode tn = n.getType(); out << n << " : "; if (tn.isFunction() || tn.isPredicate()) @@ -1087,10 +1076,7 @@ void CvcPrinter::toStreamModelTerm(std::ostream& out, { out << tn; } - // We get the value from the theory model directly, which notice - // does not have to go through the standard SmtEngine::getValue interface. - Node val = tm->getValue(n); - out << " = " << val << ";" << std::endl; + out << " = " << value << ";" << std::endl; } void CvcPrinter::toStream(std::ostream& out, const smt::Model& m) const diff --git a/src/printer/cvc/cvc_printer.h b/src/printer/cvc/cvc_printer.h index 555237177..4851868d3 100644 --- a/src/printer/cvc/cvc_printer.h +++ b/src/printer/cvc/cvc_printer.h @@ -175,19 +175,19 @@ class CvcPrinter : public cvc5::Printer LetBinding* lbind) const; /** * To stream model sort. This prints the appropriate output for type - * tn declared via declare-sort or declare-datatype. + * tn declared via declare-sort. */ void toStreamModelSort(std::ostream& out, - const smt::Model& m, - TypeNode tn) const override; + TypeNode tn, + const std::vector& elements) const override; /** * To stream model term. This prints the appropriate output for term * n declared via declare-fun. */ void toStreamModelTerm(std::ostream& out, - const smt::Model& m, - Node n) const override; + const Node& n, + const Node& value) const override; /** * To stream with let binding. This prints n, possibly in the scope * of letification generated by this method based on lbind. diff --git a/src/printer/printer.cpp b/src/printer/printer.cpp index 59122cf3d..e038952c4 100644 --- a/src/printer/printer.cpp +++ b/src/printer/printer.cpp @@ -66,22 +66,15 @@ void Printer::toStream(std::ostream& out, const smt::Model& m) const const std::vector& dsorts = m.getDeclaredSorts(); for (const TypeNode& tn : dsorts) { - toStreamModelSort(out, m, tn); + toStreamModelSort(out, tn, m.getDomainElements(tn)); } - // print the declared terms const std::vector& dterms = m.getDeclaredTerms(); for (const Node& n : dterms) { - // take into account model core, independently of the format - if (!m.isModelCoreSymbol(n)) - { - continue; - } - toStreamModelTerm(out, m, n); + toStreamModelTerm(out, n, m.getValue(n)); } - -}/* Printer::toStream(Model) */ +} void Printer::toStreamUsing(Language lang, std::ostream& out, diff --git a/src/printer/printer.h b/src/printer/printer.h index 05ac8879b..5e141fe8f 100644 --- a/src/printer/printer.h +++ b/src/printer/printer.h @@ -271,19 +271,19 @@ class Printer /** * To stream model sort. This prints the appropriate output for type - * tn declared via declare-sort or declare-datatype. + * tn declared via declare-sort. */ virtual void toStreamModelSort(std::ostream& out, - const smt::Model& m, - TypeNode tn) const = 0; + TypeNode tn, + const std::vector& elements) const = 0; /** * To stream model term. This prints the appropriate output for term * n declared via declare-fun. */ virtual void toStreamModelTerm(std::ostream& out, - const smt::Model& m, - Node n) const = 0; + const Node& n, + const Node& value) const = 0; /** write model response to command using another language printer */ void toStreamUsing(Language lang, diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 8a23a59ea..0d556c1dc 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1262,7 +1262,6 @@ void Smt2Printer::toStream(std::ostream& out, const UnsatCore& core) const void Smt2Printer::toStream(std::ostream& out, const smt::Model& m) const { - const theory::TheoryModel* tm = m.getTheoryModel(); //print the model out << "(" << endl; // don't need to print approximations since they are built into choice @@ -1271,7 +1270,7 @@ void Smt2Printer::toStream(std::ostream& out, const smt::Model& m) const out << ")" << endl; //print the heap model, if it exists Node h, neq; - if (tm->getHeapModel(h, neq)) + if (m.getHeapModel(h, neq)) { // description of the heap+what nil is equal to fully describes model out << "(heap" << endl; @@ -1282,8 +1281,8 @@ void Smt2Printer::toStream(std::ostream& out, const smt::Model& m) const } void Smt2Printer::toStreamModelSort(std::ostream& out, - const smt::Model& m, - TypeNode tn) const + TypeNode tn, + const std::vector& elements) const { if (!tn.isSort()) { @@ -1291,8 +1290,6 @@ void Smt2Printer::toStreamModelSort(std::ostream& out, << tn << std::endl; return; } - const theory::TheoryModel* tm = m.getTheoryModel(); - std::vector elements = tm->getDomainElements(tn); // print the cardinality out << "; cardinality of " << tn << " is " << elements.size() << endl; if (options::modelUninterpPrint() @@ -1322,26 +1319,22 @@ void Smt2Printer::toStreamModelSort(std::ostream& out, } void Smt2Printer::toStreamModelTerm(std::ostream& out, - const smt::Model& m, - Node n) const + const Node& n, + const Node& value) const { - const theory::TheoryModel* tm = m.getTheoryModel(); - // We get the value from the theory model directly, which notice - // does not have to go through the standard SmtEngine::getValue interface. - Node val = tm->getValue(n); - if (val.getKind() == kind::LAMBDA) + if (value.getKind() == kind::LAMBDA) { TypeNode rangeType = n.getType().getRangeType(); - out << "(define-fun " << n << " " << val[0] << " " << rangeType << " "; + out << "(define-fun " << n << " " << value[0] << " " << rangeType << " "; // call toStream and force its type to be proper - toStreamCastToType(out, val[1], -1, rangeType); + toStreamCastToType(out, value[1], -1, rangeType); out << ")" << endl; } else { out << "(define-fun " << n << " () " << n.getType() << " "; // call toStream and force its type to be proper - toStreamCastToType(out, val, -1, n.getType()); + toStreamCastToType(out, value, -1, n.getType()); out << ")" << endl; } } diff --git a/src/printer/smt2/smt2_printer.h b/src/printer/smt2/smt2_printer.h index 15f45a10e..729caebf4 100644 --- a/src/printer/smt2/smt2_printer.h +++ b/src/printer/smt2/smt2_printer.h @@ -252,16 +252,16 @@ class Smt2Printer : public cvc5::Printer * tn declared via declare-sort or declare-datatype. */ void toStreamModelSort(std::ostream& out, - const smt::Model& m, - TypeNode tn) const override; + TypeNode tn, + const std::vector& elements) const override; /** * To stream model term. This prints the appropriate output for term * n declared via declare-fun. */ void toStreamModelTerm(std::ostream& out, - const smt::Model& m, - Node n) const override; + const Node& n, + const Node& value) const override; /** * To stream with let binding. This prints n, possibly in the scope diff --git a/src/printer/tptp/tptp_printer.cpp b/src/printer/tptp/tptp_printer.cpp index bb8df120e..6c8746706 100644 --- a/src/printer/tptp/tptp_printer.cpp +++ b/src/printer/tptp/tptp_printer.cpp @@ -58,16 +58,16 @@ void TptpPrinter::toStream(std::ostream& out, const smt::Model& m) const } void TptpPrinter::toStreamModelSort(std::ostream& out, - const smt::Model& m, - TypeNode tn) const + TypeNode tn, + const std::vector& elements) const { // shouldn't be called; only the non-Command* version above should be Unreachable(); } void TptpPrinter::toStreamModelTerm(std::ostream& out, - const smt::Model& m, - Node n) const + const Node& n, + const Node& value) const { // shouldn't be called; only the non-Command* version above should be Unreachable(); diff --git a/src/printer/tptp/tptp_printer.h b/src/printer/tptp/tptp_printer.h index db86de8cf..9d288b4ed 100644 --- a/src/printer/tptp/tptp_printer.h +++ b/src/printer/tptp/tptp_printer.h @@ -48,16 +48,16 @@ class TptpPrinter : public cvc5::Printer * tn declared via declare-sort or declare-datatype. */ void toStreamModelSort(std::ostream& out, - const smt::Model& m, - TypeNode tn) const override; + TypeNode tn, + const std::vector& elements) const override; /** * To stream model term. This prints the appropriate output for term * n declared via declare-fun. */ void toStreamModelTerm(std::ostream& out, - const smt::Model& m, - Node n) const override; + const Node& n, + const Node& value) const override; }; /* class TptpPrinter */ diff --git a/src/smt/check_models.cpp b/src/smt/check_models.cpp index 0bc7ce99b..d3a2dfefa 100644 --- a/src/smt/check_models.cpp +++ b/src/smt/check_models.cpp @@ -18,13 +18,13 @@ #include "base/modal_exception.h" #include "options/smt_options.h" #include "smt/env.h" -#include "smt/model.h" #include "smt/node_command.h" #include "smt/preprocessor.h" #include "smt/smt_solver.h" #include "theory/rewriter.h" #include "theory/substitutions.h" #include "theory/theory_engine.h" +#include "theory/theory_model.h" using namespace cvc5::theory; @@ -34,7 +34,7 @@ namespace smt { CheckModels::CheckModels(Env& e) : d_env(e) {} CheckModels::~CheckModels() {} -void CheckModels::checkModel(Model* m, +void CheckModels::checkModel(TheoryModel* m, context::CDList* al, bool hardFailure) { diff --git a/src/smt/check_models.h b/src/smt/check_models.h index ce06bae07..fbfb1c2f5 100644 --- a/src/smt/check_models.h +++ b/src/smt/check_models.h @@ -25,9 +25,11 @@ namespace cvc5 { class Env; -namespace smt { +namespace theory { +class TheoryModel; +} -class Model; +namespace smt { /** * This utility is responsible for checking the current model. @@ -43,7 +45,9 @@ class CheckModels * This throws an exception if we fail to verify that m is a proper model * given assertion list al based on the model checking policy. */ - void checkModel(Model* m, context::CDList* al, bool hardFailure); + void checkModel(theory::TheoryModel* m, + context::CDList* al, + bool hardFailure); private: /** Reference to the environment */ diff --git a/src/smt/command.cpp b/src/smt/command.cpp index e6be0a646..008d7a6d8 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -38,8 +38,6 @@ #include "proof/unsat_core.h" #include "smt/dump.h" #include "smt/model.h" -#include "smt/smt_engine.h" -#include "smt/smt_engine_scope.h" #include "util/unsafe_interrupt_exception.h" #include "util/utility.h" @@ -1748,27 +1746,17 @@ void GetAssignmentCommand::toStream(std::ostream& out, /* class GetModelCommand */ /* -------------------------------------------------------------------------- */ -GetModelCommand::GetModelCommand() : d_result(nullptr) {} +GetModelCommand::GetModelCommand() {} void GetModelCommand::invoke(api::Solver* solver, SymbolManager* sm) { try { - d_result = solver->getSmtEngine()->getModel(); - // set the model declarations, which determines what is printed in the model - d_result->clearModelDeclarations(); std::vector declareSorts = sm->getModelDeclareSorts(); - for (const api::Sort& s : declareSorts) - { - d_result->addDeclarationSort(sortToTypeNode(s)); - } std::vector declareTerms = sm->getModelDeclareTerms(); - for (const api::Term& t : declareTerms) - { - d_result->addDeclarationTerm(termToNode(t)); - } + d_result = solver->getModel(declareSorts, declareTerms); d_commandStatus = CommandSuccess::instance(); } - catch (RecoverableModalException& e) + catch (api::CVC5ApiRecoverableException& e) { d_commandStatus = new CommandRecoverableFailure(e.what()); } @@ -1782,12 +1770,6 @@ void GetModelCommand::invoke(api::Solver* solver, SymbolManager* sm) } } -/* Model is private to the library -- for now -Model* GetModelCommand::getResult() const { - return d_result; -} -*/ - void GetModelCommand::printResult(std::ostream& out, uint32_t verbosity) const { if (!ok()) @@ -1796,13 +1778,13 @@ void GetModelCommand::printResult(std::ostream& out, uint32_t verbosity) const } else { - out << *d_result; + out << d_result; } } Command* GetModelCommand::clone() const { - GetModelCommand* c = new GetModelCommand(); + GetModelCommand* c = new GetModelCommand; c->d_result = d_result; return c; } diff --git a/src/smt/command.h b/src/smt/command.h index d3e3679d2..627cb13c9 100644 --- a/src/smt/command.h +++ b/src/smt/command.h @@ -950,7 +950,8 @@ class CVC5_EXPORT GetModelCommand : public Command Language language = Language::LANG_AUTO) const override; protected: - smt::Model* d_result; + /** Result of printing the model */ + std::string d_result; }; /* class GetModelCommand */ /** The command to block models. */ diff --git a/src/smt/model.cpp b/src/smt/model.cpp index cf6a90f12..9a195cb51 100644 --- a/src/smt/model.cpp +++ b/src/smt/model.cpp @@ -18,18 +18,13 @@ #include "expr/expr_iomanip.h" #include "options/base_options.h" #include "printer/printer.h" -#include "smt/dump_manager.h" -#include "smt/node_command.h" -#include "smt/smt_engine.h" -#include "smt/smt_engine_scope.h" -#include "theory/theory_model.h" namespace cvc5 { namespace smt { -Model::Model(theory::TheoryModel* tm) : d_isKnownSat(false), d_tmodel(tm) +Model::Model(bool isKnownSat, const std::string& inputName) + : d_inputName(inputName), d_isKnownSat(isKnownSat) { - Assert(d_tmodel != nullptr); } std::ostream& operator<<(std::ostream& out, const Model& m) { @@ -38,31 +33,55 @@ std::ostream& operator<<(std::ostream& out, const Model& m) { return out; } -theory::TheoryModel* Model::getTheoryModel() { return d_tmodel; } +const std::vector& Model::getDomainElements(TypeNode tn) const +{ + std::map>::const_iterator it = + d_domainElements.find(tn); + Assert(it != d_domainElements.end()); + return it->second; +} -const theory::TheoryModel* Model::getTheoryModel() const { return d_tmodel; } +Node Model::getValue(TNode n) const +{ + std::map::const_iterator it = d_declareTermValues.find(n); + Assert(it != d_declareTermValues.end()); + return it->second; +} -bool Model::isModelCoreSymbol(TNode sym) const +bool Model::getHeapModel(Node& h, Node& nilEq) const { - return d_tmodel->isModelCoreSymbol(sym); + if (d_sepHeap.isNull() || d_sepNilEq.isNull()) + { + return false; + } + h = d_sepHeap; + nilEq = d_sepNilEq; + return true; } -Node Model::getValue(TNode n) const { return d_tmodel->getValue(n); } -bool Model::hasApproximations() const { return d_tmodel->hasApproximations(); } +void Model::addDeclarationSort(TypeNode tn, const std::vector& elements) +{ + d_declareSorts.push_back(tn); + d_domainElements[tn] = elements; +} -void Model::clearModelDeclarations() +void Model::addDeclarationTerm(Node n, Node value) { - d_declareTerms.clear(); - d_declareSorts.clear(); + d_declareTerms.push_back(n); + d_declareTermValues[n] = value; } -void Model::addDeclarationSort(TypeNode tn) { d_declareSorts.push_back(tn); } +void Model::setHeapModel(Node h, Node nilEq) +{ + d_sepHeap = h; + d_sepNilEq = nilEq; +} -void Model::addDeclarationTerm(Node n) { d_declareTerms.push_back(n); } const std::vector& Model::getDeclaredSorts() const { return d_declareSorts; } + const std::vector& Model::getDeclaredTerms() const { return d_declareTerms; diff --git a/src/smt/model.h b/src/smt/model.h index 342a9f3b0..5275ea680 100644 --- a/src/smt/model.h +++ b/src/smt/model.h @@ -15,8 +15,8 @@ #include "cvc5_private.h" -#ifndef CVC5__MODEL_H -#define CVC5__MODEL_H +#ifndef CVC5__SMT__MODEL_H +#define CVC5__SMT__MODEL_H #include #include @@ -24,13 +24,6 @@ #include "expr/node.h" namespace cvc5 { - -class SmtEngine; - -namespace theory { -class TheoryModel; -} - namespace smt { class Model; @@ -38,22 +31,15 @@ class Model; std::ostream& operator<<(std::ostream&, const Model&); /** - * This is the SMT-level model object, that is responsible for maintaining - * the necessary information for how to print the model, as well as - * holding a pointer to the underlying implementation of the theory model. - * - * The model declarations maintained by this class are context-independent - * and should be updated when this model is printed. + * A utility for representing a model for pretty printing. */ class Model { - friend std::ostream& operator<<(std::ostream&, const Model&); - friend class ::cvc5::SmtEngine; - public: - /** construct */ - Model(theory::TheoryModel* tm); - /** virtual destructor */ - ~Model() {} + /** Constructor + * @param isKnownSat True if this model is associated with a "sat" response, + * or false if it is associated with an "unknown" response. + */ + Model(bool isKnownSat, const std::string& inputName); /** get the input name (file name, etc.) this model is associated to */ std::string getInputName() const { return d_inputName; } /** @@ -63,31 +49,37 @@ class Model { * only a candidate solution. */ bool isKnownSat() const { return d_isKnownSat; } - /** Get the underlying theory model */ - theory::TheoryModel* getTheoryModel(); - /** Get the underlying theory model (const version) */ - const theory::TheoryModel* getTheoryModel() const; - //----------------------- helper methods in the underlying theory model - /** Is the node n a model core symbol? */ - bool isModelCoreSymbol(TNode sym) const; + /** Get domain elements */ + const std::vector& getDomainElements(TypeNode tn) const; /** Get value */ Node getValue(TNode n) const; - /** Does this model have approximations? */ - bool hasApproximations() const; - //----------------------- end helper methods + /** Get separation logic heap and nil, return true if they have been set */ + bool getHeapModel(Node& h, Node& nilEq) const; //----------------------- model declarations - /** Clear the current model declarations. */ - void clearModelDeclarations(); /** * Set that tn is a sort that should be printed in the model, when applicable, * based on the output language. + * + * @param tn The uninterpreted sort + * @param elements The domain elements of tn in the model */ - void addDeclarationSort(TypeNode tn); + void addDeclarationSort(TypeNode tn, const std::vector& elements); /** * Set that n is a variable that should be printed in the model, when * applicable, based on the output language. + * + * @param n The variable + * @param value The value of the variable in the model + */ + void addDeclarationTerm(Node n, Node value); + /** + * Set the separation logic model information where h is the heap and nilEq + * is the value of sep.nil. + * + * @param h The value of heap in the heap model + * @param nilEq The value of sep.nil in the heap model */ - void addDeclarationTerm(Node n); + void setHeapModel(Node h, Node nilEq); /** get declared sorts */ const std::vector& getDeclaredSorts() const; /** get declared terms */ @@ -101,24 +93,26 @@ class Model { * from the solver. */ bool d_isKnownSat; - /** - * Pointer to the underlying theory model, which maintains all data regarding - * the values of sorts and terms. - */ - theory::TheoryModel* d_tmodel; /** * The list of types to print, generally corresponding to declare-sort * commands. */ std::vector d_declareSorts; + /** The interpretation of the above sorts, as a list of domain elements. */ + std::map> d_domainElements; /** * The list of terms to print, is typically one-to-one with declare-fun * commands. */ std::vector d_declareTerms; + /** Mapping terms to values */ + std::map d_declareTermValues; + /** Separation logic heap and nil */ + Node d_sepHeap; + Node d_sepNilEq; }; } // namespace smt } // namespace cvc5 -#endif /* CVC5__MODEL_H */ +#endif /* CVC5__SMT__MODEL_H */ diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 2276956b5..27e7b8530 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -92,7 +92,6 @@ SmtEngine::SmtEngine(NodeManager* nm, const Options* optr) d_routListener(new ResourceOutListener(*this)), d_snmListener(new SmtNodeManagerListener(*getDumpManager(), d_outMgr)), d_smtSolver(nullptr), - d_model(nullptr), d_checkModels(nullptr), d_pfManager(nullptr), d_ucManager(nullptr), @@ -224,7 +223,6 @@ void SmtEngine::finishInit() TheoryModel* tm = te->getModel(); if (tm != nullptr) { - d_model.reset(new Model(tm)); // make the check models utility d_checkModels.reset(new CheckModels(*d_env.get())); } @@ -305,7 +303,6 @@ SmtEngine::~SmtEngine() d_absValues.reset(nullptr); d_asserts.reset(nullptr); - d_model.reset(nullptr); d_abductSolver.reset(nullptr); d_interpolSolver.reset(nullptr); @@ -712,7 +709,7 @@ Result SmtEngine::quickCheck() { Result::ENTAILMENT_UNKNOWN, Result::REQUIRES_FULL_CHECK, filename); } -Model* SmtEngine::getAvailableModel(const char* c) const +TheoryModel* SmtEngine::getAvailableModel(const char* c) const { if (!d_env->getOptions().theory.assignFunctionValues) { @@ -751,7 +748,7 @@ Model* SmtEngine::getAvailableModel(const char* c) const throw RecoverableModalException(ss.str().c_str()); } - return d_model.get(); + return m; } QuantifiersEngine* SmtEngine::getAvailableQuantifiersEngine(const char* c) const @@ -1121,7 +1118,7 @@ Node SmtEngine::getValue(const Node& ex) const } Trace("smt") << "--- getting value of " << n << endl; - Model* m = getAvailableModel("get-value"); + TheoryModel* m = getAvailableModel("get-value"); Assert(m != nullptr); Node resultNode = m->getValue(n); Trace("smt") << "--- got value " << n << " = " << resultNode << endl; @@ -1163,8 +1160,8 @@ std::vector SmtEngine::getValues(const std::vector& exprs) const std::vector SmtEngine::getModelDomainElements(TypeNode tn) const { Assert(tn.isSort()); - Model* m = getAvailableModel("getModelDomainElements"); - return m->getTheoryModel()->getDomainElements(tn); + TheoryModel* m = getAvailableModel("getModelDomainElements"); + return m->getDomainElements(tn); } bool SmtEngine::isModelCoreSymbol(Node n) @@ -1177,8 +1174,7 @@ bool SmtEngine::isModelCoreSymbol(Node n) // if the model core mode is none, we are always a model core symbol return true; } - Model* m = getAvailableModel("isModelCoreSymbol"); - TheoryModel* tm = m->getTheoryModel(); + TheoryModel* tm = getAvailableModel("isModelCoreSymbol"); // compute the model core if not done so already if (!tm->isUsingModelCore()) { @@ -1193,41 +1189,54 @@ bool SmtEngine::isModelCoreSymbol(Node n) return tm->isModelCoreSymbol(n); } -// TODO(#1108): Simplify the error reporting of this method. -Model* SmtEngine::getModel() { - Trace("smt") << "SMT getModel()" << endl; +std::string SmtEngine::getModel(const std::vector& declaredSorts, + const std::vector& declaredFuns) +{ SmtScope smts(this); - - finishInit(); - - if (Dump.isOn("benchmark")) + // !!! Note that all methods called here should have a version at the API + // level. This is to ensure that the information associated with a model is + // completely accessible by the user. This is currently not rigorously + // enforced. An alternative design would be to have this method implemented + // at the API level, but this makes exceptions in the text interface less + // intuitive and makes it impossible to implement raw-benchmark at the + // SmtEngine level. + if (Dump.isOn("raw-benchmark")) { getPrinter().toStreamCmdGetModel(d_env->getDumpOut()); } - - Model* m = getAvailableModel("get model"); - - // Notice that the returned model is (currently) accessed by the - // GetModelCommand only, and is not returned to the user. The information - // in that model may become stale after it is returned. This is safe - // since GetModelCommand always calls this command again when it prints - // a model. - - if (d_env->getOptions().smt.modelCoresMode - != options::ModelCoresMode::NONE) + TheoryModel* tm = getAvailableModel("get model"); + // use the smt::Model model utility for printing + const Options& opts = d_env->getOptions(); + bool isKnownSat = (d_state->getMode() == SmtMode::SAT); + Model m(isKnownSat, opts.driver.filename); + // set the model declarations, which determines what is printed in the model + for (const TypeNode& tn : declaredSorts) { - // If we enabled model cores, we compute a model core for m based on our - // (expanded) assertions using the model core builder utility - std::vector asserts = getAssertionsInternal(); - d_pp->expandDefinitions(asserts); - ModelCoreBuilder::setModelCore( - asserts, m->getTheoryModel(), d_env->getOptions().smt.modelCoresMode); + m.addDeclarationSort(tn, getModelDomainElements(tn)); } - // set the information on the SMT-level model - Assert(m != nullptr); - m->d_inputName = d_env->getOptions().driver.filename; - m->d_isKnownSat = (d_state->getMode() == SmtMode::SAT); - return m; + bool usingModelCores = + (opts.smt.modelCoresMode != options::ModelCoresMode::NONE); + for (const Node& n : declaredFuns) + { + if (usingModelCores && !tm->isModelCoreSymbol(n)) + { + // skip if not in model core + continue; + } + Node value = tm->getValue(n); + m.addDeclarationTerm(n, value); + } + // for separation logic + TypeNode locT, dataT; + if (getSepHeapTypes(locT, dataT)) + { + std::pair sh = getSepHeapAndNilExpr(); + m.setHeapModel(sh.first, sh.second); + } + // print the model + std::stringstream ssm; + ssm << m; + return ssm.str(); } Result SmtEngine::blockModel() @@ -1242,7 +1251,7 @@ Result SmtEngine::blockModel() getPrinter().toStreamCmdBlockModel(d_env->getDumpOut()); } - Model* m = getAvailableModel("block model"); + TheoryModel* m = getAvailableModel("block model"); if (d_env->getOptions().smt.blockModelsMode == options::BlockModelsMode::NONE) @@ -1254,10 +1263,8 @@ Result SmtEngine::blockModel() // get expanded assertions std::vector eassertsProc = getExpandedAssertions(); - Node eblocker = - ModelBlocker::getModelBlocker(eassertsProc, - m->getTheoryModel(), - d_env->getOptions().smt.blockModelsMode); + Node eblocker = ModelBlocker::getModelBlocker( + eassertsProc, m, d_env->getOptions().smt.blockModelsMode); Trace("smt") << "Block formula: " << eblocker << std::endl; return assertFormula(eblocker); } @@ -1274,16 +1281,13 @@ Result SmtEngine::blockModelValues(const std::vector& exprs) getPrinter().toStreamCmdBlockModelValues(d_env->getDumpOut(), exprs); } - Model* m = getAvailableModel("block model values"); + TheoryModel* m = getAvailableModel("block model values"); // get expanded assertions std::vector eassertsProc = getExpandedAssertions(); // we always do block model values mode here - Node eblocker = - ModelBlocker::getModelBlocker(eassertsProc, - m->getTheoryModel(), - options::BlockModelsMode::VALUES, - exprs); + Node eblocker = ModelBlocker::getModelBlocker( + eassertsProc, m, options::BlockModelsMode::VALUES, exprs); return assertFormula(eblocker); } @@ -1299,8 +1303,7 @@ std::pair SmtEngine::getSepHeapAndNilExpr(void) NodeManagerScope nms(getNodeManager()); Node heap; Node nil; - Model* m = getAvailableModel("get separation logic heap and nil"); - TheoryModel* tm = m->getTheoryModel(); + TheoryModel* tm = getAvailableModel("get separation logic heap and nil"); if (!tm->getHeapModel(heap, nil)) { const char* msg = @@ -1548,7 +1551,7 @@ void SmtEngine::checkModel(bool hardFailure) { TimerStat::CodeTimer checkModelTimer(d_stats->d_checkModelTime); Notice() << "SmtEngine::checkModel(): generating model" << endl; - Model* m = getAvailableModel("check model"); + TheoryModel* m = getAvailableModel("check model"); Assert(m != nullptr); // check the model with the theory engine for debugging diff --git a/src/smt/smt_engine.h b/src/smt/smt_engine.h index 84501d35e..06a1c9ae4 100644 --- a/src/smt/smt_engine.h +++ b/src/smt/smt_engine.h @@ -42,7 +42,6 @@ class Env; class NodeManager; class TheoryEngine; class UnsatCore; -class LogicRequest; class StatisticsRegistry; class Printer; class ResourceManager; @@ -77,7 +76,6 @@ namespace prop { namespace smt { /** Utilities */ -class Model; class SmtEngineState; class AbstractValues; class Assertions; @@ -104,9 +102,10 @@ class UnsatCoreManager; /* -------------------------------------------------------------------------- */ namespace theory { - class Rewriter; - class QuantifiersEngine; - } // namespace theory +class TheoryModel; +class Rewriter; +class QuantifiersEngine; +} // namespace theory /* -------------------------------------------------------------------------- */ @@ -115,7 +114,6 @@ class CVC5_EXPORT SmtEngine friend class ::cvc5::api::Solver; friend class ::cvc5::smt::SmtEngineState; friend class ::cvc5::smt::SmtScope; - friend class ::cvc5::LogicRequest; /* ....................................................................... */ public: @@ -226,14 +224,6 @@ class CVC5_EXPORT SmtEngine /** Is this an internal subsolver? */ bool isInternalSubsolver() const; - /** - * Get the model (only if immediately preceded by a SAT or NOT_ENTAILED - * query). Only permitted if produce-models is on. - * - * TODO (issues#287): eliminate this method. - */ - smt::Model* getModel(); - /** * Block the current model. Can be called only if immediately preceded by * a SAT or INVALID query. Only permitted if produce-models is on, and the @@ -523,6 +513,19 @@ class CVC5_EXPORT SmtEngine */ bool isModelCoreSymbol(Node v); + /** + * Get a model (only if immediately preceded by an SAT or unknown query). + * Only permitted if the model option is on. + * + * @param declaredSorts The sorts to print in the model + * @param declaredFuns The free constants to print in the model. A subset + * of these may be printed based on isModelCoreSymbol. + * @return the string corresponding to the model. If the output language is + * smt2, then this corresponds to a response to the get-model command. + */ + std::string getModel(const std::vector& declaredSorts, + const std::vector& declaredFuns); + /** print instantiations * * Print all instantiations for all quantified formulas on out, @@ -936,7 +939,7 @@ class CVC5_EXPORT SmtEngine * @param c used for giving an error message to indicate the context * this method was called. */ - smt::Model* getAvailableModel(const char* c) const; + theory::TheoryModel* getAvailableModel(const char* c) const; /** * Get available quantifiers engine, which throws a modal exception if it * does not exist. This can happen if a quantifiers-specific call (e.g. @@ -1046,13 +1049,6 @@ class CVC5_EXPORT SmtEngine /** The SMT solver */ std::unique_ptr d_smtSolver; - /** - * The SMT-level model object, which contains information about how to - * print the model, as well as a pointer to the underlying TheoryModel - * implementation maintained by the SmtSolver. - */ - std::unique_ptr d_model; - /** * The utility used for checking models */ diff --git a/test/regress/regress0/cvc-rerror-print.cvc b/test/regress/regress0/cvc-rerror-print.cvc index e134b5666..728db28d8 100644 --- a/test/regress/regress0/cvc-rerror-print.cvc +++ b/test/regress/regress0/cvc-rerror-print.cvc @@ -1,5 +1,5 @@ % EXPECT: entailed -% EXPECT: Cannot get model unless immediately preceded by SAT/NOT_ENTAILED or UNKNOWN response. +% EXPECT: Cannot get model unless after a SAT or unknown response. OPTION "logic" "ALL"; OPTION "produce-models" true; x : INT; diff --git a/test/unit/api/solver_black.cpp b/test/unit/api/solver_black.cpp index 1daa3fba4..5ca96f035 100644 --- a/test/unit/api/solver_black.cpp +++ b/test/unit/api/solver_black.cpp @@ -1571,6 +1571,47 @@ TEST_F(TestApiBlackSolver, isModelCoreSymbol) ASSERT_THROW(d_solver.isModelCoreSymbol(zero), CVC5ApiException); } +TEST_F(TestApiBlackSolver, getModel) +{ + d_solver.setOption("produce-models", "true"); + Sort uSort = d_solver.mkUninterpretedSort("u"); + Term x = d_solver.mkConst(uSort, "x"); + Term y = d_solver.mkConst(uSort, "y"); + Term z = d_solver.mkConst(uSort, "z"); + Term f = d_solver.mkTerm(NOT, d_solver.mkTerm(EQUAL, x, y)); + d_solver.assertFormula(f); + d_solver.checkSat(); + std::vector sorts; + sorts.push_back(uSort); + std::vector terms; + terms.push_back(x); + terms.push_back(y); + ASSERT_NO_THROW(d_solver.getModel(sorts, terms)); + Term null; + terms.push_back(null); + ASSERT_THROW(d_solver.getModel(sorts, terms), CVC5ApiException); +} + +TEST_F(TestApiBlackSolver, getModel2) +{ + d_solver.setOption("produce-models", "true"); + std::vector sorts; + std::vector terms; + ASSERT_THROW(d_solver.getModel(sorts, terms), CVC5ApiException); +} + +TEST_F(TestApiBlackSolver, getModel3) +{ + d_solver.setOption("produce-models", "true"); + std::vector sorts; + std::vector terms; + d_solver.checkSat(); + ASSERT_NO_THROW(d_solver.getModel(sorts, terms)); + Sort integer = d_solver.getIntegerSort(); + sorts.push_back(integer); + ASSERT_THROW(d_solver.getModel(sorts, terms), CVC5ApiException); +} + TEST_F(TestApiBlackSolver, getQuantifierElimination) { Term x = d_solver.mkVar(d_solver.getBooleanSort(), "x"); -- 2.30.2