Clark's work on array theory - can now solve all QF_AX problems
authorClark Barrett <barrett@cs.nyu.edu>
Mon, 11 Jul 2011 19:53:44 +0000 (19:53 +0000)
committerClark Barrett <barrett@cs.nyu.edu>
Mon, 11 Jul 2011 19:53:44 +0000 (19:53 +0000)
src/prop/cnf_stream.h
src/prop/prop_engine.cpp
src/prop/prop_engine.h
src/theory/arrays/Makefile.am
src/theory/arrays/theory_arrays.cpp
src/theory/arrays/theory_arrays.h
src/theory/arrays/theory_arrays_rewriter.h
src/theory/booleans/theory_bool_rewriter.cpp
src/theory/valuation.cpp
src/theory/valuation.h
src/util/ntuple.h

index ef75e635bd53b6c6755a7433f001d31642ce4250..e53b46d9be7f52a49799896501eda23caa8adde0 100644 (file)
@@ -146,14 +146,6 @@ protected:
    */
   bool isTranslated(TNode node) const;
 
-  /**
-   * Returns true if the node has an assigned literal (it might not be translated).
-   * Caches the pair of the node and the literal corresponding to the
-   * translation.
-   * @param node the node
-   */
-  bool hasLiteral(TNode node) const;
-
   /**
    * Acquires a new variable from the SAT solver to represent the node
    * and inserts the necessary data it into the mapping tables.
@@ -207,6 +199,14 @@ public:
    */
   TNode getNode(const SatLiteral& literal);
 
+  /**
+   * Returns true if the node has an assigned literal (it might not be translated).
+   * Caches the pair of the node and the literal corresponding to the
+   * translation.
+   * @param node the node
+   */
+  bool hasLiteral(TNode node) const;
+
   /**
    * Returns the literal that represents the given node in the SAT CNF
    * representation.
index 4c9b66020fa487ad73524c5b86f95319b1127a1f..3aa014782ebedd93cbebed7b73c289c821ce18e5 100644 (file)
@@ -170,6 +170,10 @@ Node PropEngine::getValue(TNode node) {
   }
 }
 
+bool PropEngine::isSatLiteral(TNode node) {
+  return d_cnfStream->hasLiteral(node);
+}
+
 bool PropEngine::hasValue(TNode node, bool& value) {
   Assert(node.getType().isBoolean());
   SatLiteral lit = d_cnfStream->getLiteral(node);
index f44ad16f709f0d34756fd9b4b9d2ed0c921e922a..f6e66bef1e27fe7a366cff08b40626b858e5cfe1 100644 (file)
@@ -114,6 +114,10 @@ public:
    */
   Node getValue(TNode node);
 
+  /*
+   * Return true if node has an associated SAT literal
+   */
+  bool isSatLiteral(TNode node);
   /**
    * Check if the node has a value and return it if yes.
    */
index 1e070cdaf1c764988800c1718fd0a923436f095c..3dde70145291f4efbf5df12d3e0844398c1c8d15 100644 (file)
@@ -13,6 +13,8 @@ libarrays_la_SOURCES = \
        union_find.h \
        union_find.cpp \
        array_info.h \
-       array_info.cpp
+       array_info.cpp \
+       static_fact_manager.h \
+       static_fact_manager.cpp
 
 EXTRA_DIST = kinds
index 37c49b341b23b5575c4315831f0713deaf53471b..dab78c17a547459177a34c81f8fe7d99f34bcf2b 100644 (file)
@@ -21,6 +21,7 @@
 #include "theory/valuation.h"
 #include "expr/kind.h"
 #include <map>
+#include "theory/rewriter.h"
 
 using namespace std;
 using namespace CVC4;
@@ -184,6 +185,208 @@ Node TheoryArrays::getValue(TNode n) {
   }
 }
 
+Theory::SolveStatus TheoryArrays::solve(TNode in, SubstitutionMap& outSubstitutions) {
+  switch(in.getKind()) {
+    case kind::EQUAL:
+    {
+      d_staticFactManager.addEq(in);
+      if (in[0].getMetaKind() == kind::metakind::VARIABLE && !in[1].hasSubterm(in[0])) {
+        outSubstitutions.addSubstitution(in[0], in[1]);
+        return SOLVE_STATUS_SOLVED;
+      }
+      if (in[1].getMetaKind() == kind::metakind::VARIABLE && !in[0].hasSubterm(in[1])) {
+        outSubstitutions.addSubstitution(in[1], in[0]);
+        return SOLVE_STATUS_SOLVED;
+      }
+      break;
+    }
+    case kind::NOT:
+    {
+      Assert(in[0].getKind() == kind::EQUAL ||
+             in[0].getKind() == kind::IFF );
+      Node a = in[0][0];
+      Node b = in[0][1];
+      d_staticFactManager.addDiseq(in[0]);
+      break;
+    }
+    default:
+      break;
+  }
+  return SOLVE_STATUS_UNSOLVED;
+}
+
+Node TheoryArrays::preprocessTerm(TNode term) {
+  switch (term.getKind()) {
+    case kind::SELECT: {
+      // select(store(a,i,v),j) = select(a,j)
+      //    IF i != j
+      if (term[0].getKind() == kind::STORE &&
+          d_staticFactManager.areDiseq(term[0][1], term[1])) {
+        return NodeBuilder<2>(kind::SELECT) << term[0][0] << term[1];
+      }
+      break;
+    }
+    case kind::STORE: {
+      // store(store(a,i,v),j,w) = store(store(a,j,w),i,v)
+      //    IF i != j and j comes before i in the ordering
+      if (term[0].getKind() == kind::STORE &&
+          (term[1] < term[0][1]) &&
+          d_staticFactManager.areDiseq(term[1], term[0][1])) {
+        Node inner = NodeBuilder<3>(kind::STORE) << term[0][0] << term[1] << term[2];
+        Node outer = NodeBuilder<3>(kind::STORE) << inner << term[0][1] << term[0][2];
+        return outer;
+      }
+      break;
+    }
+    case kind::EQUAL: {
+      if (term[0].getKind() == kind::STORE ||
+          term[1].getKind() == kind::STORE) {
+        TNode left = term[0];
+        TNode right = term[1];
+        int leftWrites = 0, rightWrites = 0;
+
+        // Count nested writes
+        TNode e1 = left;
+        while (e1.getKind() == kind::STORE) {
+          ++leftWrites;
+          e1 = e1[0];
+        }
+
+        TNode e2 = right;
+        while (e2.getKind() == kind::STORE) {
+          ++rightWrites;
+          e2 = e2[0];
+        }
+
+        if (rightWrites > leftWrites) {
+          TNode tmp = left;
+          left = right;
+          right = tmp;
+          int tmpWrites = leftWrites;
+          leftWrites = rightWrites;
+          rightWrites = tmpWrites;
+        }
+
+        NodeManager* nm = NodeManager::currentNM();
+        if (rightWrites == 0) {
+          if (e1 == e2) {
+            // write(store, index_0, v_0, index_1, v_1, ..., index_n, v_n) = store IFF
+            //
+            // read(store, index_n) = v_n &
+            // index_{n-1} != index_n -> read(store, index_{n-1}) = v_{n-1} &
+            // (index_{n-2} != index_{n-1} & index_{n-2} != index_n) -> read(store, index_{n-2}) = v_{n-2} &
+            // ...
+            // (index_1 != index_2 & ... & index_1 != index_n) -> read(store, index_1) = v_1
+            // (index_0 != index_1 & index_0 != index_2 & ... & index_0 != index_n) -> read(store, index_0) = v_0
+            TNode write_i, write_j, index_i, index_j;
+            Node conc;
+            NodeBuilder<> result(kind::AND);
+            int i, j;
+            write_i = left;
+            for (i = leftWrites-1; i >= 0; --i) {
+              index_i = write_i[1];
+
+              // build: [index_i /= index_n && index_i /= index_(n-1) &&
+              //         ... && index_i /= index_(i+1)] -> read(store, index_i) = v_i
+              write_j = left;
+              {
+                NodeBuilder<> hyp(kind::AND);
+                for (j = leftWrites - 1; j > i; --j) {
+                  index_j = write_j[1];
+                  if (d_staticFactManager.areDiseq(index_i, index_j)) {
+                    continue;
+                  }
+                  Node hyp2(index_i.getType() == nm->booleanType()? 
+                            index_i.iffNode(index_j) : index_i.eqNode(index_j));
+                  hyp << hyp2.notNode();
+                  write_j = write_j[0];
+                }
+
+                Node r1 = nm->mkNode(kind::SELECT, e1, index_i);
+                conc = (r1.getType() == nm->booleanType())? 
+                  r1.iffNode(write_i[2]) : r1.eqNode(write_i[2]);
+                if (hyp.getNumChildren() != 0) {
+                  if (hyp.getNumChildren() == 1) {
+                    conc = hyp.getChild(0).impNode(conc);
+                  }
+                  else {
+                    r1 = hyp;
+                    conc = r1.impNode(conc);
+                  }
+                }
+
+                // And into result
+                result << conc;
+
+                // Prepare for next iteration
+                write_i = write_i[0];
+              }
+            }
+            Assert(result.getNumChildren() > 0);
+            if (result.getNumChildren() == 1) {
+              return result.getChild(0);
+            }
+            return result;
+          }
+          break;
+        }
+        else {
+          // store(...) = store(a,i,v) ==>
+          // store(store(...),i,select(a,i)) = a && select(store(...),i)=v
+          Node l = left;
+          Node tmp;
+          NodeBuilder<> nb(kind::AND);
+          while (right.getKind() == STORE) {
+            tmp = nm->mkNode(kind::SELECT, l, right[1]);
+            nb << tmp.eqNode(right[2]);
+            tmp = nm->mkNode(kind::SELECT, right[0], right[1]);
+            l = nm->mkNode(kind::STORE, l, right[1], tmp);
+            right = right[0];
+          }
+          nb << l.eqNode(right);
+          return nb;
+        }
+      }
+      break;
+    }
+    default:
+      break;
+  }
+  return term;
+}
+
+Node TheoryArrays::recursivePreprocessTerm(TNode term) {
+  unsigned nc = term.getNumChildren();
+  if (nc == 0 ||
+      (theoryOf(term) != theory::THEORY_ARRAY &&
+       term.getType() != NodeManager::currentNM()->booleanType())) {
+    return term;
+  }
+  NodeMap::iterator find = d_ppCache.find(term);
+  if (find != d_ppCache.end()) {
+    return (*find).second;
+  }
+  NodeBuilder<> newNode(term.getKind());
+  unsigned i;
+  for (i = 0; i < nc; ++i) {
+    newNode << recursivePreprocessTerm(term[i]);
+  }
+  Node newTerm = Rewriter::rewrite(newNode);
+  Node newTerm2 = preprocessTerm(newTerm);
+  if (newTerm != newTerm2) {
+    newTerm = recursivePreprocessTerm(Rewriter::rewrite(newTerm2));
+  }
+  d_ppCache[term] = newTerm;
+  return newTerm;
+}
+
+Node TheoryArrays::preprocess(TNode atom) {
+  if (d_donePreregister) return atom;
+  Assert(atom.getKind() == kind::EQUAL);
+  return recursivePreprocessTerm(atom);
+}
+
+
 void TheoryArrays::merge(TNode a, TNode b) {
   Assert(d_conflict.isNull());
 
@@ -508,7 +711,48 @@ bool TheoryArrays::isRedundantInContext(TNode a, TNode b, TNode i, TNode j) {
     checkRowForIndex(j,b); // why am i doing this?
     checkRowForIndex(i,a);
     return true;
+  }
+  Node literal1 = Rewriter::rewrite(i.eqNode(j));
+  bool hasValue1, satValue1;
+  Node ff = nm->mkConst<bool>(false);
+  Node tt = nm->mkConst<bool>(true);
+  if (literal1 == ff) {
+    hasValue1 = true;
+    satValue1 = false;
+  }
+  else if (literal1 == tt) {
+    hasValue1 = true;
+    satValue1 = true;
+  }
+  else hasValue1 = (d_valuation.isSatLiteral(literal1) && d_valuation.hasSatValue(literal1, satValue1));
+  if (hasValue1) {
+    if (satValue1) return true;
+    Node literal2 = Rewriter::rewrite(aj.eqNode(bj));
+    bool hasValue2, satValue2;
+    if (literal2 == ff) {
+      hasValue2 = true;
+      satValue2 = false;
     }
+    else if (literal2 == tt) {
+      hasValue2 = true;
+      satValue2 = true;
+    }
+    else hasValue2 = (d_valuation.isSatLiteral(literal2) && d_valuation.hasSatValue(literal2, satValue2));
+    if (hasValue2) {
+      if (satValue2) return true;
+      // conflict
+      Assert(!satValue1 && !satValue2);
+      Assert(literal1.getKind() == kind::EQUAL && literal2.getKind() == kind::EQUAL);
+      NodeBuilder<2> nb(kind::AND);
+      literal1 = areDisequal(literal1[0],literal1[1]);
+      literal2 = areDisequal(literal2[0],literal2[1]);
+      Assert(!literal1.isNull() && !literal2.isNull());
+      nb << literal1.notNode() << literal2.notNode();
+      literal1 = nb;
+      d_out->conflict(literal1, false);
+      return true;
+    }
+  }
   if(alreadyAddedRow(a,b,i,j)) {
    // Debug("arrays-lem")<<"isRedundantInContext already added "<<a<<" "<<b<<" "<<i<<" "<<j<<"\n";
     return true;
index bc1f670ba087161ce207488fc7ae39e3aea1331a..f4cccfec549221d7f21621bdf08ee3cd1c2ffff1 100644 (file)
@@ -29,6 +29,7 @@
 #include "util/ntuple.h"
 #include "util/stats.h"
 #include "util/backtrackable.h"
+#include "theory/arrays/static_fact_manager.h"
 
 #include <iostream>
 #include <map>
@@ -113,6 +114,18 @@ private:
   CongruenceClosure<CongruenceChannel, CONGRUENCE_OPERATORS_1
                                  (kind::SELECT)> d_cc;
 
+  /**
+   * (Temporary) fact manager for preprocessing - eventually handle this with
+   * something more standard (like congruence closure module)
+   */
+  StaticFactManager d_staticFactManager;
+
+  /**
+   * Cache for proprocessing of atoms.
+   */
+  typedef std::hash_map<Node, Node, NodeHashFunction> NodeMap;
+  NodeMap d_ppCache;
+
   /**
    * Union find for storing the equalities.
    */
@@ -347,6 +360,8 @@ private:
 
   bool d_donePreregister;
 
+  Node preprocessTerm(TNode term);
+  Node recursivePreprocessTerm(TNode term);
 
 public:
   TheoryArrays(context::Context* c, OutputChannel& out, Valuation valuation);
@@ -464,6 +479,8 @@ public:
   void explain(TNode n);
 
   Node getValue(TNode n);
+  SolveStatus solve(TNode in, SubstitutionMap& outSubstitutions);
+  Node preprocess(TNode atom);
   void shutdown() { }
   std::string identify() const { return std::string("TheoryArrays"); }
 
index c37cbe68c34273189b57af47ca3023106e51356a..059b7ce8b22f301606a6aeb60773a4d93aa7ab26 100644 (file)
@@ -34,51 +34,51 @@ public:
 
   static RewriteResponse postRewrite(TNode node) {
     Debug("arrays-postrewrite") << "Arrays::postRewrite start " << node << std::endl;
-    if(node.getKind() == kind::EQUAL || node.getKind() == kind::IFF) {
-      if(node[0] == node[1]) {
-        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
-      }
-      // checks for RoW axiom: (select ( store a i v) i) = v and rewrites it
-      // to true
-      if(node[0].getKind()==kind::SELECT) {
-        TNode a = node[0][0];
-        TNode j = node[0][1];
-        if(a.getKind()==kind::STORE) {
-          TNode b = a[0];
-          TNode i = a[1];
-          TNode v = a[2];
-          if(v == node[1] && i == j) {
-            Debug("arrays-postrewrite") << "Arrays::postRewrite true" << std::endl;
-            return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
-          }
+    switch (node.getKind()) {
+      case kind::SELECT: {
+        // select(store(a,i,v),i) = v
+        TNode store = node[0];
+        if (store.getKind() == kind::STORE &&
+            store[1] == node[1]) {
+          return RewriteResponse(REWRITE_DONE, store[2]);
         }
+        break;
       }
-
-      if (node[0] > node[1]) {
-        Node newNode = NodeManager::currentNM()->mkNode(node.getKind(), node[1], node[0]);
-        // If we've switched theories, we need to rewrite again (TODO: THIS IS HACK, once theories accept eq, change)
-        if (Theory::theoryOf(newNode[0]) != Theory::theoryOf(newNode[1])) {
+      case kind::STORE: {
+        TNode store = node[0];
+        TNode value = node[2];
+        // store(a,i,select(a,i)) = a
+        if (value.getKind() == kind::SELECT &&
+            value[0] == store &&
+            value[1] == node[1]) {
+          return RewriteResponse(REWRITE_DONE, store);
+        }
+        // store(store(a,i,v),i,w) = store(a,i,w)
+        if (store.getKind() == kind::STORE &&
+            store[1] == node[1]) {
+          Node newNode = NodeManager::currentNM()->mkNode(kind::STORE, store[0], store[1], value);
           return RewriteResponse(REWRITE_AGAIN_FULL, newNode);
-        } else {
-          return RewriteResponse(REWRITE_DONE, newNode);
         }
+        break;
       }
-    }
-    // FIXME: would it be better to move in preRewrite?
-    // if yes don't need the above case
-    if (node.getKind()==kind::SELECT) {
-      // we are rewriting (select (store a i v) i) to v
-      TNode a = node[0];
-      TNode i = node[1];
-      if(a.getKind() == kind::STORE) {
-        TNode b = a[0];
-        TNode j = a[1];
-        TNode v = a[2];
-        if(i==j) {
-          Debug("arrays-postrewrite") << "Arrays::postrewrite to " << v << std::endl;
-          return RewriteResponse(REWRITE_AGAIN_FULL, v);
+      case kind::EQUAL:
+      case kind::IFF: {
+        if(node[0] == node[1]) {
+          return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
+        }
+        if (node[0] > node[1]) {
+          Node newNode = NodeManager::currentNM()->mkNode(node.getKind(), node[1], node[0]);
+          // If we've switched theories, we need to rewrite again (TODO: THIS IS HACK, once theories accept eq, change)
+          if (Theory::theoryOf(newNode[0]) != Theory::theoryOf(newNode[1])) {
+            return RewriteResponse(REWRITE_AGAIN_FULL, newNode);
+          } else {
+            return RewriteResponse(REWRITE_DONE, newNode);
+          }
         }
+        break;
       }
+      default:
+        break;
     }
 
     return RewriteResponse(REWRITE_DONE, node);
index 18aa71667e30da0678bd0121aa04144016f5726d..d2693268fc20c6c7a2a2ba15feed1f3682e70445 100644 (file)
@@ -42,14 +42,51 @@ RewriteResponse TheoryBoolRewriter::preRewrite(TNode n) {
       if (n[0] == ff) return RewriteResponse(REWRITE_AGAIN, n[1]);
       if (n[1] == ff) return RewriteResponse(REWRITE_AGAIN, n[0]);
     }
+    else {
+      bool done = true;
+      TNode::iterator i = n.begin(), iend = n.end();
+      for(; i != iend; ++i) {
+        if (*i == tt) return RewriteResponse(REWRITE_DONE, tt);
+        if (*i == ff) done = false;
+      }
+      if (!done) {
+        NodeBuilder<> nb(kind::OR);
+        for(i = n.begin(); i != iend; ++i) {
+          if (*i != ff) nb << *i;
+        }
+        if (nb.getNumChildren() == 0) return RewriteResponse(REWRITE_DONE, ff);
+        if (nb.getNumChildren() == 1) return RewriteResponse(REWRITE_AGAIN, nb.getChild(0));
+        return RewriteResponse(REWRITE_AGAIN, nb.constructNode());
+      }
+    }
     break;
   }
   case kind::AND: {
+    //TODO: Why REWRITE_AGAIN here?
     if (n.getNumChildren() == 2) {
       if (n[0] == ff || n[1] == ff) return RewriteResponse(REWRITE_DONE, ff);
       if (n[0] == tt) return RewriteResponse(REWRITE_AGAIN, n[1]);
       if (n[1] == tt) return RewriteResponse(REWRITE_AGAIN, n[0]);
     }
+    else {
+      bool done = true;
+      TNode::iterator i = n.begin(), iend = n.end();
+      for(; i != iend; ++i) {
+        if (*i == ff) return RewriteResponse(REWRITE_DONE, ff);
+        if (*i == tt) done = false;
+      }
+      if (!done) {
+        NodeBuilder<> nb(kind::AND);
+        for(i = n.begin(); i != iend; ++i) {
+          if (*i != tt) {
+            nb << *i;
+          }
+        }
+        if (nb.getNumChildren() == 0) return RewriteResponse(REWRITE_DONE, tt);
+        if (nb.getNumChildren() == 1) return RewriteResponse(REWRITE_AGAIN, nb.getChild(0));
+        return RewriteResponse(REWRITE_AGAIN, nb.constructNode());
+      }
+    }
     break;
   }
   case kind::IMPLIES: {
index 0aefd7f2157a11f58623ab723f0f96498a22cda2..5002c8a59b5fb8e71790fb18fea50b7dcca45729 100644 (file)
@@ -27,6 +27,10 @@ Node Valuation::getValue(TNode n) const {
   return d_engine->getValue(n);
 }
 
+bool Valuation::isSatLiteral(TNode n) const {
+  return d_engine->getPropEngine()->isSatLiteral(n);
+}
+
 bool Valuation::hasSatValue(TNode n, bool& value) const {
   return d_engine->getPropEngine()->hasValue(n, value);
 }
index ea6772ce8d000412901edf8d4be8d252c976e08b..58615f481ec2854ef7ef6b64619b1150d7ea0448 100644 (file)
@@ -41,6 +41,11 @@ public:
 
   Node getValue(TNode n) const;
 
+  /*
+   * Return true if n has an associated SAT literal
+   */
+  bool isSatLiteral(TNode n) const;
+
   /**
    * Get the current SAT assignment to the node n.
    *
index a3b0dfdf4d51fc6f248f8b6bd7407a7b6d0801e1..4c9a033a1e4095c840f0780e49a0956f809d285e 100644 (file)
@@ -45,12 +45,9 @@ public:
   T2 second;
   T3 third;
   T4 fourth;
-  quad(const T1& t1, const T2& t2, const T3& t3, const T4& t4) {
-    first = t1;
-    second = t2;
-    third = t3;
-    fourth = t4;
-  }
+  quad(const T1& t1, const T2& t2, const T3& t3, const T4& t4)
+    : first(t1), second(t2), third(t3), fourth(t4)
+  { }
 };/* class quad<> */
 
 template <class T1, class T2, class T3, class T4>