Added path compression and caching for getBaseDecomposition.
authorLiana Hadarean <lianahady@gmail.com>
Tue, 5 Feb 2013 05:49:39 +0000 (00:49 -0500)
committerLiana Hadarean <lianahady@gmail.com>
Tue, 5 Feb 2013 05:49:39 +0000 (00:49 -0500)
src/theory/bv/bv_subtheory_core.cpp
src/theory/bv/bv_subtheory_core.h
src/theory/bv/slicer.cpp
src/theory/bv/slicer.h

index a3290ff7cebe427bd43bb33db27fab282a8ed37b..e31ab2fdf25af9862b9c7a5cd9de10e57f7791dd 100644 (file)
@@ -33,6 +33,7 @@ CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv, Slicer* slicer)
     d_notify(*this),
     d_equalityEngine(d_notify, c, "theory::bv::TheoryBV"),
     d_assertions(c),
+    d_normalFormCache(), 
     d_slicer(slicer),
     d_isCoreTheory(c, true)
 {
@@ -97,6 +98,19 @@ void CoreSolver::explain(TNode literal, std::vector<TNode>& assumptions) {
   }
 }
 
+Node CoreSolver::getBaseDecomposition(TNode a) {
+  if (d_normalFormCache.find(a) != d_normalFormCache.end()) {
+    return d_normalFormCache[a]; 
+  }
+
+  // otherwise we must compute the normal form
+  std::vector<Node> a_decomp;
+  d_slicer->getBaseDecomposition(a, a_decomp);
+  Node new_a = utils::mkConcat(a_decomp);
+  d_normalFormCache[a] = new_a;
+  return new_a; 
+}
+
 bool CoreSolver::decomposeFact(TNode fact) {
   Debug("bv-slicer") << "CoreSolver::decomposeFact fact=" << fact << endl;  
   // FIXME: are this the right things to assert? 
@@ -107,18 +121,13 @@ bool CoreSolver::decomposeFact(TNode fact) {
   TNode eq = fact.getKind() == kind::NOT? fact[0] : fact; 
 
   TNode a = eq[0];
-  TNode b = eq[1]; 
-  std::vector<Node> a_decomp;
-  std::vector<Node> b_decomp;
-
-  d_slicer->getBaseDecomposition(a, a_decomp);
-  d_slicer->getBaseDecomposition(b, b_decomp);
-
-  Assert (a_decomp.size() == b_decomp.size());
+  TNode b = eq[1];
+  Node new_a = getBaseDecomposition(a);
+  Node new_b = getBaseDecomposition(b); 
+  
+  Assert (utils::getSize(new_a) == utils::getSize(new_b) &&
+          utils::getSize(new_a) == utils::getSize(a)); 
   
-  Node new_a = utils::mkConcat(a_decomp);
-  Node new_b = utils::mkConcat(b_decomp);
-
   NodeManager* nm = NodeManager::currentNM();
   Node a_eq_new_a = nm->mkNode(kind::EQUAL, a, new_a);
   Node b_eq_new_b = nm->mkNode(kind::EQUAL, b, new_b);
@@ -134,10 +143,15 @@ bool CoreSolver::decomposeFact(TNode fact) {
   if (fact.getKind() == kind::EQUAL) {
     // assert the individual equalities as well
     //    a_i == b_i
-    for (unsigned i = 0; i < a_decomp.size(); ++i) {
-      Node eq_i = nm->mkNode(kind::EQUAL, a_decomp[i], b_decomp[i]);
-      ok = assertFact(eq_i, fact);
-      if (!ok) return false; 
+    if (new_a.getKind() == kind::BITVECTOR_CONCAT &&
+        new_b.getKind() == kind::BITVECTOR_CONCAT) {
+      
+      Assert (new_a.getNumChildren() == new_b.getNumChildren()); 
+      for (unsigned i = 0; i < new_a.getNumChildren(); ++i) {
+        Node eq_i = nm->mkNode(kind::EQUAL, new_a[i], new_b[i]);
+        ok = assertFact(eq_i, fact);
+        if (!ok) return false;
+      }
     }
   }
   return true; 
index 20b42d61cb3760706e193360eaf7fc13f0abba9e..38676bfa6b8d2726da8b8a204edcf7f0f00dc501 100644 (file)
@@ -69,11 +69,13 @@ class CoreSolver : public SubtheorySolver {
 
   /** FIXME: for debugging purposes only */
   context::CDList<TNode> d_assertions;
+  __gnu_cxx::hash_map<TNode, Node, TNodeHashFunction> d_normalFormCache; 
   Slicer* d_slicer;
   context::CDO<bool> d_isCoreTheory;
 
   bool assertFact(TNode fact, TNode reason);  
   bool decomposeFact(TNode fact);
+  Node getBaseDecomposition(TNode a); 
 public:
   bool isCoreTheory() {return d_isCoreTheory; }
   CoreSolver(context::Context* c, TheoryBV* bv, Slicer* slicer);
index 87295e8f694f9954dd0648f9551eb77516205797..f41612df315c98431c6b870f6f1904cbd3a1b89f 100644 (file)
@@ -220,10 +220,13 @@ void UnionFind::merge(TermId t1, TermId t2) {
   d_statistics.d_numRepresentatives += -1; 
 }
 
-TermId UnionFind::find(TermId id) const {
+TermId UnionFind::find(TermId id) {
   TermId repr = getRepr(id); 
-  if (repr != UndefinedId)
-    return find(repr);
+  if (repr != UndefinedId) {
+    TermId find_id =  find(repr);
+    setRepr(id, find_id);
+    return find_id; 
+  }
   return id; 
 }
 /** 
index b0929d617d0668efd2b50537a73c2f686db3d2b1..55cecb117f8b1f0b4031cad9be22aa13895d231b 100644 (file)
@@ -212,7 +212,7 @@ public:
   TermId addTerm(Index bitwidth);
   void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2); 
   void merge(TermId t1, TermId t2);
-  TermId find(TermId t1) const 
+  TermId find(TermId t1); 
   void split(TermId term, Index i);
 
   void getNormalForm(const ExtractTerm& term, NormalForm& nf);