From af6ac1f03a841a0261190cb7caa15ff1fa1f798c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Dejan=20Jovanovi=C4=87?= Date: Sun, 20 Mar 2011 01:12:31 +0000 Subject: [PATCH] commit for the version of bitvectors that passes all the unit tests --- src/theory/bv/cd_set_collection.h | 109 +++--- src/theory/bv/equality_engine.h | 299 +++++++--------- src/theory/bv/slice_manager.h | 384 +++++++++++++++++---- src/theory/bv/theory_bv.cpp | 40 ++- src/theory/bv/theory_bv.h | 29 +- src/theory/bv/theory_bv_rewriter.cpp | 6 +- src/theory/bv/theory_bv_utils.h | 88 +++++ src/util/bitvector.h | 2 +- test/regress/regress0/bv/core/Makefile.am | 22 +- test/regress/regress0/bv/core/slice-20.cvc | 13 + 10 files changed, 679 insertions(+), 313 deletions(-) create mode 100644 test/regress/regress0/bv/core/slice-20.cvc diff --git a/src/theory/bv/cd_set_collection.h b/src/theory/bv/cd_set_collection.h index 33648660b..aeb28ab7b 100644 --- a/src/theory/bv/cd_set_collection.h +++ b/src/theory/bv/cd_set_collection.h @@ -123,28 +123,30 @@ public: return newElement(value, null, null, null, false); } - void insert(memory_type& memory, reference_type root, const value_type& value) { + void insert(reference_type& root, const value_type& value) { backtrack(); if (root == null) { - return newSet(value); + root = newSet(value); + return; } // We already have a set, find the spot reference_type parent = null; + reference_type current = root; while (true) { - parent = root; - if (value < d_memory[root].value) { - root = d_memory[root].left; - if (root == null) { - root = newElement(value, null, null, parent, true); - d_memory[parent].left = root; + parent = current; + if (value < d_memory[current].getValue()) { + if (d_memory[current].hasLeft()) { + current = d_memory[current].getLeft(); + } else { + d_memory[current].setLeft(newElement(value, null, null, parent, true)); return; } } else { - Assert(value != d_memory[root].value); - root = d_memory[root].right; - if (root == null) { - root = newElement(value, null, null, parent, false); - d_memory[parent].right = root; + Assert(value != d_memory[root].getValue()); + if (d_memory[current].hasRight()) { + current = d_memory[current].getRight(); + } else { + d_memory[parent].setRight(newElement(value, null, null, parent, false)); return; } } @@ -174,49 +176,55 @@ public: */ const_value_reference prev(reference_type set, const_value_reference value) { backtrack(); - // Get the node of this value - reference_type node_ref = find(set, value); - Assert(node_ref != null); - const tree_entry_type& node = d_memory[node_ref]; - // For a left node, we know that it is smaller than all the parents and the parents other children - // The smaller node must then be the max of the left subtree - if (!node.hasParent() || node.isLeft()) { - return maxElement(node.getLeft()); - } - // For a right node, we know that it is bigger than the parent. But, we also know that the left subtree - // is also bigger than the parent - else { - if (node.hasLeft()) { - return maxElement(node.getLeft()); + + const_value_reference candidate_value; + bool candidate_found = false; + + // Find the biggest node smaleer than value (it must exist) + while (set != null) { + Debug("set_collection") << "BacktrackableSetCollection::getPrev(" << toString(set) << "," << value << ")" << std::endl; + const tree_entry_type& node = d_memory[set]; + if (node.getValue() >= value) { + // If the node is bigger than the value, we need a smaller one + set = node.getLeft(); } else { - Assert(node.hasParent()); - return d_memory[node.getParent()].getValue(); + // The node is smaller than the value + candidate_found = true; + candidate_value = node.getValue(); + // There might be a bigger one + set = node.getRight(); } } + + Assert(candidate_found); + return candidate_value; } const_value_reference next(reference_type set, const_value_reference value) { backtrack(); - // Get the node of this value - reference_type node_ref = find(set, value); - Assert(node_ref != null); - const tree_entry_type& node = d_memory[node_ref]; - // For a right node, we know that it is bigger than all the parents and the parents other children - // The bigger node must then be the min of the right subtree - if (!node.hasParent() || node.isRight()) { - return minElement(node.getRight()); - } - // For a left node, we know that it is smaller than the parent. But, we also know that the right subtree - // is also smaller than the parent - else { - if (node.hasRight()) { - return minElement(node.getRight()); + + const_value_reference candidate_value; + bool candidate_found = false; + + // Find the smallest node bigger than value (it must exist) + while (set != null) { + Debug("set_collection") << "BacktrackableSetCollection::getNext(" << toString(set) << "," << value << ")" << std::endl; + const tree_entry_type& node = d_memory[set]; + if (node.getValue() <= value) { + // If the node is smaller than the value, we need a bigger one + set = node.getRight(); } else { - Assert(node.hasParent()); - return d_memory[node.getParent()].getValue(); + // The node is bigger than the value + candidate_found = true; + candidate_value = node.getValue(); + // There might be a smaller one + set = node.getLeft(); } } - } + + Assert(candidate_found); + return candidate_value; +} /** * Count the number of elements in the given bounds. @@ -262,6 +270,9 @@ public: void getElements(reference_type set, const_value_reference lowerBound, const_value_reference upperBound, std::vector& output) const { Assert(lowerBound <= upperBound); backtrack(); + + Debug("set_collection") << "BacktrackableSetCollection::getElements(" << toString(set) << "," << lowerBound << "," << upperBound << ")" << std::endl; + // Empty set no elements if (set == null) { return; @@ -277,7 +288,7 @@ public: output.push_back(current.getValue()); } // Right child (bigger elements) - if (current.getValue() <= upperBound) { + if (current.getValue() < upperBound) { getElements(current.getRight(), lowerBound, upperBound, output); } } @@ -285,7 +296,7 @@ public: /** * Print the list of elements to the output. */ - void print(std::ostream& out, reference_type set) { + void print(std::ostream& out, reference_type set) const { backtrack(); if (set == null) { return; @@ -305,7 +316,7 @@ public: /** * String representation of a set. */ - std::string toString(reference_type set) { + std::string toString(reference_type set) const { std::stringstream out; print(out, set); return out.str(); diff --git a/src/theory/bv/equality_engine.h b/src/theory/bv/equality_engine.h index 9880539ed..53c44bed0 100644 --- a/src/theory/bv/equality_engine.h +++ b/src/theory/bv/equality_engine.h @@ -28,13 +28,13 @@ #include "context/cdo.h" #include "util/output.h" #include "util/stats.h" +#include "theory/rewriter.h" namespace CVC4 { namespace theory { namespace bv { struct BitSizeTraits { - /** The null id */ static const size_t id_null; // Defined in the cpp file (GCC bug) /** The null trigger id */ @@ -46,13 +46,6 @@ struct BitSizeTraits { static const size_t size_bits = 16; /** Number of bits we use for the trigger id */ static const size_t trigger_id_bits = 24; - - /** Number of bits we use for the function ids */ - static const size_t function_id_bits = 8; - /** Number of bits we use for the function arguments count */ - static const size_t function_arguments_count_bits = 16; - /** Number of bits we use for the index into the arguments memory */ - static const size_t function_arguments_index_bits = 24; }; class EqualityNode { @@ -68,22 +61,18 @@ public: /** The next equality node in this class */ size_t d_nextId : BitSizeTraits::id_bits; - /** Is this node a function application */ - size_t d_isFunction : 1; - public: /** * Creates a new node, which is in a list of it's own. */ EqualityNode(size_t nodeId = BitSizeTraits::id_null) - : d_size(1), d_findId(nodeId), d_nextId(nodeId), d_isFunction(0) {} + : d_size(1), d_findId(nodeId), d_nextId(nodeId) {} /** Initialize the equality node */ - inline void init(size_t nodeId, bool isFunction) { + inline void init(size_t nodeId) { d_size = 1; d_findId = d_nextId = nodeId; - d_isFunction = isFunction; } /** @@ -125,66 +114,11 @@ public: inline void setFind(size_t findId) { d_findId = findId; } }; -/** - * FunctionNode class represents the information related to a function node. It has an id, number of children - * and the - */ -class FunctionNode { - - /** Is the function associative */ - size_t d_isAssociative : 1; - /** The id of the function */ - size_t d_functionId : BitSizeTraits::function_id_bits; - /** Number of children */ - size_t d_argumentsCount : BitSizeTraits::function_arguments_count_bits; - /** Index of the start of the arguments in the children array */ - size_t d_argumentsIndex : BitSizeTraits::function_arguments_index_bits; - -public: - - FunctionNode(size_t functionId = 0, size_t argumentsCount = 0, size_t argumentsIndex = 0, bool associative = false) - : d_isAssociative(associative), d_functionId(functionId), d_argumentsCount(argumentsCount), d_argumentsIndex(argumentsIndex) - {} - - void init(size_t functionId, size_t argumentsCount, size_t argumentsIndex, bool associative) { - d_functionId = functionId; - d_argumentsCount = argumentsCount; - d_argumentsIndex = argumentsIndex; - d_isAssociative = associative; - } - - /** Check if the function is associative */ - bool isAssociative() const { return d_isAssociative; } - - /** Get the function id */ - size_t getFunctionId() const { return d_functionId; } - - /** Get the number of arguments */ - size_t getArgumentsCount() const { return d_argumentsCount; } - - /** Get the infex of the first argument in the arguments memory */ - size_t getArgumentsIndex() const { return d_argumentsIndex; } - -}; - -template +template class EqualityEngine { public: - /** - * Basic information about a function. - */ - struct FunctionInfo { - /** Name of the function */ - std::string name; - /** Is the function associative */ - bool isAssociative; - - FunctionInfo(std::string name, bool isAssociative) - : name(name), isAssociative(isAssociative) {} - }; - /** Statistics about the equality engine instance */ struct Statistics { /** Total number of merges */ @@ -193,8 +127,6 @@ public: IntStat termsCount; /** Number of function terms managed by the system */ IntStat functionTermsCount; - /** Number of distince functions managed by the system */ - IntStat functionsCount; /** Number of times we performed a backtrack */ IntStat backtracksCount; @@ -202,13 +134,11 @@ public: : mergesCount(name + "::mergesCount", 0), termsCount(name + "::termsCount", 0), functionTermsCount(name + "functionTermsCoutn", 0), - functionsCount(name + "::functionsCount", 0), backtracksCount(name + "::backtracksCount", 0) { StatisticsRegistry::registerStat(&mergesCount); StatisticsRegistry::registerStat(&termsCount); StatisticsRegistry::registerStat(&functionTermsCount); - StatisticsRegistry::registerStat(&functionsCount); StatisticsRegistry::registerStat(&backtracksCount); } @@ -216,7 +146,6 @@ public: StatisticsRegistry::unregisterStat(&mergesCount); StatisticsRegistry::unregisterStat(&termsCount); StatisticsRegistry::unregisterStat(&functionTermsCount); - StatisticsRegistry::unregisterStat(&functionsCount); StatisticsRegistry::unregisterStat(&backtracksCount); } }; @@ -238,12 +167,6 @@ private: /** Number of asserted equalities we have so far */ context::CDO d_assertedEqualitiesCount; - /** Map from ids to functional representations */ - std::vector d_functionNodes; - - /** Functions in the system */ - std::vector d_functions; - /** * We keep a list of asserted equalities. Not among original terms, but * among the class representatives. @@ -261,6 +184,8 @@ private: /** The ids of the classes we have merged */ std::vector d_assertedEqualities; + /** The reasons for the equalities */ + /** * An edge in the equality graph. This graph is an undirected graph (both edges added) * containing the actual asserted equalities. @@ -291,6 +216,11 @@ private: */ std::vector d_equalityEdges; + /** + * Reasons for equalities. + */ + std::vector d_equalityReasons; + /** * Map from a node to it's first edge in the equality graph. Edges are added to the front of the * list which makes the insertion/backtracking easy. @@ -298,7 +228,7 @@ private: std::vector d_equalityGraph; /** Add an edge to the equality graph */ - inline void addGraphEdge(size_t t1, size_t t2); + inline void addGraphEdge(size_t t1, size_t t2, Node reason); /** Returns the equality node of the given node */ inline EqualityNode& getEqualityNode(TNode node); @@ -385,11 +315,6 @@ public: */ size_t addTerm(TNode t); - /** - * Adds a term that is an application of a function symbol to the databas. Returns the internal id of the term. - */ - size_t addFunctionApplication(size_t funcionId, const std::vector& arguments); - /** * Check whether the node is already in the database. */ @@ -398,7 +323,7 @@ public: /** * Adds an equality t1 = t2 to the database. Returns false if any of the triggers failed. */ - bool addEquality(TNode t1, TNode t2); + bool addEquality(TNode t1, TNode t2, Node reason); /** * Returns the representative of the term t. @@ -424,23 +349,27 @@ public: size_t addTrigger(TNode t1, TNode t2); /** - * Adds a new function to the equality engine. The funcions are not of fixed arity and no typechecking is performed! - * Associative functions allow for normalization, i.e. f(f(x, y), z) = f(x, f(y, z)) = f(x, y, z). - * @associative should be true if the function is associative and you want this to be handled by the engine + * Normalizes a term by finding the representative. If the representative can be decomposed (using + * UnionFindPreferences) it will try and recursively find the representatives, and substitute. + * Assumptions used in normalization are retruned in the set. */ - inline size_t newFunction(std::string name, bool associative) { - Assert(use_functions); - Assert(!associative || enable_associative); - ++ d_stats.functionsCount; - size_t id = d_functions.size(); - d_functions.push_back(FunctionInfo(name, associative)); - return id; - } + Node normalize(TNode node, std::set& assumptions); + +private: + + /** Hash of normalizations to avioid cycles */ + typedef __gnu_cxx::hash_map normalization_cache; + normalization_cache d_normalizationCache; + + /** + * Same as above, but does cahcing to avoid loops. + */ + Node normalizeWithCache(TNode node, std::set& assumptions); }; -template -size_t EqualityEngine::addTerm(TNode t) { +template +size_t EqualityEngine::addTerm(TNode t) { Debug("equality") << "EqualityEngine::addTerm(" << t << ")" << std::endl; @@ -462,69 +391,35 @@ size_t EqualityEngine -size_t EqualityEngine::addFunctionApplication(size_t functionId, const std::vector& arguments) { - - Debug("equality") << "EqualityEngine::addFunctionApplication(" << d_functions[functionId].name << ":" << arguments.size() << ")" << std::endl; - - ++ d_stats.functionTermsCount; - ++ d_stats.termsCount; - - // Register the new id of the term - size_t newId = d_nodes.size(); - // Add the node to it's position - d_nodes.push_back(Node()); - // Add the trigger list for this node - d_nodeTriggers.push_back(BitSizeTraits::trigger_id_null); - // Add it to the equality graph - d_equalityGraph.push_back(BitSizeTraits::id_null); - // Add the equality node to the nodes - if (d_equalityNodes.size() <= newId) { - d_equalityNodes.resize(newId + 100); - } - d_equalityNodes[newId].init(newId, true); - // Add the function application to the function nodes - if (d_functionNodes.size() <= newId) { - d_functionNodes.resize(newId + 100); - } - // Initialize the function node - size_t argumentsIndex; - d_functionNodes[newId].init(functionId, arguments.size(), argumentsIndex, d_functions[functionId].isAssociative); - + d_equalityNodes[newId].init(newId); // Return the id of the term return newId; - } -template -bool EqualityEngine::hasTerm(TNode t) const { +template +bool EqualityEngine::hasTerm(TNode t) const { return d_nodeIds.find(t) != d_nodeIds.end(); } -template -size_t EqualityEngine::getNodeId(TNode node) const { - Assert(hasTerm(node)); +template +size_t EqualityEngine::getNodeId(TNode node) const { + Assert(hasTerm(node), node.toString().c_str()); return (*d_nodeIds.find(node)).second; } -template -EqualityNode& EqualityEngine::getEqualityNode(TNode t) { +template +EqualityNode& EqualityEngine::getEqualityNode(TNode t) { return getEqualityNode(getNodeId(t)); } -template -EqualityNode& EqualityEngine::getEqualityNode(size_t nodeId) { +template +EqualityNode& EqualityEngine::getEqualityNode(size_t nodeId) { Assert(nodeId < d_equalityNodes.size()); return d_equalityNodes[nodeId]; } -template -bool EqualityEngine::addEquality(TNode t1, TNode t2) { +template +bool EqualityEngine::addEquality(TNode t1, TNode t2, Node reason) { Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << ")" << std::endl; @@ -549,18 +444,20 @@ bool EqualityEngine: Assert(node1.getFind() == t1classId); Assert(node2.getFind() == t2classId); - // Depending on the size, merge them + // Depending on the merge preference (such as size), merge them std::vector triggers; - if (node1.getSize() < node2.getSize()) { + if (UnionFindPreferences::mergePreference(d_nodes[t2classId], node2.getSize(), d_nodes[t1classId], node1.getSize())) { + Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << "): merging " << t1 << " into " << t2 << std::endl; merge(node2, node1, triggers); d_assertedEqualities.push_back(Equality(t2classId, t1classId)); } else { + Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << "): merging " << t2 << " into " << t1 << std::endl; merge(node1, node2, triggers); d_assertedEqualities.push_back(Equality(t1classId, t2classId)); } // Add the actuall equality to the equality graph - addGraphEdge(t1Id, t2Id); + addGraphEdge(t1Id, t2Id, reason); // One more equality added d_assertedEqualitiesCount = d_assertedEqualitiesCount + 1; @@ -577,8 +474,8 @@ bool EqualityEngine: return true; } -template -TNode EqualityEngine::getRepresentative(TNode t) const { +template +TNode EqualityEngine::getRepresentative(TNode t) const { Debug("equality") << "EqualityEngine::getRepresentative(" << t << ")" << std::endl; @@ -593,8 +490,8 @@ TNode EqualityEngine return d_nodes[representativeId]; } -template -bool EqualityEngine::areEqual(TNode t1, TNode t2) const { +template +bool EqualityEngine::areEqual(TNode t1, TNode t2) const { Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ")" << std::endl; Assert(hasTerm(t1)); @@ -610,8 +507,8 @@ bool EqualityEngine: return rep1 == rep2; } -template -void EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vector& triggers) { +template +void EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vector& triggers) { Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << ")" << std::endl; @@ -660,8 +557,8 @@ void EqualityEngine: class1.merge(class2); } -template -void EqualityEngine::undoMerge(EqualityNode& class1, EqualityNode& class2, size_t class2Id) { +template +void EqualityEngine::undoMerge(EqualityNode& class1, EqualityNode& class2, size_t class2Id) { Debug("equality") << "EqualityEngine::undoMerge(" << class1.getFind() << "," << class2Id << ")" << std::endl; @@ -692,8 +589,8 @@ void EqualityEngine: } -template -void EqualityEngine::backtrack() { +template +void EqualityEngine::backtrack() { // If we need to backtrack then do it if (d_assertedEqualitiesCount < d_assertedEqualities.size()) { @@ -721,22 +618,24 @@ void EqualityEngine: } d_equalityEdges.resize(2 * d_assertedEqualitiesCount); + d_equalityReasons.resize(d_assertedEqualitiesCount); } } -template -void EqualityEngine::addGraphEdge(size_t t1, size_t t2) { +template +void EqualityEngine::addGraphEdge(size_t t1, size_t t2, Node reason) { Debug("equality") << "EqualityEngine::addGraphEdge(" << d_nodes[t1] << "," << d_nodes[t2] << ")" << std::endl; size_t edge = d_equalityEdges.size(); d_equalityEdges.push_back(EqualityEdge(t2, d_equalityGraph[t1])); d_equalityEdges.push_back(EqualityEdge(t1, d_equalityGraph[t2])); d_equalityGraph[t1] = edge; d_equalityGraph[t2] = edge | 1; + d_equalityReasons.push_back(reason); } -template -void EqualityEngine::getExplanation(TNode t1, TNode t2, std::vector& equalities) const { +template +void EqualityEngine::getExplanation(TNode t1, TNode t2, std::vector& equalities) const { Assert(equalities.empty()); Assert(t1 != t2); Assert(getRepresentative(t1) == getRepresentative(t2)); @@ -784,15 +683,9 @@ void EqualityEngine: // Reconstruct the path do { - // Get the left and right hand side from the edge - size_t firstEdge = (currentEdge >> 1) << 1; - size_t secondEdge = (currentEdge | 1); - TNode lhs = d_nodes[d_equalityEdges[secondEdge].getNodeId()]; - TNode rhs = d_nodes[d_equalityEdges[firstEdge].getNodeId()]; // Add the actual equality to the vector - equalities.push_back(lhs.eqNode(rhs)); - - Debug("equality") << "EqualityEngine::getExplanation(): adding: " << lhs.eqNode(rhs) << std::endl; + equalities.push_back(d_equalityReasons[currentEdge >> 1]); + Debug("equality") << "EqualityEngine::getExplanation(): adding: " << d_equalityReasons[currentEdge >> 1] << std::endl; // Go to the previous currentEdge = bfsQueue[currentIndex].edgeId; @@ -816,8 +709,8 @@ void EqualityEngine: } } -template -size_t EqualityEngine::addTrigger(TNode t1, TNode t2) { +template +size_t EqualityEngine::addTrigger(TNode t1, TNode t2) { Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << ")" << std::endl; @@ -850,6 +743,64 @@ size_t EqualityEngine +Node EqualityEngine::normalize(TNode node, std::set& assumptions) { + d_normalizationCache.clear(); + Node result = Rewriter::rewrite(normalizeWithCache(node, assumptions)); + d_normalizationCache.clear(); + return result; +} + + +template +Node EqualityEngine::normalizeWithCache(TNode node, std::set& assumptions) { + + Debug("equality") << "EqualityEngine::normalize(" << node << ")" << push << std::endl; + + normalization_cache::iterator find = d_normalizationCache.find(node); + if (find != d_normalizationCache.end()) { + if (find->second.isNull()) { + // We are in a cycle + return node; + } else { + // Not in a cycle, return it + return find->second; + } + } else { + d_normalizationCache[node] = Node(); + } + + // Get the representative + Node result = hasTerm(node) ? getRepresentative(node) : node; + if (node != result) { + std::vector equalities; + getExplanation(result, node, equalities); + assumptions.insert(equalities.begin(), equalities.end()); + } + + // If asked, substitute the children with their representatives + if (UnionFindPreferences::descend(result)) { + // Make the builder for substitution + NodeBuilder<> builder; + builder << result.getKind(); + kind::MetaKind metaKind = result.getMetaKind(); + if (metaKind == kind::metakind::PARAMETERIZED) { + builder << result.getOperator(); + } + for (unsigned i = 0; i < result.getNumChildren(); ++ i) { + builder << normalizeWithCache(result[i], assumptions); + } + result = builder; + } + + Debug("equality") << "EqualityEngine::normalize(" << node << ") => " << result << pop << std::endl; + + // Cache the result for real now + d_normalizationCache[node] = result; + + return result; +} + } // Namespace bv } // Namespace theory } // Namespace CVC4 diff --git a/src/theory/bv/slice_manager.h b/src/theory/bv/slice_manager.h index 8fc1e0b9d..96a0067dc 100644 --- a/src/theory/bv/slice_manager.h +++ b/src/theory/bv/slice_manager.h @@ -13,6 +13,7 @@ #include "theory/bv/cd_set_collection.h" #include +#include #include namespace CVC4 { @@ -134,9 +135,6 @@ private: /** The equality engine */ EqualityEngine& d_equalityEngine; - /** The id of the concatenation function */ - size_t d_concatFunctionId; - /** The collection of backtrackable sets */ set_collection d_setCollection; @@ -153,27 +151,35 @@ public: SliceManager(TheoryBitvector& theoryBitvector, context::Context* context) : d_theoryBitvector(theoryBitvector), d_equalityEngine(theoryBitvector.getEqualityEngine()), d_setCollection(context) { - // We register the concatentation with the equality engine - d_concatFunctionId = d_equalityEngine.newFunction("bv_concat", true); } - inline size_t getConcatFunctionId() const { return d_concatFunctionId; } - /** - * Adds the equality (lhs = rhs) to the slice manager. This will not add the equalities to the equality manager, - * but will slice the equality according to the current slicing in order to align all the slices. The terms that - * get slices get sent to the theory engine as equalities, i.e if we slice x[10:0] into x[10:5]@x[4:0] equality - * engine gets the assertion x[10:0] = concat(x[10:5], x[4:0]). + * Adds the equality (lhs = rhs) to the slice manager. The equality is first normalized according to the equality + * manager, i.e. each base term is taken from the equality manager, replaced in, and then the whole concatenation + * normalized and sliced wrt the current slicing. The method will not add the equalities to the equality manager, + * but instead will slice the equality according to the current slicing in order to align all the slices. + * + * The terms that get sliced get sent to the theory engine as equalities, i.e if we slice x[10:0] into x[10:5]@x[4:0] + * equality engine gets the assertion x[10:0] = concat(x[10:5], x[4:0]). + * + * input output slicing + * -------------------------------------------------------------------------------------------------------------- + * x@y = y@x x = y, y = x empty + * x[31:0]@x[64:32] = x x = x[31:0]@x[63:32] x:{64,32,0} + * x@y = 0000@x@0000 x = 0000@x[7:4], y = x[3:0]@0000 x:{8,4,0} + * */ - inline void addEquality(TNode lhs, TNode rhs, std::vector& lhsSlices, std::vector& rhsSlices); + inline bool solveEquality(TNode lhs, TNode rhs); private: + inline bool solveEquality(TNode lhs, TNode rhs, const std::set& assumptions); + /** - * Slices up lhs and rhs and returns the slices in lhsSlices and rhsSlices + * Slices up lhs and rhs and returns the slices in lhsSlices and rhsSlices. The slices are not atomic, + * they are sliced in order to make one of lhs or rhs atomic, the other one can be a concatenation. */ - inline void slice(std::vector& lhs, std::vector& rhs, - std::vector& lhsSlices, std::vector& rhsSlices); + inline bool sliceAndSolve(std::vector& lhs, std::vector& rhs, const std::set& assumptions); /** * Returns true if the term is already sliced wrt the current slicing. Note that, for example, even though @@ -184,7 +190,7 @@ private: /** * Slices the term wrt the current slicing. When done, isSliced returns true */ - inline void slice(TNode node, std::vector& sliced); + inline bool slice(TNode node, std::vector& sliced); /** * Returns the base term in the core theory of the given term, i.e. @@ -196,20 +202,87 @@ private: static inline TNode baseTerm(TNode node); /** - * Adds a new slice to the slice set of the given base term. + * Adds a new slice to the slice set of the given term. */ - inline void addSlice(Node baseTerm, unsigned slicePoint); + inline bool addSlice(Node term, unsigned slicePoint); }; template -void SliceManager::addSlice(Node baseTerm, unsigned slicePoint) { +bool SliceManager::addSlice(Node node, unsigned slicePoint) { + Debug("slicing") << "SliceMagager::addSlice(" << node << "," << slicePoint << ")" << std::endl; + + bool ok = true; + + int low = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractLow(node) : 0; + int high = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractHigh(node) + 1: utils::getSize(node); + slicePoint += low; + + TNode nodeBase = baseTerm(node); + + set_reference sliceSet; + slicing_map::iterator find = d_nodeSlicing.find(nodeBase); + if (find == d_nodeSlicing.end()) { + sliceSet = d_nodeSlicing[nodeBase] = d_setCollection.newSet(slicePoint); + d_setCollection.insert(sliceSet, low); + d_setCollection.insert(sliceSet, high); + } else { + sliceSet = find->second; + } + + // What are the points surrounding the new slice point + int prev = d_setCollection.prev(sliceSet, slicePoint); + int next = d_setCollection.next(sliceSet, slicePoint); + + // Add the slice to the set + d_setCollection.insert(sliceSet, slicePoint); + Debug("slicing") << "SliceMagager::addSlice(" << node << "," << slicePoint << "): current set " << d_setCollection.toString(sliceSet) << std::endl; + + // Add the terms and the equality to the equality engine + Node t1 = utils::mkExtract(nodeBase, next - 1, slicePoint); + Node t2 = utils::mkExtract(nodeBase, slicePoint - 1, prev); + Node nodeSlice = (next == high && prev == low) ? node : utils::mkExtract(nodeBase, next - 1, prev); + Node concat = utils::mkConcat(t1, t2); + + d_equalityEngine.addTerm(t1); + d_equalityEngine.addTerm(t2); + d_equalityEngine.addTerm(nodeSlice); + d_equalityEngine.addTerm(concat); + + // We are free to add this slice, unless the slice has a representative that's already a concat + TNode nodeSliceRepresentative = d_equalityEngine.getRepresentative(nodeSlice); + if (nodeSliceRepresentative.getKind() != kind::BITVECTOR_CONCAT) { + // Add the slice to the equality engine + ok = d_equalityEngine.addEquality(nodeSlice, concat, utils::mkTrue()); + } else { + // If the representative is a concat, we must solve it + // There is no need do add nodeSlice = concat as we will solve the representative of nodeSlice + std::set assumptions; + std::vector equalities; + d_equalityEngine.getExplanation(nodeSlice, nodeSliceRepresentative, equalities); + assumptions.insert(equalities.begin(), equalities.end()); + ok = solveEquality(nodeSliceRepresentative, concat, assumptions); + } + + Debug("slicing") << "SliceMagager::addSlice(" << node << "," << slicePoint << ") => " << d_setCollection.toString(d_nodeSlicing[nodeBase]) << std::endl; + + return ok; } template -void SliceManager::addEquality(TNode lhs, TNode rhs, std::vector& lhsSlices, std::vector& rhsSlices) { +bool SliceManager::solveEquality(TNode lhs, TNode rhs) { + std::set assumptions; + assumptions.insert(lhs.eqNode(rhs)); + bool ok = solveEquality(lhs, rhs, assumptions); + return ok; +} + +template +bool SliceManager::solveEquality(TNode lhs, TNode rhs, const std::set& assumptions) { + + Debug("slicing") << "SliceMagager::solveEquality(" << lhs << "," << rhs << "," << utils::setToString(assumptions) << ")" << push << std::endl; - Debug("slicing") << "SliceMagager::addEquality(" << lhs << "," << rhs << ")" << std::endl; + bool ok; // The concatenations on the left-hand side (reverse order, first is on top) std::vector lhsTerms; @@ -232,60 +305,213 @@ void SliceManager::addEquality(TNode lhs, TNode rhs, std::vecto } // Slice the individual terms to align them - slice(lhsTerms, rhsTerms, lhsSlices, rhsSlices); + ok = sliceAndSolve(lhsTerms, rhsTerms, assumptions); + + Debug("slicing") << "SliceMagager::solveEquality(" << lhs << "," << rhs << "," << utils::setToString(assumptions) << ")" << pop << std::endl; + + return ok; } + template -void SliceManager::slice(std::vector& lhs, std::vector& rhs, - std::vector& lhsSlices, std::vector& rhsSlices) { +bool SliceManager::sliceAndSolve(std::vector& lhs, std::vector& rhs, const std::set& assumptions) +{ - Debug("slicing") << "SliceManager::slice()" << std::endl; + Debug("slicing") << "SliceManager::sliceAndSolve()" << std::endl; - // Go through the work-list and align + // Go through the work-list, solve and align while (!lhs.empty()) { Assert(!rhs.empty()); + Debug("slicing") << "SliceManager::sliceAndSolve(): lhs " << utils::vectorToString(lhs) << std::endl; + Debug("slicing") << "SliceManager::sliceAndSolve(): rhs " << utils::vectorToString(rhs) << std::endl; + // The terms that we need to slice Node lhsTerm = lhs.back(); Node rhsTerm = rhs.back(); - Debug("slicing") << "slicing: " << lhsTerm << " and " << rhsTerm << std::endl; + + Debug("slicing") << "SliceManager::sliceAndSolve(): " << lhsTerm << " : " << rhsTerm << std::endl; // If the terms are not sliced wrt the current slicing, we have them sliced lhs.pop_back(); if (!isSliced(lhsTerm)) { - slice(lhsTerm, lhs); + if (!slice(lhsTerm, lhs)) return false; + Debug("slicing") << "SliceManager::sliceAndSolve(): lhs sliced" << std::endl; continue; } rhs.pop_back(); if (!isSliced(rhsTerm)) { - slice(rhsTerm, rhs); + if (!slice(rhsTerm, rhs)) return false; + // We also need to put lhs back + lhs.push_back(lhsTerm); + Debug("slicing") << "SliceManager::sliceAndSolve(): rhs sliced" << std::endl; + continue; } + Debug("slicing") << "SliceManager::sliceAndSolve(): both lhs and rhs sliced already" << std::endl; + + // The solving concatenation + std::vector concatTerms; + // If the slices are of the same size we do the additional work - unsigned lhsSize = utils::getSize(lhsTerm); - unsigned rhsSize = utils::getSize(rhsTerm); - if (lhsSize == rhsSize) { - // If they are over the same base terms, we need to do something - TNode lhsBaseTerm = baseTerm(lhsTerm); - TNode rhsBaseTerm = baseTerm(rhsTerm); - if (lhsBaseTerm == rhsBaseTerm) { - // x[i_1:j_1] vs x[i_2:j_2] + int sizeDifference = utils::getSize(lhsTerm) - utils::getSize(rhsTerm); + + // We slice constants immediately + if (sizeDifference > 0 && lhsTerm.getKind() == kind::CONST_BITVECTOR) { + BitVector low = lhsTerm.getConst().extract(utils::getSize(rhsTerm) - 1, 0); + BitVector high = lhsTerm.getConst().extract(utils::getSize(lhsTerm) - 1, utils::getSize(rhsTerm)); + lhs.push_back(utils::mkConst(low)); + lhs.push_back(utils::mkConst(high)); + rhs.push_back(rhsTerm); + continue; + } + if (sizeDifference < 0 && rhsTerm.getKind() == kind::CONST_BITVECTOR) { + BitVector low = rhsTerm.getConst().extract(utils::getSize(lhsTerm) - 1, 0); + BitVector high = rhsTerm.getConst().extract(utils::getSize(rhsTerm) - 1, utils::getSize(lhsTerm)); + rhs.push_back(utils::mkConst(low)); + rhs.push_back(utils::mkConst(high)); + lhs.push_back(lhsTerm); + continue; + } + + enum SolvingFor { + SOLVING_FOR_LHS, + SOLVING_FOR_RHS + } solvingFor = sizeDifference < 0 || lhsTerm.getKind() == kind::CONST_BITVECTOR ? SOLVING_FOR_RHS : SOLVING_FOR_LHS; + + Debug("slicing") << "SliceManager::sliceAndSolve(): " << (solvingFor == SOLVING_FOR_LHS ? "solving for LHS" : "solving for RHS") << std::endl; + + // When we slice in order to align, we might have to reslice the one we are solving for + bool reslice = false; + + switch (solvingFor) { + case SOLVING_FOR_RHS: { + concatTerms.push_back(lhsTerm); + // Maybe we need to add more lhs to make them equal + while (sizeDifference < 0 && !reslice) { + Assert(lhs.size() > 0); + // Get the next part for lhs + lhsTerm = lhs.back(); + lhs.pop_back(); + // Slice if necessary + if (!isSliced(lhsTerm)) { + if (!slice(lhsTerm, lhs)) return false; + continue; + } + // If we go above 0, we need to cut it + if (sizeDifference + (int)utils::getSize(lhsTerm) > 0) { + // Slice it so it fits + addSlice(lhsTerm, (int)utils::getSize(lhsTerm) + sizeDifference); + if (!slice(lhsTerm, lhs)) return false; + if (!isSliced(rhsTerm)) { + if (!slice(rhsTerm, rhs)) return false; + while(!concatTerms.empty()) { + lhs.push_back(concatTerms.back()); + concatTerms.pop_back(); + } + reslice = true; + } + continue; + } + concatTerms.push_back(lhsTerm); + sizeDifference += utils::getSize(lhsTerm); + } + break; + } + case SOLVING_FOR_LHS: { + concatTerms.push_back(rhsTerm); + // Maybe we need to add more rhs to make them equal + while (sizeDifference > 0 && !reslice) { + Assert(rhs.size() > 0); + // Get the next part for lhs + rhsTerm = rhs.back(); + rhs.pop_back(); + // Slice if necessary + if (!isSliced(rhsTerm)) { + if (!slice(rhsTerm, rhs)) return false; + continue; + } + // If we go below 0, we need to cut it + if (sizeDifference - (int)utils::getSize(rhsTerm) < 0) { + // Slice it so it fits + addSlice(rhsTerm, (int)utils::getSize(rhsTerm) - sizeDifference); + if (!slice(rhsTerm, rhs)) return false; + if (!isSliced(lhsTerm)) { + if (!slice(lhsTerm, lhs)) return false; + while(!concatTerms.empty()) { + rhs.push_back(concatTerms.back()); + concatTerms.pop_back(); + } + reslice = true; + } + continue; + } + concatTerms.push_back(rhsTerm); + sizeDifference -= utils::getSize(rhsTerm); + } + break; + } + } + + // If we need to reslice + if (reslice) { + continue; + } + + Assert(sizeDifference == 0); + + Node concat = utils::mkConcat(concatTerms); + Debug("slicing") << "SliceManager::sliceAndSolve(): concatenation " << concat << std::endl; + + // We have them equal size now. If the base term of the one we are solving is solved into a + // non-trivial concatenation already, we have to normalize. A concatenation is non-trivial if + // it is not a direct slicing, i.e it is a concat, and normalize(x) != x + switch (solvingFor) { + case SOLVING_FOR_LHS: { + TNode lhsTermRepresentative = d_equalityEngine.getRepresentative(lhsTerm); + if (lhsTermRepresentative != lhsTerm && + (lhsTermRepresentative.getKind() == kind::BITVECTOR_CONCAT || lhsTermRepresentative.getKind() == kind::CONST_BITVECTOR)) { + // We need to normalize and solve the normalized equations + std::vector explanation; + d_equalityEngine.getExplanation(lhsTerm, lhsTermRepresentative, explanation); + std::set additionalAssumptions(assumptions); + additionalAssumptions.insert(explanation.begin(), explanation.end()); + bool ok = solveEquality(lhsTermRepresentative, concat, additionalAssumptions); + if (!ok) return false; } else { - // x[i_1:j_1] vs y[i_2:j_2] + // We're fine, just add the equality + Debug("slicing") << "SliceManager::sliceAndSolve(): adding " << lhsTerm << " = " << concat << " " << utils::setToString(assumptions) << std::endl; + d_equalityEngine.addTerm(concat); + bool ok = d_equalityEngine.addEquality(lhsTerm, concat, utils::mkConjunction(assumptions)); + if (!ok) return false; } - lhsSlices.push_back(lhsTerm); - rhsSlices.push_back(rhsTerm); - continue; - } else { - // They are not of equal sizes, so we slice one - if (lhsSize < rhsSize) { - // We need to cut a piece of rhs + break; + } + case SOLVING_FOR_RHS: { + TNode rhsTermRepresentative = d_equalityEngine.getRepresentative(rhsTerm); + if (rhsTermRepresentative != rhsTerm && + (rhsTermRepresentative.getKind() == kind::BITVECTOR_CONCAT || rhsTermRepresentative.getKind() == kind::CONST_BITVECTOR)) { + // We need to normalize and solve the normalized equations + std::vector explanation; + d_equalityEngine.getExplanation(rhsTerm, rhsTermRepresentative, explanation); + std::set additionalAssumptions(assumptions); + additionalAssumptions.insert(explanation.begin(), explanation.end()); + bool ok = solveEquality(rhsTermRepresentative, concat, additionalAssumptions); + if (!ok) return false; } else { - // We need to cut a piece of lhs + // We're fine, just add the equality + Debug("slicing") << "SliceManager::sliceAndSolve(): adding " << rhsTerm << " = " << concat << utils::setToString(assumptions) << std::endl; + d_equalityEngine.addTerm(concat); + bool ok = d_equalityEngine.addEquality(rhsTerm, concat, utils::mkConjunction(assumptions)); + if (!ok) return false; } + break; + } } } + + return true; } template @@ -315,9 +541,11 @@ bool SliceManager::isSliced(TNode node) const { if (find == d_nodeSlicing.end()) { result = nodeKind != kind::BITVECTOR_EXTRACT; } else { - // Check whether there is a slice point in [high, low), if there is the term is not sliced. - // Hence, if we look for the upper bound of low, and it is higher than high, it is sliced. - result = d_setCollection.count(find->second, low + 1, high) > 0; + // The term is not sliced if one of the borders is not in the slice set or + // there is a point between the borders + result = + d_setCollection.contains(find->second, low) && d_setCollection.contains(find->second, high + 1) && + (low == high || d_setCollection.count(find->second, low + 1, high) == 0); } } @@ -326,7 +554,7 @@ bool SliceManager::isSliced(TNode node) const { } template -inline void SliceManager::slice(TNode node, std::vector& sliced) { +inline bool SliceManager::slice(TNode node, std::vector& sliced) { Debug("slicing") << "SliceManager::slice(" << node << ")" << std::endl; @@ -335,44 +563,64 @@ inline void SliceManager::slice(TNode node, std::vector& // The indices of the beginning and (one past) end unsigned high = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractHigh(node) + 1 : utils::getSize(node); unsigned low = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractLow(node) : 0; + Debug("slicing") << "SliceManager::slice(" << node << "): low: " << low << std::endl; + Debug("slicing") << "SliceManager::slice(" << node << "): high: " << high << std::endl; // Get the base term TNode nodeBase = baseTerm(node); Assert(nodeBase.getKind() != kind::BITVECTOR_CONCAT); Assert(nodeBase.getKind() != kind::CONST_BITVECTOR); - // Get the base term slice set - set_collection::reference_type nodeSliceSet = d_nodeSlicing[nodeBase]; + // The nodes slice set + set_collection::reference_type nodeSliceSet; + + // Find the current one or construct it + slicing_map::iterator findSliceSet = d_nodeSlicing.find(nodeBase); + if (findSliceSet == d_nodeSlicing.end()) { + nodeSliceSet = d_setCollection.newSet(utils::getSize(nodeBase)); + d_setCollection.insert(nodeSliceSet, 0); + d_nodeSlicing[nodeBase] = nodeSliceSet; + } else { + nodeSliceSet = d_nodeSlicing[nodeBase]; + } + Debug("slicing") << "SliceManager::slice(" << node << "): current: " << d_setCollection.toString(nodeSliceSet) << std::endl; std::vector slicePoints; - d_setCollection.getElements(nodeSliceSet, low + 1, high - 1, slicePoints); - + if (low + 1 < high) { + d_setCollection.getElements(nodeSliceSet, low + 1, high - 1, slicePoints); + } + // Go through all the points i_0 <= low < i_1 < ... < i_{n-1} < high <= i_n from the slice set // and generate the slices [i_0:low-1][low:i_1-1] [i_1:i2] ... [i_{n-1}:high-1][high:i_n-1]. They are in reverse order, // as they should be size_t i_0 = low == 0 ? 0 : d_setCollection.prev(nodeSliceSet, low + 1); - size_t i_n = high == utils::getSize(nodeBase) ? high: d_setCollection.next(nodeSliceSet, high); + Debug("slicing") << "SliceManager::slice(" << node << "): i_0: " << i_0 << std::endl; + size_t i_n = high == utils::getSize(nodeBase) ? high: d_setCollection.next(nodeSliceSet, high - 1); + Debug("slicing") << "SliceManager::slice(" << node << "): i_n: " << i_n << std::endl; // Add the new points to the slice set (they might be there already) if (high < i_n) { - std::vector lastTwoSlices; - lastTwoSlices.push_back(utils::mkExtract(nodeBase, i_n-1, high)); - lastTwoSlices.push_back(utils::mkExtract(nodeBase, high-1, slicePoints.back())); - d_equalityEngine.addEquality(utils::mkExtract(nodeBase, i_n-1, slicePoints.back()), utils::mkConcat(lastTwoSlices)); + if (!addSlice(nodeBase, high)) return false; } - - while (!slicePoints.empty()) { + // Construct the actuall slicing + if (slicePoints.size() > 0) { + Debug("slicing") << "SliceManager::slice(" << node << "): adding" << utils::mkExtract(nodeBase, slicePoints[0] - 1, low) << std::endl; + sliced.push_back(utils::mkExtract(nodeBase, slicePoints[0] - 1, low)); + for (unsigned i = 1; i < slicePoints.size(); ++ i) { + Debug("slicing") << "SliceManager::slice(" << node << "): adding" << utils::mkExtract(nodeBase, slicePoints[i] - 1, slicePoints[i-1])<< std::endl; + sliced.push_back(utils::mkExtract(nodeBase, slicePoints[i] - 1, slicePoints[i-1])); + } + Debug("slicing") << "SliceManager::slice(" << node << "): adding" << utils::mkExtract(nodeBase, high-1, slicePoints.back()) << std::endl; sliced.push_back(utils::mkExtract(nodeBase, high-1, slicePoints.back())); - high = slicePoints.back(); - slicePoints.pop_back(); + } else { + sliced.push_back(utils::mkExtract(nodeBase, high - 1, low)); } - + // Add the new points to the slice set (they might be there already) if (i_0 < low) { - std::vector firstTwoSlices; - firstTwoSlices.push_back(utils::mkExtract(nodeBase, high-1, low)); - firstTwoSlices.push_back(utils::mkExtract(nodeBase, low-1, i_0)); - d_equalityEngine.addEquality(utils::mkExtract(nodeBase, high-1, i_0), utils::mkConcat(firstTwoSlices)); + if (!addSlice(nodeBase, low)) return false; } + + return true; } template diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index e183a592c..2d823383c 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -35,7 +35,17 @@ void TheoryBV::preRegisterTerm(TNode node) { if (node.getKind() == kind::EQUAL) { d_eqEngine.addTerm(node[0]); + if (node[0].getKind() == kind::BITVECTOR_CONCAT) { + for (unsigned i = 0, i_end = node[0].getNumChildren(); i < i_end; ++ i) { + d_eqEngine.addTerm(node[0][i]); + } + } d_eqEngine.addTerm(node[1]); + if (node[1].getKind() == kind::BITVECTOR_CONCAT) { + for (unsigned i = 0, i_end = node[1].getNumChildren(); i < i_end; ++ i) { + d_eqEngine.addTerm(node[1][i]); + } + } size_t triggerId = d_eqEngine.addTrigger(node[0], node[1]); Assert(triggerId == d_triggers.size()); d_triggers.push_back(node); @@ -57,29 +67,27 @@ void TheoryBV::check(Effort e) { // Do the right stuff switch (assertion.getKind()) { case kind::EQUAL: { - - // Slice the equality - std::vector lhsSlices, rhsSlices; - d_sliceManager.addEquality(assertion[0], assertion[1], lhsSlices, rhsSlices); - Assert(lhsSlices.size() == rhsSlices.size()); - - // Add the equality to the equality engine - for (int i = 0, i_end = lhsSlices.size(); i != i_end; ++ i) { - bool ok = d_eqEngine.addEquality(lhsSlices[i], rhsSlices[i]); - if (!ok) return; - } + // Slice and solve the equality + bool ok = d_sliceManager.solveEquality(assertion[0], assertion[1]); + if (!ok) return; break; } case kind::NOT: { // We need to check this as the equality trigger might have been true when we made it TNode equality = assertion[0]; + // Assumptions + std::set assumptions; + Node lhsNormalized = d_eqEngine.normalize(equality[0], assumptions); + Node rhsNormalized = d_eqEngine.normalize(equality[1], assumptions); + + Debug("bitvector") << "TheoryBV::check(" << e << "): normalizes to " << lhsNormalized << " = " << rhsNormalized << std::endl; + // No need to slice the equality, the whole thing *should* be deduced - if (d_eqEngine.areEqual(equality[0], equality[1])) { - vector assertions; - d_eqEngine.getExplanation(equality[0], equality[1], assertions); - assertions.push_back(assertion); - d_out->conflict(mkAnd(assertions)); + if (lhsNormalized == rhsNormalized) { + Debug("bitvector") << "TheoryBV::check(" << e << "): conflict with " << utils::setToString(assumptions) << std::endl; + assumptions.insert(assertion); + d_out->conflict(mkConjunction(assumptions)); return; } break; diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h index ede98004f..fa9762cb7 100644 --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@ -46,7 +46,34 @@ public: } }; - typedef EqualityEngine BvEqualityEngine; + struct BVEqualitySettings { + static inline bool descend(TNode node) { + return node.getKind() == kind::BITVECTOR_CONCAT || node.getKind() == kind::BITVECTOR_EXTRACT; + } + + /** Returns true if node1 has preference to node2 as a representative, otherwise node2 is used */ + static inline bool mergePreference(TNode node1, unsigned node1size, TNode node2, unsigned node2size) { + if (node1.getKind() == kind::CONST_BITVECTOR) { + Assert(node2.getKind() != kind::CONST_BITVECTOR); + return true; + } + if (node2.getKind() == kind::CONST_BITVECTOR) { + Assert(node1.getKind() != kind::CONST_BITVECTOR); + return false; + } + if (node1.getKind() == kind::BITVECTOR_CONCAT) { + Assert(node2.getKind() != kind::BITVECTOR_CONCAT); + return true; + } + if (node2.getKind() == kind::BITVECTOR_CONCAT) { + Assert(node1.getKind() != kind::BITVECTOR_CONCAT); + return false; + } + return node2size < node1size; + } + }; + + typedef EqualityEngine BvEqualityEngine; private: diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp index 9b545d25a..9b5c8b0f9 100644 --- a/src/theory/bv/theory_bv_rewriter.cpp +++ b/src/theory/bv/theory_bv_rewriter.cpp @@ -51,12 +51,12 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) { break; case kind::BITVECTOR_EXTRACT: result = LinearRewriteStrategy< + // Extract over a concatenation is distributed to the appropriate concatenations + RewriteRule, // Extract over a constant gives a constant RewriteRule, - // Extract over an extract is simplified to one extract + // We could get another extract over extract RewriteRule, - // Extract over a concatenation is distributed to the appropriate concatenations - RewriteRule, // At this point only Extract-Whole could apply RewriteRule >::apply(node); diff --git a/src/theory/bv/theory_bv_utils.h b/src/theory/bv/theory_bv_utils.h index 6e9dbef3e..ad924f8a0 100644 --- a/src/theory/bv/theory_bv_utils.h +++ b/src/theory/bv/theory_bv_utils.h @@ -19,7 +19,9 @@ #pragma once +#include #include +#include #include "expr/node_manager.h" namespace CVC4 { @@ -51,6 +53,10 @@ inline Node mkAnd(std::vector& children) { return NodeManager::currentNM()->mkNode(kind::AND, children); } +inline Node mkAnd(std::vector& children) { + return NodeManager::currentNM()->mkNode(kind::AND, children); +} + inline Node mkExtract(TNode node, unsigned high, unsigned low) { Node extractOp = NodeManager::currentNM()->mkConst(BitVectorExtract(high, low)); std::vector children; @@ -65,10 +71,92 @@ inline Node mkConcat(std::vector& children) { return children[0]; } +inline Node mkConcat(TNode t1, TNode t2) { + return NodeManager::currentNM()->mkNode(kind::BITVECTOR_CONCAT, t1, t2); +} + + inline Node mkConst(const BitVector& value) { return NodeManager::currentNM()->mkConst(value); } +inline void getConjuncts(TNode node, std::set& conjuncts) { + if (node.getKind() != kind::AND) { + conjuncts.insert(node); + } else { + for (unsigned i = 0; i < node.getNumChildren(); ++ i) { + getConjuncts(node[i], conjuncts); + } + } +} + +inline Node mkConjunction(const std::set nodes) { + std::set expandedNodes; + + std::set::const_iterator it = nodes.begin(); + std::set::const_iterator it_end = nodes.end(); + while (it != it_end) { + TNode current = *it; + if (current != mkTrue()) { + Assert(current != mkFalse()); + if (current.getKind() == kind::AND) { + getConjuncts(current, expandedNodes); + } else { + expandedNodes.insert(current); + } + } + ++ it; + } + + Assert(expandedNodes.size() > 0); + if (expandedNodes.size() == 1) { + return *expandedNodes.begin(); + } + + NodeBuilder<> conjunction(kind::AND); + + it = expandedNodes.begin(); + it_end = expandedNodes.end(); + while (it != it_end) { + conjunction << *it; + ++ it; + } + + return conjunction; +} + +// Turn a set into a string +inline std::string setToString(const std::set& nodeSet) { + std::stringstream out; + out << "["; + std::set::const_iterator it = nodeSet.begin(); + std::set::const_iterator it_end = nodeSet.end(); + bool first = true; + while (it != it_end) { + if (!first) { + out << ","; + } + first = false; + out << *it; + ++ it; + } + out << "]"; + return out.str(); +} + +// Turn a vector into a string +inline std::string vectorToString(const std::vector& nodes) { + std::stringstream out; + out << "["; + for (unsigned i = 0; i < nodes.size(); ++ i) { + if (i > 0) { + out << ","; + } + out << nodes[i]; + } + out << "]"; + return out.str(); +} } } diff --git a/src/util/bitvector.h b/src/util/bitvector.h index d1bfafb00..ca69fb506 100644 --- a/src/util/bitvector.h +++ b/src/util/bitvector.h @@ -98,7 +98,7 @@ public: return BitVector(d_size + other.d_size, (d_value * Integer(2).pow(other.d_size)) + other.d_value); } - BitVector extract(unsigned high, unsigned low) { + BitVector extract(unsigned high, unsigned low) const { return BitVector(high - low + 1, (d_value % (Integer(2).pow(high + 1))) / Integer(2).pow(low)); } diff --git a/test/regress/regress0/bv/core/Makefile.am b/test/regress/regress0/bv/core/Makefile.am index 5b8e6d7d3..947dc65ad 100644 --- a/test/regress/regress0/bv/core/Makefile.am +++ b/test/regress/regress0/bv/core/Makefile.am @@ -41,7 +41,27 @@ TESTS = \ equality-00.smt \ equality-01.smt \ equality-02.smt \ - bv_eq_diamond10.smt + bv_eq_diamond10.smt \ + slice-01.smt \ + slice-02.smt \ + slice-03.smt \ + slice-04.smt \ + slice-05.smt \ + slice-06.smt \ + slice-07.smt \ + slice-08.smt \ + slice-09.smt \ + slice-10.smt \ + slice-11.smt \ + slice-12.smt \ + slice-13.smt \ + slice-14.smt \ + slice-15.smt \ + slice-16.smt \ + slice-17.smt \ + slice-18.smt \ + slice-19.smt \ + slice-20.smt EXTRA_DIST = $(TESTS) diff --git a/test/regress/regress0/bv/core/slice-20.cvc b/test/regress/regress0/bv/core/slice-20.cvc new file mode 100644 index 000000000..a211b5f2e --- /dev/null +++ b/test/regress/regress0/bv/core/slice-20.cvc @@ -0,0 +1,13 @@ +x1, y1: BITVECTOR(4); +x2, y2: BITVECTOR(2); +x3, y3: BITVECTOR(1); + +ASSERT(x1 = y1); + +ASSERT(x1 = x2 @ x2); +ASSERT(x2 = x3 @ x3); + +ASSERT(y1 = y2 @ y2); +ASSERT(y2 = y3 @ y3); + +QUERY(x3 = y3); \ No newline at end of file -- 2.30.2