Merge pull request #22 from kbansal/sets-model
[cvc5.git] / src / theory / sets / theory_sets_rewriter.cpp
1 /********************* */
2 /*! \file theory_sets_rewriter.cpp
3 ** \verbatim
4 ** Original author: Kshitij Bansal
5 ** Major contributors: none
6 ** Minor contributors (to current version): none
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2013-2014 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief Sets theory rewriter.
13 **
14 ** Sets theory rewriter.
15 **/
16
17 #include "theory/sets/theory_sets_rewriter.h"
18
19 namespace CVC4 {
20 namespace theory {
21 namespace sets {
22
23 typedef std::set<TNode> Elements;
24 typedef std::hash_map<TNode, Elements, TNodeHashFunction> SettermElementsMap;
25
26 bool checkConstantMembership(TNode elementTerm, TNode setTerm)
27 {
28 // Assume from pre-rewrite constant sets look like the following:
29 // (union (setenum bla) (union (setenum bla) ... (union (setenum bla) (setenum bla) ) ... ))
30
31 if(setTerm.getKind() == kind::EMPTYSET) {
32 return false;
33 }
34
35 if(setTerm.getKind() == kind::SET_SINGLETON) {
36 return elementTerm == setTerm[0];
37 }
38
39 Assert(setTerm.getKind() == kind::UNION && setTerm[1].getKind() == kind::SET_SINGLETON,
40 "kind was %d, term: %s", setTerm.getKind(), setTerm.toString().c_str());
41
42 return elementTerm == setTerm[1][0] || checkConstantMembership(elementTerm, setTerm[0]);
43
44 // switch(setTerm.getKind()) {
45 // case kind::EMPTYSET:
46 // return false;
47 // case kind::SET_SINGLETON:
48 // return elementTerm == setTerm[0];
49 // case kind::UNION:
50 // return checkConstantMembership(elementTerm, setTerm[0]) ||
51 // checkConstantMembership(elementTerm, setTerm[1]);
52 // case kind::INTERSECTION:
53 // return checkConstantMembership(elementTerm, setTerm[0]) &&
54 // checkConstantMembership(elementTerm, setTerm[1]);
55 // case kind::SETMINUS:
56 // return checkConstantMembership(elementTerm, setTerm[0]) &&
57 // !checkConstantMembership(elementTerm, setTerm[1]);
58 // default:
59 // Unhandled();
60 // }
61 }
62
63 // static
64 RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
65 NodeManager* nm = NodeManager::currentNM();
66 Kind kind = node.getKind();
67
68 switch(kind) {
69
70 case kind::MEMBER: {
71 if(!node[0].isConst() || !node[1].isConst())
72 break;
73
74 // both are constants
75 TNode S = preRewrite(node[1]).node;
76 bool isMember = checkConstantMembership(node[0], S);
77 return RewriteResponse(REWRITE_DONE, nm->mkConst(isMember));
78 }//kind::MEMBER
79
80 case kind::SUBSET: {
81 // rewrite (A subset-or-equal B) as (A union B = B)
82 TNode A = node[0];
83 TNode B = node[1];
84 return RewriteResponse(REWRITE_AGAIN_FULL,
85 nm->mkNode(kind::EQUAL,
86 nm->mkNode(kind::UNION, A, B),
87 B) );
88 }//kind::SUBSET
89
90 case kind::EQUAL:
91 case kind::IFF: {
92 //rewrite: t = t with true (t term)
93 //rewrite: c = c' with c different from c' false (c, c' constants)
94 //otherwise: sort them
95 if(node[0] == node[1]) {
96 Trace("sets-postrewrite") << "Sets::postRewrite returning true" << std::endl;
97 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
98 }
99 else if (node[0].isConst() && node[1].isConst()) {
100 Trace("sets-postrewrite") << "Sets::postRewrite returning false" << std::endl;
101 return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
102 }
103 else if (node[0] > node[1]) {
104 Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
105 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
106 return RewriteResponse(REWRITE_DONE, newNode);
107 }
108 break;
109 }//kind::IFF
110
111 case kind::SETMINUS: {
112 if(node[0] == node[1]) {
113 Node newNode = nm->mkConst(EmptySet(nm->toType(node[0].getType())));
114 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
115 return RewriteResponse(REWRITE_DONE, newNode);
116 } else if(node[0].getKind() == kind::EMPTYSET ||
117 node[1].getKind() == kind::EMPTYSET) {
118 Trace("sets-postrewrite") << "Sets::postRewrite returning " << node[0] << std::endl;
119 return RewriteResponse(REWRITE_DONE, node[0]);
120 } else if (node[0] > node[1]) {
121 Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
122 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
123 return RewriteResponse(REWRITE_DONE, newNode);
124 }
125 break;
126 }//kind::INTERSECION
127
128 case kind::INTERSECTION: {
129 if(node[0] == node[1]) {
130 Trace("sets-postrewrite") << "Sets::postRewrite returning " << node[0] << std::endl;
131 return RewriteResponse(REWRITE_DONE, node[0]);
132 } else if(node[0].getKind() == kind::EMPTYSET) {
133 return RewriteResponse(REWRITE_DONE, node[0]);
134 } else if(node[1].getKind() == kind::EMPTYSET) {
135 return RewriteResponse(REWRITE_DONE, node[1]);
136 } else if (node[0] > node[1]) {
137 Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
138 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
139 return RewriteResponse(REWRITE_DONE, newNode);
140 }
141 break;
142 }//kind::INTERSECION
143
144 case kind::UNION: {
145 if(node[0] == node[1]) {
146 Trace("sets-postrewrite") << "Sets::postRewrite returning " << node[0] << std::endl;
147 return RewriteResponse(REWRITE_DONE, node[0]);
148 } else if(node[0].getKind() == kind::EMPTYSET) {
149 return RewriteResponse(REWRITE_DONE, node[1]);
150 } else if(node[1].getKind() == kind::EMPTYSET) {
151 return RewriteResponse(REWRITE_DONE, node[0]);
152 } else if (node[0] > node[1]) {
153 Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
154 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
155 return RewriteResponse(REWRITE_DONE, newNode);
156 }
157 break;
158 }//kind::UNION
159
160 default:
161 break;
162 }//switch(node.getKind())
163
164 // This default implementation
165 return RewriteResponse(REWRITE_DONE, node);
166 }
167
168 const Elements& collectConstantElements(TNode setterm, SettermElementsMap& settermElementsMap) {
169 SettermElementsMap::const_iterator it = settermElementsMap.find(setterm);
170 if(it == settermElementsMap.end() ) {
171
172 Kind k = setterm.getKind();
173 unsigned numChildren = setterm.getNumChildren();
174 Elements cur;
175 if(numChildren == 2) {
176 const Elements& left = collectConstantElements(setterm[0], settermElementsMap);
177 const Elements& right = collectConstantElements(setterm[1], settermElementsMap);
178 switch(k) {
179 case kind::UNION:
180 if(left.size() >= right.size()) {
181 cur = left; cur.insert(right.begin(), right.end());
182 } else {
183 cur = right; cur.insert(left.begin(), left.end());
184 }
185 break;
186 case kind::INTERSECTION:
187 std::set_intersection(left.begin(), left.end(), right.begin(), right.end(),
188 std::inserter(cur, cur.begin()) );
189 break;
190 case kind::SETMINUS:
191 std::set_difference(left.begin(), left.end(), right.begin(), right.end(),
192 std::inserter(cur, cur.begin()) );
193 break;
194 default:
195 Unhandled();
196 }
197 } else {
198 switch(k) {
199 case kind::EMPTYSET:
200 /* assign emptyset, which is default */
201 break;
202 case kind::SET_SINGLETON:
203 Assert(setterm[0].isConst());
204 cur.insert(setterm[0]);
205 break;
206 default:
207 Unhandled();
208 }
209 }
210
211 it = settermElementsMap.insert(SettermElementsMap::value_type(setterm, cur)).first;
212 }
213 return it->second;
214 }
215
216 Node elementsToNormalConstant(Elements elements,
217 TypeNode setType)
218 {
219 NodeManager* nm = NodeManager::currentNM();
220
221 if(elements.size() == 0) {
222 return nm->mkConst(EmptySet(nm->toType(setType)));
223 } else {
224
225 Elements::iterator it = elements.begin();
226 Node cur = nm->mkNode(kind::SET_SINGLETON, *it);
227 while( ++it != elements.end() ) {
228 cur = nm->mkNode(kind::UNION, cur,
229 nm->mkNode(kind::SET_SINGLETON, *it));
230 }
231 return cur;
232 }
233 }
234
235
236 // static
237 RewriteResponse TheorySetsRewriter::preRewrite(TNode node) {
238 NodeManager* nm = NodeManager::currentNM();
239
240 // do nothing
241 if(node.getKind() == kind::EQUAL && node[0] == node[1])
242 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
243 // Further optimization, if constants but differing ones
244
245 if(node.getType().isSet() && node.isConst()) {
246 //rewrite set to normal form
247 SettermElementsMap setTermElementsMap; // cache
248 const Elements& elements = collectConstantElements(node, setTermElementsMap);
249 return RewriteResponse(REWRITE_DONE, elementsToNormalConstant(elements, node.getType()));
250 }
251
252 return RewriteResponse(REWRITE_DONE, node);
253 }
254
255 }/* CVC4::theory::sets namespace */
256 }/* CVC4::theory namespace */
257 }/* CVC4 namespace */