fixed more equality stuff
authorlianah <lianahady@gmail.com>
Thu, 21 Mar 2013 23:25:33 +0000 (19:25 -0400)
committerlianah <lianahady@gmail.com>
Thu, 21 Mar 2013 23:25:33 +0000 (19:25 -0400)
src/theory/bv/bv_subtheory_core.cpp
src/theory/bv/bv_subtheory_core.h
src/theory/bv/slicer.cpp
src/theory/bv/slicer.h

index d7dab10f985c41f4a247a4c605d323292b76cd6b..2af0e47b8d0be4e4744fdd77ee37f4974d91d186 100644 (file)
@@ -72,6 +72,9 @@ CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv)
   }
 }
 
+CoreSolver::~CoreSolver() {
+  delete d_slicer; 
+}
 void CoreSolver::setMasterEqualityEngine(eq::EqualityEngine* eq) {
   d_equalityEngine.setMasterEqualityEngine(eq);
 }
@@ -99,10 +102,11 @@ void CoreSolver::explain(TNode literal, std::vector<TNode>& assumptions) {
   }
 }
 
-Node CoreSolver::getBaseDecomposition(TNode a, std::vector<TNode>& explanation) {
+Node CoreSolver::getBaseDecomposition(TNode a, std::vector<Node>& explanation) {
   std::vector<Node> a_decomp;
   d_slicer->getBaseDecomposition(a, a_decomp, explanation);
   Node new_a = utils::mkConcat(a_decomp);
+  Debug("bv-slicer") << "CoreSolver::getBaseDecomposition " << a <<" => " << new_a << "\n"; 
   return new_a; 
 }
 
@@ -118,7 +122,7 @@ bool CoreSolver::decomposeFact(TNode fact) {
     TNode b = fact[1];
 
     d_slicer->processEquality(fact); 
-    std::vector<TNode> explanation; 
+    std::vector<Node> explanation; 
     Node new_a = getBaseDecomposition(a, explanation);
     Node new_b = getBaseDecomposition(b, explanation);
 
@@ -157,10 +161,20 @@ bool CoreSolver::decomposeFact(TNode fact) {
     d_slicer->assertEquality(fact); 
   } else {
     // still need to register the terms
+    d_slicer->processEquality(fact[0]);
     TNode a = fact[0][0];
     TNode b = fact[0][1];
-    d_slicer->registerTerm(a);
-    d_slicer->registerTerm(b); 
+    std::vector<Node> explanation_a; 
+    Node new_a = getBaseDecomposition(a, explanation_a);
+    Node reason_a = explanation_a.empty()? mkTrue() : mkAnd(explanation_a);
+    assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, a, new_a), reason_a);
+
+    std::vector<Node> explanation_b; 
+    Node new_b = getBaseDecomposition(b, explanation_b);
+    Node reason_b = explanation_b.empty()? mkTrue() : mkAnd(explanation_b);
+    assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, b, new_b), reason_b);
+    d_reasons.insert(reason_a);
+    d_reasons.insert(reason_b); 
   }
   // finally assert the actual fact to the equality engine
   return assertFactToEqualityEngine(fact, fact);
index f37cf5bf32c2d7ee82f733c1e3838b0ac1999938..4f2d7a27969a301c6222cda9c9a5ad60e025806d 100644 (file)
@@ -67,9 +67,10 @@ class CoreSolver : public SubtheorySolver {
   context::CDHashSet<Node, NodeHashFunction> d_reasons; 
   bool assertFactToEqualityEngine(TNode fact, TNode reason);  
   bool decomposeFact(TNode fact);
-  Node getBaseDecomposition(TNode a, std::vector<TNode>& explanation);
+  Node getBaseDecomposition(TNode a, std::vector<Node>& explanation);
 public: 
   CoreSolver(context::Context* c, TheoryBV* bv);
+  ~CoreSolver();
   bool  isCoreTheory() { return d_isCoreTheory; }
   void  setMasterEqualityEngine(eq::EqualityEngine* eq);
   void  preRegister(TNode node);
@@ -91,6 +92,7 @@ public:
     return EQUALITY_UNKNOWN;
   }
   bool hasTerm(TNode node) const { return d_equalityEngine.hasTerm(node); }
+  void addTermToEqualityEngine(TNode node) { d_equalityEngine.addTerm(node); }
 };
 
 
index 437be9bf4cdb0cabdce40dfe026e572148e83273..5d376ea50e33433d3b6ed90b16ecacc81345a740 100644 (file)
@@ -41,8 +41,11 @@ Base::Base(uint32_t size)
 
   
 void Base::sliceAt(Index index) {
+  if (index == d_size)
+    return; 
+  Assert(index < d_size); 
   Index vector_index = index / 32;
-  Assert (vector_index < d_size); 
+  Assert (vector_index < d_repr.size()); 
   Index int_index = index % 32;
   uint32_t bit_mask = utils::pow2(int_index); 
   d_repr[vector_index] = d_repr[vector_index] | bit_mask; 
@@ -184,6 +187,15 @@ TermId UnionFind::addNode(Index bitwidth) {
 }
 
 TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) {
+  if (isExtractTerm(topLevel)) {
+    ExtractTerm top = getExtractTerm(topLevel);
+    Index top_high = top.high;
+    Index top_low = top.low;
+    Assert (top_high - top_low + 1 > high);
+    high += top_low; 
+    low += top_low;
+    topLevel = top.id; 
+  }
   ExtractTerm extract(topLevel, high, low);
   if (d_extractToId.find(extract) != d_extractToId.end()) {
     return d_extractToId[extract]; 
@@ -292,13 +304,13 @@ void UnionFind::split(TermId id, Index i) {
   Assert (i < getBitwidth(id));
   if (!hasChildren(id)) {
     // first time we split this term 
-    TermId bottom_id = addExtract(getTopLevel(id), i - 1, 0);
-    TermId top_id = addExtract(getTopLevel(id), getBitwidth(id) - 1, i);
+    TermId bottom_id = addExtract(id, i - 1, 0);
+    TermId top_id = addExtract(id, getBitwidth(id) - 1, i);
     setChildren(id, top_id, bottom_id);
     recordOperation(UnionFind::SPLIT, id);
     
     if (d_slicer->termInEqualityEngine(id)) {
-      d_slicer->enqueueSplit(id, i); 
+      d_slicer->enqueueSplit(id, i, top_id, bottom_id); 
     }
   } else {
     Index cut = getCutPoint(id); 
@@ -310,13 +322,13 @@ void UnionFind::split(TermId id, Index i) {
   ++(d_statistics.d_numSplits);
 }
 
-TermId UnionFind::getTopLevel(TermId id) const {
-  __gnu_cxx::hash_map<TermId, ExtractTerm, __gnu_cxx::hash<TermId> >::const_iterator it = d_idToExtract.find(id); 
-  if (it != d_idToExtract.end()) {
-    return (*it).second.id; 
-  }
-  return id; 
-}
+// TermId UnionFind::getTopLevel(TermId id) const {
+//   __gnu_cxx::hash_map<TermId, ExtractTerm, __gnu_cxx::hash<TermId> >::const_iterator it = d_idToExtract.find(id); 
+//   if (it != d_idToExtract.end()) {
+//     return (*it).second.id; 
+//   }
+//   return id; 
+// }
 
 void UnionFind::getNormalForm(const ExtractTerm& term, NormalForm& nf) {
   nf.clear(); 
@@ -576,7 +588,7 @@ ExtractTerm Slicer::registerTerm(TNode node) {
   if (node.getKind() == kind::BITVECTOR_EXTRACT) {
     n = node[0];
     high = utils::getExtractHigh(node);
-    low = utils::getExtractLow(node); 
+    low = utils::getExtractLow(node);
   }
   if (d_nodeToId.find(n) == d_nodeToId.end()) {
     TermId id = d_unionFind.addNode(utils::getSize(n)); 
@@ -584,6 +596,7 @@ ExtractTerm Slicer::registerTerm(TNode node) {
     d_idToNode[id] = n; 
   }
   TermId id = d_nodeToId[n];
+  d_unionFind.addExtract(id, high, low);
   ExtractTerm res(id, high, low); 
   Debug("bv-slicer") << "Slicer::registerTerm " << node << " => " << res.debugPrint() << endl;
   return res; 
@@ -631,7 +644,7 @@ void Slicer::registerEquality(TNode eq) {
   }
 }
 
-void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation) {
+void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<Node>& explanation) {
   Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl;
   
   Index high = utils::getSize(node) - 1;
@@ -655,16 +668,8 @@ void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::ve
     explanation.push_back(exp); 
   }
   
-  // construct actual extract nodes
-  Index size = utils::getSize(node);
-  Index current_low = size - 1;
-  Index current_high = size - 1;
-  
   for (int i = nf.decomp.size() - 1; i>=0 ; --i) {
-    Index current_size = d_unionFind.getBitwidth(nf.decomp[i]);
-    current_low = current_low - current_size; 
-    Node current = Rewriter::rewrite(utils::mkExtract(node, current_high, current_low+1));
-    current_high -= current_size;
+    Node current = getNode(nf.decomp[i]); 
     decomp.push_back(current); 
   }
   
@@ -763,17 +768,16 @@ bool Slicer::hasNode(TermId id) const {
 }
 
 Node Slicer::getNode(TermId id) const {
-  // if it was an extract
-  if (d_unionFind.isExtractTerm(id)) {
-    ExtractTerm extract = d_unionFind.getExtractTerm(id);
-    Assert (hasNode(extract.id)); 
-    TNode node = d_idToNode.find(extract.id)->second;
-    Node ex = utils::mkExtract(node, extract.high, extract.low);
-    return ex; 
+  if (hasNode(id)) {
+    return d_idToNode.find(id)->second; 
   }
-  // otherwise must be a top-level term 
-  Assert (hasNode(id));
-  return (d_idToNode.find(id))->second; 
+  // otherwise must be an extract 
+  Assert (d_unionFind.isExtractTerm(id)); 
+  ExtractTerm extract = d_unionFind.getExtractTerm(id);
+  Assert (hasNode(extract.id)); 
+  TNode node = d_idToNode.find(extract.id)->second;
+  Node ex = utils::mkExtract(node, extract.high, extract.low);
+  return ex; 
 }
 
 bool Slicer::termInEqualityEngine(TermId id) {
@@ -781,13 +785,18 @@ bool Slicer::termInEqualityEngine(TermId id) {
   return d_coreSolver->hasTerm(node); 
 }
 
-void Slicer::enqueueSplit(TermId id, Index i) {
+void Slicer::enqueueSplit(TermId id, Index i, TermId top_id, TermId bottom_id) {
   Node node = getNode(id);
   Node bottom = Rewriter::rewrite(utils::mkExtract(node, i -1 , 0));
   Node top = Rewriter::rewrite(utils::mkExtract(node, utils::getSize(node) - 1, i));
+  // must add terms to equality engine so we get notified when they get split more
+  d_coreSolver->addTermToEqualityEngine(bottom);
+  d_coreSolver->addTermToEqualityEngine(top);
+
   Node eq = utils::mkNode(kind::EQUAL, node, utils::mkConcat(top, bottom));
   d_newSplits.push_back(eq);
-  Debug("bv-slicer") << "Slicer::enqueueSplit " << eq << endl; 
+  Debug("bv-slicer") << "Slicer::enqueueSplit " << eq << endl;
+  Debug("bv-slicer") << "                     " << id << "=" << top_id << " " << bottom_id << endl; 
 }
 
 void Slicer::getNewSplits(std::vector<Node>& splits) {
index f63cf7284c0b5b5e6a3ac02420aa8e19132d7e26..ab2d5e88f8648c7ed146dc303d7898348c3048e4 100644 (file)
@@ -224,7 +224,7 @@ class UnionFind : public context::ContextNotifyObj {
     Assert (id < d_nodes.size());
     return d_nodes[id].hasChildren(); 
   }
-  TermId getTopLevel(TermId id) const;
+  // TermId getTopLevel(TermId id) const;
   
   /// setter methods for the internal nodes
   void setRepr(TermId id, TermId new_repr, ExplanationId reason) {
@@ -338,7 +338,7 @@ public:
       d_coreSolver(coreSolver)
   {}
   
-  void getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation);
+  void getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<Node>& explanation);
   void registerEquality(TNode eq);
   ExtractTerm registerTerm(TNode node);
   void processEquality(TNode eq);
@@ -354,7 +354,7 @@ public:
   ExplanationId getExplanationId(TNode reason) const;
   
   bool termInEqualityEngine(TermId id); 
-  void enqueueSplit(TermId id, Index i);
+  void enqueueSplit(TermId id, Index i, TermId top, TermId bottom);
   void getNewSplits(std::vector<Node>& splits);
   static void splitEqualities(TNode node, std::vector<Node>& equalities);
   static unsigned d_numAddedEqualities;