commit for the version of bitvectors that passes all the unit tests
authorDejan Jovanović <dejan.jovanovic@gmail.com>
Sun, 20 Mar 2011 01:12:31 +0000 (01:12 +0000)
committerDejan Jovanović <dejan.jovanovic@gmail.com>
Sun, 20 Mar 2011 01:12:31 +0000 (01:12 +0000)
src/theory/bv/cd_set_collection.h
src/theory/bv/equality_engine.h
src/theory/bv/slice_manager.h
src/theory/bv/theory_bv.cpp
src/theory/bv/theory_bv.h
src/theory/bv/theory_bv_rewriter.cpp
src/theory/bv/theory_bv_utils.h
src/util/bitvector.h
test/regress/regress0/bv/core/Makefile.am
test/regress/regress0/bv/core/slice-20.cvc [new file with mode: 0644]

index 33648660be6aacad1b2426856ba5863a06b94b3c..aeb28ab7bcce6bb6a65df05fa6d1b2e46cade59b 100644 (file)
@@ -123,28 +123,30 @@ public:
     return newElement(value, null, null, null, false);
   }
 
-  void insert(memory_type& memory, reference_type root, const value_type& value) {
+  void insert(reference_type& root, const value_type& value) {
     backtrack();
     if (root == null) {
-      return newSet(value);
+      root = newSet(value);
+      return;
     }
     // We already have a set, find the spot
     reference_type parent = null;
+    reference_type current = root;
     while (true) {
-      parent = root;
-      if (value < d_memory[root].value) {
-        root = d_memory[root].left;
-        if (root == null) {
-          root = newElement(value, null, null, parent, true);
-          d_memory[parent].left = root;
+      parent = current;
+      if (value < d_memory[current].getValue()) {
+        if (d_memory[current].hasLeft()) {
+          current = d_memory[current].getLeft();
+        } else {
+          d_memory[current].setLeft(newElement(value, null, null, parent, true));
           return;
         }
       } else {
-        Assert(value != d_memory[root].value);
-        root = d_memory[root].right;
-        if (root == null) {
-          root = newElement(value, null, null, parent, false);
-          d_memory[parent].right = root;
+        Assert(value != d_memory[root].getValue());
+        if (d_memory[current].hasRight()) {
+          current = d_memory[current].getRight();
+        } else {
+          d_memory[parent].setRight(newElement(value, null, null, parent, false));
           return;
         }
       }
@@ -174,49 +176,55 @@ public:
    */
   const_value_reference prev(reference_type set, const_value_reference value) {
     backtrack();
-    // Get the node of this value
-    reference_type node_ref = find(set, value);
-    Assert(node_ref != null);
-    const tree_entry_type& node = d_memory[node_ref];
-    // For a left node, we know that it is smaller than all the parents and the parents other children
-    // The smaller node must then be the max of the left subtree
-    if (!node.hasParent() || node.isLeft()) {
-      return maxElement(node.getLeft());
-    }
-    // For a right node, we know that it is bigger than the parent. But, we also know that the left subtree
-    // is also bigger than the parent
-    else {
-      if (node.hasLeft()) {
-        return maxElement(node.getLeft());
+
+    const_value_reference candidate_value;
+    bool candidate_found = false;
+
+    // Find the biggest node smaleer than value (it must exist)
+    while (set != null) {
+      Debug("set_collection") << "BacktrackableSetCollection::getPrev(" << toString(set) << "," << value << ")" << std::endl;
+      const tree_entry_type& node = d_memory[set];
+      if (node.getValue() >= value) {
+        // If the node is bigger than the value, we need a smaller one
+        set = node.getLeft();
       } else {
-        Assert(node.hasParent());
-        return d_memory[node.getParent()].getValue();
+        // The node is smaller than the value
+        candidate_found = true;
+        candidate_value = node.getValue();
+        // There might be a bigger one
+        set = node.getRight();
       }
     }
+
+    Assert(candidate_found);
+    return candidate_value;
   }
 
   const_value_reference next(reference_type set, const_value_reference value) {
     backtrack();
-    // Get the node of this value
-    reference_type node_ref = find(set, value);
-    Assert(node_ref != null);
-    const tree_entry_type& node = d_memory[node_ref];
-    // For a right node, we know that it is bigger than all the parents and the parents other children
-    // The bigger node must then be the min of the right subtree
-    if (!node.hasParent() || node.isRight()) {
-      return minElement(node.getRight());
-    }
-    // For a left node, we know that it is smaller than the parent. But, we also know that the right subtree
-    // is also smaller than the parent
-    else {
-      if (node.hasRight()) {
-        return minElement(node.getRight());
+
+    const_value_reference candidate_value;
+    bool candidate_found = false;
+
+    // Find the smallest node bigger than value (it must exist)
+    while (set != null) {
+      Debug("set_collection") << "BacktrackableSetCollection::getNext(" << toString(set) << "," << value << ")" << std::endl;
+      const tree_entry_type& node = d_memory[set];
+      if (node.getValue() <= value) {
+        // If the node is smaller than the value, we need a bigger one
+        set = node.getRight();
       } else {
-        Assert(node.hasParent());
-        return d_memory[node.getParent()].getValue();
+        // The node is bigger than the value
+        candidate_found = true;
+        candidate_value = node.getValue();
+        // There might be a smaller one
+        set = node.getLeft();
       }
     }
-  }
+
+    Assert(candidate_found);
+    return candidate_value;
+}
 
   /**
    * Count the number of elements in the given bounds.
@@ -262,6 +270,9 @@ public:
   void getElements(reference_type set, const_value_reference lowerBound, const_value_reference upperBound, std::vector<value_type>& output) const {
     Assert(lowerBound <= upperBound);
     backtrack();
+
+    Debug("set_collection") << "BacktrackableSetCollection::getElements(" << toString(set) << "," << lowerBound << "," << upperBound << ")" << std::endl;
+
     // Empty set no elements
     if (set == null) {
       return;
@@ -277,7 +288,7 @@ public:
       output.push_back(current.getValue());
     }
     // Right child (bigger elements)
-    if (current.getValue() <= upperBound) {
+    if (current.getValue() < upperBound) {
       getElements(current.getRight(), lowerBound, upperBound, output);
     }
   }
@@ -285,7 +296,7 @@ public:
   /**
    * Print the list of elements to the output.
    */
-  void print(std::ostream& out, reference_type set) {
+  void print(std::ostream& out, reference_type set) const {
     backtrack();
     if (set == null) {
       return;
@@ -305,7 +316,7 @@ public:
   /**
    * String representation of a set.
    */
-  std::string toString(reference_type set) {
+  std::string toString(reference_type set) const {
     std::stringstream out;
     print(out, set);
     return out.str();
index 9880539ed2909aca9b8c71d150ee2c1ccacdcb8c..53c44bed0243bb4e7dbec2a4aaaac2b0dc3b7442 100644 (file)
 #include "context/cdo.h"
 #include "util/output.h"
 #include "util/stats.h"
+#include "theory/rewriter.h"
 
 namespace CVC4 {
 namespace theory {
 namespace bv {
 
 struct BitSizeTraits {
-
   /** The null id */
   static const size_t id_null; // Defined in the cpp file (GCC bug)
   /** The null trigger id */
@@ -46,13 +46,6 @@ struct BitSizeTraits {
   static const size_t size_bits = 16;
   /** Number of bits we use for the trigger id */
   static const size_t trigger_id_bits = 24;
-
-  /** Number of bits we use for the function ids */
-  static const size_t function_id_bits = 8;
-  /** Number of bits we use for the function arguments count */
-  static const size_t function_arguments_count_bits = 16;
-  /** Number of bits we use for the index into the arguments memory */
-  static const size_t function_arguments_index_bits = 24;
 };
 
 class EqualityNode {
@@ -68,22 +61,18 @@ public:
   /** The next equality node in this class */
   size_t d_nextId : BitSizeTraits::id_bits;
 
-  /** Is this node a function application */
-  size_t d_isFunction : 1;
-
 public:
 
   /**
    * Creates a new node, which is in a list of it's own.
    */
   EqualityNode(size_t nodeId = BitSizeTraits::id_null)
-  : d_size(1), d_findId(nodeId), d_nextId(nodeId), d_isFunction(0) {}
+  : d_size(1), d_findId(nodeId), d_nextId(nodeId) {}
 
   /** Initialize the equality node */
-  inline void init(size_t nodeId, bool isFunction) {
+  inline void init(size_t nodeId) {
     d_size = 1;
     d_findId = d_nextId = nodeId;
-    d_isFunction = isFunction;
   }
 
   /**
@@ -125,66 +114,11 @@ public:
   inline void setFind(size_t findId) { d_findId = findId; }
 };
 
-/**
- * FunctionNode class represents the information related to a function node. It has an id, number of children
- * and the
- */
-class FunctionNode {
-
-  /** Is the function associative */
-  size_t d_isAssociative  : 1;
-  /** The id of the function */
-  size_t d_functionId     : BitSizeTraits::function_id_bits;
-  /** Number of children */
-  size_t d_argumentsCount : BitSizeTraits::function_arguments_count_bits;
-  /** Index of the start of the arguments in the children array */
-  size_t d_argumentsIndex : BitSizeTraits::function_arguments_index_bits;
-
-public:
-
-  FunctionNode(size_t functionId = 0, size_t argumentsCount = 0, size_t argumentsIndex = 0, bool associative = false)
-  : d_isAssociative(associative), d_functionId(functionId), d_argumentsCount(argumentsCount), d_argumentsIndex(argumentsIndex)
-  {}
-
-  void init(size_t functionId, size_t argumentsCount, size_t argumentsIndex, bool associative) {
-    d_functionId = functionId;
-    d_argumentsCount = argumentsCount;
-    d_argumentsIndex = argumentsIndex;
-    d_isAssociative = associative;
-  }
-
-  /** Check if the function is associative */
-  bool isAssociative() const { return d_isAssociative; }
-
-  /** Get the function id */
-  size_t getFunctionId() const { return d_functionId; }
-
-  /** Get the number of arguments */
-  size_t getArgumentsCount() const { return d_argumentsCount; }
-
-  /** Get the infex of the first argument in the arguments memory */
-  size_t getArgumentsIndex() const { return d_argumentsIndex; }
-
-};
-
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
 class EqualityEngine {
 
 public:
 
-  /**
-   * Basic information about a function.
-   */
-  struct FunctionInfo {
-    /** Name of the function */
-    std::string name;
-    /** Is the function associative */
-    bool isAssociative;
-
-    FunctionInfo(std::string name, bool isAssociative)
-    : name(name), isAssociative(isAssociative) {}
-  };
-
   /** Statistics about the equality engine instance */
   struct Statistics {
     /** Total number of merges */
@@ -193,8 +127,6 @@ public:
     IntStat termsCount;
     /** Number of function terms managed by the system */
     IntStat functionTermsCount;
-    /** Number of distince functions managed by the system */
-    IntStat functionsCount;
     /** Number of times we performed a backtrack */
     IntStat backtracksCount;
 
@@ -202,13 +134,11 @@ public:
     : mergesCount(name + "::mergesCount", 0),
       termsCount(name + "::termsCount", 0),
       functionTermsCount(name + "functionTermsCoutn", 0),
-      functionsCount(name + "::functionsCount", 0),
       backtracksCount(name + "::backtracksCount", 0)
     {
       StatisticsRegistry::registerStat(&mergesCount);
       StatisticsRegistry::registerStat(&termsCount);
       StatisticsRegistry::registerStat(&functionTermsCount);
-      StatisticsRegistry::registerStat(&functionsCount);
       StatisticsRegistry::registerStat(&backtracksCount);
     }
 
@@ -216,7 +146,6 @@ public:
       StatisticsRegistry::unregisterStat(&mergesCount);
       StatisticsRegistry::unregisterStat(&termsCount);
       StatisticsRegistry::unregisterStat(&functionTermsCount);
-      StatisticsRegistry::unregisterStat(&functionsCount);
       StatisticsRegistry::unregisterStat(&backtracksCount);
     }
   };
@@ -238,12 +167,6 @@ private:
   /** Number of asserted equalities we have so far */
   context::CDO<size_t> d_assertedEqualitiesCount;
 
-  /** Map from ids to functional representations */
-  std::vector<FunctionNode> d_functionNodes;
-
-  /** Functions in the system */
-  std::vector<FunctionInfo> d_functions;
-
   /**
    * We keep a list of asserted equalities. Not among original terms, but
    * among the class representatives.
@@ -261,6 +184,8 @@ private:
   /** The ids of the classes we have merged */
   std::vector<Equality> d_assertedEqualities;
 
+  /** The reasons for the equalities */
+
   /**
    * An edge in the equality graph. This graph is an undirected graph (both edges added)
    * containing the actual asserted equalities.
@@ -291,6 +216,11 @@ private:
    */
   std::vector<EqualityEdge> d_equalityEdges;
 
+  /**
+   * Reasons for equalities.
+   */
+  std::vector<Node> d_equalityReasons;
+
   /**
    * Map from a node to it's first edge in the equality graph. Edges are added to the front of the
    * list which makes the insertion/backtracking easy.
@@ -298,7 +228,7 @@ private:
   std::vector<size_t> d_equalityGraph;
 
   /** Add an edge to the equality graph */
-  inline void addGraphEdge(size_t t1, size_t t2);
+  inline void addGraphEdge(size_t t1, size_t t2, Node reason);
 
   /** Returns the equality node of the given node */
   inline EqualityNode& getEqualityNode(TNode node);
@@ -385,11 +315,6 @@ public:
    */
   size_t addTerm(TNode t);
 
-  /**
-   * Adds a term that is an application of a function symbol to the databas. Returns the internal id of the term.
-   */
-  size_t addFunctionApplication(size_t funcionId, const std::vector<TNode>& arguments);
-
   /**
    * Check whether the node is already in the database.
    */
@@ -398,7 +323,7 @@ public:
   /**
    * Adds an equality t1 = t2 to the database. Returns false if any of the triggers failed.
    */
-  bool addEquality(TNode t1, TNode t2);
+  bool addEquality(TNode t1, TNode t2, Node reason);
 
   /**
    * Returns the representative of the term t.
@@ -424,23 +349,27 @@ public:
   size_t addTrigger(TNode t1, TNode t2);
 
   /**
-   * Adds a new function to the equality engine. The funcions are not of fixed arity and no typechecking is performed!
-   * Associative functions allow for normalization, i.e. f(f(x, y), z) = f(x, f(y, z)) = f(x, y, z).
-   * @associative should be true if the function is associative and you want this to be handled by the engine
+   * Normalizes a term by finding the representative. If the representative can be decomposed (using
+   * UnionFindPreferences) it will try and recursively find the representatives, and substitute.
+   * Assumptions used in normalization are retruned in the set.
    */
-  inline size_t newFunction(std::string name, bool associative) {
-    Assert(use_functions);
-    Assert(!associative || enable_associative);
-    ++ d_stats.functionsCount;
-    size_t id = d_functions.size();
-    d_functions.push_back(FunctionInfo(name, associative));
-    return id;
-  }
+  Node normalize(TNode node, std::set<TNode>& assumptions);
+
+private:
+
+  /** Hash of normalizations to avioid cycles */
+  typedef __gnu_cxx::hash_map<TNode, Node, TNodeHashFunction> normalization_cache;
+  normalization_cache d_normalizationCache;
+
+  /**
+   * Same as above, but does cahcing to avoid loops.
+   */
+  Node normalizeWithCache(TNode node, std::set<TNode>& assumptions);
 
 };
 
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-size_t EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::addTerm(TNode t) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+size_t EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::addTerm(TNode t) {
 
   Debug("equality") << "EqualityEngine::addTerm(" << t << ")" << std::endl;
 
@@ -462,69 +391,35 @@ size_t EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative
   if (d_equalityNodes.size() <= newId) {
     d_equalityNodes.resize(newId + 100);
   }
-  d_equalityNodes[newId].init(newId, false);
-  // Return the id of the term
-  return newId;
-}
-
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-size_t EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::addFunctionApplication(size_t functionId, const std::vector<TNode>& arguments) {
-
-  Debug("equality") << "EqualityEngine::addFunctionApplication(" << d_functions[functionId].name << ":" << arguments.size() << ")" << std::endl;
-
-  ++ d_stats.functionTermsCount;
-  ++ d_stats.termsCount;
-
-  // Register the new id of the term
-  size_t newId = d_nodes.size();
-  // Add the node to it's position
-  d_nodes.push_back(Node());
-  // Add the trigger list for this node
-  d_nodeTriggers.push_back(BitSizeTraits::trigger_id_null);
-  // Add it to the equality graph
-  d_equalityGraph.push_back(BitSizeTraits::id_null);
-  // Add the equality node to the nodes
-  if (d_equalityNodes.size() <= newId) {
-    d_equalityNodes.resize(newId + 100);
-  }
-  d_equalityNodes[newId].init(newId, true);
-  // Add the function application to the function nodes
-  if (d_functionNodes.size() <= newId) {
-    d_functionNodes.resize(newId + 100);
-  }
-  // Initialize the function node
-  size_t argumentsIndex;
-  d_functionNodes[newId].init(functionId, arguments.size(), argumentsIndex, d_functions[functionId].isAssociative);
-
+  d_equalityNodes[newId].init(newId);
   // Return the id of the term
   return newId;
-
 }
 
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-bool EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::hasTerm(TNode t) const {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+bool EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::hasTerm(TNode t) const {
   return d_nodeIds.find(t) != d_nodeIds.end();
 }
 
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-size_t EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::getNodeId(TNode node) const {
-  Assert(hasTerm(node));
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+size_t EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::getNodeId(TNode node) const {
+  Assert(hasTerm(node), node.toString().c_str());
   return (*d_nodeIds.find(node)).second;
 }
 
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-EqualityNode& EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::getEqualityNode(TNode t) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+EqualityNode& EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::getEqualityNode(TNode t) {
   return getEqualityNode(getNodeId(t));
 }
 
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-EqualityNode& EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::getEqualityNode(size_t nodeId) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+EqualityNode& EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::getEqualityNode(size_t nodeId) {
   Assert(nodeId < d_equalityNodes.size());
   return d_equalityNodes[nodeId];
 }
 
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-bool EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::addEquality(TNode t1, TNode t2) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+bool EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::addEquality(TNode t1, TNode t2, Node reason) {
 
   Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << ")" << std::endl;
 
@@ -549,18 +444,20 @@ bool EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
   Assert(node1.getFind() == t1classId);
   Assert(node2.getFind() == t2classId);
 
-  // Depending on the size, merge them
+  // Depending on the merge preference (such as size), merge them
   std::vector<size_t> triggers;
-  if (node1.getSize() < node2.getSize()) {
+  if (UnionFindPreferences::mergePreference(d_nodes[t2classId], node2.getSize(), d_nodes[t1classId], node1.getSize())) {
+    Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << "): merging " << t1 << " into " << t2 << std::endl;
     merge(node2, node1, triggers);
     d_assertedEqualities.push_back(Equality(t2classId, t1classId));
   } else {
+    Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << "): merging " << t2 << " into " << t1 << std::endl;
     merge(node1, node2, triggers);
     d_assertedEqualities.push_back(Equality(t1classId, t2classId));
   }
 
   // Add the actuall equality to the equality graph
-  addGraphEdge(t1Id, t2Id);
+  addGraphEdge(t1Id, t2Id, reason);
 
   // One more equality added
   d_assertedEqualitiesCount = d_assertedEqualitiesCount + 1;
@@ -577,8 +474,8 @@ bool EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
   return true;
 }
 
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-TNode EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::getRepresentative(TNode t) const {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+TNode EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::getRepresentative(TNode t) const {
 
   Debug("equality") << "EqualityEngine::getRepresentative(" << t << ")" << std::endl;
 
@@ -593,8 +490,8 @@ TNode EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>
   return d_nodes[representativeId];
 }
 
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-bool EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::areEqual(TNode t1, TNode t2) const {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+bool EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::areEqual(TNode t1, TNode t2) const {
   Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ")" << std::endl;
 
   Assert(hasTerm(t1));
@@ -610,8 +507,8 @@ bool EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
   return rep1 == rep2;
 }
 
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::merge(EqualityNode& class1, EqualityNode& class2, std::vector<size_t>& triggers) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+void EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::merge(EqualityNode& class1, EqualityNode& class2, std::vector<size_t>& triggers) {
 
   Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << ")" << std::endl;
 
@@ -660,8 +557,8 @@ void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
   class1.merge<true>(class2);
 }
 
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::undoMerge(EqualityNode& class1, EqualityNode& class2, size_t class2Id) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+void EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::undoMerge(EqualityNode& class1, EqualityNode& class2, size_t class2Id) {
 
   Debug("equality") << "EqualityEngine::undoMerge(" << class1.getFind() << "," << class2Id << ")" << std::endl;
 
@@ -692,8 +589,8 @@ void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
 
 }
 
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::backtrack() {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+void EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::backtrack() {
 
   // If we need to backtrack then do it
   if (d_assertedEqualitiesCount < d_assertedEqualities.size()) {
@@ -721,22 +618,24 @@ void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
     }
 
     d_equalityEdges.resize(2 * d_assertedEqualitiesCount);
+    d_equalityReasons.resize(d_assertedEqualitiesCount);
   }
 
 }
 
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::addGraphEdge(size_t t1, size_t t2) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+void EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::addGraphEdge(size_t t1, size_t t2, Node reason) {
   Debug("equality") << "EqualityEngine::addGraphEdge(" << d_nodes[t1] << "," << d_nodes[t2] << ")" << std::endl;
   size_t edge = d_equalityEdges.size();
   d_equalityEdges.push_back(EqualityEdge(t2, d_equalityGraph[t1]));
   d_equalityEdges.push_back(EqualityEdge(t1, d_equalityGraph[t2]));
   d_equalityGraph[t1] = edge;
   d_equalityGraph[t2] = edge | 1;
+  d_equalityReasons.push_back(reason);
 }
 
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::getExplanation(TNode t1, TNode t2, std::vector<TNode>& equalities) const {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+void EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::getExplanation(TNode t1, TNode t2, std::vector<TNode>& equalities) const {
   Assert(equalities.empty());
   Assert(t1 != t2);
   Assert(getRepresentative(t1) == getRepresentative(t2));
@@ -784,15 +683,9 @@ void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
 
           // Reconstruct the path
           do {
-            // Get the left and right hand side from the edge
-            size_t firstEdge = (currentEdge >> 1) << 1;
-            size_t secondEdge = (currentEdge | 1);
-            TNode lhs = d_nodes[d_equalityEdges[secondEdge].getNodeId()];
-            TNode rhs = d_nodes[d_equalityEdges[firstEdge].getNodeId()];
             // Add the actual equality to the vector
-            equalities.push_back(lhs.eqNode(rhs));
-
-            Debug("equality") << "EqualityEngine::getExplanation(): adding: " << lhs.eqNode(rhs) << std::endl;
+            equalities.push_back(d_equalityReasons[currentEdge >> 1]);
+            Debug("equality") << "EqualityEngine::getExplanation(): adding: " << d_equalityReasons[currentEdge >> 1] << std::endl;
 
             // Go to the previous
             currentEdge = bfsQueue[currentIndex].edgeId;
@@ -816,8 +709,8 @@ void EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>:
   }
 }
 
-template <typename OwnerClass, typename NotifyClass, bool use_functions, bool enable_associative>
-size_t EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative>::addTrigger(TNode t1, TNode t2) {
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+size_t EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::addTrigger(TNode t1, TNode t2) {
 
   Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << ")" << std::endl;
 
@@ -850,6 +743,64 @@ size_t EqualityEngine<OwnerClass, NotifyClass, use_functions, enable_associative
   return t1NewTriggerId / 2;
 }
 
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+Node EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::normalize(TNode node, std::set<TNode>& assumptions) {
+  d_normalizationCache.clear();
+  Node result = Rewriter::rewrite(normalizeWithCache(node, assumptions));
+  d_normalizationCache.clear();
+  return result;
+}
+
+
+template <typename OwnerClass, typename NotifyClass, typename UnionFindPreferences>
+Node EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::normalizeWithCache(TNode node, std::set<TNode>& assumptions) {
+
+  Debug("equality") << "EqualityEngine::normalize(" << node << ")" << push << std::endl;
+
+  normalization_cache::iterator find = d_normalizationCache.find(node);
+  if (find != d_normalizationCache.end()) {
+    if (find->second.isNull()) {
+      // We are in a cycle
+      return node;
+    } else {
+      // Not in a cycle, return it
+      return find->second;
+    }
+  } else {
+    d_normalizationCache[node] = Node();
+  }
+
+  // Get the representative
+  Node result = hasTerm(node) ? getRepresentative(node) : node;
+  if (node != result) {
+    std::vector<TNode> equalities;
+    getExplanation(result, node, equalities);
+    assumptions.insert(equalities.begin(), equalities.end());
+  }
+
+  // If asked, substitute the children with their representatives
+  if (UnionFindPreferences::descend(result)) {
+    // Make the builder for substitution
+    NodeBuilder<> builder;
+    builder << result.getKind();
+    kind::MetaKind metaKind = result.getMetaKind();
+    if (metaKind == kind::metakind::PARAMETERIZED) {
+      builder << result.getOperator();
+    }
+    for (unsigned i = 0; i < result.getNumChildren(); ++ i) {
+      builder << normalizeWithCache(result[i], assumptions);
+    }
+    result = builder;
+  }
+
+  Debug("equality") << "EqualityEngine::normalize(" << node << ") => " << result << pop << std::endl;
+
+  // Cache the result for real now
+  d_normalizationCache[node] = result;
+
+  return result;
+}
+
 } // Namespace bv
 } // Namespace theory
 } // Namespace CVC4
index 8fc1e0b9d87d345e2e35b95458d7276dda9c4b10..96a0067dc91a12b255b9beaa994590d16da0acab 100644 (file)
@@ -13,6 +13,7 @@
 #include "theory/bv/cd_set_collection.h"
 
 #include <map>
+#include <set>
 #include <vector>
 
 namespace CVC4 {
@@ -134,9 +135,6 @@ private:
   /** The equality engine */
   EqualityEngine& d_equalityEngine;
 
-  /** The id of the concatenation function */
-  size_t d_concatFunctionId;
-
   /** The collection of backtrackable sets */
   set_collection d_setCollection;
 
@@ -153,27 +151,35 @@ public:
 
   SliceManager(TheoryBitvector& theoryBitvector, context::Context* context)
   : d_theoryBitvector(theoryBitvector), d_equalityEngine(theoryBitvector.getEqualityEngine()), d_setCollection(context) {
-    // We register the concatentation with the equality engine
-    d_concatFunctionId = d_equalityEngine.newFunction("bv_concat", true);
   }
 
-  inline size_t getConcatFunctionId() const { return d_concatFunctionId; }
-
   /**
-   * Adds the equality (lhs = rhs) to the slice manager. This will not add the equalities to the equality manager,
-   * but will slice the equality according to the current slicing in order to align all the slices. The terms that
-   * get slices get sent to the theory engine as equalities, i.e if we slice x[10:0] into x[10:5]@x[4:0] equality
-   * engine gets the assertion x[10:0] = concat(x[10:5], x[4:0]).
+   * Adds the equality (lhs = rhs) to the slice manager. The equality is first normalized according to the equality
+   * manager, i.e. each base term is taken from the equality manager, replaced in, and then the whole concatenation
+   * normalized and sliced wrt the current slicing. The method will not add the equalities to the equality manager,
+   * but instead will slice the equality according to the current slicing in order to align all the slices.
+   *
+   * The terms that get sliced get sent to the theory engine as equalities, i.e if we slice x[10:0] into x[10:5]@x[4:0]
+   * equality engine gets the assertion x[10:0] = concat(x[10:5], x[4:0]).
+   *
+   * input                                 output                            slicing
+   * --------------------------------------------------------------------------------------------------------------
+   * x@y = y@x                             x = y, y = x                      empty
+   * x[31:0]@x[64:32] = x                  x = x[31:0]@x[63:32]              x:{64,32,0}
+   * x@y = 0000@x@0000                     x = 0000@x[7:4], y = x[3:0]@0000  x:{8,4,0}
+   *
    */
-  inline void addEquality(TNode lhs, TNode rhs, std::vector<Node>& lhsSlices, std::vector<Node>& rhsSlices);
+  inline bool solveEquality(TNode lhs, TNode rhs);
 
 private:
 
+  inline bool solveEquality(TNode lhs, TNode rhs, const std::set<TNode>& assumptions);
+
   /**
-   * Slices up lhs and rhs and returns the slices in lhsSlices and rhsSlices
+   * Slices up lhs and rhs and returns the slices in lhsSlices and rhsSlices. The slices are not atomic,
+   * they are sliced in order to make one of lhs or rhs atomic, the other one can be a concatenation.
    */
-  inline void slice(std::vector<Node>& lhs, std::vector<Node>& rhs,
-                    std::vector<Node>& lhsSlices, std::vector<Node>& rhsSlices);
+  inline bool sliceAndSolve(std::vector<Node>& lhs, std::vector<Node>& rhs, const std::set<TNode>& assumptions);
 
   /**
    * Returns true if the term is already sliced wrt the current slicing. Note that, for example, even though
@@ -184,7 +190,7 @@ private:
   /**
    * Slices the term wrt the current slicing. When done, isSliced returns true
    */
-  inline void slice(TNode node, std::vector<Node>& sliced);
+  inline bool slice(TNode node, std::vector<Node>& sliced);
 
   /**
    * Returns the base term in the core theory of the given term, i.e.
@@ -196,20 +202,87 @@ private:
   static inline TNode baseTerm(TNode node);
 
   /**
-   * Adds a new slice to the slice set of the given base term.
+   * Adds a new slice to the slice set of the given term.
    */
-  inline void addSlice(Node baseTerm, unsigned slicePoint);
+  inline bool addSlice(Node term, unsigned slicePoint);
 };
 
 
 template <class TheoryBitvector>
-void SliceManager<TheoryBitvector>::addSlice(Node baseTerm, unsigned slicePoint) {
+bool SliceManager<TheoryBitvector>::addSlice(Node node, unsigned slicePoint) {
+  Debug("slicing") << "SliceMagager::addSlice(" << node << "," << slicePoint << ")" << std::endl;
+
+  bool ok = true;
+
+  int low  = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractLow(node) : 0;
+  int high = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractHigh(node) + 1: utils::getSize(node);
+  slicePoint += low;
+
+  TNode nodeBase = baseTerm(node);
+
+  set_reference sliceSet;
+  slicing_map::iterator find = d_nodeSlicing.find(nodeBase);
+  if (find == d_nodeSlicing.end()) {
+    sliceSet = d_nodeSlicing[nodeBase] = d_setCollection.newSet(slicePoint);
+    d_setCollection.insert(sliceSet, low);
+    d_setCollection.insert(sliceSet, high);
+  } else {
+    sliceSet = find->second;
+  }
+
+  // What are the points surrounding the new slice point
+  int prev = d_setCollection.prev(sliceSet, slicePoint);
+  int next = d_setCollection.next(sliceSet, slicePoint);
+
+  // Add the slice to the set
+  d_setCollection.insert(sliceSet, slicePoint);
+  Debug("slicing") << "SliceMagager::addSlice(" << node << "," << slicePoint << "): current set " << d_setCollection.toString(sliceSet) << std::endl;
+
+  // Add the terms and the equality to the equality engine
+  Node t1 = utils::mkExtract(nodeBase, next - 1, slicePoint);
+  Node t2 = utils::mkExtract(nodeBase, slicePoint - 1, prev);
+  Node nodeSlice = (next == high && prev == low) ? node : utils::mkExtract(nodeBase, next - 1, prev);
+  Node concat = utils::mkConcat(t1, t2);
+
+  d_equalityEngine.addTerm(t1);
+  d_equalityEngine.addTerm(t2);
+  d_equalityEngine.addTerm(nodeSlice);
+  d_equalityEngine.addTerm(concat);
+
+  // We are free to add this slice, unless the slice has a representative that's already a concat
+  TNode nodeSliceRepresentative = d_equalityEngine.getRepresentative(nodeSlice);
+  if (nodeSliceRepresentative.getKind() != kind::BITVECTOR_CONCAT) {
+    // Add the slice to the equality engine
+    ok = d_equalityEngine.addEquality(nodeSlice, concat, utils::mkTrue());
+  } else {
+    // If the representative is a concat, we must solve it
+    // There is no need do add nodeSlice = concat as we will solve the representative of nodeSlice
+    std::set<TNode> assumptions;
+    std::vector<TNode> equalities;
+    d_equalityEngine.getExplanation(nodeSlice, nodeSliceRepresentative, equalities);
+    assumptions.insert(equalities.begin(), equalities.end());
+    ok = solveEquality(nodeSliceRepresentative, concat, assumptions);
+  }
+
+  Debug("slicing") << "SliceMagager::addSlice(" << node << "," << slicePoint << ") => " << d_setCollection.toString(d_nodeSlicing[nodeBase]) << std::endl;
+
+  return ok;
 }
 
 template <class TheoryBitvector>
-void SliceManager<TheoryBitvector>::addEquality(TNode lhs, TNode rhs, std::vector<Node>& lhsSlices, std::vector<Node>& rhsSlices) {
+bool SliceManager<TheoryBitvector>::solveEquality(TNode lhs, TNode rhs) {
+  std::set<TNode> assumptions;
+  assumptions.insert(lhs.eqNode(rhs));
+  bool ok = solveEquality(lhs, rhs, assumptions);
+  return ok;
+}
+
+template <class TheoryBitvector>
+bool SliceManager<TheoryBitvector>::solveEquality(TNode lhs, TNode rhs, const std::set<TNode>& assumptions) {
+
+  Debug("slicing") << "SliceMagager::solveEquality(" << lhs << "," << rhs << "," << utils::setToString(assumptions) << ")" << push << std::endl;
 
-  Debug("slicing") << "SliceMagager::addEquality(" << lhs << "," << rhs << ")" << std::endl;
+  bool ok;
 
   // The concatenations on the left-hand side (reverse order, first is on top)
   std::vector<Node> lhsTerms;
@@ -232,60 +305,213 @@ void SliceManager<TheoryBitvector>::addEquality(TNode lhs, TNode rhs, std::vecto
   }
 
   // Slice the individual terms to align them
-  slice(lhsTerms, rhsTerms, lhsSlices, rhsSlices);
+  ok = sliceAndSolve(lhsTerms, rhsTerms, assumptions);
+
+  Debug("slicing") << "SliceMagager::solveEquality(" << lhs << "," << rhs << "," << utils::setToString(assumptions) << ")" << pop << std::endl;
+
+  return ok;
 }
 
+
 template <class TheoryBitvector>
-void SliceManager<TheoryBitvector>::slice(std::vector<Node>& lhs, std::vector<Node>& rhs,
-                                          std::vector<Node>& lhsSlices, std::vector<Node>& rhsSlices) {
+bool SliceManager<TheoryBitvector>::sliceAndSolve(std::vector<Node>& lhs, std::vector<Node>& rhs, const std::set<TNode>& assumptions)
+{
 
-  Debug("slicing") << "SliceManager::slice()" << std::endl;
+  Debug("slicing") << "SliceManager::sliceAndSolve()" << std::endl;
 
-  // Go through the work-list and align
+  // Go through the work-list, solve and align
   while (!lhs.empty()) {
 
     Assert(!rhs.empty());
 
+    Debug("slicing") << "SliceManager::sliceAndSolve(): lhs " << utils::vectorToString(lhs) << std::endl;
+    Debug("slicing") << "SliceManager::sliceAndSolve(): rhs " << utils::vectorToString(rhs) << std::endl;
+
     // The terms that we need to slice
     Node lhsTerm = lhs.back();
     Node rhsTerm = rhs.back();
-    Debug("slicing") << "slicing: " << lhsTerm << " and " << rhsTerm << std::endl;
+
+    Debug("slicing") << "SliceManager::sliceAndSolve(): " << lhsTerm << " : " << rhsTerm << std::endl;
 
     // If the terms are not sliced wrt the current slicing, we have them sliced
     lhs.pop_back();
     if (!isSliced(lhsTerm)) {
-      slice(lhsTerm, lhs);
+      if (!slice(lhsTerm, lhs)) return false;
+      Debug("slicing") << "SliceManager::sliceAndSolve(): lhs sliced" << std::endl;
       continue;
     }
     rhs.pop_back();
     if (!isSliced(rhsTerm)) {
-      slice(rhsTerm, rhs);
+      if (!slice(rhsTerm, rhs)) return false;
+      // We also need to put lhs back
+      lhs.push_back(lhsTerm);
+      Debug("slicing") << "SliceManager::sliceAndSolve(): rhs sliced" << std::endl;
+      continue;
     }
 
+    Debug("slicing") << "SliceManager::sliceAndSolve(): both lhs and rhs sliced already" << std::endl;
+
+    // The solving concatenation
+    std::vector<Node> concatTerms;
+
     // If the slices are of the same size we do the additional work
-    unsigned lhsSize = utils::getSize(lhsTerm);
-    unsigned rhsSize = utils::getSize(rhsTerm);
-    if (lhsSize == rhsSize) {
-      // If they are over the same base terms, we need to do something
-      TNode lhsBaseTerm = baseTerm(lhsTerm);
-      TNode rhsBaseTerm = baseTerm(rhsTerm);
-      if (lhsBaseTerm == rhsBaseTerm) {
-        // x[i_1:j_1] vs x[i_2:j_2]
+    int sizeDifference = utils::getSize(lhsTerm) - utils::getSize(rhsTerm);
+
+    // We slice constants immediately
+    if (sizeDifference > 0 && lhsTerm.getKind() == kind::CONST_BITVECTOR) {
+      BitVector low  = lhsTerm.getConst<BitVector>().extract(utils::getSize(rhsTerm) - 1, 0);
+      BitVector high = lhsTerm.getConst<BitVector>().extract(utils::getSize(lhsTerm) - 1, utils::getSize(rhsTerm));
+      lhs.push_back(utils::mkConst(low));
+      lhs.push_back(utils::mkConst(high));
+      rhs.push_back(rhsTerm);
+      continue;
+    }
+    if (sizeDifference < 0 && rhsTerm.getKind() == kind::CONST_BITVECTOR) {
+      BitVector low  = rhsTerm.getConst<BitVector>().extract(utils::getSize(lhsTerm) - 1, 0);
+      BitVector high = rhsTerm.getConst<BitVector>().extract(utils::getSize(rhsTerm) - 1, utils::getSize(lhsTerm));
+      rhs.push_back(utils::mkConst(low));
+      rhs.push_back(utils::mkConst(high));
+      lhs.push_back(lhsTerm);
+      continue;
+    }
+
+    enum SolvingFor {
+      SOLVING_FOR_LHS,
+      SOLVING_FOR_RHS
+    } solvingFor = sizeDifference < 0 || lhsTerm.getKind() == kind::CONST_BITVECTOR ? SOLVING_FOR_RHS : SOLVING_FOR_LHS;
+
+    Debug("slicing") << "SliceManager::sliceAndSolve(): " << (solvingFor == SOLVING_FOR_LHS ? "solving for LHS" : "solving for RHS") << std::endl;
+
+    // When we slice in order to align, we might have to reslice the one we are solving for
+    bool reslice = false;
+
+    switch (solvingFor) {
+    case SOLVING_FOR_RHS: {
+      concatTerms.push_back(lhsTerm);
+      // Maybe we need to add more lhs to make them equal
+      while (sizeDifference < 0 && !reslice) {
+        Assert(lhs.size() > 0);
+        // Get the next part for lhs
+        lhsTerm = lhs.back();
+        lhs.pop_back();
+        // Slice if necessary
+        if (!isSliced(lhsTerm)) {
+          if (!slice(lhsTerm, lhs)) return false;
+          continue;
+        }
+        // If we go above 0, we need to cut it
+        if (sizeDifference + (int)utils::getSize(lhsTerm) > 0) {
+          // Slice it so it fits
+          addSlice(lhsTerm, (int)utils::getSize(lhsTerm) + sizeDifference);
+          if (!slice(lhsTerm, lhs)) return false;
+          if (!isSliced(rhsTerm)) {
+            if (!slice(rhsTerm, rhs)) return false;
+            while(!concatTerms.empty()) {
+              lhs.push_back(concatTerms.back());
+              concatTerms.pop_back();
+            }
+            reslice = true;
+          }
+          continue;
+        }
+        concatTerms.push_back(lhsTerm);
+        sizeDifference += utils::getSize(lhsTerm);
+      }
+      break;
+    }
+    case SOLVING_FOR_LHS: {
+      concatTerms.push_back(rhsTerm);
+      // Maybe we need to add more rhs to make them equal
+      while (sizeDifference > 0 && !reslice) {
+        Assert(rhs.size() > 0);
+        // Get the next part for lhs
+        rhsTerm = rhs.back();
+        rhs.pop_back();
+        // Slice if necessary
+        if (!isSliced(rhsTerm)) {
+          if (!slice(rhsTerm, rhs)) return false;
+          continue;
+        }
+        // If we go below 0, we need to cut it
+        if (sizeDifference - (int)utils::getSize(rhsTerm) < 0) {
+          // Slice it so it fits
+          addSlice(rhsTerm, (int)utils::getSize(rhsTerm) - sizeDifference);
+          if (!slice(rhsTerm, rhs)) return false;
+         if (!isSliced(lhsTerm)) {
+            if (!slice(lhsTerm, lhs)) return false;
+            while(!concatTerms.empty()) {
+              rhs.push_back(concatTerms.back());
+              concatTerms.pop_back();
+            }
+            reslice = true;
+          }
+          continue;
+        }
+        concatTerms.push_back(rhsTerm);
+        sizeDifference -= utils::getSize(rhsTerm);
+      }
+      break;
+    }
+    }
+
+    // If we need to reslice
+    if (reslice) {
+       continue;
+    }
+
+    Assert(sizeDifference == 0);
+
+    Node concat = utils::mkConcat(concatTerms);
+    Debug("slicing") << "SliceManager::sliceAndSolve(): concatenation " << concat << std::endl;
+
+    // We have them equal size now. If the base term of the one we are solving is solved into a
+    // non-trivial concatenation already, we have to normalize. A concatenation is non-trivial if
+    // it is not a direct slicing, i.e it is a concat, and normalize(x) != x
+    switch (solvingFor) {
+    case SOLVING_FOR_LHS: {
+      TNode lhsTermRepresentative = d_equalityEngine.getRepresentative(lhsTerm);
+      if (lhsTermRepresentative != lhsTerm &&
+          (lhsTermRepresentative.getKind() == kind::BITVECTOR_CONCAT || lhsTermRepresentative.getKind() == kind::CONST_BITVECTOR)) {
+        // We need to normalize and solve the normalized equations
+        std::vector<TNode> explanation;
+        d_equalityEngine.getExplanation(lhsTerm, lhsTermRepresentative, explanation);
+        std::set<TNode> additionalAssumptions(assumptions);
+        additionalAssumptions.insert(explanation.begin(), explanation.end());
+        bool ok = solveEquality(lhsTermRepresentative, concat, additionalAssumptions);
+        if (!ok) return false;
       } else {
-        // x[i_1:j_1] vs y[i_2:j_2]
+        // We're fine, just add the equality
+        Debug("slicing") << "SliceManager::sliceAndSolve(): adding " << lhsTerm << " = " << concat << " " << utils::setToString(assumptions) << std::endl;
+        d_equalityEngine.addTerm(concat);
+        bool ok = d_equalityEngine.addEquality(lhsTerm, concat, utils::mkConjunction(assumptions));
+        if (!ok) return false;
       }
-      lhsSlices.push_back(lhsTerm);
-      rhsSlices.push_back(rhsTerm);
-      continue;
-    } else {
-      // They are not of equal sizes, so we slice one
-      if (lhsSize < rhsSize) {
-        // We need to cut a piece of rhs
+      break;
+    }
+    case SOLVING_FOR_RHS: {
+      TNode rhsTermRepresentative = d_equalityEngine.getRepresentative(rhsTerm);
+      if (rhsTermRepresentative != rhsTerm &&
+          (rhsTermRepresentative.getKind() == kind::BITVECTOR_CONCAT || rhsTermRepresentative.getKind() == kind::CONST_BITVECTOR)) {
+        // We need to normalize and solve the normalized equations
+        std::vector<TNode> explanation;
+        d_equalityEngine.getExplanation(rhsTerm, rhsTermRepresentative, explanation);
+        std::set<TNode> additionalAssumptions(assumptions);
+        additionalAssumptions.insert(explanation.begin(), explanation.end());
+        bool ok = solveEquality(rhsTermRepresentative, concat, additionalAssumptions);
+        if (!ok) return false;
       } else {
-        // We need to cut a piece of lhs
+        // We're fine, just add the equality
+        Debug("slicing") << "SliceManager::sliceAndSolve(): adding " << rhsTerm << " = " << concat << utils::setToString(assumptions) << std::endl;
+        d_equalityEngine.addTerm(concat);
+        bool ok = d_equalityEngine.addEquality(rhsTerm, concat, utils::mkConjunction(assumptions));
+        if (!ok) return false;
       }
+      break;
+    }
     }
   }
+
+  return true;
 }
 
 template <class TheoryBitvector>
@@ -315,9 +541,11 @@ bool SliceManager<TheoryBitvector>::isSliced(TNode node) const {
     if (find == d_nodeSlicing.end()) {
       result = nodeKind != kind::BITVECTOR_EXTRACT;
     } else {
-      // Check whether there is a slice point in [high, low), if there is the term is not sliced.
-      // Hence, if we look for the upper bound of low, and it is higher than high, it is sliced.
-      result = d_setCollection.count(find->second, low + 1, high) > 0;
+      // The term is not sliced if one of the borders is not in the slice set or
+      // there is a point between the borders      
+      result = 
+       d_setCollection.contains(find->second, low) && d_setCollection.contains(find->second, high + 1) &&
+        (low == high || d_setCollection.count(find->second, low + 1, high) == 0);
     }
   }
 
@@ -326,7 +554,7 @@ bool SliceManager<TheoryBitvector>::isSliced(TNode node) const {
 }
 
 template <class TheoryBitvector>
-inline void SliceManager<TheoryBitvector>::slice(TNode node, std::vector<Node>& sliced) {
+inline bool SliceManager<TheoryBitvector>::slice(TNode node, std::vector<Node>& sliced) {
 
   Debug("slicing") << "SliceManager::slice(" << node << ")" << std::endl;
 
@@ -335,44 +563,64 @@ inline void SliceManager<TheoryBitvector>::slice(TNode node, std::vector<Node>&
   // The indices of the beginning and (one past) end
   unsigned high = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractHigh(node) + 1 : utils::getSize(node);
   unsigned low  = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractLow(node) : 0;
+  Debug("slicing") << "SliceManager::slice(" << node << "): low: " << low << std::endl;
+  Debug("slicing") << "SliceManager::slice(" << node << "): high: " << high << std::endl;
 
   // Get the base term
   TNode nodeBase = baseTerm(node);
   Assert(nodeBase.getKind() != kind::BITVECTOR_CONCAT);
   Assert(nodeBase.getKind() != kind::CONST_BITVECTOR);
 
-  // Get the base term slice set
-  set_collection::reference_type nodeSliceSet = d_nodeSlicing[nodeBase];
+  // The nodes slice set
+  set_collection::reference_type nodeSliceSet;
+
+  // Find the current one or construct it
+  slicing_map::iterator findSliceSet = d_nodeSlicing.find(nodeBase);
+  if (findSliceSet == d_nodeSlicing.end()) {
+    nodeSliceSet = d_setCollection.newSet(utils::getSize(nodeBase));
+    d_setCollection.insert(nodeSliceSet, 0);
+    d_nodeSlicing[nodeBase] = nodeSliceSet;
+  } else {
+    nodeSliceSet = d_nodeSlicing[nodeBase];
+  }
+
   Debug("slicing") << "SliceManager::slice(" << node << "): current: " << d_setCollection.toString(nodeSliceSet) << std::endl;
   std::vector<size_t> slicePoints;
-  d_setCollection.getElements(nodeSliceSet, low + 1, high - 1, slicePoints);
-
+  if (low + 1 < high) {
+    d_setCollection.getElements(nodeSliceSet, low + 1, high - 1, slicePoints);
+  }
   // Go through all the points i_0 <= low < i_1 < ... < i_{n-1} < high <= i_n from the slice set
   // and generate the slices [i_0:low-1][low:i_1-1] [i_1:i2] ... [i_{n-1}:high-1][high:i_n-1]. They are in reverse order,
   // as they should be
   size_t i_0 = low == 0 ? 0 : d_setCollection.prev(nodeSliceSet, low + 1);
-  size_t i_n = high == utils::getSize(nodeBase) ? high: d_setCollection.next(nodeSliceSet, high);
+  Debug("slicing") << "SliceManager::slice(" << node << "): i_0: " << i_0 << std::endl;
+  size_t i_n = high == utils::getSize(nodeBase) ? high: d_setCollection.next(nodeSliceSet, high - 1);
+  Debug("slicing") << "SliceManager::slice(" << node << "): i_n: " << i_n << std::endl;
 
   // Add the new points to the slice set (they might be there already)
   if (high < i_n) {
-    std::vector<Node> lastTwoSlices;
-    lastTwoSlices.push_back(utils::mkExtract(nodeBase, i_n-1, high));
-    lastTwoSlices.push_back(utils::mkExtract(nodeBase, high-1, slicePoints.back()));
-    d_equalityEngine.addEquality(utils::mkExtract(nodeBase, i_n-1, slicePoints.back()), utils::mkConcat(lastTwoSlices));
+    if (!addSlice(nodeBase, high)) return false;
   }
-
-  while (!slicePoints.empty()) {
+  // Construct the actuall slicing
+  if (slicePoints.size() > 0) {
+    Debug("slicing") << "SliceManager::slice(" << node << "): adding" << utils::mkExtract(nodeBase, slicePoints[0] - 1, low) << std::endl;
+    sliced.push_back(utils::mkExtract(nodeBase, slicePoints[0] - 1, low));
+    for (unsigned i = 1; i < slicePoints.size(); ++ i) {
+      Debug("slicing") << "SliceManager::slice(" << node << "): adding" << utils::mkExtract(nodeBase, slicePoints[i] - 1, slicePoints[i-1])<< std::endl;
+      sliced.push_back(utils::mkExtract(nodeBase, slicePoints[i] - 1, slicePoints[i-1]));
+    }
+    Debug("slicing") << "SliceManager::slice(" << node << "): adding" << utils::mkExtract(nodeBase, high-1, slicePoints.back()) << std::endl;
     sliced.push_back(utils::mkExtract(nodeBase, high-1, slicePoints.back()));
-    high = slicePoints.back();
-    slicePoints.pop_back();
+  } else {
+    sliced.push_back(utils::mkExtract(nodeBase, high - 1, low));
   }
-
+  // Add the new points to the slice set (they might be there already)
   if (i_0 < low) {
-    std::vector<Node> firstTwoSlices;
-    firstTwoSlices.push_back(utils::mkExtract(nodeBase, high-1, low));
-    firstTwoSlices.push_back(utils::mkExtract(nodeBase, low-1, i_0));
-    d_equalityEngine.addEquality(utils::mkExtract(nodeBase, high-1, i_0), utils::mkConcat(firstTwoSlices));
+    if (!addSlice(nodeBase, low)) return false;
   }
+
+  return true;
 }
 
 template <class TheoryBitvector>
index e183a592c0aa2f292462c26b2e752f15849cf315..2d823383c38720196cbe8fca522dcb69b6a1932c 100644 (file)
@@ -35,7 +35,17 @@ void TheoryBV::preRegisterTerm(TNode node) {
 
   if (node.getKind() == kind::EQUAL) {
     d_eqEngine.addTerm(node[0]);
+    if (node[0].getKind() == kind::BITVECTOR_CONCAT) {
+      for (unsigned i = 0, i_end = node[0].getNumChildren(); i < i_end; ++ i) {
+        d_eqEngine.addTerm(node[0][i]);
+      }
+    }
     d_eqEngine.addTerm(node[1]);
+    if (node[1].getKind() == kind::BITVECTOR_CONCAT) {
+      for (unsigned i = 0, i_end = node[1].getNumChildren(); i < i_end; ++ i) {
+        d_eqEngine.addTerm(node[1][i]);
+      }
+    }
     size_t triggerId = d_eqEngine.addTrigger(node[0], node[1]);
     Assert(triggerId == d_triggers.size());
     d_triggers.push_back(node);
@@ -57,29 +67,27 @@ void TheoryBV::check(Effort e) {
     // Do the right stuff
     switch (assertion.getKind()) {
     case kind::EQUAL: {
-
-      // Slice the equality
-      std::vector<Node> lhsSlices, rhsSlices;
-      d_sliceManager.addEquality(assertion[0], assertion[1], lhsSlices, rhsSlices);
-      Assert(lhsSlices.size() == rhsSlices.size());
-
-      // Add the equality to the equality engine
-      for (int i = 0, i_end = lhsSlices.size(); i != i_end; ++ i) {
-        bool ok = d_eqEngine.addEquality(lhsSlices[i], rhsSlices[i]);
-        if (!ok) return;
-      }
+      // Slice and solve the equality
+      bool ok = d_sliceManager.solveEquality(assertion[0], assertion[1]);
+      if (!ok) return;
       break;
     }
     case kind::NOT: {
       // We need to check this as the equality trigger might have been true when we made it
       TNode equality = assertion[0];
 
+      // Assumptions
+      std::set<TNode> assumptions;
+      Node lhsNormalized = d_eqEngine.normalize(equality[0], assumptions);
+      Node rhsNormalized = d_eqEngine.normalize(equality[1], assumptions);
+
+      Debug("bitvector") << "TheoryBV::check(" << e << "): normalizes to " << lhsNormalized << " = " << rhsNormalized << std::endl;
+      
       // No need to slice the equality, the whole thing *should* be deduced
-      if (d_eqEngine.areEqual(equality[0], equality[1])) {
-        vector<TNode> assertions;
-        d_eqEngine.getExplanation(equality[0], equality[1], assertions);
-        assertions.push_back(assertion);
-        d_out->conflict(mkAnd(assertions));
+      if (lhsNormalized == rhsNormalized) {
+        Debug("bitvector") << "TheoryBV::check(" << e << "): conflict with " << utils::setToString(assumptions) << std::endl;
+        assumptions.insert(assertion);
+        d_out->conflict(mkConjunction(assumptions));
         return;
       }
       break;
index ede98004f11d79c51eee4027e857ea26665994e9..fa9762cb738271c108e52300fda7d9230b2475e8 100644 (file)
@@ -46,7 +46,34 @@ public:
     }
   };
 
-  typedef EqualityEngine<TheoryBV, EqualityNotify, true, true> BvEqualityEngine;
+  struct BVEqualitySettings {
+    static inline bool descend(TNode node) {
+      return node.getKind() == kind::BITVECTOR_CONCAT || node.getKind() == kind::BITVECTOR_EXTRACT;
+    }
+
+    /** Returns true if node1 has preference to node2 as a representative, otherwise node2 is used */
+    static inline bool mergePreference(TNode node1, unsigned node1size, TNode node2, unsigned node2size) {
+      if (node1.getKind() == kind::CONST_BITVECTOR) {
+        Assert(node2.getKind() != kind::CONST_BITVECTOR);
+        return true;
+      }
+      if (node2.getKind() == kind::CONST_BITVECTOR) {
+        Assert(node1.getKind() != kind::CONST_BITVECTOR);
+        return false;
+      }
+      if (node1.getKind() == kind::BITVECTOR_CONCAT) {
+        Assert(node2.getKind() != kind::BITVECTOR_CONCAT);
+        return true;
+      }
+      if (node2.getKind() == kind::BITVECTOR_CONCAT) {
+        Assert(node1.getKind() != kind::BITVECTOR_CONCAT);
+        return false;
+      }
+      return node2size < node1size;
+    }
+  };
+
+  typedef EqualityEngine<TheoryBV, EqualityNotify, BVEqualitySettings> BvEqualityEngine;
 
 private:
 
index 9b545d25a34d41c5d298c317c973941142cffa11..9b5c8b0f958e134d495ba138c787050c4291535a 100644 (file)
@@ -51,12 +51,12 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) {
       break;
     case kind::BITVECTOR_EXTRACT:
       result = LinearRewriteStrategy<
+                  // Extract over a concatenation is distributed to the appropriate concatenations
+                  RewriteRule<ExtractConcat>,
                   // Extract over a constant gives a constant
                   RewriteRule<ExtractConstant>,
-                  // Extract over an extract is simplified to one extract
+                  // We could get another extract over extract
                   RewriteRule<ExtractExtract>,
-                  // Extract over a concatenation is distributed to the appropriate concatenations
-                  RewriteRule<ExtractConcat>,
                   // At this point only Extract-Whole could apply
                   RewriteRule<ExtractWhole>
                 >::apply(node);
index 6e9dbef3e66ccf5d5fed17963b014a41b705d757..ad924f8a0d7745764132e314851747237062e2fa 100644 (file)
@@ -19,7 +19,9 @@
 
 #pragma once 
 
+#include <set>
 #include <vector>
+#include <sstream>
 #include "expr/node_manager.h"
 
 namespace CVC4 {
@@ -51,6 +53,10 @@ inline Node mkAnd(std::vector<TNode>& children) {
   return NodeManager::currentNM()->mkNode(kind::AND, children);
 }
 
+inline Node mkAnd(std::vector<Node>& children) {
+  return NodeManager::currentNM()->mkNode(kind::AND, children);
+}
+
 inline Node mkExtract(TNode node, unsigned high, unsigned low) {
   Node extractOp = NodeManager::currentNM()->mkConst<BitVectorExtract>(BitVectorExtract(high, low));
   std::vector<Node> children;
@@ -65,10 +71,92 @@ inline Node mkConcat(std::vector<Node>& children) {
     return children[0];
 }
 
+inline Node mkConcat(TNode t1, TNode t2) {
+    return NodeManager::currentNM()->mkNode(kind::BITVECTOR_CONCAT, t1, t2);
+}
+
+
 inline Node mkConst(const BitVector& value) {
   return NodeManager::currentNM()->mkConst<BitVector>(value);
 }
 
+inline void getConjuncts(TNode node, std::set<TNode>& conjuncts) {
+  if (node.getKind() != kind::AND) {
+    conjuncts.insert(node);
+  } else {
+    for (unsigned i = 0; i < node.getNumChildren(); ++ i) {
+      getConjuncts(node[i], conjuncts);
+    }
+  }
+}
+
+inline Node mkConjunction(const std::set<TNode> nodes) {
+  std::set<TNode> expandedNodes;
+
+  std::set<TNode>::const_iterator it = nodes.begin();
+  std::set<TNode>::const_iterator it_end = nodes.end();
+  while (it != it_end) {
+    TNode current = *it;
+    if (current != mkTrue()) {
+      Assert(current != mkFalse());
+      if (current.getKind() == kind::AND) {
+        getConjuncts(current, expandedNodes);
+      } else {
+        expandedNodes.insert(current);
+      }
+    }
+    ++ it;
+  }
+
+  Assert(expandedNodes.size() > 0);
+  if (expandedNodes.size() == 1) {
+    return *expandedNodes.begin();
+  }
+
+  NodeBuilder<> conjunction(kind::AND);
+
+  it = expandedNodes.begin();
+  it_end = expandedNodes.end();
+  while (it != it_end) {
+    conjunction << *it;
+    ++ it;
+  }
+
+  return conjunction;
+}
+
+// Turn a set into a string
+inline std::string setToString(const std::set<TNode>& nodeSet) {
+  std::stringstream out;
+  out << "[";
+  std::set<TNode>::const_iterator it = nodeSet.begin();
+  std::set<TNode>::const_iterator it_end = nodeSet.end();
+  bool first = true;
+  while (it != it_end) {
+    if (!first) {
+      out << ",";
+    }
+    first = false;
+    out << *it;
+    ++ it;
+  }
+  out << "]";
+  return out.str();
+}
+
+// Turn a vector into a string
+inline std::string vectorToString(const std::vector<Node>& nodes) {
+  std::stringstream out;
+  out << "[";
+  for (unsigned i = 0; i < nodes.size(); ++ i) {
+    if (i > 0) {
+      out << ",";
+    }
+    out << nodes[i];
+  }
+  out << "]";
+  return out.str();
+}
 
 }
 }
index d1bfafb00f881583532f39337a8ca520ebe10301..ca69fb50658902c53856ce3d65cf90a533dd142a 100644 (file)
@@ -98,7 +98,7 @@ public:
     return BitVector(d_size + other.d_size, (d_value * Integer(2).pow(other.d_size)) + other.d_value);
   }
 
-  BitVector extract(unsigned high, unsigned low) {
+  BitVector extract(unsigned high, unsigned low) const {
     return BitVector(high - low + 1, (d_value % (Integer(2).pow(high + 1))) / Integer(2).pow(low));
   }
 
index 5b8e6d7d308cb53f84776a1597baa4956e587591..947dc65ad044e3b064efebc644c205a5bbf41d64 100644 (file)
@@ -41,7 +41,27 @@ TESTS =      \
        equality-00.smt \
        equality-01.smt \
        equality-02.smt \
-       bv_eq_diamond10.smt
+       bv_eq_diamond10.smt \
+       slice-01.smt \
+       slice-02.smt \
+       slice-03.smt \
+       slice-04.smt \
+       slice-05.smt \
+       slice-06.smt \
+       slice-07.smt \
+       slice-08.smt \
+       slice-09.smt \
+       slice-10.smt \
+       slice-11.smt \
+       slice-12.smt \
+       slice-13.smt \
+       slice-14.smt \
+       slice-15.smt \
+       slice-16.smt \
+       slice-17.smt \
+       slice-18.smt \
+       slice-19.smt \
+       slice-20.smt
        
 EXTRA_DIST = $(TESTS)
 
diff --git a/test/regress/regress0/bv/core/slice-20.cvc b/test/regress/regress0/bv/core/slice-20.cvc
new file mode 100644 (file)
index 0000000..a211b5f
--- /dev/null
@@ -0,0 +1,13 @@
+x1, y1: BITVECTOR(4);
+x2, y2: BITVECTOR(2);
+x3, y3: BITVECTOR(1);
+
+ASSERT(x1 = y1);
+
+ASSERT(x1 = x2 @ x2);
+ASSERT(x2 = x3 @ x3);
+
+ASSERT(y1 = y2 @ y2);
+ASSERT(y2 = y3 @ y3);
+
+QUERY(x3 = y3);
\ No newline at end of file