Set Constant's normal form and other short fixes
[cvc5.git] / src / theory / sets / theory_sets_rewriter.cpp
1 /********************* */
2 /*! \file theory_sets_rewriter.cpp
3 ** \verbatim
4 ** Original author: Kshitij Bansal
5 ** Major contributors: none
6 ** Minor contributors (to current version): none
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2014 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief Sets theory rewriter.
13 **
14 ** Sets theory rewriter.
15 **/
16
17 #include "theory/sets/theory_sets_rewriter.h"
18 #include "theory/sets/normal_form.h"
19
20 namespace CVC4 {
21 namespace theory {
22 namespace sets {
23
24 typedef std::set<TNode> Elements;
25 typedef std::hash_map<TNode, Elements, TNodeHashFunction> SettermElementsMap;
26
27 bool checkConstantMembership(TNode elementTerm, TNode setTerm)
28 {
29 if(setTerm.getKind() == kind::EMPTYSET) {
30 return false;
31 }
32
33 if(setTerm.getKind() == kind::SINGLETON) {
34 return elementTerm == setTerm[0];
35 }
36
37 Assert(setTerm.getKind() == kind::UNION && setTerm[1].getKind() == kind::SINGLETON,
38 "kind was %d, term: %s", setTerm.getKind(), setTerm.toString().c_str());
39
40 return
41 elementTerm == setTerm[1][0] ||
42 checkConstantMembership(elementTerm, setTerm[0]);
43 }
44
45 // static
46 RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
47 NodeManager* nm = NodeManager::currentNM();
48 Kind kind = node.getKind();
49
50
51 if(node.isConst()) {
52 // Dare you touch the const and mangle it to something else.
53 return RewriteResponse(REWRITE_DONE, node);
54 }
55
56 switch(kind) {
57
58 case kind::MEMBER: {
59 if(node[0].isConst() && node[1].isConst()) {
60 // both are constants
61 TNode S = preRewrite(node[1]).node;
62 bool isMember = checkConstantMembership(node[0], S);
63 return RewriteResponse(REWRITE_DONE, nm->mkConst(isMember));
64 }
65 break;
66 }//kind::MEMBER
67
68 case kind::SUBSET: {
69 Assert(false, "TheorySets::postRrewrite(): Subset is handled in preRewrite.");
70
71 // but in off-chance we do end up here, let us do our best
72
73 // rewrite (A subset-or-equal B) as (A union B = B)
74 TNode A = node[0];
75 TNode B = node[1];
76 return RewriteResponse(REWRITE_AGAIN_FULL,
77 nm->mkNode(kind::EQUAL,
78 nm->mkNode(kind::UNION, A, B),
79 B) );
80 }//kind::SUBSET
81
82 case kind::EQUAL:
83 case kind::IFF: {
84 //rewrite: t = t with true (t term)
85 //rewrite: c = c' with c different from c' false (c, c' constants)
86 //otherwise: sort them
87 if(node[0] == node[1]) {
88 Trace("sets-postrewrite") << "Sets::postRewrite returning true" << std::endl;
89 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
90 }
91 else if (node[0].isConst() && node[1].isConst()) {
92 Trace("sets-postrewrite") << "Sets::postRewrite returning false" << std::endl;
93 return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
94 }
95 else if (node[0] > node[1]) {
96 Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
97 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
98 return RewriteResponse(REWRITE_DONE, newNode);
99 }
100 break;
101 }//kind::IFF
102
103 case kind::SETMINUS: {
104 if(node[0] == node[1]) {
105 Node newNode = nm->mkConst(EmptySet(nm->toType(node[0].getType())));
106 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
107 return RewriteResponse(REWRITE_DONE, newNode);
108 } else if(node[0].getKind() == kind::EMPTYSET ||
109 node[1].getKind() == kind::EMPTYSET) {
110 Trace("sets-postrewrite") << "Sets::postRewrite returning " << node[0] << std::endl;
111 return RewriteResponse(REWRITE_DONE, node[0]);
112 } else if(node[0].isConst() && node[1].isConst()) {
113 std::set<Node> left = NormalForm::getElementsFromNormalConstant(node[0]);
114 std::set<Node> right = NormalForm::getElementsFromNormalConstant(node[1]);
115 std::set<Node> newSet;
116 std::set_difference(left.begin(), left.end(), right.begin(), right.end(),
117 std::inserter(newSet, newSet.begin()));
118 Node newNode = NormalForm::elementsToSet(newSet, node.getType());
119 Assert(newNode.isConst());
120 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
121 return RewriteResponse(REWRITE_DONE, newNode);
122 }
123 break;
124 }//kind::INTERSECION
125
126 case kind::INTERSECTION: {
127 if(node[0] == node[1]) {
128 Trace("sets-postrewrite") << "Sets::postRewrite returning " << node[0] << std::endl;
129 return RewriteResponse(REWRITE_DONE, node[0]);
130 } else if(node[0].getKind() == kind::EMPTYSET) {
131 return RewriteResponse(REWRITE_DONE, node[0]);
132 } else if(node[1].getKind() == kind::EMPTYSET) {
133 return RewriteResponse(REWRITE_DONE, node[1]);
134 } else if(node[0].isConst() && node[1].isConst()) {
135 std::set<Node> left = NormalForm::getElementsFromNormalConstant(node[0]);
136 std::set<Node> right = NormalForm::getElementsFromNormalConstant(node[1]);
137 std::set<Node> newSet;
138 std::set_intersection(left.begin(), left.end(), right.begin(), right.end(),
139 std::inserter(newSet, newSet.begin()));
140 Node newNode = NormalForm::elementsToSet(newSet, node.getType());
141 Assert(newNode.isConst());
142 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
143 return RewriteResponse(REWRITE_DONE, newNode);
144 } else if (node[0] > node[1]) {
145 Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
146 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
147 return RewriteResponse(REWRITE_DONE, newNode);
148 }
149 break;
150 }//kind::INTERSECION
151
152 case kind::UNION: {
153 // NOTE: case where it is CONST is taken care of at the top
154 if(node[0] == node[1]) {
155 Trace("sets-postrewrite") << "Sets::postRewrite returning " << node[0] << std::endl;
156 return RewriteResponse(REWRITE_DONE, node[0]);
157 } else if(node[0].getKind() == kind::EMPTYSET) {
158 return RewriteResponse(REWRITE_DONE, node[1]);
159 } else if(node[1].getKind() == kind::EMPTYSET) {
160 return RewriteResponse(REWRITE_DONE, node[0]);
161 } else if(node[0].isConst() && node[1].isConst()) {
162 std::set<Node> left = NormalForm::getElementsFromNormalConstant(node[0]);
163 std::set<Node> right = NormalForm::getElementsFromNormalConstant(node[1]);
164 std::set<Node> newSet;
165 std::set_union(left.begin(), left.end(), right.begin(), right.end(),
166 std::inserter(newSet, newSet.begin()));
167 Node newNode = NormalForm::elementsToSet(newSet, node.getType());
168 Assert(newNode.isConst());
169 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
170 return RewriteResponse(REWRITE_DONE, newNode);
171 } else if (node[0] > node[1]) {
172 Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
173 Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
174 return RewriteResponse(REWRITE_DONE, newNode);
175 }
176 break;
177 }//kind::UNION
178
179 default:
180 break;
181 }//switch(node.getKind())
182
183 // This default implementation
184 return RewriteResponse(REWRITE_DONE, node);
185 }
186
187
188 // static
189 RewriteResponse TheorySetsRewriter::preRewrite(TNode node) {
190 NodeManager* nm = NodeManager::currentNM();
191
192 if(node.getKind() == kind::EQUAL) {
193
194 if(node[0] == node[1]) {
195 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
196 }
197
198 }//kind::EQUAL
199 else if(node.getKind() == kind::INSERT) {
200
201 Node insertedElements = nm->mkNode(kind::SINGLETON, node[0]);
202 size_t setNodeIndex = node.getNumChildren()-1;
203 for(size_t i = 1; i < setNodeIndex; ++i) {
204 insertedElements = nm->mkNode(kind::UNION,
205 insertedElements,
206 nm->mkNode(kind::SINGLETON, node[i]));
207 }
208 return RewriteResponse(REWRITE_AGAIN,
209 nm->mkNode(kind::UNION,
210 insertedElements,
211 node[setNodeIndex]));
212
213 }//kind::INSERT
214 else if(node.getKind() == kind::SUBSET) {
215
216 // rewrite (A subset-or-equal B) as (A union B = B)
217 return RewriteResponse(REWRITE_AGAIN,
218 nm->mkNode(kind::EQUAL,
219 nm->mkNode(kind::UNION, node[0], node[1]),
220 node[1]) );
221
222 }//kind::SUBSET
223
224 return RewriteResponse(REWRITE_DONE, node);
225 }
226
227 }/* CVC4::theory::sets namespace */
228 }/* CVC4::theory namespace */
229 }/* CVC4 namespace */