constant normal form and rewrite
authorKshitij Bansal <kshitij@cs.nyu.edu>
Fri, 14 Mar 2014 02:17:02 +0000 (22:17 -0400)
committerKshitij Bansal <kshitij@cs.nyu.edu>
Thu, 20 Mar 2014 21:18:58 +0000 (17:18 -0400)
src/theory/sets/kinds
src/theory/sets/theory_sets_rewriter.cpp

index e46f3a4f81be3b2aa9956ca0f037ea849f8e5718..a56601b98c7d22f64f4b8237ba6dd8e93a0f8673 100644 (file)
@@ -41,6 +41,7 @@ operator SUBSET 2 "subset"
 operator MEMBER 2 "set membership"
 
 operator SET_SINGLETON 1 "singleton set"
+operator FINITE_SET 1: "finite set"
 
 typerule UNION ::CVC4::theory::sets::SetUnionTypeRule
 typerule INTERSECTION ::CVC4::theory::sets::SetIntersectionTypeRule
index 82b79cbd656ed1554a0ab213b2d3956c984adbbb..87c6db8f2463cc8f938d857359234d30d1997dd5 100644 (file)
@@ -20,25 +20,44 @@ namespace CVC4 {
 namespace theory {
 namespace sets {
 
+typedef std::set<TNode> Elements;
+typedef std::hash_map<TNode, Elements, TNodeHashFunction> SettermElementsMap;
+
 bool checkConstantMembership(TNode elementTerm, TNode setTerm)
 {
-  switch(setTerm.getKind()) {
-  case kind::EMPTYSET:
+  // Assume from pre-rewrite constant sets look like the following:
+  // (union (setenum bla) (union (setenum bla) ... (union (setenum bla) (setenum bla) ) ... ))
+
+  if(setTerm.getKind() == kind::EMPTYSET) {
     return false;
-  case kind::SET_SINGLETON:
+  }
+
+  if(setTerm.getKind() == kind::SET_SINGLETON) {
     return elementTerm == setTerm[0];
-  case kind::UNION:
-    return checkConstantMembership(elementTerm, setTerm[0]) ||
-      checkConstantMembership(elementTerm, setTerm[1]);
-  case kind::INTERSECTION:
-    return checkConstantMembership(elementTerm, setTerm[0]) &&
-      checkConstantMembership(elementTerm, setTerm[1]);
-  case kind::SETMINUS:
-    return checkConstantMembership(elementTerm, setTerm[0]) &&
-      !checkConstantMembership(elementTerm, setTerm[1]);
-  default:
-    Unhandled();
   }
+
+  Assert(setTerm.getKind() == kind::UNION && setTerm[1].getKind() == kind::SET_SINGLETON,
+         "kind was %d, term: %s", setTerm.getKind(), setTerm.toString().c_str());
+
+  return elementTerm == setTerm[1][0] || checkConstantMembership(elementTerm, setTerm[0]);
+
+  // switch(setTerm.getKind()) {
+  // case kind::EMPTYSET:
+  //   return false;
+  // case kind::SET_SINGLETON:
+  //   return elementTerm == setTerm[0];
+  // case kind::UNION:
+  //   return checkConstantMembership(elementTerm, setTerm[0]) ||
+  //     checkConstantMembership(elementTerm, setTerm[1]);
+  // case kind::INTERSECTION:
+  //   return checkConstantMembership(elementTerm, setTerm[0]) &&
+  //     checkConstantMembership(elementTerm, setTerm[1]);
+  // case kind::SETMINUS:
+  //   return checkConstantMembership(elementTerm, setTerm[0]) &&
+  //     !checkConstantMembership(elementTerm, setTerm[1]);
+  // default:
+  //   Unhandled();
+  // }
 }
 
 // static
@@ -53,7 +72,8 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
       break;
 
     // both are constants
-    bool isMember = checkConstantMembership(node[0], node[1]);
+    TNode S = preRewrite(node[1]).node;
+    bool isMember = checkConstantMembership(node[0], S);
     return RewriteResponse(REWRITE_DONE, nm->mkConst(isMember));
   }//kind::MEMBER
 
@@ -145,6 +165,74 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
   return RewriteResponse(REWRITE_DONE, node);
 }
 
+const Elements& collectConstantElements(TNode setterm, SettermElementsMap& settermElementsMap) {
+  SettermElementsMap::const_iterator it = settermElementsMap.find(setterm);
+  if(it == settermElementsMap.end() ) {
+
+    Kind k = setterm.getKind();
+    unsigned numChildren = setterm.getNumChildren();
+    Elements cur;
+    if(numChildren == 2) {
+      const Elements& left = collectConstantElements(setterm[0], settermElementsMap);
+      const Elements& right = collectConstantElements(setterm[1], settermElementsMap);
+      switch(k) {
+      case kind::UNION:
+        if(left.size() >= right.size()) {
+          cur = left; cur.insert(right.begin(), right.end());
+        } else {
+          cur = right; cur.insert(left.begin(), left.end());
+        }
+        break;
+      case kind::INTERSECTION:
+        std::set_intersection(left.begin(), left.end(), right.begin(), right.end(),
+                              std::inserter(cur, cur.begin()) );
+        break;
+      case kind::SETMINUS:
+        std::set_difference(left.begin(), left.end(), right.begin(), right.end(),
+                            std::inserter(cur, cur.begin()) );
+        break;
+      default:
+        Unhandled();
+      }
+    } else {
+      switch(k) {
+      case kind::EMPTYSET:
+        /* assign emptyset, which is default */
+        break;
+      case kind::SET_SINGLETON:
+        Assert(setterm[0].isConst());
+        cur.insert(setterm[0]);
+        break;
+      default:
+        Unhandled();
+      }
+    }
+
+    it = settermElementsMap.insert(SettermElementsMap::value_type(setterm, cur)).first;
+  }
+  return it->second;
+}
+
+Node elementsToNormalConstant(Elements elements,
+                              TypeNode setType)
+{
+  NodeManager* nm = NodeManager::currentNM();
+
+  if(elements.size() == 0) {
+    return nm->mkConst(EmptySet(nm->toType(setType)));
+  } else {
+
+    Elements::iterator it = elements.begin();
+    Node cur = nm->mkNode(kind::SET_SINGLETON, *it);
+    while( ++it != elements.end() ) {
+      cur = nm->mkNode(kind::UNION, cur,
+                       nm->mkNode(kind::SET_SINGLETON, *it));
+    }
+    return cur;
+  }
+}
+
+
 // static
 RewriteResponse TheorySetsRewriter::preRewrite(TNode node) {
   NodeManager* nm = NodeManager::currentNM();
@@ -154,6 +242,13 @@ RewriteResponse TheorySetsRewriter::preRewrite(TNode node) {
     return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
   // Further optimization, if constants but differing ones
 
+  if(node.getType().isSet() && node.isConst()) {
+    //rewrite set to normal form
+    SettermElementsMap setTermElementsMap;   // cache
+    const Elements& elements = collectConstantElements(node, setTermElementsMap);
+    return RewriteResponse(REWRITE_DONE, elementsToNormalConstant(elements, node.getType()));
+  }
+
   return RewriteResponse(REWRITE_DONE, node);
 }