From 48101a9e2416a16b645f31936776694a25cf9efd Mon Sep 17 00:00:00 2001 From: =?utf8?q?Dejan=20Jovanovi=C4=87?= Date: Fri, 25 Feb 2011 02:11:14 +0000 Subject: [PATCH] slicing manager is not breaking the old regressions, time to sync --- src/theory/bv/Makefile.am | 3 +- src/theory/bv/cd_set_collection.h | 316 ++++++++++++++++++++++++ src/theory/bv/equality_engine.h | 249 +++++++++++++++---- src/theory/bv/slice_manager.h | 309 ++++++++++++++++++++++- src/theory/bv/theory_bv.cpp | 20 +- src/theory/bv/theory_bv.h | 4 +- src/theory/bv/theory_bv_rewrite_rules.h | 2 +- 7 files changed, 838 insertions(+), 65 deletions(-) create mode 100644 src/theory/bv/cd_set_collection.h diff --git a/src/theory/bv/Makefile.am b/src/theory/bv/Makefile.am index db750575f..08dde4349 100644 --- a/src/theory/bv/Makefile.am +++ b/src/theory/bv/Makefile.am @@ -16,6 +16,7 @@ libbv_la_SOURCES = \ theory_bv_rewriter.cpp \ equality_engine.h \ equality_engine.cpp \ - slice_manager.h + slice_manager.h \ + cd_set_collection.h EXTRA_DIST = kinds diff --git a/src/theory/bv/cd_set_collection.h b/src/theory/bv/cd_set_collection.h new file mode 100644 index 000000000..bd26a3595 --- /dev/null +++ b/src/theory/bv/cd_set_collection.h @@ -0,0 +1,316 @@ +/* + * set_collection.h + * + * Created on: Feb 24, 2011 + * Author: dejan + */ + +#pragma once + +#include +#include "context/cdo.h" + +namespace CVC4 { +namespace context { + +/** + * A class representing a backtrackable set of slice points. The memory should allow indexing with the TreeEntry.left and + * TreeEntry.right. TreeEntry should also provide null for the non-existing reference and a constructor with (value, + * left, right). + */ +template +class BacktrackableSetCollection { + + /** + * This is an interesting C++ question: how to make class applicable to non-by-value elements. If we turn + * below into const value_type& it doesn't work for the bit-slicing as size_t it is passed by value, and moreover + * we are using only a part of the word for the value. Hmmm. + */ + typedef value_type const_value_reference; + + /** Type of the reference */ + typedef typename tree_entry_type::reference_type reference_type; + + /** The null reference */ + static const reference_type null = tree_entry_type::null; + + /** The memory this set collection will use */ + memory_type d_memory; + + /** Backtrackable number of nodes that have been inserted */ + context::CDO d_nodesInserted; + + /** Backtrack */ + void backtrack() { + while (d_nodesInserted < d_memory.size()) { + const tree_entry_type& node = d_memory.back(); + if (node.hasParent()) { + if (node.isLeft()) { + d_memory[node.getParent()].removeLeft(); + } else { + d_memory[node.getParent()].removeRight(); + } + } + d_memory.pop_back(); + } + } + + inline void backtrack() const { + const_cast(this)->backtrack(); + } + + /** + * Create a new set. The set must have at least one element. + */ + reference_type newElement(const value_type& value, reference_type left, reference_type right, reference_type parent, bool isLeft) { + reference_type index = d_memory.size(); + d_memory.push_back(tree_entry_type(value, left, right, value, isLeft)); + d_nodesInserted = d_nodesInserted + 1; + return index; + } + + /** + * Return the reference to the value if it's in the set or null otherwise + */ + reference_type find(reference_type set, const value_type& value) const { + while (set != null) { + const tree_entry_type& node = d_memory[set]; + if (node.getValue() == value) { + return set; + } else if (value < node.getValue()) { + set = node.getLeft(); + } else { + set = node.getRight(); + } + } + return null; + } + + /** + * Returns the maximal value in the set + */ + reference_type max(reference_type set) const { + Assert(set != null); + while (d_memory[set].hasRight()) { + set = d_memory[set].getRight(); + } + return set; + } + + /** + * Returns the minimal value in the set + */ + reference_type min(reference_type set) const { + Assert(set != null); + while (d_memory[set].hasLeft()) { + set = d_memory[set].getLeft(); + } + return set; + } + +public: + + BacktrackableSetCollection(context::Context* context) + : d_nodesInserted(context, 0) {} + + size_t size() const { + backtrack(); + return d_memory.size(); + } + + reference_type newSet(const value_type& value) { + backtrack(); + return newElement(value, null, null, null, false); + } + + void insert(memory_type& memory, reference_type root, const value_type& value) { + backtrack(); + if (root == null) { + return newSet(value); + } + // We already have a set, find the spot + reference_type parent = null; + while (true) { + parent = root; + if (value < d_memory[root].value) { + root = d_memory[root].left; + if (root == null) { + root = newElement(value, null, null, parent, true); + d_memory[parent].left = root; + return; + } + } else { + Assert(value != d_memory[root].value); + root = d_memory[root].right; + if (root == null) { + root = newElement(value, null, null, parent, false); + d_memory[parent].right = root; + return; + } + } + } + } + + /** + * Returns the maximal value in the set + */ + const_value_reference maxElement(reference_type set) const { + Assert(set != null); + backtrack(); + return d_memory[max(set)].getValue(); + } + + /** + * Returns the minimal value in the set + */ + const_value_reference minElement(reference_type set) const { + Assert(set != null); + backtrack(); + return d_memory[min(set)].getValue(); + } + + /** + * Return the previous (smaller) element. + */ + const_value_reference prev(reference_type set, const_value_reference value) { + backtrack(); + // Get the node of this value + reference_type node_ref = find(set, value); + Assert(node_ref != null); + const tree_entry_type& node = d_memory[node_ref]; + // For a left node, we know that it is smaller than all the parents and the parents other children + // The smaller node must then be the max of the left subtree + if (!node.hasParent() || node.isLeft()) { + return maxElement(node.getLeft()); + } + // For a right node, we know that it is bigger than the parent. But, we also know that the left subtree + // is also bigger than the parent + else { + if (node.hasLeft()) { + return maxElement(node.getLeft()); + } else { + Assert(node.hasParent()); + return d_memory[node.getParent()].getValue(); + } + } + } + + const_value_reference next(reference_type set, const_value_reference value) { + backtrack(); + // Get the node of this value + reference_type node_ref = find(set, value); + Assert(node_ref != null); + const tree_entry_type& node = d_memory[node_ref]; + // For a right node, we know that it is bigger than all the parents and the parents other children + // The bigger node must then be the min of the right subtree + if (!node.hasParent() || node.isRight()) { + return minElement(node.getRight()); + } + // For a left node, we know that it is smaller than the parent. But, we also know that the right subtree + // is also smaller than the parent + else { + if (node.hasRight()) { + return minElement(node.getRight()); + } else { + Assert(node.hasParent()); + return d_memory[node.getParent()].getValue(); + } + } + } + + /** + * Count the number of elements in the given bounds. + */ + unsigned count(reference_type set, const_value_reference lowerBound, const_value_reference upperBound) const { + Assert(lowerBound <= upperBound); + backtrack(); + // Empty set no elements + if (set == null) { + return 0; + } + // The counter + unsigned c = 0; + // Current set + const tree_entry_type& current = d_memory[set]; + // Left child (smaller elements) + if (lowerBound < current.getValue()) { + c += count(current.getLeft(), lowerBound, upperBound); + } + // Current element + if (lowerBound <= current.getValue() && current.getValue() <= upperBound) { + ++ c; + } + // Right child (bigger elements) + if (current.getValue() <= upperBound) { + c += count(current.getRight(), lowerBound, upperBound); + } + return c; + } + + /** + * Check for membership. + */ + bool contains(reference_type set, const_value_reference value) const { + backtrack(); + return count(set, value, value) > 0; + } + + /** + * Returns the elements (sorted) in the set between the given lower and upper bound. If include borders is on, + * and the + */ + void getElements(reference_type set, const_value_reference lowerBound, const_value_reference upperBound, std::vector& output) const { + Assert(lowerBound <= upperBound); + backtrack(); + // Empty set no elements + if (set == null) { + return; + } + // Current set + const tree_entry_type& current = d_memory[set]; + // Left child (smaller elements) + if (lowerBound < current.getValue()) { + getElements(current.getLeft(), lowerBound, upperBound, output); + } + // Current element + if (lowerBound <= current.getValue() && current.getValue() <= upperBound) { + output.push_back(current.getValue()); + } + // Right child (bigger elements) + if (current.getValue() <= upperBound) { + getElements(current.getRight(), lowerBound, upperBound, output); + } + } + + /** + * Print the list of elements to the output. + */ + void print(std::ostream& out, reference_type set) { + backtrack(); + if (set == null) { + return; + } + const tree_entry_type& current = d_memory[set]; + if (current.hasLeft()) { + print(out, current.getLeft()); + out << ","; + } + out << current.getValue(); + if (current.hasRight()) { + out << ","; + print(out, current.getRight()); + } + } + + /** + * String representation of a set. + */ + std::string toString(reference_type set) { + stringstream out; + print(out, set); + return out.str(); + } +}; + +} // Namespace context +} // Namespace CVC4s diff --git a/src/theory/bv/equality_engine.h b/src/theory/bv/equality_engine.h index 000e93a7b..9880539ed 100644 --- a/src/theory/bv/equality_engine.h +++ b/src/theory/bv/equality_engine.h @@ -27,6 +27,7 @@ #include "expr/node.h" #include "context/cdo.h" #include "util/output.h" +#include "util/stats.h" namespace CVC4 { namespace theory { @@ -46,6 +47,12 @@ struct BitSizeTraits { /** Number of bits we use for the trigger id */ static const size_t trigger_id_bits = 24; + /** Number of bits we use for the function ids */ + static const size_t function_id_bits = 8; + /** Number of bits we use for the function arguments count */ + static const size_t function_arguments_count_bits = 16; + /** Number of bits we use for the index into the arguments memory */ + static const size_t function_arguments_index_bits = 24; }; class EqualityNode { @@ -61,22 +68,26 @@ public: /** The next equality node in this class */ size_t d_nextId : BitSizeTraits::id_bits; + /** Is this node a function application */ + size_t d_isFunction : 1; + public: /** * Creates a new node, which is in a list of it's own. */ EqualityNode(size_t nodeId = BitSizeTraits::id_null) - : d_size(1), d_findId(nodeId), d_nextId(nodeId) {} + : d_size(1), d_findId(nodeId), d_nextId(nodeId), d_isFunction(0) {} /** Initialize the equality node */ - inline void init(size_t nodeId) { + inline void init(size_t nodeId, bool isFunction) { d_size = 1; d_findId = d_nextId = nodeId; + d_isFunction = isFunction; } /** - * Returns the next node in the class circural list. + * Returns the next node in the class circular list. */ inline size_t getNext() const { return d_nextId; @@ -114,10 +125,104 @@ public: inline void setFind(size_t findId) { d_findId = findId; } }; +/** + * FunctionNode class represents the information related to a function node. It has an id, number of children + * and the + */ +class FunctionNode { + + /** Is the function associative */ + size_t d_isAssociative : 1; + /** The id of the function */ + size_t d_functionId : BitSizeTraits::function_id_bits; + /** Number of children */ + size_t d_argumentsCount : BitSizeTraits::function_arguments_count_bits; + /** Index of the start of the arguments in the children array */ + size_t d_argumentsIndex : BitSizeTraits::function_arguments_index_bits; + +public: + + FunctionNode(size_t functionId = 0, size_t argumentsCount = 0, size_t argumentsIndex = 0, bool associative = false) + : d_isAssociative(associative), d_functionId(functionId), d_argumentsCount(argumentsCount), d_argumentsIndex(argumentsIndex) + {} + + void init(size_t functionId, size_t argumentsCount, size_t argumentsIndex, bool associative) { + d_functionId = functionId; + d_argumentsCount = argumentsCount; + d_argumentsIndex = argumentsIndex; + d_isAssociative = associative; + } + + /** Check if the function is associative */ + bool isAssociative() const { return d_isAssociative; } -template + /** Get the function id */ + size_t getFunctionId() const { return d_functionId; } + + /** Get the number of arguments */ + size_t getArgumentsCount() const { return d_argumentsCount; } + + /** Get the infex of the first argument in the arguments memory */ + size_t getArgumentsIndex() const { return d_argumentsIndex; } + +}; + +template class EqualityEngine { +public: + + /** + * Basic information about a function. + */ + struct FunctionInfo { + /** Name of the function */ + std::string name; + /** Is the function associative */ + bool isAssociative; + + FunctionInfo(std::string name, bool isAssociative) + : name(name), isAssociative(isAssociative) {} + }; + + /** Statistics about the equality engine instance */ + struct Statistics { + /** Total number of merges */ + IntStat mergesCount; + /** Number of terms managed by the system */ + IntStat termsCount; + /** Number of function terms managed by the system */ + IntStat functionTermsCount; + /** Number of distince functions managed by the system */ + IntStat functionsCount; + /** Number of times we performed a backtrack */ + IntStat backtracksCount; + + Statistics(std::string name) + : mergesCount(name + "::mergesCount", 0), + termsCount(name + "::termsCount", 0), + functionTermsCount(name + "functionTermsCoutn", 0), + functionsCount(name + "::functionsCount", 0), + backtracksCount(name + "::backtracksCount", 0) + { + StatisticsRegistry::registerStat(&mergesCount); + StatisticsRegistry::registerStat(&termsCount); + StatisticsRegistry::registerStat(&functionTermsCount); + StatisticsRegistry::registerStat(&functionsCount); + StatisticsRegistry::registerStat(&backtracksCount); + } + + ~Statistics() { + StatisticsRegistry::unregisterStat(&mergesCount); + StatisticsRegistry::unregisterStat(&termsCount); + StatisticsRegistry::unregisterStat(&functionTermsCount); + StatisticsRegistry::unregisterStat(&functionsCount); + StatisticsRegistry::unregisterStat(&backtracksCount); + } + }; + +private: + /** The class to notify when a representative changes for a term */ NotifyClass d_notify; @@ -133,8 +238,11 @@ class EqualityEngine { /** Number of asserted equalities we have so far */ context::CDO d_assertedEqualitiesCount; - /** Number of functions in the system */ - context::CDO d_functionsCount; + /** Map from ids to functional representations */ + std::vector d_functionNodes; + + /** Functions in the system */ + std::vector d_functions; /** * We keep a list of asserted equalities. Not among original terms, but @@ -248,23 +356,26 @@ class EqualityEngine { /** * Trigger lists per node. The begin id changes as we merge, but the end always points to - * the acutal end of the triggers for this node. + * the actual end of the triggers for this node. */ std::vector d_nodeTriggers; /** - * Adds the trigger with triggerId to the begining of the trigger list of the node with id nodeId. + * Adds the trigger with triggerId to the beginning of the trigger list of the node with id nodeId. */ inline void addTriggerToList(size_t nodeId, size_t triggerId); + /** Statistics */ + Statistics d_stats; + public: /** - * Initialize the equalty engine, given the owning class. This will initialize the notifier with + * Initialize the equality engine, given the owning class. This will initialize the notifier with * the owner information. */ - EqualityEngine(OwnerClass& owner, context::Context* context) - : d_notify(owner), d_assertedEqualitiesCount(context, 0), d_functionsCount(context, 0) { + EqualityEngine(OwnerClass& owner, context::Context* context, std::string name) + : d_notify(owner), d_assertedEqualitiesCount(context, 0), d_stats(name) { Debug("equality") << "EqualityEdge::EqualityEdge(): id_null = " << BitSizeTraits::id_null << ", trigger_id_null = " << BitSizeTraits::trigger_id_null << std::endl; } @@ -274,6 +385,11 @@ public: */ size_t addTerm(TNode t); + /** + * Adds a term that is an application of a function symbol to the databas. Returns the internal id of the term. + */ + size_t addFunctionApplication(size_t funcionId, const std::vector& arguments); + /** * Check whether the node is already in the database. */ @@ -308,20 +424,31 @@ public: size_t addTrigger(TNode t1, TNode t2); /** - * Adds a new function to the equality engine. The funcions are not of fixed arity! + * Adds a new function to the equality engine. The funcions are not of fixed arity and no typechecking is performed! + * Associative functions allow for normalization, i.e. f(f(x, y), z) = f(x, f(y, z)) = f(x, y, z). + * @associative should be true if the function is associative and you want this to be handled by the engine */ - inline size_t newFunction() { d_functionsCount = d_functionsCount + 1; return d_functionsCount; } + inline size_t newFunction(std::string name, bool associative) { + Assert(use_functions); + Assert(!associative || enable_associative); + ++ d_stats.functionsCount; + size_t id = d_functions.size(); + d_functions.push_back(FunctionInfo(name, associative)); + return id; + } }; -template -size_t EqualityEngine::addTerm(TNode t) { +template +size_t EqualityEngine::addTerm(TNode t) { Debug("equality") << "EqualityEngine::addTerm(" << t << ")" << std::endl; // If term already added, retrurn it's id if (hasTerm(t)) return getNodeId(t); + ++ d_stats.termsCount; + // Register the new id of the term size_t newId = d_nodes.size(); d_nodeIds[t] = newId; @@ -335,35 +462,69 @@ size_t EqualityEngine::addTerm(TNode t) { if (d_equalityNodes.size() <= newId) { d_equalityNodes.resize(newId + 100); } - d_equalityNodes[newId].init(newId); + d_equalityNodes[newId].init(newId, false); // Return the id of the term return newId; } -template -bool EqualityEngine::hasTerm(TNode t) const { +template +size_t EqualityEngine::addFunctionApplication(size_t functionId, const std::vector& arguments) { + + Debug("equality") << "EqualityEngine::addFunctionApplication(" << d_functions[functionId].name << ":" << arguments.size() << ")" << std::endl; + + ++ d_stats.functionTermsCount; + ++ d_stats.termsCount; + + // Register the new id of the term + size_t newId = d_nodes.size(); + // Add the node to it's position + d_nodes.push_back(Node()); + // Add the trigger list for this node + d_nodeTriggers.push_back(BitSizeTraits::trigger_id_null); + // Add it to the equality graph + d_equalityGraph.push_back(BitSizeTraits::id_null); + // Add the equality node to the nodes + if (d_equalityNodes.size() <= newId) { + d_equalityNodes.resize(newId + 100); + } + d_equalityNodes[newId].init(newId, true); + // Add the function application to the function nodes + if (d_functionNodes.size() <= newId) { + d_functionNodes.resize(newId + 100); + } + // Initialize the function node + size_t argumentsIndex; + d_functionNodes[newId].init(functionId, arguments.size(), argumentsIndex, d_functions[functionId].isAssociative); + + // Return the id of the term + return newId; + +} + +template +bool EqualityEngine::hasTerm(TNode t) const { return d_nodeIds.find(t) != d_nodeIds.end(); } -template -size_t EqualityEngine::getNodeId(TNode node) const { +template +size_t EqualityEngine::getNodeId(TNode node) const { Assert(hasTerm(node)); return (*d_nodeIds.find(node)).second; } -template -EqualityNode& EqualityEngine::getEqualityNode(TNode t) { +template +EqualityNode& EqualityEngine::getEqualityNode(TNode t) { return getEqualityNode(getNodeId(t)); } -template -EqualityNode& EqualityEngine::getEqualityNode(size_t nodeId) { +template +EqualityNode& EqualityEngine::getEqualityNode(size_t nodeId) { Assert(nodeId < d_equalityNodes.size()); return d_equalityNodes[nodeId]; } -template -bool EqualityEngine::addEquality(TNode t1, TNode t2) { +template +bool EqualityEngine::addEquality(TNode t1, TNode t2) { Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << ")" << std::endl; @@ -416,8 +577,8 @@ bool EqualityEngine::addEquality(TNode t1, TNode t2) { return true; } -template -TNode EqualityEngine::getRepresentative(TNode t) const { +template +TNode EqualityEngine::getRepresentative(TNode t) const { Debug("equality") << "EqualityEngine::getRepresentative(" << t << ")" << std::endl; @@ -432,8 +593,8 @@ TNode EqualityEngine::getRepresentative(TNode t) const return d_nodes[representativeId]; } -template -bool EqualityEngine::areEqual(TNode t1, TNode t2) const { +template +bool EqualityEngine::areEqual(TNode t1, TNode t2) const { Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ")" << std::endl; Assert(hasTerm(t1)); @@ -449,13 +610,15 @@ bool EqualityEngine::areEqual(TNode t1, TNode t2) const return rep1 == rep2; } -template -void EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vector& triggers) { +template +void EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vector& triggers) { Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << ")" << std::endl; Assert(triggers.empty()); + ++ d_stats.mergesCount; + size_t class1Id = class1.getFind(); size_t class2Id = class2.getFind(); @@ -497,8 +660,8 @@ void EqualityEngine::merge(EqualityNode& class1, Equali class1.merge(class2); } -template -void EqualityEngine::undoMerge(EqualityNode& class1, EqualityNode& class2, size_t class2Id) { +template +void EqualityEngine::undoMerge(EqualityNode& class1, EqualityNode& class2, size_t class2Id) { Debug("equality") << "EqualityEngine::undoMerge(" << class1.getFind() << "," << class2Id << ")" << std::endl; @@ -529,12 +692,14 @@ void EqualityEngine::undoMerge(EqualityNode& class1, Eq } -template -void EqualityEngine::backtrack() { +template +void EqualityEngine::backtrack() { // If we need to backtrack then do it if (d_assertedEqualitiesCount < d_assertedEqualities.size()) { + ++ d_stats.backtracksCount; + Debug("equality") << "EqualityEngine::backtrack(): nodes" << std::endl; for (int i = (int)d_assertedEqualities.size() - 1, i_end = (int)d_assertedEqualitiesCount; i >= i_end; --i) { @@ -560,8 +725,8 @@ void EqualityEngine::backtrack() { } -template -void EqualityEngine::addGraphEdge(size_t t1, size_t t2) { +template +void EqualityEngine::addGraphEdge(size_t t1, size_t t2) { Debug("equality") << "EqualityEngine::addGraphEdge(" << d_nodes[t1] << "," << d_nodes[t2] << ")" << std::endl; size_t edge = d_equalityEdges.size(); d_equalityEdges.push_back(EqualityEdge(t2, d_equalityGraph[t1])); @@ -570,8 +735,8 @@ void EqualityEngine::addGraphEdge(size_t t1, size_t t2) d_equalityGraph[t2] = edge | 1; } -template -void EqualityEngine::getExplanation(TNode t1, TNode t2, std::vector& equalities) const { +template +void EqualityEngine::getExplanation(TNode t1, TNode t2, std::vector& equalities) const { Assert(equalities.empty()); Assert(t1 != t2); Assert(getRepresentative(t1) == getRepresentative(t2)); @@ -651,8 +816,8 @@ void EqualityEngine::getExplanation(TNode t1, TNode t2, } } -template -size_t EqualityEngine::addTrigger(TNode t1, TNode t2) { +template +size_t EqualityEngine::addTrigger(TNode t1, TNode t2) { Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << ")" << std::endl; diff --git a/src/theory/bv/slice_manager.h b/src/theory/bv/slice_manager.h index b31f721a0..436ebaec0 100644 --- a/src/theory/bv/slice_manager.h +++ b/src/theory/bv/slice_manager.h @@ -7,15 +7,76 @@ #pragma once +#include "context/cdo.h" #include "theory/bv/theory_bv_utils.h" #include "theory/bv/equality_engine.h" +#include "theory/bv/cd_set_collection.h" +#include #include namespace CVC4 { namespace theory { namespace bv { +/** + * Representation of the slice points in tree. + */ +class slice_point +{ +public: + + /** Number of bits we use for the index of the slice */ + static const size_t s_slice_index_bits = 31; + /** Number of bits we use for the index of the slice_point in the slice memory (reference) */ + static const size_t s_slice_point_reference_bits = 32; + /** The null reference (maximal number in the given bits) */ + static const size_t null = (1llu << s_slice_point_reference_bits) - 1; + + /** Type of the reference for the outside world */ + typedef size_t reference_type; + + /** Type of the value for the outside world */ + typedef size_t value_type; + +private: + + /** The value of the slice point (bit index) */ + size_t d_value : s_slice_index_bits; + /** Is this the left child */ + size_t d_isLeftChild : 1; + /** Reference to the left in the tree */ + size_t d_left : s_slice_point_reference_bits; + /** Reference to the right of the tree */ + size_t d_right : s_slice_point_reference_bits; + /** Reference to the parent */ + size_t d_parent : s_slice_point_reference_bits; + +public: + + slice_point(size_t value, size_t left, size_t right, size_t parent, bool isLeftChild) + : d_value(value), d_isLeftChild(isLeftChild ? 1 : 0), d_left(left), d_right(right), d_parent(parent) {} + + bool isLeft() const { return d_isLeftChild == 1; } + bool isRight() const { return d_isLeftChild == 0; } + + bool hasLeft() const { return d_left != null; } + bool hasRight() const { return d_right != null; } + bool hasParent() const { return d_parent != null; } + + reference_type getLeft() const { return d_left; } + reference_type getRight() const { return d_right; } + reference_type getParent() const { return d_parent; } + + void removeLeft() { Assert(d_left != null); d_left = null; } + void removeRight() { Assert(d_right != null); d_right = null; } + + void setLeft(reference_type left) { Assert(d_left == null && left != null); d_left = left; } + void setRight(reference_type right) { Assert(d_right == null && right != null); d_right = right; } + + value_type getValue() const { return d_value; } +}; + /** * Slice manager should keep the database of slices for the core theory leaf terms, for example * @@ -51,9 +112,22 @@ namespace bv { template class SliceManager { +public: + + /** The references to backtrackable sets */ + typedef slice_point::reference_type set_reference; + + /** The set collection we'll be using */ + typedef context::BacktrackableSetCollection, slice_point, set_reference> set_collection; + + /** The map type from nodes to their references */ + typedef std::map slicing_map; + /** The equality engine theory of bit-vectors is using */ typedef typename TheoryBitvector::BvEqualityEngine EqualityEngine; +private: + /** The theory of bitvectors */ TheoryBitvector& d_theoryBitvector; @@ -63,12 +137,24 @@ class SliceManager { /** The id of the concatenation function */ size_t d_concatFunctionId; + /** The collection of backtrackable sets */ + set_collection d_setCollection; + + /** + * A map from base nodes to slice points. For each node, the slice points are + * 0 = i_1 < i_2 < ... < i_n = size, and the slices are + * x[i_n-1:i_{n-1}]@x[i_{n-1}-1:i_{n-2}]@...@x[i_2-1:i_1] + * Each time we add a slict t = t1@t2@...@tn of a term (or a slice), we also notify the equality engine with an + * extra assertion. Since the equality engine is backtrackable, we will need to backtrack the slices accordingly. + */ + slicing_map d_nodeSlicing; + public: - SliceManager(TheoryBitvector& theoryBitvector) - : d_theoryBitvector(theoryBitvector), d_equalityEngine(theoryBitvector.getEqualityEngine()) { + SliceManager(TheoryBitvector& theoryBitvector, context::Context* context) + : d_theoryBitvector(theoryBitvector), d_equalityEngine(theoryBitvector.getEqualityEngine()), d_setCollection(context) { // We register the concatentation with the equality engine - d_concatFunctionId = d_equalityEngine.newFunction(); + d_concatFunctionId = d_equalityEngine.newFunction("bv_concat", true); } inline size_t getConcatFunctionId() const { return d_concatFunctionId; } @@ -80,14 +166,225 @@ public: * engine gets the assertion x[10:0] = concat(x[10:5], x[4:0]). */ inline void addEquality(TNode lhs, TNode rhs, std::vector& lhsSlices, std::vector& rhsSlices); + +private: + + /** + * Slices up lhs and rhs and returns the slices in lhsSlices and rhsSlices + */ + inline void slice(std::vector& lhs, std::vector& rhs, + std::vector& lhsSlices, std::vector& rhsSlices); + + /** + * Returns true if the term is already sliced wrt the current slicing. Note that, for example, even though + * the slicing is empty, x[i:j] is considered sliced. Sliced means that there is no slice points between i and j. + */ + inline bool isSliced(TNode node) const; + + /** + * Slices the term wrt the current slicing. When done, isSliced returns true + */ + inline void slice(TNode node, std::vector& sliced); + + /** + * Returns the base term in the core theory of the given term, i.e. + * x => x + * x[i:j] => x + * (x + y) => x+y + * (x + y)[i:j] => x+y + */ + static inline TNode baseTerm(TNode node); + + /** + * Adds a new slice to the slice set of the given base term. + */ + inline void addSlice(Node baseTerm, unsigned slicePoint); }; +template +void SliceManager::addSlice(Node baseTerm, unsigned slicePoint) { +} + template void SliceManager::addEquality(TNode lhs, TNode rhs, std::vector& lhsSlices, std::vector& rhsSlices) { - Debug("theory::bv::slicing") << "addEquality(" << lhs << "," << rhs << ")"; - lhsSlices.push_back(lhs); - rhsSlices.push_back(rhs); + + Debug("slicing") << "SliceMagager::addEquality(" << lhs << "," << rhs << ")" << std::endl; + + // The concatenations on the left-hand side (reverse order, first is on top) + std::vector lhsTerms; + if (lhs.getKind() == kind::BITVECTOR_CONCAT) { + for (int i = (int) lhs.getNumChildren() - 1; i >= 0; -- i) { + lhsTerms.push_back(lhs[i]); + } + } else { + lhsTerms.push_back(lhs); + } + + // The concatenations on the right-hand side (reverse order, first is on top) + std::vector rhsTerms; + if (rhs.getKind() == kind::BITVECTOR_CONCAT) { + for (int i = (int) rhs.getNumChildren() - 1; i >= 0; --i) { + rhsTerms.push_back(rhs[i]); + } + } else { + rhsTerms.push_back(rhs); + } + + // Slice the individual terms to align them + slice(lhsTerms, rhsTerms, lhsSlices, rhsSlices); +} + +template +void SliceManager::slice(std::vector& lhs, std::vector& rhs, + std::vector& lhsSlices, std::vector& rhsSlices) { + + Debug("slicing") << "SliceManager::slice()" << std::endl; + + // Go through the work-list and align + while (!lhs.empty()) { + + Assert(!rhs.empty()); + + // The terms that we need to slice + Node lhsTerm = lhs.back(); + Node rhsTerm = rhs.back(); + Debug("slicing") << "slicing: " << lhsTerm << " and " << rhsTerm << std::endl; + + // If the terms are not sliced wrt the current slicing, we have them sliced + lhs.pop_back(); + if (!isSliced(lhsTerm)) { + slice(lhsTerm, lhs); + continue; + } + rhs.pop_back(); + if (!isSliced(rhsTerm)) { + slice(rhsTerm, rhs); + } + + // If the slices are of the same size we do the additional work + unsigned lhsSize = utils::getSize(lhsTerm); + unsigned rhsSize = utils::getSize(rhsTerm); + if (lhsSize == rhsSize) { + // If they are over the same base terms, we need to do something + TNode lhsBaseTerm = baseTerm(lhsTerm); + TNode rhsBaseTerm = baseTerm(rhsTerm); + if (lhsBaseTerm == rhsBaseTerm) { + // x[i_1:j_1] vs x[i_2:j_2] + } else { + // x[i_1:j_1] vs y[i_2:j_2] + } + lhsSlices.push_back(lhsTerm); + rhsSlices.push_back(rhsTerm); + continue; + } else { + // They are not of equal sizes, so we slice one + if (lhsSize < rhsSize) { + // We need to cut a piece of rhs + } else { + // We need to cut a piece of lhs + } + } + } +} + +template +bool SliceManager::isSliced(TNode node) const { + + Debug("slicing") << "SliceManager::isSliced(" << node << ")" << std::endl; + + bool result = false; + + // Constants are always sliced + if (node.getKind() == kind::CONST_BITVECTOR) { + result = true; + } else { + // The indices of the beginning and end + Kind nodeKind = node.getKind(); + unsigned high = nodeKind == kind::BITVECTOR_EXTRACT ? utils::getExtractHigh(node) : utils::getSize(node) - 1; + unsigned low = nodeKind == kind::BITVECTOR_EXTRACT ? utils::getExtractLow(node) : 0; + + // Get the base term + TNode nodeBase = baseTerm(node); + Assert(nodeBase.getKind() != kind::BITVECTOR_CONCAT); + Assert(nodeBase.getKind() != kind::CONST_BITVECTOR); + + // Get the base term slice set + slicing_map::const_iterator find = d_nodeSlicing.find(nodeBase); + // If no slices, it's just a term, so we are done, UNLESS it's an extract + if (find == d_nodeSlicing.end()) { + result = nodeKind != kind::BITVECTOR_EXTRACT; + } else { + // Check whether there is a slice point in [high, low), if there is the term is not sliced. + // Hence, if we look for the upper bound of low, and it is higher than high, it is sliced. + result = d_setCollection.count(find->second, low + 1, high) > 0; + } + } + + Debug("slicing") << "SliceManager::isSliced(" << node << ") => " << (result ? "true" : "false") << std::endl; + return result; +} + +template +inline void SliceManager::slice(TNode node, std::vector& sliced) { + + Debug("slicing") << "SliceManager::slice(" << node << ")" << endl; + + Assert(!isSliced(node)); + + // The indices of the beginning and (one past) end + unsigned high = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractHigh(node) + 1 : utils::getSize(node); + unsigned low = node.getKind() == kind::BITVECTOR_EXTRACT ? utils::getExtractLow(node) : 0; + + // Get the base term + TNode nodeBase = baseTerm(node); + Assert(nodeBase.getKind() != kind::BITVECTOR_CONCAT); + Assert(nodeBase.getKind() != kind::CONST_BITVECTOR); + + // Get the base term slice set + set_collection::reference_type nodeSliceSet = d_nodeSlicing[nodeBase]; + Debug("slicing") << "SliceManager::slice(" << node << "): current: " << d_setCollection.toString(nodeSliceSet) << endl; + std::vector slicePoints; + d_setCollection.getElements(nodeSliceSet, low + 1, high - 1, slicePoints); + + // Go through all the points i_0 <= low < i_1 < ... < i_{n-1} < high <= i_n from the slice set + // and generate the slices [i_0:low-1][low:i_1-1] [i_1:i2] ... [i_{n-1}:high-1][high:i_n-1]. They are in reverse order, + // as they should be + size_t i_0 = low == 0 ? 0 : d_setCollection.prev(nodeSliceSet, low + 1); + size_t i_n = high == utils::getSize(nodeBase) ? high: d_setCollection.next(nodeSliceSet, high); + + // Add the new points to the slice set (they might be there already) + if (high < i_n) { + std::vector lastTwoSlices; + lastTwoSlices.push_back(utils::mkExtract(nodeBase, i_n-1, high)); + lastTwoSlices.push_back(utils::mkExtract(nodeBase, high-1, slicePoints.back())); + d_equalityEngine.addEquality(utils::mkExtract(nodeBase, i_n-1, slicePoints.back()), utils::mkConcat(lastTwoSlices)); + } + + while (!slicePoints.empty()) { + sliced.push_back(utils::mkExtract(nodeBase, high-1, slicePoints.back())); + high = slicePoints.back(); + slicePoints.pop_back(); + } + + if (i_0 < low) { + std::vector firstTwoSlices; + firstTwoSlices.push_back(utils::mkExtract(nodeBase, high-1, low)); + firstTwoSlices.push_back(utils::mkExtract(nodeBase, low-1, i_0)); + d_equalityEngine.addEquality(utils::mkExtract(nodeBase, high-1, i_0), utils::mkConcat(firstTwoSlices)); + } +} + +template +TNode SliceManager::baseTerm(TNode node) { + if (node.getKind() == kind::BITVECTOR_EXTRACT) { + Assert(node[0].getKind() != kind::BITVECTOR_EXTRACT); + Assert(node[0].getKind() != kind::CONST_BITVECTOR); + return node[0]; + } else { + Assert(node.getKind() != kind::BITVECTOR_CONCAT); + return node; + } } } // Namespace bv diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index cb8b95751..0356e5f27 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -72,19 +72,13 @@ void TheoryBV::check(Effort e) { // We need to check this as the equality trigger might have been true when we made it TNode equality = assertion[0]; - // Slice the equality - std::vector lhsSlices, rhsSlices; - d_sliceManager.addEquality(equality[0], equality[1], lhsSlices, rhsSlices); - Assert(lhsSlices.size() == rhsSlices.size()); - - for (int i = 0, i_end = lhsSlices.size(); i != i_end; ++ i) { - if (d_eqEngine.areEqual(lhsSlices[i], rhsSlices[i])) { - vector assertions; - d_eqEngine.getExplanation(lhsSlices[i], rhsSlices[i], assertions); - assertions.push_back(assertion); - d_out->conflict(mkAnd(assertions)); - return; - } + // No need to slice the equality, the whole thing *should* be deduced + if (d_eqEngine.areEqual(equality[0], equality[1])) { + vector assertions; + d_eqEngine.getExplanation(equality[0], equality[1], assertions); + assertions.push_back(assertion); + d_out->conflict(mkAnd(assertions)); + return; } break; } diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h index 14a526e36..ed23bf53f 100644 --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@ -46,7 +46,7 @@ public: } }; - typedef EqualityEngine BvEqualityEngine; + typedef EqualityEngine BvEqualityEngine; private: @@ -69,7 +69,7 @@ private: public: TheoryBV(context::Context* c, OutputChannel& out) : - Theory(THEORY_BV, c, out), d_eqEngine(*this, c), d_sliceManager(*this), d_assertions(c) { + Theory(THEORY_BV, c, out), d_eqEngine(*this, c, "theory::bv::EqualityEngine"), d_sliceManager(*this, c), d_assertions(c) { } BvEqualityEngine& getEqualityEngine() { diff --git a/src/theory/bv/theory_bv_rewrite_rules.h b/src/theory/bv/theory_bv_rewrite_rules.h index 32d0f92a0..5815f2c7f 100644 --- a/src/theory/bv/theory_bv_rewrite_rules.h +++ b/src/theory/bv/theory_bv_rewrite_rules.h @@ -80,7 +80,7 @@ class RewriteRule { /** Constructor */ RuleStatistics() - : d_ruleApplications(getStatName("theory::bv::count"), 0) { + : d_ruleApplications(getStatName("theory::bv::RewriteRules::count"), 0) { StatisticsRegistry::registerStat(&d_ruleApplications); } -- 2.30.2