From: Andrew Reynolds Date: Wed, 25 Nov 2020 16:46:41 +0000 (-0600) Subject: Use symbol manager for printing responses get-model (#5516) X-Git-Tag: cvc5-1.0.0~2553 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=de14432ebd850dab001bb860db102e86ec734f97;p=cvc5.git Use symbol manager for printing responses get-model (#5516) This makes symbol manager be in charge of determining which sorts and terms to print in response to get-model. This eliminates the need for the parser to call ExprManager::mkVar (and similar methods) with custom flags. This requires significant simplifications to the printers for models, where instead of a NodeCommand, we take a Sort or Term to print in the model. This is one of the last remaining steps for migrating the parser to the new API. The next step will be to remove a lot of the internal infrastructure for managing expression names, commands to print in models, node commands, node listeners, etc. --- diff --git a/src/parser/parser.cpp b/src/parser/parser.cpp index 1fc995fd6..1ca2e1c01 100644 --- a/src/parser/parser.cpp +++ b/src/parser/parser.cpp @@ -333,8 +333,7 @@ void Parser::defineParameterizedType(const std::string& name, api::Sort Parser::mkSort(const std::string& name, uint32_t flags) { Debug("parser") << "newSort(" << name << ")" << std::endl; - api::Sort type = - api::Sort(d_solver, d_solver->getExprManager()->mkSort(name, flags)); + api::Sort type = d_solver->mkUninterpretedSort(name); bool globalDecls = d_symman->getGlobalDeclarations(); defineType( name, type, globalDecls && !(flags & ExprManager::SORT_FLAG_PLACEHOLDER)); @@ -347,9 +346,7 @@ api::Sort Parser::mkSortConstructor(const std::string& name, { Debug("parser") << "newSortConstructor(" << name << ", " << arity << ")" << std::endl; - api::Sort type = api::Sort( - d_solver, - d_solver->getExprManager()->mkSortConstructor(name, arity, flags)); + api::Sort type = d_solver->mkSortConstructorSort(name, arity); bool globalDecls = d_symman->getGlobalDeclarations(); defineType(name, vector(arity), @@ -379,10 +376,7 @@ api::Sort Parser::mkUnresolvedTypeConstructor( { Debug("parser") << "newSortConstructor(P)(" << name << ", " << params.size() << ")" << std::endl; - api::Sort unresolved = - api::Sort(d_solver, - d_solver->getExprManager()->mkSortConstructor( - name, params.size(), ExprManager::SORT_FLAG_PLACEHOLDER)); + api::Sort unresolved = d_solver->mkSortConstructorSort(name, params.size()); defineType(name, params, unresolved); api::Sort t = getSort(name, params); d_unresolved.insert(unresolved); @@ -644,8 +638,7 @@ api::Term Parser::mkVar(const std::string& name, const api::Sort& type, uint32_t flags) { - return api::Term( - d_solver, d_solver->getExprManager()->mkVar(name, type.getType(), flags)); + return d_solver->mkConst(type, name); } //!!!!!!!!!!! temporary diff --git a/src/printer/ast/ast_printer.cpp b/src/printer/ast/ast_printer.cpp index 8bf3bd24e..4b9371181 100644 --- a/src/printer/ast/ast_printer.cpp +++ b/src/printer/ast/ast_printer.cpp @@ -148,9 +148,17 @@ void AstPrinter::toStream(std::ostream& out, const smt::Model& m) const out << "Model()"; } -void AstPrinter::toStream(std::ostream& out, - const smt::Model& m, - const NodeCommand* c) const +void AstPrinter::toStreamModelSort(std::ostream& out, + const smt::Model& m, + TypeNode tn) const +{ + // shouldn't be called; only the non-Command* version above should be + Unreachable(); +} + +void AstPrinter::toStreamModelTerm(std::ostream& out, + const smt::Model& m, + Node n) const { // shouldn't be called; only the non-Command* version above should be Unreachable(); @@ -272,12 +280,9 @@ void AstPrinter::toStreamCmdDefineFunction(std::ostream& out, } void AstPrinter::toStreamCmdDeclareType(std::ostream& out, - const std::string& id, - size_t arity, TypeNode type) const { - out << "DeclareType(" << id << "," << arity << "," << type << ')' - << std::endl; + out << "DeclareType(" << type << ')' << std::endl; } void AstPrinter::toStreamCmdDefineType(std::ostream& out, diff --git a/src/printer/ast/ast_printer.h b/src/printer/ast/ast_printer.h index ad20ffb79..e4251eba0 100644 --- a/src/printer/ast/ast_printer.h +++ b/src/printer/ast/ast_printer.h @@ -62,8 +62,6 @@ class AstPrinter : public CVC4::Printer /** Print declare-sort command */ void toStreamCmdDeclareType(std::ostream& out, - const std::string& id, - size_t arity, TypeNode type) const override; /** Print define-sort command */ @@ -165,9 +163,21 @@ class AstPrinter : public CVC4::Printer private: void toStream(std::ostream& out, TNode n, int toDepth) const; - void toStream(std::ostream& out, - const smt::Model& m, - const NodeCommand* c) const override; + /** + * To stream model sort. This prints the appropriate output for type + * tn declared via declare-sort or declare-datatype. + */ + void toStreamModelSort(std::ostream& out, + const smt::Model& m, + TypeNode tn) const override; + + /** + * To stream model term. This prints the appropriate output for term + * n declared via declare-fun. + */ + void toStreamModelTerm(std::ostream& out, + const smt::Model& m, + Node n) const override; }; /* class AstPrinter */ } // namespace ast diff --git a/src/printer/cvc/cvc_printer.cpp b/src/printer/cvc/cvc_printer.cpp index 44ff7be10..be530099b 100644 --- a/src/printer/cvc/cvc_printer.cpp +++ b/src/printer/cvc/cvc_printer.cpp @@ -1067,20 +1067,23 @@ void CvcPrinter::toStream(std::ostream& out, const CommandStatus* s) const }/* CvcPrinter::toStream(CommandStatus*) */ -namespace { - -void DeclareTypeNodeCommandToStream(std::ostream& out, - const theory::TheoryModel& model, - const DeclareTypeNodeCommand& command) +void CvcPrinter::toStreamModelSort(std::ostream& out, + const smt::Model& m, + TypeNode tn) const { - TypeNode type_node = command.getType(); - const std::vector* type_reps = - model.getRepSet()->getTypeRepsOrNull(type_node); + if (!tn.isSort()) + { + out << "ERROR: don't know how to print a non uninterpreted sort in model: " + << tn << std::endl; + return; + } + const theory::TheoryModel* tm = m.getTheoryModel(); + const std::vector* type_reps = tm->getRepSet()->getTypeRepsOrNull(tn); if (options::modelUninterpPrint() == options::ModelUninterpPrintMode::DtEnum - && type_node.isSort() && type_reps != nullptr) + && type_reps != nullptr) { out << "DATATYPE" << std::endl; - out << " " << command.getSymbol() << " = "; + out << " " << tn << " = "; for (size_t i = 0; i < type_reps->size(); i++) { if (i > 0) @@ -1091,16 +1094,16 @@ void DeclareTypeNodeCommandToStream(std::ostream& out, } out << std::endl << "END;" << std::endl; } - else if (type_node.isSort() && type_reps != nullptr) + else if (type_reps != nullptr) { - out << "% cardinality of " << type_node << " is " << type_reps->size() + out << "% cardinality of " << tn << " is " << type_reps->size() << std::endl; - out << command << std::endl; + toStreamCmdDeclareType(out, tn); for (Node type_rep : *type_reps) { if (type_rep.isVar()) { - out << type_rep << " : " << type_node << ";" << std::endl; + out << type_rep << " : " << tn << ";" << std::endl; } else { @@ -1110,21 +1113,15 @@ void DeclareTypeNodeCommandToStream(std::ostream& out, } else { - out << command << std::endl; + toStreamCmdDeclareType(out, tn); } } -void DeclareFunctionNodeCommandToStream( - std::ostream& out, - const theory::TheoryModel& model, - const DeclareFunctionNodeCommand& command) +void CvcPrinter::toStreamModelTerm(std::ostream& out, + const smt::Model& m, + Node n) const { - Node n = command.getFunction(); - if (n.getKind() == kind::SKOLEM) - { - // don't print out internal stuff - return; - } + const theory::TheoryModel* tm = m.getTheoryModel(); TypeNode tn = n.getType(); out << n << " : "; if (tn.isFunction() || tn.isPredicate()) @@ -1146,15 +1143,16 @@ void DeclareFunctionNodeCommandToStream( } // We get the value from the theory model directly, which notice // does not have to go through the standard SmtEngine::getValue interface. - Node val = model.getValue(n); + Node val = tm->getValue(n); if (options::modelUninterpPrint() == options::ModelUninterpPrintMode::DtEnum && val.getKind() == kind::STORE) { TypeNode type_node = val[1].getType(); if (tn.isSort()) { - if (const std::vector* type_reps = - model.getRepSet()->getTypeRepsOrNull(type_node)) + const std::vector* type_reps = + tm->getRepSet()->getTypeRepsOrNull(type_node); + if (type_reps != nullptr) { Cardinality indexCard(type_reps->size()); val = theory::arrays::TheoryArraysRewriter::normalizeConstant( @@ -1165,8 +1163,6 @@ void DeclareFunctionNodeCommandToStream( out << " = " << val << ";" << std::endl; } -} // namespace - void CvcPrinter::toStream(std::ostream& out, const smt::Model& m) const { const theory::TheoryModel* tm = m.getTheoryModel(); @@ -1185,28 +1181,6 @@ void CvcPrinter::toStream(std::ostream& out, const smt::Model& m) const out << "MODEL END;" << std::endl; } -void CvcPrinter::toStream(std::ostream& out, - const smt::Model& model, - const NodeCommand* command) const -{ - const auto* theory_model = model.getTheoryModel(); - AlwaysAssert(theory_model != nullptr); - if (const auto* declare_type_command = - dynamic_cast(command)) - { - DeclareTypeNodeCommandToStream(out, *theory_model, *declare_type_command); - } - else if (const auto* dfc = - dynamic_cast(command)) - { - DeclareFunctionNodeCommandToStream(out, *theory_model, *dfc); - } - else - { - out << *command << std::endl; - } -} - void CvcPrinter::toStreamCmdAssert(std::ostream& out, Node n) const { out << "ASSERT " << n << ';' << std::endl; @@ -1322,6 +1296,7 @@ void CvcPrinter::toStreamCmdDeclarationSequence( { DeclarationDefinitionCommand* dd = static_cast(*i++); + Assert(dd != nullptr); if (i != sequence.cend()) { out << dd->getSymbol() << ", "; @@ -1376,20 +1351,18 @@ void CvcPrinter::toStreamCmdDefineFunction(std::ostream& out, } void CvcPrinter::toStreamCmdDeclareType(std::ostream& out, - const std::string& id, - size_t arity, TypeNode type) const { + size_t arity = type.isSortConstructor() ? type.getSortConstructorArity() : 0; if (arity > 0) { - // TODO? out << "ERROR: Don't know how to print parameterized type declaration " "in CVC language." << std::endl; } else { - out << id << " : TYPE;" << std::endl; + out << type << " : TYPE;" << std::endl; } } diff --git a/src/printer/cvc/cvc_printer.h b/src/printer/cvc/cvc_printer.h index ee4750a61..b0328bc3c 100644 --- a/src/printer/cvc/cvc_printer.h +++ b/src/printer/cvc/cvc_printer.h @@ -63,8 +63,6 @@ class CvcPrinter : public CVC4::Printer /** Print declare-sort command */ void toStreamCmdDeclareType(std::ostream& out, - const std::string& id, - size_t arity, TypeNode type) const override; /** Print define-sort command */ @@ -166,9 +164,21 @@ class CvcPrinter : public CVC4::Printer private: void toStream(std::ostream& out, TNode n, int toDepth, bool bracket) const; - void toStream(std::ostream& out, - const smt::Model& m, - const NodeCommand* c) const override; + /** + * To stream model sort. This prints the appropriate output for type + * tn declared via declare-sort or declare-datatype. + */ + void toStreamModelSort(std::ostream& out, + const smt::Model& m, + TypeNode tn) const override; + + /** + * To stream model term. This prints the appropriate output for term + * n declared via declare-fun. + */ + void toStreamModelTerm(std::ostream& out, + const smt::Model& m, + Node n) const override; bool d_cvc3Mode; }; /* class CvcPrinter */ diff --git a/src/printer/printer.cpp b/src/printer/printer.cpp index b24025124..7225721c0 100644 --- a/src/printer/printer.cpp +++ b/src/printer/printer.cpp @@ -74,18 +74,34 @@ unique_ptr Printer::makePrinter(OutputLanguage lang) void Printer::toStream(std::ostream& out, const smt::Model& m) const { - for(size_t i = 0; i < m.getNumCommands(); ++i) { - const NodeCommand* cmd = m.getCommand(i); - const DeclareFunctionNodeCommand* dfc = - dynamic_cast(cmd); - if (dfc != NULL && !m.isModelCoreSymbol(dfc->getFunction())) + // print the declared sorts + const std::vector& dsorts = m.getDeclaredSorts(); + for (const TypeNode& tn : dsorts) + { + toStreamModelSort(out, m, tn); + } + + // print the declared terms + const std::vector& dterms = m.getDeclaredTerms(); + for (const Node& n : dterms) + { + // take into account model core, independently of the format + if (!m.isModelCoreSymbol(n)) { continue; } - toStream(out, m, cmd); + toStreamModelTerm(out, m, n); } + }/* Printer::toStream(Model) */ +void Printer::toStreamUsing(OutputLanguage lang, + std::ostream& out, + const smt::Model& m) const +{ + getPrinter(lang)->toStream(out, m); +} + void Printer::toStream(std::ostream& out, const UnsatCore& core) const { for(UnsatCore::iterator i = core.begin(); i != core.end(); ++i) { @@ -160,8 +176,6 @@ void Printer::toStreamCmdDeclareFunction(std::ostream& out, } void Printer::toStreamCmdDeclareType(std::ostream& out, - const std::string& id, - size_t arity, TypeNode type) const { printUnknownCommand(out, "declare-sort"); diff --git a/src/printer/printer.h b/src/printer/printer.h index d32418deb..5bcccedb8 100644 --- a/src/printer/printer.h +++ b/src/printer/printer.h @@ -86,8 +86,6 @@ class Printer /** Print declare-sort command */ virtual void toStreamCmdDeclareType(std::ostream& out, - const std::string& id, - size_t arity, TypeNode type) const; /** Print define-sort command */ @@ -266,19 +264,26 @@ class Printer /** Derived classes can construct, but no one else. */ Printer() {} - /** write model response to command */ - virtual void toStream(std::ostream& out, - const smt::Model& m, - const NodeCommand* c) const = 0; + /** + * To stream model sort. This prints the appropriate output for type + * tn declared via declare-sort or declare-datatype. + */ + virtual void toStreamModelSort(std::ostream& out, + const smt::Model& m, + TypeNode tn) const = 0; + + /** + * To stream model term. This prints the appropriate output for term + * n declared via declare-fun. + */ + virtual void toStreamModelTerm(std::ostream& out, + const smt::Model& m, + Node n) const = 0; /** write model response to command using another language printer */ void toStreamUsing(OutputLanguage lang, std::ostream& out, - const smt::Model& m, - const NodeCommand* c) const - { - getPrinter(lang)->toStream(out, m, c); - } + const smt::Model& m) const; /** * Write an error to `out` stating that command `name` is not supported by diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 747873bee..9e9500bdb 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1364,124 +1364,91 @@ void Smt2Printer::toStream(std::ostream& out, const smt::Model& m) const } } -void Smt2Printer::toStream(std::ostream& out, - const smt::Model& model, - const NodeCommand* command) const +void Smt2Printer::toStreamModelSort(std::ostream& out, + const smt::Model& m, + TypeNode tn) const { - const theory::TheoryModel* theory_model = model.getTheoryModel(); - AlwaysAssert(theory_model != nullptr); - if (const DeclareTypeNodeCommand* dtc = - dynamic_cast(command)) + if (!tn.isSort()) { - // print out the DeclareTypeCommand - TypeNode tn = dtc->getType(); - if (!tn.isSort()) + out << "ERROR: don't know how to print non uninterpreted sort in model: " + << tn << std::endl; + return; + } + const theory::TheoryModel* tm = m.getTheoryModel(); + std::vector elements = tm->getDomainElements(tn); + if (options::modelUninterpPrint() == options::ModelUninterpPrintMode::DtEnum) + { + if (isVariant_2_6(d_variant)) { - out << (*dtc) << endl; + out << "(declare-datatypes ((" << tn << " 0)) ("; } else { - std::vector elements = theory_model->getDomainElements(tn); - if (options::modelUninterpPrint() - == options::ModelUninterpPrintMode::DtEnum) - { - if (isVariant_2_6(d_variant)) - { - out << "(declare-datatypes ((" << (*dtc).getSymbol() << " 0)) ("; - } - else - { - out << "(declare-datatypes () ((" << (*dtc).getSymbol() << " "; - } - for (const Node& type_ref : elements) - { - out << "(" << type_ref << ")"; - } - out << ")))" << endl; - } - else - { - // print the cardinality - out << "; cardinality of " << tn << " is " << elements.size() << endl; - if (options::modelUninterpPrint() - == options::ModelUninterpPrintMode::DeclSortAndFun) - { - out << (*dtc) << endl; - } - // print the representatives - for (const Node& trn : elements) - { - if (trn.isVar()) - { - out << "(declare-fun " << quoteSymbol(trn) << " () " << tn << ")" - << endl; - } - else - { - out << "; rep: " << trn << endl; - } - } - } - } - } - else if (const DeclareFunctionNodeCommand* dfc = - dynamic_cast(command)) - { - // print out the DeclareFunctionCommand - Node n = dfc->getFunction(); - if ((*dfc).getPrintInModelSetByUser()) - { - if (!(*dfc).getPrintInModel()) - { - return; - } + out << "(declare-datatypes () ((" << tn << " "; } - else if (n.getKind() == kind::SKOLEM) + for (const Node& type_ref : elements) { - // don't print out internal stuff - return; + out << "(" << type_ref << ")"; } - // We get the value from the theory model directly, which notice - // does not have to go through the standard SmtEngine::getValue interface. - Node val = theory_model->getValue(n); - if (val.getKind() == kind::LAMBDA) + out << ")))" << endl; + return; + } + // print the cardinality + out << "; cardinality of " << tn << " is " << elements.size() << endl; + if (options::modelUninterpPrint() + == options::ModelUninterpPrintMode::DeclSortAndFun) + { + toStreamCmdDeclareType(out, tn); + } + // print the representatives + for (const Node& trn : elements) + { + if (trn.isVar()) { - TypeNode rangeType = n.getType().getRangeType(); - out << "(define-fun " << n << " " << val[0] << " " << rangeType << " "; - // call toStream and force its type to be proper - toStreamCastToType(out, val[1], -1, rangeType); - out << ")" << endl; + out << "(declare-fun " << quoteSymbol(trn) << " () " << tn << ")" << endl; } else { - if (options::modelUninterpPrint() - == options::ModelUninterpPrintMode::DtEnum - && val.getKind() == kind::STORE) - { - TypeNode tn = val[1].getType(); - const std::vector* type_refs = - theory_model->getRepSet()->getTypeRepsOrNull(tn); - if (tn.isSort() && type_refs != nullptr) - { - Cardinality indexCard(type_refs->size()); - val = theory::arrays::TheoryArraysRewriter::normalizeConstant( - val, indexCard); - } - } - out << "(define-fun " << n << " () " << n.getType() << " "; - // call toStream and force its type to be proper - toStreamCastToType(out, val, -1, n.getType()); - out << ")" << endl; + out << "; rep: " << trn << endl; } } - else if (const DeclareDatatypeNodeCommand* declare_datatype_command = - dynamic_cast(command)) +} + +void Smt2Printer::toStreamModelTerm(std::ostream& out, + const smt::Model& m, + Node n) const +{ + const theory::TheoryModel* tm = m.getTheoryModel(); + // We get the value from the theory model directly, which notice + // does not have to go through the standard SmtEngine::getValue interface. + Node val = tm->getValue(n); + if (val.getKind() == kind::LAMBDA) { - out << *declare_datatype_command; + TypeNode rangeType = n.getType().getRangeType(); + out << "(define-fun " << n << " " << val[0] << " " << rangeType << " "; + // call toStream and force its type to be proper + toStreamCastToType(out, val[1], -1, rangeType); + out << ")" << endl; } else { - Unreachable(); + if (options::modelUninterpPrint() == options::ModelUninterpPrintMode::DtEnum + && val.getKind() == kind::STORE) + { + TypeNode tn = val[1].getType(); + const std::vector* type_refs = + tm->getRepSet()->getTypeRepsOrNull(tn); + if (tn.isSort() && type_refs != nullptr) + { + Cardinality indexCard(type_refs->size()); + val = theory::arrays::TheoryArraysRewriter::normalizeConstant( + val, indexCard); + } + } + out << "(define-fun " << n << " () " << n.getType() << " "; + // call toStream and force its type to be proper + toStreamCastToType(out, val, -1, n.getType()); + out << ")" << endl; } } @@ -1694,11 +1661,13 @@ void Smt2Printer::toStreamCmdDefineFunctionRec( } void Smt2Printer::toStreamCmdDeclareType(std::ostream& out, - const std::string& id, - size_t arity, TypeNode type) const { - out << "(declare-sort " << CVC4::quoteSymbol(id) << " " << arity << ")" + Assert(type.isSort() || type.isSortConstructor()); + std::stringstream id; + id << type; + size_t arity = type.isSortConstructor() ? type.getSortConstructorArity() : 0; + out << "(declare-sort " << CVC4::quoteSymbol(id.str()) << " " << arity << ")" << std::endl; } diff --git a/src/printer/smt2/smt2_printer.h b/src/printer/smt2/smt2_printer.h index c83a74d97..3d90cee06 100644 --- a/src/printer/smt2/smt2_printer.h +++ b/src/printer/smt2/smt2_printer.h @@ -78,8 +78,6 @@ class Smt2Printer : public CVC4::Printer /** Print declare-sort command */ void toStreamCmdDeclareType(std::ostream& out, - const std::string& id, - size_t arity, TypeNode type) const override; /** Print define-sort command */ @@ -243,11 +241,24 @@ class Smt2Printer : public CVC4::Printer TNode n, int toDepth, TypeNode tn) const; - void toStream(std::ostream& out, - const smt::Model& m, - const NodeCommand* c) const override; void toStream(std::ostream& out, const SExpr& sexpr) const; void toStream(std::ostream& out, const DType& dt) const; + /** + * To stream model sort. This prints the appropriate output for type + * tn declared via declare-sort or declare-datatype. + */ + void toStreamModelSort(std::ostream& out, + const smt::Model& m, + TypeNode tn) const override; + + /** + * To stream model term. This prints the appropriate output for term + * n declared via declare-fun. + */ + void toStreamModelTerm(std::ostream& out, + const smt::Model& m, + Node n) const override; + /** * To stream with let binding. This prints n, possibly in the scope * of letification generated by this method based on lbind. diff --git a/src/printer/tptp/tptp_printer.cpp b/src/printer/tptp/tptp_printer.cpp index c93f3593a..f9384b0cb 100644 --- a/src/printer/tptp/tptp_printer.cpp +++ b/src/printer/tptp/tptp_printer.cpp @@ -54,20 +54,27 @@ void TptpPrinter::toStream(std::ostream& out, const smt::Model& m) const : "CandidateFiniteModel"); out << "% SZS output start " << statusName << " for " << m.getInputName() << endl; - for(size_t i = 0; i < m.getNumCommands(); ++i) { - this->Printer::toStreamUsing(language::output::LANG_SMTLIB_V2_5, out, m, m.getCommand(i)); - } + this->Printer::toStreamUsing(language::output::LANG_SMTLIB_V2_5, out, m); out << "% SZS output end " << statusName << " for " << m.getInputName() << endl; } -void TptpPrinter::toStream(std::ostream& out, - const smt::Model& m, - const NodeCommand* c) const +void TptpPrinter::toStreamModelSort(std::ostream& out, + const smt::Model& m, + TypeNode tn) const +{ + // shouldn't be called; only the non-Command* version above should be + Unreachable(); +} + +void TptpPrinter::toStreamModelTerm(std::ostream& out, + const smt::Model& m, + Node n) const { // shouldn't be called; only the non-Command* version above should be Unreachable(); } + void TptpPrinter::toStream(std::ostream& out, const UnsatCore& core) const { out << "% SZS output start UnsatCore " << std::endl; diff --git a/src/printer/tptp/tptp_printer.h b/src/printer/tptp/tptp_printer.h index 449fe409c..38a56bcb5 100644 --- a/src/printer/tptp/tptp_printer.h +++ b/src/printer/tptp/tptp_printer.h @@ -44,9 +44,21 @@ class TptpPrinter : public CVC4::Printer void toStream(std::ostream& out, const UnsatCore& core) const override; private: - void toStream(std::ostream& out, - const smt::Model& m, - const NodeCommand* c) const override; + /** + * To stream model sort. This prints the appropriate output for type + * tn declared via declare-sort or declare-datatype. + */ + void toStreamModelSort(std::ostream& out, + const smt::Model& m, + TypeNode tn) const override; + + /** + * To stream model term. This prints the appropriate output for term + * n declared via declare-fun. + */ + void toStreamModelTerm(std::ostream& out, + const smt::Model& m, + Node n) const override; }; /* class TptpPrinter */ diff --git a/src/smt/command.cpp b/src/smt/command.cpp index 717d423fe..154166eb7 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -1091,6 +1091,8 @@ void DeclareFunctionCommand::setPrintInModel(bool p) void DeclareFunctionCommand::invoke(api::Solver* solver, SymbolManager* sm) { + // mark that it will be printed in the model + sm->addModelDeclarationTerm(d_func); d_commandStatus = CommandSuccess::instance(); } @@ -1132,6 +1134,8 @@ size_t DeclareSortCommand::getArity() const { return d_arity; } api::Sort DeclareSortCommand::getSort() const { return d_sort; } void DeclareSortCommand::invoke(api::Solver* solver, SymbolManager* sm) { + // mark that it will be printed in the model + sm->addModelDeclarationSort(d_sort); d_commandStatus = CommandSuccess::instance(); } @@ -1150,8 +1154,8 @@ void DeclareSortCommand::toStream(std::ostream& out, size_t dag, OutputLanguage language) const { - Printer::getPrinter(language)->toStreamCmdDeclareType( - out, d_sort.toString(), d_arity, d_sort.getTypeNode()); + Printer::getPrinter(language)->toStreamCmdDeclareType(out, + d_sort.getTypeNode()); } /* -------------------------------------------------------------------------- */ @@ -1693,6 +1697,18 @@ void GetModelCommand::invoke(api::Solver* solver, SymbolManager* sm) try { d_result = solver->getSmtEngine()->getModel(); + // set the model declarations, which determines what is printed in the model + d_result->clearModelDeclarations(); + std::vector declareSorts = sm->getModelDeclareSorts(); + for (const api::Sort& s : declareSorts) + { + d_result->addDeclarationSort(s.getTypeNode()); + } + std::vector declareTerms = sm->getModelDeclareTerms(); + for (const api::Term& t : declareTerms) + { + d_result->addDeclarationTerm(t.getNode()); + } d_commandStatus = CommandSuccess::instance(); } catch (RecoverableModalException& e) diff --git a/src/smt/node_command.cpp b/src/smt/node_command.cpp index eb2493c87..91184d27d 100644 --- a/src/smt/node_command.cpp +++ b/src/smt/node_command.cpp @@ -104,8 +104,7 @@ void DeclareTypeNodeCommand::toStream(std::ostream& out, size_t dag, OutputLanguage language) const { - Printer::getPrinter(language)->toStreamCmdDeclareType( - out, d_id, d_arity, d_type); + Printer::getPrinter(language)->toStreamCmdDeclareType(out, d_type); } NodeCommand* DeclareTypeNodeCommand::clone() const diff --git a/test/regress/regress0/datatypes/dt-param-2.6-print.smt2 b/test/regress/regress0/datatypes/dt-param-2.6-print.smt2 index 2b706478f..ce92821c1 100644 --- a/test/regress/regress0/datatypes/dt-param-2.6-print.smt2 +++ b/test/regress/regress0/datatypes/dt-param-2.6-print.smt2 @@ -1,6 +1,5 @@ ; EXPECT: sat ; EXPECT: ( -; EXPECT: (declare-datatypes ((Pair 2)) ((par (X Y)((mkPair (first X) (second Y)))))) ; EXPECT: (define-fun x () (Pair Int Real) ((as mkPair (Pair Int Real)) 2 (/ 3 2))) ; EXPECT: ) diff --git a/test/unit/parser/parser_black.h b/test/unit/parser/parser_black.h index 3b0bbb139..ef8f2e3cf 100644 --- a/test/unit/parser/parser_black.h +++ b/test/unit/parser/parser_black.h @@ -259,8 +259,6 @@ public: tryGoodInput("a : INT = 5; a: INT;"); // decl after define, compatible tryGoodInput("a : TYPE; a : INT;"); // ok, sort and variable symbol spaces distinct tryGoodInput("a : TYPE; a : INT; b : a;"); // ok except a is both INT and sort `a' - //tryGoodInput("a : [0..0]; b : [-5..5]; c : [-1..1]; d : [ _ .._];"); // subranges - tryGoodInput("a : [ _..1]; b : [_.. 0]; c :[_..-1];"); tryGoodInput("DATATYPE list = nil | cons(car:INT,cdr:list) END; DATATYPE cons = null END;"); tryGoodInput("DATATYPE tree = node(data:list), list = cons(car:tree,cdr:list) | nil END;"); //tryGoodInput("DATATYPE tree = node(data:[list,list,ARRAY tree OF list]), list = cons(car:ARRAY list OF tree,cdr:BITVECTOR(32)) END;");