* simplifying equality engine interface
authorDejan Jovanović <dejan.jovanovic@gmail.com>
Wed, 9 May 2012 21:25:17 +0000 (21:25 +0000)
committerDejan Jovanović <dejan.jovanovic@gmail.com>
Wed, 9 May 2012 21:25:17 +0000 (21:25 +0000)
* notifications are now through the interface subclass instead of a template
* notifications include constants being merged
* changed contextNotifyObj::notify to contextNotifyObj::contextNotifyPop so it's more descriptive and doesn't clutter methods when subclassed
* sat solver now has explicit methods to make true and false constants
* 0-level literals are removed from explanations of propagations

41 files changed:
src/context/context.cpp
src/context/context.h
src/context/stacking_map.h
src/context/stacking_vector.h
src/prop/bvminisat/bvminisat.cpp
src/prop/bvminisat/bvminisat.h
src/prop/bvminisat/core/Solver.cc
src/prop/bvminisat/core/Solver.h
src/prop/bvminisat/simp/SimpSolver.cc
src/prop/cnf_stream.cpp
src/prop/minisat/core/Solver.cc
src/prop/minisat/core/Solver.h
src/prop/minisat/minisat.cpp
src/prop/minisat/minisat.h
src/prop/minisat/simp/SimpSolver.cc
src/prop/sat_solver.h
src/smt/smt_engine.cpp
src/theory/arith/congruence_manager.cpp
src/theory/arith/congruence_manager.h
src/theory/arith/theory_arith.cpp
src/theory/arrays/theory_arrays.cpp
src/theory/arrays/theory_arrays.h
src/theory/booleans/circuit_propagator.h
src/theory/bv/theory_bv.cpp
src/theory/bv/theory_bv.h
src/theory/datatypes/union_find.cpp
src/theory/datatypes/union_find.h
src/theory/shared_terms_database.cpp
src/theory/shared_terms_database.h
src/theory/substitutions.h
src/theory/theory_engine.cpp
src/theory/theory_engine.h
src/theory/uf/Makefile.am
src/theory/uf/equality_engine.cpp [new file with mode: 0644]
src/theory/uf/equality_engine.h
src/theory/uf/equality_engine_impl.h [deleted file]
src/theory/uf/theory_uf.cpp
src/theory/uf/theory_uf.h
src/util/configuration.cpp
test/unit/context/context_black.h
test/unit/prop/cnf_stream_black.h

index abb1575d4e539a01702c91886710e7699593b89c..da60a5bc4497b90cc5734552ff04e991602d9d97 100644 (file)
@@ -80,7 +80,7 @@ void Context::pop() {
   while(pCNO != NULL) {
     // pre-store the "next" pointer in case pCNO deletes itself on notify()
     ContextNotifyObj* next = pCNO->d_pCNOnext;
-    pCNO->notify();
+    pCNO->contextNotifyPop();
     pCNO = next;
   }
 
@@ -101,7 +101,7 @@ void Context::pop() {
   while(pCNO != NULL) {
     // pre-store the "next" pointer in case pCNO deletes itself on notify()
     ContextNotifyObj* next = pCNO->d_pCNOnext;
-    pCNO->notify();
+    pCNO->contextNotifyPop();
     pCNO = next;
   }
 
index f0dbff72bff373315223b00589beed4b82df5453..165c35c58c171eff94f5bb75d0b604ca83d5e2c4 100644 (file)
@@ -658,6 +658,7 @@ public:
  * the ContextObj objects have been restored).
  */
 class ContextNotifyObj {
+
   /**
    * Context is our friend so that when the Context is deleted, any
    * remaining ContextNotifyObj can be removed from the Context list.
@@ -686,6 +687,15 @@ class ContextNotifyObj {
    */
   ContextNotifyObj**& prev() throw() { return d_ppCNOprev; }
 
+protected:
+
+  /**
+   * This is the method called to notify the object of a pop.  It must be
+   * implemented by the subclass. It is protected since context is out
+   * friend.
+   */
+  virtual void contextNotifyPop() = 0;
+
 public:
 
   /**
@@ -703,12 +713,6 @@ public:
    */
   virtual ~ContextNotifyObj() throw(AssertionException);
 
-  /**
-   * This is the method called to notify the object of a pop.  It must be
-   * implemented by the subclass.
-   */
-  virtual void notify() = 0;
-
 };/* class ContextNotifyObj */
 
 inline void ContextObj::makeCurrent() throw(AssertionException) {
index 2dec1845cd732fbb9901f4bacad00e6ebad6a3ec..ba644596ee042db48b78716898cc1e43b33ab061 100644 (file)
@@ -96,6 +96,14 @@ class StackingMap : context::ContextNotifyObj {
   /** Our current offset in the d_trace stack (context-dependent). */
   context::CDO<size_t> d_offset;
 
+protected:
+
+  /**
+   * Called by the Context when a pop occurs.  Cancels everything to the
+   * current context level.  Overrides ContextNotifyObj::contextNotifyPop().
+   */
+  void contextNotifyPop();
+
 public:
   typedef typename MapType::const_iterator const_iterator;
 
@@ -128,12 +136,6 @@ public:
    */
   void set(ArgType n, const ValueType& newValue);
 
-  /**
-   * Called by the Context when a pop occurs.  Cancels everything to the
-   * current context level.  Overrides ContextNotifyObj::notify().
-   */
-  void notify();
-
 };/* class StackingMap<> */
 
 template <class KeyType, class ValueType, class KeyHash>
@@ -146,7 +148,7 @@ void StackingMap<KeyType, ValueType, KeyHash>::set(ArgType n, const ValueType& n
 }
 
 template <class KeyType, class ValueType, class KeyHash>
-void StackingMap<KeyType, ValueType, KeyHash>::notify() {
+void StackingMap<KeyType, ValueType, KeyHash>::contextNotifyPop() {
   Trace("sm") << "SM cancelling : " << d_offset << " < " << d_trace.size() << " ?" << std::endl;
   while(d_offset < d_trace.size()) {
     std::pair<ArgType, ValueType> p = d_trace.back();
index 9987731d46ce7ff62f5c5370060dc4b75bbe3dea..ed311b952701d8db35ba5af4a77c168174eef31a 100644 (file)
@@ -82,7 +82,7 @@ public:
    * Called by the Context when a pop occurs.  Cancels everything to the
    * current context level.  Overrides ContextNotifyObj::notify().
    */
-  void notify();
+  void contextNotifyPop();
 
 };/* class StackingVector<> */
 
@@ -99,7 +99,7 @@ void StackingVector<T>::set(size_t n, const T& newValue) {
 }
 
 template <class T>
-void StackingVector<T>::notify() {
+void StackingVector<T>::contextNotifyPop() {
   Trace("sv") << "SV cancelling : " << d_offset << " < " << d_trace.size() << " ?" << std::endl;
   while(d_offset < d_trace.size()) {
     std::pair<size_t, T> p = d_trace.back();
index 124fc35f1802dcb49783509b1f80d432ee389f19..4868db6f574d2a8201d4d8ba4140f2fece1d284f 100644 (file)
@@ -73,7 +73,7 @@ SatValue BVMinisatSatSolver::assertAssumption(SatLiteral lit, bool propagate) {
   return toSatLiteralValue(d_minisat->assertAssumption(toMinisatLit(lit), propagate));
 }
 
-void BVMinisatSatSolver::notify() {
+void BVMinisatSatSolver::contextNotifyPop() {
   while (d_assertionsCount > d_assertionsRealCount) {
     popAssumption();
     d_assertionsCount --;
index cd2a2c6b90305ac17f8678ecaadbf08ac40f2e5b..60cdd1c2838206424d4752f9763636e92e4f232d 100644 (file)
@@ -54,6 +54,10 @@ private:
   context::CDO<unsigned> d_assertionsRealCount;
   context::CDO<unsigned> d_lastPropagation;
 
+protected:
+
+  void contextNotifyPop();
+
 public:
 
   BVMinisatSatSolver() :
@@ -70,10 +74,12 @@ public:
 
   SatVariable newVar(bool theoryAtom = false);
 
+  SatVariable trueVar() { return d_minisat->trueVar(); }
+  SatVariable falseVar() { return d_minisat->falseVar(); }
+
   void markUnremovable(SatLiteral lit);
 
   void interrupt();
-  void notify(); 
   
   SatValue solve();
   SatValue solve(long unsigned int&);
index e24fcac1aa64d00b17d742ca50c0bd43de07e0f6..c96b6e4b2c94342db8f827faca38ed74d2a3a1d7 100644 (file)
@@ -119,7 +119,15 @@ Solver::Solver(CVC4::context::Context* c) :
   , propagation_budget (-1)
   , asynch_interrupt   (false)
   , clause_added(false)
-{}
+{
+  // Create the constant variables
+  varTrue = newVar(true, false);
+  varFalse = newVar(false, false);
+
+  // Assert the constants
+  uncheckedEnqueue(mkLit(varTrue, false));
+  uncheckedEnqueue(mkLit(varFalse, true));
+}
 
 
 Solver::~Solver()
index c323bfe2bc8d42896e154f531461c03d5d81c6e7..ae5efd81ef1c550e748879b6af70cfef8fb74251 100644 (file)
@@ -64,6 +64,12 @@ class Solver {
     /** Cvc4 context */
     CVC4::context::Context* c;
 
+    /** True constant */
+    Var varTrue;
+
+    /** False constant */
+    Var varFalse;
+
 public:
 
     // Constructor/Destructor:
@@ -76,6 +82,9 @@ public:
     // Problem specification:
     //
     Var     newVar    (bool polarity = true, bool dvar = true); // Add a new variable with parameters specifying variable mode.
+    Var     trueVar() const { return varTrue; }
+    Var     falseVar() const { return varFalse; }
+
 
     bool    addClause (const vec<Lit>& ps);                     // Add a clause to the solver. 
     bool    addEmptyClause();                                   // Add the empty clause, making the solver contradictory.
index c8ce134102b98fe6bc73cd9da1444a3dd7dea093..59820e9e3499406030c4a786f3ced87ccea96535 100644 (file)
@@ -63,11 +63,25 @@ SimpSolver::SimpSolver(CVC4::context::Context* c) :
   , bwdsub_assigns     (0)
   , n_touched          (0)
 {
-  CVC4::StatisticsRegistry::registerStat(&total_eliminate_time); 
+    CVC4::StatisticsRegistry::registerStat(&total_eliminate_time);
     vec<Lit> dummy(1,lit_Undef);
     ca.extra_clause_field = true; // NOTE: must happen before allocating the dummy clause below.
     bwdsub_tmpunit        = ca.alloc(dummy);
     remove_satisfied      = false;
+
+    // add the initialization for all the internal variables
+    for (int i = frozen.size(); i < vardata.size(); ++ i) {
+      frozen    .push(1);
+      eliminated.push(0);
+      if (use_simplification){
+          n_occ     .push(0);
+          n_occ     .push(0);
+          occurs    .init(i);
+          touched   .push(0);
+          elim_heap .insert(i);
+      }
+    }
+
 }
 
 
index 3a4fa781a922e2bfb609fb1410bf5f12dac0a826..d18ec6e69f4f9be8e4e9c33af029e13ceca35577 100644 (file)
@@ -175,7 +175,15 @@ SatLiteral CnfStream::newLiteral(TNode node, bool theoryLiteral) {
   SatLiteral lit;
   if (!hasLiteral(node)) {
     // If no literal, we'll make one
-    lit = SatLiteral(d_satSolver->newVar(theoryLiteral));
+    if (node.getKind() == kind::CONST_BOOLEAN) {
+      if (node.getConst<bool>()) {
+        lit = SatLiteral(d_satSolver->trueVar());
+      } else {
+        lit = SatLiteral(d_satSolver->falseVar());
+      }
+    } else {
+      lit = SatLiteral(d_satSolver->newVar(theoryLiteral));
+    }
     d_translationCache[node].literal = lit;
     d_translationCache[node.notNode()].literal = ~lit;
   } else {
index 5e1b032a36db036ef4e50477411b622f6adf4892..6ee508eba1d29edabcd03c9528828481d2f81498 100644 (file)
@@ -126,6 +126,14 @@ Solver::Solver(CVC4::prop::TheoryProxy* proxy, CVC4::context::Context* context,
   , asynch_interrupt   (false)
 {
   PROOF(ProofManager::initSatProof(this);)
+
+  // Create the constant variables
+  varTrue = newVar(true, false, false);
+  varFalse = newVar(false, false, false);
+
+  // Assert the constants
+  uncheckedEnqueue(mkLit(varTrue, false));
+  uncheckedEnqueue(mkLit(varFalse, true));
 }
 
 
@@ -190,16 +198,26 @@ CRef Solver::reason(Var x) {
 
     // Compute the assertion level for this clause
     int explLevel = 0;
-    for (int i = 0; i < explanation.size(); ++ i) {
+    int i, j;
+    for (i = 0, j = 0; i < explanation.size(); ++ i) {
       int varLevel = intro_level(var(explanation[i]));
       if (varLevel > explLevel) {
         explLevel = varLevel;
       }
       Assert(value(explanation[i]) != l_Undef);
       Assert(i == 0 || trail_index(var(explanation[0])) > trail_index(var(explanation[i])));
+      // ignore zero level literals
+      if (i == 0 || level(var(explanation[i])) > 0) {
+        explanation[j++] = explanation[i];
+      }
+    }
+    explanation.shrink(i - j);
+    if (j == 1) {
+      // Add not TRUE to the clause
+      explanation.push(mkLit(varTrue, true));
     }
 
-    // Construct the reason (level 0)
+    // Construct the reason
     CRef real_reason = ca.alloc(explLevel, explanation, true);
     vardata[x] = mkVarData(real_reason, level(x), intro_level(x), trail_index(x));
     clauses_removable.push(real_reason);
index cfeb0621158ebb366126b9fde5d311d015bd3eac..e677d7220327900d856b9ac541bb29c0bf24123d 100644 (file)
@@ -65,6 +65,13 @@ protected:
 
   /** The current assertion level (user) */
   int assertionLevel; 
+
+  /** Variable representing true */
+  Var varTrue;
+
+  /** Variable representing false */
+  Var varFalse;
+
 public:
   /** Returns the current user assertion level */
   int getAssertionLevel() const { return assertionLevel; }
@@ -108,6 +115,8 @@ public:
     // Problem specification:
     //
     Var     newVar    (bool polarity = true, bool dvar = true, bool theoryAtom = false); // Add a new variable with parameters specifying variable mode.
+    Var     trueVar() const { return varTrue; }
+    Var     falseVar() const { return varFalse; }
 
     // Less than for literals in a lemma
     struct lemma_lt {
index bed30d6583888b62220ec4470b045c92868671f2..4f2a16670eff43826f59e630f2a562e783da9156 100644 (file)
@@ -121,7 +121,6 @@ SatVariable MinisatSatSolver::newVar(bool theoryAtom) {
   return d_minisat->newVar(true, true, theoryAtom);
 }
 
-
 SatValue MinisatSatSolver::solve(unsigned long& resource) {
   Trace("limit") << "SatSolver::solve(): have limit of " << resource << " conflicts" << std::endl;
   if(resource == 0) {
index 9cf75a12e39989d350261edab539b725bc13896d..19ade8ffab479292ba9bde13e197c9e06d44f8b1 100644 (file)
@@ -56,6 +56,8 @@ public:
   void addClause(SatClause& clause, bool removable);
 
   SatVariable newVar(bool theoryAtom = false);
+  SatVariable trueVar() { return d_minisat->trueVar(); }
+  SatVariable falseVar() { return d_minisat->falseVar(); }
 
   SatValue solve();
   SatValue solve(long unsigned int&);
index 2cacfbcc01c54412397149be865539e7e35c8a8d..8da3856ff7b4b7ea28499c1371beb410b7759dd7 100644 (file)
@@ -67,6 +67,19 @@ SimpSolver::SimpSolver(CVC4::prop::TheoryProxy* proxy, CVC4::context::Context* c
     ca.extra_clause_field = true; // NOTE: must happen before allocating the dummy clause below.
     bwdsub_tmpunit        = ca.alloc(0, dummy);
     remove_satisfied      = false;
+
+    // add the initialization for all the internal variables
+    for (int i = frozen.size(); i < vardata.size(); ++ i) {
+      frozen    .push(1);
+      eliminated.push(0);
+      if (use_simplification){
+          n_occ     .push(0);
+          n_occ     .push(0);
+          occurs    .init(i);
+          touched   .push(0);
+          elim_heap .insert(i);
+      }
+    }
 }
 
 
index 898709c43fcb3577e59060ef88368440303c5c00..2865f2cb5569d0f46ff501938140461e7a60e869 100644 (file)
@@ -46,6 +46,12 @@ public:
   /** Create a new boolean variable in the solver. */
   virtual SatVariable newVar(bool theoryAtom = false) = 0;
  
+  /** Create a new (or return an existing) boolean variable representing the constant true */
+  virtual SatVariable trueVar() = 0;
+
+  /** Create a new (or return an existing) boolean variable representing the constant false */
+  virtual SatVariable falseVar() = 0;
+
   /** Check the satisfiability of the added clauses */
   virtual SatValue solve() = 0;
 
index e636b91429104ed249306e23aa3c01e64957b6ab..2759f5717833c25ed64208ef6852c9a6fdc7058a 100644 (file)
@@ -304,7 +304,6 @@ SmtEngine::SmtEngine(ExprManager* em) throw(AssertionException) :
     setTimeLimit(Options::current()->cumulativeMillisecondLimit, true);
   }
 
-
   d_propEngine->assertFormula(NodeManager::currentNM()->mkConst<bool>(true));
   d_propEngine->assertFormula(NodeManager::currentNM()->mkConst<bool>(false).notNode());
 }
index 201eb08e70454d926c351cb3e345defc2c0f0feb..39468e92816ce672696cfcf406ffe8aef5d250c9 100644 (file)
@@ -1,5 +1,4 @@
 #include "theory/arith/congruence_manager.h"
-#include "theory/uf/equality_engine_impl.h"
 
 #include "theory/arith/constraint.h"
 #include "theory/arith/arith_utilities.h"
@@ -17,8 +16,7 @@ ArithCongruenceManager::ArithCongruenceManager(context::Context* c, ConstraintDa
     d_constraintDatabase(cd),
     d_setupLiteral(setup),
     d_av2Node(av2Node),
-    d_ee(d_notify, c, "theory::arith::ArithCongruenceManager"),
-    d_false(mkBoolNode(false))
+    d_ee(d_notify, c, "theory::arith::ArithCongruenceManager")
 {}
 
 ArithCongruenceManager::Statistics::Statistics():
@@ -113,7 +111,7 @@ bool ArithCongruenceManager::propagate(TNode x){
     }else{
       ++(d_statistics.d_conflicts);
 
-      Node conf = explainInternal(x);
+      Node conf = flattenAnd(explainInternal(x));
       d_conflict.set(conf);
       Debug("arith::congruenceManager") << "rewritten to false "<<x<<" with explanation "<< conf << std::endl;
       return false;
@@ -181,20 +179,11 @@ bool ArithCongruenceManager::propagate(TNode x){
 }
 
 void ArithCongruenceManager::explain(TNode literal, std::vector<TNode>& assumptions) {
-  TNode lhs, rhs;
-  switch (literal.getKind()) {
-  case kind::EQUAL:
-    lhs = literal[0];
-    rhs = literal[1];
-    break;
-  case kind::NOT:
-    lhs = literal[0];
-    rhs = d_false;
-    break;
-  default:
-    Unreachable();
+  if (literal.getKind() != kind::NOT) {
+    d_ee.explainEquality(literal[0], literal[1], true, assumptions);
+  } else {
+    d_ee.explainEquality(literal[0][0], literal[0][1], false, assumptions);
   }
-  d_ee.explainEquality(lhs, rhs, assumptions);
 }
 
 void ArithCongruenceManager::enqueueIntoNB(const std::set<TNode> s, NodeBuilder<>& nb){
@@ -258,13 +247,10 @@ void ArithCongruenceManager::assertionToEqualityEngine(bool isEquality, ArithVar
   TNode eq = d_watchedEqualities[s];
   Assert(eq.getKind() == kind::EQUAL);
 
-  TNode x = eq[0];
-  TNode y = eq[1];
-
   if(isEquality){
-    d_ee.addEquality(x, y, reason);
+    d_ee.assertEquality(eq, true, reason);
   }else{
-    d_ee.addDisequality(x, y, reason);
+    d_ee.assertEquality(eq, false, reason);
   }
 }
 
@@ -286,7 +272,7 @@ void ArithCongruenceManager::equalsConstant(Constraint c){
   Node reason = c->explainForConflict();
   d_keepAlive.push_back(reason);
 
-  d_ee.addEquality(xAsNode, asRational, reason);
+  d_ee.assertEquality(eq, true, reason);
 }
 
 void ArithCongruenceManager::equalsConstant(Constraint lb, Constraint ub){
@@ -310,7 +296,7 @@ void ArithCongruenceManager::equalsConstant(Constraint lb, Constraint ub){
   d_keepAlive.push_back(reason);
 
 
-  d_ee.addEquality(xAsNode, asRational, reason);
+  d_ee.assertEquality(eq, true, reason);
 }
 
 void ArithCongruenceManager::addSharedTerm(Node x){
index a729894986447f5e3973dd4f04f44463fba5befb..18ecbeb9d7757f5e77f89a9862b7ae78499c87a9 100644 (file)
@@ -37,24 +37,43 @@ private:
   ArithVarToNodeMap d_watchedEqualities;
 
 
-  class ArithCongruenceNotify {
+  class ArithCongruenceNotify : public eq::EqualityEngineNotify {
   private:
     ArithCongruenceManager& d_acm;
   public:
     ArithCongruenceNotify(ArithCongruenceManager& acm): d_acm(acm) {}
 
-    bool notify(TNode propagation) {
-      Debug("arith::congruences") << "ArithCongruenceNotify::notify(" << propagation << ")" << std::endl;
-      // Just forward to dm
-      return d_acm.propagate(propagation);
+    bool eqNotifyTriggerEquality(TNode equality, bool value) {
+      Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false") << ")" << std::endl;
+      if (value) {
+        return d_acm.propagate(equality);
+      } else {
+        return d_acm.propagate(equality.notNode());
+      }
     }
 
-    void notify(TNode t1, TNode t2) {
-      Debug("arith::congruences") << "ArithCongruenceNotify::notify(" << t1 << ", " << t2 << ")" << std::endl;
-      Node equality = t1.eqNode(t2);
-      d_acm.propagate(equality);
+    bool eqNotifyTriggerPredicate(TNode predicate, bool value) {
+      Unreachable();
     }
-  };
+
+    bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) {
+      Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 << ", " << (value ? "true" : "false") << ")" << std::endl;
+      if (value) {
+        return d_acm.propagate(t1.eqNode(t2));
+      } else {
+        return d_acm.propagate(t1.eqNode(t2).notNode());
+      }
+    }
+
+    bool eqNotifyConstantTermMerge(TNode t1, TNode t2) {
+      Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << std::endl;
+      if (t1.getKind() == kind::CONST_BOOLEAN) {
+        return d_acm.propagate(t1.iffNode(t2));
+      } else {
+        return d_acm.propagate(t1.eqNode(t2));
+      }
+    }
+   };
   ArithCongruenceNotify d_notify;
 
   context::CDList<Node> d_keepAlive;
@@ -75,8 +94,7 @@ private:
 
   const ArithVarNodeMap& d_av2Node;
 
-  theory::uf::EqualityEngine<ArithCongruenceNotify> d_ee;
-  Node d_false;
+  eq::EqualityEngine d_ee;
 
 public:
 
index c7072de72c3f688f2be44e8e0cbd3e1b96fc1884..6bb3821da97c90926e61ff848fe5535ba6099e5a 100644 (file)
@@ -1001,8 +1001,8 @@ Node TheoryArith::assertionCases(TNode assertion){
       if(Debug.isOn("whytheoryenginewhy")){
         debugPrintFacts();
       }
-      Warning() << "arith: Theory engine is sending me both a literal and its negation?"
-                << "BOOOOOOOOOOOOOOOOOOOOOO!!!!"<< endl;
+//      Warning() << "arith: Theory engine is sending me both a literal and its negation?"
+//                << "BOOOOOOOOOOOOOOOOOOOOOO!!!!"<< endl;
     }
     Debug("arith::eq") << constraint << endl;
     Debug("arith::eq") << negation << endl;
index 80bcb47dd2c2d491e978ec2c27ed862b2d185ba4..1dd74f060d168e0bc601306f54c1970eb9a4b48f 100644 (file)
 #include <map>
 #include "theory/rewriter.h"
 #include "expr/command.h"
-#include "theory/uf/equality_engine_impl.h"
-
 
 using namespace std;
 
-
 namespace CVC4 {
 namespace theory {
 namespace arrays {
 
-
 // These are the options that produce the best empirical results on QF_AX benchmarks.
 // eagerLemmas = true
 // eagerIndexSplitting = false
@@ -58,14 +54,12 @@ TheoryArrays::TheoryArrays(context::Context* c, context::UserContext* u, OutputC
   d_numNonLinear("theory::arrays::number of calls to setNonLinear", 0),
   d_numSharedArrayVarSplits("theory::arrays::number of shared array var splits", 0),
   d_checkTimer("theory::arrays::checkTime"),
-  d_ppNotify(),
-  d_ppEqualityEngine(d_ppNotify, u, "theory::arrays::TheoryArraysPP"),
+  d_ppEqualityEngine(u, "theory::arrays::TheoryArraysPP"),
   d_ppFacts(u),
   //  d_ppCache(u),  
   d_literalsToPropagate(c),
   d_literalsToPropagateIndex(c, 0),
-  d_mayEqualNotify(),
-  d_mayEqualEqualityEngine(d_mayEqualNotify, c, "theory::arrays::TheoryArraysMayEqual"),
+  d_mayEqualEqualityEngine(c, "theory::arrays::TheoryArraysMayEqual"),
   d_notify(*this),
   d_equalityEngine(d_notify, c, "theory::arrays::TheoryArrays"),
   d_conflict(c, false),
@@ -91,14 +85,6 @@ TheoryArrays::TheoryArrays(context::Context* c, context::UserContext* u, OutputC
   d_true = NodeManager::currentNM()->mkConst<bool>(true);
   d_false = NodeManager::currentNM()->mkConst<bool>(false);
 
-  d_ppEqualityEngine.addTerm(d_true);
-  d_ppEqualityEngine.addTerm(d_false);
-  d_ppEqualityEngine.addTriggerEquality(d_true, d_false, d_false);
-
-  d_equalityEngine.addTerm(d_true);
-  d_equalityEngine.addTerm(d_false);
-  d_equalityEngine.addTriggerEquality(d_true, d_false, d_false);
-
   // The kinds we are treating as function application in congruence
   d_equalityEngine.addFunctionKind(kind::SELECT);
   if (d_ccStore) {
@@ -281,7 +267,7 @@ Theory::PPAssertStatus TheoryArrays::ppAssert(TNode in, SubstitutionMap& outSubs
     case kind::EQUAL:
     {
       d_ppFacts.push_back(in);
-      d_ppEqualityEngine.addEquality(in[0], in[1], in);
+      d_ppEqualityEngine.assertEquality(in, true, in);
       if (in[0].getMetaKind() == kind::metakind::VARIABLE && !in[1].hasSubterm(in[0])) {
         outSubstitutions.addSubstitution(in[0], in[1]);
         return PP_ASSERT_STATUS_SOLVED;
@@ -299,7 +285,7 @@ Theory::PPAssertStatus TheoryArrays::ppAssert(TNode in, SubstitutionMap& outSubs
              in[0].getKind() == kind::IFF );
       Node a = in[0][0];
       Node b = in[0][1];
-      d_ppEqualityEngine.addDisequality(a, b, in);
+      d_ppEqualityEngine.assertEquality(in[0], false, in);
       break;
     }
     default:
@@ -335,10 +321,8 @@ bool TheoryArrays::propagate(TNode literal)
       Debug("arrays") << spaces(getSatContext()->getLevel()) << "TheoryArrays::propagate(" << literal << ", normalized = " << normalized << ") => conflict" << std::endl;
       std::vector<TNode> assumptions;
       Node negatedLiteral;
-      if (normalized != d_false) {
-        negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode();
-        assumptions.push_back(negatedLiteral);
-      }
+      negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode();
+      assumptions.push_back(negatedLiteral);
       explain(literal, assumptions);
       d_conflictNode = mkAnd(assumptions);
       d_conflict = true;
@@ -357,67 +341,40 @@ bool TheoryArrays::propagate(TNode literal)
 
 
 void TheoryArrays::explain(TNode literal, std::vector<TNode>& assumptions) {
-  TNode lhs, rhs;
-  switch (literal.getKind()) {
-    case kind::EQUAL:
-      lhs = literal[0];
-      rhs = literal[1];
-      break;
-    case kind::SELECT:
-      lhs = literal;
-      rhs = d_true;
-      break;
-    case kind::NOT:
-      if (literal[0].getKind() == kind::EQUAL) {
-        // Disequalities
-        d_equalityEngine.explainDisequality(literal[0][0], literal[0][1], assumptions);
-        return;
-      } else {
-        // Predicates
-        lhs = literal[0];
-        rhs = d_false;
-        break;
-      }
-    case kind::CONST_BOOLEAN:
-      // we get to explain true = false, since we set false to be the trigger of this
-      lhs = d_true;
-      rhs = d_false;
-      break;
-    default:
-      Unreachable();
+  // Do the work
+  bool polarity = literal.getKind() != kind::NOT;
+  TNode atom = polarity ? literal : literal[0];
+  if (atom.getKind() == kind::EQUAL || atom.getKind() == kind::IFF) {
+    d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
+  } else {
+    d_equalityEngine.explainPredicate(atom, polarity, assumptions);
   }
-  d_equalityEngine.explainEquality(lhs, rhs, assumptions);
 }
 
 
-  /**
  * Stores in d_infoMap the following information for each term a of type array:
  *
  *    - all i, such that there exists a term a[i] or a = store(b i v)
  *      (i.e. all indices it is being read atl; store(b i v) is implicitly read at
  *      position i due to the implicit axiom store(b i v)[i] = v )
  *
  *    - all the stores a is congruent to (this information is context dependent)
  *
  *    - all store terms of the form store (a i v) (i.e. in which a appears
  *      directly; this is invariant because no new store terms are created)
  *
  * Note: completeness depends on having pre-register called on all the input
  *       terms before starting to instantiate lemmas.
  */
+/**
+ * Stores in d_infoMap the following information for each term a of type array:
+ *
+ *    - all i, such that there exists a term a[i] or a = store(b i v)
+ *      (i.e. all indices it is being read atl; store(b i v) is implicitly read at
+ *      position i due to the implicit axiom store(b i v)[i] = v )
+ *
+ *    - all the stores a is congruent to (this information is context dependent)
+ *
+ *    - all store terms of the form store (a i v) (i.e. in which a appears
+ *      directly; this is invariant because no new store terms are created)
+ *
+ * Note: completeness depends on having pre-register called on all the input
+ *       terms before starting to instantiate lemmas.
+ */
 void TheoryArrays::preRegisterTerm(TNode node)
 {
   Debug("arrays") << spaces(getSatContext()->getLevel()) << "TheoryArrays::preRegisterTerm(" << node << ")" << std::endl;
 
   switch (node.getKind()) {
   case kind::EQUAL:
-    // Add the terms
-    //    d_equalityEngine.addTerm(node[0]);
-    //    d_equalityEngine.addTerm(node[1]);
-    d_equalityEngine.addTerm(node);
     // Add the trigger for equality
-    d_equalityEngine.addTriggerEquality(node[0], node[1], node);
-    d_equalityEngine.addTriggerDisequality(node[0], node[1], node.notNode());
+    d_equalityEngine.addTriggerEquality(node);
     break;
   case kind::SELECT: {
     // Reads
@@ -438,7 +395,7 @@ void TheoryArrays::preRegisterTerm(TNode node)
           Assert(!d_equalityEngine.hasTerm(ni));
           preRegisterTerm(ni);
         }
-        d_equalityEngine.addEquality(ni, s[2], d_true);
+        d_equalityEngine.assertEquality(ni.eqNode(s[2]), true, d_true);
         Assert(++it == stores->end());
       }
     }
@@ -447,8 +404,7 @@ void TheoryArrays::preRegisterTerm(TNode node)
     // TODO: remove this or keep it if we allow Boolean elements in arrays.
     if (node.getType().isBoolean()) {
       // Get triggered for both equal and dis-equal
-      d_equalityEngine.addTriggerEquality(node, d_true, node);
-      d_equalityEngine.addTriggerEquality(node, d_false, node.notNode());
+      d_equalityEngine.addTriggerPredicate(node);
     }
 
     d_infoMap.addIndex(node[0], node[1]);
@@ -463,7 +419,7 @@ void TheoryArrays::preRegisterTerm(TNode node)
     //    TNode i = node[1];
     //    TNode v = node[2];
 
-    d_mayEqualEqualityEngine.addEquality(node, a, d_true);
+    d_mayEqualEqualityEngine.assertEquality(node.eqNode(a), true, d_true);
 
     // NodeManager* nm = NodeManager::currentNM();
     // Node ni = nm->mkNode(kind::SELECT, node, i);
@@ -508,10 +464,8 @@ void TheoryArrays::propagate(Effort e)
         Debug("arrays") << spaces(getSatContext()->getLevel()) << "TheoryArrays::propagate(): in conflict, normalized = " << normalized << std::endl;
         Node negatedLiteral;
         std::vector<TNode> assumptions;
-        if (normalized != d_false) {
-          negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode();
-          assumptions.push_back(negatedLiteral);
-        }
+        negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode();
+        assumptions.push_back(negatedLiteral);
         explain(literal, assumptions);
         d_conflictNode = mkAnd(assumptions);
         d_conflict = true;
@@ -727,17 +681,17 @@ void TheoryArrays::check(Effort e) {
     // Do the work
     switch (fact.getKind()) {
       case kind::EQUAL:
-        d_equalityEngine.addEquality(fact[0], fact[1], fact);
+        d_equalityEngine.assertEquality(fact, true, fact);
         break;
       case kind::SELECT:
-        d_equalityEngine.addPredicate(fact, true, fact);
+        d_equalityEngine.assertPredicate(fact, true, fact);
         break;
       case kind::NOT:
         if (fact[0].getKind() == kind::SELECT) {
-          d_equalityEngine.addPredicate(fact[0], false, fact);
+          d_equalityEngine.assertPredicate(fact[0], false, fact);
         } else if (!d_equalityEngine.areDisequal(fact[0][0], fact[0][1])) {
           // Assert the dis-equality
-          d_equalityEngine.addDisequality(fact[0][0], fact[0][1], fact);
+          d_equalityEngine.assertEquality(fact[0], false, fact);
 
           // Apply ArrDiseq Rule if diseq is between arrays
           if(fact[0][0].getType().isArray()) {
@@ -764,7 +718,7 @@ void TheoryArrays::check(Effort e) {
             if (!d_equalityEngine.hasTerm(bk)) {
               preRegisterTerm(bk);
             }
-            d_equalityEngine.addDisequality(ak, bk, fact);
+            d_equalityEngine.assertEquality(ak.eqNode(bk), false, fact);
             Trace("arrays-lem")<<"Arrays::addExtLemma "<< ak << " /= " << bk <<"\n";
             ++d_numExt;
           }
@@ -807,14 +761,11 @@ Node TheoryArrays::mkAnd(std::vector<TNode>& conjunctions)
   for (; i < conjunctions.size(); ++i) {
     t = conjunctions[i];
 
-    // Remove true node - represents axiomatically true assertion
-    if (t == d_true) continue;
-
     // Expand explanation resulting from propagating a ROW lemma
     if (t.getKind() == kind::OR) {
       if ((explained.find(t) == explained.end())) {
         Assert(t[1].getKind() == kind::EQUAL);
-        d_equalityEngine.explainDisequality(t[1][0], t[1][1], conjunctions);
+        d_equalityEngine.explainEquality(t[1][0], t[1][1], false, conjunctions);
         explained.insert(t);
       }
       continue;
@@ -949,7 +900,7 @@ void TheoryArrays::checkRIntro1(TNode a, TNode b)
     Node ni = nm->mkNode(kind::SELECT, s, s[1]);
     Assert(!d_equalityEngine.hasTerm(ni));
     preRegisterTerm(ni);
-    d_equalityEngine.addEquality(ni, s[2], d_true);
+    d_equalityEngine.assertEquality(ni.eqNode(s[2]), true, d_true);
   }
 }
 
@@ -1004,7 +955,7 @@ void TheoryArrays::mergeArrays(TNode a, TNode b)
       }
     }
 
-    d_mayEqualEqualityEngine.addEquality(a, b, d_true);
+    d_mayEqualEqualityEngine.assertEquality(a.eqNode(b), true, d_true);
 
     checkRowLemmas(a,b);
     checkRowLemmas(b,a);
@@ -1186,7 +1137,7 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem)
       if (!bjExists) {
         preRegisterTerm(bj);
       }
-      d_equalityEngine.addEquality(aj, bj, reason);
+      d_equalityEngine.assertEquality(aj.eqNode(bj), true, reason);
       ++d_numProp;
       return;
     }
@@ -1194,7 +1145,7 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem)
       Trace("arrays-lem") << spaces(getSatContext()->getLevel()) <<"Arrays::queueRowLemma: propagating i = j ("<<i<<", "<<j<<")\n";
       Node reason = nm->mkNode(kind::OR, i.eqNode(j), aj.eqNode(bj));
       d_permRef.push_back(reason);
-      d_equalityEngine.addEquality(i, j, reason);
+      d_equalityEngine.assertEquality(i.eqNode(j), true, reason);
       ++d_numProp;
       return;
     }
index d18b3abde8f15d4b1b8b971099a1973a72f2e712..88986ee7a92ef848bae778788b016de035d04eb7 100644 (file)
@@ -133,18 +133,8 @@ class TheoryArrays : public Theory {
 
   private:
 
-  // PPNotifyClass: dummy template class for d_ppEqualityEngine - notifications not used
-  class PPNotifyClass {
-  public:
-    bool notify(TNode propagation) { return true; }
-    void notify(TNode t1, TNode t2) { }
-  };
-
-  /** The notify class for d_ppEqualityEngine */
-  PPNotifyClass d_ppNotify;
-
   /** Equaltity engine */
-  uf::EqualityEngine<PPNotifyClass> d_ppEqualityEngine;
+  eq::EqualityEngine d_ppEqualityEngine;
 
   // List of facts learned by preprocessor - needed for permanent ref for benefit of d_ppEqualityEngine
   context::CDList<Node> d_ppFacts;
@@ -187,17 +177,8 @@ class TheoryArrays : public Theory {
 
   private:
 
-  class MayEqualNotifyClass {
-  public:
-    bool notify(TNode propagation) { return true; }
-    void notify(TNode t1, TNode t2) { }
-  };
-
-  /** The notify class for d_mayEqualEqualityEngine */
-  MayEqualNotifyClass d_mayEqualNotify;
-
   /** Equaltity engine for determining if two arrays might be equal */
-  uf::EqualityEngine<MayEqualNotifyClass> d_mayEqualEqualityEngine;
+  eq::EqualityEngine d_mayEqualEqualityEngine;
 
   public:
 
@@ -238,37 +219,57 @@ class TheoryArrays : public Theory {
   private:
 
   // NotifyClass: template helper class for d_equalityEngine - handles call-back from congruence closure module
-  class NotifyClass {
+  class NotifyClass : public eq::EqualityEngineNotify {
     TheoryArrays& d_arrays;
   public:
     NotifyClass(TheoryArrays& arrays): d_arrays(arrays) {}
 
-    bool notify(TNode propagation) {
-      Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::notify(" << propagation << ")" << std::endl;
+    bool eqNotifyTriggerEquality(TNode equality, bool value) {
+      Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false") << ")" << std::endl;
       // Just forward to arrays
-      return d_arrays.propagate(propagation);
+      if (value) {
+        return d_arrays.propagate(equality);
+      } else {
+        return d_arrays.propagate(equality.notNode());
+      }
+    }
+
+    bool eqNotifyTriggerPredicate(TNode predicate, bool value) {
+      Unreachable();
     }
 
-    void notify(TNode t1, TNode t2) {
-      Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::notify(" << t1 << ", " << t2 << ")" << std::endl;
-      if (t1.getType().isArray()) {
-        d_arrays.mergeArrays(t1, t2);
-        if (!d_arrays.isShared(t1) || !d_arrays.isShared(t2)) {
-          return;
+    bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) {
+      Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 << ")" << std::endl;
+      if (value) {
+        if (t1.getType().isArray()) {
+          d_arrays.mergeArrays(t1, t2);
+          if (!d_arrays.isShared(t1) || !d_arrays.isShared(t2)) {
+            return true;
+          }
         }
+        // Propagate equality between shared terms
+        Node equality = Rewriter::rewriteEquality(theory::THEORY_UF, t1.eqNode(t2));
+        d_arrays.propagate(equality);
       }
-      // Propagate equality between shared terms
-      Node equality = Rewriter::rewriteEquality(theory::THEORY_UF, t1.eqNode(t2));
-      d_arrays.propagate(equality);
+      // TODO: implement negation propagation
+      return true;
     }
-  };
 
+    bool eqNotifyConstantTermMerge(TNode t1, TNode t2) {
+      Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << ")" << std::endl;
+      if (Theory::theoryOf(t1) == THEORY_BOOL) {
+        return d_arrays.propagate(t1.iffNode(t2));
+      } else {
+        return d_arrays.propagate(t1.eqNode(t2));
+      }
+    }
+  };
 
   /** The notify class for d_equalityEngine */
   NotifyClass d_notify;
 
   /** Equaltity engine */
-  uf::EqualityEngine<NotifyClass> d_equalityEngine;
+  eq::EqualityEngine d_equalityEngine;
 
   // Are we in conflict?
   context::CDO<bool> d_conflict;
index 78221a61754e7cf70a6441165dd0f2206e215dbe..f5e4f4630d3fdfb9abcdb583e5069acb2db1d54f 100644 (file)
@@ -79,17 +79,17 @@ private:
   template <class T>
   class DataClearer : context::ContextNotifyObj {
     T& d_data;
+  protected:
+    void contextNotifyPop() {
+      Trace("circuit-prop") << "CircuitPropagator::DataClearer: clearing data "
+                            << "(size was " << d_data.size() << ")" << std::endl;
+      d_data.clear();
+    }
   public:
     DataClearer(context::Context* context, T& data) :
       context::ContextNotifyObj(context),
       d_data(data) {
     }
-
-    void notify() {
-      Trace("circuit-prop") << "CircuitPropagator::DataClearer: clearing data "
-                            << "(size was " << d_data.size() << ")" << std::endl;
-      d_data.clear();
-    }
   };/* class DataClearer<T> */
 
   /**
index c9d58574ed6563e7e300b66245100fdd12404261..4076a7ee053a03aa1ec642dbde5287df4281dc51 100644 (file)
@@ -21,7 +21,6 @@
 #include "theory/bv/theory_bv_utils.h"
 #include "theory/valuation.h"
 #include "theory/bv/bv_sat.h"
-#include "theory/uf/equality_engine_impl.h"
 
 using namespace CVC4;
 using namespace CVC4::theory;
@@ -52,18 +51,7 @@ TheoryBV::TheoryBV(context::Context* c, context::UserContext* u, OutputChannel&
     d_toBitBlast(c),
     d_propagatedBy(c)
   {
-    d_true = utils::mkTrue();
-    d_false = utils::mkFalse();
-
     if (d_useEqualityEngine) {
-      d_equalityEngine.addTerm(d_true);
-      d_equalityEngine.addTerm(d_false);
-      d_equalityEngine.addTriggerEquality(d_true, d_false, d_false);
-
-      // add disequality between 0 and 1 bits
-      d_equalityEngine.addDisequality(utils::mkConst(BitVector((unsigned)1, (unsigned)0)),
-                                      utils::mkConst(BitVector((unsigned)1, (unsigned)1)),
-                                      d_true);
 
       // The kinds we are treating as function application in congruence
       d_equalityEngine.addFunctionKind(kind::BITVECTOR_CONCAT);
@@ -137,11 +125,8 @@ void TheoryBV::preRegisterTerm(TNode node) {
   if (d_useEqualityEngine) {
     switch (node.getKind()) {
       case kind::EQUAL:
-        // Add the terms
-        d_equalityEngine.addTerm(node);
         // Add the trigger for equality
-        d_equalityEngine.addTriggerEquality(node[0], node[1], node);
-        d_equalityEngine.addTriggerDisequality(node[0], node[1], node.notNode());
+        d_equalityEngine.addTriggerEquality(node);
         break;
       default:
         d_equalityEngine.addTerm(node);
@@ -185,15 +170,15 @@ void TheoryBV::check(Effort e)
       if (predicate.getKind() == kind::EQUAL) {
         if (negated) {
           // dis-equality
-          d_equalityEngine.addDisequality(predicate[0], predicate[1], fact);
+          d_equalityEngine.assertEquality(predicate, false, fact);
         } else {
           // equality
-          d_equalityEngine.addEquality(predicate[0], predicate[1], fact);
+          d_equalityEngine.assertEquality(predicate, true, fact);
         }
       } else {
         // Adding predicate if the congruence over it is turned on
         if (d_equalityEngine.isFunctionKind(predicate.getKind())) {
-          d_equalityEngine.addPredicate(predicate, !negated, fact);
+          d_equalityEngine.assertPredicate(predicate, !negated, fact);
         }
       }
     }
@@ -279,16 +264,16 @@ void TheoryBV::propagate(Effort e) {
       bool satValue;
       if (!d_valuation.hasSatValue(normalized, satValue) || satValue) {
         // check if we already propagated the negation
-        Node neg_literal = literal.getKind() == kind::NOT ? (Node)literal[0] : mkNot(literal); 
-        if (d_alreadyPropagatedSet.find(neg_literal) != d_alreadyPropagatedSet.end()) {
+        Node negLiteral = literal.getKind() == kind::NOT ? (Node)literal[0] : mkNot(literal);
+        if (d_alreadyPropagatedSet.find(negLiteral) != d_alreadyPropagatedSet.end()) {
           Debug("bitvector") << spaces(getSatContext()->getLevel()) << "TheoryBV::propagate(): in conflict " << literal << " and its negation both propagated \n"; 
           // we are in conflict
           std::vector<TNode> assumptions;
           explain(literal, assumptions);
-          explain(neg_literal, assumptions);
+          explain(negLiteral, assumptions);
           d_conflictNode = mkAnd(assumptions); 
           d_conflict = true;
-          return; 
+          return;
         }
         
         BVDebug("bitvector") << spaces(getSatContext()->getLevel()) << "TheoryBV::propagate(): " << literal << std::endl;
@@ -299,10 +284,8 @@ void TheoryBV::propagate(Effort e) {
         
         Node negatedLiteral;
         std::vector<TNode> assumptions;
-        if (normalized != d_false) {
         negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode();
         assumptions.push_back(negatedLiteral);
-        }
         explain(literal, assumptions);
         d_conflictNode = mkAnd(assumptions);
         d_conflict = true;
@@ -352,8 +335,6 @@ bool TheoryBV::storePropagation(TNode literal, SubTheory subtheory)
   // If propagated already, just skip
   PropagatedMap::const_iterator find = d_propagatedBy.find(literal);
   if (find != d_propagatedBy.end()) {
-    //unsigned theories = (*find).second | (unsigned) subtheory;
-    //d_propagatedBy[literal] = theories;
     return true;
   } else {
     d_propagatedBy[literal] = subtheory;
@@ -362,56 +343,37 @@ bool TheoryBV::storePropagation(TNode literal, SubTheory subtheory)
   // See if the literal has been asserted already
   bool satValue = false;
   bool hasSatValue = d_valuation.hasSatValue(literal, satValue);
-  // If asserted, we might be in conflict
 
+  // If asserted, we might be in conflict
   if (hasSatValue && !satValue) {
-      Debug("bitvector-prop") << spaces(getSatContext()->getLevel()) << "TheoryBV::storePropagation(" << literal << ") => conflict" << std::endl;
-      std::vector<TNode> assumptions;
-      Node negatedLiteral = literal.getKind() == kind::NOT ? (Node) literal[0] : literal.notNode();
-      assumptions.push_back(negatedLiteral);
-      explain(literal, assumptions);
-      d_conflictNode = mkAnd(assumptions);
-      d_conflict = true;
-      return false;
+    Debug("bitvector-prop") << spaces(getSatContext()->getLevel()) << "TheoryBV::storePropagation(" << literal << ") => conflict" << std::endl;
+    std::vector<TNode> assumptions;
+    Node negatedLiteral = literal.getKind() == kind::NOT ? (Node) literal[0] : literal.notNode();
+    assumptions.push_back(negatedLiteral);
+    explain(literal, assumptions);
+    d_conflictNode = mkAnd(assumptions);
+    d_conflict = true;
+    return false;
   }
 
   // Nothing, just enqueue it for propagation and mark it as asserted already
   Debug("bitvector-prop") << spaces(getSatContext()->getLevel()) << "TheoryBV::storePropagation(" << literal << ") => enqueuing for propagation" << std::endl;
   d_literalsToPropagate.push_back(literal);
 
+  // No conflict
   return true;
 }/* TheoryBV::propagate(TNode) */
 
 
 void TheoryBV::explain(TNode literal, std::vector<TNode>& assumptions) {
-
   if (propagatedBy(literal, SUB_EQUALITY)) {
-    TNode lhs, rhs;
-    switch (literal.getKind()) {
-      case kind::EQUAL:
-        lhs = literal[0];
-        rhs = literal[1];
-        break;
-      case kind::NOT:
-        if (literal[0].getKind() == kind::EQUAL) {
-          // Disequalities
-          d_equalityEngine.explainDisequality(literal[0][0], literal[0][1], assumptions);
-          return;
-        } else {
-          // Predicates
-          lhs = literal[0];
-          rhs = d_false;
-          break;
-        }
-      case kind::CONST_BOOLEAN:
-        // we get to explain true = false, since we set false to be the trigger of this
-        lhs = d_true;
-        rhs = d_false;
-        break;
-      default:
-        Unreachable();
+    bool polarity = literal.getKind() != kind::NOT;
+    TNode atom = polarity ? literal : literal[0];
+    if (atom.getKind() == kind::EQUAL) {
+      d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
+    } else {
+      d_equalityEngine.explainPredicate(atom, polarity, assumptions);
     }
-    d_equalityEngine.explainEquality(lhs, rhs, assumptions);
   } else {
     Assert(propagatedBy(literal, SUB_BITBLASTER));
     d_bitblaster->explain(literal, assumptions); 
@@ -430,7 +392,9 @@ Node TheoryBV::explain(TNode node) {
     return utils::mkTrue(); 
   }
   // return the explanation
-  return mkAnd(assumptions);
+  Node explanation = mkAnd(assumptions);
+  Debug("bitvector::explain") << "TheoryBV::explain(" << node << ") => " << explanation << std::endl;
+  return explanation;
 }
 
 
index 0ced179ec0fa0c6cb320fd3c94b6afc19c691047..e46d052f887ac7829334cb880988bb0b72772785 100644 (file)
@@ -61,8 +61,6 @@ private:
   
   /** Bitblaster */
   Bitblaster* d_bitblaster; 
-  Node d_true;
-  Node d_false;
     
   /** Context dependent set of atoms we already propagated */
   context::CDHashSet<TNode, TNodeHashFunction> d_alreadyPropagatedSet;
@@ -99,22 +97,44 @@ private:
   
   // Added by Clark
   // NotifyClass: template helper class for d_equalityEngine - handles call-back from congruence closure module
-  class NotifyClass {
+  class NotifyClass : public eq::EqualityEngineNotify {
+
     TheoryBV& d_bv;
+
   public:
+
     NotifyClass(TheoryBV& uf): d_bv(uf) {}
 
-    bool notify(TNode propagation) {
-      Debug("bitvector") << spaces(d_bv.getSatContext()->getLevel()) << "NotifyClass::notify(" << propagation << ")" << std::endl;
-      // Just forward to bv
-      return d_bv.storePropagation(propagation, SUB_EQUALITY);
+    bool eqNotifyTriggerEquality(TNode equality, bool value) {
+      Debug("bitvector") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false" )<< ")" << std::endl;
+      if (value) {
+        return d_bv.storePropagation(equality, SUB_EQUALITY);
+      } else {
+        return d_bv.storePropagation(equality.notNode(), SUB_EQUALITY);
+      }
+    }
+
+    bool eqNotifyTriggerPredicate(TNode predicate, bool value) {
+      Debug("bitvector") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false" )<< ")" << std::endl;
+      if (value) {
+        return d_bv.storePropagation(predicate, SUB_EQUALITY);
+      } else {
+       return d_bv.storePropagation(predicate, SUB_EQUALITY);
+      }
+    }
+
+    bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) {
+      Debug("bitvector") << "NotifyClass::eqNotifyTriggerTermMerge(" << t1 << ", " << t2 << std::endl;
+      if (value) {
+        return d_bv.storePropagation(t1.eqNode(t2), SUB_EQUALITY);
+      } else {
+        return d_bv.storePropagation(t1.eqNode(t2).notNode(), SUB_EQUALITY);
+      }
     }
 
-    void notify(TNode t1, TNode t2) {
-      Debug("arrays") << spaces(d_bv.getSatContext()->getLevel()) << "NotifyClass::notify(" << t1 << ", " << t2 << ")" << std::endl;
-      // Propagate equality between shared terms
-      Node equality = Rewriter::rewriteEquality(theory::THEORY_UF, t1.eqNode(t2));
-      d_bv.storePropagation(t1.eqNode(t2), SUB_EQUALITY);
+    bool eqNotifyConstantTermMerge(TNode t1, TNode t2) {
+      Debug("bitvector") << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << std::endl;
+      return d_bv.storePropagation(t1.eqNode(t2), SUB_EQUALITY);
     }
   };
 
@@ -122,7 +142,7 @@ private:
   NotifyClass d_notify;
 
   /** Equaltity engine */
-  uf::EqualityEngine<NotifyClass> d_equalityEngine;
+  eq::EqualityEngine d_equalityEngine;
 
   // Are we in conflict?
   context::CDO<bool> d_conflict;
index eacc4e7987278463d401c88bb59908b5d602ff49..34706719ea5ceb193703a114a70b5573a02ef35a 100644 (file)
@@ -31,7 +31,7 @@ namespace theory {
 namespace datatypes {
 
 template <class NodeType, class NodeHash>
-void UnionFind<NodeType, NodeHash>::notify() {
+void UnionFind<NodeType, NodeHash>::contextNotifyPop() {
   Trace("datatypesuf") << "datatypesUF cancelling : " << d_offset << " < " << d_trace.size() << " ?" << endl;
   while(d_offset < d_trace.size()) {
     pair<TNode, TNode> p = d_trace.back();
@@ -50,9 +50,9 @@ void UnionFind<NodeType, NodeHash>::notify() {
 // The following declarations allow us to put functions in the .cpp file
 // instead of the header, since we know which instantiations are needed.
 
-template void UnionFind<Node, NodeHashFunction>::notify();
+template void UnionFind<Node, NodeHashFunction>::contextNotifyPop();
 
-template void UnionFind<TNode, TNodeHashFunction>::notify();
+template void UnionFind<TNode, TNodeHashFunction>::contextNotifyPop();
 
 }/* CVC4::theory::datatypes namespace */
 }/* CVC4::theory namespace */
index 51d1d85bc227b59ccd0f61e2ac882a896faa687d..4893c35023490b4b509148b137926160fd6b064e 100644 (file)
@@ -84,13 +84,13 @@ public:
    */
   inline void setCanon(TNode n, TNode newParent);
 
+protected:
 
-public:
   /**
    * Called by the Context when a pop occurs.  Cancels everything to the
-   * current context level.  Overrides ContextNotifyObj::notify().
+   * current context level.  Overrides ContextNotifyObj::contextNotifyPop().
    */
-  void notify();
+  void contextNotifyPop();
 
 };/* class UnionFind<> */
 
index 577e1b957f202574d38b9aeb7b02234cf8632b96..4f5475e976cc8e11cad06f1522c67d1bcee627e0 100644 (file)
@@ -16,7 +16,6 @@
  **/
 
 #include "theory/shared_terms_database.h"
-#include "theory/uf/equality_engine_impl.h"
 
 using namespace CVC4;
 using namespace theory;
@@ -36,15 +35,8 @@ SharedTermsDatabase::SharedTermsDatabase(SharedTermsNotifyClass& notify, context
     d_equalityEngine(d_EENotify, context, "SharedTermsDatabase")
 {
   StatisticsRegistry::registerStat(&d_statSharedTerms);
-  NodeManager* nm = NodeManager::currentNM();
-  d_true = nm->mkConst<bool>(true);
-  d_false = nm->mkConst<bool>(false);
-  d_equalityEngine.addTerm(d_true);
-  d_equalityEngine.addTerm(d_false);
-  d_equalityEngine.addTriggerEquality(d_true, d_false, d_false);
 }
 
-
 SharedTermsDatabase::~SharedTermsDatabase() throw(AssertionException)
 {
   StatisticsRegistry::unregisterStat(&d_statSharedTerms);
@@ -53,9 +45,9 @@ SharedTermsDatabase::~SharedTermsDatabase() throw(AssertionException)
   }
 }
 
-
 void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theories) {
   Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", " << term << ", " << Theory::setToString(theories) << ")" << std::endl; 
+
   std::pair<TNode, TNode> search_pair(atom, term);
   SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
   if (find == d_termsToTheories.end()) {
@@ -243,23 +235,21 @@ bool SharedTermsDatabase::areDisequal(TNode a, TNode b) {
   return d_equalityEngine.areDisequal(a,b);
 }
 
-
 void SharedTermsDatabase::processSharedLiteral(TNode literal, TNode reason)
 {
   bool negated = literal.getKind() == kind::NOT;
   TNode atom = negated ? literal[0] : literal;
   if (negated) {
     Assert(!d_equalityEngine.areDisequal(atom[0], atom[1]));
-    d_equalityEngine.addDisequality(atom[0], atom[1], reason);
+    d_equalityEngine.assertEquality(atom, false, reason);
     //    !!! need to send this out
   }
   else {
     Assert(!d_equalityEngine.areEqual(atom[0], atom[1]));
-    d_equalityEngine.addEquality(atom[0], atom[1], reason);
+    d_equalityEngine.assertEquality(atom, true, reason);
   }
 }
 
-
 static Node mkAnd(const std::vector<TNode>& conjunctions) {
   Assert(conjunctions.size() > 0);
 
@@ -286,31 +276,12 @@ static Node mkAnd(const std::vector<TNode>& conjunctions) {
 Node SharedTermsDatabase::explain(TNode literal)
 {
   std::vector<TNode> assumptions;
-  explain(literal, assumptions);
-  return mkAnd(assumptions);
-}
-
-
-void SharedTermsDatabase::explain(TNode literal, std::vector<TNode>& assumptions) {
-  TNode lhs, rhs;
-  switch (literal.getKind()) {
-    case kind::EQUAL:
-      lhs = literal[0];
-      rhs = literal[1];
-      break;
-    case kind::NOT:
-      if (literal[0].getKind() == kind::EQUAL) {
-        // Disequalities
-        d_equalityEngine.explainDisequality(literal[0][0], literal[0][1], assumptions);
-        return;
-      }
-    case kind::CONST_BOOLEAN:
-      // we get to explain true = false, since we set false to be the trigger of this
-      lhs = d_true;
-      rhs = d_false;
-      break;
-    default:
-      Unreachable();
+  if (literal.getKind() == kind::NOT) {
+    Assert(literal[0].getKind() == kind::EQUAL);
+    d_equalityEngine.explainEquality(literal[0][0], literal[0][1], false, assumptions);
+  } else {
+    Assert(literal.getKind() == kind::EQUAL);
+    d_equalityEngine.explainEquality(literal[0], literal[1], true, assumptions);
   }
-  d_equalityEngine.explainEquality(lhs, rhs, assumptions);
+  return mkAnd(assumptions);
 }
index 6af7fd41fa4161c8fb349d395b71826f1413edfe..403c90ced7e19538947d02ae736c36abcba2354b 100644 (file)
@@ -28,25 +28,23 @@ class SharedTermsDatabase : public context::ContextNotifyObj {
 
 public:
 
-  /** A conainer for a list of shared terms */
+  /** A container for a list of shared terms */
   typedef std::vector<TNode> shared_terms_list;
-  /** The iterator to go rhough the shared terms list */
+
+  /** The iterator to go through the shared terms list */
   typedef shared_terms_list::const_iterator shared_terms_iterator;
 
 private:
 
-  Node d_true;
-
-  Node d_false;
-
   /** The context */
   context::Context* d_context;
   
   /** Some statistics */
   IntStat d_statSharedTerms;
 
-  // Needs to be a map from Nodes as after a backtrack they might not exist 
+  // Needs to be a map from Nodes as after a backtrack they might not exist
   typedef std::hash_map<Node, shared_terms_list, TNodeHashFunction> SharedTermsMap;
+
   /** A map from atoms to a list of shared terms */
   SharedTermsMap d_atomsToTerms;
   
@@ -57,14 +55,17 @@ private:
   context::CDO<unsigned> d_addedSharedTermsSize;
   
   typedef context::CDHashMap<std::pair<Node, TNode>, theory::Theory::Set, TNodePairHashFunction> SharedTermsTheoriesMap;
+
   /** A map from atoms and subterms to the theories that use it */
   SharedTermsTheoriesMap d_termsToTheories;
 
   typedef context::CDHashMap<TNode, theory::Theory::Set, TNodeHashFunction> AlreadyNotifiedMap;
+
   /** Map from term to theories that have already been notified about the shared term */
   AlreadyNotifiedMap d_alreadyNotifiedMap;
 
 public:
+
   /** Class for notifications about new shared term equalities */
   class SharedTermsNotifyClass {
     public:
@@ -74,6 +75,7 @@ public:
   };
 
 private:
+
   // Instance of class to send shared term notifications to
   SharedTermsNotifyClass& d_sharedNotify;
 
@@ -101,21 +103,37 @@ private:
   void backtrack();
 
   // EENotifyClass: template helper class for d_equalityEngine - handles call-backs
-  class EENotifyClass {
+  class EENotifyClass : public theory::eq::EqualityEngineNotify {
     SharedTermsDatabase& d_sharedTerms;
   public:
     EENotifyClass(SharedTermsDatabase& shared): d_sharedTerms(shared) {}
-    bool notify(TNode propagation) { return true; }    // Not used
-    void notify(TNode t1, TNode t2) {
-      d_sharedTerms.mergeSharedTerms(t1, t2);
+    bool eqNotifyTriggerEquality(TNode equality, bool value) {
+      Unreachable();
+      return true;
+    }
+
+    bool eqNotifyTriggerPredicate(TNode predicate, bool value) {
+      Unreachable();
+      return true;
+    }
+
+    bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) {
+      if (value) {
+        d_sharedTerms.mergeSharedTerms(t1, t2);
+      }
+      return true;
+    }
+
+    bool eqNotifyConstantTermMerge(TNode t1, TNode t2) {
+      return true;
     }
   };
 
   /** The notify class for d_equalityEngine */
   EENotifyClass d_EENotify;
 
-  /** Equaltity engine */
-  theory::uf::EqualityEngine<EENotifyClass> d_equalityEngine;
+  /** Equality engine */
+  theory::eq::EqualityEngine d_equalityEngine;
 
   /** Attach a new notify list to an equivalence class representative */
   NotifyList* getNewNotifyList();
@@ -123,9 +141,6 @@ private:
   /** Method called by equalityEngine when a becomes equal to b */
   void mergeSharedTerms(TNode a, TNode b);
 
-  /** Internal explanation method */
-  void explain(TNode literal, std::vector<TNode>& assumptions);
-
 public:
 
   SharedTermsDatabase(SharedTermsNotifyClass& notify, context::Context* context);
@@ -179,10 +194,12 @@ public:
 
   Node explain(TNode literal);
 
+protected:
+
   /**
    * This method gets called on backtracks from the context manager.
    */
-  void notify() {
+  void contextNotifyPop() {
     backtrack();
   }
 };
index 27c1a2b6999dbe60c9d80af28f5acbfa38e3302b..958f502762f0d4f1454fe5fadda306433deb211f 100644 (file)
@@ -73,16 +73,16 @@ private:
   /** Helper class to invalidate cache on user pop */
   class CacheInvalidator : public context::ContextNotifyObj {
     bool& d_cacheInvalidated;
-
+  protected:
+    void contextNotifyPop() {
+      d_cacheInvalidated = true;
+    }
   public:
     CacheInvalidator(context::Context* context, bool& cacheInvalidated) :
       context::ContextNotifyObj(context),
       d_cacheInvalidated(cacheInvalidated) {
     }
 
-    void notify() {
-      d_cacheInvalidated = true;
-    }
   };/* class SubstitutionMap::CacheInvalidator */
 
   /**
index a3aee985d71aa29aee22561269e97bf7444a48fe..c19bdda919b0d6a759ce1cf54feecee7baf2eef4 100644 (file)
@@ -124,6 +124,50 @@ void TheoryEngine::preRegister(TNode preprocessed) {
   // }
 }
 
+void TheoryEngine::printAssertions(const char* tag) {
+  if (Debug.isOn(tag)) {
+    for (TheoryId theoryId = THEORY_FIRST; theoryId < THEORY_LAST; ++theoryId) {
+      Theory* theory = d_theoryTable[theoryId];
+      if (theory && d_logicInfo.isTheoryEnabled(theoryId)) {
+        Debug(tag) << "--------------------------------------------" << std::endl;
+        Debug(tag) << "Assertions of " << theory->getId() << ": " << std::endl;
+        context::CDList<Assertion>::const_iterator it = theory->facts_begin(), it_end = theory->facts_end();
+        for (unsigned i = 0; it != it_end; ++ it, ++i) {
+            if ((*it).isPreregistered) {
+              Debug(tag) << "[" << i << "]: ";
+            } else {
+              Debug(tag) << "(" << i << "): ";
+            }
+            Debug(tag) << (*it).assertion << endl;
+        }
+
+        if (d_logicInfo.isSharingEnabled()) {
+          Debug(tag) << "Shared terms of " << theory->getId() << ": " << std::endl;
+          context::CDList<TNode>::const_iterator it = theory->shared_terms_begin(), it_end = theory->shared_terms_end();
+          for (unsigned i = 0; it != it_end; ++ it, ++i) {
+              Debug(tag) << "[" << i << "]: " << (*it) << endl;
+          }
+        }
+      }
+    }
+
+  }
+}
+
+template<typename T, bool doAssert>
+class scoped_vector_clear {
+  vector<T>& d_v;
+public:
+  scoped_vector_clear(vector<T>& v)
+  : d_v(v) {
+    Assert(!doAssert || d_v.empty());
+  }
+  ~scoped_vector_clear() {
+    d_v.clear();
+  }
+
+};
+
 /**
  * Check all (currently-active) theories for conflicts.
  * @param effort the effort level to use
@@ -143,12 +187,12 @@ void TheoryEngine::check(Theory::Effort effort) {
        } \
     }
 
+  // make sure d_propagatedSharedLiterals is cleared on exit
+  scoped_vector_clear<SharedLiteral, true> clear_shared_literals(d_propagatedSharedLiterals);
+
   // Do the checking
   try {
 
-    // Clear any leftover propagated shared literals
-    d_propagatedSharedLiterals.clear();
-
     // Mark the output channel unused (if this is FULL_EFFORT, and nothing
     // is done by the theories, no additional check will be needed)
     d_outputChannelUsed = false;
@@ -159,32 +203,10 @@ void TheoryEngine::check(Theory::Effort effort) {
     while (true) {
 
       Debug("theory") << "TheoryEngine::check(" << effort << "): running check" << std::endl;
+      Assert(d_propagatedSharedLiterals.empty());
 
       if (Debug.isOn("theory::assertions")) {
-        for (TheoryId theoryId = THEORY_FIRST; theoryId < THEORY_LAST; ++theoryId) {
-          Theory* theory = d_theoryTable[theoryId];
-          if (theory && d_logicInfo.isTheoryEnabled(theoryId)) {
-            Debug("theory::assertions") << "--------------------------------------------" << std::endl;
-            Debug("theory::assertions") << "Assertions of " << theory->getId() << ": " << std::endl;
-            context::CDList<Assertion>::const_iterator it = theory->facts_begin(), it_end = theory->facts_end();
-            for (unsigned i = 0; it != it_end; ++ it, ++i) {
-                if ((*it).isPreregistered) {
-                  Debug("theory::assertions") << "[" << i << "]: ";
-                } else {
-                  Debug("theory::assertions") << "(" << i << "): ";
-                }
-                Debug("theory::assertions") << (*it).assertion << endl;
-            }
-
-            if (d_logicInfo.isSharingEnabled()) {
-              Debug("theory::assertions") << "Shared terms of " << theory->getId() << ": " << std::endl;
-              context::CDList<TNode>::const_iterator it = theory->shared_terms_begin(), it_end = theory->shared_terms_end();
-              for (unsigned i = 0; it != it_end; ++ it, ++i) {
-                  Debug("theory::assertions") << "[" << i << "]: " << (*it) << endl;
-              }
-            }
-          }
-        }
+        printAssertions("theory::assertions");
       }
 
       // Do the checking
@@ -232,9 +254,6 @@ void TheoryEngine::check(Theory::Effort effort) {
       }
     }
 
-    // Clear any leftover propagated shared literals
-    d_propagatedSharedLiterals.clear();
-
     Debug("theory") << "TheoryEngine::check(" << effort << "): done, we are " << (d_inConflict ? "unsat" : "sat") << (d_lemmasAdded ? " with new lemmas" : " with no new lemmas") << std::endl;
 
   } catch(const theory::Interrupted&) {
@@ -243,6 +262,9 @@ void TheoryEngine::check(Theory::Effort effort) {
 }
 
 void TheoryEngine::outputSharedLiterals() {
+
+  scoped_vector_clear<SharedLiteral, false> clear_shared_literals(d_propagatedSharedLiterals);
+
   // Assert all the shared literals
   for (unsigned i = 0; i < d_propagatedSharedLiterals.size(); ++ i) {
     const SharedLiteral& eq = d_propagatedSharedLiterals[i];
@@ -258,8 +280,6 @@ void TheoryEngine::outputSharedLiterals() {
       }
     }
   }
-  // Clear the equalities
-  d_propagatedSharedLiterals.clear();
 }
 
 
@@ -269,7 +289,9 @@ void TheoryEngine::combineTheories() {
 
   TimerStat::CodeTimer combineTheoriesTimer(d_combineTheoriesTime);
 
+  // Care graph we'll be building
   CareGraph careGraph;
+
 #ifdef CVC4_FOR_EACH_THEORY_STATEMENT
 #undef CVC4_FOR_EACH_THEORY_STATEMENT
 #endif
@@ -278,6 +300,7 @@ void TheoryEngine::combineTheories() {
      reinterpret_cast<theory::TheoryTraits<THEORY>::theory_class*>(theoryOf(THEORY))->getCareGraph(careGraph); \
   }
 
+  // Call on each parametric theory to give us its care graph
   CVC4_FOR_EACH_THEORY;
 
   // Now add splitters for the ones we are interested in
@@ -833,6 +856,8 @@ Node TheoryEngine::getExplanation(TNode node) {
   }
   Assert(properExplanation(node, explanation));
 
+  Debug("theory::explain") << "TheoryEngine::getExplanation(" << node << ") => " << explanation << std::endl;
+
   return explanation;
 }
 
index 5c73da1f62e8a32ae2a7605946cae1b80150a208..2871d5559d0fe43cc9b4ae86dbe0cc3ebff19517 100644 (file)
@@ -309,8 +309,10 @@ class TheoryEngine {
   }
 
   struct SharedLiteral {
-    /** The node/theory pair for the assertion */
-    /** THEORY_LAST indicates this is a SAT literal and should be sent to the SAT solver */
+    /**
+     * The node/theory pair for the assertion. THEORY_LAST indicates this is a SAT
+     * literal and should be sent to the SAT solver
+     */
     NodeTheoryPair toAssert;
     /** This is the node that we will use to explain it */
     Node toExplain;
@@ -319,7 +321,7 @@ class TheoryEngine {
     : toAssert(assertion, receivingTheory),
       toExplain(original)
     { }
-  };/* struct SharedLiteral */
+  };
 
   /**
    * Map from nodes to theories.
@@ -728,6 +730,9 @@ private:
   /** Visitor for collecting shared terms */
   SharedTermsVisitor d_sharedTermsVisitor;
 
+  /** Prints the assertions to the debug stream */
+  void printAssertions(const char* tag);
+
 };/* class TheoryEngine */
 
 }/* CVC4 namespace */
index f25e50ec922db62f719080741386c7fe6b640171..9d95eaa224e32940b9275eb6e81fa93d0669d2b2 100644 (file)
@@ -11,7 +11,7 @@ libuf_la_SOURCES = \
        theory_uf_type_rules.h \
        theory_uf_rewriter.h \
        equality_engine.h \
-       equality_engine_impl.h \
+       equality_engine.cpp \
        symmetry_breaker.h \
        symmetry_breaker.cpp
 
diff --git a/src/theory/uf/equality_engine.cpp b/src/theory/uf/equality_engine.cpp
new file mode 100644 (file)
index 0000000..b78015c
--- /dev/null
@@ -0,0 +1,995 @@
+/*********************                                                        */
+/*! \file equality_engine_impl.h
+ ** \verbatim
+ ** Original author: dejan
+ ** Major contributors: none
+ ** Minor contributors (to current version): none
+ ** This file is part of the CVC4 prototype.
+ ** Copyright (c) 2009, 2010, 2011  The Analysis of Computer Systems Group (ACSys)
+ ** Courant Institute of Mathematical Sciences
+ ** New York University
+ ** See the file COPYING in the top-level source directory for licensing
+ ** information.\endverbatim
+ **
+ ** \brief [[ Add one-line brief description here ]]
+ **
+ ** [[ Add lengthier description here ]]
+ ** \todo document this file
+ **/
+
+#include "theory/uf/equality_engine.h"
+
+namespace CVC4 {
+namespace theory {
+namespace eq {
+
+/**
+ * Data used in the BFS search through the equality graph.
+ */
+struct BfsData {
+  // The current node
+  EqualityNodeId nodeId;
+  // The index of the edge we traversed
+  EqualityEdgeId edgeId;
+  // Index in the queue of the previous node. Shouldn't be too much of them, at most the size
+  // of the biggest equivalence class
+  size_t previousIndex;
+
+  BfsData(EqualityNodeId nodeId = null_id, EqualityEdgeId edgeId = null_edge, size_t prev = 0)
+  : nodeId(nodeId), edgeId(edgeId), previousIndex(prev) {}
+};
+
+class ScopedBool {
+  bool& watch;
+  bool oldValue;
+public:
+  ScopedBool(bool& watch, bool newValue)
+  : watch(watch), oldValue(watch) {
+    watch = newValue;
+  }
+  ~ScopedBool() {
+    watch = oldValue;
+  }
+};
+
+EqualityEngineNotifyNone EqualityEngine::s_notifyNone;
+
+void EqualityEngine::init() {
+  Debug("equality") << "EqualityEdge::EqualityEngine(): id_null = " << +null_id << std::endl;
+  Debug("equality") << "EqualityEdge::EqualityEngine(): edge_null = " << +null_edge << std::endl;
+  Debug("equality") << "EqualityEdge::EqualityEngine(): trigger_null = " << +null_trigger << std::endl;
+  d_true = NodeManager::currentNM()->mkConst<bool>(true);
+  d_false = NodeManager::currentNM()->mkConst<bool>(false);
+  addTerm(d_true);
+  addTerm(d_false);
+} 
+
+
+EqualityEngine::EqualityEngine(context::Context* context, std::string name) 
+: ContextNotifyObj(context)
+, d_context(context)
+, d_performNotify(true)
+, d_notify(s_notifyNone)
+, d_applicationLookupsCount(context, 0)
+, d_nodesCount(context, 0)
+, d_assertedEqualitiesCount(context, 0)
+, d_equalityTriggersCount(context, 0)
+, d_individualTriggersSize(context, 0)
+, d_constantRepresentativesSize(context, 0)
+, d_stats(name)
+{
+  init();
+}
+
+EqualityEngine::EqualityEngine(EqualityEngineNotify& notify, context::Context* context, std::string name)
+: ContextNotifyObj(context)
+, d_context(context)
+, d_performNotify(true)
+, d_notify(notify)
+, d_applicationLookupsCount(context, 0)
+, d_nodesCount(context, 0)
+, d_assertedEqualitiesCount(context, 0)
+, d_equalityTriggersCount(context, 0)
+, d_individualTriggersSize(context, 0)
+, d_constantRepresentativesSize(context, 0)
+, d_stats(name)
+{
+  init();
+}
+
+void EqualityEngine::enqueue(const MergeCandidate& candidate) {
+    Debug("equality") << "EqualityEngine::enqueue(" << candidate.toString(*this) << ")" << std::endl;
+    d_propagationQueue.push(candidate);
+}
+
+EqualityNodeId EqualityEngine::newApplicationNode(TNode original, EqualityNodeId t1, EqualityNodeId t2) {
+  Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << ")" << std::endl;
+
+  ++ d_stats.functionTermsCount;
+
+  // Get another id for this
+  EqualityNodeId funId = newNode(original);
+  FunctionApplication funOriginal(t1, t2);
+  // The function application we're creating
+  EqualityNodeId t1ClassId = getEqualityNode(t1).getFind();
+  EqualityNodeId t2ClassId = getEqualityNode(t2).getFind();
+  FunctionApplication funNormalized(t1ClassId, t2ClassId);
+
+  // We add the original version
+  d_applications[funId] = FunctionApplicationPair(funOriginal, funNormalized);
+
+  // Add the lookup data, if it's not already there
+  typename ApplicationIdsMap::iterator find = d_applicationLookup.find(funNormalized);
+  if (find == d_applicationLookup.end()) {
+    // When we backtrack, if the lookup is not there anymore, we'll add it again
+    Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << "): no lookup, setting up" << std::endl;
+    // Mark the normalization to the lookup
+    storeApplicationLookup(funNormalized, funId);
+  } else {
+    // If it's there, we need to merge these two
+    Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << "): lookup exists, adding to queue" << std::endl;
+    enqueue(MergeCandidate(funId, find->second, MERGED_THROUGH_CONGRUENCE, TNode::null()));
+    propagate();
+  }
+
+  // Add to the use lists
+  d_equalityNodes[t1ClassId].usedIn(funId, d_useListNodes);
+  d_equalityNodes[t2ClassId].usedIn(funId, d_useListNodes);
+
+  // Return the new id
+  Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << ") => " << funId << std::endl;
+
+  return funId;
+}
+
+EqualityNodeId EqualityEngine::newNode(TNode node) {
+
+  Debug("equality") << "EqualityEngine::newNode(" << node << ")" << std::endl;
+
+  ++ d_stats.termsCount;
+
+  // Register the new id of the term
+  EqualityNodeId newId = d_nodes.size();
+  d_nodeIds[node] = newId;
+  // Add the node to it's position
+  d_nodes.push_back(node);
+  // Note if this is an application or not
+  d_applications.push_back(FunctionApplicationPair());
+  // Add the trigger list for this node
+  d_nodeTriggers.push_back(+null_trigger);
+  // Add it to the equality graph
+  d_equalityGraph.push_back(+null_edge);
+  // Mark the no-individual trigger
+  d_nodeIndividualTrigger.push_back(+null_id);
+  // Mark non-constant by default
+  d_constantRepresentative.push_back(node.isConst() ? newId : +null_id);
+  // Add the equality node to the nodes
+  d_equalityNodes.push_back(EqualityNode(newId));
+
+  // Increase the counters
+  d_nodesCount = d_nodesCount + 1;
+
+  Debug("equality") << "EqualityEngine::newNode(" << node << ") => " << newId << std::endl;
+
+  return newId;
+}
+
+void EqualityEngine::addTerm(TNode t) {
+
+  Debug("equality") << "EqualityEngine::addTerm(" << t << ")" << std::endl;
+
+  // If there already, we're done
+  if (hasTerm(t)) {
+    Debug("equality") << "EqualityEngine::addTerm(" << t << "): already there" << std::endl;
+    return;
+  }
+
+  EqualityNodeId result;
+
+  // If a function application we go in
+  if (t.getNumChildren() > 0 && d_congruenceKinds[t.getKind()])
+  {
+    // Add the operator
+    TNode tOp = t.getOperator();
+    addTerm(tOp);
+    // Add all the children and Curryfy
+    result = getNodeId(tOp);
+    for (unsigned i = 0; i < t.getNumChildren(); ++ i) {
+      // Add the child
+      addTerm(t[i]);
+      // Add the application
+      result = newApplicationNode(t, result, getNodeId(t[i]));
+    }
+  } else {
+    // Otherwise we just create the new id
+    result = newNode(t);
+  }
+
+  Debug("equality") << "EqualityEngine::addTerm(" << t << ") => " << result << std::endl;
+}
+
+bool EqualityEngine::hasTerm(TNode t) const {
+  return d_nodeIds.find(t) != d_nodeIds.end();
+}
+
+EqualityNodeId EqualityEngine::getNodeId(TNode node) const {
+  Assert(hasTerm(node), node.toString().c_str());
+  return (*d_nodeIds.find(node)).second;
+}
+
+EqualityNode& EqualityEngine::getEqualityNode(TNode t) {
+  return getEqualityNode(getNodeId(t));
+}
+
+EqualityNode& EqualityEngine::getEqualityNode(EqualityNodeId nodeId) {
+  Assert(nodeId < d_equalityNodes.size());
+  return d_equalityNodes[nodeId];
+}
+
+const EqualityNode& EqualityEngine::getEqualityNode(TNode t) const {
+  return getEqualityNode(getNodeId(t));
+}
+
+const EqualityNode& EqualityEngine::getEqualityNode(EqualityNodeId nodeId) const {
+  Assert(nodeId < d_equalityNodes.size());
+  return d_equalityNodes[nodeId];
+}
+
+void EqualityEngine::assertEqualityInternal(TNode t1, TNode t2, TNode reason) {
+
+  Debug("equality") << "EqualityEngine::addEqualityInternal(" << t1 << "," << t2 << ")" << std::endl;
+
+  // Add the terms if they are not already in the database
+  addTerm(t1);
+  addTerm(t2);
+
+  // Add to the queue and propagate
+  EqualityNodeId t1Id = getNodeId(t1);
+  EqualityNodeId t2Id = getNodeId(t2);
+  enqueue(MergeCandidate(t1Id, t2Id, MERGED_THROUGH_EQUALITY, reason));
+
+  propagate();
+}
+
+void EqualityEngine::assertPredicate(TNode t, bool polarity, TNode reason) {
+  Debug("equality") << "EqualityEngine::addPredicate(" << t << "," << (polarity ? "true" : "false") << ")" << std::endl;
+  Assert(t.getKind() != kind::EQUAL, "Use assertEquality instead");
+  assertEqualityInternal(t, polarity ? d_true : d_false, reason);
+}
+
+void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason) {
+  Debug("equality") << "EqualityEngine::addEquality(" << eq << "," << (polarity ? "true" : "false") << std::endl;
+  if (polarity) {
+    // Add equality between terms
+    assertEqualityInternal(eq[0], eq[1], reason);
+    // Add eq = true for dis-equality propagation
+    assertEqualityInternal(eq, d_true, reason);
+  } else {
+    assertEqualityInternal(eq, d_false, reason);
+    Node eqSymm = eq[1].eqNode(eq[0]);
+    assertEqualityInternal(eqSymm, d_false, reason);
+  }
+}
+
+TNode EqualityEngine::getRepresentative(TNode t) const {
+  Debug("equality::internal") << "EqualityEngine::getRepresentative(" << t << ")" << std::endl;
+  Assert(hasTerm(t));
+  EqualityNodeId representativeId = getEqualityNode(t).getFind();
+  Debug("equality::internal") << "EqualityEngine::getRepresentative(" << t << ") => " << d_nodes[representativeId] << std::endl;
+  return d_nodes[representativeId];
+}
+
+bool EqualityEngine::areEqual(TNode t1, TNode t2) const {
+  Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ")" << std::endl;
+
+  Assert(hasTerm(t1));
+  Assert(hasTerm(t2));
+
+  // Both following commands are semantically const
+  EqualityNodeId rep1 = getEqualityNode(t1).getFind();
+  EqualityNodeId rep2 = getEqualityNode(t2).getFind();
+
+  Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ") => " << (rep1 == rep2 ? "true" : "false") << std::endl;
+
+  return rep1 == rep2;
+}
+
+bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vector<TriggerId>& triggers) {
+
+  Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << ")" << std::endl;
+
+  Assert(triggers.empty());
+
+  ++ d_stats.mergesCount;
+
+  EqualityNodeId class1Id = class1.getFind();
+  EqualityNodeId class2Id = class2.getFind();
+
+  // Update class2 representative information
+  Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating class " << class2Id << std::endl;
+  EqualityNodeId currentId = class2Id;
+  do {
+    // Get the current node
+    EqualityNode& currentNode = getEqualityNode(currentId);
+
+    // Update it's find to class1 id
+    Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): " << currentId << "->" << class1Id << std::endl;
+    currentNode.setFind(class1Id);
+
+    // Go through the triggers and inform if necessary
+    TriggerId currentTrigger = d_nodeTriggers[currentId];
+    while (currentTrigger != null_trigger) {
+      Trigger& trigger = d_equalityTriggers[currentTrigger];
+      Trigger& otherTrigger = d_equalityTriggers[currentTrigger ^ 1];
+
+      // If the two are not already in the same class
+      if (otherTrigger.classId != trigger.classId) {
+        trigger.classId = class1Id;
+        // If they became the same, call the trigger
+        if (otherTrigger.classId == class1Id) {
+          // Id of the real trigger is half the internal one
+          triggers.push_back(currentTrigger);
+        }
+      }
+
+      // Go to the next trigger
+      currentTrigger = trigger.nextTrigger;
+    }
+
+    // Move to the next node
+    currentId = currentNode.getNext();
+
+  } while (currentId != class2Id);
+
+
+  // Update class2 table lookup and information
+  Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating lookups of " << class2Id << std::endl;
+  do {
+    // Get the current node
+    EqualityNode& currentNode = getEqualityNode(currentId);    
+    Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating lookups of node " << currentId << std::endl;
+    // Go through the uselist and check for congruences
+    UseListNodeId currentUseId = currentNode.getUseList();
+    while (currentUseId != null_uselist_id) {
+      // Get the node of the use list
+      UseListNode& useNode = d_useListNodes[currentUseId];
+      // Get the function application
+      EqualityNodeId funId = useNode.getApplicationId();
+      Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): " << currentId << " in " << d_nodes[funId] << std::endl;
+      const FunctionApplication& fun = d_applications[useNode.getApplicationId()].normalized;
+      // Check if there is an application with find arguments
+      EqualityNodeId aNormalized = getEqualityNode(fun.a).getFind();
+      EqualityNodeId bNormalized = getEqualityNode(fun.b).getFind();
+      FunctionApplication funNormalized(aNormalized, bNormalized);
+      typename ApplicationIdsMap::iterator find = d_applicationLookup.find(funNormalized);
+      if (find != d_applicationLookup.end()) {
+        // Applications fun and the funNormalized can be merged due to congruence
+        if (getEqualityNode(funId).getFind() != getEqualityNode(find->second).getFind()) {
+          enqueue(MergeCandidate(funId, find->second, MERGED_THROUGH_CONGRUENCE, TNode::null()));
+        }
+      } else {
+        // There is no representative, so we can add one, we remove this when backtracking
+        storeApplicationLookup(funNormalized, funId);
+      }
+      // Go to the next one in the use list
+      currentUseId = useNode.getNext();
+    }
+
+    // Move to the next node
+    currentId = currentNode.getNext();
+  } while (currentId != class2Id);
+
+  // Now merge the lists
+  class1.merge<true>(class2);
+
+  // Check for constants
+  EqualityNodeId class2constId = d_constantRepresentative[class2Id];
+  if (class2constId != +null_id) {
+    EqualityNodeId class1constId = d_constantRepresentative[class1Id];
+    if (class1constId != +null_id) {
+      if (d_performNotify) {
+        TNode const1 = d_nodes[class1constId];
+        TNode const2 = d_nodes[class2constId];
+        if (!d_notify.eqNotifyConstantTermMerge(const1, const2)) {
+          return false;
+       } 
+      }
+    } else {
+      // If the class we're merging in is constant, mark the representative as constant
+      d_constantRepresentative[class1Id] = d_constantRepresentative[class2Id];
+      d_constantRepresentatives.push_back(class1Id);
+      d_constantRepresentativesSize = d_constantRepresentativesSize + 1;  
+    }
+  }
+
+  // Notify the trigger term merges
+  EqualityNodeId class2triggerId = d_nodeIndividualTrigger[class2Id];
+  if (class2triggerId != +null_id) {
+    EqualityNodeId class1triggerId = d_nodeIndividualTrigger[class1Id];
+    if (class1triggerId == +null_id) {
+      // If class1 is not an individual trigger, but class2 is, mark it
+      d_nodeIndividualTrigger[class1Id] = class2triggerId;
+      // Add it to the list for backtracking
+      d_individualTriggers.push_back(class1Id);
+      d_individualTriggersSize = d_individualTriggersSize + 1;  
+    } else {
+      // Notify when done
+      if (d_performNotify) {
+        if (!d_notify.eqNotifyTriggerTermEquality(d_nodes[class1triggerId], d_nodes[class2triggerId], true)) {
+          return false;
+        }
+      }
+    }  
+  }
+
+  // Everything fine
+  return true;
+}
+
+void EqualityEngine::undoMerge(EqualityNode& class1, EqualityNode& class2, EqualityNodeId class2Id) {
+
+  Debug("equality") << "EqualityEngine::undoMerge(" << class1.getFind() << "," << class2Id << ")" << std::endl;
+
+  // Now unmerge the lists (same as merge)
+  class1.merge<false>(class2);
+
+  // Update class2 representative information
+  EqualityNodeId currentId = class2Id;
+  Debug("equality") << "EqualityEngine::undoMerge(" << class1.getFind() << "," << class2Id << "): undoing representative info" << std::endl;
+  do {
+    // Get the current node
+    EqualityNode& currentNode = getEqualityNode(currentId);
+
+    // Update it's find to class1 id
+    currentNode.setFind(class2Id);
+
+    // Go through the trigger list (if any) and undo the class
+    TriggerId currentTrigger = d_nodeTriggers[currentId];
+    while (currentTrigger != null_trigger) {
+      Trigger& trigger = d_equalityTriggers[currentTrigger];
+      trigger.classId = class2Id;
+      currentTrigger = trigger.nextTrigger;
+    }
+
+    // Move to the next node
+    currentId = currentNode.getNext();
+
+  } while (currentId != class2Id);
+
+}
+
+void EqualityEngine::backtrack() {
+
+  Debug("equality::backtrack") << "backtracking" << std::endl;
+
+  // If we need to backtrack then do it
+  if (d_assertedEqualitiesCount < d_assertedEqualities.size()) {
+
+    // Clear the propagation queue
+    while (!d_propagationQueue.empty()) {
+      d_propagationQueue.pop();
+    }
+
+    Debug("equality") << "EqualityEngine::backtrack(): nodes" << std::endl;
+
+    for (int i = (int)d_assertedEqualities.size() - 1, i_end = (int)d_assertedEqualitiesCount; i >= i_end; --i) {
+      // Get the ids of the merged classes
+      Equality& eq = d_assertedEqualities[i];
+      // Undo the merge
+      undoMerge(d_equalityNodes[eq.lhs], d_equalityNodes[eq.rhs], eq.rhs);
+    }
+
+    d_assertedEqualities.resize(d_assertedEqualitiesCount);
+
+    Debug("equality") << "EqualityEngine::backtrack(): edges" << std::endl;
+
+    for (int i = (int)d_equalityEdges.size() - 2, i_end = (int)(2*d_assertedEqualitiesCount); i >= i_end; i -= 2) {
+      EqualityEdge& edge1 = d_equalityEdges[i];
+      EqualityEdge& edge2 = d_equalityEdges[i | 1];
+      d_equalityGraph[edge2.getNodeId()] = edge1.getNext();
+      d_equalityGraph[edge1.getNodeId()] = edge2.getNext();
+    }
+
+    d_equalityEdges.resize(2 * d_assertedEqualitiesCount);
+  }
+
+  if (d_individualTriggers.size() > d_individualTriggersSize) {
+    // Unset the individual triggers
+    for (int i = d_individualTriggers.size() - 1, i_end = d_individualTriggersSize; i >= i_end; -- i) {
+      d_nodeIndividualTrigger[d_individualTriggers[i]] = +null_id;
+    }
+    d_individualTriggers.resize(d_individualTriggersSize);
+  }
+  
+  if (d_constantRepresentatives.size() > d_constantRepresentativesSize) {
+    // Unset the constant representatives
+    for (int i = d_constantRepresentatives.size() - 1, i_end = d_constantRepresentativesSize; i >= i_end; -- i) {
+      d_constantRepresentative[d_constantRepresentatives[i]] = +null_id;
+    }
+    d_constantRepresentatives.resize(d_constantRepresentativesSize);
+  }
+
+  if (d_equalityTriggers.size() > d_equalityTriggersCount) {
+    // Unlink the triggers from the lists
+    for (int i = d_equalityTriggers.size() - 1, i_end = d_equalityTriggersCount; i >= i_end; -- i) {
+      const Trigger& trigger = d_equalityTriggers[i];
+      d_nodeTriggers[trigger.classId] = trigger.nextTrigger;
+    }
+    // Get rid of the triggers 
+    d_equalityTriggers.resize(d_equalityTriggersCount);
+    d_equalityTriggersOriginal.resize(d_equalityTriggersCount);
+  }
+
+  if (d_applicationLookups.size() > d_applicationLookupsCount) {
+    for (int i = d_applicationLookups.size() - 1, i_end = (int) d_applicationLookupsCount; i >= i_end; -- i) {
+      d_applicationLookup.erase(d_applicationLookups[i]);
+    }
+    d_applicationLookups.resize(d_applicationLookupsCount);
+  }
+
+  if (d_nodes.size() > d_nodesCount) {
+    // Go down the nodes, check the application nodes and remove them from use-lists
+    for(int i = d_nodes.size() - 1, i_end = (int)d_nodesCount; i >= i_end; -- i) {
+      // Remove from the node -> id map
+      Debug("equality") << "EqualityEngine::backtrack(): removing node " << d_nodes[i] << std::endl;
+      d_nodeIds.erase(d_nodes[i]);
+
+      const FunctionApplication& app = d_applications[i].normalized;
+      if (app.isApplication()) {
+        // Remove b from use-list
+        getEqualityNode(app.b).removeTopFromUseList(d_useListNodes);
+        // Remove a from use-list
+        getEqualityNode(app.a).removeTopFromUseList(d_useListNodes);
+      }
+    }
+
+    // Now get rid of the nodes and the rest
+    d_nodes.resize(d_nodesCount);
+    d_applications.resize(d_nodesCount);
+    d_nodeTriggers.resize(d_nodesCount);
+    d_nodeIndividualTrigger.resize(d_nodesCount);
+    d_constantRepresentative.resize(d_nodesCount);
+    d_equalityGraph.resize(d_nodesCount);
+    d_equalityNodes.resize(d_nodesCount);
+  }
+}
+
+void EqualityEngine::addGraphEdge(EqualityNodeId t1, EqualityNodeId t2, MergeReasonType type, TNode reason) {
+  Debug("equality") << "EqualityEngine::addGraphEdge(" << d_nodes[t1] << "," << d_nodes[t2] << "," << reason << ")" << std::endl;
+  EqualityEdgeId edge = d_equalityEdges.size();
+  d_equalityEdges.push_back(EqualityEdge(t2, d_equalityGraph[t1], type, reason));
+  d_equalityEdges.push_back(EqualityEdge(t1, d_equalityGraph[t2], type, reason));
+  d_equalityGraph[t1] = edge;
+  d_equalityGraph[t2] = edge | 1;
+
+  if (Debug.isOn("equality::internal")) {
+    debugPrintGraph();
+  }
+}
+
+std::string EqualityEngine::edgesToString(EqualityEdgeId edgeId) const {
+  std::stringstream out;
+  bool first = true;
+  if (edgeId == null_edge) {
+    out << "null";
+  } else {
+    while (edgeId != null_edge) {
+      const EqualityEdge& edge = d_equalityEdges[edgeId];
+      if (!first) out << ",";
+      out << d_nodes[edge.getNodeId()];
+      edgeId = edge.getNext();
+      first = false;
+    }
+  }
+  return out.str();
+}
+
+void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, std::vector<TNode>& equalities) {
+  Debug("equality") << "EqualityEngine::explainEquality(" << t1 << "," << t2 << ")" << std::endl;
+
+  // Don't notify during this check
+  ScopedBool turnOffNotify(d_performNotify, false);
+
+  // Add the terms (they might not be there)
+  addTerm(t1);
+  addTerm(t2);
+
+  if (polarity) {
+    // Get the explanation
+    EqualityNodeId t1Id = getNodeId(t1);
+    EqualityNodeId t2Id = getNodeId(t2);
+    getExplanation(t1Id, t2Id, equalities);
+  } else {
+    // Add the equality
+    Node equality = t1.eqNode(t2);
+    addTerm(equality);
+
+    // Get the explanation
+    EqualityNodeId equalityId = getNodeId(equality);
+    EqualityNodeId falseId = getNodeId(d_false);
+    getExplanation(equalityId, falseId, equalities);
+  }
+}
+
+void EqualityEngine::explainPredicate(TNode p, bool polarity, std::vector<TNode>& assertions) {
+  Debug("equality") << "EqualityEngine::explainEquality(" << p << ")" << std::endl;
+
+  // Don't notify during this check
+  ScopedBool turnOffNotify(d_performNotify, false);
+
+  // Add the terms
+  addTerm(p);
+
+  // Get the explanation
+  getExplanation(getNodeId(p), getNodeId(polarity ? d_true : d_false), assertions);
+}
+
+void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, std::vector<TNode>& equalities) const {
+
+  Debug("equality") << "EqualityEngine::getExplanation(" << d_nodes[t1Id] << "," << d_nodes[t2Id] << ")" << std::endl;
+
+  Assert(getEqualityNode(t1Id).getFind() == getEqualityNode(t2Id).getFind());
+
+  // If the nodes are the same, we're done
+  if (t1Id == t2Id) return;
+
+  if (Debug.isOn("equality::internal")) {
+    debugPrintGraph();
+  }
+
+  // Queue for the BFS containing nodes
+  std::vector<BfsData> bfsQueue;
+
+  // Find a path from t1 to t2 in the graph (BFS)
+  bfsQueue.push_back(BfsData(t1Id, null_id, 0));
+  size_t currentIndex = 0;
+  while (true) {
+    // There should always be a path, and every node can be visited only once (tree)
+    Assert(currentIndex < bfsQueue.size());
+
+    // The next node to visit
+    BfsData current = bfsQueue[currentIndex];
+    EqualityNodeId currentNode = current.nodeId;
+
+    Debug("equality") << "EqualityEngine::getExplanation(): currentNode =  " << d_nodes[currentNode] << std::endl;
+
+    // Go through the equality edges of this node
+    EqualityEdgeId currentEdge = d_equalityGraph[currentNode];
+    Debug("equality") << "EqualityEngine::getExplanation(): edgesId =  " << currentEdge << std::endl;
+    Debug("equality") << "EqualityEngine::getExplanation(): edges =  " << edgesToString(currentEdge) << std::endl;
+
+    while (currentEdge != null_edge) {
+      // Get the edge
+      const EqualityEdge& edge = d_equalityEdges[currentEdge];
+
+      // If not just the backwards edge
+      if ((currentEdge | 1u) != (current.edgeId | 1u)) {
+
+        Debug("equality") << "EqualityEngine::getExplanation(): currentEdge = (" << d_nodes[currentNode] << "," << d_nodes[edge.getNodeId()] << ")" << std::endl;
+
+        // Did we find the path
+        if (edge.getNodeId() == t2Id) {
+
+          Debug("equality") << "EqualityEngine::getExplanation(): path found: " << std::endl;
+
+          // Reconstruct the path
+          do {
+            // The current node
+            currentNode = bfsQueue[currentIndex].nodeId;
+            EqualityNodeId edgeNode = d_equalityEdges[currentEdge].getNodeId();
+            MergeReasonType reasonType = d_equalityEdges[currentEdge].getReasonType();
+
+            Debug("equality") << "EqualityEngine::getExplanation(): currentEdge = " << currentEdge << ", currentNode = " << currentNode << std::endl;
+
+            // Add the actual equality to the vector
+            if (reasonType == MERGED_THROUGH_CONGRUENCE) {
+              // f(x1, x2) == f(y1, y2) because x1 = y1 and x2 = y2
+              Debug("equality") << "EqualityEngine::getExplanation(): due to congruence, going deeper" << std::endl;
+              const FunctionApplication& f1 = d_applications[currentNode].original;
+              const FunctionApplication& f2 = d_applications[edgeNode].original;
+              Debug("equality") << push;
+              getExplanation(f1.a, f2.a, equalities);
+              getExplanation(f1.b, f2.b, equalities);
+              Debug("equality") << pop;
+            } else {
+              // Construct the equality
+              Debug("equality") << "EqualityEngine::getExplanation(): adding: " << d_equalityEdges[currentEdge].getReason() << std::endl;
+              equalities.push_back(d_equalityEdges[currentEdge].getReason());
+            }
+
+            // Go to the previous
+            currentEdge = bfsQueue[currentIndex].edgeId;
+            currentIndex = bfsQueue[currentIndex].previousIndex;
+          } while (currentEdge != null_id);
+
+          // Done
+          return;
+        }
+
+        // Push to the visitation queue if it's not the backward edge
+        bfsQueue.push_back(BfsData(edge.getNodeId(), currentEdge, currentIndex));
+      }
+
+      // Go to the next edge
+      currentEdge = edge.getNext();
+    }
+
+    // Go to the next node to visit
+    ++ currentIndex;
+  }
+}
+
+void EqualityEngine::addTriggerEquality(TNode eq) {
+  Assert(eq.getKind() == kind::EQUAL);
+  // Add the terms
+  addTerm(eq);
+  // Positive trigger
+  addTriggerEqualityInternal(eq[0], eq[1], eq, true);
+  // Negative trigger
+  addTriggerEqualityInternal(eq, d_false, eq, false);
+}
+
+void EqualityEngine::addTriggerPredicate(TNode predicate) {
+  Assert(predicate.getKind() != kind::NOT && predicate.getKind() != kind::EQUAL);
+  Assert(d_congruenceKinds.tst(predicate.getKind()), "No point in adding non-congruence predicates");
+  // Add the term
+  addTerm(predicate);
+  // Positive trigger
+  addTriggerEqualityInternal(predicate, d_true, predicate, true);
+  // Negative trigger
+  addTriggerEqualityInternal(predicate, d_false, predicate, false);
+}
+
+void EqualityEngine::addTriggerEqualityInternal(TNode t1, TNode t2, TNode trigger, bool polarity) {
+
+  Debug("equality") << "EqualityEngine::addTrigger(" << t1 << ", " << t2 << ", " << trigger << ")" << std::endl;
+
+  Assert(hasTerm(t1));
+  Assert(hasTerm(t2));
+
+  // Get the information about t1
+  EqualityNodeId t1Id = getNodeId(t1);
+  EqualityNodeId t1classId = getEqualityNode(t1Id).getFind();
+  TriggerId t1TriggerId = d_nodeTriggers[t1classId];
+
+  // Get the information about t2
+  EqualityNodeId t2Id = getNodeId(t2);
+  EqualityNodeId t2classId = getEqualityNode(t2Id).getFind();
+  TriggerId t2TriggerId = d_nodeTriggers[t2classId];
+
+  Debug("equality") << "EqualityEngine::addTrigger(" << trigger << "): " << t1Id << " (" << t1classId << ") = " << t2Id << " (" << t2classId << ")" << std::endl;
+
+  // Create the triggers
+  TriggerId t1NewTriggerId = d_equalityTriggers.size();
+  TriggerId t2NewTriggerId = t1NewTriggerId | 1;
+  d_equalityTriggers.push_back(Trigger(t1classId, t1TriggerId));
+  d_equalityTriggersOriginal.push_back(TriggerInfo(trigger, polarity));
+  d_equalityTriggers.push_back(Trigger(t2classId, t2TriggerId));
+  d_equalityTriggersOriginal.push_back(TriggerInfo(trigger, polarity));
+
+  // Update the counters
+  d_equalityTriggersCount = d_equalityTriggersCount + 2;
+
+  // Add the trigger to the trigger graph
+  d_nodeTriggers[t1classId] = t1NewTriggerId;
+  d_nodeTriggers[t2classId] = t2NewTriggerId;
+
+  // If the trigger is already on, we propagate
+  if (t1classId == t2classId) {
+    Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << "): triggered at setup time" << std::endl;
+    if (d_performNotify) {
+      d_notify.eqNotifyTriggerEquality(trigger, polarity); // Don't care about the return value
+    }
+  }
+
+  if (Debug.isOn("equality::internal")) {
+    debugPrintGraph();
+  }
+
+  Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << ") => (" << t1NewTriggerId << ", " << t2NewTriggerId << ")" << std::endl;
+}
+
+void EqualityEngine::propagate() {
+
+  Debug("equality") << "EqualityEngine::propagate()" << std::endl;
+
+  bool done = false;
+  while (!d_propagationQueue.empty()) {
+
+    // The current merge candidate
+    const MergeCandidate current = d_propagationQueue.front();
+    d_propagationQueue.pop();
+
+    if (done) {
+      // If we're done, just empty the queue
+      continue;
+    }
+
+    // Get the representatives
+    EqualityNodeId t1classId = getEqualityNode(current.t1Id).getFind();
+    EqualityNodeId t2classId = getEqualityNode(current.t2Id).getFind();
+
+    // If already the same, we're done
+    if (t1classId == t2classId) {
+      continue;
+    }
+
+    // Get the nodes of the representatives
+    EqualityNode& node1 = getEqualityNode(t1classId);
+    EqualityNode& node2 = getEqualityNode(t2classId);
+
+    Assert(node1.getFind() == t1classId);
+    Assert(node2.getFind() == t2classId);
+
+    // Add the actual equality to the equality graph
+    addGraphEdge(current.t1Id, current.t2Id, current.type, current.reason);
+
+    // One more equality added
+    d_assertedEqualitiesCount = d_assertedEqualitiesCount + 1;
+
+    // Depending on the merge preference (such as size), merge them
+    std::vector<TriggerId> triggers;
+    if (node2.getSize() > node1.getSize()) {
+      Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t1Id]<< " into " << d_nodes[current.t2Id] << std::endl;
+      d_assertedEqualities.push_back(Equality(t2classId, t1classId));
+      done = !merge(node2, node1, triggers);
+    } else {
+      Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t2Id] << " into " << d_nodes[current.t1Id] << std::endl;
+      d_assertedEqualities.push_back(Equality(t1classId, t2classId));
+      done = !merge(node1, node2, triggers);
+    }
+
+    // Notify the triggers
+    if (d_performNotify && !done) {
+      for (size_t trigger_i = 0, trigger_end = triggers.size(); trigger_i < trigger_end && !done; ++ trigger_i) {
+        const TriggerInfo& triggerInfo = d_equalityTriggersOriginal[triggers[trigger_i]];
+        // Notify the trigger and exit if it fails
+        if (triggerInfo.trigger.getKind() == kind::EQUAL) {
+          done = !d_notify.eqNotifyTriggerEquality(triggerInfo.trigger, triggerInfo.polarity);
+        } else {
+          done = !d_notify.eqNotifyTriggerPredicate(triggerInfo.trigger, triggerInfo.polarity);
+        }
+      }
+    }
+  }
+}
+
+void EqualityEngine::debugPrintGraph() const {
+  for (EqualityNodeId nodeId = 0; nodeId < d_nodes.size(); ++ nodeId) {
+
+    Debug("equality::graph") << d_nodes[nodeId] << " " << nodeId << "(" << getEqualityNode(nodeId).getFind() << "):";
+
+    EqualityEdgeId edgeId = d_equalityGraph[nodeId];
+    while (edgeId != null_edge) {
+      const EqualityEdge& edge = d_equalityEdges[edgeId];
+      Debug("equality::graph") << " " << d_nodes[edge.getNodeId()] << ":" << edge.getReason();
+      edgeId = edge.getNext();
+    }
+
+    Debug("equality::graph") << std::endl;
+  }
+}
+
+bool EqualityEngine::areEqual(TNode t1, TNode t2)
+{
+  // Don't notify during this check
+  ScopedBool turnOffNotify(d_performNotify, false);
+
+  // Add the terms
+  addTerm(t1);
+  addTerm(t2);
+  bool equal = getEqualityNode(t1).getFind() == getEqualityNode(t2).getFind();
+
+  // Return whether the two terms are equal
+  return equal;
+}
+
+bool EqualityEngine::areDisequal(TNode t1, TNode t2)
+{
+  // Don't notify during this check
+  ScopedBool turnOffNotify(d_performNotify, false);
+
+  // Add the terms
+  addTerm(t1);
+  addTerm(t2);
+
+  // Check (t1 = t2) = false
+  // No need to check the symmetric version: we can only deduce a disequality from an existing
+  // diseqality, and each of those is asserted in the symmetric version also
+  Node equality = t1.eqNode(t2);
+  addTerm(equality);
+  if (getEqualityNode(equality).getFind() == getEqualityNode(d_false).getFind()) {
+    return true;
+  }
+
+  // Return whether the terms are disequal
+  return false;
+}
+
+size_t EqualityEngine::getSize(TNode t)
+{
+  // Add the term
+  addTerm(t);
+  return getEqualityNode(getEqualityNode(t).getFind()).getSize();
+}
+
+void EqualityEngine::addTriggerTerm(TNode t)
+{
+  Debug("equality::internal") << "EqualityEngine::addTriggerTerm(" << t << ")" << std::endl;
+
+  // Add the term if it's not already there
+  addTerm(t);
+
+  // Get the node id
+  EqualityNodeId eqNodeId = getNodeId(t);
+  EqualityNode& eqNode = getEqualityNode(eqNodeId);
+  EqualityNodeId classId = eqNode.getFind();
+
+  if (d_nodeIndividualTrigger[classId] != +null_id) {  
+    // No need to keep it, just propagate the existing individual triggers
+    if (d_performNotify) {
+      d_notify.eqNotifyTriggerTermEquality(t, d_nodes[d_nodeIndividualTrigger[classId]], true);
+    }
+  } else {
+    // Add it to the list for backtracking
+    d_individualTriggers.push_back(classId);
+    d_individualTriggersSize = d_individualTriggersSize + 1; 
+    // Mark the class id as a trigger
+    d_nodeIndividualTrigger[classId] = eqNodeId;
+  }
+}
+
+bool EqualityEngine::isTriggerTerm(TNode t) const {
+  if (!hasTerm(t)) return false;
+  EqualityNodeId classId = getEqualityNode(t).getFind();
+  return d_nodeIndividualTrigger[classId] != +null_id;
+}
+
+
+TNode EqualityEngine::getTriggerTermRepresentative(TNode t) const {
+  Assert(isTriggerTerm(t));
+  EqualityNodeId classId = getEqualityNode(t).getFind();
+  return d_nodes[d_nodeIndividualTrigger[classId]];
+}
+
+void EqualityEngine::storeApplicationLookup(FunctionApplication& funNormalized, EqualityNodeId funId) {
+  Assert(d_applicationLookup.find(funNormalized) == d_applicationLookup.end());
+  d_applicationLookup[funNormalized] = funId;
+  d_applicationLookups.push_back(funNormalized);
+  d_applicationLookupsCount = d_applicationLookupsCount + 1;
+  Debug("equality::backtrack") << "d_applicationLookupsCount = " << d_applicationLookupsCount << std::endl;
+  Debug("equality::backtrack") << "d_applicationLookups.size() = " << d_applicationLookups.size() << std::endl;
+  Assert(d_applicationLookupsCount == d_applicationLookups.size());
+}
+
+void EqualityEngine::getUseListTerms(TNode t, std::set<TNode>& output) {
+  if (hasTerm(t)) {
+    // Get the equivalence class
+    EqualityNodeId classId = getEqualityNode(t).getFind();
+    // Go through the equivalence class and get where t is used in
+    EqualityNodeId currentId = classId;
+    do {
+      // Get the current node
+      EqualityNode& currentNode = getEqualityNode(currentId);
+      // Go through the use-list
+      UseListNodeId currentUseId = currentNode.getUseList();
+      while (currentUseId != null_uselist_id) {
+        // Get the node of the use list
+        UseListNode& useNode = d_useListNodes[currentUseId];
+        // Get the function application
+        EqualityNodeId funId = useNode.getApplicationId();
+        output.insert(d_nodes[funId]);
+        // Go to the next one in the use list
+        currentUseId = useNode.getNext();
+      }
+      // Move to the next node
+      currentId = currentNode.getNext();
+    } while (currentId != classId);
+  }
+}
+
+} // Namespace uf
+} // Namespace theory
+} // Namespace CVC4
+
index dccd5ba5672591f3a17db685701b672ad00bb28c..f9c10d1b6cc5d72b40059336dc3bb515b18eeb6e 100644 (file)
@@ -35,7 +35,7 @@
 
 namespace CVC4 {
 namespace theory {
-namespace uf {
+namespace eq {
 
 /** Id of the node */
 typedef size_t EqualityNodeId;
@@ -213,9 +213,74 @@ public:
   }
 };
 
-template <typename NotifyClass>
+/**
+ * Interface for equality engine notifications. All the notifications
+ * are safe as TNodes, but not necessarily for negations.
+ */
+class EqualityEngineNotify {
+
+  friend class EqualityEngine;
+
+public:
+
+  virtual ~EqualityEngineNotify() {};
+
+  /**
+   * Notifies about a trigger equality that became true or false.
+   *
+   * @param eq the equality that became true or false
+   * @param value the value of the equality
+   */
+  virtual bool eqNotifyTriggerEquality(TNode equality, bool value) = 0;
+
+  /**
+   * Notifies about a trigger predicate that became true or false.
+   *
+   * @param predicate the trigger predicate that bacame true or false
+   * @param value the value of the predicate
+   */
+  virtual bool eqNotifyTriggerPredicate(TNode predicate, bool value) = 0;
+
+  /**
+   * Notifies about the merge of two trigger terms.
+   *
+   * @param t1 a term marked as trigger
+   * @param t2 a term marked as trigger
+   * @param value true if equal, false if dis-equal
+   */
+  virtual bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) = 0;
+
+  /**
+   * Notifies about the merge of two constant terms.
+   *
+   * @param t1 a constant term
+   * @param t2 a constnat term
+   */
+  virtual bool eqNotifyConstantTermMerge(TNode t1, TNode t2) = 0;
+};
+
+/**
+ * Implementation of the notification interface that ignores all the
+ * notifications.
+ */
+class EqualityEngineNotifyNone : public EqualityEngineNotify {
+public:
+  bool eqNotifyTriggerEquality(TNode equality, bool value) { return true; }
+  bool eqNotifyTriggerPredicate(TNode predicate, bool value) { return true; }
+  bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { return true; }
+  bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { return true; }
+};
+
+
+/**
+ * Class for keeping an incremental congurence closure over a set of terms. It provides
+ * notifications via an EqualityEngineNotify object.
+ */
 class EqualityEngine : public context::ContextNotifyObj {
 
+  /** Default implementation of the notification object */
+  static EqualityEngineNotifyNone s_notifyNone;
+
 public:
 
   /** Statistics about the equality engine instance */
@@ -226,21 +291,26 @@ public:
     IntStat termsCount;
     /** Number of function terms managed by the system */
     IntStat functionTermsCount;
+    /** Number of constant terms managed by the system */
+    IntStat constantTermsCount;
 
     Statistics(std::string name)
     : mergesCount(name + "::mergesCount", 0),
       termsCount(name + "::termsCount", 0),
-      functionTermsCount(name + "::functionTermsCount", 0)
+      functionTermsCount(name + "::functionTermsCount", 0),
+      constantTermsCount(name + "::constantTermsCount", 0)
     {
       StatisticsRegistry::registerStat(&mergesCount);
       StatisticsRegistry::registerStat(&termsCount);
       StatisticsRegistry::registerStat(&functionTermsCount);
+      StatisticsRegistry::registerStat(&constantTermsCount);
     }
 
     ~Statistics() {
       StatisticsRegistry::unregisterStat(&mergesCount);
       StatisticsRegistry::unregisterStat(&termsCount);
       StatisticsRegistry::unregisterStat(&functionTermsCount);
+      StatisticsRegistry::unregisterStat(&constantTermsCount);
     }
   };
 
@@ -282,7 +352,7 @@ private:
   bool d_performNotify;
 
   /** The class to notify when a representative changes for a term */
-  NotifyClass d_notify;
+  EqualityEngineNotify& d_notify;
 
   /** The map of kinds to be treated as function applications */
   KindMap d_congruenceKinds;
@@ -428,8 +498,11 @@ private:
   /** Returns the id of the node */
   EqualityNodeId getNodeId(TNode node) const;
 
-  /** Merge the class2 into class1 */
-  void merge(EqualityNode& class1, EqualityNode& class2, std::vector<TriggerId>& triggers);
+  /**
+   * Merge the class2 into class1
+   * @return true if ok, false if to break out
+   */
+  bool merge(EqualityNode& class1, EqualityNode& class2, std::vector<TriggerId>& triggers);
 
   /** Undo the mereg of class2 into class1 */
   void undoMerge(EqualityNode& class1, EqualityNode& class2, EqualityNodeId class2Id);
@@ -437,29 +510,13 @@ private:
   /** Backtrack the information if necessary */
   void backtrack();
 
-  /**
-   * Data used in the BFS search through the equality graph.
-   */
-  struct BfsData {
-    // The current node
-    EqualityNodeId nodeId;
-    // The index of the edge we traversed
-    EqualityEdgeId edgeId;
-    // Index in the queue of the previous node. Shouldn't be too much of them, at most the size
-    // of the biggest equivalence class
-    size_t previousIndex;
-
-    BfsData(EqualityNodeId nodeId = null_id, EqualityEdgeId edgeId = null_edge, size_t prev = 0)
-    : nodeId(nodeId), edgeId(edgeId), previousIndex(prev) {}
-  };
-
   /**
    * Trigger that will be updated
    */
   struct Trigger {
     /** The current class id of the LHS of the trigger */
     EqualityNodeId classId;
-    /** Next trigger for class */
+    /** Next trigger for class */
     TriggerId nextTrigger;
 
     Trigger(EqualityNodeId classId = null_id, TriggerId nextTrigger = null_trigger)
@@ -473,10 +530,20 @@ private:
    */
   std::vector<Trigger> d_equalityTriggers;
 
+  struct TriggerInfo {
+    /** The trigger itself */
+    Node trigger;
+    /** Polarity of the trigger */
+    bool polarity;
+    TriggerInfo() {}
+    TriggerInfo(Node trigger, bool polarity)
+    : trigger(trigger), polarity(polarity) {}
+  };
+
   /**
    * Vector of original equalities of the triggers.
    */
-  std::vector<Node> d_equalityTriggersOriginal;
+  std::vector<TriggerInfo> d_equalityTriggersOriginal;
 
   /**
    * Context dependent count of triggers
@@ -504,6 +571,19 @@ private:
    */
   std::vector<EqualityNodeId> d_nodeIndividualTrigger;
 
+  /**
+   * Map from ids to the id of the constant that is the representative.
+   */
+  std::vector<EqualityNodeId> d_constantRepresentative;
+
+  /**
+   * Size of the constant representatives list.
+   */
+  context::CDO<unsigned> d_constantRepresentativesSize;
+  
+  /** The list of representatives that became constant. */ 
+  std::vector<EqualityNodeId> d_constantRepresentatives;
+
   /**
    * Adds the trigger with triggerId to the beginning of the trigger list of the node with id nodeId.
    */
@@ -516,7 +596,7 @@ private:
   EqualityNodeId newApplicationNode(TNode original, EqualityNodeId t1, EqualityNodeId t2);
 
   /** Add a new node to the database */
-  EqualityNodeId newNode(TNode t, bool isApplication);
+  EqualityNodeId newNode(TNode t);
 
   struct MergeCandidate {
     EqualityNodeId t1Id, t2Id;
@@ -561,44 +641,41 @@ private:
   /**
    * Adds an equality of terms t1 and t2 to the database.
    */
-  void addEqualityInternal(TNode t1, TNode t2, TNode reason);
+  void assertEqualityInternal(TNode t1, TNode t2, TNode reason);
 
-public:
+  /**
+   * Adds a trigger equality to the database with the trigger node and polarity for notification.
+   */
+  void addTriggerEqualityInternal(TNode t1, TNode t2, TNode trigger, bool polarity);
 
   /**
-   * Initialize the equality engine, given the owning class. This will initialize the notifier with
-   * the owner information.
-   */
-  EqualityEngine(NotifyClass& notify, context::Context* context, std::string name)
-  : ContextNotifyObj(context),
-    d_context(context),
-    d_performNotify(true),
-    d_notify(notify),
-    d_applicationLookupsCount(context, 0),
-    d_nodesCount(context, 0),
-    d_assertedEqualitiesCount(context, 0),
-    d_equalityTriggersCount(context, 0),
-    d_individualTriggersSize(context, 0),
-    d_stats(name)
-  {
-    Debug("equality") << "EqualityEdge::EqualityEngine(): id_null = " << +null_id << std::endl;
-    Debug("equality") << "EqualityEdge::EqualityEngine(): edge_null = " << +null_edge << std::endl;
-    Debug("equality") << "EqualityEdge::EqualityEngine(): trigger_null = " << +null_trigger << std::endl;
-    d_true = NodeManager::currentNM()->mkConst<bool>(true);
-    d_false = NodeManager::currentNM()->mkConst<bool>(false);
+   * This method gets called on backtracks from the context manager.
+   */
+  void contextNotifyPop() {
+    backtrack();
   }
 
   /**
-   * Just a destructor.
+   * Constructor initialization stuff.
    */
-  virtual ~EqualityEngine() throw(AssertionException) {}
+  void init();
+
+public:
 
   /**
-   * This method gets called on backtracks from the context manager.
+   * Initialize the equality engine, given the notification class. 
    */
-  void notify() {
-    backtrack();
-  }
+  EqualityEngine(EqualityEngineNotify& notify, context::Context* context, std::string name);
+
+  /**
+   * Initialize the equality engine with no notification class. 
+   */
+  EqualityEngine(context::Context* context, std::string name);
+
+  /**
+   * Just a destructor.
+   */
+  virtual ~EqualityEngine() throw(AssertionException) {}
 
   /**
    * Adds a term to the term database.
@@ -629,77 +706,91 @@ public:
   bool hasTerm(TNode t) const;
 
   /**
-   * Adds aa predicate t with given polarity
+   * Adds a predicate p with given polarity. The predicate asserted
+   * should be in the coungruence closure kinds (otherwise it's 
+   * useless.
+   *
+   * @param p the (non-negated) predicate
+   * @param polarity true if asserting the predicate, false if 
+   *                 asserting the negated predicate
+   * @param the reason to keep for building explanations
    */
-  void addPredicate(TNode t, bool polarity, TNode reason);
+  void assertPredicate(TNode p, bool polarity, TNode reason);
 
   /**
-   * Adds an equality t1 = t2 to the database.
+   * Adds an equality eq with the given polarity to the database.
+   *
+   * @param eq the (non-negated) equality
+   * @param polarity true if asserting the equality, false if 
+   *                 asserting the negated equality
+   * @param the reason to keep for building explanations
    */
-  void addEquality(TNode t1, TNode t2, TNode reason);
+  void assertEquality(TNode eq, bool polarity, TNode reason);
 
   /**
-   * Adds an dis-equality t1 != t2 to the database.
-   */
-  void addDisequality(TNode t1, TNode t2, TNode reason);
-
-  /**
-   * Returns the representative of the term t.
+   * Returns the current representative of the term t.
    */
   TNode getRepresentative(TNode t) const;
 
   /**
-   * Add all the terms where the given term appears in (directly or implicitly).
+   * Add all the terms where the given term appears as a first child 
+   * (directly or implicitly).
    */
   void getUseListTerms(TNode t, std::set<TNode>& output);
 
   /**
-   * Returns true if the two nodes are in the same class.
+   * Returns true if the two nodes are in the same equivalence class.
    */
   bool areEqual(TNode t1, TNode t2) const;
 
   /**
-   * Get an explanation of the equality t1 = t2. Returns the asserted equalities that
-   * imply t1 = t2. Returns TNodes as the assertion equalities should be hashed somewhere
-   * else. 
+   * Get an explanation of the equality t1 = t2 begin true of false. 
+   * Returns the reasons (added when asserting) that imply it
+   * in the assertions vector.
    */
-  void explainEquality(TNode t1, TNode t2, std::vector<TNode>& equalities);
+  void explainEquality(TNode t1, TNode t2, bool polarity, std::vector<TNode>& assertions);
 
   /**
-   * Get an explanation of the equality t1 = t2. Returns the asserted equalities that
-   * imply t1 = t2. Returns TNodes as the assertion equalities should be hashed somewhere
-   * else. 
+   * Get an explanation of the predicate being true or false. 
+   * Returns the reasons (added when asserting) that imply imply it
+   * in the assertions vector.
    */
-  void explainDisequality(TNode t1, TNode t2, std::vector<TNode>& equalities);
+  void explainPredicate(TNode p, bool polarity, std::vector<TNode>& assertions);
 
   /**
-   * Add term to the trigger terms. The notify class will get notified when two 
-   * trigger terms become equal. Thihs will only happen on trigger term 
-   * representatives.
+   * Add term to the trigger terms. The notify class will get notified 
+   * when two trigger terms become equal or dis-equal. The notification
+   * will not happen on all the terms, but only on the ones that are 
+   * represent the class.
    */
   void addTriggerTerm(TNode t);
 
   /**
-   * Returns true if t is a trigger term or equal to some other trigger term.
+   * Returns true if t is a trigger term or in the same equivalence 
+   * class as some other trigger term.
    */
   bool isTriggerTerm(TNode t) const;
 
   /**
-   * Returns the representative trigger term (isTriggerTerm(t)) should be true.
+   * Returns the representative trigger term of the given term.
+   *
+   * @param t the term to check where isTriggerTerm(t) should be true
    */
   TNode getTriggerTermRepresentative(TNode t) const;
 
   /**
-   * Adds a notify trigger for equality t1 = t2, i.e. when t1 = t2 the notify will be called with
-   * trigger.
+   * Adds a notify trigger for equality. When equality becomes true eqNotifyTriggerEquality
+   * will be called with value = true, and when equality becomes false eqNotifyTriggerEquality
+   * will be called with value = false.
    */
-  void addTriggerEquality(TNode t1, TNode t2, TNode trigger);
+  void addTriggerEquality(TNode equality);
 
   /**
-   * Adds a notify trigger for dis-equality t1 != t2, i.e. when t1 != t2 the notify will be called with
-   * trigger.
+   * Adds a notify trigger for the predicate p. When the predicate becomes true
+   * eqNotifyTriggerPredicate will be called with value = true, and when equality becomes false
+   * eqNotifyTriggerPredicate will be called with value = false.
    */
-  void addTriggerDisequality(TNode t1, TNode t2, TNode trigger);
+  void addTriggerPredicate(TNode predicate);
 
   /**
    * Check whether the two terms are equal.
@@ -712,7 +803,7 @@ public:
   bool areDisequal(TNode t1, TNode t2);
 
   /**
-   * Return the number of nodes in the equivalence class contianing t
+   * Return the number of nodes in the equivalence class containing t
    * Adds t if not already there.
    */
   size_t getSize(TNode t);
diff --git a/src/theory/uf/equality_engine_impl.h b/src/theory/uf/equality_engine_impl.h
deleted file mode 100644 (file)
index be12e5f..0000000
+++ /dev/null
@@ -1,947 +0,0 @@
-/*********************                                                        */
-/*! \file equality_engine_impl.h
- ** \verbatim
- ** Original author: dejan
- ** Major contributors: none
- ** Minor contributors (to current version): none
- ** This file is part of the CVC4 prototype.
- ** Copyright (c) 2009, 2010, 2011  The Analysis of Computer Systems Group (ACSys)
- ** Courant Institute of Mathematical Sciences
- ** New York University
- ** See the file COPYING in the top-level source directory for licensing
- ** information.\endverbatim
- **
- ** \brief [[ Add one-line brief description here ]]
- **
- ** [[ Add lengthier description here ]]
- ** \todo document this file
- **/
-
-#include "cvc4_private.h"
-
-#pragma once
-
-#include "theory/uf/equality_engine.h"
-
-namespace CVC4 {
-namespace theory {
-namespace uf {
-
-class ScopedBool {
-  bool& watch;
-  bool oldValue;
-public:
-  ScopedBool(bool& watch, bool newValue)
-  : watch(watch), oldValue(watch) {
-    watch = newValue;
-  }
-  ~ScopedBool() {
-    watch = oldValue;
-  }
-};
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::enqueue(const MergeCandidate& candidate) {
-    Debug("equality") << "EqualityEngine::enqueue(" << candidate.toString(*this) << ")" << std::endl;
-    d_propagationQueue.push(candidate);
-}
-
-template <typename NotifyClass>
-EqualityNodeId EqualityEngine<NotifyClass>::newApplicationNode(TNode original, EqualityNodeId t1, EqualityNodeId t2) {
-  Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << ")" << std::endl;
-
-  ++ d_stats.functionTermsCount;
-
-  // Get another id for this
-  EqualityNodeId funId = newNode(original, true);
-  FunctionApplication funOriginal(t1, t2);
-  // The function application we're creating
-  EqualityNodeId t1ClassId = getEqualityNode(t1).getFind();
-  EqualityNodeId t2ClassId = getEqualityNode(t2).getFind();
-  FunctionApplication funNormalized(t1ClassId, t2ClassId);
-
-  // We add the original version
-  d_applications[funId] = FunctionApplicationPair(funOriginal, funNormalized);
-
-  // Add the lookup data, if it's not already there
-  typename ApplicationIdsMap::iterator find = d_applicationLookup.find(funNormalized);
-  if (find == d_applicationLookup.end()) {
-    // When we backtrack, if the lookup is not there anymore, we'll add it again
-    Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << "): no lookup, setting up" << std::endl;
-    // Mark the normalization to the lookup
-    storeApplicationLookup(funNormalized, funId);
-  } else {
-    // If it's there, we need to merge these two
-    Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << "): lookup exists, adding to queue" << std::endl;
-    enqueue(MergeCandidate(funId, find->second, MERGED_THROUGH_CONGRUENCE, TNode::null()));
-    propagate();
-  }
-
-  // Add to the use lists
-  d_equalityNodes[t1ClassId].usedIn(funId, d_useListNodes);
-  d_equalityNodes[t2ClassId].usedIn(funId, d_useListNodes);
-
-  // Return the new id
-  Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << ") => " << funId << std::endl;
-
-  return funId;
-}
-
-template <typename NotifyClass>
-EqualityNodeId EqualityEngine<NotifyClass>::newNode(TNode node, bool isApplication) {
-
-  Debug("equality") << "EqualityEngine::newNode(" << node << ", " << (isApplication ? "function" : "regular") << ")" << std::endl;
-
-  ++ d_stats.termsCount;
-
-  // Register the new id of the term
-  EqualityNodeId newId = d_nodes.size();
-  d_nodeIds[node] = newId;
-  // Add the node to it's position
-  d_nodes.push_back(node);
-  // Note if this is an application or not
-  d_applications.push_back(FunctionApplicationPair());
-  // Add the trigger list for this node
-  d_nodeTriggers.push_back(+null_trigger);
-  // Add it to the equality graph
-  d_equalityGraph.push_back(+null_edge);
-  // Mark the no-individual trigger
-  d_nodeIndividualTrigger.push_back(+null_id);
-  // Add the equality node to the nodes
-  d_equalityNodes.push_back(EqualityNode(newId));
-
-  // Increase the counters
-  d_nodesCount = d_nodesCount + 1;
-
-  Debug("equality") << "EqualityEngine::newNode(" << node << ", " << (isApplication ? "function" : "regular") << ") => " << newId << std::endl;
-
-  return newId;
-}
-
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::addTerm(TNode t) {
-
-  Debug("equality") << "EqualityEngine::addTerm(" << t << ")" << std::endl;
-
-  // If there already, we're done
-  if (hasTerm(t)) {
-    Debug("equality") << "EqualityEngine::addTerm(" << t << "): already there" << std::endl;
-    return;
-  }
-
-  EqualityNodeId result;
-
-  // If a function application we go in
-  if (t.getNumChildren() > 0 && d_congruenceKinds[t.getKind()])
-  {
-    // Add the operator
-    TNode tOp = t.getOperator();
-    addTerm(tOp);
-    // Add all the children and Curryfy
-    result = getNodeId(tOp);
-    for (unsigned i = 0; i < t.getNumChildren(); ++ i) {
-      // Add the child
-      addTerm(t[i]);
-      // Add the application
-      result = newApplicationNode(t, result, getNodeId(t[i]));
-    }
-  } else {
-    // Otherwise we just create the new id
-    result = newNode(t, false);
-  }
-
-  Debug("equality") << "EqualityEngine::addTerm(" << t << ") => " << result << std::endl;
-}
-
-template <typename NotifyClass>
-bool EqualityEngine<NotifyClass>::hasTerm(TNode t) const {
-  return d_nodeIds.find(t) != d_nodeIds.end();
-}
-
-template <typename NotifyClass>
-EqualityNodeId EqualityEngine<NotifyClass>::getNodeId(TNode node) const {
-  Assert(hasTerm(node), node.toString().c_str());
-  return (*d_nodeIds.find(node)).second;
-}
-
-template <typename NotifyClass>
-EqualityNode& EqualityEngine<NotifyClass>::getEqualityNode(TNode t) {
-  return getEqualityNode(getNodeId(t));
-}
-
-template <typename NotifyClass>
-EqualityNode& EqualityEngine<NotifyClass>::getEqualityNode(EqualityNodeId nodeId) {
-  Assert(nodeId < d_equalityNodes.size());
-  return d_equalityNodes[nodeId];
-}
-
-template <typename NotifyClass>
-const EqualityNode& EqualityEngine<NotifyClass>::getEqualityNode(TNode t) const {
-  return getEqualityNode(getNodeId(t));
-}
-
-template <typename NotifyClass>
-const EqualityNode& EqualityEngine<NotifyClass>::getEqualityNode(EqualityNodeId nodeId) const {
-  Assert(nodeId < d_equalityNodes.size());
-  return d_equalityNodes[nodeId];
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::addEqualityInternal(TNode t1, TNode t2, TNode reason) {
-
-  Debug("equality") << "EqualityEngine::addEqualityInternal(" << t1 << "," << t2 << ")" << std::endl;
-
-  // Add the terms if they are not already in the database
-  addTerm(t1);
-  addTerm(t2);
-
-  // Add to the queue and propagate
-  EqualityNodeId t1Id = getNodeId(t1);
-  EqualityNodeId t2Id = getNodeId(t2);
-  enqueue(MergeCandidate(t1Id, t2Id, MERGED_THROUGH_EQUALITY, reason));
-
-  propagate();
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::addPredicate(TNode t, bool polarity, TNode reason) {
-
-  Debug("equality") << "EqualityEngine::addPredicate(" << t << "," << (polarity ? "true" : "false") << ")" << std::endl;
-
-  addEqualityInternal(t, polarity ? d_true : d_false, reason);
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::addEquality(TNode t1, TNode t2, TNode reason) {
-
-  Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << ")" << std::endl;
-
-  addEqualityInternal(t1, t2, reason);
-
-  Node equality = t1.eqNode(t2);
-  addEqualityInternal(equality, d_true, reason);
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::addDisequality(TNode t1, TNode t2, TNode reason) {
-
-  Debug("equality") << "EqualityEngine::addDisequality(" << t1 << "," << t2 << ")" << std::endl;
-
-  Node equality1 = t1.eqNode(t2);
-  addEqualityInternal(equality1, d_false, reason);
-  Node equality2 = t2.eqNode(t1);
-  addEqualityInternal(equality2, d_false, reason);
-}
-
-
-template <typename NotifyClass>
-TNode EqualityEngine<NotifyClass>::getRepresentative(TNode t) const {
-
-  Debug("equality::internal") << "EqualityEngine::getRepresentative(" << t << ")" << std::endl;
-
-  Assert(hasTerm(t));
-
-  // Both following commands are semantically const
-  EqualityNodeId representativeId = getEqualityNode(t).getFind();
-
-  Debug("equality::internal") << "EqualityEngine::getRepresentative(" << t << ") => " << d_nodes[representativeId] << std::endl;
-
-  return d_nodes[representativeId];
-}
-
-template <typename NotifyClass>
-bool EqualityEngine<NotifyClass>::areEqual(TNode t1, TNode t2) const {
-  Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ")" << std::endl;
-
-  Assert(hasTerm(t1));
-  Assert(hasTerm(t2));
-
-  // Both following commands are semantically const
-  EqualityNodeId rep1 = getEqualityNode(t1).getFind();
-  EqualityNodeId rep2 = getEqualityNode(t2).getFind();
-
-  Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ") => " << (rep1 == rep2 ? "true" : "false") << std::endl;
-
-  return rep1 == rep2;
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::merge(EqualityNode& class1, EqualityNode& class2, std::vector<TriggerId>& triggers) {
-
-  Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << ")" << std::endl;
-
-  Assert(triggers.empty());
-
-  ++ d_stats.mergesCount;
-
-  EqualityNodeId class1Id = class1.getFind();
-  EqualityNodeId class2Id = class2.getFind();
-
-  // Update class2 representative information
-  Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating class " << class2Id << std::endl;
-  EqualityNodeId currentId = class2Id;
-  do {
-    // Get the current node
-    EqualityNode& currentNode = getEqualityNode(currentId);
-
-    // Update it's find to class1 id
-    Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): " << currentId << "->" << class1Id << std::endl;
-    currentNode.setFind(class1Id);
-
-    // Go through the triggers and inform if necessary
-    TriggerId currentTrigger = d_nodeTriggers[currentId];
-    while (currentTrigger != null_trigger) {
-      Trigger& trigger = d_equalityTriggers[currentTrigger];
-      Trigger& otherTrigger = d_equalityTriggers[currentTrigger ^ 1];
-
-      // If the two are not already in the same class
-      if (otherTrigger.classId != trigger.classId) {
-        trigger.classId = class1Id;
-        // If they became the same, call the trigger
-        if (otherTrigger.classId == class1Id) {
-          // Id of the real trigger is half the internal one
-          triggers.push_back(currentTrigger);
-        }
-      }
-
-      // Go to the next trigger
-      currentTrigger = trigger.nextTrigger;
-    }
-
-    // Move to the next node
-    currentId = currentNode.getNext();
-
-  } while (currentId != class2Id);
-
-
-  // Update class2 table lookup and information
-  Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating lookups of " << class2Id << std::endl;
-  do {
-    // Get the current node
-    EqualityNode& currentNode = getEqualityNode(currentId);    
-    Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating lookups of node " << currentId << std::endl;
-    // Go through the uselist and check for congruences
-    UseListNodeId currentUseId = currentNode.getUseList();
-    while (currentUseId != null_uselist_id) {
-      // Get the node of the use list
-      UseListNode& useNode = d_useListNodes[currentUseId];
-      // Get the function application
-      EqualityNodeId funId = useNode.getApplicationId();
-      Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): " << currentId << " in " << d_nodes[funId] << std::endl;
-      const FunctionApplication& fun = d_applications[useNode.getApplicationId()].normalized;
-      // Check if there is an application with find arguments
-      EqualityNodeId aNormalized = getEqualityNode(fun.a).getFind();
-      EqualityNodeId bNormalized = getEqualityNode(fun.b).getFind();
-      FunctionApplication funNormalized(aNormalized, bNormalized);
-      typename ApplicationIdsMap::iterator find = d_applicationLookup.find(funNormalized);
-      if (find != d_applicationLookup.end()) {
-        // Applications fun and the funNormalized can be merged due to congruence
-        if (getEqualityNode(funId).getFind() != getEqualityNode(find->second).getFind()) {
-          enqueue(MergeCandidate(funId, find->second, MERGED_THROUGH_CONGRUENCE, TNode::null()));
-        }
-      } else {
-        // There is no representative, so we can add one, we remove this when backtracking
-        storeApplicationLookup(funNormalized, funId);
-      }
-      // Go to the next one in the use list
-      currentUseId = useNode.getNext();
-    }
-
-    // Move to the next node
-    currentId = currentNode.getNext();
-  } while (currentId != class2Id);
-
-  // Now merge the lists
-  class1.merge<true>(class2);
-
-  // Notfiy the triggers
-  EqualityNodeId class1triggerId = d_nodeIndividualTrigger[class1Id];
-  EqualityNodeId class2triggerId = d_nodeIndividualTrigger[class2Id];
-  if (class2triggerId != +null_id) {
-    if (class1triggerId == +null_id) {
-      // If class1 is not an individual trigger, but class2 is, mark it
-      d_nodeIndividualTrigger[class1Id] = class2triggerId;
-      // Add it to the list for backtracking
-      d_individualTriggers.push_back(class1Id);
-      d_individualTriggersSize = d_individualTriggersSize + 1;  
-    } else {
-      // Notify when done
-      if (d_performNotify) {
-        d_notify.notify(d_nodes[class1triggerId], d_nodes[class2triggerId]); 
-      }
-    }  
-  }
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::undoMerge(EqualityNode& class1, EqualityNode& class2, EqualityNodeId class2Id) {
-
-  Debug("equality") << "EqualityEngine::undoMerge(" << class1.getFind() << "," << class2Id << ")" << std::endl;
-
-  // Now unmerge the lists (same as merge)
-  class1.merge<false>(class2);
-
-  // Update class2 representative information
-  EqualityNodeId currentId = class2Id;
-  Debug("equality") << "EqualityEngine::undoMerge(" << class1.getFind() << "," << class2Id << "): undoing representative info" << std::endl;
-  do {
-    // Get the current node
-    EqualityNode& currentNode = getEqualityNode(currentId);
-
-    // Update it's find to class1 id
-    currentNode.setFind(class2Id);
-
-    // Go through the trigger list (if any) and undo the class
-    TriggerId currentTrigger = d_nodeTriggers[currentId];
-    while (currentTrigger != null_trigger) {
-      Trigger& trigger = d_equalityTriggers[currentTrigger];
-      trigger.classId = class2Id;
-      currentTrigger = trigger.nextTrigger;
-    }
-
-    // Move to the next node
-    currentId = currentNode.getNext();
-
-  } while (currentId != class2Id);
-
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::backtrack() {
-
-  Debug("equality::backtrack") << "backtracking" << std::endl;
-
-  // If we need to backtrack then do it
-  if (d_assertedEqualitiesCount < d_assertedEqualities.size()) {
-
-    // Clear the propagation queue
-    while (!d_propagationQueue.empty()) {
-      d_propagationQueue.pop();
-    }
-
-    Debug("equality") << "EqualityEngine::backtrack(): nodes" << std::endl;
-
-    for (int i = (int)d_assertedEqualities.size() - 1, i_end = (int)d_assertedEqualitiesCount; i >= i_end; --i) {
-      // Get the ids of the merged classes
-      Equality& eq = d_assertedEqualities[i];
-      // Undo the merge
-      undoMerge(d_equalityNodes[eq.lhs], d_equalityNodes[eq.rhs], eq.rhs);
-    }
-
-    d_assertedEqualities.resize(d_assertedEqualitiesCount);
-
-    Debug("equality") << "EqualityEngine::backtrack(): edges" << std::endl;
-
-    for (int i = (int)d_equalityEdges.size() - 2, i_end = (int)(2*d_assertedEqualitiesCount); i >= i_end; i -= 2) {
-      EqualityEdge& edge1 = d_equalityEdges[i];
-      EqualityEdge& edge2 = d_equalityEdges[i | 1];
-      d_equalityGraph[edge2.getNodeId()] = edge1.getNext();
-      d_equalityGraph[edge1.getNodeId()] = edge2.getNext();
-    }
-
-    d_equalityEdges.resize(2 * d_assertedEqualitiesCount);
-  }
-
-  if (d_individualTriggers.size() > d_individualTriggersSize) {
-    // Unset the individual triggers
-    for (int i = d_individualTriggers.size() - 1, i_end = d_individualTriggersSize; i >= i_end; -- i) {
-      d_nodeIndividualTrigger[d_individualTriggers[i]] = +null_id;
-    }
-    d_individualTriggers.resize(d_individualTriggersSize);
-  }
-  
-  if (d_equalityTriggers.size() > d_equalityTriggersCount) {
-    // Unlink the triggers from the lists
-    for (int i = d_equalityTriggers.size() - 1, i_end = d_equalityTriggersCount; i >= i_end; -- i) {
-      const Trigger& trigger = d_equalityTriggers[i];
-      d_nodeTriggers[trigger.classId] = trigger.nextTrigger;
-    }
-    // Get rid of the triggers 
-    d_equalityTriggers.resize(d_equalityTriggersCount);
-    d_equalityTriggersOriginal.resize(d_equalityTriggersCount);
-  }
-
-  if (d_applicationLookups.size() > d_applicationLookupsCount) {
-    for (int i = d_applicationLookups.size() - 1, i_end = (int) d_applicationLookupsCount; i >= i_end; -- i) {
-      d_applicationLookup.erase(d_applicationLookups[i]);
-    }
-    d_applicationLookups.resize(d_applicationLookupsCount);
-  }
-
-  if (d_nodes.size() > d_nodesCount) {
-    // Go down the nodes, check the application nodes and remove them from use-lists
-    for(int i = d_nodes.size() - 1, i_end = (int)d_nodesCount; i >= i_end; -- i) {
-      // Remove from the node -> id map
-      Debug("equality") << "EqualityEngine::backtrack(): removing node " << d_nodes[i] << std::endl;
-      d_nodeIds.erase(d_nodes[i]);
-
-      const FunctionApplication& app = d_applications[i].normalized;
-      if (app.isApplication()) {
-        // Remove b from use-list
-        getEqualityNode(app.b).removeTopFromUseList(d_useListNodes);
-        // Remove a from use-list
-        getEqualityNode(app.a).removeTopFromUseList(d_useListNodes);
-      }
-    }
-
-    // Now get rid of the nodes and the rest
-    d_nodes.resize(d_nodesCount);
-    d_applications.resize(d_nodesCount);
-    d_nodeTriggers.resize(d_nodesCount);
-    d_nodeIndividualTrigger.resize(d_nodesCount);
-    d_equalityGraph.resize(d_nodesCount);
-    d_equalityNodes.resize(d_nodesCount);
-  }
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::addGraphEdge(EqualityNodeId t1, EqualityNodeId t2, MergeReasonType type, TNode reason) {
-  Debug("equality") << "EqualityEngine::addGraphEdge(" << d_nodes[t1] << "," << d_nodes[t2] << "," << reason << ")" << std::endl;
-  EqualityEdgeId edge = d_equalityEdges.size();
-  d_equalityEdges.push_back(EqualityEdge(t2, d_equalityGraph[t1], type, reason));
-  d_equalityEdges.push_back(EqualityEdge(t1, d_equalityGraph[t2], type, reason));
-  d_equalityGraph[t1] = edge;
-  d_equalityGraph[t2] = edge | 1;
-
-  if (Debug.isOn("equality::internal")) {
-    debugPrintGraph();
-  }
-}
-
-template <typename NotifyClass>
-std::string EqualityEngine<NotifyClass>::edgesToString(EqualityEdgeId edgeId) const {
-  std::stringstream out;
-  bool first = true;
-  if (edgeId == null_edge) {
-    out << "null";
-  } else {
-    while (edgeId != null_edge) {
-      const EqualityEdge& edge = d_equalityEdges[edgeId];
-      if (!first) out << ",";
-      out << d_nodes[edge.getNodeId()];
-      edgeId = edge.getNext();
-      first = false;
-    }
-  }
-  return out.str();
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::explainEquality(TNode t1, TNode t2, std::vector<TNode>& equalities) {
-  Debug("equality") << "EqualityEngine::explainEquality(" << t1 << "," << t2 << ")" << std::endl;
-
-  // Don't notify during this check
-  ScopedBool turnOfNotify(d_performNotify, false);
-
-  // Add the terms (they might not be there)
-  addTerm(t1);
-  addTerm(t2);
-
-  Assert(getRepresentative(t1) == getRepresentative(t2),
-         "Cannot explain an equality, because the two terms are not equal!\n"
-         "The representative of %s\n"
-         "                   is %s\n"
-         "The representative of %s\n"
-         "                   is %s",
-         t1.toString().c_str(), getRepresentative(t1).toString().c_str(),
-         t2.toString().c_str(), getRepresentative(t2).toString().c_str());
-
-  // Get the explanation
-  EqualityNodeId t1Id = getNodeId(t1);
-  EqualityNodeId t2Id = getNodeId(t2);
-  getExplanation(t1Id, t2Id, equalities);
-
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::explainDisequality(TNode t1, TNode t2, std::vector<TNode>& equalities) {
-  Debug("equality") << "EqualityEngine::explainDisequality(" << t1 << "," << t2 << ")" << std::endl;
-
-  // Don't notify during this check
-  ScopedBool turnOfNotify(d_performNotify, false);
-
-  // Add the terms
-  addTerm(t1);
-  addTerm(t2);
-
-  // Add the equality
-  Node equality = t1.eqNode(t2);
-  addTerm(equality);
-
-  Assert(getRepresentative(equality) == getRepresentative(d_false),
-         "Cannot explain the dis-equality, because the two terms are not dis-equal!\n"
-         "The representative of %s\n"
-         "                   is %s\n"
-         "The representative of %s\n"
-         "                   is %s",
-         equality.toString().c_str(), getRepresentative(equality).toString().c_str(),
-         d_false.toString().c_str(), getRepresentative(d_false).toString().c_str());
-
-  // Get the explanation 
-  EqualityNodeId equalityId = getNodeId(equality);
-  EqualityNodeId falseId = getNodeId(d_false);
-  getExplanation(equalityId, falseId, equalities);
-
-}
-
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, std::vector<TNode>& equalities) const {
-
-  Debug("equality") << "EqualityEngine::getExplanation(" << d_nodes[t1Id] << "," << d_nodes[t2Id] << ")" << std::endl;
-
-  // If the nodes are the same, we're done
-  if (t1Id == t2Id) return;
-
-  if (Debug.isOn("equality::internal")) {
-    debugPrintGraph();
-  }
-
-  // Queue for the BFS containing nodes
-  std::vector<BfsData> bfsQueue;
-
-  // Find a path from t1 to t2 in the graph (BFS)
-  bfsQueue.push_back(BfsData(t1Id, null_id, 0));
-  size_t currentIndex = 0;
-  while (true) {
-    // There should always be a path, and every node can be visited only once (tree)
-    Assert(currentIndex < bfsQueue.size());
-
-    // The next node to visit
-    BfsData current = bfsQueue[currentIndex];
-    EqualityNodeId currentNode = current.nodeId;
-
-    Debug("equality") << "EqualityEngine::getExplanation(): currentNode =  " << d_nodes[currentNode] << std::endl;
-
-    // Go through the equality edges of this node
-    EqualityEdgeId currentEdge = d_equalityGraph[currentNode];
-    Debug("equality") << "EqualityEngine::getExplanation(): edgesId =  " << currentEdge << std::endl;
-    Debug("equality") << "EqualityEngine::getExplanation(): edges =  " << edgesToString(currentEdge) << std::endl;
-
-    while (currentEdge != null_edge) {
-      // Get the edge
-      const EqualityEdge& edge = d_equalityEdges[currentEdge];
-
-      // If not just the backwards edge
-      if ((currentEdge | 1u) != (current.edgeId | 1u)) {
-
-        Debug("equality") << "EqualityEngine::getExplanation(): currentEdge = (" << d_nodes[currentNode] << "," << d_nodes[edge.getNodeId()] << ")" << std::endl;
-
-        // Did we find the path
-        if (edge.getNodeId() == t2Id) {
-
-          Debug("equality") << "EqualityEngine::getExplanation(): path found: " << std::endl;
-
-          // Reconstruct the path
-          do {
-            // The current node
-            currentNode = bfsQueue[currentIndex].nodeId;
-            EqualityNodeId edgeNode = d_equalityEdges[currentEdge].getNodeId();
-            MergeReasonType reasonType = d_equalityEdges[currentEdge].getReasonType();
-
-            Debug("equality") << "EqualityEngine::getExplanation(): currentEdge = " << currentEdge << ", currentNode = " << currentNode << std::endl;
-
-            // Add the actual equality to the vector
-            if (reasonType == MERGED_THROUGH_CONGRUENCE) {
-              // f(x1, x2) == f(y1, y2) because x1 = y1 and x2 = y2
-              Debug("equality") << "EqualityEngine::getExplanation(): due to congruence, going deeper" << std::endl;
-              const FunctionApplication& f1 = d_applications[currentNode].original;
-              const FunctionApplication& f2 = d_applications[edgeNode].original;
-              Debug("equality") << push;
-              getExplanation(f1.a, f2.a, equalities);
-              getExplanation(f1.b, f2.b, equalities);
-              Debug("equality") << pop;
-            } else {
-              // Construct the equality
-              Debug("equality") << "EqualityEngine::getExplanation(): adding: " << d_equalityEdges[currentEdge].getReason() << std::endl;
-              equalities.push_back(d_equalityEdges[currentEdge].getReason());
-            }
-
-            // Go to the previous
-            currentEdge = bfsQueue[currentIndex].edgeId;
-            currentIndex = bfsQueue[currentIndex].previousIndex;
-          } while (currentEdge != null_id);
-
-          // Done
-          return;
-        }
-
-        // Push to the visitation queue if it's not the backward edge
-        bfsQueue.push_back(BfsData(edge.getNodeId(), currentEdge, currentIndex));
-      }
-
-      // Go to the next edge
-      currentEdge = edge.getNext();
-    }
-
-    // Go to the next node to visit
-    ++ currentIndex;
-  }
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::addTriggerDisequality(TNode t1, TNode t2, TNode trigger) {
-  Node equality = t1.eqNode(t2);
-  addTerm(equality);
-  addTriggerEquality(equality, d_false, trigger);
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::addTriggerEquality(TNode t1, TNode t2, TNode trigger) {
-
-  Debug("equality") << "EqualityEngine::addTrigger(" << t1 << ", " << t2 << ", " << trigger << ")" << std::endl;
-
-  Assert(hasTerm(t1));
-  Assert(hasTerm(t2));
-
-  // Get the information about t1
-  EqualityNodeId t1Id = getNodeId(t1);
-  EqualityNodeId t1classId = getEqualityNode(t1Id).getFind();
-  TriggerId t1TriggerId = d_nodeTriggers[t1classId];
-
-  // Get the information about t2
-  EqualityNodeId t2Id = getNodeId(t2);
-  EqualityNodeId t2classId = getEqualityNode(t2Id).getFind();
-  TriggerId t2TriggerId = d_nodeTriggers[t2classId];
-
-  Debug("equality") << "EqualityEngine::addTrigger(" << trigger << "): " << t1Id << " (" << t1classId << ") = " << t2Id << " (" << t2classId << ")" << std::endl;
-
-  // Create the triggers
-  TriggerId t1NewTriggerId = d_equalityTriggers.size();
-  TriggerId t2NewTriggerId = t1NewTriggerId | 1;
-  d_equalityTriggers.push_back(Trigger(t1classId, t1TriggerId));
-  d_equalityTriggersOriginal.push_back(trigger);
-  d_equalityTriggers.push_back(Trigger(t2classId, t2TriggerId));
-  d_equalityTriggersOriginal.push_back(trigger);
-
-  // Update the counters
-  d_equalityTriggersCount = d_equalityTriggersCount + 2;
-
-  // Add the trigger to the trigger graph
-  d_nodeTriggers[t1classId] = t1NewTriggerId;
-  d_nodeTriggers[t2classId] = t2NewTriggerId;
-
-  // If the trigger is already on, we propagate
-  if (t1classId == t2classId) {
-    Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << "): triggered at setup time" << std::endl;
-    if (d_performNotify) {
-      d_notify.notify(trigger); // Don't care about the return value
-    }
-  }
-
-  if (Debug.isOn("equality::internal")) {
-    debugPrintGraph();
-  }
-
-  Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << ") => (" << t1NewTriggerId << ", " << t2NewTriggerId << ")" << std::endl;
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::propagate() {
-
-  Debug("equality") << "EqualityEngine::propagate()" << std::endl;
-
-  bool done = false;
-  while (!d_propagationQueue.empty()) {
-
-    // The current merge candidate
-    const MergeCandidate current = d_propagationQueue.front();
-    d_propagationQueue.pop();
-
-    if (done) {
-      // If we're done, just empty the queue
-      continue;
-    }
-
-    // Get the representatives
-    EqualityNodeId t1classId = getEqualityNode(current.t1Id).getFind();
-    EqualityNodeId t2classId = getEqualityNode(current.t2Id).getFind();
-
-    // If already the same, we're done
-    if (t1classId == t2classId) {
-      continue;
-    }
-
-    // Get the nodes of the representatives
-    EqualityNode& node1 = getEqualityNode(t1classId);
-    EqualityNode& node2 = getEqualityNode(t2classId);
-
-    Assert(node1.getFind() == t1classId);
-    Assert(node2.getFind() == t2classId);
-
-    // Add the actual equality to the equality graph
-    addGraphEdge(current.t1Id, current.t2Id, current.type, current.reason);
-
-    // One more equality added
-    d_assertedEqualitiesCount = d_assertedEqualitiesCount + 1;
-
-    // Depending on the merge preference (such as size), merge them
-    std::vector<TriggerId> triggers;
-    if (node2.getSize() > node1.getSize()) {
-      Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t1Id]<< " into " << d_nodes[current.t2Id] << std::endl;
-      d_assertedEqualities.push_back(Equality(t2classId, t1classId));
-      merge(node2, node1, triggers);
-    } else {
-      Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t2Id] << " into " << d_nodes[current.t1Id] << std::endl;
-      d_assertedEqualities.push_back(Equality(t1classId, t2classId));
-      merge(node1, node2, triggers);
-    }
-
-    // Notify the triggers
-    if (d_performNotify) {
-      for (size_t trigger = 0, trigger_end = triggers.size(); trigger < trigger_end && !done; ++ trigger) {
-        // Notify the trigger and exit if it fails
-        done = !d_notify.notify(d_equalityTriggersOriginal[triggers[trigger]]);
-      }
-    }
-  }
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::debugPrintGraph() const {
-  for (EqualityNodeId nodeId = 0; nodeId < d_nodes.size(); ++ nodeId) {
-
-    Debug("equality::graph") << d_nodes[nodeId] << " " << nodeId << "(" << getEqualityNode(nodeId).getFind() << "):";
-
-    EqualityEdgeId edgeId = d_equalityGraph[nodeId];
-    while (edgeId != null_edge) {
-      const EqualityEdge& edge = d_equalityEdges[edgeId];
-      Debug("equality::graph") << " " << d_nodes[edge.getNodeId()] << ":" << edge.getReason();
-      edgeId = edge.getNext();
-    }
-
-    Debug("equality::graph") << std::endl;
-  }
-}
-
-template <typename NotifyClass>
-bool EqualityEngine<NotifyClass>::areEqual(TNode t1, TNode t2)
-{
-  // Don't notify during this check
-  ScopedBool turnOfNotify(d_performNotify, false);
-
-  // Add the terms
-  addTerm(t1);
-  addTerm(t2);
-  bool equal = getEqualityNode(t1).getFind() == getEqualityNode(t2).getFind();
-
-  // Return whether the two terms are equal
-  return equal;
-}
-
-template <typename NotifyClass>
-bool EqualityEngine<NotifyClass>::areDisequal(TNode t1, TNode t2)
-{
-  // Don't notify during this check
-  ScopedBool turnOfNotify(d_performNotify, false);
-
-  // Add the terms
-  addTerm(t1);
-  addTerm(t2);
-
-  // Check (t1 = t2) = false
-  Node equality = t1.eqNode(t2);
-  addTerm(equality);
-  if (getEqualityNode(equality).getFind() == getEqualityNode(d_false).getFind()) {
-    return true;
-  }
-
-  // Return whether the terms are disequal
-  return false;
-}
-
-template <typename NotifyClass>
-size_t EqualityEngine<NotifyClass>::getSize(TNode t)
-{
-  // Add the term
-  addTerm(t);
-  return getEqualityNode(getEqualityNode(t).getFind()).getSize();
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::addTriggerTerm(TNode t) 
-{
-  Debug("equality::internal") << "EqualityEngine::addTriggerTerm(" << t << ")" << std::endl;
-
-  // Add the term if it's not already there
-  addTerm(t);
-
-  // Get the node id
-  EqualityNodeId eqNodeId = getNodeId(t);
-  EqualityNode& eqNode = getEqualityNode(eqNodeId);
-  EqualityNodeId classId = eqNode.getFind();
-
-  if (d_nodeIndividualTrigger[classId] != +null_id) {  
-    // No need to keep it, just propagate the existing individual triggers
-    if (d_performNotify) {
-      d_notify.notify(t, d_nodes[d_nodeIndividualTrigger[classId]]); 
-    }
-  } else {
-    // Add it to the list for backtracking
-    d_individualTriggers.push_back(classId);
-    d_individualTriggersSize = d_individualTriggersSize + 1; 
-    // Mark the class id as a trigger
-    d_nodeIndividualTrigger[classId] = eqNodeId;
-  }
-}
-
-template <typename NotifyClass>
-bool EqualityEngine<NotifyClass>::isTriggerTerm(TNode t) const {
-  if (!hasTerm(t)) return false;
-  EqualityNodeId classId = getEqualityNode(t).getFind();
-  return d_nodeIndividualTrigger[classId] != +null_id;
-}
-
-
-template <typename NotifyClass>
-TNode EqualityEngine<NotifyClass>::getTriggerTermRepresentative(TNode t) const {
-  Assert(isTriggerTerm(t));
-  EqualityNodeId classId = getEqualityNode(t).getFind();
-  return d_nodes[d_nodeIndividualTrigger[classId]];
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::storeApplicationLookup(FunctionApplication& funNormalized, EqualityNodeId funId) {
-  Assert(d_applicationLookup.find(funNormalized) == d_applicationLookup.end());
-  d_applicationLookup[funNormalized] = funId;
-  d_applicationLookups.push_back(funNormalized);
-  d_applicationLookupsCount = d_applicationLookupsCount + 1;
-  Debug("equality::backtrack") << "d_applicationLookupsCount = " << d_applicationLookupsCount << std::endl;
-  Debug("equality::backtrack") << "d_applicationLookups.size() = " << d_applicationLookups.size() << std::endl;
-  Assert(d_applicationLookupsCount == d_applicationLookups.size());
-}
-
-template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::getUseListTerms(TNode t, std::set<TNode>& output) {
-  if (hasTerm(t)) {
-    // Get the equivalence class
-    EqualityNodeId classId = getEqualityNode(t).getFind();
-    // Go through the equivalence class and get where t is used in
-    EqualityNodeId currentId = classId;
-    do {
-      // Get the current node
-      EqualityNode& currentNode = getEqualityNode(currentId);
-      // Go through the use-list
-      UseListNodeId currentUseId = currentNode.getUseList();
-      while (currentUseId != null_uselist_id) {
-        // Get the node of the use list
-        UseListNode& useNode = d_useListNodes[currentUseId];
-        // Get the function application
-        EqualityNodeId funId = useNode.getApplicationId();
-        output.insert(d_nodes[funId]);
-        // Go to the next one in the use list
-        currentUseId = useNode.getNext();
-      }
-      // Move to the next node
-      currentId = currentNode.getNext();
-    } while (currentId != classId);
-  }
-}
-
-} // Namespace uf
-} // Namespace theory
-} // Namespace CVC4
-
index ec28dad7569f269d2bdc9afada73186891f40993..cddd01b071202bb6000b54f159fe753c3b5fe568 100644 (file)
  **/
 
 #include "theory/uf/theory_uf.h"
-#include "theory/uf/equality_engine_impl.h"
 
 using namespace std;
-
-namespace CVC4 {
-namespace theory {
-namespace uf {
+using namespace CVC4;
+using namespace CVC4::theory;
+using namespace CVC4::theory::uf;
 
 /** Constructs a new instance of TheoryUF w.r.t. the provided context.*/
 TheoryUF::TheoryUF(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo) :
@@ -40,12 +38,6 @@ TheoryUF::TheoryUF(context::Context* c, context::UserContext* u, OutputChannel&
   d_equalityEngine.addFunctionKind(kind::APPLY_UF);
   d_equalityEngine.addFunctionKind(kind::EQUAL);
 
-  // The boolean constants
-  d_true = NodeManager::currentNM()->mkConst<bool>(true);
-  d_false = NodeManager::currentNM()->mkConst<bool>(false);
-  d_equalityEngine.addTerm(d_true);
-  d_equalityEngine.addTerm(d_false);
-  d_equalityEngine.addTriggerEquality(d_true, d_false, d_false);
 }/* TheoryUF::TheoryUF() */
 
 static Node mkAnd(const std::vector<TNode>& conjunctions) {
@@ -91,23 +83,12 @@ void TheoryUF::check(Effort level) {
     }
 
     // Do the work
-    switch (fact.getKind()) {
-    case kind::EQUAL:
-      d_equalityEngine.addEquality(fact[0], fact[1], fact);
-      break;
-    case kind::APPLY_UF:
-      d_equalityEngine.addPredicate(fact, true, fact);
-      break;
-    case kind::NOT:
-      if (fact[0].getKind() == kind::APPLY_UF) {
-        d_equalityEngine.addPredicate(fact[0], false, fact);
-      } else {
-        // Assert the dis-equality
-        d_equalityEngine.addDisequality(fact[0][0], fact[0][1], fact);
-      }
-      break;
-    default:
-      Unreachable();
+    bool polarity = fact.getKind() != kind::NOT;
+    TNode atom = polarity ? fact : fact[0];
+    if (atom.getKind() == kind::EQUAL) {
+      d_equalityEngine.assertEquality(atom, polarity, fact);
+    } else {
+      d_equalityEngine.assertPredicate(atom, polarity, fact);
     }
   }
 
@@ -139,10 +120,8 @@ void TheoryUF::propagate(Effort level) {
         Debug("uf") << "TheoryUF::propagate(): in conflict, normalized = " << normalized << std::endl;
         Node negatedLiteral;
         std::vector<TNode> assumptions;
-        if (normalized != d_false) {
-          negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode();
-          assumptions.push_back(negatedLiteral);
-        }
+        negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode();
+        assumptions.push_back(negatedLiteral);
         explain(literal, assumptions);
         d_conflictNode = mkAnd(assumptions);
         d_conflict = true;
@@ -157,21 +136,17 @@ void TheoryUF::preRegisterTerm(TNode node) {
 
   switch (node.getKind()) {
   case kind::EQUAL:
-    // Add the terms
-    d_equalityEngine.addTerm(node[0]);
-    d_equalityEngine.addTerm(node[1]);
     // Add the trigger for equality
-    d_equalityEngine.addTriggerEquality(node[0], node[1], node);
-    d_equalityEngine.addTriggerDisequality(node[0], node[1], node.notNode());
+    d_equalityEngine.addTriggerEquality(node);
     break;
   case kind::APPLY_UF:
-    // Function applications/predicates
-    d_equalityEngine.addTerm(node);
     // Maybe it's a predicate
     if (node.getType().isBoolean()) {
       // Get triggered for both equal and dis-equal
-      d_equalityEngine.addTriggerEquality(node, d_true, node);
-      d_equalityEngine.addTriggerEquality(node, d_false, node.notNode());
+      d_equalityEngine.addTriggerPredicate(node);
+    } else {
+      // Function applications/predicates
+      d_equalityEngine.addTerm(node);
     }
     // Remember the function and predicate terms
     d_functionsTerms.push_back(node);
@@ -194,26 +169,20 @@ bool TheoryUF::propagate(TNode literal) {
 
   // See if the literal has been asserted already
   Node normalized = Rewriter::rewrite(literal);
-  bool satValue = false;
-  bool isAsserted = normalized == d_false || d_valuation.hasSatValue(normalized, satValue);
 
-  // If asserted, we're done or in conflict
-  if (isAsserted) {
-    if (!satValue) {
+  // If asserted and is false, we're done or in conflict
+  // Note that even trivial equalities have a SAT value (i.e. 1 = 2 -> false)
+  bool satValue = false;
+  if (d_valuation.hasSatValue(normalized, satValue) && !satValue) {
       Debug("uf") << "TheoryUF::propagate(" << literal << ", normalized = " << normalized << ") => conflict" << std::endl;
       std::vector<TNode> assumptions;
       Node negatedLiteral;
-      if (normalized != d_false) {
-        negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode();
-        assumptions.push_back(negatedLiteral);
-      }
+      negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode();
+      assumptions.push_back(negatedLiteral);
       explain(literal, assumptions);
       d_conflictNode = mkAnd(assumptions);
       d_conflict = true;
       return false;
-    }
-    // Propagate even if already known in SAT - could be a new equation between shared terms
-    // (terms that weren't shared when the literal was asserted!)
   }
 
   // Nothing, just enqueue it for propagation and mark it as asserted already
@@ -224,36 +193,14 @@ bool TheoryUF::propagate(TNode literal) {
 }/* TheoryUF::propagate(TNode) */
 
 void TheoryUF::explain(TNode literal, std::vector<TNode>& assumptions) {
-  TNode lhs, rhs;
-  switch (literal.getKind()) {
-    case kind::EQUAL:
-      lhs = literal[0];
-      rhs = literal[1];
-      break;
-    case kind::APPLY_UF:
-      lhs = literal;
-      rhs = d_true;
-      break;
-    case kind::NOT:
-      if (literal[0].getKind() == kind::EQUAL) {
-        // Disequalities
-        d_equalityEngine.explainDisequality(literal[0][0], literal[0][1], assumptions);
-        return;
-      } else {
-        // Predicates
-        lhs = literal[0];
-        rhs = d_false;
-        break;
-      }
-    case kind::CONST_BOOLEAN:
-      // we get to explain true = false, since we set false to be the trigger of this
-      lhs = d_true;
-      rhs = d_false;
-      break;
-    default:
-      Unreachable();
+  // Do the work
+  bool polarity = literal.getKind() != kind::NOT;
+  TNode atom = polarity ? literal : literal[0];
+  if (atom.getKind() == kind::EQUAL || atom.getKind() == kind::IFF) {
+    d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
+  } else {
+    d_equalityEngine.explainPredicate(atom, polarity, assumptions);
   }
-  d_equalityEngine.explainEquality(lhs, rhs, assumptions);
 }
 
 Node TheoryUF::explain(TNode literal) {
@@ -508,7 +455,3 @@ void TheoryUF::computeCareGraph() {
     }
   }
 }/* TheoryUF::computeCareGraph() */
-
-}/* CVC4::theory::uf namespace */
-}/* CVC4::theory namespace */
-}/* CVC4 namespace */
index 6956390f524f34947c3c389fbd70cb607881fd50..9017928b77a1f9ed783a639a210475429d3f33aa 100644 (file)
@@ -39,21 +39,46 @@ namespace uf {
 class TheoryUF : public Theory {
 public:
 
-  class NotifyClass {
+  class NotifyClass : public eq::EqualityEngineNotify {
     TheoryUF& d_uf;
   public:
     NotifyClass(TheoryUF& uf): d_uf(uf) {}
 
-    bool notify(TNode propagation) {
-      Debug("uf") << "NotifyClass::notify(" << propagation << ")" << std::endl;
-      // Just forward to uf
-      return d_uf.propagate(propagation);
+    bool eqNotifyTriggerEquality(TNode equality, bool value) {
+      Debug("uf") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false" )<< ")" << std::endl;
+      if (value) {
+        return d_uf.propagate(equality);
+      } else {
+        // We use only literal triggers so taking not is safe
+        return d_uf.propagate(equality.notNode());
+      }
     }
-    
-    void notify(TNode t1, TNode t2) {
-      Debug("uf") << "NotifyClass::notify(" << t1 << ", " << t2 << ")" << std::endl;
-      Node equality = Rewriter::rewriteEquality(theory::THEORY_UF, t1.eqNode(t2));
-      d_uf.propagate(equality);
+
+    bool eqNotifyTriggerPredicate(TNode predicate, bool value) {
+      Debug("uf") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false" )<< ")" << std::endl;
+      if (value) {
+        return d_uf.propagate(predicate);
+      } else {
+       return d_uf.propagate(predicate.notNode());
+      }
+    }
+
+    bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) {
+      Debug("uf") << "NotifyClass::eqNotifyTriggerTermMerge(" << t1 << ", " << t2 << std::endl;
+      if (value) {
+        return d_uf.propagate(t1.eqNode(t2));
+      } else {
+        return d_uf.propagate(t1.eqNode(t2).notNode());
+      }
+    }
+
+    bool eqNotifyConstantTermMerge(TNode t1, TNode t2) {
+      Debug("uf") << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << std::endl;
+      if (Theory::theoryOf(t1) == THEORY_BOOL) {
+        return d_uf.propagate(t1.iffNode(t2));
+      } else {
+        return d_uf.propagate(t1.eqNode(t2));
+      }
     }
   };
 
@@ -63,7 +88,7 @@ private:
   NotifyClass d_notify;
 
   /** Equaltity engine */
-  EqualityEngine<NotifyClass> d_equalityEngine;
+  eq::EqualityEngine d_equalityEngine;
 
   /** Are we in conflict */
   context::CDO<bool> d_conflict;
@@ -72,7 +97,8 @@ private:
   Node d_conflictNode;
 
   /**
-   * Should be called to propagate the literal. 
+   * Should be called to propagate the literal. We use a node here 
+   * since some of the propagated literals are not kept anywhere. 
    */
   bool propagate(TNode literal);
 
@@ -90,12 +116,6 @@ private:
   /** All the function terms that the theory has seen */
   context::CDList<TNode> d_functionsTerms;
 
-  /** True node for predicates = true */
-  Node d_true;
-
-  /** True node for predicates = false */
-  Node d_false;
-
   /** Symmetry analyzer */
   SymmetryBreaker d_symb;
 
index 66b0a2f909702003fb7e6ced8265af54d61c7e47..6f01d6cf455e5cdad72a3966a4284c8480d80915 100644 (file)
@@ -162,7 +162,7 @@ bool Configuration::isDebugTag(char const *tag){
       return true;
     }
   }
-#endif * CVC4_DEBUG */
+#endif /* CVC4_DEBUG */
   return false;
 }
 
index 33863e8488b2d8f7a2a384e1681c2dc439954f67..1a50d0637de4c0b06997cd64cb6a3293e0e7bf8d 100644 (file)
@@ -37,7 +37,7 @@ struct MyContextNotifyObj : public ContextNotifyObj {
     nCalls(0) {
   }
 
-  void notify() {
+  void contextNotifyPop() {
     ++nCalls;
   }
 };
index 63ba95b5744ec5475ef77ce0723360422d05b031..c24104acc582eff0eb53fd64a7737e2002c2a3a9 100644 (file)
@@ -59,6 +59,14 @@ public:
     return d_nextVar++;
   }
 
+  SatVariable trueVar() {
+    return d_nextVar++;
+  }
+
+  SatVariable falseVar() {
+    return d_nextVar++;
+  }
+
   void addClause(SatClause& c, bool lemma) {
     d_addClauseCalled = true;
   }