From: Dejan Jovanović Date: Wed, 16 Feb 2011 01:26:26 +0000 (+0000) Subject: updates for the rewriter, added some statistics X-Git-Tag: cvc5-1.0.0~8712 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=d9c4e43fe7314aa5ddeeca3ca710612e8b1d92a9;p=cvc5.git updates for the rewriter, added some statistics --- diff --git a/src/theory/bv/Makefile.am b/src/theory/bv/Makefile.am index 3e84f482c..fdace42b4 100644 --- a/src/theory/bv/Makefile.am +++ b/src/theory/bv/Makefile.am @@ -10,7 +10,7 @@ libbv_la_SOURCES = \ theory_bv.cpp \ theory_bv_utils.h \ theory_bv_rewrite_rules.h \ - theory_bv_rewrite_rules.cpp \ + theory_bv_rewrite_rules_core.h \ theory_bv_type_rules.h \ theory_bv_rewriter.h \ theory_bv_rewriter.cpp \ diff --git a/src/theory/bv/theory_bv_rewrite_rules.cpp b/src/theory/bv/theory_bv_rewrite_rules.cpp deleted file mode 100644 index d2fb621f9..000000000 --- a/src/theory/bv/theory_bv_rewrite_rules.cpp +++ /dev/null @@ -1,305 +0,0 @@ -/********************* */ -/*! \file theory_bv_rewrite_rules.cpp - ** \verbatim - ** Original author: dejan - ** Major contributors: none - ** Minor contributors (to current version): none - ** This file is part of the CVC4 prototype. - ** Copyright (c) 2009, 2010 The Analysis of Computer Systems Group (ACSys) - ** Courant Institute of Mathematical Sciences - ** New York University - ** See the file COPYING in the top-level source directory for licensing - ** information.\endverbatim - ** - ** \brief [[ Add one-line brief description here ]] - ** - ** [[ Add lengthier description here ]] - ** \todo document this file - **/ - -#include -#include "expr/node_builder.h" -#include "theory_bv_rewrite_rules.h" -#include "theory_bv_utils.h" - -using namespace std; -using namespace CVC4; -using namespace CVC4::theory; -using namespace CVC4::theory::bv; -using namespace CVC4::theory::bv::utils; - -bool CoreRewriteRules::ConcatFlatten::applies(Node node) { - return (node.getKind() == kind::BITVECTOR_CONCAT); -} - -Node CoreRewriteRules::ConcatFlatten::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "ConcatFlatten(" << node << ")" << endl; - - NodeBuilder<> result(kind::BITVECTOR_CONCAT); - vector processing_stack; - processing_stack.push_back(node); - while (!processing_stack.empty()) { - Node current = processing_stack.back(); - processing_stack.pop_back(); - if (current.getKind() == kind::BITVECTOR_CONCAT) { - for (int i = current.getNumChildren() - 1; i >= 0; i--) - processing_stack.push_back(current[i]); - } else { - result << current; - } - } - - Node resultNode = result; - Debug("bitvector") << "ConcatFlatten(" << node << ") => " << resultNode << endl; - - return resultNode; -} - -bool CoreRewriteRules::ConcatExtractMerge::applies(Node node) { - return (node.getKind() == kind::BITVECTOR_CONCAT); -} - -Node CoreRewriteRules::ConcatExtractMerge::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "ConcatExtractMerge(" << node << ")" << endl; - - vector mergedExtracts; - - Node current = node[0]; - bool mergeStarted = false; - unsigned currentHigh = 0; - unsigned currentLow = 0; - - for(size_t i = 1, end = node.getNumChildren(); i < end; ++ i) { - // The next candidate for merging - Node next = node[i]; - // If the current is not an extract we just go to the next - if (current.getKind() != kind::BITVECTOR_EXTRACT) { - mergedExtracts.push_back(current); - current = next; - continue; - } - // If it is an extract and the first one, get the extract parameters - else if (!mergeStarted) { - currentHigh = getExtractHigh(current); - currentLow = getExtractLow(current); - } - - // If the next one can be merged, try to merge - bool merged = false; - if (next.getKind() == kind::BITVECTOR_EXTRACT && current[0] == next[0]) { - //x[i : j] @ x[j − 1 : k] -> c x[i : k] - unsigned nextHigh = getExtractHigh(next); - unsigned nextLow = getExtractLow(next); - if(nextHigh + 1 == currentLow) { - currentLow = nextLow; - mergeStarted = true; - merged = true; - } - } - // If we haven't merged anything, add the previous merge and continue with the next - if (!merged) { - if (!mergeStarted) mergedExtracts.push_back(current); - else mergedExtracts.push_back(mkExtract(current[0], currentHigh, currentLow)); - current = next; - mergeStarted = false; - } - } - - // Add the last child - if (!mergeStarted) mergedExtracts.push_back(current); - else mergedExtracts.push_back(mkExtract(current[0], currentHigh, currentLow)); - - // Create the result - Node result = mkConcat(mergedExtracts); - - Debug("bitvector") << "ConcatExtractMerge(" << node << ") =>" << result << endl; - - // Return the result - return result; -} - -bool CoreRewriteRules::ConcatConstantMerge::applies(Node node) { - return node.getKind() == kind::BITVECTOR_CONCAT; -} - -Node CoreRewriteRules::ConcatConstantMerge::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "ConcatConstantMerge(" << node << ")" << endl; - - vector mergedConstants; - for (unsigned i = 0, end = node.getNumChildren(); i < end;) { - if (node[i].getKind() != kind::CONST_BITVECTOR) { - // If not a constant, just add it - mergedConstants.push_back(node[i]); - ++i; - } else { - // Find the longest sequence of constants - unsigned j = i + 1; - while (j < end) { - if (node[j].getKind() != kind::CONST_BITVECTOR) { - break; - } else { - ++ j; - } - } - // Append all the constants - BitVector current = node[i].getConst(); - for (unsigned k = i + 1; k < j; ++ k) { - current = current.concat(node[k].getConst()); - } - // Add the new merged constant - mergedConstants.push_back(mkConst(current)); - i = j + 1; - } - } - - Node result = mkConcat(mergedConstants); - - Debug("bitvector") << "ConcatConstantMerge(" << node << ") => " << result << endl; - - return result; -} - -bool CoreRewriteRules::ExtractWhole::applies(Node node) { - if (node.getKind() != kind::BITVECTOR_EXTRACT) return false; - unsigned length = getSize(node[0]); - unsigned extractHigh = getExtractHigh(node); - if (extractHigh != length - 1) return false; - unsigned extractLow = getExtractLow(node); - if (extractLow != 0) return false; - return true; -} - -Node CoreRewriteRules::ExtractWhole::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "ExtractWhole(" << node << ")" << endl; - Debug("bitvector") << "ExtractWhole(" << node << ") => " << node[0] << endl; - - return node[0]; -} - -bool CoreRewriteRules::ExtractConstant::applies(Node node) { - if (node.getKind() != kind::BITVECTOR_EXTRACT) return false; - if (node[0].getKind() != kind::CONST_BITVECTOR) return false; - return true; -} - -Node CoreRewriteRules::ExtractConstant::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "ExtractConstant(" << node << ")" << endl; - - Node child = node[0]; - BitVector childValue = child.getConst(); - - Node result = mkConst(childValue.extract(getExtractHigh(node), getExtractLow(node))); - - Debug("bitvector") << "ExtractConstant(" << node << ") => " << result << endl; - - return result; -} - -bool CoreRewriteRules::ExtractConcat::applies(Node node) { - if (node.getKind() != kind::BITVECTOR_EXTRACT) return false; - if (node[0].getKind() != kind::BITVECTOR_CONCAT) return false; - return true; -} - -Node CoreRewriteRules::ExtractConcat::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "ExtractConcat(" << node << ")" << endl; - - int extract_high = getExtractHigh(node); - int extract_low = getExtractLow(node); - - vector resultChildren; - - Node concat = node[0]; - for (int i = concat.getNumChildren() - 1; i >= 0 && extract_high >= 0; i--) { - Node concatChild = concat[i]; - int concatChildSize = getSize(concatChild); - if (extract_low < concatChildSize) { - int extract_start = extract_low < 0 ? 0 : extract_low; - int extract_end = extract_high < concatChildSize ? extract_high : concatChildSize - 1; - resultChildren.push_back(mkExtract(concatChild, extract_end, extract_start)); - } - extract_low -= concatChildSize; - extract_high -= concatChildSize; - } - - std::reverse(resultChildren.begin(), resultChildren.end()); - - Node result = mkConcat(resultChildren); - - Debug("bitvector") << "ExtractConcat(" << node << ") => " << result << endl; - - return result; -} - -bool CoreRewriteRules::ExtractExtract::applies(Node node) { - if (node.getKind() != kind::BITVECTOR_EXTRACT) return false; - if (node[0].getKind() != kind::BITVECTOR_EXTRACT) return false; - return true; -} - -Node CoreRewriteRules::ExtractExtract::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "ExtractExtract(" << node << ")" << endl; - - // x[i:j][k:l] ~> x[k+j:l+j] - Node child = node[0]; - unsigned k = getExtractHigh(node); - unsigned l = getExtractLow(node); - unsigned j = getExtractLow(child); - - Node result = mkExtract(child[0], k + j, l + j); - - Debug("bitvector") << "ExtractExtract(" << node << ") => " << result << endl; - - return result; -} - -bool CoreRewriteRules::FailEq::applies(Node node) { - if (node.getKind() != kind::EQUAL) return false; - if (node[0].getKind() != kind::CONST_BITVECTOR) return false; - if (node[1].getKind() != kind::CONST_BITVECTOR) return false; - return node[0] != node[1]; -} - -Node CoreRewriteRules::FailEq::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "FailEq(" << node << ")" << endl; - - Node result = mkFalse(); - - Debug("bitvector") << "FailEq(" << node << ") => " << result << endl; - - return result; -} - -bool CoreRewriteRules::SimplifyEq::applies(Node node) { - if (node.getKind() != kind::EQUAL) return false; - return node[0] == node[1]; -} - -Node CoreRewriteRules::SimplifyEq::apply(Node node) { - Assert(applies(node)); - - Debug("bitvector") << "FailEq(" << node << ")" << endl; - - Node result = mkTrue(); - - Debug("bitvector") << "FailEq(" << node << ") => " << result << endl; - - return result; -} - diff --git a/src/theory/bv/theory_bv_rewrite_rules.h b/src/theory/bv/theory_bv_rewrite_rules.h index eba8f917c..32d0f92a0 100644 --- a/src/theory/bv/theory_bv_rewrite_rules.h +++ b/src/theory/bv/theory_bv_rewrite_rules.h @@ -22,108 +22,193 @@ #include "cvc4_private.h" #include "theory/theory.h" #include "context/context.h" +#include "util/stats.h" +#include namespace CVC4 { namespace theory { namespace bv { -struct CoreRewriteRules { +enum RewriteRuleId { + EmptyRule, + ConcatFlatten, + ConcatExtractMerge, + ConcatConstantMerge, + ExtractExtract, + ExtractWhole, + ExtractConcat, + ExtractConstant, + FailEq, + SimplifyEq, + ReflexivityEq, +}; - struct EmptyRule { - static inline Node apply(Node node) { return node; } - static inline bool applies(Node node) { return false; } - }; +inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) { + switch (ruleId) { + case EmptyRule: out << "EmptyRule"; return out; + case ConcatFlatten: out << "ConcatFlatten"; return out; + case ConcatExtractMerge: out << "ConcatExtractMerge"; return out; + case ConcatConstantMerge: out << "ConcatConstantMerge"; return out; + case ExtractExtract: out << "ExtractExtract"; return out; + case ExtractWhole: out << "ExtractWhole"; return out; + case ExtractConcat: out << "ExtractConcat"; return out; + case ExtractConstant: out << "ExtractConstant"; return out; + case FailEq: out << "FailEq"; return out; + case SimplifyEq: out << "SimplifyEq"; return out; + case ReflexivityEq: out << "ReflexivityEq"; return out; + default: + Unreachable(); + } +}; - struct ConcatFlatten { - static Node apply(Node node); - static bool applies(Node node); - }; +template +class RewriteRule { - struct ConcatExtractMerge { - static Node apply(Node node); - static bool applies(Node node); - }; + class RuleStatistics { - struct ConcatConstantMerge { - static Node apply(Node node); - static bool applies(Node node); - }; + /** The name of the rule prefixed with the prefix */ + static std::string getStatName(const char* prefix) { + std::stringstream statName; + statName << prefix << rule; + return statName.str(); + } - struct ExtractExtract { - static Node apply(Node node); - static bool applies(Node node); - }; + public: - struct ExtractWhole { - static Node apply(Node node); - static bool applies(Node node); - }; + /** Number of applications of this rule */ + IntStat d_ruleApplications; - struct ExtractConcat { - static Node apply(Node node); - static bool applies(Node node); - }; + /** Constructor */ + RuleStatistics() + : d_ruleApplications(getStatName("theory::bv::count"), 0) { + StatisticsRegistry::registerStat(&d_ruleApplications); + } - struct ExtractConstant { - static Node apply(Node node); - static bool applies(Node node); + /** Destructor */ + ~RuleStatistics() { + StatisticsRegistry::unregisterStat(&d_ruleApplications); + } }; - struct FailEq { - static Node apply(Node node); - static bool applies(Node node); - }; + /** Statistics about the rule */ + static RuleStatistics* s_statictics; - struct SimplifyEq { - static Node apply(Node node); - static bool applies(Node node); - }; + /** Actually apply the rewrite rule */ + static inline Node apply(Node node) { + Unreachable(); + } + +public: + RewriteRule() { + if (s_statictics == NULL) { + s_statictics = new RuleStatistics(); + } + } + + ~RewriteRule() { + delete s_statictics; + s_statictics = NULL; + } + + static inline bool applies(Node node) { + Unreachable(); + } + + template + static inline Node run(Node node) { + if (!checkApplies || applies(node)) { + Debug("theory::bv::rewrite") << "RewriteRule<" << rule << ">(" << node << ")" << std::endl; + Assert(checkApplies || applies(node)); + ++ s_statictics->d_ruleApplications; + Node result = apply(node); + Debug("theory::bv::rewrite") << "RewriteRule<" << rule << ">(" << node << ") => " << result << std::endl; + return result; + } else { + return node; + } + } +}; + +template +typename RewriteRule::RuleStatistics* RewriteRule::s_statictics = NULL; + +/** Have to list all the rewrite rules to get the statistics out */ +struct AllRewriteRules { + RewriteRule rule00; + RewriteRule rule01; + RewriteRule rule02; + RewriteRule rule03; + RewriteRule rule04; + RewriteRule rule05; + RewriteRule rule06; + RewriteRule rule07; + RewriteRule rule08; + RewriteRule rule09; + RewriteRule rule10; }; -template +template<> +bool RewriteRule::applies(Node node) { + return false; +} + +template<> +Node RewriteRule::apply(Node node) { + Unreachable(); + return node; +} + +template struct ApplyRuleToChildren { static Node apply(Node node) { if (node.getKind() != kind) { - if (Rule::applies(node)) return Rule::apply(node); - else return node; + return RewriteRule::template run(node); } NodeBuilder<> result(kind); for (unsigned i = 0, end = node.getNumChildren(); i < end; ++ i) { - if (Rule::applies(node[i])) result << Rule::apply(node[i]); - else result << node[i]; + result << RewriteRule::template run(node[i]); } return result; } static bool applies(Node node) { if (node.getKind() == kind) return true; - return Rule::applies(node); + return RewriteRule::applies(node); } + template + static Node run(Node node) { + if (!checkApplies || applies(node)) { + return apply(node); + } else { + return node; + } + } }; - template < typename R1, typename R2, - typename R3 = CoreRewriteRules::EmptyRule, - typename R4 = CoreRewriteRules::EmptyRule, - typename R5 = CoreRewriteRules::EmptyRule, - typename R6 = CoreRewriteRules::EmptyRule, - typename R7 = CoreRewriteRules::EmptyRule + typename R3 = RewriteRule, + typename R4 = RewriteRule, + typename R5 = RewriteRule, + typename R6 = RewriteRule, + typename R7 = RewriteRule, + typename R8 = RewriteRule > struct LinearRewriteStrategy { static Node apply(Node node) { Node current = node; - if (R1::applies(current)) current = R1::apply(current); - if (R2::applies(current)) current = R2::apply(current); - if (R3::applies(current)) current = R3::apply(current); - if (R4::applies(current)) current = R4::apply(current); - if (R5::applies(current)) current = R5::apply(current); - if (R6::applies(current)) current = R6::apply(current); - if (R7::applies(current)) current = R7::apply(current); + if (R1::applies(current)) current = R1::template run(current); + if (R2::applies(current)) current = R2::template run(current); + if (R3::applies(current)) current = R3::template run(current); + if (R4::applies(current)) current = R4::template run(current); + if (R5::applies(current)) current = R5::template run(current); + if (R6::applies(current)) current = R6::template run(current); + if (R7::applies(current)) current = R7::template run(current); + if (R8::applies(current)) current = R8::template run(current); return current; } }; diff --git a/src/theory/bv/theory_bv_rewrite_rules_core.h b/src/theory/bv/theory_bv_rewrite_rules_core.h new file mode 100644 index 000000000..e75f53711 --- /dev/null +++ b/src/theory/bv/theory_bv_rewrite_rules_core.h @@ -0,0 +1,265 @@ +/********************* */ +/*! \file theory_bv_rewrite_rules_core.h + ** \verbatim + ** Original author: dejan + ** Major contributors: none + ** Minor contributors (to current version): none + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief [[ Add one-line brief description here ]] + ** + ** [[ Add lengthier description here ]] + ** \todo document this file + **/ + +#pragma once + +#include "theory/bv/theory_bv_rewrite_rules.h" +#include "theory/bv/theory_bv_utils.h" + +namespace CVC4 { +namespace theory { +namespace bv { + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_CONCAT); +} + +template<> +Node RewriteRule::apply(Node node) { + NodeBuilder<> result(kind::BITVECTOR_CONCAT); + vector processing_stack; + processing_stack.push_back(node); + while (!processing_stack.empty()) { + Node current = processing_stack.back(); + processing_stack.pop_back(); + if (current.getKind() == kind::BITVECTOR_CONCAT) { + for (int i = current.getNumChildren() - 1; i >= 0; i--) + processing_stack.push_back(current[i]); + } else { + result << current; + } + } + Node resultNode = result; + return resultNode; +} + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::BITVECTOR_CONCAT); +} + +template<> +Node RewriteRule::apply(Node node) { + vector mergedExtracts; + + Node current = node[0]; + bool mergeStarted = false; + unsigned currentHigh = 0; + unsigned currentLow = 0; + + for(size_t i = 1, end = node.getNumChildren(); i < end; ++ i) { + // The next candidate for merging + Node next = node[i]; + // If the current is not an extract we just go to the next + if (current.getKind() != kind::BITVECTOR_EXTRACT) { + mergedExtracts.push_back(current); + current = next; + continue; + } + // If it is an extract and the first one, get the extract parameters + else if (!mergeStarted) { + currentHigh = utils::getExtractHigh(current); + currentLow = utils::getExtractLow(current); + } + + // If the next one can be merged, try to merge + bool merged = false; + if (next.getKind() == kind::BITVECTOR_EXTRACT && current[0] == next[0]) { + //x[i : j] @ x[j − 1 : k] -> c x[i : k] + unsigned nextHigh = utils::getExtractHigh(next); + unsigned nextLow = utils::getExtractLow(next); + if(nextHigh + 1 == currentLow) { + currentLow = nextLow; + mergeStarted = true; + merged = true; + } + } + // If we haven't merged anything, add the previous merge and continue with the next + if (!merged) { + if (!mergeStarted) mergedExtracts.push_back(current); + else mergedExtracts.push_back(utils::mkExtract(current[0], currentHigh, currentLow)); + current = next; + mergeStarted = false; + } + } + + // Add the last child + if (!mergeStarted) mergedExtracts.push_back(current); + else mergedExtracts.push_back(utils::mkExtract(current[0], currentHigh, currentLow)); + + // Return the result + return utils::mkConcat(mergedExtracts); +} + +template<> +bool RewriteRule::applies(Node node) { + return node.getKind() == kind::BITVECTOR_CONCAT; +} + +template<> +Node RewriteRule::apply(Node node) { + vector mergedConstants; + for (unsigned i = 0, end = node.getNumChildren(); i < end;) { + if (node[i].getKind() != kind::CONST_BITVECTOR) { + // If not a constant, just add it + mergedConstants.push_back(node[i]); + ++i; + } else { + // Find the longest sequence of constants + unsigned j = i + 1; + while (j < end) { + if (node[j].getKind() != kind::CONST_BITVECTOR) { + break; + } else { + ++ j; + } + } + // Append all the constants + BitVector current = node[i].getConst(); + for (unsigned k = i + 1; k < j; ++ k) { + current = current.concat(node[k].getConst()); + } + // Add the new merged constant + mergedConstants.push_back(utils::mkConst(current)); + i = j + 1; + } + } + + return utils::mkConcat(mergedConstants); +} + +template<> +bool RewriteRule::applies(Node node) { + if (node.getKind() != kind::BITVECTOR_EXTRACT) return false; + unsigned length = utils::getSize(node[0]); + unsigned extractHigh = utils::getExtractHigh(node); + if (extractHigh != length - 1) return false; + unsigned extractLow = utils::getExtractLow(node); + if (extractLow != 0) return false; + return true; +} + +template<> +Node RewriteRule::apply(Node node) { + return node[0]; +} + +template<> +bool RewriteRule::applies(Node node) { + if (node.getKind() != kind::BITVECTOR_EXTRACT) return false; + if (node[0].getKind() != kind::CONST_BITVECTOR) return false; + return true; +} + +template<> +Node RewriteRule::apply(Node node) { + Node child = node[0]; + BitVector childValue = child.getConst(); + return utils::mkConst(childValue.extract(utils::getExtractHigh(node), utils::getExtractLow(node))); +} + +template<> +bool RewriteRule::applies(Node node) { + if (node.getKind() != kind::BITVECTOR_EXTRACT) return false; + if (node[0].getKind() != kind::BITVECTOR_CONCAT) return false; + return true; +} + +template<> +Node RewriteRule::apply(Node node) { + int extract_high = utils::getExtractHigh(node); + int extract_low = utils::getExtractLow(node); + + vector resultChildren; + + Node concat = node[0]; + for (int i = concat.getNumChildren() - 1; i >= 0 && extract_high >= 0; i--) { + Node concatChild = concat[i]; + int concatChildSize = utils::getSize(concatChild); + if (extract_low < concatChildSize) { + int extract_start = extract_low < 0 ? 0 : extract_low; + int extract_end = extract_high < concatChildSize ? extract_high : concatChildSize - 1; + resultChildren.push_back(utils::mkExtract(concatChild, extract_end, extract_start)); + } + extract_low -= concatChildSize; + extract_high -= concatChildSize; + } + + std::reverse(resultChildren.begin(), resultChildren.end()); + + return utils::mkConcat(resultChildren); +} + +template<> +bool RewriteRule::applies(Node node) { + if (node.getKind() != kind::BITVECTOR_EXTRACT) return false; + if (node[0].getKind() != kind::BITVECTOR_EXTRACT) return false; + return true; +} + +template<> +Node RewriteRule::apply(Node node) { + // x[i:j][k:l] ~> x[k+j:l+j] + Node child = node[0]; + unsigned k = utils::getExtractHigh(node); + unsigned l = utils::getExtractLow(node); + unsigned j = utils::getExtractLow(child); + + Node result = utils::mkExtract(child[0], k + j, l + j); + return result; +} + +template<> +bool RewriteRule::applies(Node node) { + if (node.getKind() != kind::EQUAL) return false; + if (node[0].getKind() != kind::CONST_BITVECTOR) return false; + if (node[1].getKind() != kind::CONST_BITVECTOR) return false; + return node[0] != node[1]; +} + +template<> +Node RewriteRule::apply(Node node) { + return utils::mkFalse(); +} + +template<> +bool RewriteRule::applies(Node node) { + if (node.getKind() != kind::EQUAL) return false; + return node[0] == node[1]; +} + +template<> +Node RewriteRule::apply(Node node) { + return utils::mkTrue(); +} + +template<> +bool RewriteRule::applies(Node node) { + return (node.getKind() == kind::EQUAL && node[0] < node[1]); +} + +template<> +Node RewriteRule::apply(Node node) { + return node[1].eqNode(node[0]);; +} + +} +} +} diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp index cd2efd64f..08245afcb 100644 --- a/src/theory/bv/theory_bv_rewriter.cpp +++ b/src/theory/bv/theory_bv_rewriter.cpp @@ -5,8 +5,10 @@ * Author: dejan */ +#include "theory/theory.h" #include "theory/bv/theory_bv_rewriter.h" #include "theory/bv/theory_bv_rewrite_rules.h" +#include "theory/bv/theory_bv_rewrite_rules_core.h" using namespace CVC4; using namespace CVC4::theory; @@ -18,41 +20,43 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) { Node result; - if (node.getKind() == kind::CONST_BITVECTOR /* || isLeaf(n)) */) + if (node.getKind() == kind::CONST_BITVECTOR || (node.getKind() != kind::EQUAL && Theory::isLeafOf(node, THEORY_BV))) { result = node; - else { + } else { switch (node.getKind()) { case kind::BITVECTOR_CONCAT: result = LinearRewriteStrategy< // Flatten the top level concatenations - CoreRewriteRules::ConcatFlatten, + RewriteRule, // Merge the adjacent extracts on non-constants - CoreRewriteRules::ConcatExtractMerge, + RewriteRule, // Merge the adjacent extracts on constants - CoreRewriteRules::ConcatConstantMerge, + RewriteRule, // At this point only Extract-Whole could apply, if the result is only one extract // or at some sub-expression if the result is a concatenation. - ApplyRuleToChildren + ApplyRuleToChildren >::apply(node); break; case kind::BITVECTOR_EXTRACT: result = LinearRewriteStrategy< // Extract over a constant gives a constant - CoreRewriteRules::ExtractConstant, + RewriteRule, // Extract over an extract is simplified to one extract - CoreRewriteRules::ExtractExtract, + RewriteRule, // Extract over a concatenation is distributed to the appropriate concatenations - CoreRewriteRules::ExtractConcat, + RewriteRule, // At this point only Extract-Whole could apply - CoreRewriteRules::ExtractWhole + RewriteRule >::apply(node); break; case kind::EQUAL: result = LinearRewriteStrategy< // Two distinct values rewrite to false - CoreRewriteRules::FailEq, + RewriteRule, // If both sides are equal equality is true - CoreRewriteRules::SimplifyEq + RewriteRule, + // Normalize the equalities + RewriteRule >::apply(node); break; default: @@ -68,3 +72,12 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) { return RewriteResponse(REWRITE_DONE, result); } +AllRewriteRules* TheoryBVRewriter::s_allRules = NULL; + +void TheoryBVRewriter::init() { + s_allRules = new AllRewriteRules; +} + +void TheoryBVRewriter::shutdown() { + delete s_allRules; +} diff --git a/src/theory/bv/theory_bv_rewriter.h b/src/theory/bv/theory_bv_rewriter.h index 741b9fcbc..437ac49d3 100644 --- a/src/theory/bv/theory_bv_rewriter.h +++ b/src/theory/bv/theory_bv_rewriter.h @@ -7,16 +7,18 @@ #pragma once - - #include "theory/rewriter.h" namespace CVC4 { namespace theory { namespace bv { +class AllRewriteRules; + class TheoryBVRewriter { + static AllRewriteRules* s_allRules; + public: static RewriteResponse postRewrite(TNode node); @@ -25,9 +27,8 @@ public: return RewriteResponse(REWRITE_DONE, node); } - static inline void init() {} - static inline void shutdown() {} - + static void init(); + static void shutdown(); }; }