Assert (! hasChildren(t1) && ! hasChildren(t2));
setRepr(t1, t2);
+ recordOperation(UnionFind::MERGE, t1);
d_representatives.erase(t1);
d_statistics.d_numRepresentatives += -1;
}
TermId repr = getRepr(id);
if (repr != UndefinedId) {
TermId find_id = find(repr);
- setRepr(id, find_id);
+ // setRepr(id, find_id);
return find_id;
}
return id;
TermId bottom_id = addTerm(i);
TermId top_id = addTerm(getBitwidth(id) - i);
setChildren(id, top_id, bottom_id);
+ recordOperation(UnionFind::SPLIT, id);
} else {
Index cut = getCutPoint(id);
split(id, term.low);
}
+void UnionFind::backtrack() {
+ for (int i = d_undoStack.size() -1; i >= d_undoStackIndex; ++i) {
+ Operation op = d_undoStack.back();
+ d_undoStack.pop_back();
+ if (op.op == UnionFind::MERGE) {
+ undoMerge(op.id);
+ } else {
+ Assert (op.op == UnionFind::SPLIT);
+ undoSplit(op.id);
+ }
+ }
+}
+
+void UnionFind::undoMerge(TermId id) {
+ Node& node = getNode(id);
+ Assert (getRepr(id) != id);
+ setRepr(id, id);
+}
+
+void UnionFind::undoSplit(TermId id) {
+ Node& node = getNode(id);
+ Assert (hasChildren(node));
+ setChildren(id, UndefindId, UndefinedId);
+}
+
+void UnionFind::recordOperation(OperationKind op, TermId term) {
+ ++d_undoStackIndex;
+ d_undoStack.push_back(Operation(op, term));
+ Assert (d_undoStack.size() == d_undoStackIndex);
+}
+
/**
* Slicer
*
#include "util/index.h"
#include "expr/node.h"
#include "theory/bv/theory_bv_utils.h"
+#include "context/context.h"
+#include "context/cdhashset.h"
+#include "context/cdo.h"
+
#ifndef __CVC4__THEORY__BV__SLICER_BV_H
#define __CVC4__THEORY__BV__SLICER_BV_H
* UnionFind
*
*/
-typedef __gnu_cxx::hash_set<TermId> TermSet;
+typedef context::CDHashSet<uint32_t> CDTermSet;
typedef std::vector<TermId> Decomposition;
struct ExtractTerm {
};
-class UnionFind {
+class UnionFind : public context::ContextNotifyObj {
class Node {
Index d_bitwidth;
TermId d_ch1, d_ch0;
/// map from TermId to the nodes that represent them
std::vector<Node> d_nodes;
/// a term is in this set if it is its own representative
- TermSet d_representatives;
+ CDTermSet d_representatives;
void getDecomposition(const ExtractTerm& term, Decomposition& decomp);
void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common);
d_nodes[id].setChildren(ch1, ch0);
}
+ /* Backtracking mechanisms */
+
+ enum OperationKind {
+ MERGE,
+ SPLIT
+ };
+
+ struct Operation {
+ OperationKind op;
+ TermId id;
+ Operation(OperationKind o, TermId i)
+ : op(o), id(i) {}
+ };
+
+ std::vector<Operation> d_undoStack;
+ context::CDO<unsigned> d_undoStackIndex;
+
+ void backtrack();
+ void undoMerge(TermId id);
+ void undoSplit(TermId id);
+ void recordOperation(OperationKind op, TermId term);
+
class Statistics {
public:
IntStat d_numNodes;
IntStat d_numMerges;
AverageStat d_avgFindDepth;
ReferenceStat<unsigned> d_numAddedEqualities;
- //IntStat d_numAddedEqualities;
Statistics();
~Statistics();
};
- Statistics d_statistics
-;
-
+ Statistics d_statistics;
public:
- UnionFind()
+ UnionFind(context::Context* ctx)
: d_nodes(),
- d_representatives()
+ d_representatives(ctx),
+ d_undoStack(),
+ d_undoStackIndex(ctx),
+ d_statistics()
{}
TermId addTerm(Index bitwidth);
return d_nodes[id].getBitwidth();
}
std::string debugPrint(TermId id);
+
+ void contextNotifyPop() {
+ backtrack();
+ }
+
friend class Slicer;
};
UnionFind d_unionFind;
ExtractTerm registerTerm(TNode node);
public:
- Slicer()
+ Slicer(context::Context* ctx)
: d_idToNode(),
d_nodeToId(),
d_coreTermCache(),
- d_unionFind()
+ d_unionFind(ctx)
{}
void getBaseDecomposition(TNode node, std::vector<Node>& decomp);
void processEquality(TNode eq);
bool isCoreTerm (TNode node);
+
static void splitEqualities(TNode node, std::vector<Node>& equalities);
static unsigned d_numAddedEqualities;
};