Merging the unate-propagator branch into the trunk. This is a big update so expect...
authorTim King <taking@cs.nyu.edu>
Tue, 29 Jun 2010 20:53:47 +0000 (20:53 +0000)
committerTim King <taking@cs.nyu.edu>
Tue, 29 Jun 2010 20:53:47 +0000 (20:53 +0000)
22 files changed:
src/prop/cnf_stream.cpp
src/prop/cnf_stream.h
src/prop/minisat/core/Solver.C
src/prop/minisat/core/Solver.h
src/prop/minisat/simp/SimpSolver.C
src/prop/sat.cpp
src/prop/sat.h
src/theory/Makefile.am
src/theory/arith/Makefile.am
src/theory/arith/arith_propagator.cpp [new file with mode: 0644]
src/theory/arith/arith_propagator.h [new file with mode: 0644]
src/theory/arith/ordered_bounds_list.h [new file with mode: 0644]
src/theory/arith/theory_arith.cpp
src/theory/arith/theory_arith.h
src/theory/theory.cpp
src/theory/theory.h
src/theory/theory_engine.h
src/theory/theory_test_utils.h [new file with mode: 0644]
src/theory/uf/theory_uf.cpp
test/unit/Makefile.am
test/unit/theory/theory_arith_white.h [new file with mode: 0644]
test/unit/theory/theory_uf_white.h

index 45f7ab39842fbe9cd72d4da9383733f7e08d6c70..9136a73c307795aa831ef7d8789e2009223d8905 100644 (file)
@@ -102,26 +102,10 @@ Node CnfStream::getNode(const SatLiteral& literal) {
   return node;
 }
 
-SatLiteral CnfStream::getLiteral(TNode node) {
-  TranslationCache::iterator find = d_translationCache.find(node);
-  Assert(find != d_translationCache.end(), "Literal not in the CNF Cache");
-  SatLiteral literal = find->second;
-  Debug("cnf") << "CnfStream::getLiteral(" << node << ") => " << literal << std::endl;
-  return literal;
-}
-
-const CnfStream::NodeCache& CnfStream::getNodeCache() const {
-  return d_nodeCache;
-}
-
-const CnfStream::TranslationCache& CnfStream::getTranslationCache() const {
-  return d_translationCache;
-}
-
-SatLiteral TseitinCnfStream::handleAtom(TNode node) {
+SatLiteral CnfStream::convertAtom(TNode node) {
   Assert(!isCached(node), "atom already mapped!");
 
-  Debug("cnf") << "handleAtom(" << node << ")" << endl;
+  Debug("cnf") << "convertAtom(" << node << ")" << endl;
 
   bool theoryLiteral = node.getKind() != kind::VARIABLE;
   SatLiteral lit = newLiteral(node, theoryLiteral);
@@ -137,6 +121,23 @@ SatLiteral TseitinCnfStream::handleAtom(TNode node) {
   return lit;
 }
 
+SatLiteral CnfStream::getLiteral(TNode node, bool create /* = false */) {
+  TranslationCache::iterator find = d_translationCache.find(node);
+  SatLiteral literal;
+  if(create) {
+    if(find == d_translationCache.end()) {
+      literal = convertAtom(node);
+    } else {
+      literal = find->second;
+    }
+  } else {
+    Assert(find != d_translationCache.end(), "Literal not in the CNF Cache");
+    literal = find->second;
+  }
+  Debug("cnf") << "CnfStream::getLiteral(" << node << ", create = " << create << ") => " << literal << std::endl;
+  return literal;
+}
+
 SatLiteral TseitinCnfStream::handleXor(TNode xorNode) {
   Assert(!isCached(xorNode), "Atom already mapped!");
   Assert(xorNode.getKind() == XOR, "Expecting an XOR expression!");
@@ -366,10 +367,10 @@ SatLiteral TseitinCnfStream::toCNF(TNode node, bool negated) {
     default:
       {
         //TODO make sure this does not contain any boolean substructure
-        nodeLit = handleAtom(node);
+        nodeLit = convertAtom(node);
         //Unreachable();
         //Node atomic = handleNonAtomicNode(node);
-        //return isCached(atomic) ? lookupInCache(atomic) : handleAtom(atomic);
+        //return isCached(atomic) ? lookupInCache(atomic) : convertAtom(atomic);
       }
     }
   }
index abb69f590f54f4dac110aa6677a884f0a37995f4..ba87cf269fe6255577e93bbd4c8e4cad0ec3b42f 100644 (file)
@@ -127,6 +127,16 @@ protected:
    */
   SatLiteral newLiteral(TNode node, bool theoryLiteral = false);
 
+  /**
+   * Constructs a new literal for an atom and returns it.  Calls
+   * newLiteral().
+   *
+   * @param node the node to convert; there should be no boolean
+   * structure in this expression.  Assumed to not be in the
+   * translation cache.
+   */
+  SatLiteral convertAtom(TNode node);
+
 public:
 
   /**
@@ -161,14 +171,25 @@ public:
 
   /**
    * Returns the literal that represents the given node in the SAT CNF
-   * representation. [Presumably there are some constraints on the kind
-   * of node? E.g., it needs to be a boolean? -Chris]
-   *
+   * representation.
+   * @param node [Presumably there are some constraints on the kind of
+   * node? E.g., it needs to be a boolean? -Chris]
+   * @param create Controls whether or not to create a new SAT literal
+   * mapping for Node if it does not exist.  This exists to break
+   * circular dependencies, where an atom is converted and asserted to
+   * the SAT solver, which propagates it immediately since it's a
+   * unit, which can theory-propagate additional literals that don't
+   * yet have a SAT literal mapping.
    */
-  SatLiteral getLiteral(TNode node);
+  SatLiteral getLiteral(TNode node, bool create = false);
+
+  const TranslationCache& getTranslationCache() const {
+    return d_translationCache;
+  }
 
-  const TranslationCache& getTranslationCache() const;
-  const NodeCache& getNodeCache() const;
+  const NodeCache& getNodeCache() const {
+    return d_nodeCache;
+  }
 }; /* class CnfStream */
 
 /**
@@ -178,7 +199,7 @@ public:
  * will be equivalent to each subexpression in the constructed equi-satisfiable
  * formula, then substitute the new literal for the formula, and so on,
  * recursively.
- * 
+ *
  * This implementation does this in a single recursive pass. [??? -Chris]
  */
 class TseitinCnfStream : public CnfStream {
@@ -211,7 +232,6 @@ private:
   //   - returning l
   //
   // handleX( n ) can assume that n is not in d_translationCache
-  SatLiteral handleAtom(TNode node);
   SatLiteral handleNot(TNode node);
   SatLiteral handleXor(TNode node);
   SatLiteral handleImplies(TNode node);
index 8533e191ba7a07db4e2f84c437d3bc6ca0b17134..1667af20d0bcda2e06a4308ebafbf3a2c50ba9b1 100644 (file)
@@ -29,6 +29,28 @@ namespace CVC4 {
 namespace prop {
 namespace minisat {
 
+Clause* Solver::lazy_reason = reinterpret_cast<Clause*>(1);
+
+Clause* Solver::getReason(Lit l)
+{
+    if (reason[var(l)] != lazy_reason) return reason[var(l)];
+    // Get the explanation from the theory
+    SatClause explanation;
+    if (value(l) == l_True) {
+      proxy->explainPropagation(l, explanation);
+      assert(explanation[0] == l);
+    } else {
+      proxy->explainPropagation(~l, explanation);
+      assert(explanation[0] == ~l);
+    }
+    Clause* real_reason = Clause_new(explanation, true);
+    reason[var(l)] = real_reason;
+    // Add it to the database
+    learnts.push(real_reason);
+    attachClause(*real_reason);
+    return real_reason;
+}
+
 Solver::Solver(SatSolver* proxy, context::Context* context) :
 
     // SMT stuff
@@ -122,7 +144,7 @@ bool Solver::addClause(vec<Lit>& ps, ClauseType type)
         assert(type != CLAUSE_LEMMA);
         assert(value(ps[0]) == l_Undef);
         uncheckedEnqueue(ps[0]);
-        return ok = (propagate() == NULL);
+        return ok = (propagate(CHECK_WITHOUTH_PROPAGATION_QUICK) == NULL);
     }else{
         Clause* c = Clause_new(ps, false);
         clauses.push(c);
@@ -282,7 +304,7 @@ void Solver::analyze(Clause* confl, vec<Lit>& out_learnt, int& out_btlevel)
         // Select next clause to look at:
         while (!seen[var(trail[index--])]);
         p     = trail[index+1];
-        confl = reason[var(p)];
+        confl = getReason(p);
         seen[var(p)] = 0;
         pathC--;
 
@@ -299,12 +321,12 @@ void Solver::analyze(Clause* confl, vec<Lit>& out_learnt, int& out_btlevel)
 
         out_learnt.copyTo(analyze_toclear);
         for (i = j = 1; i < out_learnt.size(); i++)
-            if (reason[var(out_learnt[i])] == NULL || !litRedundant(out_learnt[i], abstract_level))
+            if (getReason(out_learnt[i]) == NULL || !litRedundant(out_learnt[i], abstract_level))
                 out_learnt[j++] = out_learnt[i];
     }else{
         out_learnt.copyTo(analyze_toclear);
         for (i = j = 1; i < out_learnt.size(); i++){
-            Clause& c = *reason[var(out_learnt[i])];
+            Clause& c = *getReason(out_learnt[i]);
             for (int k = 1; k < c.size(); k++)
                 if (!seen[var(c[k])] && level[var(c[k])] > 0){
                     out_learnt[j++] = out_learnt[i];
@@ -342,13 +364,13 @@ bool Solver::litRedundant(Lit p, uint32_t abstract_levels)
     analyze_stack.clear(); analyze_stack.push(p);
     int top = analyze_toclear.size();
     while (analyze_stack.size() > 0){
-        assert(reason[var(analyze_stack.last())] != NULL);
+        assert(getReason(analyze_stack.last()) != NULL);
         Clause& c = *reason[var(analyze_stack.last())]; analyze_stack.pop();
 
         for (int i = 1; i < c.size(); i++){
             Lit p  = c[i];
             if (!seen[var(p)] && level[var(p)] > 0){
-                if (reason[var(p)] != NULL && (abstractLevel(var(p)) & abstract_levels) != 0){
+                if (getReason(p) != NULL && (abstractLevel(var(p)) & abstract_levels) != 0){
                     seen[var(p)] = 1;
                     analyze_stack.push(p);
                     analyze_toclear.push(p);
@@ -415,42 +437,74 @@ void Solver::uncheckedEnqueue(Lit p, Clause* from)
     polarity [var(p)] = sign(p);
     trail.push(p);
 
-    if (theory[var(p)]) {
+    if (theory[var(p)] && from != lazy_reason) {
       // Enqueue to the theory
       proxy->enqueueTheoryLiteral(p);
     }
 }
 
 
-Clause* Solver::propagate()
+Clause* Solver::propagate(TheoryCheckType type)
 {
     Clause* confl = NULL;
 
-    while(qhead < trail.size()) {
-      confl = propagateBool();
-      if (confl != NULL) break;
-      confl = propagateTheory();
-      if (confl != NULL) break;
+    // If this is the final check, no need for Boolean propagation and
+    // theory propagation
+    if (type == CHECK_WITHOUTH_PROPAGATION_FINAL) {
+      return theoryCheck(theory::Theory::FULL_EFFORT);
     }
 
+    // The effort we will be using to theory check
+    theory::Theory::Effort effort = type == CHECK_WITHOUTH_PROPAGATION_QUICK ?
+        theory::Theory::QUICK_CHECK : theory::Theory::STANDARD;
+
+    // Keep running until we have checked everything, we
+    // have no conflict and no new literals have been asserted
+    bool new_assertions;
+    do {
+        new_assertions = false;
+        while(qhead < trail.size()) {
+            confl = propagateBool();
+            if (confl != NULL) break;
+            confl = theoryCheck(effort);
+            if (confl != NULL) break;
+        }
+
+        if (confl == NULL && type == CHECK_WITH_PROPAGATION_STANDARD) {
+          new_assertions = propagateTheory();
+          if (!new_assertions) break;
+        }
+    } while (new_assertions);
+
     return confl;
 }
 
+bool Solver::propagateTheory() {
+  std::vector<Lit> propagatedLiterals;
+  proxy->theoryPropagate(propagatedLiterals);
+  const unsigned i_end = propagatedLiterals.size();
+  for (unsigned i = 0; i < i_end; ++ i) {
+    uncheckedEnqueue(propagatedLiterals[i], lazy_reason);
+  }
+  proxy->clearPropagatedLiterals();
+  return propagatedLiterals.size() > 0;
+}
+
 /*_________________________________________________________________________________________________
 |
-|  propagateTheory : [void]  ->  [Clause*]
+|  theoryCheck: [void]  ->  [Clause*]
 |
 |  Description:
-|    Propagates all enqueued theory facts. If a conflict arises, the conflicting clause is returned,
-|    otherwise NULL.
+|    Checks all enqueued theory facts for satisfiability. If a conflict arises, the conflicting
+|    clause is returned, otherwise NULL.
 |
 |    Note: the propagation queue might be NOT empty
 |________________________________________________________________________________________________@*/
-Clause* Solver::propagateTheory()
+Clause* Solver::theoryCheck(theory::Theory::Effort effort)
 {
   Clause* c = NULL;
   SatClause clause;
-  proxy->theoryCheck(clause);
+  proxy->theoryCheck(effort, clause);
   int clause_size = clause.size();
   Assert(clause_size != 1, "Can't handle unit clause explanations");
   if(clause_size > 0) {
@@ -598,7 +652,7 @@ bool Solver::simplify()
 {
     assert(decisionLevel() == 0);
 
-    if (!ok || propagate() != NULL)
+    if (!ok || propagate(CHECK_WITHOUTH_PROPAGATION_QUICK) != NULL)
         return ok = false;
 
     if (nAssigns() == simpDB_assigns || (simpDB_props > 0))
@@ -643,9 +697,9 @@ lbool Solver::search(int nof_conflicts, int nof_learnts)
     starts++;
 
     bool first = true;
-
+    TheoryCheckType check_type = CHECK_WITH_PROPAGATION_STANDARD;
     for (;;){
-        Clause* confl = propagate();
+        Clause* confl = propagate(check_type);
         if (confl != NULL){
             // CONFLICT
             conflicts++; conflictC++;
@@ -671,9 +725,16 @@ lbool Solver::search(int nof_conflicts, int nof_learnts)
             varDecayActivity();
             claDecayActivity();
 
+            // We have a conflict so, we are going back to standard checks
+            check_type = CHECK_WITH_PROPAGATION_STANDARD;
+
         }else{
             // NO CONFLICT
 
+            // If this was a final check, we are satisfiable
+            if (check_type == CHECK_WITHOUTH_PROPAGATION_FINAL)
+              return l_True;
+
             if (nof_conflicts >= 0 && conflictC >= nof_conflicts){
                 // Reached bound on number of conflicts:
                 progress_estimate = progressEstimate();
@@ -709,9 +770,11 @@ lbool Solver::search(int nof_conflicts, int nof_learnts)
                 decisions++;
                 next = pickBranchLit(polarity_mode, random_var_freq);
 
-                if (next == lit_Undef)
-                    // Model found:
-                    return l_True;
+                if (next == lit_Undef) {
+                    // We need to do a full theory check to confirm
+                    check_type = CHECK_WITHOUTH_PROPAGATION_FINAL;
+                    continue;
+                }
             }
 
             // Increase decision level and enqueue 'next'
index 312fe44d52bce2f4f132315b58badca58b4c2ec8..2e44803e9411a313911997e0cefc386e1897ccc6 100644 (file)
@@ -23,6 +23,7 @@ OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWA
 #define __CVC4__PROP__MINISAT__SOLVER_H
 
 #include "context/context.h"
+#include "theory/theory.h"
 
 #include <cstdio>
 #include <cassert>
@@ -161,7 +162,11 @@ protected:
     vec<int>            trail_lim;        // Separator indices for different decision levels in 'trail'.
     vec<Clause*>        lemmas;           // List of lemmas we added (context dependent)
     vec<int>            lemmas_lim;       // Separator indices for different decision levels in 'lemmas'.
-    vec<Clause*>        reason;           // 'reason[var]' is the clause that implied the variables current value, or 'NULL' if none.
+    static Clause*      lazy_reason;      // The mark when we need to ask the theory engine for a reason
+    vec<Clause*>        reason;           // 'reason[var]' is the clause that implied the variables current value, lazy_reason if theory propagated, or 'NULL' if none.
+
+    Clause* getReason(Lit l);             // Returns the reason, or asks the theory for an explanation
+
     vec<int>            level;            // 'level[var]' contains the level at which the assignment was made.
     int                 qhead;            // Head of queue (as index into the trail -- no more explicit propagation queue in MiniSat).
     int                 lhead;            // Head of the lemma stack (for backtracking)
@@ -181,6 +186,15 @@ protected:
     vec<Lit>            analyze_toclear;
     vec<Lit>            add_tmp;
 
+    enum TheoryCheckType {
+      // Quick check, but don't perform theory propagation
+      CHECK_WITHOUTH_PROPAGATION_QUICK,
+      // Check and perform theory propagation
+      CHECK_WITH_PROPAGATION_STANDARD,
+      // The SAT problem is satisfiable, perform a full theory check
+      CHECK_WITHOUTH_PROPAGATION_FINAL
+    };
+
     // Main internal methods:
     //
     void     insertVarOrder   (Var x);                                                 // Insert a variable in the decision order priority queue.
@@ -188,9 +202,10 @@ protected:
     void     newDecisionLevel ();                                                      // Begins a new decision level.
     void     uncheckedEnqueue (Lit p, Clause* from = NULL);                            // Enqueue a literal. Assumes value of literal is undefined.
     bool     enqueue          (Lit p, Clause* from = NULL);                            // Test if fact 'p' contradicts current state, enqueue otherwise.
-    Clause*  propagate        ();                                                      // Perform Boolean and Theory. Returns possibly conflicting clause.
+    Clause*  propagate        (TheoryCheckType type);                                  // Perform Boolean and Theory. Returns possibly conflicting clause.
     Clause*  propagateBool    ();                                                      // Perform Boolean propagation. Returns possibly conflicting clause.
-    Clause*  propagateTheory  ();                                                      // Perform Theory propagation. Returns possibly conflicting clause.
+    bool     propagateTheory  ();                                                      // Perform Theory propagation. Return true if any literals were asserted.
+    Clause*  theoryCheck      (theory::Theory::Effort effort);                         // Perform a theory satisfiability check. Returns possibly conflicting clause.
     void     cancelUntil      (int level);                                             // Backtrack until a certain level.
     void     analyze          (Clause* confl, vec<Lit>& out_learnt, int& out_btlevel); // (bt = backtrack)
     void     analyzeFinal     (Lit p, vec<Lit>& out_conflict);                         // COULD THIS BE IMPLEMENTED BY THE ORDINARIY "analyze" BY SOME REASONABLE GENERALIZATION?
@@ -216,7 +231,7 @@ protected:
 
     // Misc:
     //
-    int      decisionLevel    ()      const; // Gives the current decisionlevel.
+    int      decisionLevel    ()      const; // Gives the current decision level.
     uint32_t abstractLevel    (Var x) const; // Used to represent an abstraction of sets of decision levels.
     double   progressEstimate ()      const; // DELETE THIS ?? IT'S NOT VERY USEFUL ...
 
index 9aad6aea7990d89191ede3bf3ab475ca11ca4eab..00f93402f758fcc52500f13f5139ef92321442a2 100644 (file)
@@ -212,7 +212,7 @@ bool SimpSolver::strengthenClause(Clause& c, Lit l)
         updateElimHeap(var(l));
     }
 
-    return c.size() == 1 ? enqueue(c[0]) && propagate() == NULL : true;
+    return c.size() == 1 ? enqueue(c[0]) && propagate(CHECK_WITHOUTH_PROPAGATION_QUICK) == NULL : true;
 }
 
 
@@ -312,7 +312,7 @@ bool SimpSolver::implied(const vec<Lit>& c)
             uncheckedEnqueue(~c[i]);
         }
 
-    bool result = propagate() != NULL;
+    bool result = propagate(CHECK_WITHOUTH_PROPAGATION_QUICK) != NULL;
     cancelUntil(0);
     return result;
 }
@@ -394,7 +394,7 @@ bool SimpSolver::asymm(Var v, Clause& c)
         else
             l = c[i];
 
-    if (propagate() != NULL){
+    if (propagate(CHECK_WITHOUTH_PROPAGATION_QUICK) != NULL){
         cancelUntil(0);
         asymm_lits++;
         if (!strengthenClause(c, l))
index 207bda4db0562cf501ba8384c3e483151fb3fc9a..a7b536a57f515a3676c68efee994e94c3785f5da 100644 (file)
@@ -26,9 +26,9 @@
 namespace CVC4 {
 namespace prop {
 
-void SatSolver::theoryCheck(SatClause& conflict) {
+void SatSolver::theoryCheck(theory::Theory::Effort effort, SatClause& conflict) {
   // Try theory propagation
-  bool ok = d_theoryEngine->check(theory::Theory::FULL_EFFORT);
+  bool ok = d_theoryEngine->check(effort);
   // If in conflict construct the conflict clause
   if (!ok) {
     // We have a conflict, get it
@@ -47,6 +47,47 @@ void SatSolver::theoryCheck(SatClause& conflict) {
   }
 }
 
+void SatSolver::theoryPropagate(std::vector<SatLiteral>& output) {
+  // Propagate
+  d_theoryEngine->propagate();
+  // Get the propagated literals
+  const std::vector<TNode>& outputNodes = d_theoryEngine->getPropagatedLiterals();
+  // If any literals, make a clause
+  const unsigned i_end = outputNodes.size();
+  for (unsigned i = 0; i < i_end; ++ i) {
+    Debug("prop-explain") << "theoryPropagate() => " << outputNodes[i].toString() << endl;
+    // The second argument ("true") instructs the CNF stream to create
+    // a new literal mapping if it doesn't exist.  This can happen due
+    // to a circular dependence, if a SAT literal "a" is asserted as a
+    // unit to the SAT solver, a round of theory propagation can occur
+    // before all Nodes have SAT variable mappings.
+    SatLiteral l = d_cnfStream->getLiteral(outputNodes[i], true);
+    output.push_back(l);
+  }
+}
+
+void SatSolver::explainPropagation(SatLiteral l, SatClause& explanation) {
+  TNode lNode = d_cnfStream->getNode(l);
+  Debug("prop-explain") << "explainPropagation(" << lNode.toString() << ")" << endl;
+  Node theoryExplanation = d_theoryEngine->getExplanation(lNode);
+  Debug("prop-explain") << "explainPropagation() => " <<  theoryExplanation.toString() << endl;
+  if (lNode.getKind() == kind::AND) {
+    Node::const_iterator it = theoryExplanation.begin();
+    Node::const_iterator it_end = theoryExplanation.end();
+    explanation.push(l);
+    for (; it != it_end; ++ it) {
+      explanation.push(~d_cnfStream->getLiteral(*it));
+    }
+  } else {
+    explanation.push(l);
+    explanation.push(~d_cnfStream->getLiteral(theoryExplanation));
+  }
+}
+
+void SatSolver::clearPropagatedLiterals() {
+  d_theoryEngine->clearPropagatedLiterals();
+}
+
 void SatSolver::enqueueTheoryLiteral(const SatLiteral& l) {
   Node literalNode = d_cnfStream->getNode(l);
   Debug("prop") << "enqueueing theory literal " << l << " " << literalNode << std::endl;
index f64697d7bcaba8da4ca2d0723eb74d2a809cf6db..992d8ecd242781c62df6384b4172fa30857aab46 100644 (file)
@@ -27,6 +27,7 @@
 
 #include "util/options.h"
 #include "util/stats.h"
+#include "theory/theory.h"
 
 #ifdef __CVC4_USE_MINISAT
 
@@ -199,7 +200,13 @@ public:
 
   SatVariable newVar(bool theoryAtom = false);
 
-  void theoryCheck(SatClause& conflict);
+  void theoryCheck(theory::Theory::Effort effort, SatClause& conflict);
+
+  void explainPropagation(SatLiteral l, SatClause& explanation);
+
+  void theoryPropagate(std::vector<SatLiteral>& output);
+
+  void clearPropagatedLiterals();
 
   void enqueueTheoryLiteral(const SatLiteral& l);
 
@@ -229,6 +236,11 @@ inline SatSolver::SatSolver(PropEngine* propEngine, TheoryEngine* theoryEngine,
   // Make minisat reuse the literal values
   d_minisat->polarity_mode = minisat::SimpSolver::polarity_user;
 
+  // No random choices
+  if(debugTagIsOn("no_rnd_decisions")){
+    d_minisat->random_var_freq = 0;
+  }
+
   d_statistics.init(d_minisat);
 }
 
index 7cfc1571b976562902c8423943c2cbb9463f85cc..d0d2f23d7f1c312be13f2fe718ef4abe23898d21 100644 (file)
@@ -9,6 +9,7 @@ libtheory_la_SOURCES = \
        @srcdir@/theoryof_table.h \
        theory_engine.h \
        theory_engine.cpp \
+       theory_test_utils.h \
        theory.h \
        theory.cpp
 
index 83d44e285a2dd2a04b67264d0bf12804607ddcd3..37df73edd010da9ddff9a5aeb49c4544effb4b18 100644 (file)
@@ -15,10 +15,13 @@ libarith_la_SOURCES = \
        delta_rational.cpp \
        partial_model.h \
        partial_model.cpp \
+       ordered_bounds_list.h \
        basic.h \
        normal.h \
        slack.h \
        tableau.h \
+       arith_propagator.h \
+       arith_propagator.cpp \
        theory_arith.h \
        theory_arith.cpp
 
diff --git a/src/theory/arith/arith_propagator.cpp b/src/theory/arith/arith_propagator.cpp
new file mode 100644 (file)
index 0000000..e405750
--- /dev/null
@@ -0,0 +1,347 @@
+#include "theory/arith/arith_propagator.h"
+#include "theory/arith/arith_utilities.h"
+
+#include <list>
+
+using namespace CVC4;
+using namespace CVC4::theory;
+using namespace CVC4::theory::arith;
+using namespace CVC4::theory::arith::propagator;
+
+using namespace CVC4::kind;
+
+using namespace std;
+
+ArithUnatePropagator::ArithUnatePropagator(context::Context* cxt) :
+  d_assertions(cxt), d_pendingAssertions(cxt,0)
+{ }
+
+
+bool acceptedKinds(Kind k){
+  switch(k){
+  case EQUAL:
+  case LEQ:
+  case GEQ:
+    return true;
+  default:
+    return false;
+  }
+}
+
+void ArithUnatePropagator::addAtom(TNode atom){
+  Assert(acceptedKinds(atom.getKind()));
+
+  TNode left  = atom[0];
+  TNode right = atom[1];
+
+  if(!leftIsSetup(left)){
+    setupLefthand(left);
+  }
+
+  switch(atom.getKind()){
+  case EQUAL:
+    {
+      OrderedBoundsList* eqList = left.getAttribute(propagator::PropagatorEqList());
+      Assert(!eqList->contains(atom));
+      eqList->append(atom);
+      break;
+    }
+  case LEQ:
+    {
+      OrderedBoundsList* leqList = left.getAttribute(propagator::PropagatorLeqList());
+      Assert(! leqList->contains(atom));
+      leqList->append(atom);
+      break;
+    }
+    break;
+  case GEQ:
+    {
+      OrderedBoundsList* geqList = left.getAttribute(propagator::PropagatorGeqList());
+      Assert(! geqList->contains(atom));
+      geqList->append(atom);
+      break;
+    }
+  default:
+    Unreachable();
+  }
+}
+bool ArithUnatePropagator::leftIsSetup(TNode left){
+  return left.hasAttribute(propagator::PropagatorEqList());
+}
+
+void ArithUnatePropagator::setupLefthand(TNode left){
+  Assert(!leftIsSetup(left));
+
+  OrderedBoundsList* eqList = new OrderedBoundsList();
+  OrderedBoundsList* geqList = new OrderedBoundsList();
+  OrderedBoundsList* leqList = new OrderedBoundsList();
+
+  left.setAttribute(propagator::PropagatorEqList(), eqList);
+  left.setAttribute(propagator::PropagatorLeqList(), leqList);
+  left.setAttribute(propagator::PropagatorGeqList(), geqList);
+}
+
+void ArithUnatePropagator::assertLiteral(TNode lit){
+
+  if(lit.getKind() == NOT){
+    Assert(!lit[0].getAttribute(propagator::PropagatorMarked()));
+    lit[0].setAttribute(propagator::PropagatorMarked(), true);
+  }else{
+    Assert(!lit.getAttribute(propagator::PropagatorMarked()));
+    lit.setAttribute(propagator::PropagatorMarked(), true);
+  }
+  d_assertions.push_back(lit);
+}
+
+std::vector<Node> ArithUnatePropagator::getImpliedLiterals(){
+  std::vector<Node> impliedButNotAsserted;
+
+  while(d_pendingAssertions < d_assertions.size()){
+    TNode assertion = d_assertions[d_pendingAssertions];
+    d_pendingAssertions = d_pendingAssertions + 1;
+
+    enqueueImpliedLiterals(assertion, impliedButNotAsserted);
+  }
+
+  if(debugTagIsOn("arith::propagator")){
+    for(std::vector<Node>::iterator i = impliedButNotAsserted.begin(),
+          endIter = impliedButNotAsserted.end(); i != endIter; ++i){
+      Node imp = *i;
+      Debug("arith::propagator") << explain(imp) << " (prop)-> " << imp << endl;
+    }
+  }
+
+  return impliedButNotAsserted;
+}
+
+/** This function is effectively a case split. */
+void ArithUnatePropagator::enqueueImpliedLiterals(TNode lit, std::vector<Node>& buffer){
+  switch(lit.getKind()){
+  case EQUAL:
+    enqueueEqualityImplications(lit, buffer);
+    break;
+  case LEQ:
+    enqueueUpperBoundImplications(lit, lit, buffer);
+    break;
+  case GEQ:
+    enqueueLowerBoundImplications(lit, lit, buffer);
+    break;
+  case NOT:
+    {
+      TNode under = lit[0];
+      switch(under.getKind()){
+      case EQUAL:
+        //Do nothing
+        break;;
+      case LEQ:
+        enqueueLowerBoundImplications(under, lit, buffer);
+        break;
+      case GEQ:
+        enqueueUpperBoundImplications(under, lit, buffer);
+        break;
+      default:
+        Unreachable();
+      }
+      break;
+    }
+  default:
+    Unreachable();
+  }
+}
+
+/**
+ * An equality (x = c) has been asserted.
+ * In this case we can propagate everything by comparing against the other constants.
+ */
+void ArithUnatePropagator::enqueueEqualityImplications(TNode orig, std::vector<Node>& buffer){
+  TNode left = orig[0];
+  TNode right = orig[1];
+  const Rational& c = right.getConst<Rational>();
+
+  OrderedBoundsList* eqList = left.getAttribute(propagator::PropagatorEqList());
+  OrderedBoundsList* leqList = left.getAttribute(propagator::PropagatorLeqList());
+  OrderedBoundsList* geqList = left.getAttribute(propagator::PropagatorGeqList());
+
+
+  /* (x = c) /\ (c !=d) => (x != d)  */
+  for(OrderedBoundsList::iterator i = eqList->begin(); i != eqList->end(); ++i){
+    TNode eq = *i;
+    Assert(eq.getKind() == EQUAL);
+    if(!eq.getAttribute(propagator::PropagatorMarked())){ //Note that (x = c) is marked
+      Assert(eq[1].getConst<Rational>() != c);
+
+      eq.setAttribute(propagator::PropagatorMarked(), true);
+
+      Node neq = NodeManager::currentNM()->mkNode(NOT, eq);
+      neq.setAttribute(propagator::PropagatorExplanation(), orig);
+      buffer.push_back(neq);
+    }
+  }
+  for(OrderedBoundsList::iterator i = leqList->begin(); i != leqList->end(); ++i){
+    TNode leq = *i;
+    Assert(leq.getKind() == LEQ);
+    if(!leq.getAttribute(propagator::PropagatorMarked())){
+      leq.setAttribute(propagator::PropagatorMarked(), true);
+      const Rational& d = leq[1].getConst<Rational>();
+      if(c <= d){
+        /* (x = c) /\ (c <= d) => (x <= d)  */
+        leq.setAttribute(propagator::PropagatorExplanation(), orig);
+        buffer.push_back(leq);
+      }else{
+        /* (x = c) /\ (c > d) => (x > d)  */
+        Node gt = NodeManager::currentNM()->mkNode(NOT, leq);
+        gt.setAttribute(propagator::PropagatorExplanation(), orig);
+        buffer.push_back(gt);
+      }
+    }
+  }
+
+  for(OrderedBoundsList::iterator i = geqList->begin(); i != geqList->end(); ++i){
+    TNode geq = *i;
+    Assert(geq.getKind() == GEQ);
+    if(!geq.getAttribute(propagator::PropagatorMarked())){
+      geq.setAttribute(propagator::PropagatorMarked(), true);
+      const Rational& d = geq[1].getConst<Rational>();
+      if(c >= d){
+        /* (x = c) /\ (c >= d) => (x >= d)  */
+        geq.setAttribute(propagator::PropagatorExplanation(), orig);
+        buffer.push_back(geq);
+      }else{
+        /* (x = c) /\ (c >= d) => (x >= d)  */
+        Node lt = NodeManager::currentNM()->mkNode(NOT, geq);
+        lt.setAttribute(propagator::PropagatorExplanation(), orig);
+        buffer.push_back(lt);
+      }
+    }
+  }
+}
+
+void ArithUnatePropagator::enqueueUpperBoundImplications(TNode atom, TNode orig, std::vector<Node>& buffer){
+
+  Assert(atom.getKind() == LEQ || (orig.getKind() == NOT && atom.getKind() == GEQ));
+
+  TNode left = atom[0];
+  TNode right = atom[1];
+  const Rational& c = right.getConst<Rational>();
+
+  OrderedBoundsList* eqList = left.getAttribute(propagator::PropagatorEqList());
+  OrderedBoundsList* leqList = left.getAttribute(propagator::PropagatorLeqList());
+  OrderedBoundsList* geqList = left.getAttribute(propagator::PropagatorGeqList());
+
+
+  //For every node (x <= d), we will restrict ourselves to look at the cases when (d >= c)
+  for(OrderedBoundsList::iterator i = leqList->lower_bound(atom); i != leqList->end(); ++i){
+    TNode leq = *i;
+    Assert(leq.getKind() == LEQ);
+    if(!leq.getAttribute(propagator::PropagatorMarked())){
+      leq.setAttribute(propagator::PropagatorMarked(), true);
+      const Rational& d = leq[1].getConst<Rational>();
+      Assert( c <= d );
+
+      leq.setAttribute(propagator::PropagatorExplanation(), orig);
+      buffer.push_back(leq); // (x<=c) /\ (c <= d) => (x <= d)
+      //Note that if c=d, that at the node is not marked this can only be reached when (x < c)
+      //So we do not have to worry about a circular dependency
+    }else if(leq != atom){
+      break; //No need to examine the rest, this atom implies the rest of the possible propagataions
+    }
+  }
+
+  for(OrderedBoundsList::iterator i = geqList->upper_bound(atom); i != geqList->end(); ++i){
+    TNode geq = *i;
+    Assert(geq.getKind() == GEQ);
+    if(!geq.getAttribute(propagator::PropagatorMarked())){
+      geq.setAttribute(propagator::PropagatorMarked(), true);
+      const Rational& d = geq[1].getConst<Rational>();
+      Assert( c < d );
+
+      Node lt = NodeManager::currentNM()->mkNode(NOT, geq);
+      lt.setAttribute(propagator::PropagatorExplanation(), orig);
+      buffer.push_back(lt); // x<=c /\ d > c => x < d
+    }else{
+      break; //No need to examine this atom implies the rest
+    }
+  }
+
+  for(OrderedBoundsList::iterator i = eqList->upper_bound(atom); i != eqList->end(); ++i){
+    TNode eq = *i;
+    Assert(eq.getKind() == EQUAL);
+    if(!eq.getAttribute(propagator::PropagatorMarked())){
+      eq.setAttribute(propagator::PropagatorMarked(), true);
+      const Rational& d = eq[1].getConst<Rational>();
+      Assert( c < d );
+
+      Node neq = NodeManager::currentNM()->mkNode(NOT, eq);
+      neq.setAttribute(propagator::PropagatorExplanation(), orig);
+      buffer.push_back(neq); // x<=c /\ c < d => x !=  d
+    }
+  }
+}
+
+void ArithUnatePropagator::enqueueLowerBoundImplications(TNode atom, TNode orig, std::vector<Node>& buffer){
+
+  Assert(atom.getKind() == GEQ || (orig.getKind() == NOT && atom.getKind() == LEQ));
+
+  TNode left = atom[0];
+  TNode right = atom[1];
+  const Rational& c = right.getConst<Rational>();
+
+  OrderedBoundsList* eqList = left.getAttribute(propagator::PropagatorEqList());
+  OrderedBoundsList* leqList = left.getAttribute(propagator::PropagatorLeqList());
+  OrderedBoundsList* geqList = left.getAttribute(propagator::PropagatorGeqList());
+
+
+  for(OrderedBoundsList::reverse_iterator i = geqList->reverse_lower_bound(atom);
+      i != geqList->rend(); i++){
+    TNode geq = *i;
+    Assert(geq.getKind() == GEQ);
+    if(!geq.getAttribute(propagator::PropagatorMarked())){
+      geq.setAttribute(propagator::PropagatorMarked(), true);
+      const Rational& d = geq[1].getConst<Rational>();
+      Assert( c >= d );
+
+      geq.setAttribute(propagator::PropagatorExplanation(), orig);
+      buffer.push_back(geq); // x>=c /\ c >= d => x >= d
+    }else if(geq != atom){
+      break; //No need to examine the rest, this atom implies the rest of the possible propagataions
+    }
+  }
+
+  for(OrderedBoundsList::reverse_iterator i = leqList->reverse_upper_bound(atom);
+      i != leqList->rend(); ++i){
+    TNode leq = *i;
+    Assert(leq.getKind() == LEQ);
+    if(!leq.getAttribute(propagator::PropagatorMarked())){
+      leq.setAttribute(propagator::PropagatorMarked(), true);
+      const Rational& d = leq[1].getConst<Rational>();
+      Assert( c > d );
+
+      Node gt = NodeManager::currentNM()->mkNode(NOT, leq);
+      gt.setAttribute(propagator::PropagatorExplanation(), orig);
+      buffer.push_back(gt); // x>=c /\ d < c => x > d
+    }else{
+      break; //No need to examine this atom implies the rest
+    }
+  }
+
+  for(OrderedBoundsList::reverse_iterator i = eqList->reverse_upper_bound(atom);
+      i != eqList->rend(); ++i){
+    TNode eq = *i;
+    Assert(eq.getKind() == EQUAL);
+    if(!eq.getAttribute(propagator::PropagatorMarked())){
+      eq.setAttribute(propagator::PropagatorMarked(), true);
+      const Rational& d = eq[1].getConst<Rational>();
+      Assert( c > d );
+
+      Node neq = NodeManager::currentNM()->mkNode(NOT, eq);
+      neq.setAttribute(propagator::PropagatorExplanation(), orig);
+      buffer.push_back(neq); // x>=c /\ c > d => x !=  d
+    }
+  }
+
+}
+
+Node ArithUnatePropagator::explain(TNode lit){
+  Assert(lit.hasAttribute(propagator::PropagatorExplanation()));
+  return lit.getAttribute(propagator::PropagatorExplanation());
+}
diff --git a/src/theory/arith/arith_propagator.h b/src/theory/arith/arith_propagator.h
new file mode 100644 (file)
index 0000000..a623517
--- /dev/null
@@ -0,0 +1,111 @@
+
+
+#include "cvc4_private.h"
+
+#ifndef __CVC4__THEORY__ARITH__ARITH_PROPAGATOR_H
+#define __CVC4__THEORY__ARITH__ARITH_PROPAGATOR_H
+
+#include "expr/node.h"
+#include "context/cdlist.h"
+#include "context/context.h"
+#include "context/cdo.h"
+#include "theory/arith/ordered_bounds_list.h"
+
+#include <algorithm>
+#include <vector>
+
+namespace CVC4 {
+namespace theory {
+namespace arith {
+
+class ArithUnatePropagator {
+private:
+  /** Index of assertions. */
+  context::CDList<Node> d_assertions;
+
+  /** Index of the last assertion in d_assertions to be asserted. */
+  context::CDO<unsigned int> d_pendingAssertions;
+
+public:
+  ArithUnatePropagator(context::Context* cxt);
+
+  /**
+   * Adds a new atom for the propagator to watch.
+   * Atom is assumed to have been rewritten by TheoryArith::rewrite().
+   */
+  void addAtom(TNode atom);
+
+  /**
+   * Informs the propagator that a literal has been asserted to the theory.
+   */
+  void assertLiteral(TNode lit);
+
+
+  /**
+   * returns a vector of literals that are 
+   */
+  std::vector<Node> getImpliedLiterals();
+
+  /** Explains a literal that was asserted in the current context. */
+  Node explain(TNode lit);
+
+private:
+  /** returns true if the left hand side side left has been setup. */
+  bool leftIsSetup(TNode left);
+
+  /**
+   * Sets up a left hand side.
+   * This initializes the attributes PropagatorEqList, PropagatorGeqList, and PropagatorLeqList for left.
+   */
+  void setupLefthand(TNode left);
+
+  /**
+   * Given that the literal lit is now asserted,
+   * enqueue additional entailed assertions in buffer.
+   */
+  void enqueueImpliedLiterals(TNode lit, std::vector<Node>& buffer);
+
+  void enqueueEqualityImplications(TNode original, std::vector<Node>& buffer);
+  void enqueueLowerBoundImplications(TNode atom, TNode original, std::vector<Node>& buffer);
+  /**
+   * Given that the literal original is now asserted, which is either (<= x c) or (not (>= x c)),
+   * enqueue additional entailed assertions in buffer.
+   */
+  void enqueueUpperBoundImplications(TNode atom, TNode original, std::vector<Node>& buffer);
+};
+
+
+
+namespace propagator {
+
+/** Basic memory management wrapper for deleting PropagatorEqList, PropagatorGeqList, and PropagatorLeqList.*/
+struct ListCleanupStrategy{
+  static void cleanup(OrderedBoundsList* l){
+    Debug("arithgc") << "cleaning up  " << l << "\n";
+    delete l;
+  }
+};
+
+
+struct PropagatorEqListID {};
+typedef expr::Attribute<PropagatorEqListID, OrderedBoundsList*, ListCleanupStrategy> PropagatorEqList;
+
+struct PropagatorGeqListID {};
+typedef expr::Attribute<PropagatorGeqListID, OrderedBoundsList*, ListCleanupStrategy> PropagatorGeqList;
+
+struct PropagatorLeqListID {};
+typedef expr::Attribute<PropagatorLeqListID, OrderedBoundsList*, ListCleanupStrategy> PropagatorLeqList;
+
+
+struct PropagatorMarkedID {};
+typedef expr::CDAttribute<PropagatorMarkedID, bool> PropagatorMarked;
+
+struct PropagatorExplanationID {};
+typedef expr::CDAttribute<PropagatorExplanationID, Node> PropagatorExplanation;
+}/* CVC4::theory::arith::propagator */
+
+}/* CVC4::theory::arith namespace */
+}/* CVC4::theory namespace */
+}/* CVC4 namespace */
+
+#endif /* __CVC4__THEORY__ARITH__THEORY_ARITH_H */
diff --git a/src/theory/arith/ordered_bounds_list.h b/src/theory/arith/ordered_bounds_list.h
new file mode 100644 (file)
index 0000000..d21283a
--- /dev/null
@@ -0,0 +1,212 @@
+
+
+#include "cvc4_private.h"
+
+
+#ifndef __CVC4__THEORY__ARITH__ORDERED_BOUNDS_LIST_H
+#define __CVC4__THEORY__ARITH__ORDERED_BOUNDS_LIST_H
+
+#include "expr/node.h"
+#include "util/rational.h"
+#include "expr/kind.h"
+
+#include <vector>
+#include <algorithm>
+
+namespace CVC4 {
+namespace theory {
+namespace arith {
+
+struct RightHandRationalLT
+{
+  bool operator()(TNode s1, TNode s2) const
+  {
+    Assert(s1.getNumChildren() >= 2);
+    Assert(s2.getNumChildren() >= 2);
+
+    Assert(s1[1].getKind() == kind::CONST_RATIONAL);
+    Assert(s2[1].getKind() == kind::CONST_RATIONAL);
+
+    TNode rh1 = s1[1];
+    TNode rh2 = s2[1];
+    const Rational& c1 = rh1.getConst<Rational>();
+    const Rational& c2 = rh2.getConst<Rational>();
+    return c1.cmp(c2) < 0;
+  }
+};
+
+struct RightHandRationalGT
+{
+  bool operator()(TNode s1, TNode s2) const
+  {
+    Assert(s1.getNumChildren() >= 2);
+    Assert(s2.getNumChildren() >= 2);
+
+    Assert(s1[1].getKind() == kind::CONST_RATIONAL);
+    Assert(s2[1].getKind() == kind::CONST_RATIONAL);
+
+    TNode rh1 = s1[1];
+    TNode rh2 = s2[1];
+    const Rational& c1 = rh1.getConst<Rational>();
+    const Rational& c2 = rh2.getConst<Rational>();
+    return c1.cmp(c2) > 0;
+  }
+};
+
+/**
+ * An OrderedBoundsList is a lazily sorted vector of Arithmetic constraints.
+ * The intended use is for a list of rewriting arithmetic atoms.
+ * An example of such a list would be [(<= x 5);(= y 78); (>= x 9)].
+ *
+ * Nodes are required to have a CONST_RATIONAL child as their second node.
+ * Nodes are sorted in increasing order according to RightHandRationalLT.
+ *
+ * The lists are lazily sorted in the sense that the list is not sorted until
+ * an operation to access the element is attempted.
+ *
+ * An append() may make the list no longer sorted.
+ * After an append() operation all iterators for the list become invalid.
+ */
+class OrderedBoundsList {
+private:
+  bool d_isSorted;
+  std::vector<Node> d_list;
+
+public:
+  typedef std::vector<Node>::const_iterator iterator;
+  typedef std::vector<Node>::const_reverse_iterator reverse_iterator;
+
+  /**
+   * Constucts a new and empty OrderBoundsList.
+   * The empty list is initially sorted.
+   */
+  OrderedBoundsList() : d_isSorted(true){}
+
+  /**
+   * Appends a node onto the back of the list.
+   * The list may no longer be sorted.
+   */
+  void append(TNode n){
+    Assert(n.getNumChildren() >= 2);
+    Assert(n[1].getKind() == kind::CONST_RATIONAL);
+    d_isSorted = false;
+    d_list.push_back(n);
+  }
+
+  /** returns the size of the list */
+  unsigned int size(){
+    return d_list.size();
+  }
+
+  /** returns the i'th element in the sort list. This may sort the list.*/
+  TNode at(unsigned int idx){
+    sortIfNeeded();
+    return d_list.at(idx);
+  }
+
+  /** returns true if the list is known to be sorted. */
+  bool isSorted() const{
+    return d_isSorted;
+  }
+
+  /** sorts the list. */
+  void sort(){
+    d_isSorted = true;
+    std::sort(d_list.begin(), d_list.end(), RightHandRationalLT());
+  }
+
+  /**
+   * returns an iterator to the list that iterates in ascending order.
+   * This may sort the list.
+   */
+  iterator begin(){
+    sortIfNeeded();
+    return d_list.begin();
+  }
+  /**
+   * returns an iterator to the end of the list when interating in ascending order.
+   */
+  iterator end() const{
+    return d_list.end();
+  }
+
+  /**
+   * returns an iterator to the list that iterates in descending order.
+   * This may sort the list.
+   */
+  reverse_iterator rbegin(){
+    sortIfNeeded();
+    return d_list.rend();
+  }
+  /**
+   * returns an iterator to the end of the list when interating in descending order.
+   */
+  reverse_iterator rend() const{
+    return d_list.rend();
+  }
+
+  /**
+   * returns an iterator to the least strict upper bound of value.
+   * if the list is [(<= x 2);(>= x 80);(< y 70)]
+   * then *upper_bound((< z 70)) == (>= x 80)
+   *
+   * This may sort the list.
+   * see stl::upper_bound for more information.
+   */
+  iterator upper_bound(TNode value){
+    sortIfNeeded();
+    return std::upper_bound(begin(), end(), value, RightHandRationalLT());
+  }
+  /**
+   * returns an iterator to the greatest lower bound of value.
+   * This is bound is not strict.
+   * if the list is [(<= x 2);(>= x 80);(< y 70)]
+   * then *lower_bound((< z 70)) == (< y 70)
+   *
+   * This may sort the list.
+   * see stl::upper_bound for more information.
+   */
+  iterator lower_bound(TNode value){
+    sortIfNeeded();
+    return std::lower_bound(begin(), end(), value, RightHandRationalLT());
+  }
+  /**
+   * see OrderedBoundsList::upper_bound for more information.
+   * The difference is that the iterator goes in descending order.
+   */
+  reverse_iterator reverse_upper_bound(TNode value){
+    sortIfNeeded();
+    return std::upper_bound(rbegin(), rend(), value, RightHandRationalGT());
+  }
+  /**
+   * see OrderedBoundsList::lower_bound for more information.
+   * The difference is that the iterator goes in descending order.
+   */
+  reverse_iterator reverse_lower_bound(TNode value){
+    sortIfNeeded();
+    return std::lower_bound(rbegin(), rend(), value, RightHandRationalGT());
+  }
+
+  /**
+   * This is an O(n) method for searching the array to check if it contains n.
+   */
+  bool contains(TNode n) const {
+    for(std::vector<Node>::const_iterator i = d_list.begin(); i != d_list.end(); ++i){
+      if(*i == n) return true;
+    }
+    return false;
+  }
+private:
+  /** Sorts the list if it is not already sorted.  */
+  void sortIfNeeded(){
+    if(!d_isSorted){
+      sort();
+    }
+  }
+};
+
+}/* CVC4::theory::arith namespace */
+}/* CVC4::theory namespace */
+}/* CVC4 namespace */
+
+#endif /* __CVC4__THEORY__ARITH__ORDERED_BOUNDS_LIST_H */
index b3b7f58be10cf198d954a183976881e37a3c77a0..bd35e07976ec0e14d43c1708884a27131c16a836 100644 (file)
@@ -34,6 +34,7 @@
 #include "theory/arith/basic.h"
 
 #include "theory/arith/arith_rewriter.h"
+#include "theory/arith/arith_propagator.h"
 
 #include "theory/arith/theory_arith.h"
 #include <map>
@@ -55,6 +56,7 @@ TheoryArith::TheoryArith(context::Context* c, OutputChannel& out) :
   d_partialModel(c),
   d_diseq(c),
   d_rewriter(&d_constants),
+  d_propagator(c),
   d_statistics()
 {
   uint64_t ass_id = partial_model::Assignment::getId();
@@ -81,6 +83,15 @@ TheoryArith::Statistics::Statistics():
   StatisticsRegistry::registerStat(&d_statUpdateConflicts);
 }
 
+TheoryArith::Statistics::~Statistics(){
+  StatisticsRegistry::unregisterStat(&d_statPivots);
+  StatisticsRegistry::unregisterStat(&d_statUpdates);
+  StatisticsRegistry::unregisterStat(&d_statAssertUpperConflicts);
+  StatisticsRegistry::unregisterStat(&d_statAssertLowerConflicts);
+  StatisticsRegistry::unregisterStat(&d_statUpdateConflicts);
+}
+
+
 bool isBasicSum(TNode n){
   if(n.getKind() != kind::PLUS) return false;
 
@@ -143,6 +154,8 @@ void TheoryArith::preRegisterTerm(TNode n) {
     Assert(isNormalAtom(n));
 
 
+    d_propagator.addAtom(n);
+
     TNode left  = n[0];
     TNode right = n[1];
     if(left.getKind() == PLUS){
@@ -206,6 +219,10 @@ void TheoryArith::setupVariable(TNode x){
     //lower bound. This is done to strongly enforce the notion that basic
     //variables should not be changed without begin checked.
 
+    //Strictly speaking checking x is unnessecary as it cannot have an upper or
+    //lower bound. This is done to strongly enforce the notion that basic
+    //variables should not be changed without begin checked.
+
   }
   Debug("arithgc") << "setupVariable("<<x<<")"<<std::endl;
 };
@@ -682,41 +699,69 @@ void TheoryArith::check(Effort level){
   Debug("arith") << "TheoryArith::check begun" << std::endl;
 
   while(!done()){
+
     Node original = get();
     Node assertion = simulatePreprocessing(original);
     Debug("arith_assertions") << "arith assertion(" << original
                               << " \\-> " << assertion << ")" << std::endl;
 
+    d_propagator.assertLiteral(original);
     bool conflictDuringAnAssert = assertionCases(original, assertion);
 
+
     if(conflictDuringAnAssert){
-      if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); }
       d_partialModel.revertAssignmentChanges();
-      if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); }
-
       return;
     }
   }
 
-  if(fullEffort(level)){
-    if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); }
+  //TODO This must be done everytime for the time being
+  if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); }
 
-    Node possibleConflict = updateInconsistentVars();
-    if(possibleConflict != Node::null()){
+  Node possibleConflict = updateInconsistentVars();
+  if(possibleConflict != Node::null()){
 
-      d_partialModel.revertAssignmentChanges();
+    d_partialModel.revertAssignmentChanges();
 
-      d_out->conflict(possibleConflict, true);
+    if(debugTagIsOn("arith::print-conflict"))
+      Debug("arith_conflict") << (possibleConflict) << std::endl;
 
-      Debug("arith_conflict") <<"Found a conflict "<< possibleConflict << endl;
-    }else{
-      d_partialModel.commitAssignmentChanges();
-    }
-    if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); }
+    d_out->conflict(possibleConflict);
+
+    Debug("arith_conflict") <<"Found a conflict "<< possibleConflict << endl;
+  }else{
+    d_partialModel.commitAssignmentChanges();
   }
+  if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); }
+
 
   Debug("arith") << "TheoryArith::check end" << std::endl;
 
+  if(debugTagIsOn("arith::print_model")) {
+    Debug("arith::print_model") << "Model:" << endl;
+
+    for (unsigned i = 0; i < d_variables.size(); ++ i) {
+      Debug("arith::print_model") << d_variables[i] << " : " <<
+        d_partialModel.getAssignment(d_variables[i]);
+      if(isBasic(d_variables[i]))
+        Debug("arith::print_model") << " (basic)";
+      Debug("arith::print_model") << endl;
+    }
+  }
+  if(debugTagIsOn("arith::print_assertions")) {
+    Debug("arith::print_assertions") << "Assertions:" << endl;
+    for (unsigned i = 0; i < d_variables.size(); ++ i) {
+      Node x = d_variables[i];
+      if (x.hasAttribute(partial_model::LowerConstraint())) {
+        Node constr = d_partialModel.getLowerConstraint(x);
+        Debug("arith::print_assertions") << constr.toString() << endl;
+      }
+      if (x.hasAttribute(partial_model::UpperConstraint())) {
+        Node constr = d_partialModel.getUpperConstraint(x);
+        Debug("arith::print_assertions") << constr.toString() << endl;
+      }
+    }
+  }
 }
 
 /**
@@ -750,3 +795,23 @@ void TheoryArith::checkTableau(){
     Assert(sum == shouldBe);
   }
 }
+
+
+void TheoryArith::explain(TNode n, Effort e) {
+  Node explanation = d_propagator.explain(n);
+  Debug("arith") << "arith::explain("<<explanation<<")->"
+                 << explanation << endl;
+  d_out->explanation(explanation, true);
+}
+
+void TheoryArith::propagate(Effort e) {
+
+  if(quickCheckOrMore(e)){
+    std::vector<Node> implied = d_propagator.getImpliedLiterals();
+    for(std::vector<Node>::iterator i = implied.begin();
+        i != implied.end();
+        ++i){
+      d_out->propagate(*i);
+    }
+  }
+}
index aff60f651e3955b7b5ae5ae4a7048be94d2e22a7..c76923bee77a9ba0d1f730dbc40cd9fdc9d38585 100644 (file)
@@ -30,6 +30,7 @@
 #include "theory/arith/tableau.h"
 #include "theory/arith/arith_rewriter.h"
 #include "theory/arith/partial_model.h"
+#include "theory/arith/arith_propagator.h"
 
 #include "util/stats.h"
 
@@ -96,6 +97,7 @@ private:
    */
   ArithRewriter d_rewriter;
 
+  ArithUnatePropagator d_propagator;
 
 public:
   TheoryArith(context::Context* c, OutputChannel& out);
@@ -115,8 +117,8 @@ public:
   void registerTerm(TNode n);
 
   void check(Effort e);
-  void propagate(Effort e) { Unimplemented(); }
-  void explain(TNode n, Effort e) { Unimplemented(); }
+  void propagate(Effort e);
+  void explain(TNode n, Effort e);
 
   void shutdown(){ }
 
@@ -242,6 +244,7 @@ private:
     IntStat d_statAssertLowerConflicts, d_statUpdateConflicts;
 
     Statistics();
+    ~Statistics();
   };
 
   Statistics d_statistics;
index e06c9594ccbef68378a932fab1c15817ebfae9c7..5e83d3728df509b635b398bd3d32f68e9c444853 100644 (file)
@@ -103,5 +103,21 @@ Node Theory::get() {
   return fact;
 }
 
+std::ostream& operator<<(std::ostream& os, Theory::Effort level){
+  switch(level){
+  case Theory::MIN_EFFORT:
+    os << "MIN_EFFORT"; break;
+  case Theory::QUICK_CHECK:
+    os << "QUICK_CHECK:"; break;
+  case Theory::STANDARD:
+    os << "STANDARD"; break;
+  case Theory::FULL_EFFORT:
+    os << "FULL_EFFORT"; break;
+  default:
+      Unreachable();
+  }
+  return os;
+}
+
 }/* CVC4::theory namespace */
 }/* CVC4 namespace */
index 1bf6f660c532b6b1dcbf17c6e4d55878667230a3..6f4effe7867f716482704b6077edb230076f36cd 100644 (file)
@@ -331,6 +331,8 @@ protected:
 
 };/* class Theory */
 
+std::ostream& operator<<(std::ostream& os, Theory::Effort level);
+
 }/* CVC4::theory namespace */
 }/* CVC4 namespace */
 
index c2511f4e64102846f354b500c28676f0d7ac4b30..15b406cddd3613e69655dd8d7838317c578e76da 100644 (file)
@@ -64,13 +64,22 @@ class TheoryEngine {
     TheoryEngine* d_engine;
     context::Context* d_context;
     context::CDO<Node> d_conflictNode;
+    context::CDO<Node> d_explanationNode;
+
+    /**
+     * Literals that are propagated by the theory. Note that these are TNodes.
+     * The theory can only propagate nodes that have an assigned literal in the
+     * sat solver and are hence referenced in the SAT solver.
+     */
+    std::vector<TNode> d_propagatedLiterals;
 
   public:
 
     EngineOutputChannel(TheoryEngine* engine, context::Context* context) :
       d_engine(engine),
       d_context(context),
-      d_conflictNode(context) {
+      d_conflictNode(context),
+      d_explanationNode(context){
     }
 
     void conflict(TNode conflictNode, bool safe) throw(theory::Interrupted, AssertionException) {
@@ -82,7 +91,9 @@ class TheoryEngine {
       }
     }
 
-    void propagate(TNode, bool) throw(theory::Interrupted, AssertionException) {
+    void propagate(TNode lit, bool) throw(theory::Interrupted, AssertionException) {
+      d_propagatedLiterals.push_back(lit);
+      ++(d_engine->d_statistics.d_statPropagate);
       ++(d_engine->d_statistics.d_statPropagate);
     }
 
@@ -94,7 +105,9 @@ class TheoryEngine {
       ++(d_engine->d_statistics.d_statAugLemma);
       d_engine->newAugmentingLemma(node);
     }
-    void explanation(TNode, bool) throw(theory::Interrupted, AssertionException) {
+    void explanation(TNode explanationNode, bool) throw(theory::Interrupted, AssertionException) {
+      d_explanationNode = explanationNode;
+      ++(d_engine->d_statistics.d_statExplanatation);
       ++(d_engine->d_statistics.d_statExplanatation);
     }
   };
@@ -302,6 +315,7 @@ public:
   inline bool check(theory::Theory::Effort effort)
   {
     d_theoryOut.d_conflictNode = Node::null();
+    d_theoryOut.d_propagatedLiterals.clear();
     // Do the checking
     try {
       //d_bool.check(effort);
@@ -316,13 +330,23 @@ public:
     return d_theoryOut.d_conflictNode.get().isNull();
   }
 
+  inline const std::vector<TNode>& getPropagatedLiterals() const {
+    return d_theoryOut.d_propagatedLiterals;
+  }
+
+  void clearPropagatedLiterals() {
+    d_theoryOut.d_propagatedLiterals.clear();
+  }
+
   inline void newLemma(TNode node) {
     d_propEngine->assertLemma(node);
   }
+
   inline void newAugmentingLemma(TNode node) {
     Node preprocessed = preprocess(node);
     d_propEngine->assertFormula(preprocessed);
   }
+
   /**
    * Returns the last conflict (if any).
    */
@@ -330,6 +354,21 @@ public:
     return d_theoryOut.d_conflictNode;
   }
 
+  inline void propagate() {
+    d_theoryOut.d_propagatedLiterals.clear();
+    // Do the propagation
+    d_uf.propagate(theory::Theory::FULL_EFFORT);
+    d_arith.propagate(theory::Theory::FULL_EFFORT);
+  }
+
+  inline Node getExplanation(TNode node){
+    d_theoryOut.d_explanationNode = Node::null();
+    theory::Theory* theory =
+              node.getKind() == kind::NOT ? theoryOf(node[0]) : theoryOf(node);
+    theory->explain(node);
+    return d_theoryOut.d_explanationNode;
+  }
+
 private:
   class Statistics {
   public:
@@ -350,6 +389,7 @@ private:
   };
   Statistics d_statistics;
 
+
 };/* class TheoryEngine */
 
 }/* CVC4 namespace */
diff --git a/src/theory/theory_test_utils.h b/src/theory/theory_test_utils.h
new file mode 100644 (file)
index 0000000..dc08788
--- /dev/null
@@ -0,0 +1,81 @@
+
+
+#include "cvc4_public.h"
+
+
+#ifndef __CVC4__THEORY__THEORY_TEST_UTILS_H
+#define __CVC4__THEORY__ITHEORY_TEST_UTILS_H
+
+#include "util/Assert.h"
+#include "expr/node.h"
+#include "theory/output_channel.h"
+#include "theory/interrupted.h"
+
+#include <vector>
+
+namespace CVC4{
+
+namespace theory {
+
+/**
+ * Very basic OutputChannel for testing simple Theory Behaviour.
+ * Stores a call sequence for the output channel
+ */
+enum OutputChannelCallType { CONFLICT, PROPOGATE, AUG_LEMMA, LEMMA, EXPLANATION };
+
+
+class TestOutputChannel : public theory::OutputChannel {
+public:
+  std::vector< pair<enum OutputChannelCallType, Node> > d_callHistory;
+
+  TestOutputChannel() {}
+
+  ~TestOutputChannel() {}
+
+  void safePoint()  throw(Interrupted, AssertionException) {}
+
+  void conflict(TNode n, bool safe = false)  throw(Interrupted, AssertionException) {
+    push(CONFLICT, n);
+  }
+
+  void propagate(TNode n, bool safe = false)  throw(Interrupted, AssertionException) {
+    push(PROPOGATE, n);
+  }
+
+  void lemma(TNode n, bool safe = false) throw(Interrupted, AssertionException) {
+    push(LEMMA, n);
+  }
+  void augmentingLemma(TNode n, bool safe = false) throw(Interrupted, AssertionException){
+    push(AUG_LEMMA, n);
+  }
+  void explanation(TNode n, bool safe = false)  throw(Interrupted, AssertionException) {
+    push(EXPLANATION, n);
+  }
+
+  void clear() {
+    d_callHistory.clear();
+  }
+
+  Node getIthNode(int i) {
+    Node tmp = (d_callHistory[i]).second;
+    return tmp;
+  }
+
+  OutputChannelCallType getIthCallType(int i) {
+    return (d_callHistory[i]).first;
+  }
+
+  unsigned getNumCalls() {
+    return d_callHistory.size();
+  }
+
+private:
+  void push(OutputChannelCallType call, TNode n) {
+    d_callHistory.push_back(make_pair(call,n));
+  }
+};/* class TestOutputChannel */
+
+}/* namespace theory */
+}/* namespace CVC4 */
+
+#endif /* __CVC4__THEORY__THEORY_TEST_UTILS_H */
index d13baf6a944bcec3b97958f93c410b4361e17265..f440c3d0f44d453a5ce56a2c2ae3ae72cdc240c0 100644 (file)
@@ -307,7 +307,7 @@ void TheoryUF::check(Effort level) {
     merge();
   }
 
-  if(fullEffort(level)) {
+  if(standardEffortOrMore(level)) {
     for(CDList<Node>::const_iterator diseqIter = d_disequality.begin();
         diseqIter != d_disequality.end();
         ++diseqIter) {
index 9f8379d5437156fcec5a23f81607eb4d9369ec0f..ddab915bf7af9484c14c66dffc3eeaa11b6d7bfc 100644 (file)
@@ -23,6 +23,7 @@ UNIT_TESTS = \
        context/cdmap_white \
        theory/theory_black \
        theory/theory_uf_white \
+       theory/theory_arith_white \
        util/assert_white \
        util/bitvector_black \
        util/configuration_black \
diff --git a/test/unit/theory/theory_arith_white.h b/test/unit/theory/theory_arith_white.h
new file mode 100644 (file)
index 0000000..fe9cbb3
--- /dev/null
@@ -0,0 +1,312 @@
+
+#include <cxxtest/TestSuite.h>
+
+#include "theory/theory.h"
+#include "theory/arith/theory_arith.h"
+#include "expr/node.h"
+#include "expr/node_manager.h"
+#include "context/context.h"
+#include "util/rational.h"
+
+#include "theory/theory_test_utils.h"
+
+#include <vector>
+
+using namespace CVC4;
+using namespace CVC4::theory;
+using namespace CVC4::theory::arith;
+using namespace CVC4::expr;
+using namespace CVC4::context;
+using namespace CVC4::kind;
+
+using namespace std;
+
+class TheoryArithWhite : public CxxTest::TestSuite {
+
+  Context* d_ctxt;
+  NodeManager* d_nm;
+  NodeManagerScope* d_scope;
+
+  TestOutputChannel d_outputChannel;
+  Theory::Effort d_level;
+
+  TheoryArith* d_arith;
+
+  TypeNode* d_booleanType;
+  TypeNode* d_realType;
+
+  const Rational d_zero;
+  const Rational d_one;
+
+  std::set<Node>* preregistered;
+
+  bool debug;
+
+public:
+
+  TheoryArithWhite() : d_level(Theory::FULL_EFFORT), d_zero(0), d_one(1), debug(false) {}
+
+  void setUp() {
+    d_ctxt = new Context;
+    d_nm = new NodeManager(d_ctxt);
+    d_scope = new NodeManagerScope(d_nm);
+    d_outputChannel.clear();
+    d_arith = new TheoryArith(d_ctxt, d_outputChannel);
+
+    preregistered = new std::set<Node>();
+
+    d_booleanType = new TypeNode(d_nm->booleanType());
+    d_realType = new TypeNode(d_nm->realType());
+
+  }
+
+  void tearDown() {
+    delete d_realType;
+    delete d_booleanType;
+
+    delete preregistered;
+
+    delete d_arith;
+    d_outputChannel.clear();
+    delete d_scope;
+    delete d_nm;
+    delete d_ctxt;
+  }
+
+  Node fakeTheoryEnginePreprocess(TNode inp){
+    Node rewrite = d_arith->rewrite(inp);
+
+    if(debug) cout << rewrite << inp << endl;
+
+    std::list<Node> toPreregister;
+
+    toPreregister.push_back(rewrite);
+    for(std::list<Node>::iterator i = toPreregister.begin(); i != toPreregister.end(); ++i){
+      Node n = *i;
+      preregistered->insert(n);
+
+      for(Node::iterator citer = n.begin(); citer != n.end(); ++citer){
+        Node c = *citer;
+        if(preregistered->find(c) == preregistered->end()){
+          toPreregister.push_back(c);
+        }
+      }
+    }
+    for(std::list<Node>::reverse_iterator i = toPreregister.rbegin(); i != toPreregister.rend(); ++i){
+      Node n = *i;
+      if(debug) cout << n.getId() << " "<< n << endl;
+      d_arith->preRegisterTerm(n);
+    }
+
+    return rewrite;
+  }
+
+  void testAssert() {
+    Node x = d_nm->mkVar(*d_realType);
+    Node c = d_nm->mkConst<Rational>(d_zero);
+
+    Node leq = d_nm->mkNode(LEQ, x, c);
+    Node rLeq = fakeTheoryEnginePreprocess(leq);
+
+    d_arith->assertFact(rLeq);
+
+    d_arith->check(d_level);
+
+    TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 0u);
+  }
+
+  Node simulateSplit(TNode l, TNode r){
+    Node eq = d_nm->mkNode(EQUAL, l, r);
+    Node lt = d_nm->mkNode(LT, l, r);
+    Node gt = d_nm->mkNode(GT, l, r);
+
+    Node dis = d_nm->mkNode(OR, eq, lt, gt);
+    return dis;
+  }
+
+  void testAssertEqualityEagerSplit() {
+    Node x = d_nm->mkVar(*d_realType);
+    Node c = d_nm->mkConst<Rational>(d_zero);
+
+    Node eq = d_nm->mkNode(EQUAL, x, c);
+    Node expectedDisjunct = simulateSplit(x,c);
+
+    Node rEq = fakeTheoryEnginePreprocess(eq);
+
+    d_arith->assertFact(rEq);
+
+    d_arith->check(d_level);
+
+    TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 1u);
+
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), expectedDisjunct);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), AUG_LEMMA);
+
+  }
+  void testLtRewrite() {
+    Node x = d_nm->mkVar(*d_realType);
+    Node c = d_nm->mkConst<Rational>(d_zero);
+
+    Node lt = d_nm->mkNode(LT, x, c);
+    Node geq = d_nm->mkNode(GEQ, x, c);
+    Node expectedRewrite = d_nm->mkNode(NOT, geq);
+
+    Node rewrite = d_arith->rewrite(lt);
+
+    TS_ASSERT_EQUALS(expectedRewrite, rewrite);
+  }
+
+  void testBasicConflict() {
+    Node x = d_nm->mkVar(*d_realType);
+    Node c = d_nm->mkConst<Rational>(d_zero);
+
+    Node eq = d_nm->mkNode(EQUAL, x, c);
+    Node lt = d_nm->mkNode(LT, x, c);
+    Node expectedDisjunct = simulateSplit(x,c);
+
+    Node rEq = fakeTheoryEnginePreprocess(eq);
+    Node rLt = fakeTheoryEnginePreprocess(lt);
+
+    d_arith->assertFact(rEq);
+    d_arith->assertFact(rLt);
+
+
+    d_arith->check(d_level);
+
+    TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 2u);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), expectedDisjunct);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), AUG_LEMMA);
+
+    TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), CONFLICT);
+
+    Node expectedClonflict = d_nm->mkNode(AND, rEq, rLt);
+
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), expectedClonflict);
+  }
+
+  void testBasicPropagate() {
+    Node x = d_nm->mkVar(*d_realType);
+    Node c = d_nm->mkConst<Rational>(d_zero);
+
+    Node eq = d_nm->mkNode(EQUAL, x, c);
+    Node lt = d_nm->mkNode(LT, x, c);
+    Node expectedDisjunct = simulateSplit(x,c);
+
+    Node rEq = fakeTheoryEnginePreprocess(eq);
+    Node rLt = fakeTheoryEnginePreprocess(lt);
+
+    d_arith->assertFact(rEq);
+
+
+    d_arith->check(d_level);
+    d_arith->propagate(d_level);
+
+    TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 2u);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), AUG_LEMMA);
+
+
+    Node expectedProp = d_nm->mkNode(GEQ, x, c);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), PROPOGATE);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), expectedProp);
+
+  }
+  void testTPLt1() {
+    Node x = d_nm->mkVar(*d_realType);
+    Node c0 = d_nm->mkConst<Rational>(d_zero);
+    Node c1 = d_nm->mkConst<Rational>(d_one);
+
+    Node leq0 = d_nm->mkNode(LEQ, x, c0);
+    Node leq1 = d_nm->mkNode(LEQ, x, c1);
+    Node lt1 = d_nm->mkNode(LT, x, c1);
+
+    Node rLeq0 = fakeTheoryEnginePreprocess(leq0);
+    Node rLt1 = fakeTheoryEnginePreprocess(lt1);
+    Node rLeq1 = fakeTheoryEnginePreprocess(leq1);
+
+    d_arith->assertFact(rLt1);
+
+
+    d_arith->check(d_level);
+    d_arith->propagate(d_level);
+
+#ifdef CVC4_ASSERTIONS
+    TS_ASSERT_THROWS(  d_arith->explain(rLeq0, d_level), AssertionException );
+    TS_ASSERT_THROWS(  d_arith->explain(rLt1, d_level), AssertionException );
+#endif
+    d_arith->explain(rLeq1, d_level);
+
+    TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 2u);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), PROPOGATE);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), EXPLANATION);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), leq1);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), rLt1);
+  }
+
+
+  void testTPLeq0() {
+    Node x = d_nm->mkVar(*d_realType);
+    Node c0 = d_nm->mkConst<Rational>(d_zero);
+    Node c1 = d_nm->mkConst<Rational>(d_one);
+
+    Node leq0 = d_nm->mkNode(LEQ, x, c0);
+    Node leq1 = d_nm->mkNode(LEQ, x, c1);
+    Node lt1 = d_nm->mkNode(LT, x, c1);
+
+    Node rLeq0 = fakeTheoryEnginePreprocess(leq0);
+    Node rLt1 = fakeTheoryEnginePreprocess(lt1);
+    Node rLeq1 = fakeTheoryEnginePreprocess(leq1);
+
+    d_arith->assertFact(rLeq0);
+
+
+    d_arith->check(d_level);
+    d_arith->propagate(d_level);
+
+
+    d_arith->explain(rLt1, d_level);
+#ifdef CVC4_ASSERTIONS
+    TS_ASSERT_THROWS(  d_arith->explain(rLeq0, d_level), AssertionException );
+#endif
+    d_arith->explain(rLeq1, d_level);
+
+    TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 4u);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), PROPOGATE);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), PROPOGATE);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(2), EXPLANATION);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(3), EXPLANATION);
+
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), rLt1);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), rLeq1);
+
+
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(2), rLeq0);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(3), rLeq0);
+  }
+  void testTPLeq1() {
+    Node x = d_nm->mkVar(*d_realType);
+    Node c0 = d_nm->mkConst<Rational>(d_zero);
+    Node c1 = d_nm->mkConst<Rational>(d_one);
+
+    Node leq0 = d_nm->mkNode(LEQ, x, c0);
+    Node leq1 = d_nm->mkNode(LEQ, x, c1);
+    Node lt1 = d_nm->mkNode(LT, x, c1);
+
+    Node rLeq0 = fakeTheoryEnginePreprocess(leq0);
+    Node rLt1 = fakeTheoryEnginePreprocess(lt1);
+    Node rLeq1 = fakeTheoryEnginePreprocess(leq1);
+
+    d_arith->assertFact(rLeq1);
+
+
+    d_arith->check(d_level);
+    d_arith->propagate(d_level);
+
+#ifdef CVC4_ASSERTIONS
+    TS_ASSERT_THROWS(  d_arith->explain(rLeq0, d_level), AssertionException );
+    TS_ASSERT_THROWS(  d_arith->explain(rLeq1, d_level), AssertionException );
+    TS_ASSERT_THROWS(  d_arith->explain(rLt1, d_level), AssertionException );
+#endif
+
+    TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 0u);
+  }
+};
index 50c201606c9aa97cdcf53c9bbce2fc0846e3c918..203d669b757098ed3988172293a6e575c56a6163 100644 (file)
@@ -24,6 +24,8 @@
 #include "expr/node_manager.h"
 #include "context/context.h"
 
+#include "theory/theory_test_utils.h"
+
 #include <vector>
 
 using namespace CVC4;
@@ -34,60 +36,6 @@ using namespace CVC4::context;
 
 using namespace std;
 
-/**
- * Very basic OutputChannel for testing simple Theory Behaviour.
- * Stores a call sequence for the output channel
- */
-enum OutputChannelCallType { CONFLICT, PROPOGATE, LEMMA, EXPLANATION };
-class TestOutputChannel : public OutputChannel {
-private:
-  void push(OutputChannelCallType call, TNode n) {
-    d_callHistory.push_back(make_pair(call,n));
-  }
-public:
-  vector< pair<OutputChannelCallType, Node> > d_callHistory;
-
-  TestOutputChannel() {}
-
-  ~TestOutputChannel() {}
-
-  void safePoint()  throw(Interrupted, AssertionException) {}
-
-  void conflict(TNode n, bool safe = false)  throw(Interrupted, AssertionException) {
-    push(CONFLICT, n);
-  }
-
-  void propagate(TNode n, bool safe = false)  throw(Interrupted, AssertionException) {
-    push(PROPOGATE, n);
-  }
-
-  void lemma(TNode n, bool safe = false) throw(Interrupted, AssertionException) {
-    push(LEMMA, n);
-  }
-  void augmentingLemma(TNode n, bool safe = false) throw(Interrupted, AssertionException){
-    Unreachable();
-  }
-  void explanation(TNode n, bool safe = false)  throw(Interrupted, AssertionException) {
-    push(EXPLANATION, n);
-  }
-
-  void clear() {
-    d_callHistory.clear();
-  }
-
-  Node getIthNode(int i) {
-    Node tmp = (d_callHistory[i]).second;
-    return tmp;
-  }
-
-  OutputChannelCallType getIthCallType(int i) {
-    return (d_callHistory[i]).first;
-  }
-
-  unsigned getNumCalls() {
-    return d_callHistory.size();
-  }
-};
 
 class TheoryUFWhite : public CxxTest::TestSuite {