fixed some explanation problems for the core theory; still slow
authorlianah <lianahady@gmail.com>
Sat, 23 Mar 2013 17:40:29 +0000 (13:40 -0400)
committerlianah <lianahady@gmail.com>
Sat, 23 Mar 2013 17:40:29 +0000 (13:40 -0400)
src/theory/bv/bv_subtheory_core.cpp
src/theory/bv/bv_subtheory_core.h
src/theory/bv/slicer.cpp
src/theory/bv/slicer.h
src/theory/bv/theory_bv_utils.h

index 2af0e47b8d0be4e4744fdd77ee37f4974d91d186..6f5fd4119d7539211920b64cc924ddc7fd9498d7 100644 (file)
@@ -102,7 +102,7 @@ void CoreSolver::explain(TNode literal, std::vector<TNode>& assumptions) {
   }
 }
 
-Node CoreSolver::getBaseDecomposition(TNode a, std::vector<Node>& explanation) {
+Node CoreSolver::getBaseDecomposition(TNode a, std::vector<TNode>& explanation) {
   std::vector<Node> a_decomp;
   d_slicer->getBaseDecomposition(a, a_decomp, explanation);
   Node new_a = utils::mkConcat(a_decomp);
@@ -122,28 +122,35 @@ bool CoreSolver::decomposeFact(TNode fact) {
     TNode b = fact[1];
 
     d_slicer->processEquality(fact); 
-    std::vector<Node> explanation; 
-    Node new_a = getBaseDecomposition(a, explanation);
-    Node new_b = getBaseDecomposition(b, explanation);
+    std::vector<TNode> explanation_a; 
+    Node new_a = getBaseDecomposition(a, explanation_a);
+    Node reason_a = mkAnd(explanation_a);
+    d_reasons.insert(reason_a);
+    
+    std::vector<TNode> explanation_b; 
+    Node new_b = getBaseDecomposition(b, explanation_b);
+    Node reason_b = mkAnd(explanation_b);
+    d_reasons.insert(reason_b);
 
+    std::vector<Node> explanation; 
     explanation.push_back(fact);
+    explanation.insert(explanation.end(), explanation_a.begin(), explanation_a.end());
+    explanation.insert(explanation.end(), explanation_b.begin(), explanation_b.end());
+    
     Node reason = utils::mkAnd(explanation); 
     d_reasons.insert(reason);
     
     Assert (utils::getSize(new_a) == utils::getSize(new_b) &&
             utils::getSize(new_a) == utils::getSize(a)); 
-    // FIXME: do we still need to assert these? 
+
     NodeManager* nm = NodeManager::currentNM();
     Node a_eq_new_a = nm->mkNode(kind::EQUAL, a, new_a);
     Node b_eq_new_b = nm->mkNode(kind::EQUAL, b, new_b);
 
-    d_reasons.insert(a_eq_new_a);
-    d_reasons.insert(b_eq_new_b); 
-    
     bool ok = true; 
-    ok = assertFactToEqualityEngine(a_eq_new_a, utils::mkTrue());
+    ok = assertFactToEqualityEngine(a_eq_new_a, reason_a);
     if (!ok) return false; 
-    ok = assertFactToEqualityEngine(b_eq_new_b, utils::mkTrue());
+    ok = assertFactToEqualityEngine(b_eq_new_b, reason_a);
     if (!ok) return false; 
     // assert the individual equalities as well
     //    a_i == b_i
@@ -152,6 +159,7 @@ bool CoreSolver::decomposeFact(TNode fact) {
       Assert (new_a.getNumChildren() == new_b.getNumChildren()); 
       for (unsigned i = 0; i < new_a.getNumChildren(); ++i) {
         Node eq_i = nm->mkNode(kind::EQUAL, new_a[i], new_b[i]);
+        // this reason is not very precise!!
         ok = assertFactToEqualityEngine(eq_i, reason);
         d_reasons.insert(eq_i); 
         if (!ok) return false;
@@ -164,15 +172,16 @@ bool CoreSolver::decomposeFact(TNode fact) {
     d_slicer->processEquality(fact[0]);
     TNode a = fact[0][0];
     TNode b = fact[0][1];
-    std::vector<Node> explanation_a; 
+    std::vector<TNode> explanation_a; 
     Node new_a = getBaseDecomposition(a, explanation_a);
     Node reason_a = explanation_a.empty()? mkTrue() : mkAnd(explanation_a);
     assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, a, new_a), reason_a);
 
-    std::vector<Node> explanation_b; 
+    std::vector<TNode> explanation_b; 
     Node new_b = getBaseDecomposition(b, explanation_b);
     Node reason_b = explanation_b.empty()? mkTrue() : mkAnd(explanation_b);
     assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, b, new_b), reason_b);
+
     d_reasons.insert(reason_a);
     d_reasons.insert(reason_b); 
   }
@@ -279,13 +288,16 @@ void CoreSolver::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t2) {
 bool CoreSolver::storePropagation(TNode literal) {
   return d_bv->storePropagation(literal, SUB_CORE);
 }
-  
+
 void CoreSolver::conflict(TNode a, TNode b) {
   std::vector<TNode> assumptions;
   d_equalityEngine.explainEquality(a, b, true, assumptions);
-  d_bv->setConflict(mkAnd(assumptions));
+  Node conflict = flattenAnd(assumptions);
+  d_bv->setConflict(conflict);
 }
 
+
+
 void CoreSolver::collectModelInfo(TheoryModel* m) {
   if (Debug.isOn("bitvector-model")) {
     context::CDQueue<Node>::const_iterator it = d_assertionQueue.begin();
index 4f2d7a27969a301c6222cda9c9a5ad60e025806d..868f3754f370583e3a1d084f3a4bf773977a2b78 100644 (file)
@@ -67,7 +67,7 @@ class CoreSolver : public SubtheorySolver {
   context::CDHashSet<Node, NodeHashFunction> d_reasons; 
   bool assertFactToEqualityEngine(TNode fact, TNode reason);  
   bool decomposeFact(TNode fact);
-  Node getBaseDecomposition(TNode a, std::vector<Node>& explanation);
+  Node getBaseDecomposition(TNode a, std::vector<TNode>& explanation);
 public: 
   CoreSolver(context::Context* c, TheoryBV* bv);
   ~CoreSolver();
index 5d376ea50e33433d3b6ed90b16ecacc81345a740..b24702635cfff705f5550e3615406db3de2d1ea2 100644 (file)
@@ -156,11 +156,11 @@ std::string NormalForm::debugPrint(const UnionFind& uf) const {
   return os.str(); 
 }
 /**
- * UnionFind::Node
+ * UnionFind::EqualityNode
  * 
  */
 
-std::string UnionFind::Node::debugPrint() const {
+std::string UnionFind::EqualityNode::debugPrint() const {
   ostringstream os;
   os << "Repr " << d_edge.repr << " ["<< d_bitwidth << "] ";
   os << "( " << d_ch1 <<", " << d_ch0 << ")" << endl; 
@@ -172,41 +172,80 @@ std::string UnionFind::Node::debugPrint() const {
  * UnionFind
  * 
  */
-TermId UnionFind::addNode(Index bitwidth) {
+
+TermId UnionFind::registerTopLevelTerm(Index bitwidth) {
+  TermId id = mkEqualityNode(bitwidth);
+  d_topLevelIds.insert(id);
+  return id; 
+}
+
+TermId UnionFind::mkEqualityNode(Index bitwidth) {
   Assert (bitwidth > 0); 
-  Node node(bitwidth);
-  d_nodes.push_back(node);
+  EqualityNode node(bitwidth);
+  d_equalityNodes.push_back(node);
   
   ++(d_statistics.d_numNodes);
   
-  TermId id = d_nodes.size() - 1; 
+  TermId id = d_equalityNodes.size() - 1; 
   //  d_representatives.insert(id);
   ++(d_statistics.d_numRepresentatives); 
   Debug("bv-slicer-uf") << "UnionFind::addTerm " << id << " size " << bitwidth << endl;
   return id; 
 }
-
-TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) {
-  if (isExtractTerm(topLevel)) {
-    ExtractTerm top = getExtractTerm(topLevel);
-    Index top_high = top.high;
-    Index top_low = top.low;
-    Assert (top_high - top_low + 1 > high);
-    high += top_low; 
-    low += top_low;
-    topLevel = top.id; 
+/** 
+ * Create an extract term making sure there are no nested extracts. 
+ * 
+ * @param id 
+ * @param high 
+ * @param low 
+ * 
+ * @return 
+ */
+ExtractTerm UnionFind::mkExtractTerm(TermId id, Index high, Index low) {
+  if (d_topLevelIds.find(id) != d_topLevelIds.end()) {
+    return ExtractTerm(id, high, low); 
   }
-  ExtractTerm extract(topLevel, high, low);
+  Assert (isExtractTerm(id)); 
+  ExtractTerm top = getExtractTerm(id);
+  Assert (d_topLevelIds.find(top.id) != d_topLevelIds.end());
+  
+  Index top_high = top.high;
+  Index top_low = top.low;
+  Assert (top_high - top_low + 1 > high);
+  high += top_low; 
+  low += top_low;
+  id = top.id; 
+  return ExtractTerm(id, high, low); 
+}
+
+/** 
+ * Associate the given extract term with the given id. 
+ * 
+ * @param id 
+ * @param extract 
+ */
+void UnionFind::storeExtractTerm(TermId id, const ExtractTerm& extract) {
   if (d_extractToId.find(extract) != d_extractToId.end()) {
-    return d_extractToId[extract]; 
+    Assert (d_extractToId[extract] == id); 
+    return; 
   }
-
-  Assert (high >= low); 
-  
-  TermId id = addNode(high - low + 1); 
+  Debug("bv-slicer") << "UnionFind::storeExtract " << extract.debugPrint() << " => id" << id << "\n";  
   d_idToExtract[id] = extract;
   d_extractToId[extract] = id; 
-  return id; 
+ }
+
+TermId UnionFind::addEqualityNode(unsigned bitwidth, TermId id, Index high, Index low) {
+  ExtractTerm extract(id, high, low);
+  if (d_extractToId.find(extract) != d_extractToId.end()) {
+    // if the extract already exists we don't need to make a new node
+    TermId extract_id = d_extractToId[extract];
+    Assert (extract_id < d_equalityNodes.size());
+    return extract_id; 
+  }
+  // otherwise make an equality node for it and store the extract
+  TermId node_id = mkEqualityNode(bitwidth);
+  storeExtractTerm(node_id, extract);
+  return node_id; 
 }
 
 /** 
@@ -215,7 +254,10 @@ TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) {
  * @param t1 
  * @param t2 
  */
-void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason) {
+void UnionFind::unionTerms(TermId id1, TermId id2, TermId reason) {
+  const ExtractTerm& t1 = getExtractTerm(id1);
+  const ExtractTerm& t2 = getExtractTerm(id2);
+  
   Debug("bv-slicer") << "UnionFind::unionTerms " << t1.debugPrint() << " and \n"
                      << "                      " << t2.debugPrint() << "\n"
                      << " with reason " << reason << endl;
@@ -294,7 +336,7 @@ TermId UnionFind::findWithExplanation(TermId id, std::vector<ExplanationId>& exp
 void UnionFind::split(TermId id, Index i) {
   Debug("bv-slicer-uf") << "UnionFind::split " << id << " at " << i << endl;
   id = find(id); 
-  Debug("bv-slicer-uf") << "   node: " << d_nodes[id].debugPrint() << endl;
+  Debug("bv-slicer-uf") << "   node: " << d_equalityNodes[id].debugPrint() << endl;
 
   if (i == 0 || i == getBitwidth(id)) {
     // nothing to do 
@@ -303,9 +345,15 @@ void UnionFind::split(TermId id, Index i) {
 
   Assert (i < getBitwidth(id));
   if (!hasChildren(id)) {
-    // first time we split this term 
-    TermId bottom_id = addExtract(id, i - 1, 0);
-    TermId top_id = addExtract(id, getBitwidth(id) - 1, i);
+    // first time we split this term
+    ExtractTerm bottom_extract = mkExtractTerm(id, i-1, 0);
+    ExtractTerm top_extract = mkExtractTerm(id, getBitwidth(id) - 1, i);
+    
+    TermId bottom_id = extractHasId(bottom_extract)? getExtractId(bottom_extract) : mkEqualityNode(i); 
+    TermId top_id = extractHasId(top_extract)? getExtractId(top_extract) : mkEqualityNode(getBitwidth(id) - i);
+    storeExtractTerm(bottom_id, bottom_extract);
+    storeExtractTerm(top_id, top_extract); 
+    
     setChildren(id, top_id, bottom_id);
     recordOperation(UnionFind::SPLIT, id);
     
@@ -471,7 +519,10 @@ void UnionFind::handleCommonSlice(const Decomposition& decomp1, const Decomposit
 
 }
 
-void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2) {
+void UnionFind::alignSlicings(TermId id1, TermId id2) {
+  const ExtractTerm& term1 = getExtractTerm(id1);
+  const ExtractTerm& term2 = getExtractTerm(id2);
+  
   Debug("bv-slicer") << "UnionFind::alignSlicings " << term1.debugPrint() << endl;
   Debug("bv-slicer") << "                         " << term2.debugPrint() << endl;
   NormalForm nf1(term1.getBitwidth());
@@ -519,15 +570,18 @@ void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2
     }
   } while (changed); 
 }
+
+
 /** 
  * Given an extract term a[i:j] makes sure a is sliced
  * at indices i and j. 
  * 
  * @param term 
  */
-void UnionFind::ensureSlicing(const ExtractTerm& term) {
+void UnionFind::ensureSlicing(TermId t) {
+  ExtractTerm term = getExtractTerm(t); 
   //Debug("bv-slicer") << "Slicer::ensureSlicing " << term.debugPrint() << endl;
-  TermId id = find(term.id);
+  TermId id = term.id; 
   split(id, term.high + 1);
   split(id, term.low);
 }
@@ -576,30 +630,69 @@ void UnionFind::getBase(TermId id, Base& base, Index offset) {
   getBase(id0, base, offset); 
 }
 
+/// getter methods for the internal nodes
+TermId UnionFind::getRepr(TermId id)  const {
+  Assert (id < d_equalityNodes.size());
+  return d_equalityNodes[id].getRepr(); 
+}
+ExplanationId UnionFind::getReason(TermId id) const {
+  Assert (id < d_equalityNodes.size());
+  return d_equalityNodes[id].getReason(); 
+}
+TermId UnionFind::getChild(TermId id, Index i) const {
+  Assert (id < d_equalityNodes.size());
+  return d_equalityNodes[id].getChild(i); 
+}
+Index UnionFind::getCutPoint(TermId id) const {
+  return getBitwidth(getChild(id, 0)); 
+}
+bool UnionFind::hasChildren(TermId id) const {
+  Assert (id < d_equalityNodes.size());
+  return d_equalityNodes[id].hasChildren(); 
+}
+  
+/// setter methods for the internal nodes
+void UnionFind::setRepr(TermId id, TermId new_repr, ExplanationId reason) {
+  Assert (id < d_equalityNodes.size());
+  d_equalityNodes[id].setRepr(new_repr, reason); 
+}
+void UnionFind::setChildren(TermId id, TermId ch1, TermId ch0) {
+  Assert ((ch1 == UndefinedId && ch0 == UndefinedId) ||
+          (id < d_equalityNodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0)));
+  d_equalityNodes[id].setChildren(ch1, ch0); 
+}
+
 
 /**
  * Slicer
  * 
  */
 
-ExtractTerm Slicer::registerTerm(TNode node) {
-  Index low = 0, high = utils::getSize(node) - 1; 
-  TNode n = node; 
+TermId Slicer::registerTerm(TNode node) {
   if (node.getKind() == kind::BITVECTOR_EXTRACT) {
-    n = node[0];
-    high = utils::getExtractHigh(node);
-    low = utils::getExtractLow(node);
-  }
-  if (d_nodeToId.find(n) == d_nodeToId.end()) {
-    TermId id = d_unionFind.addNode(utils::getSize(n)); 
-    d_nodeToId[n] = id;
-    d_idToNode[id] = n; 
+    TNode n = node[0];
+    TermId top_id = registerTopLevelTerm(n);
+    Index high = utils::getExtractHigh(node);
+    Index low = utils::getExtractLow(node);
+    TermId id = d_unionFind.addEqualityNode(utils::getSize(node), top_id, high, low); 
+    return id; 
+  }
+  TermId id = registerTopLevelTerm(node);
+  return id; 
+}
+
+TermId Slicer::registerTopLevelTerm(TNode node) {
+  Assert (node.getKind() != kind::BITVECTOR_EXTRACT ||
+          node.getKind() != kind::BITVECTOR_CONCAT);
+  
+  if (d_nodeToId.find(node) == d_nodeToId.end()) {
+    TermId id = d_unionFind.registerTopLevelTerm(utils::getSize(node)); 
+    d_idToNode[id] = node; 
+    d_nodeToId[node] = id;
+    Debug("bv-slicer") << "Slicer::registerTopLevelTerm " << node << " => id" << id << endl;
+    return id; 
   }
-  TermId id = d_nodeToId[n];
-  d_unionFind.addExtract(id, high, low);
-  ExtractTerm res(id, high, low); 
-  Debug("bv-slicer") << "Slicer::registerTerm " << node << " => " << res.debugPrint() << endl;
-  return res; 
+  return d_nodeToId[node]; 
 }
 
 void Slicer::processEquality(TNode eq) {
@@ -609,42 +702,38 @@ void Slicer::processEquality(TNode eq) {
   Assert (eq.getKind() == kind::EQUAL);
   TNode a = eq[0];
   TNode b = eq[1];
-  ExtractTerm a_ex= registerTerm(a);
-  ExtractTerm b_ex= registerTerm(b);
+  TermId a_id = registerTerm(a);
+  TermId b_id = registerTerm(b);
   
-  d_unionFind.ensureSlicing(a_ex);
-  d_unionFind.ensureSlicing(b_ex);
+  d_unionFind.ensureSlicing(a_id);
+  d_unionFind.ensureSlicing(b_id);
   
-  d_unionFind.alignSlicings(a_ex, b_ex);
+  d_unionFind.alignSlicings(a_id, b_id);
 
-  Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl;
-  Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl;
-  Debug("bv-slicer") << "Slicer::processEquality done. " << endl;
+  // Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl;
+  // Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl;
+  // Debug("bv-slicer") << "Slicer::processEquality done. " << endl;
 }
 
 void Slicer::assertEquality(TNode eq) {
   Assert (eq.getKind() == kind::EQUAL);
-  ExtractTerm a = registerTerm(eq[0]);
-  ExtractTerm b = registerTerm(eq[1]);
+  TermId a = registerTerm(eq[0]);
+  TermId b = registerTerm(eq[1]);
   ExplanationId reason = getExplanationId(eq); 
   d_unionFind.unionTerms(a, b, reason); 
 }
 
-TermId Slicer::getId(TNode node) const {
-  __gnu_cxx::hash_map<Node, TermId, NodeHashFunction >::const_iterator it = d_nodeToId.find(node);
-  Assert (it != d_nodeToId.end());
-  return it->second; 
-}
 
 void Slicer::registerEquality(TNode eq) {
   if (d_explanationToId.find(eq) == d_explanationToId.end()) {
     ExplanationId id = d_explanations.size(); 
     d_explanations.push_back(eq);
-    d_explanationToId[eq] = id; 
+    d_explanationToId[eq] = id;
+    Debug("bv-slicer-explanation") << "Slicer::registerEquality " << eq << " => id"<< id << "\n";  
   }
 }
 
-void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<Node>& explanation) {
+void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation) {
   Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl;
   
   Index high = utils::getSize(node) - 1;
@@ -672,13 +761,18 @@ void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::ve
     Node current = getNode(nf.decomp[i]); 
     decomp.push_back(current); 
   }
-  
-
-  Debug("bv-slicer") << "as [";
-  for (unsigned i = 0; i < decomp.size(); ++i) {
-    Debug("bv-slicer") << decomp[i] <<" "; 
+  if (Debug.isOn("bv-slicer-explanation")) {
+    Debug("bv-slicer-explanation") << "Slicer::getBaseDecomposition for " << node << "\n"
+                                   << "as ";
+    for (unsigned i = 0; i < decomp.size(); ++i) {
+      Debug("bv-slicer-explanation") << decomp[i] <<" " ; 
+    }
+    Debug("bv-slicer-explanation") << "\n Explanation : \n";
+    for (unsigned i = 0; i < explanation.size(); ++i) {
+      Debug("bv-slicer-explanation") << "   " << explanation[i] << "\n"; 
+    }
+    
   }
-  Debug("bv-slicer") << "]" << endl;
 
 }
 
@@ -754,6 +848,10 @@ void Slicer::splitEqualities(TNode node, std::vector<Node>& equalities) {
 
 
 ExtractTerm UnionFind::getExtractTerm(TermId id) const {
+  if (d_topLevelIds.find(id) != d_topLevelIds.end()) {
+    // if it's a top level term so we don't have an extract stored for it
+    return ExtractTerm(id, getBitwidth(id) - 1, 0); 
+  }
   Assert (isExtractTerm(id));
   
   return (d_idToExtract.find(id))->second; 
@@ -763,19 +861,21 @@ bool UnionFind::isExtractTerm(TermId id) const {
   return d_idToExtract.find(id) != d_idToExtract.end();
 }
 
-bool Slicer::hasNode(TermId id) const {
+bool Slicer::isTopLevelNode(TermId id) const {
   return d_idToNode.find(id) != d_idToNode.end(); 
 }
 
 Node Slicer::getNode(TermId id) const {
-  if (hasNode(id)) {
+  if (isTopLevelNode(id)) {
     return d_idToNode.find(id)->second; 
   }
-  // otherwise must be an extract 
   Assert (d_unionFind.isExtractTerm(id)); 
-  ExtractTerm extract = d_unionFind.getExtractTerm(id);
-  Assert (hasNode(extract.id)); 
+  const ExtractTerm& extract = d_unionFind.getExtractTerm(id);
+  Assert (isTopLevelNode(extract.id)); 
   TNode node = d_idToNode.find(extract.id)->second;
+  if (extract.high == utils::getSize(node) -1 && extract.low == 0) {
+    return node; 
+  }
   Node ex = utils::mkExtract(node, extract.high, extract.low);
   return ex; 
 }
index ab2d5e88f8648c7ed146dc303d7898348c3048e4..c46ef99edb57dce885769e628c7d0c4edab486de 100644 (file)
@@ -161,13 +161,13 @@ class UnionFind : public context::ContextNotifyObj {
     {}
   }; 
     
-  class Node {
+  class EqualityNode {
     Index d_bitwidth;  
     TermId d_ch1, d_ch0; // the ids of the two children if they exist
     ReprEdge d_edge;     // points to the representative and stores the explanation
     
   public:
-    Node(Index b)
+    EqualityNode(Index b)
   : d_bitwidth(b),
     d_ch1(UndefinedId),
     d_ch0(UndefinedId), 
@@ -189,54 +189,36 @@ class UnionFind : public context::ContextNotifyObj {
       d_edge.reason = reason; 
     }
     void setChildren(TermId ch1, TermId ch0) {
-      // Assert (d_repr == UndefinedId && !hasChildren());
       d_ch1 = ch1;
       d_ch0 = ch0; 
     }
     std::string debugPrint() const;
   };
+
+  // the equality nodes in the union find
+  std::vector<EqualityNode> d_equalityNodes;
+  
+  /// getter methods for the internal nodes
+  TermId getRepr(TermId id)  const;
+  ExplanationId getReason(TermId id) const;
+  TermId getChild(TermId id, Index i) const;
+  Index getCutPoint(TermId id) const;
+  bool hasChildren(TermId id) const;
   
-  /// map from TermId to the nodes that represent them 
-  std::vector<Node> d_nodes;
+  /// setter methods for the internal nodes
+  void setRepr(TermId id, TermId new_repr, ExplanationId reason);
+  void setChildren(TermId id, TermId ch1, TermId ch0); 
+
+  // the mappings between ExtractTerms and ids
   __gnu_cxx::hash_map<TermId, ExtractTerm, __gnu_cxx::hash<TermId> > d_idToExtract;
   __gnu_cxx::hash_map<ExtractTerm, TermId, ExtractTermHashFunction > d_extractToId;
+
+  __gnu_cxx::hash_set<TermId> d_topLevelIds;
   
   void getDecomposition(const ExtractTerm& term, Decomposition& decomp);
   void getDecompositionWithExplanation(const ExtractTerm& term, Decomposition& decomp, std::vector<ExplanationId>& explanation);
   void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common);
-  /// getter methods for the internal nodes
-  TermId getRepr(TermId id)  const {
-    Assert (id < d_nodes.size());
-    return d_nodes[id].getRepr(); 
-  }
-  ExplanationId getReason(TermId id) const {
-    Assert (id < d_nodes.size());
-    return d_nodes[id].getReason(); 
-  }
-  TermId getChild(TermId id, Index i) const {
-    Assert (id < d_nodes.size());
-    return d_nodes[id].getChild(i); 
-  }
-  Index getCutPoint(TermId id) const {
-    return getBitwidth(getChild(id, 0)); 
-  }
-  bool hasChildren(TermId id) const {
-    Assert (id < d_nodes.size());
-    return d_nodes[id].hasChildren(); 
-  }
-  // TermId getTopLevel(TermId id) const;
   
-  /// setter methods for the internal nodes
-  void setRepr(TermId id, TermId new_repr, ExplanationId reason) {
-    Assert (id < d_nodes.size());
-    d_nodes[id].setRepr(new_repr, reason); 
-  }
-  void setChildren(TermId id, TermId ch1, TermId ch0) {
-    Assert ((ch1 == UndefinedId && ch0 == UndefinedId) ||
-            (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0)));
-    d_nodes[id].setChildren(ch1, ch0); 
-  }
-
   /* Backtracking mechanisms */
 
   enum OperationKind {
@@ -271,36 +253,44 @@ class UnionFind : public context::ContextNotifyObj {
     ~Statistics();
   };
   Statistics d_statistics;
-  Slicer* d_slicer; 
+  Slicer* d_slicer;
+  TermId d_termIdCount; 
+
+  TermId mkEqualityNode(Index bitwidth);
+  ExtractTerm mkExtractTerm(TermId id, Index high, Index low); 
+  void storeExtractTerm(Index id, const ExtractTerm& term);
+  ExtractTerm getExtractTerm(TermId id) const;
+  bool extractHasId(const ExtractTerm& ex) const { return d_extractToId.find(ex) != d_extractToId.end(); }
+  TermId getExtractId(const ExtractTerm& ex) const {Assert (extractHasId(ex)); return d_extractToId.find(ex)->second; }
+  bool isExtractTerm(TermId id) const; 
 public:
   UnionFind(context::Context* ctx, Slicer* slicer)
     : ContextNotifyObj(ctx), 
-      d_nodes(),
+      d_equalityNodes(),
       d_idToExtract(),
-      d_extractToId(), 
+      d_extractToId(),
+      d_topLevelIds(),
       d_undoStack(),
       d_undoStackIndex(ctx),
       d_statistics(),
-      d_slicer(slicer)
+      d_slicer(slicer),
+      d_termIdCount(0)
   {}
 
-  TermId addNode(Index bitwidth);
-  TermId addExtract(Index topLevel, Index high, Index low);
-  ExtractTerm getExtractTerm(TermId id) const;
-  bool isExtractTerm(TermId id) const; 
-  
-  void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason); 
+  TermId addEqualityNode(unsigned bitwidth, TermId id, Index high, Index low);
+  TermId registerTopLevelTerm(Index bitwidth);  
+  void unionTerms(TermId id1, TermId id2, TermId reason); 
   void merge(TermId t1, TermId t2, TermId reason);
   TermId find(TermId t1);
   TermId findWithExplanation(TermId id, std::vector<ExplanationId>& explanation); 
   void split(TermId term, Index i);
   void getNormalForm(const ExtractTerm& term, NormalForm& nf);
   void getNormalFormWithExplanation(const ExtractTerm& term, NormalForm& nf, std::vector<ExplanationId>& explanation); 
-  void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2);
-  void ensureSlicing(const ExtractTerm& term);
+  void alignSlicings(TermId id1, TermId id2);
+  void ensureSlicing(TermId id);
   Index getBitwidth(TermId id) const {
-    Assert (id < d_nodes.size());
-    return d_nodes[id].getBitwidth(); 
+    Assert (id < d_equalityNodes.size());
+    return d_equalityNodes[id].getBitwidth(); 
   }
   void getBase(TermId id, Base& base, Index offset); 
   std::string debugPrint(TermId id);
@@ -314,17 +304,19 @@ public:
 class CoreSolver; 
 
 class Slicer {
-  __gnu_cxx::hash_map<TermId, TNode, __gnu_cxx::hash<TermId> > d_idToNode;
-  __gnu_cxx::hash_map<Node, TermId, NodeHashFunction> d_nodeToId;
-  __gnu_cxx::hash_map<Node, bool, NodeHashFunction> d_coreTermCache;
-  __gnu_cxx::hash_map<Node, ExplanationId, NodeHashFunction> d_explanationToId;
-  std::vector<Node> d_explanations; 
+  __gnu_cxx::hash_map<TermId, TNode> d_idToNode; 
+  __gnu_cxx::hash_map<TNode, TermId, TNodeHashFunction> d_nodeToId;
+  __gnu_cxx::hash_map<TNode, bool, TNodeHashFunction> d_coreTermCache;
+  __gnu_cxx::hash_map<TNode, ExplanationId, NodeHashFunction> d_explanationToId;
+  std::vector<TNode> d_explanations; 
   UnionFind d_unionFind;
 
   context::CDQueue<Node> d_newSplits;
   context::CDO<unsigned>  d_newSplitsIndex;
   CoreSolver* d_coreSolver;
-  TermId d_termIdCount; 
+  TermId registerTopLevelTerm(TNode node); 
+  bool isTopLevelNode(TermId id) const; 
+  TermId registerTerm(TNode node);
 public:
   Slicer(context::Context* ctx, CoreSolver* coreSolver)
     : d_idToNode(),
@@ -338,16 +330,15 @@ public:
       d_coreSolver(coreSolver)
   {}
   
-  void getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<Node>& explanation);
+  void getBaseDecomposition(TNode node, std::vector<Node>& decomp, std::vector<TNode>& explanation);
   void registerEquality(TNode eq);
-  ExtractTerm registerTerm(TNode node);
+
   void processEquality(TNode eq);
   void assertEquality(TNode eq);
   bool isCoreTerm (TNode node);
 
   bool hasNode(TermId id) const; 
   Node  getNode(TermId id) const;
-  TermId getId(TNode node) const;
 
   bool hasExplanation(ExplanationId id) const;
   TNode getExplanation(ExplanationId id) const;
index e5a7bbb840e278e69c0a8806a01aa9a1498d27a5..98bc8041d7c071c1f7b5a6330fa9abeecdf0c018 100644 (file)
@@ -69,28 +69,6 @@ inline Node mkVar(unsigned size) {
   return nm->mkSkolem("bv_$$", nm->mkBitVectorType(size), "is a variable created by the theory of bitvectors"); 
 }
 
-inline Node mkAnd(std::vector<TNode>& children) {
-  std::set<TNode> distinctChildren;
-  distinctChildren.insert(children.begin(), children.end());
-  
-  if (distinctChildren.size() == 0) {
-    return mkTrue();
-  }
-  
-  if (distinctChildren.size() == 1) {
-    return *children.begin();
-  }
-  
-  NodeBuilder<> conjunction(kind::AND);
-  std::set<TNode>::const_iterator it = distinctChildren.begin();
-  std::set<TNode>::const_iterator it_end = distinctChildren.end();
-  while (it != it_end) {
-    conjunction << *it;
-    ++ it;
-  }
-
-  return conjunction;
-}
 
 inline Node mkSortedNode(Kind kind, std::vector<Node>& children) {
   Assert (kind == kind::BITVECTOR_AND ||
@@ -155,14 +133,6 @@ inline Node mkXor(TNode node1, TNode node2) {
 }
 
 
-inline Node mkAnd(std::vector<Node>& children) {
-  if(children.size() > 1) {
-    return NodeManager::currentNM()->mkNode(kind::AND, children);
-  } else {
-    return children[0];
-  }
-}
-
 inline Node mkExtract(TNode node, unsigned high, unsigned low) {
   Node extractOp = NodeManager::currentNM()->mkConst<BitVectorExtract>(BitVectorExtract(high, low));
   std::vector<Node> children;
@@ -268,7 +238,6 @@ inline Node mkConjunction(const std::set<TNode> nodes) {
   return conjunction;
 }
 
-
 inline unsigned isPow2Const(TNode node) {
   if (node.getKind() != kind::CONST_BITVECTOR) {
     return false; 
@@ -278,6 +247,83 @@ inline unsigned isPow2Const(TNode node) {
   return bv.isPow2(); 
 }
 
+typedef __gnu_cxx::hash_set<TNode, TNodeHashFunction> TNodeSet;
+
+inline Node mkAnd(const std::vector<TNode>& conjunctions) {
+  std::set<TNode> all;
+  all.insert(conjunctions.begin(), conjunctions.end());
+
+  if (all.size() == 0) {
+    return mkTrue(); 
+  }
+  
+  if (all.size() == 1) {
+    // All the same, or just one
+    return conjunctions[0];
+  }
+  
+
+  NodeBuilder<> conjunction(kind::AND);
+  std::set<TNode>::const_iterator it = all.begin();
+  std::set<TNode>::const_iterator it_end = all.end();
+  while (it != it_end) {
+    conjunction << *it;
+    ++ it;
+  }
+
+  return conjunction;
+}/* mkAnd() */
+
+inline Node mkAnd(const std::vector<Node>& conjunctions) {
+  std::set<TNode> all;
+  all.insert(conjunctions.begin(), conjunctions.end());
+
+  if (all.size() == 0) {
+    return mkTrue(); 
+  }
+  
+  if (all.size() == 1) {
+    // All the same, or just one
+    return conjunctions[0];
+  }
+  
+
+  NodeBuilder<> conjunction(kind::AND);
+  std::set<TNode>::const_iterator it = all.begin();
+  std::set<TNode>::const_iterator it_end = all.end();
+  while (it != it_end) {
+    conjunction << *it;
+    ++ it;
+  }
+
+  return conjunction;
+}/* mkAnd() */
+
+
+
+inline Node flattenAnd(std::vector<TNode>& queue) {
+  TNodeSet nodes;
+  while(!queue.empty()) {
+    TNode current = queue.back();
+    queue.pop_back();
+    if (current.getKind() ==  kind::AND) {
+      for (unsigned i = 0; i < current.getNumChildren(); ++i) {
+        if (nodes.count(current[i]) == 0) {
+          queue.push_back(current[i]);
+        }
+      }
+    } else {
+      nodes.insert(current); 
+    }
+  }
+  std::vector<TNode> children; 
+  for (TNodeSet::const_iterator it = nodes.begin(); it!= nodes.end(); ++it) {
+    children.push_back(*it); 
+  }
+  return mkAnd(children); 
+}
+
+
 // neeed a better name, this is not technically a ground term 
 inline bool isBVGroundTerm(TNode node) {
   if (node.getNumChildren() == 0) {
@@ -356,27 +402,7 @@ inline Node mkConjunction(const std::vector<TNode>& nodes) {
 }
 
 
-inline Node mkAnd(const std::vector<TNode>& conjunctions) {
-  Assert(conjunctions.size() > 0);
-
-  std::set<TNode> all;
-  all.insert(conjunctions.begin(), conjunctions.end());
 
-  if (all.size() == 1) {
-    // All the same, or just one
-    return conjunctions[0];
-  }
-
-  NodeBuilder<> conjunction(kind::AND);
-  std::set<TNode>::const_iterator it = all.begin();
-  std::set<TNode>::const_iterator it_end = all.end();
-  while (it != it_end) {
-    conjunction << *it;
-    ++ it;
-  }
-
-  return conjunction;
-}/* mkAnd() */
 
 
 // Turn a set into a string