Merge remote-tracking branch 'origin/master' into segfaultfix
[cvc5.git] / src / theory / sets / theory_sets_rewriter.cpp
1 /********************* */
2 /*! \file theory_sets_rewriter.cpp
3 ** \verbatim
4 ** Original author: Kshitij Bansal
5 ** Major contributors: none
6 ** Minor contributors (to current version): none
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2014 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief Sets theory rewriter.
13 **
14 ** Sets theory rewriter.
15 **/
16
17 #include "theory/sets/theory_sets_rewriter.h"
18
19 namespace CVC4 {
20 namespace theory {
21 namespace sets {
22
23 typedef std::set<TNode> Elements;
24 typedef std::hash_map<TNode, Elements, TNodeHashFunction> SettermElementsMap;
25
26 bool checkConstantMembership(TNode elementTerm, TNode setTerm)
27 {
28 // Assume from pre-rewrite constant sets look like the following:
29 // (union (setenum bla) (union (setenum bla) ... (union (setenum bla) (setenum bla) ) ... ))
30
31 if(setTerm.getKind() == kind::EMPTYSET) {
32 return false;
33 }
34
35 if(setTerm.getKind() == kind::SINGLETON) {
36 return elementTerm == setTerm[0];
37 }
38
39 Assert(setTerm.getKind() == kind::UNION && setTerm[1].getKind() == kind::SINGLETON,
40 "kind was %d, term: %s", setTerm.getKind(), setTerm.toString().c_str());
41
42 return elementTerm == setTerm[1][0] || checkConstantMembership(elementTerm, setTerm[0]);
43
44 // switch(setTerm.getKind()) {
45 // case kind::EMPTYSET:
46 // return false;
47 // case kind::SINGLETON:
48 // return elementTerm == setTerm[0];
49 // case kind::UNION:
50 // return checkConstantMembership(elementTerm, setTerm[0]) ||
51 // checkConstantMembership(elementTerm, setTerm[1]);
52 // case kind::INTERSECTION:
53 // return checkConstantMembership(elementTerm, setTerm[0]) &&
54 // checkConstantMembership(elementTerm, setTerm[1]);
55 // case kind::SETMINUS:
56 // return checkConstantMembership(elementTerm, setTerm[0]) &&
57 // !checkConstantMembership(elementTerm, setTerm[1]);
58 // default:
59 // Unhandled();
60 // }
61 }
62
63 // static
64 RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
65 NodeManager* nm = NodeManager::currentNM();
66 Kind kind = node.getKind();
67
68 switch(kind) {
69
70 case kind::MEMBER: {
71 if(node[0].isConst() && node[1].isConst()) {
72 // both are constants
73 TNode S = preRewrite(node[1]).node;
74 bool isMember = checkConstantMembership(node[0], S);
75 return RewriteResponse(REWRITE_DONE, nm->mkConst(isMember));
76 }
77 break;
78 }//kind::MEMBER
79
80 case kind::SUBSET: {
81 // rewrite (A subset-or-equal B) as (A union B = B)
82 TNode A = node[0];
83 TNode B = node[1];
84 return RewriteResponse(REWRITE_AGAIN_FULL,
85 nm->mkNode(kind::EQUAL,
86 nm->mkNode(kind::UNION, A, B),
87 B) );
88 }//kind::SUBSET
89
90 case kind::EQUAL:
91 case kind::IFF: {
92 //rewrite: t = t with true (t term)
93 //rewrite: c = c' with c different from c' false (c, c' constants)
94 //otherwise: sort them
95 if(node[0] == node[1]) {
96 Trace("sets-postrewrite") << "Sets::postRewrite returning true" << std::endl;
97 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
98 }
99 else if (node[0].isConst() && node[1].isConst()) {
100 Trace("sets-postrewrite") << "Sets::postRewrite returning false" << std::endl;
101 return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
102 }
103 else if (node[0] > node[1]) {
104 Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
105 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
106 return RewriteResponse(REWRITE_DONE, newNode);
107 }
108 break;
109 }//kind::IFF
110
111 case kind::SETMINUS: {
112 if(node[0] == node[1]) {
113 Node newNode = nm->mkConst(EmptySet(nm->toType(node[0].getType())));
114 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
115 return RewriteResponse(REWRITE_DONE, newNode);
116 } else if(node[0].getKind() == kind::EMPTYSET ||
117 node[1].getKind() == kind::EMPTYSET) {
118 Trace("sets-postrewrite") << "Sets::postRewrite returning " << node[0] << std::endl;
119 return RewriteResponse(REWRITE_DONE, node[0]);
120 }
121 break;
122 }//kind::INTERSECION
123
124 case kind::INTERSECTION: {
125 if(node[0] == node[1]) {
126 Trace("sets-postrewrite") << "Sets::postRewrite returning " << node[0] << std::endl;
127 return RewriteResponse(REWRITE_DONE, node[0]);
128 } else if(node[0].getKind() == kind::EMPTYSET) {
129 return RewriteResponse(REWRITE_DONE, node[0]);
130 } else if(node[1].getKind() == kind::EMPTYSET) {
131 return RewriteResponse(REWRITE_DONE, node[1]);
132 } else if (node[0] > node[1]) {
133 Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
134 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
135 return RewriteResponse(REWRITE_DONE, newNode);
136 }
137 break;
138 }//kind::INTERSECION
139
140 case kind::UNION: {
141 if(node[0] == node[1]) {
142 Trace("sets-postrewrite") << "Sets::postRewrite returning " << node[0] << std::endl;
143 return RewriteResponse(REWRITE_DONE, node[0]);
144 } else if(node[0].getKind() == kind::EMPTYSET) {
145 return RewriteResponse(REWRITE_DONE, node[1]);
146 } else if(node[1].getKind() == kind::EMPTYSET) {
147 return RewriteResponse(REWRITE_DONE, node[0]);
148 } else if (node[0] > node[1]) {
149 Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
150 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
151 return RewriteResponse(REWRITE_DONE, newNode);
152 }
153 break;
154 }//kind::UNION
155
156 default:
157 break;
158 }//switch(node.getKind())
159
160 // This default implementation
161 return RewriteResponse(REWRITE_DONE, node);
162 }
163
164 const Elements& collectConstantElements(TNode setterm, SettermElementsMap& settermElementsMap) {
165 SettermElementsMap::const_iterator it = settermElementsMap.find(setterm);
166 if(it == settermElementsMap.end() ) {
167
168 Kind k = setterm.getKind();
169 unsigned numChildren = setterm.getNumChildren();
170 Elements cur;
171 if(numChildren == 2) {
172 const Elements& left = collectConstantElements(setterm[0], settermElementsMap);
173 const Elements& right = collectConstantElements(setterm[1], settermElementsMap);
174 switch(k) {
175 case kind::UNION:
176 if(left.size() >= right.size()) {
177 cur = left; cur.insert(right.begin(), right.end());
178 } else {
179 cur = right; cur.insert(left.begin(), left.end());
180 }
181 break;
182 case kind::INTERSECTION:
183 std::set_intersection(left.begin(), left.end(), right.begin(), right.end(),
184 std::inserter(cur, cur.begin()) );
185 break;
186 case kind::SETMINUS:
187 std::set_difference(left.begin(), left.end(), right.begin(), right.end(),
188 std::inserter(cur, cur.begin()) );
189 break;
190 default:
191 Unhandled();
192 }
193 } else {
194 switch(k) {
195 case kind::EMPTYSET:
196 /* assign emptyset, which is default */
197 break;
198 case kind::SINGLETON:
199 Assert(setterm[0].isConst());
200 cur.insert(TheorySetsRewriter::preRewrite(setterm[0]).node);
201 break;
202 default:
203 Unhandled();
204 }
205 }
206 Debug("sets-rewrite-constant") << "[sets-rewrite-constant] "<< setterm << " " << setterm.getId() << std::endl;
207
208 it = settermElementsMap.insert(SettermElementsMap::value_type(setterm, cur)).first;
209 }
210 return it->second;
211 }
212
213 Node elementsToNormalConstant(Elements elements,
214 TypeNode setType)
215 {
216 NodeManager* nm = NodeManager::currentNM();
217
218 if(elements.size() == 0) {
219 return nm->mkConst(EmptySet(nm->toType(setType)));
220 } else {
221
222 Elements::iterator it = elements.begin();
223 Node cur = nm->mkNode(kind::SINGLETON, *it);
224 while( ++it != elements.end() ) {
225 cur = nm->mkNode(kind::UNION, cur,
226 nm->mkNode(kind::SINGLETON, *it));
227 }
228 return cur;
229 }
230 }
231
232
233 // static
234 RewriteResponse TheorySetsRewriter::preRewrite(TNode node) {
235 NodeManager* nm = NodeManager::currentNM();
236
237 // do nothing
238 if(node.getKind() == kind::EQUAL && node[0] == node[1])
239 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
240 // Further optimization, if constants but differing ones
241
242 if(node.getKind() == kind::INSERT) {
243 Node insertedElements = nm->mkNode(kind::SINGLETON, node[0]);
244 size_t setNodeIndex = node.getNumChildren()-1;
245 for(size_t i = 1; i < setNodeIndex; ++i) {
246 insertedElements = nm->mkNode(kind::UNION, insertedElements, nm->mkNode(kind::SINGLETON, node[i]));
247 }
248 return RewriteResponse(REWRITE_AGAIN, nm->mkNode(kind::UNION, insertedElements, node[setNodeIndex]));
249 }//kind::INSERT
250
251 if(node.getType().isSet() && node.isConst()) {
252 //rewrite set to normal form
253 SettermElementsMap setTermElementsMap; // cache
254 const Elements& elements = collectConstantElements(node, setTermElementsMap);
255 RewriteResponse response(REWRITE_DONE, elementsToNormalConstant(elements, node.getType()));
256 Debug("sets-rewrite-constant") << "[sets-rewrite-constant] Rewriting " << node << std::endl
257 << "[sets-rewrite-constant] to " << response.node << std::endl;
258 return response;
259 }
260
261 return RewriteResponse(REWRITE_DONE, node);
262 }
263
264 }/* CVC4::theory::sets namespace */
265 }/* CVC4::theory namespace */
266 }/* CVC4 namespace */