From b5914204bd29c3bc3480fdec234882cedaad2c2a Mon Sep 17 00:00:00 2001 From: Kshitij Bansal Date: Thu, 13 Mar 2014 22:17:02 -0400 Subject: [PATCH] constant normal form and rewrite --- src/theory/sets/kinds | 1 + src/theory/sets/theory_sets_rewriter.cpp | 125 ++++++++++++++++++++--- 2 files changed, 111 insertions(+), 15 deletions(-) diff --git a/src/theory/sets/kinds b/src/theory/sets/kinds index e46f3a4f8..a56601b98 100644 --- a/src/theory/sets/kinds +++ b/src/theory/sets/kinds @@ -41,6 +41,7 @@ operator SUBSET 2 "subset" operator MEMBER 2 "set membership" operator SET_SINGLETON 1 "singleton set" +operator FINITE_SET 1: "finite set" typerule UNION ::CVC4::theory::sets::SetUnionTypeRule typerule INTERSECTION ::CVC4::theory::sets::SetIntersectionTypeRule diff --git a/src/theory/sets/theory_sets_rewriter.cpp b/src/theory/sets/theory_sets_rewriter.cpp index 82b79cbd6..87c6db8f2 100644 --- a/src/theory/sets/theory_sets_rewriter.cpp +++ b/src/theory/sets/theory_sets_rewriter.cpp @@ -20,25 +20,44 @@ namespace CVC4 { namespace theory { namespace sets { +typedef std::set Elements; +typedef std::hash_map SettermElementsMap; + bool checkConstantMembership(TNode elementTerm, TNode setTerm) { - switch(setTerm.getKind()) { - case kind::EMPTYSET: + // Assume from pre-rewrite constant sets look like the following: + // (union (setenum bla) (union (setenum bla) ... (union (setenum bla) (setenum bla) ) ... )) + + if(setTerm.getKind() == kind::EMPTYSET) { return false; - case kind::SET_SINGLETON: + } + + if(setTerm.getKind() == kind::SET_SINGLETON) { return elementTerm == setTerm[0]; - case kind::UNION: - return checkConstantMembership(elementTerm, setTerm[0]) || - checkConstantMembership(elementTerm, setTerm[1]); - case kind::INTERSECTION: - return checkConstantMembership(elementTerm, setTerm[0]) && - checkConstantMembership(elementTerm, setTerm[1]); - case kind::SETMINUS: - return checkConstantMembership(elementTerm, setTerm[0]) && - !checkConstantMembership(elementTerm, setTerm[1]); - default: - Unhandled(); } + + Assert(setTerm.getKind() == kind::UNION && setTerm[1].getKind() == kind::SET_SINGLETON, + "kind was %d, term: %s", setTerm.getKind(), setTerm.toString().c_str()); + + return elementTerm == setTerm[1][0] || checkConstantMembership(elementTerm, setTerm[0]); + + // switch(setTerm.getKind()) { + // case kind::EMPTYSET: + // return false; + // case kind::SET_SINGLETON: + // return elementTerm == setTerm[0]; + // case kind::UNION: + // return checkConstantMembership(elementTerm, setTerm[0]) || + // checkConstantMembership(elementTerm, setTerm[1]); + // case kind::INTERSECTION: + // return checkConstantMembership(elementTerm, setTerm[0]) && + // checkConstantMembership(elementTerm, setTerm[1]); + // case kind::SETMINUS: + // return checkConstantMembership(elementTerm, setTerm[0]) && + // !checkConstantMembership(elementTerm, setTerm[1]); + // default: + // Unhandled(); + // } } // static @@ -53,7 +72,8 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { break; // both are constants - bool isMember = checkConstantMembership(node[0], node[1]); + TNode S = preRewrite(node[1]).node; + bool isMember = checkConstantMembership(node[0], S); return RewriteResponse(REWRITE_DONE, nm->mkConst(isMember)); }//kind::MEMBER @@ -145,6 +165,74 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { return RewriteResponse(REWRITE_DONE, node); } +const Elements& collectConstantElements(TNode setterm, SettermElementsMap& settermElementsMap) { + SettermElementsMap::const_iterator it = settermElementsMap.find(setterm); + if(it == settermElementsMap.end() ) { + + Kind k = setterm.getKind(); + unsigned numChildren = setterm.getNumChildren(); + Elements cur; + if(numChildren == 2) { + const Elements& left = collectConstantElements(setterm[0], settermElementsMap); + const Elements& right = collectConstantElements(setterm[1], settermElementsMap); + switch(k) { + case kind::UNION: + if(left.size() >= right.size()) { + cur = left; cur.insert(right.begin(), right.end()); + } else { + cur = right; cur.insert(left.begin(), left.end()); + } + break; + case kind::INTERSECTION: + std::set_intersection(left.begin(), left.end(), right.begin(), right.end(), + std::inserter(cur, cur.begin()) ); + break; + case kind::SETMINUS: + std::set_difference(left.begin(), left.end(), right.begin(), right.end(), + std::inserter(cur, cur.begin()) ); + break; + default: + Unhandled(); + } + } else { + switch(k) { + case kind::EMPTYSET: + /* assign emptyset, which is default */ + break; + case kind::SET_SINGLETON: + Assert(setterm[0].isConst()); + cur.insert(setterm[0]); + break; + default: + Unhandled(); + } + } + + it = settermElementsMap.insert(SettermElementsMap::value_type(setterm, cur)).first; + } + return it->second; +} + +Node elementsToNormalConstant(Elements elements, + TypeNode setType) +{ + NodeManager* nm = NodeManager::currentNM(); + + if(elements.size() == 0) { + return nm->mkConst(EmptySet(nm->toType(setType))); + } else { + + Elements::iterator it = elements.begin(); + Node cur = nm->mkNode(kind::SET_SINGLETON, *it); + while( ++it != elements.end() ) { + cur = nm->mkNode(kind::UNION, cur, + nm->mkNode(kind::SET_SINGLETON, *it)); + } + return cur; + } +} + + // static RewriteResponse TheorySetsRewriter::preRewrite(TNode node) { NodeManager* nm = NodeManager::currentNM(); @@ -154,6 +242,13 @@ RewriteResponse TheorySetsRewriter::preRewrite(TNode node) { return RewriteResponse(REWRITE_DONE, nm->mkConst(true)); // Further optimization, if constants but differing ones + if(node.getType().isSet() && node.isConst()) { + //rewrite set to normal form + SettermElementsMap setTermElementsMap; // cache + const Elements& elements = collectConstantElements(node, setTermElementsMap); + return RewriteResponse(REWRITE_DONE, elementsToNormalConstant(elements, node.getType())); + } + return RewriteResponse(REWRITE_DONE, node); } -- 2.30.2