Updating equality manager to handle tagged trigger terms. Notifications are pushed...
[cvc5.git] / src / theory / shared_terms_database.cpp
1 /********************* */
2 /*! \file shared_terms_database.cpp
3 ** \verbatim
4 ** Original author: dejan
5 ** Major contributors: none
6 ** Minor contributors (to current version): none
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys)
9 ** Courant Institute of Mathematical Sciences
10 ** New York University
11 ** See the file COPYING in the top-level source directory for licensing
12 ** information.\endverbatim
13 **
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
16 **/
17
18 #include "theory/shared_terms_database.h"
19
20 using namespace std;
21 using namespace CVC4;
22 using namespace theory;
23
24 SharedTermsDatabase::SharedTermsDatabase(SharedTermsNotifyClass& notify, context::Context* context)
25 : ContextNotifyObj(context),
26 d_context(context),
27 d_statSharedTerms("theory::shared_terms", 0),
28 d_addedSharedTermsSize(context, 0),
29 d_termsToTheories(context),
30 d_alreadyNotifiedMap(context),
31 d_sharedNotify(notify),
32 d_termToNotifyList(context),
33 d_allocatedNLSize(0),
34 d_allocatedNLNext(context, 0),
35 d_EENotify(*this),
36 d_equalityEngine(d_EENotify, context, "SharedTermsDatabase")
37 {
38 StatisticsRegistry::registerStat(&d_statSharedTerms);
39 }
40
41 SharedTermsDatabase::~SharedTermsDatabase() throw(AssertionException)
42 {
43 StatisticsRegistry::unregisterStat(&d_statSharedTerms);
44 for (unsigned i = 0; i < d_allocatedNLSize; ++i) {
45 d_allocatedNL[i]->deleteSelf();
46 }
47 }
48
49 void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theories) {
50 Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", " << term << ", " << Theory::setToString(theories) << ")" << std::endl;
51
52 std::pair<TNode, TNode> search_pair(atom, term);
53 SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
54 if (find == d_termsToTheories.end()) {
55 // First time for this term and this atom
56 d_atomsToTerms[atom].push_back(term);
57 d_addedSharedTerms.push_back(atom);
58 d_addedSharedTermsSize = d_addedSharedTermsSize + 1;
59 d_termsToTheories[search_pair] = theories;
60 if (!d_equalityEngine.hasTerm(term)) {
61 d_equalityEngine.addTriggerTerm(term, THEORY_UF);
62 }
63 } else {
64 Assert(theories != (*find).second);
65 d_termsToTheories[search_pair] = Theory::setUnion(theories, (*find).second);
66 }
67 }
68
69 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::begin(TNode atom) const {
70 Assert(hasSharedTerms(atom));
71 return d_atomsToTerms.find(atom)->second.begin();
72 }
73
74 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::end(TNode atom) const {
75 Assert(hasSharedTerms(atom));
76 return d_atomsToTerms.find(atom)->second.end();
77 }
78
79 bool SharedTermsDatabase::hasSharedTerms(TNode atom) const {
80 return d_atomsToTerms.find(atom) != d_atomsToTerms.end();
81 }
82
83 void SharedTermsDatabase::backtrack() {
84 for (int i = d_addedSharedTerms.size() - 1, i_end = (int)d_addedSharedTermsSize; i >= i_end; -- i) {
85 TNode atom = d_addedSharedTerms[i];
86 shared_terms_list& list = d_atomsToTerms[atom];
87 list.pop_back();
88 if (list.empty()) {
89 d_atomsToTerms.erase(atom);
90 }
91 }
92 d_addedSharedTerms.resize(d_addedSharedTermsSize);
93 }
94
95 Theory::Set SharedTermsDatabase::getTheoriesToNotify(TNode atom, TNode term) const {
96 // Get the theories that share this term from this atom
97 std::pair<TNode, TNode> search_pair(atom, term);
98 SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
99 Assert(find != d_termsToTheories.end());
100
101 // Get the theories that were already notified
102 Theory::Set alreadyNotified = 0;
103 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
104 if (theoriesFind != d_alreadyNotifiedMap.end()) {
105 alreadyNotified = (*theoriesFind).second;
106 }
107
108 // Return the ones that haven't been notified yet
109 return Theory::setDifference((*find).second, alreadyNotified);
110 }
111
112
113 Theory::Set SharedTermsDatabase::getNotifiedTheories(TNode term) const {
114 // Get the theories that were already notified
115 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
116 if (theoriesFind != d_alreadyNotifiedMap.end()) {
117 return (*theoriesFind).second;
118 } else {
119 return 0;
120 }
121 }
122
123
124 SharedTermsDatabase::NotifyList* SharedTermsDatabase::getNewNotifyList()
125 {
126 NotifyList* retval;
127 if (d_allocatedNLSize == d_allocatedNLNext) {
128 retval = new (true) NotifyList(d_context);
129 d_allocatedNL.push_back(retval);
130 d_allocatedNLNext = ++d_allocatedNLSize;
131 }
132 else {
133 retval = d_allocatedNL[d_allocatedNLNext];
134 d_allocatedNLNext = d_allocatedNLNext + 1;
135 }
136 Assert(retval->empty());
137 return retval;
138 }
139
140
141 void SharedTermsDatabase::mergeSharedTerms(TNode a, TNode b)
142 {
143 // Note: a is the new representative
144 Debug("shared-terms-database") << "SharedTermsDatabase::mergeSharedTerms(" << a << "," << b << ")" << endl;
145
146 NotifyList* pnlLeft = NULL;
147 NotifyList* pnlRight = NULL;
148
149 TermToNotifyList::iterator it = d_termToNotifyList.find(a);
150 if (it == d_termToNotifyList.end()) {
151 pnlLeft = getNewNotifyList();
152 d_termToNotifyList[a] = pnlLeft;
153 }
154 else {
155 pnlLeft = (*it).second;
156 }
157 it = d_termToNotifyList.find(b);
158 if (it != d_termToNotifyList.end()) {
159 pnlRight = (*it).second;
160 }
161
162 // Get theories interested in EC for lhs
163 Theory::Set lhsSet = getNotifiedTheories(a);
164 Theory::Set rhsSet = getNotifiedTheories(b);
165 NotifyList::iterator nit;
166 TNode left, right;
167
168 for (TheoryId currentTheory = THEORY_FIRST; currentTheory != THEORY_LAST; ++ currentTheory) {
169
170 if (Theory::setContains(currentTheory, rhsSet)) {
171 right = b;
172 }
173 else if (pnlRight != NULL &&
174 ((nit = pnlRight->find(currentTheory)) != pnlRight->end())) {
175 right = (*nit).second;
176 }
177 else {
178 // no match for right: continue
179 continue;
180 }
181
182 if (Theory::setContains(currentTheory, lhsSet)) {
183 left = a;
184 }
185 else if ((nit = pnlLeft->find(currentTheory)) != pnlLeft->end()) {
186 left = (*nit).second;
187 }
188 else {
189 // no match for left: insert right into left
190 (*pnlLeft)[currentTheory] = right;
191 continue;
192 }
193
194 // New shared equality: notify the client
195
196 // TODO: add propagation of disequalities?
197
198 assertEq(left.eqNode(right), currentTheory);
199 }
200
201 }
202
203
204 void SharedTermsDatabase::assertEq(TNode equality, TheoryId theory)
205 {
206 Debug("shared-terms-database") << "SharedTermsDatabase::assertEq(" << equality << ") to theory " << theory << endl;
207 Node normalized = Rewriter::rewriteEquality(theory, equality);
208 if (normalized.getKind() != kind::CONST_BOOLEAN || !normalized.getConst<bool>()) {
209 // Notify client
210 d_sharedNotify.notify(normalized, equality, theory);
211 }
212 }
213
214
215 // term was just part of an assertion that makes it shared for theories
216 // Let's mark that the set theories has now been notified
217 // In addition, we make sure the equivalence class containing term knows a
218 // representative for each theory in theories.
219 // Finally, if the EC already knows a rep for a theory that was just notified, we
220 // have to tell the theory that these two terms are equal.
221 void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
222
223 // Find out if there are any new theories that were notified about this term
224 Theory::Set alreadyNotified = 0;
225 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
226 if (theoriesFind != d_alreadyNotifiedMap.end()) {
227 alreadyNotified = (*theoriesFind).second;
228 }
229 Theory::Set newlyNotified = Theory::setDifference(theories, alreadyNotified);
230
231 // If no new theories were notified, we are done
232 if (newlyNotified == 0) {
233 return;
234 }
235
236 Debug("shared-terms-database") << "SharedTermsDatabase::markNotified(" << term << ")" << endl;
237
238 // First update the set of notified theories for this term
239 d_alreadyNotifiedMap[term] = Theory::setUnion(newlyNotified, alreadyNotified);
240
241 // Now get the representative of the equivalence class and find out which theories it represents
242 TNode rep = d_equalityEngine.getRepresentative(term);
243 if (rep != term) {
244 alreadyNotified = 0;
245 theoriesFind = d_alreadyNotifiedMap.find(rep);
246 if (theoriesFind != d_alreadyNotifiedMap.end()) {
247 alreadyNotified = (*theoriesFind).second;
248 }
249 }
250
251 // For each theory that is newly notified
252 for (TheoryId theory = THEORY_FIRST; theory != THEORY_LAST; ++ theory) {
253 if (Theory::setContains(theory, newlyNotified)) {
254
255 Debug("shared-terms-database") << "SharedTermsDatabase::markNotified: processing theory " << theory << endl;
256
257 if (Theory::setContains(theory, alreadyNotified)) {
258 // rep represents this theory already, need to assert that term = rep
259 Assert(rep != term);
260 assertEq(rep.eqNode(term), theory);
261 }
262 else {
263 // Get the list of terms representing theories for this EC
264 TermToNotifyList::iterator it = d_termToNotifyList.find(rep);
265 if (it == d_termToNotifyList.end()) {
266 // No need to do anything - no list associated with this EC
267 Assert(term == rep);
268 }
269 else {
270 NotifyList* pnl = (*it).second;
271 Assert(pnl != NULL);
272
273 // Check if this theory is already represented
274 NotifyList::iterator nit = pnl->find(theory);
275 if (nit != pnl->end()) {
276 // Already have a representative for this theory, assert term equal to it
277 assertEq((*nit).second.eqNode(term), theory);
278 }
279 else {
280 // if term == rep, no need to do anything, as term will represent the theory via alreadyNotifiedMap
281 if (term != rep) {
282 // No term in this EC represents this theory, so add term as a new representative
283 Debug("shared-terms-database") << "SharedTermsDatabase::markNotified: adding " << term << " to representative " << rep << " for theory " << theory << endl;
284 (*pnl)[theory] = term;
285 }
286 }
287 }
288 }
289 }
290 }
291 }
292
293
294 bool SharedTermsDatabase::areEqual(TNode a, TNode b) {
295 return d_equalityEngine.areEqual(a,b);
296 }
297
298
299 bool SharedTermsDatabase::areDisequal(TNode a, TNode b) {
300 return d_equalityEngine.areDisequal(a,b);
301 }
302
303 void SharedTermsDatabase::processSharedLiteral(TNode literal, TNode reason)
304 {
305 bool negated = literal.getKind() == kind::NOT;
306 TNode atom = negated ? literal[0] : literal;
307 if (negated) {
308 Assert(!d_equalityEngine.areDisequal(atom[0], atom[1]));
309 d_equalityEngine.assertEquality(atom, false, reason);
310 // !!! need to send this out
311 }
312 else {
313 Assert(!d_equalityEngine.areEqual(atom[0], atom[1]));
314 d_equalityEngine.assertEquality(atom, true, reason);
315 }
316 }
317
318 static Node mkAnd(const std::vector<TNode>& conjunctions) {
319 Assert(conjunctions.size() > 0);
320
321 std::set<TNode> all;
322 all.insert(conjunctions.begin(), conjunctions.end());
323
324 if (all.size() == 1) {
325 // All the same, or just one
326 return conjunctions[0];
327 }
328
329 NodeBuilder<> conjunction(kind::AND);
330 std::set<TNode>::const_iterator it = all.begin();
331 std::set<TNode>::const_iterator it_end = all.end();
332 while (it != it_end) {
333 conjunction << *it;
334 ++ it;
335 }
336
337 return conjunction;
338 }/* mkAnd() */
339
340
341 Node SharedTermsDatabase::explain(TNode literal)
342 {
343 std::vector<TNode> assumptions;
344 if (literal.getKind() == kind::NOT) {
345 Assert(literal[0].getKind() == kind::EQUAL);
346 d_equalityEngine.explainEquality(literal[0][0], literal[0][1], false, assumptions);
347 } else {
348 Assert(literal.getKind() == kind::EQUAL);
349 d_equalityEngine.explainEquality(literal[0], literal[1], true, assumptions);
350 }
351 return mkAnd(assumptions);
352 }