More progress on array constants.
authorClark Barrett <barrett@cs.nyu.edu>
Wed, 22 Aug 2012 18:04:10 +0000 (18:04 +0000)
committerClark Barrett <barrett@cs.nyu.edu>
Wed, 22 Aug 2012 18:04:10 +0000 (18:04 +0000)
Here's a fun way to give yourself a week-long headache: try to figure out how
to write efficient code to normalize array constants.

It's mostly there now - just need to figure out how to use type enumerators and
update once the new cardinality stuff is in place.

src/theory/arrays/theory_arrays_rewriter.h
src/theory/arrays/theory_arrays_type_rules.h

index ba2baab2c09a7abbf2c5262ce1b5eb939a281c72..50a5ee2d06efdcf259afecbb738b0bc52c371ff7 100644 (file)
@@ -30,6 +30,169 @@ namespace arrays {
 
 class TheoryArraysRewriter {
 
+  static Node normalizeConstant(TNode node) {
+    TNode store = node[0];
+    TNode index = node[1];
+    TNode value = node[2];
+
+    std::vector<TNode> indices;
+    std::vector<TNode> elements;
+
+    // Normal form for nested stores is just ordering by index - but also need
+    // to check if we are writing to default value
+
+    // Go through nested stores looking for where to insert index
+    // Also check whether we are replacing an existing store
+    TNode replacedValue;
+    Integer depth = 1;
+    Integer valCount = 1;
+    while (store.getKind() == kind::STORE) {
+      if (index == store[1]) {
+        replacedValue = store[2];
+        store = store[0];
+        break;
+      }
+      else if (!(index < store[1])) {
+        break;
+      }
+      if (value == store[2]) {
+        valCount += 1;
+      }
+      depth += 1;
+      indices.push_back(store[1]);
+      elements.push_back(store[2]);
+      store = store[0];
+    }
+    Node n = store;
+
+    // Get the default value at the bottom of the nested stores
+    while (store.getKind() == kind::STORE) {
+      if (value == store[2]) {
+        valCount += 1;
+      }
+      depth += 1;
+      store = store[0];
+    }
+    Assert(store.getKind() == kind::STORE_ALL);
+    ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
+    Node defaultValue = Node::fromExpr(storeAll.getExpr());
+    NodeManager* nm = NodeManager::currentNM();
+
+    // Check if we are writing to default value - if so the store
+    // to index can be ignored
+    if (value == defaultValue) {
+      if (replacedValue.isNull()) {
+        // Quick exit - if writing to default value and nothing was
+        // replaced, we can just return node[0]
+        return node[0];
+      }
+    }
+    else {
+      n = nm->mkNode(kind::STORE, n, index, value);
+    }
+
+    // Build the rest of the store after inserting/deleting
+    while (!indices.empty()) {
+      n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
+      indices.pop_back();
+      elements.pop_back();
+    }
+
+    Cardinality indexCard = index.getType().getCardinality();
+    if (indexCard.isInfinite()) {
+      return n;
+    }
+
+    // When index sort is finite, we have to check whether there is any value
+    // that is written to more than the default value.  If so, it must become
+    // the new default value
+
+    TNode mostFrequentValue;
+    Integer mostFrequentValueCount = 0;
+    bool recompute = false;
+    if (node[0].getKind() == kind::STORE) {
+      // TODO: look up most frequent value and count
+      if (!replacedValue.isNull() && mostFrequentValue == replacedValue) {
+        recompute = true;
+      }
+    }
+
+    // Compute the most frequently written value for n
+    if (valCount > mostFrequentValueCount ||
+        (valCount == mostFrequentValueCount && value < mostFrequentValue)) {
+      mostFrequentValue = value;
+      mostFrequentValueCount = valCount;
+      recompute = false;
+    }
+
+    // Need to make sure the default value count is larger, or the same and the default value is expression-order-less-than nextValue
+    int compare;// = indexCard.compare(mostFrequentValueCount + depth);
+    // Assert result of compare is not unknown
+    if (compare > 0 ||
+        (compare == 0 && (defaultValue < mostFrequentValue))) {
+      return n;
+    }
+
+    // Bad case: have to recompute value counts and/or possibly switch out
+    // default value
+    store = n;
+    std::hash_set<TNode, TNodeHashFunction> indexSet;
+    std::hash_map<TNode, unsigned, TNodeHashFunction> elementsMap;
+    std::hash_map<TNode, unsigned, TNodeHashFunction>::iterator it;
+    unsigned count;
+    unsigned max = 0;
+    TNode maxValue;
+    while (store.getKind() == kind::STORE) {
+      indices.push_back(store[1]);
+      indexSet.insert(store[1]);
+      elements.push_back(store[2]);
+      it = elementsMap.find(store[2]);
+      if (it != elementsMap.end()) {
+        (*it).second = (*it).second + 1;
+        count = (*it).second;
+      }
+      else {
+        elementsMap[store[2]] = 1;
+        count = 1;
+      }
+      if (count > max ||
+          (count == max && store[2] < maxValue)) {
+        max = count;
+        maxValue = store[2];
+      }
+      store = store[0];
+    }
+
+    Assert(depth == indices.size());
+    //compare = indexCard.compare(max + depth);
+    // Assert result of compare is not unknown
+    if (compare > 0 ||
+        (compare == 0 && (defaultValue < maxValue))) {
+      Assert(recompute);
+      return n;
+    }
+
+    // Out of luck: have to swap out default value
+    std::vector<Node> newIndices;
+    // Enumerate values from index type into newIndices and sort
+
+    //n = storeAll(maxValue)
+    while (!newIndices.empty() || !indices.empty()) {
+      if (!newIndices.empty() && (indices.empty() || newIndices.back() < indices.back())) {
+        n = nm->mkNode(kind::STORE, n, newIndices.back(), defaultValue);
+        newIndices.pop_back();
+      }
+      else if (newIndices.empty() || indices.back() < newIndices.back()) {
+        if (elements.back() != maxValue) {
+          n = nm->mkNode(kind::STORE, n, indices.back(), elements.back());
+        }
+        indices.pop_back();
+        elements.pop_back();
+      }
+    }
+    return n;
+  }
+
 public:
 
   static RewriteResponse postRewrite(TNode node) {
@@ -37,21 +200,43 @@ public:
     switch (node.getKind()) {
       case kind::SELECT: {
         TNode store = node[0];
-        if (store.getKind() == kind::STORE) {
-          // select(store(a,i,v),j)
-          Node eqRewritten = Rewriter::rewrite(store[1].eqNode(node[1]));
-          if (eqRewritten.getKind() == kind::CONST_BOOLEAN) {
-            bool value = eqRewritten.getConst<bool>();
-            if (value) {
-              // select(store(a,i,v),i) = v
-              return RewriteResponse(REWRITE_DONE, store[2]);
-            }
-            else {
-              // select(store(a,i,v),j) = select(a,j) if i /= j
-              Node newNode = NodeManager::currentNM()->mkNode(kind::SELECT, store[0], node[1]);
-              return RewriteResponse(REWRITE_AGAIN_FULL, newNode);
+        TNode index = node[1];
+        Node n;
+        bool val;
+        while (store.getKind() == kind::STORE) {
+          if (index == store[1]) {
+            val = true;
+          }
+          else if (index.isConst() && store[1].isConst()) {
+            val = false;
+          }
+          else {
+            n = Rewriter::rewrite(store[1].eqNode(index));
+            if (n.getKind() != kind::CONST_BOOLEAN) {
+              break;
             }
+            val = n.getConst<bool>();
+          }
+          if (val) {
+            // select(store(a,i,v),j) = v if i = j
+            Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << store[2] << std::endl;
+            return RewriteResponse(REWRITE_DONE, store[2]);
           }
+          // select(store(a,i,v),j) = select(a,j) if i /= j
+          store = store[0];
+        }
+        if (store.getKind() == kind::STORE_ALL) {
+          // select(store_all(v),i) = v
+          ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
+          n = Node::fromExpr(storeAll.getExpr());
+          Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl;
+          Assert(n.isConst());
+          return RewriteResponse(REWRITE_DONE, n);
+        }
+        else if (store != node[0]) {
+          n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index);
+          Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl;
+          return RewriteResponse(REWRITE_DONE, n);
         }
         break;
       }
@@ -65,9 +250,15 @@ public:
           Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << store << std::endl;
           return RewriteResponse(REWRITE_DONE, store);
         }
+        TNode index = node[1];
+        if (store.isConst() && index.isConst() && value.isConst()) {
+          // normalize constant
+          Node n = normalizeConstant(node);
+          Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl;
+          return RewriteResponse(REWRITE_DONE, n);
+        }
         if (store.getKind() == kind::STORE) {
           // store(store(a,i,v),j,w)
-          TNode index = node[1];
           bool val;
           if (index == store[1]) {
             val = true;
@@ -161,10 +352,10 @@ public:
     switch (node.getKind()) {
       case kind::SELECT: {
         TNode store = node[0];
-        if (store.getKind() == kind::STORE) {
-          // select(store(a,i,v),j)
-          TNode index = node[1];
-          bool val;
+        TNode index = node[1];
+        Node n;
+        bool val;
+        while (store.getKind() == kind::STORE) {
           if (index == store[1]) {
             val = true;
           }
@@ -172,45 +363,32 @@ public:
             val = false;
           }
           else {
-            Node eqRewritten = Rewriter::rewrite(store[1].eqNode(index));
-            if (eqRewritten.getKind() != kind::CONST_BOOLEAN) {
+            n = Rewriter::rewrite(store[1].eqNode(index));
+            if (n.getKind() != kind::CONST_BOOLEAN) {
               break;
             }
-            val = eqRewritten.getConst<bool>();
+            val = n.getConst<bool>();
           }
           if (val) {
-            // select(store(a,i,v),i) = v
+            // select(store(a,i,v),j) = v if i = j
             Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << store[2] << std::endl;
-            return RewriteResponse(REWRITE_DONE, store[2]);
-          }
-          else {
-            // select(store(a,i,v),j) = select(a,j) if i /= j
-            store = store[0];
-            Node n;
-            while (store.getKind() == kind::STORE) {
-              if (index == store[1]) {
-                val = true;
-              }
-              else if (index.isConst() && store[1].isConst()) {
-                val = false;
-              }
-              else {
-                n = Rewriter::rewrite(store[1].eqNode(index));
-                if (n.getKind() != kind::CONST_BOOLEAN) {
-                  break;
-                }
-                val = n.getConst<bool>();
-              }
-              if (val) {
-                Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << store[2] << std::endl;
-                return RewriteResponse(REWRITE_DONE, store[2]);
-              }
-              store = store[0];
-            }
-            n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index);
-            Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << n << std::endl;
-            return RewriteResponse(REWRITE_DONE, n);
+            return RewriteResponse(REWRITE_AGAIN, store[2]);
           }
+          // select(store(a,i,v),j) = select(a,j) if i /= j
+          store = store[0];
+        }
+        if (store.getKind() == kind::STORE_ALL) {
+          // select(store_all(v),i) = v
+          ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
+          n = Node::fromExpr(storeAll.getExpr());
+          Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << n << std::endl;
+          Assert(n.isConst());
+          return RewriteResponse(REWRITE_DONE, n);
+        }
+        else if (store != node[0]) {
+          n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index);
+          Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << n << std::endl;
+          return RewriteResponse(REWRITE_DONE, n);
         }
         break;
       }
@@ -222,7 +400,7 @@ public:
             value[0] == store &&
             value[1] == node[1]) {
           Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << store << std::endl;
-          return RewriteResponse(REWRITE_DONE, store);
+          return RewriteResponse(REWRITE_AGAIN, store);
         }
         if (store.getKind() == kind::STORE) {
           // store(store(a,i,v),j,w)
@@ -237,8 +415,7 @@ public:
           else {
             Node eqRewritten = Rewriter::rewrite(store[1].eqNode(index));
             if (eqRewritten.getKind() != kind::CONST_BOOLEAN) {
-              Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << node << std::endl;
-              return RewriteResponse(REWRITE_DONE, node);
+              break;
             }
             val = eqRewritten.getConst<bool>();
           }
@@ -255,7 +432,7 @@ public:
       case kind::EQUAL:
       case kind::IFF: {
         if(node[0] == node[1]) {
-          Trace("arrays-prerewrite") << "Arrays::preRewrite returning  true" << std::endl;
+          Trace("arrays-prerewrite") << "Arrays::preRewrite returning true" << std::endl;
           return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
         }
         break;
index 9c678d6cedfe601a114b94b1bbd5eaaf22aae2d8..f6ce84d0f257b25b067d446d84f59c725a6b7fa3 100644 (file)
@@ -81,71 +81,63 @@ struct ArrayStoreTypeRule {
       return false;
     }
 
-    // If store indices are not in order, not in normal form
-    if (store.getKind() == kind::STORE && index < store[1]) {
+    // Normal form for nested stores is just ordering by index but also need to check that we are not writing
+    // to default value
+    if (store.getKind() == kind::STORE && (!(store[1] < index))) {
       return false;
     }
 
-    // Compute the number of nested stores
     Integer depth = 1;
+    Integer valCount = 1;
     while (store.getKind() == kind::STORE) {
-      store = store[0];
       depth += 1;
+      if (store[2] == value) {
+        valCount += 1;
+      }
+      store = store[0];
     }
-
-    // Get the default value in the STORE_ALL object at the bottom of the nested stores
     Assert(store.getKind() == kind::STORE_ALL);
     ArrayStoreAll storeAll = store.getConst<ArrayStoreAll>();
-    TNode defaultValue /* = storeAll.getExpr().getTNode()*/ ;
-
-    // If writing to default value, not in normal form
-    if (defaultValue == value) {
+    Node defaultValue = Node::fromExpr(storeAll.getExpr());
+    if (value == defaultValue) {
       return false;
     }
-    
+
     // Get the cardinality of the index type
     Cardinality indexCard = index.getType().getCardinality();
 
-    // If cardinality is infinite, ok - in normal form
     if (indexCard.isInfinite()) {
       return true;
     }
-    
-    /*
-    Assert(depth <= indexCard);
-
-    // If number of stores is equal to cardinality of index type,
-    // then the default value is overridden at all indices.  Our normal form
-    // requires that the most frequent value is the default value.
-    if (depth == indexCard) {
-       return false;
-    }
 
-    // If the number of stores is less than half of the cardinality, then we
-    // know the default value is the most frequent value, so in normal form
-    if (depth*2 < indexCard) {
-      return true;
-    }
-    Integer defaultCount = indexCard - depth;
+    // When index sort is finite, we have to check whether there is any value
+    // that is written to more than the default value.  If so, it is not in
+    // normal form.
 
-    // Have to compare number of occurrences of value with defaultValue
+    // Get the most frequently written value from n[0]
+    TNode mostFrequentValue;
+    Integer mostFrequentValueCount = 0;
     store = n[0];
-    depth = 1;
-    while (store.getKind() == kind::STORE) {
-      if (store[2] == value) {
-        depth += 1;
-      }
-      store = store[0];
+    if (store.getKind() == kind::STORE) {
+      // TODO: look up most frequent value and count
+    }
+
+    // Compute the most frequently written value for n
+    if (valCount > mostFrequentValueCount ||
+        (valCount == mostFrequentValueCount && value < mostFrequentValue)) {
+      mostFrequentValue = value;
+      mostFrequentValueCount = valCount;
     }
 
-    // If value occurs more frequently than the default value or the same
-    // and is less than defaultValue, then this is not in normal form
-    if (depth > defaultCount ||
-        (depth == defaultCount && value < defaultValue)) {
+    // Need to make sure the default value count is larger, or the same and the default value is expression-order-less-than nextValue
+    int compare;// = indexCard.compare(mostFrequentValueCount + depth);
+    // Assert result of compare is not unknown
+    if (compare < 0 ||
+        (compare == 0 && (!(defaultValue < mostFrequentValue)))) {
       return false;
     }
-    */
 
+    // TODO: store mostFrequentValue and mostFrequentValueCount for this node
     return true;
   }