From 700689a4e4ed42b5198816611eac5bcc1278284d Mon Sep 17 00:00:00 2001 From: Morgan Deters Date: Sat, 9 Jun 2012 00:35:38 +0000 Subject: [PATCH] Dagification of output expressions. By default, common subexpressions are dagified if they appear > 1 time and are not constants or variables. This can be changed with --default-expr-dag=N --- N is a threshold such that if the subexpression occurs > N times, it is dagified; a setting of 0 turns off dagification entirely. If you notice strange dumping behavior (taking too long to print anything, e.g.), revert to the old behavior with --default-expr-dag=0 and let me know of the problem. --- src/expr/command.cpp | 5 +- src/expr/command.h | 2 +- src/expr/expr.i | 1 + src/expr/expr_template.cpp | 5 +- src/expr/expr_template.h | 119 +++++++++++++++++-- src/expr/node.h | 30 ++++- src/expr/node_value.cpp | 4 +- src/expr/node_value.h | 14 ++- src/expr/type_node.h | 15 ++- src/printer/Makefile.am | 1 + src/printer/ast/ast_printer.cpp | 47 +++++++- src/printer/ast/ast_printer.h | 5 +- src/printer/cvc/cvc_printer.cpp | 36 +++++- src/printer/cvc/cvc_printer.h | 9 +- src/printer/dagification_visitor.h | 180 +++++++++++++++++++++++++++++ src/printer/printer.h | 4 +- src/printer/smt/smt_printer.cpp | 8 +- src/printer/smt/smt_printer.h | 4 +- src/printer/smt2/smt2_printer.cpp | 53 +++++++-- src/printer/smt2/smt2_printer.h | 5 +- src/theory/substitutions.h | 8 +- src/util/node_visitor.h | 6 +- src/util/options.cpp | 20 ++++ test/unit/expr/node_black.h | 3 +- 24 files changed, 523 insertions(+), 61 deletions(-) create mode 100644 src/printer/dagification_visitor.h diff --git a/src/expr/command.cpp b/src/expr/command.cpp index 78d04f000..ae24f4984 100644 --- a/src/expr/command.cpp +++ b/src/expr/command.cpp @@ -43,6 +43,7 @@ std::ostream& operator<<(std::ostream& out, const Command& c) throw() { c.toStream(out, Node::setdepth::getDepth(out), Node::printtypes::getPrintTypes(out), + Node::dag::getDag(out), Node::setlanguage::getLanguage(out)); return out; } @@ -101,9 +102,9 @@ std::string Command::toString() const throw() { return ss.str(); } -void Command::toStream(std::ostream& out, int toDepth, bool types, +void Command::toStream(std::ostream& out, int toDepth, bool types, size_t dag, OutputLanguage language) const throw() { - Printer::getPrinter(language)->toStream(out, this, toDepth, types); + Printer::getPrinter(language)->toStream(out, this, toDepth, types, dag); } void CommandStatus::toStream(std::ostream& out, OutputLanguage language) const throw() { diff --git a/src/expr/command.h b/src/expr/command.h index 19d1f16e7..98046c242 100644 --- a/src/expr/command.h +++ b/src/expr/command.h @@ -205,7 +205,7 @@ public: virtual void invoke(SmtEngine* smtEngine) throw() = 0; virtual void invoke(SmtEngine* smtEngine, std::ostream& out) throw(); - virtual void toStream(std::ostream& out, int toDepth = -1, bool types = false, + virtual void toStream(std::ostream& out, int toDepth = -1, bool types = false, size_t dag = 1, OutputLanguage language = language::output::LANG_AST) const throw(); std::string toString() const throw(); diff --git a/src/expr/expr.i b/src/expr/expr.i index 6ed7f2d25..9b6c55703 100644 --- a/src/expr/expr.i +++ b/src/expr/expr.i @@ -9,6 +9,7 @@ %ignore CVC4::expr::operator<<(std::ostream&, ExprSetDepth); %ignore CVC4::expr::operator<<(std::ostream&, ExprPrintTypes); +%ignore CVC4::expr::operator<<(std::ostream&, ExprDag); %ignore CVC4::expr::operator<<(std::ostream&, ExprSetLanguage); %rename(assign) CVC4::Expr::operator=(const Expr&); diff --git a/src/expr/expr_template.cpp b/src/expr/expr_template.cpp index d0f5fde9e..c70fed889 100644 --- a/src/expr/expr_template.cpp +++ b/src/expr/expr_template.cpp @@ -45,6 +45,7 @@ namespace expr { const int ExprSetDepth::s_iosIndex = std::ios_base::xalloc(); const int ExprPrintTypes::s_iosIndex = std::ios_base::xalloc(); +const int ExprDag::s_iosIndex = std::ios_base::xalloc(); const int ExprSetLanguage::s_iosIndex = std::ios_base::xalloc(); }/* CVC4::expr namespace */ @@ -416,10 +417,10 @@ bool Expr::isConst() const { return d_node->isConst(); } -void Expr::toStream(std::ostream& out, int depth, bool types, +void Expr::toStream(std::ostream& out, int depth, bool types, size_t dag, OutputLanguage language) const { ExprManagerScope ems(*this); - d_node->toStream(out, depth, types, language); + d_node->toStream(out, depth, types, dag, language); } Node Expr::getNode() const throw() { diff --git a/src/expr/expr_template.h b/src/expr/expr_template.h index 7a6c0179d..6cd476a5f 100644 --- a/src/expr/expr_template.h +++ b/src/expr/expr_template.h @@ -79,6 +79,7 @@ namespace smt { namespace expr { class CVC4_PUBLIC ExprSetDepth; class CVC4_PUBLIC ExprPrintTypes; + class CVC4_PUBLIC ExprDag; class CVC4_PUBLIC ExprSetLanguage; NodeTemplate exportInternal(NodeTemplate n, ExprManager* from, ExprManager* to, ExprManagerMapCollection& vmap); @@ -400,7 +401,7 @@ public: * debugging expressions) * @param language the language in which to output */ - void toStream(std::ostream& out, int toDepth = -1, bool types = false, + void toStream(std::ostream& out, int toDepth = -1, bool types = false, size_t dag = 1, OutputLanguage language = language::output::LANG_AST) const; /** @@ -492,6 +493,11 @@ public: */ typedef expr::ExprPrintTypes printtypes; + /** + * IOStream manipulator to print expressions as a DAG (or not). + */ + typedef expr::ExprDag dag; + /** * IOStream manipulator to set the output language for Exprs. */ @@ -722,13 +728,13 @@ public: */ class CVC4_PUBLIC ExprPrintTypes { /** - * The allocated index in ios_base for our depth setting. + * The allocated index in ios_base for our setting. */ static const int s_iosIndex; /** - * The default depth to print, for ostreams that haven't yet had a - * setdepth() applied to them. + * The default printtypes setting, for ostreams that haven't yet had a + * printtypes() applied to them. */ static const int s_defaultPrintTypes = false; @@ -781,6 +787,85 @@ public: };/* class ExprPrintTypes */ +/** + * IOStream manipulator to print expressions as a dag (or not). + */ +class CVC4_PUBLIC ExprDag { + /** + * The allocated index in ios_base for our setting. + */ + static const int s_iosIndex; + + /** + * The default setting, for ostreams that haven't yet had a + * dag() applied to them. + */ + static const size_t s_defaultDag = 1; + + /** + * When this manipulator is used, the setting is stored here. + */ + size_t d_dag; + +public: + /** + * Construct a ExprDag with the given setting (dagification on or off). + */ + explicit ExprDag(bool dag) : d_dag(dag ? 1 : 0) {} + + /** + * Construct a ExprDag with the given setting (letify only common + * subexpressions that appear more than 'dag' times). dag==0 means + * don't dagify. + */ + ExprDag(size_t dag) : d_dag(dag) {} + + inline void applyDag(std::ostream& out) { + // (offset by one to detect whether default has been set yet) + out.iword(s_iosIndex) = static_cast(d_dag) + 1; + } + + static inline size_t getDag(std::ostream& out) { + long& l = out.iword(s_iosIndex); + if(l == 0) { + // set the default dag setting on this ostream + // (offset by one to detect whether default has been set yet) + l = s_defaultDag + 1; + } + return static_cast(l - 1); + } + + static inline void setDag(std::ostream& out, size_t dag) { + // (offset by one to detect whether default has been set yet) + out.iword(s_iosIndex) = static_cast(dag) + 1; + } + + /** + * Set the dag state on the output stream for the current + * stack scope. This makes sure the old state is reset on the + * stream after normal OR exceptional exit from the scope, using the + * RAII C++ idiom. + */ + class Scope { + std::ostream& d_out; + size_t d_oldDag; + + public: + + inline Scope(std::ostream& out, size_t dag) : + d_out(out), + d_oldDag(ExprDag::getDag(out)) { + ExprDag::setDag(out, dag); + } + + inline ~Scope() { + ExprDag::setDag(d_out, d_oldDag); + } + + };/* class ExprDag::Scope */ + +};/* class ExprDag */ + /** * IOStream manipulator to set the output language for Exprs. */ @@ -857,13 +942,13 @@ public: ${getConst_instantiations} -#line 861 "${template}" +#line 938 "${template}" namespace expr { /** - * Sets the default print-types setting when pretty-printing an Expr - * to an ostream. Use like this: + * Sets the default depth when pretty-printing a Expr to an ostream. + * Use like this: * * // let out be an ostream, e an Expr * out << Expr::setdepth(n) << e << endl; @@ -876,11 +961,11 @@ inline std::ostream& operator<<(std::ostream& out, ExprSetDepth sd) { } /** - * Sets the default depth when pretty-printing a Expr to an ostream. - * Use like this: + * Sets the default print-types setting when pretty-printing an Expr + * to an ostream. Use like this: * * // let out be an ostream, e an Expr - * out << Expr::setprinttypes(true) << e << endl; + * out << Expr::printtypes(true) << e << endl; * * The setting stays permanently (until set again) with the stream. */ @@ -889,6 +974,20 @@ inline std::ostream& operator<<(std::ostream& out, ExprPrintTypes pt) { return out; } +/** + * Sets the default dag setting when pretty-printing a Expr to an ostream. + * Use like this: + * + * // let out be an ostream, e an Expr + * out << Expr::dag(true) << e << endl; + * + * The setting stays permanently (until set again) with the stream. + */ +inline std::ostream& operator<<(std::ostream& out, ExprDag d) { + d.applyDag(out); + return out; +} + /** * Sets the output language when pretty-printing a Expr to an ostream. * Use like this: diff --git a/src/expr/node.h b/src/expr/node.h index 3532116bc..a61944433 100644 --- a/src/expr/node.h +++ b/src/expr/node.h @@ -803,10 +803,10 @@ public: * (might break language compliance, but good for debugging expressions) * @param language the language in which to output */ - inline void toStream(std::ostream& out, int toDepth = -1, bool types = false, + inline void toStream(std::ostream& out, int toDepth = -1, bool types = false, size_t dag = 1, OutputLanguage language = language::output::LANG_AST) const { assertTNodeNotExpired(); - d_nv->toStream(out, toDepth, types, language); + d_nv->toStream(out, toDepth, types, dag, language); } /** @@ -836,6 +836,11 @@ public: */ typedef expr::ExprPrintTypes printtypes; + /** + * IOStream manipulator to print expressions as DAGs (or not). + */ + typedef expr::ExprDag dag; + /** * IOStream manipulator to set the output language for Exprs. */ @@ -885,6 +890,7 @@ inline std::ostream& operator<<(std::ostream& out, TNode n) { n.toStream(out, Node::setdepth::getDepth(out), Node::printtypes::getPrintTypes(out), + Node::dag::getDag(out), Node::setlanguage::getLanguage(out)); return out; } @@ -1468,6 +1474,16 @@ bool NodeTemplate::hasSubterm(NodeTemplate t, bool strict) con */ static void __attribute__((used)) debugPrintNode(const NodeTemplate& n) { Warning() << Node::setdepth(-1) + << Node::printtypes(false) + << Node::dag(true) + << Node::setlanguage(language::output::LANG_AST) + << n << std::endl; + Warning().flush(); +} +static void __attribute__((used)) debugPrintNodeNoDag(const NodeTemplate& n) { + Warning() << Node::setdepth(-1) + << Node::printtypes(false) + << Node::dag(false) << Node::setlanguage(language::output::LANG_AST) << n << std::endl; Warning().flush(); @@ -1479,6 +1495,16 @@ static void __attribute__((used)) debugPrintRawNode(const NodeTemplate& n) static void __attribute__((used)) debugPrintTNode(const NodeTemplate& n) { Warning() << Node::setdepth(-1) + << Node::printtypes(false) + << Node::dag(true) + << Node::setlanguage(language::output::LANG_AST) + << n << std::endl; + Warning().flush(); +} +static void __attribute__((used)) debugPrintTNodeNoDag(const NodeTemplate& n) { + Warning() << Node::setdepth(-1) + << Node::printtypes(false) + << Node::dag(false) << Node::setlanguage(language::output::LANG_AST) << n << std::endl; Warning().flush(); diff --git a/src/expr/node_value.cpp b/src/expr/node_value.cpp index 970d2e0fc..dbf706c45 100644 --- a/src/expr/node_value.cpp +++ b/src/expr/node_value.cpp @@ -45,14 +45,14 @@ string NodeValue::toString() const { return ss.str(); } -void NodeValue::toStream(std::ostream& out, int toDepth, bool types, +void NodeValue::toStream(std::ostream& out, int toDepth, bool types, size_t dag, OutputLanguage language) const { // Ensure that this node value is live for the length of this call. // It really breaks things badly if we don't have a nonzero ref // count, even just for printing. RefCountGuard guard(this); - Printer::getPrinter(language)->toStream(out, TNode(this), toDepth, types); + Printer::getPrinter(language)->toStream(out, TNode(this), toDepth, types, dag); } void NodeValue::printAst(std::ostream& out, int ind) const { diff --git a/src/expr/node_value.h b/src/expr/node_value.h index e5ecfbc48..657fabeb5 100644 --- a/src/expr/node_value.h +++ b/src/expr/node_value.h @@ -267,7 +267,7 @@ public: } std::string toString() const; - void toStream(std::ostream& out, int toDepth = -1, bool types = false, + void toStream(std::ostream& out, int toDepth = -1, bool types = false, size_t dag = 1, OutputLanguage = language::output::LANG_AST) const; static inline unsigned kindToDKind(Kind k) { @@ -487,6 +487,7 @@ inline std::ostream& operator<<(std::ostream& out, const NodeValue& nv) { nv.toStream(out, Node::setdepth::getDepth(out), Node::printtypes::getPrintTypes(out), + Node::dag::getDag(out), Node::setlanguage::getLanguage(out)); return out; } @@ -501,11 +502,20 @@ inline std::ostream& operator<<(std::ostream& out, const NodeValue& nv) { */ static void __attribute__((used)) debugPrintNodeValue(const expr::NodeValue* nv) { Warning() << Node::setdepth(-1) + << Node::printtypes(false) + << Node::dag(true) + << Node::setlanguage(language::output::LANG_AST) + << *nv << std::endl; + Warning().flush(); +} +static void __attribute__((used)) debugPrintNodeValueNoDag(const expr::NodeValue* nv) { + Warning() << Node::setdepth(-1) + << Node::printtypes(false) + << Node::dag(false) << Node::setlanguage(language::output::LANG_AST) << *nv << std::endl; Warning().flush(); } - static void __attribute__((used)) debugPrintRawNodeValue(const expr::NodeValue* nv) { nv->printAst(Warning(), 0); Warning().flush(); diff --git a/src/expr/type_node.h b/src/expr/type_node.h index 482da2814..bfbedde88 100644 --- a/src/expr/type_node.h +++ b/src/expr/type_node.h @@ -389,9 +389,9 @@ public: * (might break language compliance, but good for debugging expressions) * @param language the language in which to output */ - inline void toStream(std::ostream& out, int toDepth = -1, bool types = false, + inline void toStream(std::ostream& out, int toDepth = -1, bool types = false, size_t dag = 1, OutputLanguage language = language::output::LANG_AST) const { - d_nv->toStream(out, toDepth, types, language); + d_nv->toStream(out, toDepth, types, dag, language); } /** @@ -636,6 +636,7 @@ inline std::ostream& operator<<(std::ostream& out, const TypeNode& n) { n.toStream(out, Node::setdepth::getDepth(out), Node::printtypes::getPrintTypes(out), + Node::dag::getDag(out), Node::setlanguage::getLanguage(out)); return out; } @@ -980,6 +981,16 @@ inline const SubrangeBounds& TypeNode::getSubrangeBounds() const { */ static void __attribute__((used)) debugPrintTypeNode(const TypeNode& n) { Warning() << Node::setdepth(-1) + << Node::printtypes(false) + << Node::dag(true) + << Node::setlanguage(language::output::LANG_AST) + << n << std::endl; + Warning().flush(); +} +static void __attribute__((used)) debugPrintTypeNodeNoDag(const TypeNode& n) { + Warning() << Node::setdepth(-1) + << Node::printtypes(false) + << Node::dag(false) << Node::setlanguage(language::output::LANG_AST) << n << std::endl; Warning().flush(); diff --git a/src/printer/Makefile.am b/src/printer/Makefile.am index ca3cbacf1..3f0eba12d 100644 --- a/src/printer/Makefile.am +++ b/src/printer/Makefile.am @@ -6,6 +6,7 @@ AM_CXXFLAGS = -Wall -Wno-unknown-pragmas $(FLAG_VISIBILITY_HIDDEN) noinst_LTLIBRARIES = libprinter.la libprinter_la_SOURCES = \ + dagification_visitor.h \ printer.h \ printer.cpp \ ast/ast_printer.h \ diff --git a/src/printer/ast/ast_printer.cpp b/src/printer/ast/ast_printer.cpp index b941957c4..5a7b2e834 100644 --- a/src/printer/ast/ast_printer.cpp +++ b/src/printer/ast/ast_printer.cpp @@ -21,6 +21,8 @@ #include "util/language.h" // for LANG_AST #include "expr/node_manager.h" // for VarNameAttr #include "expr/command.h" +#include "printer/dagification_visitor.h" +#include "util/node_visitor.h" #include #include @@ -33,6 +35,40 @@ namespace CVC4 { namespace printer { namespace ast { +void AstPrinter::toStream(std::ostream& out, TNode n, + int toDepth, bool types, size_t dag) const throw() { + if(dag != 0) { + DagificationVisitor dv(dag); + NodeVisitor visitor; + visitor.run(dv, n); + const theory::SubstitutionMap& lets = dv.getLets(); + if(!lets.empty()) { + out << "(LET "; + bool first = true; + for(theory::SubstitutionMap::const_iterator i = lets.begin(); + i != lets.end(); + ++i) { + if(! first) { + out << ", "; + } else { + first = false; + } + toStream(out, (*i).second, toDepth, types, false); + out << " := "; + toStream(out, (*i).first, toDepth, types, false); + } + out << " IN "; + } + Node body = dv.getDagifiedBody(); + toStream(out, body, toDepth, types); + if(!lets.empty()) { + out << ')'; + } + } else { + toStream(out, n, toDepth, types); + } +} + void AstPrinter::toStream(std::ostream& out, TNode n, int toDepth, bool types) const throw() { // null @@ -57,7 +93,7 @@ void AstPrinter::toStream(std::ostream& out, TNode n, if(types) { // print the whole type, but not *its* type out << ":"; - n.getType().toStream(out, -1, false, language::output::LANG_AST); + n.getType().toStream(out, -1, false, 0, language::output::LANG_AST); } return; @@ -73,8 +109,7 @@ void AstPrinter::toStream(std::ostream& out, TNode n, if(n.getMetaKind() == kind::metakind::PARAMETERIZED) { out << ' '; if(toDepth != 0) { - n.getOperator().toStream(out, toDepth < 0 ? toDepth : toDepth - 1, - types, language::output::LANG_AST); + toStream(out, n.getOperator(), toDepth < 0 ? toDepth : toDepth - 1, types); } else { out << "(...)"; } @@ -87,8 +122,7 @@ void AstPrinter::toStream(std::ostream& out, TNode n, out << ' '; } if(toDepth != 0) { - (*i).toStream(out, toDepth < 0 ? toDepth : toDepth - 1, - types, language::output::LANG_AST); + toStream(out, *i, toDepth < 0 ? toDepth : toDepth - 1, types); } else { out << "(...)"; } @@ -101,9 +135,10 @@ template static bool tryToStream(std::ostream& out, const Command* c) throw(); void AstPrinter::toStream(std::ostream& out, const Command* c, - int toDepth, bool types) const throw() { + int toDepth, bool types, size_t dag) const throw() { expr::ExprSetDepth::Scope sdScope(out, toDepth); expr::ExprPrintTypes::Scope ptScope(out, types); + expr::ExprDag::Scope dagScope(out, dag); if(tryToStream(out, c) || tryToStream(out, c) || diff --git a/src/printer/ast/ast_printer.h b/src/printer/ast/ast_printer.h index 2cae4c672..4dfb2c0d5 100644 --- a/src/printer/ast/ast_printer.h +++ b/src/printer/ast/ast_printer.h @@ -30,9 +30,10 @@ namespace printer { namespace ast { class AstPrinter : public CVC4::Printer { -public: void toStream(std::ostream& out, TNode n, int toDepth, bool types) const throw(); - void toStream(std::ostream& out, const Command* c, int toDepth, bool types) const throw(); +public: + void toStream(std::ostream& out, TNode n, int toDepth, bool types, size_t dag) const throw(); + void toStream(std::ostream& out, const Command* c, int toDepth, bool types, size_t dag) const throw(); void toStream(std::ostream& out, const CommandStatus* s) const throw(); };/* class AstPrinter */ diff --git a/src/printer/cvc/cvc_printer.cpp b/src/printer/cvc/cvc_printer.cpp index f779a1bdc..cc95d72b0 100644 --- a/src/printer/cvc/cvc_printer.cpp +++ b/src/printer/cvc/cvc_printer.cpp @@ -36,8 +36,37 @@ namespace CVC4 { namespace printer { namespace cvc { -void CvcPrinter::toStream(std::ostream& out, TNode n, int depth, bool types, bool bracket) const throw() -{ +void CvcPrinter::toStream(std::ostream& out, TNode n, int toDepth, bool types, size_t dag) const throw() { + if(dag != 0) { + DagificationVisitor dv(dag); + NodeVisitor visitor; + visitor.run(dv, n); + const theory::SubstitutionMap& lets = dv.getLets(); + if(!lets.empty()) { + out << "LET "; + bool first = true; + for(theory::SubstitutionMap::const_iterator i = lets.begin(); + i != lets.end(); + ++i) { + if(! first) { + out << ", "; + } else { + first = false; + } + toStream(out, (*i).second, toDepth, types, false); + out << " = "; + toStream(out, (*i).first, toDepth, types, false); + } + out << " IN "; + } + Node body = dv.getDagifiedBody(); + toStream(out, body, toDepth, types, false); + } else { + toStream(out, n, toDepth, types, false); + } +} + +void CvcPrinter::toStream(std::ostream& out, TNode n, int depth, bool types, bool bracket) const throw() { if (depth == 0) { out << "(...)"; } else { @@ -568,9 +597,10 @@ template static bool tryToStream(std::ostream& out, const Command* c) throw(); void CvcPrinter::toStream(std::ostream& out, const Command* c, - int toDepth, bool types) const throw() { + int toDepth, bool types, size_t dag) const throw() { expr::ExprSetDepth::Scope sdScope(out, toDepth); expr::ExprPrintTypes::Scope ptScope(out, types); + expr::ExprDag::Scope dagScope(out, dag); if(tryToStream(out, c) || tryToStream(out, c) || diff --git a/src/printer/cvc/cvc_printer.h b/src/printer/cvc/cvc_printer.h index 3db3f2c66..7fb611a79 100644 --- a/src/printer/cvc/cvc_printer.h +++ b/src/printer/cvc/cvc_printer.h @@ -24,6 +24,9 @@ #include #include "printer/printer.h" +#include "printer/dagification_visitor.h" +#include "theory/substitutions.h" +#include "util/node_visitor.h" namespace CVC4 { namespace printer { @@ -32,10 +35,8 @@ namespace cvc { class CvcPrinter : public CVC4::Printer { void toStream(std::ostream& out, TNode n, int toDepth, bool types, bool bracket) const throw(); public: - void toStream(std::ostream& out, TNode n, int toDepth, bool types) const throw() { - toStream(out, n, toDepth, types, false); - } - void toStream(std::ostream& out, const Command* c, int toDepth, bool types) const throw(); + void toStream(std::ostream& out, TNode n, int toDepth, bool types, size_t dag) const throw(); + void toStream(std::ostream& out, const Command* c, int toDepth, bool types, size_t dag) const throw(); void toStream(std::ostream& out, const CommandStatus* s) const throw(); };/* class CvcPrinter */ diff --git a/src/printer/dagification_visitor.h b/src/printer/dagification_visitor.h new file mode 100644 index 000000000..8e17f6027 --- /dev/null +++ b/src/printer/dagification_visitor.h @@ -0,0 +1,180 @@ +/********************* */ +/*! \file dagification_visitor.h + ** \verbatim + ** Original author: mdeters + ** Major contributors: none + ** Minor contributors (to current version): none + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010, 2011, 2012 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__PRINTER__DAGIFICATION_VISITOR_H +#define __CVC4__PRINTER__DAGIFICATION_VISITOR_H + +#include "context/context.h" +#include "theory/substitutions.h" +#include "expr/node.h" +#include "util/hash.h" + +#include +#include +#include + +namespace CVC4 { +namespace printer { + +class DagificationVisitor { + + unsigned d_threshold; + std::string d_letVarPrefix; + std::hash_map d_nodeCount; + TNode d_top; + context::Context* d_context; + theory::SubstitutionMap* d_substitutions; + unsigned d_letVar; + bool d_done; + std::hash_map d_uniqueParent; + std::vector d_substNodes; + +public: + + typedef void return_type; + + DagificationVisitor(unsigned threshold, std::string letVarPrefix = "_let_") : + d_threshold(threshold), + d_letVarPrefix(letVarPrefix), + d_nodeCount(), + d_top(), + d_context(new context::Context()), + d_substitutions(new theory::SubstitutionMap(d_context)), + d_letVar(0), + d_done(false), + d_uniqueParent(), + d_substNodes() { + + // 0 doesn't make sense + CheckArgument(threshold > 0, threshold); + } + + ~DagificationVisitor() { + delete d_substitutions; + delete d_context; + } + + /** + * Returns true if current has already been dagified. + */ + bool alreadyVisited(TNode current, TNode parent) { + // don't visit variables, constants, or those exprs that we've + // already seen more than the threshold: if we've increased + // the count beyond the threshold already, we've done the same + // for all subexpressions, so it isn't useful to traverse and + // increment again (they'll be dagified anyway). + return current.getMetaKind() == kind::metakind::VARIABLE || + current.getMetaKind() == kind::metakind::CONSTANT || + ( ( current.getKind() == kind::NOT || + current.getKind() == kind::UMINUS ) && + ( current[0].getMetaKind() == kind::metakind::VARIABLE || + current[0].getMetaKind() == kind::metakind::CONSTANT ) ) || + current.getKind() == kind::SORT_TYPE || + d_nodeCount[current] > d_threshold; + } + + /** + * Dagify the "current" expression. + */ + void visit(TNode current, TNode parent) { + if(d_uniqueParent.find(current) != d_uniqueParent.end()) { + TNode& uniqueParent = d_uniqueParent[current]; + + if(!uniqueParent.isNull() && uniqueParent != parent) { + // there is not a unique parent for this expr + uniqueParent = TNode::null(); + } + + unsigned count = ++d_nodeCount[current]; + + if(count > d_threshold) { + d_substNodes.push_back(current); + } + } else { + Assert(d_nodeCount[current] == 0); + d_nodeCount[current] = 1; + d_uniqueParent[current] = parent; + } + } + + /** + * Marks the node as the starting literal. + */ + void start(TNode node) { + Assert(!d_done, "DagificationVisitor cannot be re-used"); + d_top = node; + } + + /** + * Called when we're done with all visitation. + */ + void done(TNode node) { + Assert(!d_done); + + d_done = true; + + // letify subexprs before parents (cascading LETs) + std::sort(d_substNodes.begin(), d_substNodes.end()); + + for(std::vector::iterator i = d_substNodes.begin(); + i != d_substNodes.end(); + ++i) { + Assert(d_nodeCount[*i] > d_threshold); + TNode parent = d_uniqueParent[*i]; + if(!parent.isNull() && d_nodeCount[parent] > d_threshold) { + // no need to letify this expr, because it only occurs in + // a single super-expression, and that one will be letified + continue; + } + + std::stringstream ss; + ss << d_letVarPrefix << d_letVar++; + Node letvar = NodeManager::currentNM()->mkVar(ss.str(), (*i).getType()); + + Node n = d_substitutions->apply(*i); + // the three last arguments to addSubstitution are: + // invalidateCache -- the rhs of our substitution is a letvar, + // we're not going to use it on lhs so no cache problem + // backSub - no need for SubstitutionMap to do internal substitution, + // we did our own above + // forwardSub - ditto + Assert(! d_substitutions->hasSubstitution(n)); + d_substitutions->addSubstitution(n, letvar); + } + } + + /** + * Get the let substitutions. + */ + const theory::SubstitutionMap& getLets() { + Assert(d_done, "DagificationVisitor must be used as a visitor before getting the dagified version out!"); + return *d_substitutions; + } + + /** + * Return the let-substituted expression. + */ + Node getDagifiedBody() { + Assert(d_done, "DagificationVisitor must be used as a visitor before getting the dagified version out!"); + return d_substitutions->apply(d_top); + } + +};/* class DagificationVisitor */ + +}/* CVC4::printer namespace */ +}/* CVC4 namespace */ + +#endif /* __CVC4__PRINTER__DAGIFICATION_VISITOR_H */ diff --git a/src/printer/printer.h b/src/printer/printer.h index 04b435060..8d1931a83 100644 --- a/src/printer/printer.h +++ b/src/printer/printer.h @@ -53,11 +53,11 @@ public: /** Write a Node out to a stream with this Printer. */ virtual void toStream(std::ostream& out, TNode n, - int toDepth, bool types) const throw() = 0; + int toDepth, bool types, size_t dag) const throw() = 0; /** Write a Command out to a stream with this Printer. */ virtual void toStream(std::ostream& out, const Command* c, - int toDepth, bool types) const throw() = 0; + int toDepth, bool types, size_t dag) const throw() = 0; /** Write a CommandStatus out to a stream with this Printer. */ virtual void toStream(std::ostream& out, const CommandStatus* s) const throw() = 0; diff --git a/src/printer/smt/smt_printer.cpp b/src/printer/smt/smt_printer.cpp index e6490de63..f74a1e07d 100644 --- a/src/printer/smt/smt_printer.cpp +++ b/src/printer/smt/smt_printer.cpp @@ -34,13 +34,13 @@ namespace printer { namespace smt { void SmtPrinter::toStream(std::ostream& out, TNode n, - int toDepth, bool types) const throw() { - n.toStream(out, toDepth, types, language::output::LANG_SMTLIB_V2); + int toDepth, bool types, size_t dag) const throw() { + n.toStream(out, toDepth, types, dag, language::output::LANG_SMTLIB_V2); }/* SmtPrinter::toStream() */ void SmtPrinter::toStream(std::ostream& out, const Command* c, - int toDepth, bool types) const throw() { - c->toStream(out, toDepth, types, language::output::LANG_SMTLIB_V2); + int toDepth, bool types, size_t dag) const throw() { + c->toStream(out, toDepth, types, dag, language::output::LANG_SMTLIB_V2); }/* SmtPrinter::toStream() */ void SmtPrinter::toStream(std::ostream& out, const CommandStatus* s) const throw() { diff --git a/src/printer/smt/smt_printer.h b/src/printer/smt/smt_printer.h index 370e0908c..612dfd19e 100644 --- a/src/printer/smt/smt_printer.h +++ b/src/printer/smt/smt_printer.h @@ -31,8 +31,8 @@ namespace smt { class SmtPrinter : public CVC4::Printer { public: - void toStream(std::ostream& out, TNode n, int toDepth, bool types) const throw(); - void toStream(std::ostream& out, const Command* c, int toDepth, bool types) const throw(); + void toStream(std::ostream& out, TNode n, int toDepth, bool types, size_t dag) const throw(); + void toStream(std::ostream& out, const Command* c, int toDepth, bool types, size_t dag) const throw(); void toStream(std::ostream& out, const CommandStatus* s) const throw(); };/* class SmtPrinter */ diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index a1ee99d8f..d3ec376ae 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -24,6 +24,8 @@ #include #include "util/boolean_simplification.h" +#include "printer/dagification_visitor.h" +#include "util/node_visitor.h" using namespace std; @@ -35,6 +37,42 @@ static string smtKindString(Kind k) throw(); static void printBvParameterizedOp(std::ostream& out, TNode n) throw(); +void Smt2Printer::toStream(std::ostream& out, TNode n, + int toDepth, bool types, size_t dag) const throw() { + if(dag != 0) { + DagificationVisitor dv(dag); + NodeVisitor visitor; + visitor.run(dv, n); + const theory::SubstitutionMap& lets = dv.getLets(); + if(!lets.empty()) { + out << "(let ("; + bool first = true; + for(theory::SubstitutionMap::const_iterator i = lets.begin(); + i != lets.end(); + ++i) { + if(!first) { + out << ' '; + } else { + first = false; + } + out << '('; + toStream(out, (*i).second, toDepth, types); + out << ' '; + toStream(out, (*i).first, toDepth, types); + out << ')'; + } + out << ") "; + } + Node body = dv.getDagifiedBody(); + toStream(out, body, toDepth, types); + if(!lets.empty()) { + out << ')'; + } + } else { + toStream(out, n, toDepth, types); + } +} + void Smt2Printer::toStream(std::ostream& out, TNode n, int toDepth, bool types) const throw() { // null @@ -59,7 +97,7 @@ void Smt2Printer::toStream(std::ostream& out, TNode n, if(types) { // print the whole type, but not *its* type out << ":"; - n.getType().toStream(out, -1, false, language::output::LANG_SMTLIB_V2); + n.getType().toStream(out, -1, false, 0, language::output::LANG_SMTLIB_V2); } return; @@ -251,8 +289,7 @@ void Smt2Printer::toStream(std::ostream& out, TNode n, if(n.getMetaKind() == kind::metakind::PARAMETERIZED && stillNeedToPrintParams) { if(toDepth != 0) { - n.getOperator().toStream(out, toDepth < 0 ? toDepth : toDepth - 1, - types, language::output::LANG_SMTLIB_V2); + toStream(out, n.getOperator(), toDepth < 0 ? toDepth : toDepth - 1, types); } else { out << "(...)"; } @@ -264,8 +301,7 @@ void Smt2Printer::toStream(std::ostream& out, TNode n, iend = n.end(); i != iend; ) { if(toDepth != 0) { - (*i).toStream(out, toDepth < 0 ? toDepth : toDepth - 1, - types, language::output::LANG_SMTLIB_V2); + toStream(out, *i, toDepth < 0 ? toDepth : toDepth - 1, types); } else { out << "(...)"; } @@ -273,7 +309,9 @@ void Smt2Printer::toStream(std::ostream& out, TNode n, out << ' '; } } - if(n.getNumChildren() != 0) out << ')'; + if(n.getNumChildren() != 0) { + out << ')'; + } }/* Smt2Printer::toStream(TNode) */ static string smtKindString(Kind k) throw() { @@ -395,9 +433,10 @@ template static bool tryToStream(std::ostream& out, const Command* c) throw(); void Smt2Printer::toStream(std::ostream& out, const Command* c, - int toDepth, bool types) const throw() { + int toDepth, bool types, size_t dag) const throw() { expr::ExprSetDepth::Scope sdScope(out, toDepth); expr::ExprPrintTypes::Scope ptScope(out, types); + expr::ExprDag::Scope dagScope(out, dag); if(tryToStream(out, c) || tryToStream(out, c) || diff --git a/src/printer/smt2/smt2_printer.h b/src/printer/smt2/smt2_printer.h index a48104e45..fd65a1efa 100644 --- a/src/printer/smt2/smt2_printer.h +++ b/src/printer/smt2/smt2_printer.h @@ -30,9 +30,10 @@ namespace printer { namespace smt2 { class Smt2Printer : public CVC4::Printer { -public: void toStream(std::ostream& out, TNode n, int toDepth, bool types) const throw(); - void toStream(std::ostream& out, const Command* c, int toDepth, bool types) const throw(); +public: + void toStream(std::ostream& out, TNode n, int toDepth, bool types, size_t dag) const throw(); + void toStream(std::ostream& out, const Command* c, int toDepth, bool types, size_t dag) const throw(); void toStream(std::ostream& out, const CommandStatus* s) const throw(); };/* class Smt2Printer */ diff --git a/src/theory/substitutions.h b/src/theory/substitutions.h index c32dee635..ee2a15f6f 100644 --- a/src/theory/substitutions.h +++ b/src/theory/substitutions.h @@ -148,10 +148,14 @@ public: return d_substitutions.end(); } + bool empty() const { + return d_substitutions.empty(); + } + // NOTE [MGD]: removed clear() and swap() from the interface // when this data structure became context-dependent // because they weren't used---and it's not clear how they - // should // best interact with cache invalidation on context + // should best interact with cache invalidation on context // pops. /** @@ -159,7 +163,7 @@ public: */ void print(std::ostream& out) const; -}; +};/* class SubstitutionMap */ inline std::ostream& operator << (std::ostream& out, const SubstitutionMap& subst) { subst.print(out); diff --git a/src/util/node_visitor.h b/src/util/node_visitor.h index 5e04f820d..3714fcccc 100644 --- a/src/util/node_visitor.h +++ b/src/util/node_visitor.h @@ -32,7 +32,7 @@ template class NodeVisitor { /** For re-entry checking */ - static CVC4_THREADLOCAL(bool) d_inRun; + static CVC4_THREADLOCAL(bool) s_inRun; class GuardReentry { bool& d_guard; @@ -69,7 +69,7 @@ public: */ static typename Visitor::return_type run(Visitor& visitor, TNode node) { - GuardReentry guard(bool(d_inRun)); + GuardReentry guard(bool(s_inRun)); // Notify of a start visitor.start(node); @@ -111,7 +111,7 @@ public: }; template -CVC4_THREADLOCAL(bool) NodeVisitor::d_inRun = false; +CVC4_THREADLOCAL(bool) NodeVisitor::s_inRun = false; } diff --git a/src/util/options.cpp b/src/util/options.cpp index 78eea71ad..a6bd9d09a 100644 --- a/src/util/options.cpp +++ b/src/util/options.cpp @@ -194,6 +194,8 @@ Additional CVC4 options:\n\ --show-trace-tags show all avalable tags for tracing\n\ --show-sat-solvers show all available SAT solvers\n\ --default-expr-depth=N print exprs to depth N (0 == default, -1 == no limit)\n\ + --default-dag-thresh=N dagify common subexprs appearing > N times\n\ + (1 == default, 0 == don't dagify)\n\ --print-expr-types print types with variables when printing exprs\n\ --lazy-definition-expansion expand define-funs/LAMBDAs lazily\n\ --simplification=MODE choose simplification mode, see --simplification=help\n\ @@ -454,6 +456,7 @@ enum OptionValue { SHOW_CONFIG, STRICT_PARSING, DEFAULT_EXPR_DEPTH, + DEFAULT_DAG_THRESH, PRINT_EXPR_TYPES, UF_THEORY, LAZY_DEFINITION_EXPANSION, @@ -557,6 +560,7 @@ static struct option cmdlineOptions[] = { { "mmap" , no_argument , NULL, USE_MMAP }, { "strict-parsing", no_argument , NULL, STRICT_PARSING }, { "default-expr-depth", required_argument, NULL, DEFAULT_EXPR_DEPTH }, + { "default-dag-thresh", required_argument, NULL, DEFAULT_DAG_THRESH }, { "print-expr-types", no_argument , NULL, PRINT_EXPR_TYPES }, { "uf" , required_argument, NULL, UF_THEORY }, { "lazy-definition-expansion", no_argument, NULL, LAZY_DEFINITION_EXPANSION }, @@ -832,6 +836,22 @@ throw(OptionException) { } break; + case DEFAULT_DAG_THRESH: + { + int dag = atoi(optarg); + if(dag < 0) { + throw OptionException("--default-dag-thresh requires a nonnegative argument."); + } + Debug.getStream() << Expr::dag(size_t(dag)); + Trace.getStream() << Expr::dag(size_t(dag)); + Notice.getStream() << Expr::dag(size_t(dag)); + Chat.getStream() << Expr::dag(size_t(dag)); + Message.getStream() << Expr::dag(size_t(dag)); + Warning.getStream() << Expr::dag(size_t(dag)); + Dump.getStream() << Expr::dag(size_t(dag)); + } + break; + case PRINT_EXPR_TYPES: Debug.getStream() << Expr::printtypes(true); Trace.getStream() << Expr::printtypes(true); diff --git a/test/unit/expr/node_black.h b/test/unit/expr/node_black.h index 40545e5e8..36a92ec2f 100644 --- a/test/unit/expr/node_black.h +++ b/test/unit/expr/node_black.h @@ -547,11 +547,12 @@ public: Node o = NodeBuilder<>() << n << n << kind::XOR; stringstream sstr; + sstr << Node::dag(false); n.toStream(sstr); TS_ASSERT(sstr.str() == "(AND w (OR x y) z)"); sstr.str(string()); - o.toStream(sstr); + o.toStream(sstr, -1, false, 0); TS_ASSERT(sstr.str() == "(XOR (AND w (OR x y) z) (AND w (OR x y) z))"); sstr.str(string()); -- 2.30.2