1 /********************* */
2 /*! \file theory_sets_rewriter.cpp
4 ** Original author: Kshitij Bansal
5 ** Major contributors: none
6 ** Minor contributors (to current version): none
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2014 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
12 ** \brief Sets theory rewriter.
14 ** Sets theory rewriter.
17 #include "theory/sets/theory_sets_rewriter.h"
23 typedef std::set
<TNode
> Elements
;
24 typedef std::hash_map
<TNode
, Elements
, TNodeHashFunction
> SettermElementsMap
;
26 bool checkConstantMembership(TNode elementTerm
, TNode setTerm
)
28 // Assume from pre-rewrite constant sets look like the following:
29 // (union (setenum bla) (union (setenum bla) ... (union (setenum bla) (setenum bla) ) ... ))
31 if(setTerm
.getKind() == kind::EMPTYSET
) {
35 if(setTerm
.getKind() == kind::SINGLETON
) {
36 return elementTerm
== setTerm
[0];
39 Assert(setTerm
.getKind() == kind::UNION
&& setTerm
[1].getKind() == kind::SINGLETON
,
40 "kind was %d, term: %s", setTerm
.getKind(), setTerm
.toString().c_str());
42 return elementTerm
== setTerm
[1][0] || checkConstantMembership(elementTerm
, setTerm
[0]);
44 // switch(setTerm.getKind()) {
45 // case kind::EMPTYSET:
47 // case kind::SINGLETON:
48 // return elementTerm == setTerm[0];
50 // return checkConstantMembership(elementTerm, setTerm[0]) ||
51 // checkConstantMembership(elementTerm, setTerm[1]);
52 // case kind::INTERSECTION:
53 // return checkConstantMembership(elementTerm, setTerm[0]) &&
54 // checkConstantMembership(elementTerm, setTerm[1]);
55 // case kind::SETMINUS:
56 // return checkConstantMembership(elementTerm, setTerm[0]) &&
57 // !checkConstantMembership(elementTerm, setTerm[1]);
64 RewriteResponse
TheorySetsRewriter::postRewrite(TNode node
) {
65 NodeManager
* nm
= NodeManager::currentNM();
66 Kind kind
= node
.getKind();
71 if(node
[0].isConst() && node
[1].isConst()) {
73 TNode S
= preRewrite(node
[1]).node
;
74 bool isMember
= checkConstantMembership(node
[0], S
);
75 return RewriteResponse(REWRITE_DONE
, nm
->mkConst(isMember
));
81 // rewrite (A subset-or-equal B) as (A union B = B)
84 return RewriteResponse(REWRITE_AGAIN_FULL
,
85 nm
->mkNode(kind::EQUAL
,
86 nm
->mkNode(kind::UNION
, A
, B
),
92 //rewrite: t = t with true (t term)
93 //rewrite: c = c' with c different from c' false (c, c' constants)
94 //otherwise: sort them
95 if(node
[0] == node
[1]) {
96 Trace("sets-postrewrite") << "Sets::postRewrite returning true" << std::endl
;
97 return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
99 else if (node
[0].isConst() && node
[1].isConst()) {
100 Trace("sets-postrewrite") << "Sets::postRewrite returning false" << std::endl
;
101 return RewriteResponse(REWRITE_DONE
, nm
->mkConst(false));
103 else if (node
[0] > node
[1]) {
104 Node newNode
= nm
->mkNode(node
.getKind(), node
[1], node
[0]);
105 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode
<< std::endl
;
106 return RewriteResponse(REWRITE_DONE
, newNode
);
111 case kind::SETMINUS
: {
112 if(node
[0] == node
[1]) {
113 Node newNode
= nm
->mkConst(EmptySet(nm
->toType(node
[0].getType())));
114 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode
<< std::endl
;
115 return RewriteResponse(REWRITE_DONE
, newNode
);
116 } else if(node
[0].getKind() == kind::EMPTYSET
||
117 node
[1].getKind() == kind::EMPTYSET
) {
118 Trace("sets-postrewrite") << "Sets::postRewrite returning " << node
[0] << std::endl
;
119 return RewriteResponse(REWRITE_DONE
, node
[0]);
124 case kind::INTERSECTION
: {
125 if(node
[0] == node
[1]) {
126 Trace("sets-postrewrite") << "Sets::postRewrite returning " << node
[0] << std::endl
;
127 return RewriteResponse(REWRITE_DONE
, node
[0]);
128 } else if(node
[0].getKind() == kind::EMPTYSET
) {
129 return RewriteResponse(REWRITE_DONE
, node
[0]);
130 } else if(node
[1].getKind() == kind::EMPTYSET
) {
131 return RewriteResponse(REWRITE_DONE
, node
[1]);
132 } else if (node
[0] > node
[1]) {
133 Node newNode
= nm
->mkNode(node
.getKind(), node
[1], node
[0]);
134 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode
<< std::endl
;
135 return RewriteResponse(REWRITE_DONE
, newNode
);
141 if(node
[0] == node
[1]) {
142 Trace("sets-postrewrite") << "Sets::postRewrite returning " << node
[0] << std::endl
;
143 return RewriteResponse(REWRITE_DONE
, node
[0]);
144 } else if(node
[0].getKind() == kind::EMPTYSET
) {
145 return RewriteResponse(REWRITE_DONE
, node
[1]);
146 } else if(node
[1].getKind() == kind::EMPTYSET
) {
147 return RewriteResponse(REWRITE_DONE
, node
[0]);
148 } else if (node
[0] > node
[1]) {
149 Node newNode
= nm
->mkNode(node
.getKind(), node
[1], node
[0]);
150 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode
<< std::endl
;
151 return RewriteResponse(REWRITE_DONE
, newNode
);
158 }//switch(node.getKind())
160 // This default implementation
161 return RewriteResponse(REWRITE_DONE
, node
);
164 const Elements
& collectConstantElements(TNode setterm
, SettermElementsMap
& settermElementsMap
) {
165 SettermElementsMap::const_iterator it
= settermElementsMap
.find(setterm
);
166 if(it
== settermElementsMap
.end() ) {
168 Kind k
= setterm
.getKind();
169 unsigned numChildren
= setterm
.getNumChildren();
171 if(numChildren
== 2) {
172 const Elements
& left
= collectConstantElements(setterm
[0], settermElementsMap
);
173 const Elements
& right
= collectConstantElements(setterm
[1], settermElementsMap
);
176 if(left
.size() >= right
.size()) {
177 cur
= left
; cur
.insert(right
.begin(), right
.end());
179 cur
= right
; cur
.insert(left
.begin(), left
.end());
182 case kind::INTERSECTION
:
183 std::set_intersection(left
.begin(), left
.end(), right
.begin(), right
.end(),
184 std::inserter(cur
, cur
.begin()) );
187 std::set_difference(left
.begin(), left
.end(), right
.begin(), right
.end(),
188 std::inserter(cur
, cur
.begin()) );
196 /* assign emptyset, which is default */
198 case kind::SINGLETON
:
199 Assert(setterm
[0].isConst());
200 cur
.insert(TheorySetsRewriter::preRewrite(setterm
[0]).node
);
206 Debug("sets-rewrite-constant") << "[sets-rewrite-constant] "<< setterm
<< " " << setterm
.getId() << std::endl
;
208 it
= settermElementsMap
.insert(SettermElementsMap::value_type(setterm
, cur
)).first
;
213 Node
elementsToNormalConstant(Elements elements
,
216 NodeManager
* nm
= NodeManager::currentNM();
218 if(elements
.size() == 0) {
219 return nm
->mkConst(EmptySet(nm
->toType(setType
)));
222 Elements::iterator it
= elements
.begin();
223 Node cur
= nm
->mkNode(kind::SINGLETON
, *it
);
224 while( ++it
!= elements
.end() ) {
225 cur
= nm
->mkNode(kind::UNION
, cur
,
226 nm
->mkNode(kind::SINGLETON
, *it
));
234 RewriteResponse
TheorySetsRewriter::preRewrite(TNode node
) {
235 NodeManager
* nm
= NodeManager::currentNM();
238 if(node
.getKind() == kind::EQUAL
&& node
[0] == node
[1])
239 return RewriteResponse(REWRITE_DONE
, nm
->mkConst(true));
240 // Further optimization, if constants but differing ones
242 if(node
.getKind() == kind::INSERT
) {
243 Node insertedElements
= nm
->mkNode(kind::SINGLETON
, node
[0]);
244 size_t setNodeIndex
= node
.getNumChildren()-1;
245 for(size_t i
= 1; i
< setNodeIndex
; ++i
) {
246 insertedElements
= nm
->mkNode(kind::UNION
, insertedElements
, nm
->mkNode(kind::SINGLETON
, node
[i
]));
248 return RewriteResponse(REWRITE_AGAIN
, nm
->mkNode(kind::UNION
, insertedElements
, node
[setNodeIndex
]));
251 if(node
.getType().isSet() && node
.isConst()) {
252 //rewrite set to normal form
253 SettermElementsMap setTermElementsMap
; // cache
254 const Elements
& elements
= collectConstantElements(node
, setTermElementsMap
);
255 RewriteResponse
response(REWRITE_DONE
, elementsToNormalConstant(elements
, node
.getType()));
256 Debug("sets-rewrite-constant") << "[sets-rewrite-constant] Rewriting " << node
<< std::endl
257 << "[sets-rewrite-constant] to " << response
.node
<< std::endl
;
261 return RewriteResponse(REWRITE_DONE
, node
);
264 }/* CVC4::theory::sets namespace */
265 }/* CVC4::theory namespace */
266 }/* CVC4 namespace */