From d9c4e43fe7314aa5ddeeca3ca710612e8b1d92a9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Dejan=20Jovanovi=C4=87?= Date: Wed, 16 Feb 2011 01:26:26 +0000 Subject: [PATCH] updates for the rewriter, added some statistics --- src/theory/bv/Makefile.am | 2 +- src/theory/bv/theory_bv_rewrite_rules.h | 203 +++++++++++++----- ...les.cpp => theory_bv_rewrite_rules_core.h} | 198 +++++++---------- src/theory/bv/theory_bv_rewriter.cpp | 37 ++-- src/theory/bv/theory_bv_rewriter.h | 11 +- 5 files changed, 255 insertions(+), 196 deletions(-) rename src/theory/bv/{theory_bv_rewrite_rules.cpp => theory_bv_rewrite_rules_core.h} (53%) diff --git a/src/theory/bv/Makefile.am b/src/theory/bv/Makefile.am index 3e84f482c..fdace42b4 100644 --- a/src/theory/bv/Makefile.am +++ b/src/theory/bv/Makefile.am @@ -10,7 +10,7 @@ libbv_la_SOURCES = \ theory_bv.cpp \ theory_bv_utils.h \ theory_bv_rewrite_rules.h \ - theory_bv_rewrite_rules.cpp \ + theory_bv_rewrite_rules_core.h \ theory_bv_type_rules.h \ theory_bv_rewriter.h \ theory_bv_rewriter.cpp \ diff --git a/src/theory/bv/theory_bv_rewrite_rules.h b/src/theory/bv/theory_bv_rewrite_rules.h index eba8f917c..32d0f92a0 100644 --- a/src/theory/bv/theory_bv_rewrite_rules.h +++ b/src/theory/bv/theory_bv_rewrite_rules.h @@ -22,108 +22,193 @@ #include "cvc4_private.h" #include "theory/theory.h" #include "context/context.h" +#include "util/stats.h" +#include namespace CVC4 { namespace theory { namespace bv { -struct CoreRewriteRules { +enum RewriteRuleId { + EmptyRule, + ConcatFlatten, + ConcatExtractMerge, + ConcatConstantMerge, + ExtractExtract, + ExtractWhole, + ExtractConcat, + ExtractConstant, + FailEq, + SimplifyEq, + ReflexivityEq, +}; - struct EmptyRule { - static inline Node apply(Node node) { return node; } - static inline bool applies(Node node) { return false; } - }; +inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) { + switch (ruleId) { + case EmptyRule: out << "EmptyRule"; return out; + case ConcatFlatten: out << "ConcatFlatten"; return out; + case ConcatExtractMerge: out << "ConcatExtractMerge"; return out; + case ConcatConstantMerge: out << "ConcatConstantMerge"; return out; + case ExtractExtract: out << "ExtractExtract"; return out; + case ExtractWhole: out << "ExtractWhole"; return out; + case ExtractConcat: out << "ExtractConcat"; return out; + case ExtractConstant: out << "ExtractConstant"; return out; + case FailEq: out << "FailEq"; return out; + case SimplifyEq: out << "SimplifyEq"; return out; + case ReflexivityEq: out << "ReflexivityEq"; return out; + default: + Unreachable(); + } +}; - struct ConcatFlatten { - static Node apply(Node node); - static bool applies(Node node); - }; +template +class RewriteRule { - struct ConcatExtractMerge { - static Node apply(Node node); - static bool applies(Node node); - }; + class RuleStatistics { - struct ConcatConstantMerge { - static Node apply(Node node); - static bool applies(Node node); - }; + /** The name of the rule prefixed with the prefix */ + static std::string getStatName(const char* prefix) { + std::stringstream statName; + statName << prefix << rule; + return statName.str(); + } - struct ExtractExtract { - static Node apply(Node node); - static bool applies(Node node); - }; + public: - struct ExtractWhole { - static Node apply(Node node); - static bool applies(Node node); - }; + /** Number of applications of this rule */ + IntStat d_ruleApplications; - struct ExtractConcat { - static Node apply(Node node); - static bool applies(Node node); - }; + /** Constructor */ + RuleStatistics() + : d_ruleApplications(getStatName("theory::bv::count"), 0) { + StatisticsRegistry::registerStat(&d_ruleApplications); + } - struct ExtractConstant { - static Node apply(Node node); - static bool applies(Node node); + /** Destructor */ + ~RuleStatistics() { + StatisticsRegistry::unregisterStat(&d_ruleApplications); + } }; - struct FailEq { - static Node apply(Node node); - static bool applies(Node node); - }; + /** Statistics about the rule */ + static RuleStatistics* s_statictics; - struct SimplifyEq { - static Node apply(Node node); - static bool applies(Node node); - }; + /** Actually apply the rewrite rule */ + static inline Node apply(Node node) { + Unreachable(); + } + +public: + RewriteRule() { + if (s_statictics == NULL) { + s_statictics = new RuleStatistics(); + } + } + + ~RewriteRule() { + delete s_statictics; + s_statictics = NULL; + } + + static inline bool applies(Node node) { + Unreachable(); + } + + template + static inline Node run(Node node) { + if (!checkApplies || applies(node)) { + Debug("theory::bv::rewrite") << "RewriteRule<" << rule << ">(" << node << ")" << std::endl; + Assert(checkApplies || applies(node)); + ++ s_statictics->d_ruleApplications; + Node result = apply(node); + Debug("theory::bv::rewrite") << "RewriteRule<" << rule << ">(" << node << ") => " << result << std::endl; + return result; + } else { + return node; + } + } +}; + +template +typename RewriteRule::RuleStatistics* RewriteRule::s_statictics = NULL; + +/** Have to list all the rewrite rules to get the statistics out */ +struct AllRewriteRules { + RewriteRule rule00; + RewriteRule rule01; + RewriteRule rule02; + RewriteRule rule03; + RewriteRule rule04; + RewriteRule rule05; + RewriteRule rule06; + RewriteRule rule07; + RewriteRule rule08; + RewriteRule rule09; + RewriteRule rule10; }; -template +template<> +bool RewriteRule::applies(Node node) { + return false; +} + +template<> +Node RewriteRule::apply(Node node) { + Unreachable(); + return node; +} + +template struct ApplyRuleToChildren { static Node apply(Node node) { if (node.getKind() != kind) { - if (Rule::applies(node)) return Rule::apply(node); - else return node; + return RewriteRule::template run(node); } NodeBuilder<> result(kind); for (unsigned i = 0, end = node.getNumChildren(); i < end; ++ i) { - if (Rule::applies(node[i])) result << Rule::apply(node[i]); - else result << node[i]; + result << RewriteRule::template run(node[i]); } return result; } static bool applies(Node node) { if (node.getKind() == kind) return true; - return Rule::applies(node); + return RewriteRule::applies(node); } + template + static Node run(Node node) { + if (!checkApplies || applies(node)) { + return apply(node); + } else { + return node; + } + } }; - template < typename R1, typename R2, - typename R3 = CoreRewriteRules::EmptyRule, - typename R4 = CoreRewriteRules::EmptyRule, - typename R5 = CoreRewriteRules::EmptyRule, - typename R6 = CoreRewriteRules::EmptyRule, - typename R7 = CoreRewriteRules::EmptyRule + typename R3 = RewriteRule, + typename R4 = RewriteRule, + typename R5 = RewriteRule, + typename R6 = RewriteRule, + typename R7 = RewriteRule, + typename R8 = RewriteRule > struct LinearRewriteStrategy { static Node apply(Node node) { Node current = node; - if (R1::applies(current)) current = R1::apply(current); - if (R2::applies(current)) current = R2::apply(current); - if (R3::applies(current)) current = R3::apply(current); - if (R4::applies(current)) current = R4::apply(current); - if (R5::applies(current)) current = R5::apply(current); - if (R6::applies(current)) current = R6::apply(current); - if (R7::applies(current)) current = R7::apply(current); + if (R1::applies(current)) current = R1::template run(current); + if (R2::applies(current)) current = R2::template run(current); + if (R3::applies(current)) current = R3::template run(current); + if (R4::applies(current)) current = R4::template run(current); + if (R5::applies(current)) current = R5::template run(current); + if (R6::applies(current)) current = R6::template run(current); + if (R7::applies(current)) current = R7::template run(current); + if (R8::applies(current)) current = R8::template run(current); return current; } }; diff --git a/src/theory/bv/theory_bv_rewrite_rules.cpp b/src/theory/bv/theory_bv_rewrite_rules_core.h similarity index 53% rename from src/theory/bv/theory_bv_rewrite_rules.cpp rename to src/theory/bv/theory_bv_rewrite_rules_core.h index d2fb621f9..e75f53711 100644 --- a/src/theory/bv/theory_bv_rewrite_rules.cpp +++ b/src/theory/bv/theory_bv_rewrite_rules_core.h @@ -1,5 +1,5 @@ /********************* */ -/*! \file theory_bv_rewrite_rules.cpp +/*! \file theory_bv_rewrite_rules_core.h ** \verbatim ** Original author: dejan ** Major contributors: none @@ -17,26 +17,22 @@ ** \todo document this file **/ -#include -#include "expr/node_builder.h" -#include "theory_bv_rewrite_rules.h" -#include "theory_bv_utils.h" +#pragma once -using namespace std; -using namespace CVC4; -using namespace CVC4::theory; -using namespace CVC4::theory::bv; -using namespace CVC4::theory::bv::utils; +#include "theory/bv/theory_bv_rewrite_rules.h" +#include "theory/bv/theory_bv_utils.h" -bool CoreRewriteRules::ConcatFlatten::applies(Node node) { +namespace CVC4 { +namespace theory { +namespace bv { + +template<> +bool RewriteRule::applies(Node node) { return (node.getKind() == kind::BITVECTOR_CONCAT); } -Node CoreRewriteRules::ConcatFlatten::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "ConcatFlatten(" << node << ")" << endl; - +template<> +Node RewriteRule::apply(Node node) { NodeBuilder<> result(kind::BITVECTOR_CONCAT); vector processing_stack; processing_stack.push_back(node); @@ -50,22 +46,17 @@ Node CoreRewriteRules::ConcatFlatten::apply(Node node) { result << current; } } - Node resultNode = result; - Debug("bitvector") << "ConcatFlatten(" << node << ") => " << resultNode << endl; - return resultNode; } -bool CoreRewriteRules::ConcatExtractMerge::applies(Node node) { +template<> +bool RewriteRule::applies(Node node) { return (node.getKind() == kind::BITVECTOR_CONCAT); } -Node CoreRewriteRules::ConcatExtractMerge::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "ConcatExtractMerge(" << node << ")" << endl; - +template<> +Node RewriteRule::apply(Node node) { vector mergedExtracts; Node current = node[0]; @@ -84,16 +75,16 @@ Node CoreRewriteRules::ConcatExtractMerge::apply(Node node) { } // If it is an extract and the first one, get the extract parameters else if (!mergeStarted) { - currentHigh = getExtractHigh(current); - currentLow = getExtractLow(current); + currentHigh = utils::getExtractHigh(current); + currentLow = utils::getExtractLow(current); } // If the next one can be merged, try to merge bool merged = false; if (next.getKind() == kind::BITVECTOR_EXTRACT && current[0] == next[0]) { //x[i : j] @ x[j − 1 : k] -> c x[i : k] - unsigned nextHigh = getExtractHigh(next); - unsigned nextLow = getExtractLow(next); + unsigned nextHigh = utils::getExtractHigh(next); + unsigned nextLow = utils::getExtractLow(next); if(nextHigh + 1 == currentLow) { currentLow = nextLow; mergeStarted = true; @@ -103,7 +94,7 @@ Node CoreRewriteRules::ConcatExtractMerge::apply(Node node) { // If we haven't merged anything, add the previous merge and continue with the next if (!merged) { if (!mergeStarted) mergedExtracts.push_back(current); - else mergedExtracts.push_back(mkExtract(current[0], currentHigh, currentLow)); + else mergedExtracts.push_back(utils::mkExtract(current[0], currentHigh, currentLow)); current = next; mergeStarted = false; } @@ -111,26 +102,19 @@ Node CoreRewriteRules::ConcatExtractMerge::apply(Node node) { // Add the last child if (!mergeStarted) mergedExtracts.push_back(current); - else mergedExtracts.push_back(mkExtract(current[0], currentHigh, currentLow)); - - // Create the result - Node result = mkConcat(mergedExtracts); - - Debug("bitvector") << "ConcatExtractMerge(" << node << ") =>" << result << endl; + else mergedExtracts.push_back(utils::mkExtract(current[0], currentHigh, currentLow)); // Return the result - return result; + return utils::mkConcat(mergedExtracts); } -bool CoreRewriteRules::ConcatConstantMerge::applies(Node node) { +template<> +bool RewriteRule::applies(Node node) { return node.getKind() == kind::BITVECTOR_CONCAT; } -Node CoreRewriteRules::ConcatConstantMerge::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "ConcatConstantMerge(" << node << ")" << endl; - +template<> +Node RewriteRule::apply(Node node) { vector mergedConstants; for (unsigned i = 0, end = node.getNumChildren(); i < end;) { if (node[i].getKind() != kind::CONST_BITVECTOR) { @@ -153,82 +137,66 @@ Node CoreRewriteRules::ConcatConstantMerge::apply(Node node) { current = current.concat(node[k].getConst()); } // Add the new merged constant - mergedConstants.push_back(mkConst(current)); + mergedConstants.push_back(utils::mkConst(current)); i = j + 1; } } - Node result = mkConcat(mergedConstants); - - Debug("bitvector") << "ConcatConstantMerge(" << node << ") => " << result << endl; - - return result; + return utils::mkConcat(mergedConstants); } -bool CoreRewriteRules::ExtractWhole::applies(Node node) { +template<> +bool RewriteRule::applies(Node node) { if (node.getKind() != kind::BITVECTOR_EXTRACT) return false; - unsigned length = getSize(node[0]); - unsigned extractHigh = getExtractHigh(node); + unsigned length = utils::getSize(node[0]); + unsigned extractHigh = utils::getExtractHigh(node); if (extractHigh != length - 1) return false; - unsigned extractLow = getExtractLow(node); + unsigned extractLow = utils::getExtractLow(node); if (extractLow != 0) return false; return true; } -Node CoreRewriteRules::ExtractWhole::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "ExtractWhole(" << node << ")" << endl; - Debug("bitvector") << "ExtractWhole(" << node << ") => " << node[0] << endl; - +template<> +Node RewriteRule::apply(Node node) { return node[0]; } -bool CoreRewriteRules::ExtractConstant::applies(Node node) { +template<> +bool RewriteRule::applies(Node node) { if (node.getKind() != kind::BITVECTOR_EXTRACT) return false; if (node[0].getKind() != kind::CONST_BITVECTOR) return false; return true; } -Node CoreRewriteRules::ExtractConstant::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "ExtractConstant(" << node << ")" << endl; - +template<> +Node RewriteRule::apply(Node node) { Node child = node[0]; BitVector childValue = child.getConst(); - - Node result = mkConst(childValue.extract(getExtractHigh(node), getExtractLow(node))); - - Debug("bitvector") << "ExtractConstant(" << node << ") => " << result << endl; - - return result; + return utils::mkConst(childValue.extract(utils::getExtractHigh(node), utils::getExtractLow(node))); } -bool CoreRewriteRules::ExtractConcat::applies(Node node) { +template<> +bool RewriteRule::applies(Node node) { if (node.getKind() != kind::BITVECTOR_EXTRACT) return false; if (node[0].getKind() != kind::BITVECTOR_CONCAT) return false; return true; } -Node CoreRewriteRules::ExtractConcat::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "ExtractConcat(" << node << ")" << endl; - - int extract_high = getExtractHigh(node); - int extract_low = getExtractLow(node); +template<> +Node RewriteRule::apply(Node node) { + int extract_high = utils::getExtractHigh(node); + int extract_low = utils::getExtractLow(node); vector resultChildren; Node concat = node[0]; for (int i = concat.getNumChildren() - 1; i >= 0 && extract_high >= 0; i--) { Node concatChild = concat[i]; - int concatChildSize = getSize(concatChild); + int concatChildSize = utils::getSize(concatChild); if (extract_low < concatChildSize) { int extract_start = extract_low < 0 ? 0 : extract_low; int extract_end = extract_high < concatChildSize ? extract_high : concatChildSize - 1; - resultChildren.push_back(mkExtract(concatChild, extract_end, extract_start)); + resultChildren.push_back(utils::mkExtract(concatChild, extract_end, extract_start)); } extract_low -= concatChildSize; extract_high -= concatChildSize; @@ -236,70 +204,62 @@ Node CoreRewriteRules::ExtractConcat::apply(Node node) { std::reverse(resultChildren.begin(), resultChildren.end()); - Node result = mkConcat(resultChildren); - - Debug("bitvector") << "ExtractConcat(" << node << ") => " << result << endl; - - return result; + return utils::mkConcat(resultChildren); } -bool CoreRewriteRules::ExtractExtract::applies(Node node) { +template<> +bool RewriteRule::applies(Node node) { if (node.getKind() != kind::BITVECTOR_EXTRACT) return false; if (node[0].getKind() != kind::BITVECTOR_EXTRACT) return false; return true; } -Node CoreRewriteRules::ExtractExtract::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "ExtractExtract(" << node << ")" << endl; - +template<> +Node RewriteRule::apply(Node node) { // x[i:j][k:l] ~> x[k+j:l+j] Node child = node[0]; - unsigned k = getExtractHigh(node); - unsigned l = getExtractLow(node); - unsigned j = getExtractLow(child); - - Node result = mkExtract(child[0], k + j, l + j); - - Debug("bitvector") << "ExtractExtract(" << node << ") => " << result << endl; + unsigned k = utils::getExtractHigh(node); + unsigned l = utils::getExtractLow(node); + unsigned j = utils::getExtractLow(child); + Node result = utils::mkExtract(child[0], k + j, l + j); return result; } -bool CoreRewriteRules::FailEq::applies(Node node) { +template<> +bool RewriteRule::applies(Node node) { if (node.getKind() != kind::EQUAL) return false; if (node[0].getKind() != kind::CONST_BITVECTOR) return false; if (node[1].getKind() != kind::CONST_BITVECTOR) return false; return node[0] != node[1]; } -Node CoreRewriteRules::FailEq::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "FailEq(" << node << ")" << endl; - - Node result = mkFalse(); - - Debug("bitvector") << "FailEq(" << node << ") => " << result << endl; - - return result; +template<> +Node RewriteRule::apply(Node node) { + return utils::mkFalse(); } -bool CoreRewriteRules::SimplifyEq::applies(Node node) { +template<> +bool RewriteRule::applies(Node node) { if (node.getKind() != kind::EQUAL) return false; return node[0] == node[1]; } -Node CoreRewriteRules::SimplifyEq::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "FailEq(" << node << ")" << endl; - - Node result = mkTrue(); +template<> +Node RewriteRule::apply(Node node) { + return utils::mkTrue(); +} - Debug("bitvector") << "FailEq(" << node << ") => " << result << endl; +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::EQUAL && node[0] < node[1]); +} - return result; +template<> +Node RewriteRule::apply(Node node) { + return node[1].eqNode(node[0]);; } +} +} +} diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp index cd2efd64f..08245afcb 100644 --- a/src/theory/bv/theory_bv_rewriter.cpp +++ b/src/theory/bv/theory_bv_rewriter.cpp @@ -5,8 +5,10 @@ * Author: dejan */ +#include "theory/theory.h" #include "theory/bv/theory_bv_rewriter.h" #include "theory/bv/theory_bv_rewrite_rules.h" +#include "theory/bv/theory_bv_rewrite_rules_core.h" using namespace CVC4; using namespace CVC4::theory; @@ -18,41 +20,43 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) { Node result; - if (node.getKind() == kind::CONST_BITVECTOR /* || isLeaf(n)) */) + if (node.getKind() == kind::CONST_BITVECTOR || (node.getKind() != kind::EQUAL && Theory::isLeafOf(node, THEORY_BV))) { result = node; - else { + } else { switch (node.getKind()) { case kind::BITVECTOR_CONCAT: result = LinearRewriteStrategy< // Flatten the top level concatenations - CoreRewriteRules::ConcatFlatten, + RewriteRule, // Merge the adjacent extracts on non-constants - CoreRewriteRules::ConcatExtractMerge, + RewriteRule, // Merge the adjacent extracts on constants - CoreRewriteRules::ConcatConstantMerge, + RewriteRule, // At this point only Extract-Whole could apply, if the result is only one extract // or at some sub-expression if the result is a concatenation. - ApplyRuleToChildren + ApplyRuleToChildren >::apply(node); break; case kind::BITVECTOR_EXTRACT: result = LinearRewriteStrategy< // Extract over a constant gives a constant - CoreRewriteRules::ExtractConstant, + RewriteRule, // Extract over an extract is simplified to one extract - CoreRewriteRules::ExtractExtract, + RewriteRule, // Extract over a concatenation is distributed to the appropriate concatenations - CoreRewriteRules::ExtractConcat, + RewriteRule, // At this point only Extract-Whole could apply - CoreRewriteRules::ExtractWhole + RewriteRule >::apply(node); break; case kind::EQUAL: result = LinearRewriteStrategy< // Two distinct values rewrite to false - CoreRewriteRules::FailEq, + RewriteRule, // If both sides are equal equality is true - CoreRewriteRules::SimplifyEq + RewriteRule, + // Normalize the equalities + RewriteRule >::apply(node); break; default: @@ -68,3 +72,12 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) { return RewriteResponse(REWRITE_DONE, result); } +AllRewriteRules* TheoryBVRewriter::s_allRules = NULL; + +void TheoryBVRewriter::init() { + s_allRules = new AllRewriteRules; +} + +void TheoryBVRewriter::shutdown() { + delete s_allRules; +} diff --git a/src/theory/bv/theory_bv_rewriter.h b/src/theory/bv/theory_bv_rewriter.h index 741b9fcbc..437ac49d3 100644 --- a/src/theory/bv/theory_bv_rewriter.h +++ b/src/theory/bv/theory_bv_rewriter.h @@ -7,16 +7,18 @@ #pragma once - - #include "theory/rewriter.h" namespace CVC4 { namespace theory { namespace bv { +class AllRewriteRules; + class TheoryBVRewriter { + static AllRewriteRules* s_allRules; + public: static RewriteResponse postRewrite(TNode node); @@ -25,9 +27,8 @@ public: return RewriteResponse(REWRITE_DONE, node); } - static inline void init() {} - static inline void shutdown() {} - + static void init(); + static void shutdown(); }; } -- 2.30.2