Fix for slow array rewrite and minor bug fix in arrays that popped up as a result
authorClark Barrett <barrett@cs.nyu.edu>
Mon, 18 Jun 2012 19:59:56 +0000 (19:59 +0000)
committerClark Barrett <barrett@cs.nyu.edu>
Mon, 18 Jun 2012 19:59:56 +0000 (19:59 +0000)
src/theory/arrays/theory_arrays.cpp
src/theory/arrays/theory_arrays_rewriter.h

index da82e4bc3600a5f02428fd771bfea23e0486e153..460289439ad876a974314fb0f02a666a9ac982be 100644 (file)
@@ -770,7 +770,9 @@ Node TheoryArrays::mkAnd(std::vector<TNode>& conjunctions)
     all.insert(t);
   }
 
-  Assert(all.size() > 0);
+  if (all.size() == 0) {
+    return d_true;
+  }
   if (all.size() == 1) {
     // All the same, or just one
     return *(all.begin());
index c6ef5cd2541f45ea75e94443995a254a0891f726..d59ef736dff640d567cd032d569bfb6e9f068428 100644 (file)
@@ -39,17 +39,49 @@ public:
         TNode store = node[0];
         if (store.getKind() == kind::STORE) {
           // select(store(a,i,v),j)
-          Node eqRewritten = Rewriter::rewrite(store[1].eqNode(node[1]));
-          if (eqRewritten.getKind() == kind::CONST_BOOLEAN) {
-            bool value = eqRewritten.getConst<bool>();
-            if (value) {
+          TNode index = node[1];
+          bool val;
+          if (index == store[1]) {
+            val = true;
+          }
+          else if (index.isConst() && store[1].isConst()) {
+            val = false;
+          }
+          else {
+            Node eqRewritten = Rewriter::rewrite(store[1].eqNode(index));
+            if (eqRewritten.getKind() != kind::CONST_BOOLEAN) {
+              break;
+            }
+            val = eqRewritten.getConst<bool>();
+            if (val) {
               // select(store(a,i,v),i) = v
               return RewriteResponse(REWRITE_DONE, store[2]);
             }
             else {
               // select(store(a,i,v),j) = select(a,j) if i /= j
-              Node newNode = NodeManager::currentNM()->mkNode(kind::SELECT, store[0], node[1]);
-              return RewriteResponse(REWRITE_AGAIN_FULL, newNode);
+              store = store[0];
+              Node n;
+              while (store.getKind() == kind::STORE) {
+                if (index == store[1]) {
+                  val = true;
+                }
+                else if (index.isConst() && store[1].isConst()) {
+                  val = false;
+                }
+                else {
+                  n = Rewriter::rewrite(store[1].eqNode(index));
+                  if (n.getKind() != kind::CONST_BOOLEAN) {
+                    break;
+                  }
+                  val = n.getConst<bool>();
+                }
+                if (val) {
+                  return RewriteResponse(REWRITE_DONE, store[2]);
+                }
+                store = store[0];
+              }
+              n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index);
+              return RewriteResponse(REWRITE_DONE, n);
             }
           }
         }
@@ -66,22 +98,67 @@ public:
         }
         if (store.getKind() == kind::STORE) {
           // store(store(a,i,v),j,w)
-          Node eqRewritten = Rewriter::rewrite(store[1].eqNode(node[1]));
-          if (eqRewritten.getKind() == kind::CONST_BOOLEAN) {
-            bool val = eqRewritten.getConst<bool>();
-            NodeManager* nm = NodeManager::currentNM();
-            if (val) {
-              // store(store(a,i,v),i,w) = store(a,i,w)
-              Node newNode = nm->mkNode(kind::STORE, store[0], store[1], value);
-              return RewriteResponse(REWRITE_AGAIN_FULL, newNode);
+          TNode index = node[1];
+          bool val;
+          if (index == store[1]) {
+            val = true;
+          }
+          else if (index.isConst() && store[1].isConst()) {
+            val = false;
+          }
+          else {
+            Node eqRewritten = Rewriter::rewrite(store[1].eqNode(index));
+            if (eqRewritten.getKind() != kind::CONST_BOOLEAN) {
+              return RewriteResponse(REWRITE_DONE, node);
             }
-            else if (node[1] < store[1]) {
-              // store(store(a,i,v),j,w) = store(store(a,j,w),i,v)
-              //    IF i != j and j comes before i in the ordering
-              Node newNode = nm->mkNode(kind::STORE, store[0], node[1], value);
-              newNode = nm->mkNode(kind::STORE, newNode, store[1], store[2]);
-              return RewriteResponse(REWRITE_AGAIN_FULL, newNode);
+            val = eqRewritten.getConst<bool>();
+          }
+          NodeManager* nm = NodeManager::currentNM();
+          if (val) {
+            // store(store(a,i,v),i,w) = store(a,i,w)
+            return RewriteResponse(REWRITE_DONE, nm->mkNode(kind::STORE, store[0], index, value));
+          }
+          else if (index < store[1]) {
+            // store(store(a,i,v),j,w) = store(store(a,j,w),i,v)
+            //    IF i != j and j comes before i in the ordering
+            std::vector<TNode> indices;
+            std::vector<TNode> elements;
+            indices.push_back(store[1]);
+            elements.push_back(store[2]);
+            store = store[0];
+            Node n;
+            while (store.getKind() == kind::STORE) {
+              if (index == store[1]) {
+                val = true;
+              }
+              else if (index.isConst() && store[1].isConst()) {
+                val = false;
+              }
+              else {
+                n = Rewriter::rewrite(store[1].eqNode(index));
+                if (n.getKind() != kind::CONST_BOOLEAN) {
+                  break;
+                }
+                val = n.getConst<bool>();
+              }
+              if (val) {
+                store = store[0];
+                break;
+              }
+              else if (!(index < store[1])) {
+                break;
+              }
+              indices.push_back(store[1]);
+              elements.push_back(store[2]);
+              store = store[0];
             }
+            n = nm->mkNode(kind::STORE, store, index, value);
+            while (!indices.empty()) {
+              n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
+              indices.pop_back();
+              elements.pop_back();
+            }
+            return RewriteResponse(REWRITE_DONE, n);
           }
         }
         break;
@@ -105,7 +182,169 @@ public:
   }
 
   static inline RewriteResponse preRewrite(TNode node) {
-    Trace("arrays-prerewrite") << "Arrays::preRewrite " << node << std::endl;
+    Trace("arrays-prerewrite") << "Arrays::preRewrite start " << node << std::endl;
+    switch (node.getKind()) {
+      case kind::SELECT: {
+        TNode store = node[0];
+        if (store.getKind() == kind::STORE) {
+          // select(store(a,i,v),j)
+          TNode index = node[1];
+          bool val;
+          if (index == store[1]) {
+            val = true;
+          }
+          else if (index.isConst() && store[1].isConst()) {
+            val = false;
+          }
+          else {
+            Node eqRewritten = Rewriter::rewrite(store[1].eqNode(index));
+            if (eqRewritten.getKind() != kind::CONST_BOOLEAN) {
+              break;
+            }
+            val = eqRewritten.getConst<bool>();
+            if (val) {
+              // select(store(a,i,v),i) = v
+              return RewriteResponse(REWRITE_DONE, store[2]);
+            }
+            else {
+              // select(store(a,i,v),j) = select(a,j) if i /= j
+              store = store[0];
+              Node n;
+              while (store.getKind() == kind::STORE) {
+                if (index == store[1]) {
+                  val = true;
+                }
+                else if (index.isConst() && store[1].isConst()) {
+                  val = false;
+                }
+                else {
+                  n = Rewriter::rewrite(store[1].eqNode(index));
+                  if (n.getKind() != kind::CONST_BOOLEAN) {
+                    break;
+                  }
+                  val = n.getConst<bool>();
+                }
+                if (val) {
+                  return RewriteResponse(REWRITE_DONE, store[2]);
+                }
+                store = store[0];
+              }
+              n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index);
+              return RewriteResponse(REWRITE_DONE, n);
+            }
+          }
+        }
+        break;
+      }
+      case kind::STORE: {
+        TNode store = node[0];
+        TNode value = node[2];
+        // store(a,i,select(a,i)) = a
+        if (value.getKind() == kind::SELECT &&
+            value[0] == store &&
+            value[1] == node[1]) {
+          return RewriteResponse(REWRITE_DONE, store);
+        }
+        if (store.getKind() == kind::STORE) {
+          // store(store(a,i,v),j,w)
+          TNode index = node[1];
+          bool val;
+          if (index == store[1]) {
+            val = true;
+          }
+          else if (index.isConst() && store[1].isConst()) {
+            val = false;
+          }
+          else {
+            Node eqRewritten = Rewriter::rewrite(store[1].eqNode(index));
+            if (eqRewritten.getKind() != kind::CONST_BOOLEAN) {
+              return RewriteResponse(REWRITE_DONE, node);
+            }
+            val = eqRewritten.getConst<bool>();
+          }
+          NodeManager* nm = NodeManager::currentNM();
+          if (val) {
+            // store(store(a,i,v),i,w) = store(a,i,w)
+            return RewriteResponse(REWRITE_DONE, nm->mkNode(kind::STORE, store[0], index, value));
+          }
+          else if (index.isConst() && store[1].isConst()) {
+            std::map<TNode, TNode> elements;
+            elements[index] = value;
+            elements[store[1]] = store[2];
+            store = store[0];
+            Node n;
+            while (store.getKind() == kind::STORE) {
+              if (!store[1].isConst()) {
+                break;
+              }
+              if (elements.find(store[1]) != elements.end()) {
+                elements[store[1]] = store[2];
+              }
+              store = store[0];
+            }
+            std::map<TNode, TNode>::iterator it = elements.begin();
+            std::map<TNode, TNode>::iterator iend = elements.end();
+            for (; it != iend; ++it) {
+              n = nm->mkNode(kind::STORE, store, (*it).first, (*it).second);
+            }
+            return RewriteResponse(REWRITE_DONE, n);
+          }
+          else if (index < store[1]) {
+            // store(store(a,i,v),j,w) = store(store(a,j,w),i,v)
+            //    IF i != j and j comes before i in the ordering
+            std::vector<TNode> indices;
+            std::vector<TNode> elements;
+            indices.push_back(store[1]);
+            elements.push_back(store[2]);
+            store = store[0];
+            Node n;
+            while (store.getKind() == kind::STORE) {
+              if (index == store[1]) {
+                val = true;
+              }
+              else if (index.isConst() && store[1].isConst()) {
+                val = false;
+              }
+              else {
+                n = Rewriter::rewrite(store[1].eqNode(index));
+                if (n.getKind() != kind::CONST_BOOLEAN) {
+                  break;
+                }
+                val = n.getConst<bool>();
+              }
+              if (val) {
+                store = store[0];
+                break;
+              }
+              else if (!(index < store[1])) {
+                break;
+              }
+              indices.push_back(store[1]);
+              elements.push_back(store[2]);
+              store = store[0];
+            }
+            n = nm->mkNode(kind::STORE, store, index, value);
+            while (!indices.empty()) {
+              n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
+              indices.pop_back();
+              elements.pop_back();
+            }
+            return RewriteResponse(REWRITE_DONE, n);
+          }
+        }
+        break;
+      }
+      case kind::EQUAL:
+      case kind::IFF: {
+        if(node[0] == node[1]) {
+          return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
+        }
+        break;
+      }
+      default:
+        break;
+    }
+
     return RewriteResponse(REWRITE_DONE, node);
   }