updates for the rewriter, added some statistics
authorDejan Jovanović <dejan.jovanovic@gmail.com>
Wed, 16 Feb 2011 01:26:26 +0000 (01:26 +0000)
committerDejan Jovanović <dejan.jovanovic@gmail.com>
Wed, 16 Feb 2011 01:26:26 +0000 (01:26 +0000)
src/theory/bv/Makefile.am
src/theory/bv/theory_bv_rewrite_rules.cpp [deleted file]
src/theory/bv/theory_bv_rewrite_rules.h
src/theory/bv/theory_bv_rewrite_rules_core.h [new file with mode: 0644]
src/theory/bv/theory_bv_rewriter.cpp
src/theory/bv/theory_bv_rewriter.h

index 3e84f482ce780982ba1f955ef52238d5bfcf84dc..fdace42b48b7e96cb94c96559bd130f2971f50d6 100644 (file)
@@ -10,7 +10,7 @@ libbv_la_SOURCES = \
        theory_bv.cpp \
        theory_bv_utils.h \
        theory_bv_rewrite_rules.h \
-       theory_bv_rewrite_rules.cpp \
+       theory_bv_rewrite_rules_core.h \
        theory_bv_type_rules.h \
        theory_bv_rewriter.h \
        theory_bv_rewriter.cpp \
diff --git a/src/theory/bv/theory_bv_rewrite_rules.cpp b/src/theory/bv/theory_bv_rewrite_rules.cpp
deleted file mode 100644 (file)
index d2fb621..0000000
+++ /dev/null
@@ -1,305 +0,0 @@
-/*********************                                                        */
-/*! \file theory_bv_rewrite_rules.cpp
- ** \verbatim
- ** Original author: dejan
- ** Major contributors: none
- ** Minor contributors (to current version): none
- ** This file is part of the CVC4 prototype.
- ** Copyright (c) 2009, 2010  The Analysis of Computer Systems Group (ACSys)
- ** Courant Institute of Mathematical Sciences
- ** New York University
- ** See the file COPYING in the top-level source directory for licensing
- ** information.\endverbatim
- **
- ** \brief [[ Add one-line brief description here ]]
- **
- ** [[ Add lengthier description here ]]
- ** \todo document this file
- **/
-
-#include <vector>
-#include "expr/node_builder.h"
-#include "theory_bv_rewrite_rules.h"
-#include "theory_bv_utils.h"
-
-using namespace std;
-using namespace CVC4;
-using namespace CVC4::theory;
-using namespace CVC4::theory::bv;
-using namespace CVC4::theory::bv::utils;
-
-bool CoreRewriteRules::ConcatFlatten::applies(Node node) {
-  return (node.getKind() == kind::BITVECTOR_CONCAT);
-}
-
-Node CoreRewriteRules::ConcatFlatten::apply(Node node) {
-  Assert(applies(node));
-
-  Debug("bitvector") << "ConcatFlatten(" << node << ")" << endl;
-
-  NodeBuilder<> result(kind::BITVECTOR_CONCAT);
-  vector<Node> processing_stack;
-  processing_stack.push_back(node);
-  while (!processing_stack.empty()) {
-    Node current = processing_stack.back();
-    processing_stack.pop_back();
-    if (current.getKind() == kind::BITVECTOR_CONCAT) {
-      for (int i = current.getNumChildren() - 1; i >= 0; i--)
-        processing_stack.push_back(current[i]);
-    } else {
-      result << current;
-    }
-  }
-
-  Node resultNode = result;
-  Debug("bitvector") << "ConcatFlatten(" << node << ") => " << resultNode << endl;
-
-  return resultNode;
-}
-
-bool CoreRewriteRules::ConcatExtractMerge::applies(Node node) {
-  return (node.getKind() == kind::BITVECTOR_CONCAT);
-}
-
-Node CoreRewriteRules::ConcatExtractMerge::apply(Node node) {
-  Assert(applies(node));
-
-  Debug("bitvector") << "ConcatExtractMerge(" << node << ")" << endl;
-
-  vector<Node> mergedExtracts;
-
-  Node current = node[0];
-  bool mergeStarted = false;
-  unsigned currentHigh = 0;
-  unsigned currentLow  = 0;
-
-  for(size_t i = 1, end = node.getNumChildren(); i < end; ++ i) {
-    // The next candidate for merging
-    Node next = node[i];
-    // If the current is not an extract we just go to the next
-    if (current.getKind() != kind::BITVECTOR_EXTRACT) {
-      mergedExtracts.push_back(current);
-      current = next;
-      continue;
-    }
-    // If it is an extract and the first one, get the extract parameters
-    else if (!mergeStarted) {
-      currentHigh = getExtractHigh(current);
-      currentLow = getExtractLow(current);
-    }
-
-    // If the next one can be merged, try to merge
-    bool merged = false;
-    if (next.getKind() == kind::BITVECTOR_EXTRACT && current[0] == next[0]) {
-      //x[i : j] @ x[j − 1 : k] -> c x[i : k]
-      unsigned nextHigh = getExtractHigh(next);
-      unsigned nextLow  = getExtractLow(next);
-      if(nextHigh + 1 == currentLow) {
-        currentLow = nextLow;
-        mergeStarted = true;
-        merged = true;
-      }
-    }
-    // If we haven't merged anything, add the previous merge and continue with the next
-    if (!merged) {
-      if (!mergeStarted) mergedExtracts.push_back(current);
-      else mergedExtracts.push_back(mkExtract(current[0], currentHigh, currentLow));
-      current = next;
-      mergeStarted = false;
-    }
-  }
-
-  // Add the last child
-  if (!mergeStarted) mergedExtracts.push_back(current);
-  else mergedExtracts.push_back(mkExtract(current[0], currentHigh, currentLow));
-
-  // Create the result
-  Node result = mkConcat(mergedExtracts);
-
-  Debug("bitvector") << "ConcatExtractMerge(" << node << ") =>" << result << endl;
-
-  // Return the result
-  return result;
-}
-
-bool CoreRewriteRules::ConcatConstantMerge::applies(Node node) {
-  return node.getKind() == kind::BITVECTOR_CONCAT;
-}
-
-Node CoreRewriteRules::ConcatConstantMerge::apply(Node node) {
-  Assert(applies(node));
-
-  Debug("bitvector") << "ConcatConstantMerge(" << node << ")" << endl;
-
-  vector<Node> mergedConstants;
-  for (unsigned i = 0, end = node.getNumChildren(); i < end;) {
-    if (node[i].getKind() != kind::CONST_BITVECTOR) {
-      // If not a constant, just add it
-      mergedConstants.push_back(node[i]);
-      ++i;
-    } else {
-      // Find the longest sequence of constants
-      unsigned j = i + 1;
-      while (j < end) {
-        if (node[j].getKind() != kind::CONST_BITVECTOR) {
-          break;
-        } else {
-          ++ j;
-        }
-      }
-      // Append all the constants
-      BitVector current = node[i].getConst<BitVector>();
-      for (unsigned k = i + 1; k < j; ++ k) {
-        current = current.concat(node[k].getConst<BitVector>());
-      }
-      // Add the new merged constant
-      mergedConstants.push_back(mkConst(current));
-      i = j + 1;
-    }
-  }
-
-  Node result = mkConcat(mergedConstants);
-
-  Debug("bitvector") << "ConcatConstantMerge(" << node << ") => " << result << endl;
-
-  return result;
-}
-
-bool CoreRewriteRules::ExtractWhole::applies(Node node) {
-  if (node.getKind() != kind::BITVECTOR_EXTRACT) return false;
-  unsigned length = getSize(node[0]);
-  unsigned extractHigh = getExtractHigh(node);
-  if (extractHigh != length - 1) return false;
-  unsigned extractLow  = getExtractLow(node);
-  if (extractLow != 0) return false;
-  return true;
-}
-
-Node CoreRewriteRules::ExtractWhole::apply(Node node) {
-  Assert(applies(node));
-
-  Debug("bitvector") << "ExtractWhole(" << node << ")" << endl;
-  Debug("bitvector") << "ExtractWhole(" << node << ") => " << node[0] << endl;
-
-  return node[0];
-}
-
-bool CoreRewriteRules::ExtractConstant::applies(Node node) {
-  if (node.getKind() != kind::BITVECTOR_EXTRACT) return false;
-  if (node[0].getKind() != kind::CONST_BITVECTOR) return false;
-  return true;
-}
-
-Node CoreRewriteRules::ExtractConstant::apply(Node node) {
-  Assert(applies(node));
-
-  Debug("bitvector") << "ExtractConstant(" << node << ")" << endl;
-
-  Node child = node[0];
-  BitVector childValue = child.getConst<BitVector>();
-
-  Node result = mkConst(childValue.extract(getExtractHigh(node), getExtractLow(node)));
-
-  Debug("bitvector") << "ExtractConstant(" << node << ") => " << result << endl;
-
-  return result;
-}
-
-bool CoreRewriteRules::ExtractConcat::applies(Node node) {
-  if (node.getKind() != kind::BITVECTOR_EXTRACT) return false;
-  if (node[0].getKind() != kind::BITVECTOR_CONCAT) return false;
-  return true;
-}
-
-Node CoreRewriteRules::ExtractConcat::apply(Node node) {
-  Assert(applies(node));
-
-  Debug("bitvector") << "ExtractConcat(" << node << ")" << endl;
-
-  int extract_high = getExtractHigh(node);
-  int extract_low = getExtractLow(node);
-
-  vector<Node> resultChildren;
-
-  Node concat = node[0];
-  for (int i = concat.getNumChildren() - 1; i >= 0 && extract_high >= 0; i--) {
-    Node concatChild = concat[i];
-    int concatChildSize = getSize(concatChild);
-    if (extract_low < concatChildSize) {
-      int extract_start = extract_low < 0 ? 0 : extract_low;
-      int extract_end = extract_high < concatChildSize ? extract_high : concatChildSize - 1;
-      resultChildren.push_back(mkExtract(concatChild, extract_end, extract_start));
-    }
-    extract_low -= concatChildSize;
-    extract_high -= concatChildSize;
-  }
-
-  std::reverse(resultChildren.begin(), resultChildren.end());
-
-  Node result = mkConcat(resultChildren);
-
-  Debug("bitvector") << "ExtractConcat(" << node << ") => " << result << endl;
-
-  return result;
-}
-
-bool CoreRewriteRules::ExtractExtract::applies(Node node) {
-  if (node.getKind() != kind::BITVECTOR_EXTRACT) return false;
-  if (node[0].getKind() != kind::BITVECTOR_EXTRACT) return false;
-  return true;
-}
-
-Node CoreRewriteRules::ExtractExtract::apply(Node node) {
-  Assert(applies(node));
-
-  Debug("bitvector") << "ExtractExtract(" << node << ")" << endl;
-
-  // x[i:j][k:l] ~>  x[k+j:l+j]
-  Node child = node[0];
-  unsigned k = getExtractHigh(node);
-  unsigned l = getExtractLow(node);
-  unsigned j = getExtractLow(child);
-
-  Node result = mkExtract(child[0], k + j, l + j);
-
-  Debug("bitvector") << "ExtractExtract(" << node << ") => " << result << endl;
-
-  return result;
-}
-
-bool CoreRewriteRules::FailEq::applies(Node node) {
-  if (node.getKind() != kind::EQUAL) return false;
-  if (node[0].getKind() != kind::CONST_BITVECTOR) return false;
-  if (node[1].getKind() != kind::CONST_BITVECTOR) return false;
-  return node[0] != node[1];
-}
-
-Node CoreRewriteRules::FailEq::apply(Node node) {
-  Assert(applies(node));
-
-  Debug("bitvector") << "FailEq(" << node << ")" << endl;
-
-  Node result = mkFalse();
-
-  Debug("bitvector") << "FailEq(" << node << ") => " << result << endl;
-
-  return result;
-}
-
-bool CoreRewriteRules::SimplifyEq::applies(Node node) {
-  if (node.getKind() != kind::EQUAL) return false;
-  return node[0] == node[1];
-}
-
-Node CoreRewriteRules::SimplifyEq::apply(Node node) {
-  Assert(applies(node));
-
-  Debug("bitvector") << "FailEq(" << node << ")" << endl;
-
-  Node result = mkTrue();
-
-  Debug("bitvector") << "FailEq(" << node << ") => " << result << endl;
-
-  return result;
-}
-
index eba8f917c8af004949bc6a3e7d539a2b6e8d512d..32d0f92a03d2a01838798fec7e3208f2b4a8f76d 100644 (file)
 #include "cvc4_private.h"
 #include "theory/theory.h"
 #include "context/context.h"
+#include "util/stats.h"
+#include <sstream>
 
 namespace CVC4 {
 namespace theory {
 namespace bv {
 
-struct CoreRewriteRules {
+enum RewriteRuleId {
+  EmptyRule,
+  ConcatFlatten,
+  ConcatExtractMerge,
+  ConcatConstantMerge,
+  ExtractExtract,
+  ExtractWhole,
+  ExtractConcat,
+  ExtractConstant,
+  FailEq,
+  SimplifyEq,
+  ReflexivityEq,
+};
 
-  struct EmptyRule {
-    static inline Node apply(Node node) { return node; }
-    static inline bool applies(Node node) { return false; }
-  };
+inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) {
+  switch (ruleId) {
+  case EmptyRule:           out << "EmptyRule";           return out;
+  case ConcatFlatten:       out << "ConcatFlatten";       return out;
+  case ConcatExtractMerge:  out << "ConcatExtractMerge";  return out;
+  case ConcatConstantMerge: out << "ConcatConstantMerge"; return out;
+  case ExtractExtract:      out << "ExtractExtract";      return out;
+  case ExtractWhole:        out << "ExtractWhole";        return out;
+  case ExtractConcat:       out << "ExtractConcat";       return out;
+  case ExtractConstant:     out << "ExtractConstant";     return out;
+  case FailEq:              out << "FailEq";              return out;
+  case SimplifyEq:          out << "SimplifyEq";          return out;
+  case ReflexivityEq:       out << "ReflexivityEq";       return out;
+  default:
+    Unreachable();
+  }
+};
 
-  struct ConcatFlatten {
-    static Node apply(Node node);
-    static bool applies(Node node);
-  };
+template <RewriteRuleId rule>
+class RewriteRule {
 
-  struct ConcatExtractMerge {
-    static Node apply(Node node);
-    static bool applies(Node node);
-  };
+  class RuleStatistics {
 
-  struct ConcatConstantMerge {
-    static Node apply(Node node);
-    static bool applies(Node node);
-  };
+    /** The name of the rule prefixed with the prefix */
+    static std::string getStatName(const char* prefix) {
+      std::stringstream statName;
+      statName << prefix << rule;
+      return statName.str();
+    }
 
-  struct ExtractExtract {
-    static Node apply(Node node);
-    static bool applies(Node node);
-  };
+  public:
 
-  struct ExtractWhole {
-    static Node apply(Node node);
-    static bool applies(Node node);
-  };
+    /** Number of applications of this rule */
+    IntStat d_ruleApplications;
 
-  struct ExtractConcat {
-    static Node apply(Node node);
-    static bool applies(Node node);
-  };
+    /** Constructor */
+    RuleStatistics()
+    : d_ruleApplications(getStatName("theory::bv::count"), 0) {
+      StatisticsRegistry::registerStat(&d_ruleApplications);
+    }
 
-  struct ExtractConstant {
-    static Node apply(Node node);
-    static bool applies(Node node);
+    /** Destructor */
+    ~RuleStatistics() {
+      StatisticsRegistry::unregisterStat(&d_ruleApplications);
+    }
   };
 
-  struct FailEq {
-    static Node apply(Node node);
-    static bool applies(Node node);
-  };
+  /** Statistics about the rule */
+  static RuleStatistics* s_statictics;
 
-  struct SimplifyEq {
-    static Node apply(Node node);
-    static bool applies(Node node);
-  };
+  /** Actually apply the rewrite rule */
+  static inline Node apply(Node node) {
+    Unreachable();
+  }
+
+public:
 
+  RewriteRule() {
+    if (s_statictics == NULL) {
+      s_statictics = new RuleStatistics();
+    }
+  }
+
+  ~RewriteRule() {
+    delete s_statictics;
+    s_statictics = NULL;
+  }
+
+  static inline bool applies(Node node) {
+    Unreachable();
+  }
+
+  template<bool checkApplies>
+  static inline Node run(Node node) {
+    if (!checkApplies || applies(node)) {
+      Debug("theory::bv::rewrite") << "RewriteRule<" << rule << ">(" << node << ")" << std::endl;
+      Assert(checkApplies || applies(node));
+      ++ s_statictics->d_ruleApplications;
+      Node result = apply(node);
+      Debug("theory::bv::rewrite") << "RewriteRule<" << rule << ">(" << node << ") => " << result << std::endl;
+      return result;
+    } else {
+      return node;
+    }
+  }
+};
+
+template<RewriteRuleId rule>
+typename RewriteRule<rule>::RuleStatistics* RewriteRule<rule>::s_statictics = NULL;
+
+/** Have to list all the rewrite rules to get the statistics out */
+struct AllRewriteRules {
+  RewriteRule<EmptyRule>            rule00;
+  RewriteRule<ConcatFlatten>        rule01;
+  RewriteRule<ConcatExtractMerge>   rule02;
+  RewriteRule<ConcatConstantMerge>  rule03;
+  RewriteRule<ExtractExtract>       rule04;
+  RewriteRule<ExtractWhole>         rule05;
+  RewriteRule<ExtractConcat>        rule06;
+  RewriteRule<ExtractConstant>      rule07;
+  RewriteRule<FailEq>               rule08;
+  RewriteRule<SimplifyEq>           rule09;
+  RewriteRule<ReflexivityEq>        rule10;
 };
 
-template<Kind kind, typename Rule>
+template<>
+bool RewriteRule<EmptyRule>::applies(Node node) {
+  return false;
+}
+
+template<>
+Node RewriteRule<EmptyRule>::apply(Node node) {
+  Unreachable();
+  return node;
+}
+
+template<Kind kind, RewriteRuleId rule>
 struct ApplyRuleToChildren {
 
   static Node apply(Node node) {
     if (node.getKind() != kind) {
-      if (Rule::applies(node)) return Rule::apply(node);
-      else return node;
+      return RewriteRule<rule>::template run<true>(node);
     }
     NodeBuilder<> result(kind);
     for (unsigned i = 0, end = node.getNumChildren(); i < end; ++ i) {
-      if (Rule::applies(node[i])) result << Rule::apply(node[i]);
-      else result << node[i];
+      result << RewriteRule<rule>::template run<true>(node[i]);
     }
     return result;
   }
 
   static bool applies(Node node) {
     if (node.getKind() == kind) return true;
-    return Rule::applies(node);
+    return RewriteRule<rule>::applies(node);
   }
 
+  template <bool checkApplies>
+  static Node run(Node node) {
+    if (!checkApplies || applies(node)) {
+      return apply(node);
+    } else {
+      return node;
+    }
+  }
 };
 
-
 template <
   typename R1,
   typename R2,
-  typename R3 = CoreRewriteRules::EmptyRule,
-  typename R4 = CoreRewriteRules::EmptyRule,
-  typename R5 = CoreRewriteRules::EmptyRule,
-  typename R6 = CoreRewriteRules::EmptyRule,
-  typename R7 = CoreRewriteRules::EmptyRule
+  typename R3 = RewriteRule<EmptyRule>,
+  typename R4 = RewriteRule<EmptyRule>,
+  typename R5 = RewriteRule<EmptyRule>,
+  typename R6 = RewriteRule<EmptyRule>,
+  typename R7 = RewriteRule<EmptyRule>,
+  typename R8 = RewriteRule<EmptyRule>
   >
 struct LinearRewriteStrategy {
   static Node apply(Node node) {
     Node current = node;
-    if (R1::applies(current)) current = R1::apply(current);
-    if (R2::applies(current)) current = R2::apply(current);
-    if (R3::applies(current)) current = R3::apply(current);
-    if (R4::applies(current)) current = R4::apply(current);
-    if (R5::applies(current)) current = R5::apply(current);
-    if (R6::applies(current)) current = R6::apply(current);
-    if (R7::applies(current)) current = R7::apply(current);
+    if (R1::applies(current)) current = R1::template run<false>(current);
+    if (R2::applies(current)) current = R2::template run<false>(current);
+    if (R3::applies(current)) current = R3::template run<false>(current);
+    if (R4::applies(current)) current = R4::template run<false>(current);
+    if (R5::applies(current)) current = R5::template run<false>(current);
+    if (R6::applies(current)) current = R6::template run<false>(current);
+    if (R7::applies(current)) current = R7::template run<false>(current);
+    if (R8::applies(current)) current = R8::template run<false>(current);
     return current;
   }
 };
diff --git a/src/theory/bv/theory_bv_rewrite_rules_core.h b/src/theory/bv/theory_bv_rewrite_rules_core.h
new file mode 100644 (file)
index 0000000..e75f537
--- /dev/null
@@ -0,0 +1,265 @@
+/*********************                                                        */
+/*! \file theory_bv_rewrite_rules_core.h
+ ** \verbatim
+ ** Original author: dejan
+ ** Major contributors: none
+ ** Minor contributors (to current version): none
+ ** This file is part of the CVC4 prototype.
+ ** Copyright (c) 2009, 2010  The Analysis of Computer Systems Group (ACSys)
+ ** Courant Institute of Mathematical Sciences
+ ** New York University
+ ** See the file COPYING in the top-level source directory for licensing
+ ** information.\endverbatim
+ **
+ ** \brief [[ Add one-line brief description here ]]
+ **
+ ** [[ Add lengthier description here ]]
+ ** \todo document this file
+ **/
+
+#pragma once
+
+#include "theory/bv/theory_bv_rewrite_rules.h"
+#include "theory/bv/theory_bv_utils.h"
+
+namespace CVC4 {
+namespace theory {
+namespace bv {
+
+template<>
+bool RewriteRule<ConcatFlatten>::applies(Node node) {
+  return (node.getKind() == kind::BITVECTOR_CONCAT);
+}
+
+template<>
+Node RewriteRule<ConcatFlatten>::apply(Node node) {
+  NodeBuilder<> result(kind::BITVECTOR_CONCAT);
+  vector<Node> processing_stack;
+  processing_stack.push_back(node);
+  while (!processing_stack.empty()) {
+    Node current = processing_stack.back();
+    processing_stack.pop_back();
+    if (current.getKind() == kind::BITVECTOR_CONCAT) {
+      for (int i = current.getNumChildren() - 1; i >= 0; i--)
+        processing_stack.push_back(current[i]);
+    } else {
+      result << current;
+    }
+  }
+  Node resultNode = result;
+  return resultNode;
+}
+
+template<>
+bool RewriteRule<ConcatExtractMerge>::applies(Node node) {
+  return (node.getKind() == kind::BITVECTOR_CONCAT);
+}
+
+template<>
+Node RewriteRule<ConcatExtractMerge>::apply(Node node) {
+  vector<Node> mergedExtracts;
+
+  Node current = node[0];
+  bool mergeStarted = false;
+  unsigned currentHigh = 0;
+  unsigned currentLow  = 0;
+
+  for(size_t i = 1, end = node.getNumChildren(); i < end; ++ i) {
+    // The next candidate for merging
+    Node next = node[i];
+    // If the current is not an extract we just go to the next
+    if (current.getKind() != kind::BITVECTOR_EXTRACT) {
+      mergedExtracts.push_back(current);
+      current = next;
+      continue;
+    }
+    // If it is an extract and the first one, get the extract parameters
+    else if (!mergeStarted) {
+      currentHigh = utils::getExtractHigh(current);
+      currentLow = utils::getExtractLow(current);
+    }
+
+    // If the next one can be merged, try to merge
+    bool merged = false;
+    if (next.getKind() == kind::BITVECTOR_EXTRACT && current[0] == next[0]) {
+      //x[i : j] @ x[j − 1 : k] -> c x[i : k]
+      unsigned nextHigh = utils::getExtractHigh(next);
+      unsigned nextLow  = utils::getExtractLow(next);
+      if(nextHigh + 1 == currentLow) {
+        currentLow = nextLow;
+        mergeStarted = true;
+        merged = true;
+      }
+    }
+    // If we haven't merged anything, add the previous merge and continue with the next
+    if (!merged) {
+      if (!mergeStarted) mergedExtracts.push_back(current);
+      else mergedExtracts.push_back(utils::mkExtract(current[0], currentHigh, currentLow));
+      current = next;
+      mergeStarted = false;
+    }
+  }
+
+  // Add the last child
+  if (!mergeStarted) mergedExtracts.push_back(current);
+  else mergedExtracts.push_back(utils::mkExtract(current[0], currentHigh, currentLow));
+
+  // Return the result
+  return utils::mkConcat(mergedExtracts);
+}
+
+template<>
+bool RewriteRule<ConcatConstantMerge>::applies(Node node) {
+  return node.getKind() == kind::BITVECTOR_CONCAT;
+}
+
+template<>
+Node RewriteRule<ConcatConstantMerge>::apply(Node node) {
+  vector<Node> mergedConstants;
+  for (unsigned i = 0, end = node.getNumChildren(); i < end;) {
+    if (node[i].getKind() != kind::CONST_BITVECTOR) {
+      // If not a constant, just add it
+      mergedConstants.push_back(node[i]);
+      ++i;
+    } else {
+      // Find the longest sequence of constants
+      unsigned j = i + 1;
+      while (j < end) {
+        if (node[j].getKind() != kind::CONST_BITVECTOR) {
+          break;
+        } else {
+          ++ j;
+        }
+      }
+      // Append all the constants
+      BitVector current = node[i].getConst<BitVector>();
+      for (unsigned k = i + 1; k < j; ++ k) {
+        current = current.concat(node[k].getConst<BitVector>());
+      }
+      // Add the new merged constant
+      mergedConstants.push_back(utils::mkConst(current));
+      i = j + 1;
+    }
+  }
+
+  return utils::mkConcat(mergedConstants);
+}
+
+template<>
+bool RewriteRule<ExtractWhole>::applies(Node node) {
+  if (node.getKind() != kind::BITVECTOR_EXTRACT) return false;
+  unsigned length = utils::getSize(node[0]);
+  unsigned extractHigh = utils::getExtractHigh(node);
+  if (extractHigh != length - 1) return false;
+  unsigned extractLow  = utils::getExtractLow(node);
+  if (extractLow != 0) return false;
+  return true;
+}
+
+template<>
+Node RewriteRule<ExtractWhole>::apply(Node node) {
+  return node[0];
+}
+
+template<>
+bool RewriteRule<ExtractConstant>::applies(Node node) {
+  if (node.getKind() != kind::BITVECTOR_EXTRACT) return false;
+  if (node[0].getKind() != kind::CONST_BITVECTOR) return false;
+  return true;
+}
+
+template<>
+Node RewriteRule<ExtractConstant>::apply(Node node) {
+  Node child = node[0];
+  BitVector childValue = child.getConst<BitVector>();
+  return utils::mkConst(childValue.extract(utils::getExtractHigh(node), utils::getExtractLow(node)));
+}
+
+template<>
+bool RewriteRule<ExtractConcat>::applies(Node node) {
+  if (node.getKind() != kind::BITVECTOR_EXTRACT) return false;
+  if (node[0].getKind() != kind::BITVECTOR_CONCAT) return false;
+  return true;
+}
+
+template<>
+Node RewriteRule<ExtractConcat>::apply(Node node) {
+  int extract_high = utils::getExtractHigh(node);
+  int extract_low = utils::getExtractLow(node);
+
+  vector<Node> resultChildren;
+
+  Node concat = node[0];
+  for (int i = concat.getNumChildren() - 1; i >= 0 && extract_high >= 0; i--) {
+    Node concatChild = concat[i];
+    int concatChildSize = utils::getSize(concatChild);
+    if (extract_low < concatChildSize) {
+      int extract_start = extract_low < 0 ? 0 : extract_low;
+      int extract_end = extract_high < concatChildSize ? extract_high : concatChildSize - 1;
+      resultChildren.push_back(utils::mkExtract(concatChild, extract_end, extract_start));
+    }
+    extract_low -= concatChildSize;
+    extract_high -= concatChildSize;
+  }
+
+  std::reverse(resultChildren.begin(), resultChildren.end());
+
+  return utils::mkConcat(resultChildren);
+}
+
+template<>
+bool RewriteRule<ExtractExtract>::applies(Node node) {
+  if (node.getKind() != kind::BITVECTOR_EXTRACT) return false;
+  if (node[0].getKind() != kind::BITVECTOR_EXTRACT) return false;
+  return true;
+}
+
+template<>
+Node RewriteRule<ExtractExtract>::apply(Node node) {
+  // x[i:j][k:l] ~>  x[k+j:l+j]
+  Node child = node[0];
+  unsigned k = utils::getExtractHigh(node);
+  unsigned l = utils::getExtractLow(node);
+  unsigned j = utils::getExtractLow(child);
+
+  Node result = utils::mkExtract(child[0], k + j, l + j);
+  return result;
+}
+
+template<>
+bool RewriteRule<FailEq>::applies(Node node) {
+  if (node.getKind() != kind::EQUAL) return false;
+  if (node[0].getKind() != kind::CONST_BITVECTOR) return false;
+  if (node[1].getKind() != kind::CONST_BITVECTOR) return false;
+  return node[0] != node[1];
+}
+
+template<>
+Node RewriteRule<FailEq>::apply(Node node) {
+    return utils::mkFalse();
+}
+
+template<>
+bool RewriteRule<SimplifyEq>::applies(Node node) {
+  if (node.getKind() != kind::EQUAL) return false;
+  return node[0] == node[1];
+}
+
+template<>
+Node RewriteRule<SimplifyEq>::apply(Node node) {
+  return utils::mkTrue();
+}
+
+template<>
+bool RewriteRule<ReflexivityEq>::applies(Node node) {
+  return (node.getKind() == kind::EQUAL && node[0] < node[1]);
+}
+
+template<>
+Node RewriteRule<ReflexivityEq>::apply(Node node) {
+  return node[1].eqNode(node[0]);;
+}
+
+}
+}
+}
index cd2efd64f10899253332b50017186f10b2b4788c..08245afcbd474bcac2e062186fc3c57860e86c29 100644 (file)
@@ -5,8 +5,10 @@
  *      Author: dejan
  */
 
+#include "theory/theory.h"
 #include "theory/bv/theory_bv_rewriter.h"
 #include "theory/bv/theory_bv_rewrite_rules.h"
+#include "theory/bv/theory_bv_rewrite_rules_core.h"
 
 using namespace CVC4;
 using namespace CVC4::theory;
@@ -18,41 +20,43 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) {
 
   Node result;
 
-  if (node.getKind() == kind::CONST_BITVECTOR /* || isLeaf(n)) */)
+  if (node.getKind() == kind::CONST_BITVECTOR || (node.getKind() != kind::EQUAL && Theory::isLeafOf(node, THEORY_BV))) {
     result = node;
-  else {
+  else {
     switch (node.getKind()) {
     case kind::BITVECTOR_CONCAT:
       result = LinearRewriteStrategy<
                   // Flatten the top level concatenations
-                  CoreRewriteRules::ConcatFlatten,
+                  RewriteRule<ConcatFlatten>,
                   // Merge the adjacent extracts on non-constants
-                  CoreRewriteRules::ConcatExtractMerge,
+                  RewriteRule<ConcatExtractMerge>,
                   // Merge the adjacent extracts on constants
-                  CoreRewriteRules::ConcatConstantMerge,
+                  RewriteRule<ConcatConstantMerge>,
                   // At this point only Extract-Whole could apply, if the result is only one extract
                   // or at some sub-expression if the result is a concatenation.
-                  ApplyRuleToChildren<kind::BITVECTOR_CONCAT, CoreRewriteRules::ExtractWhole>
+                  ApplyRuleToChildren<kind::BITVECTOR_CONCAT, ExtractWhole>
                >::apply(node);
       break;
     case kind::BITVECTOR_EXTRACT:
       result = LinearRewriteStrategy<
                   // Extract over a constant gives a constant
-                  CoreRewriteRules::ExtractConstant,
+                  RewriteRule<ExtractConstant>,
                   // Extract over an extract is simplified to one extract
-                  CoreRewriteRules::ExtractExtract,
+                  RewriteRule<ExtractExtract>,
                   // Extract over a concatenation is distributed to the appropriate concatenations
-                  CoreRewriteRules::ExtractConcat,
+                  RewriteRule<ExtractConcat>,
                   // At this point only Extract-Whole could apply
-                  CoreRewriteRules::ExtractWhole
+                  RewriteRule<ExtractWhole>
                 >::apply(node);
       break;
     case kind::EQUAL:
       result = LinearRewriteStrategy<
                   // Two distinct values rewrite to false
-                  CoreRewriteRules::FailEq,
+                  RewriteRule<FailEq>,
                   // If both sides are equal equality is true
-                  CoreRewriteRules::SimplifyEq
+                  RewriteRule<SimplifyEq>,
+                  // Normalize the equalities
+                  RewriteRule<ReflexivityEq>
                 >::apply(node);
       break;
     default:
@@ -68,3 +72,12 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) {
   return RewriteResponse(REWRITE_DONE, result);
 }
 
+AllRewriteRules* TheoryBVRewriter::s_allRules = NULL;
+
+void TheoryBVRewriter::init() {
+  s_allRules = new AllRewriteRules;
+}
+
+void TheoryBVRewriter::shutdown() {
+  delete s_allRules;
+}
index 741b9fcbc09c4f37b6a98dac7912b5d72fc5efe8..437ac49d30b4f49b83c31898447865e04fe91e2f 100644 (file)
@@ -7,16 +7,18 @@
 
 #pragma once
 
-
-
 #include "theory/rewriter.h"
 
 namespace CVC4 {
 namespace theory {
 namespace bv {
 
+class AllRewriteRules;
+
 class TheoryBVRewriter {
 
+  static AllRewriteRules* s_allRules;
+
 public:
 
   static RewriteResponse postRewrite(TNode node);
@@ -25,9 +27,8 @@ public:
     return RewriteResponse(REWRITE_DONE, node);
   }
 
-  static inline void init() {}
-  static inline void shutdown() {}
-
+  static void init();
+  static void shutdown();
 };
 
 }