* simplifying equality engine interface
[cvc5.git] / src / theory / shared_terms_database.cpp
1 /********************* */
2 /*! \file shared_terms_database.cpp
3 ** \verbatim
4 ** Original author: dejan
5 ** Major contributors: none
6 ** Minor contributors (to current version): none
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys)
9 ** Courant Institute of Mathematical Sciences
10 ** New York University
11 ** See the file COPYING in the top-level source directory for licensing
12 ** information.\endverbatim
13 **
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
16 **/
17
18 #include "theory/shared_terms_database.h"
19
20 using namespace CVC4;
21 using namespace theory;
22
23 SharedTermsDatabase::SharedTermsDatabase(SharedTermsNotifyClass& notify, context::Context* context)
24 : ContextNotifyObj(context),
25 d_context(context),
26 d_statSharedTerms("theory::shared_terms", 0),
27 d_addedSharedTermsSize(context, 0),
28 d_termsToTheories(context),
29 d_alreadyNotifiedMap(context),
30 d_sharedNotify(notify),
31 d_termToNotifyList(context),
32 d_allocatedNLSize(0),
33 d_allocatedNLNext(context, 0),
34 d_EENotify(*this),
35 d_equalityEngine(d_EENotify, context, "SharedTermsDatabase")
36 {
37 StatisticsRegistry::registerStat(&d_statSharedTerms);
38 }
39
40 SharedTermsDatabase::~SharedTermsDatabase() throw(AssertionException)
41 {
42 StatisticsRegistry::unregisterStat(&d_statSharedTerms);
43 for (unsigned i = 0; i < d_allocatedNLSize; ++i) {
44 d_allocatedNL[i]->deleteSelf();
45 }
46 }
47
48 void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theories) {
49 Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", " << term << ", " << Theory::setToString(theories) << ")" << std::endl;
50
51 std::pair<TNode, TNode> search_pair(atom, term);
52 SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
53 if (find == d_termsToTheories.end()) {
54 // First time for this term and this atom
55 d_atomsToTerms[atom].push_back(term);
56 d_addedSharedTerms.push_back(atom);
57 d_addedSharedTermsSize = d_addedSharedTermsSize + 1;
58 d_termsToTheories[search_pair] = theories;
59 if (!d_equalityEngine.hasTerm(term)) {
60 d_equalityEngine.addTriggerTerm(term);
61 }
62 } else {
63 Assert(theories != (*find).second);
64 d_termsToTheories[search_pair] = Theory::setUnion(theories, (*find).second);
65 }
66 }
67
68 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::begin(TNode atom) const {
69 Assert(hasSharedTerms(atom));
70 return d_atomsToTerms.find(atom)->second.begin();
71 }
72
73 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::end(TNode atom) const {
74 Assert(hasSharedTerms(atom));
75 return d_atomsToTerms.find(atom)->second.end();
76 }
77
78 bool SharedTermsDatabase::hasSharedTerms(TNode atom) const {
79 return d_atomsToTerms.find(atom) != d_atomsToTerms.end();
80 }
81
82 void SharedTermsDatabase::backtrack() {
83 for (int i = d_addedSharedTerms.size() - 1, i_end = (int)d_addedSharedTermsSize; i >= i_end; -- i) {
84 TNode atom = d_addedSharedTerms[i];
85 shared_terms_list& list = d_atomsToTerms[atom];
86 list.pop_back();
87 if (list.empty()) {
88 d_atomsToTerms.erase(atom);
89 }
90 }
91 d_addedSharedTerms.resize(d_addedSharedTermsSize);
92 }
93
94 Theory::Set SharedTermsDatabase::getTheoriesToNotify(TNode atom, TNode term) const {
95 // Get the theories that share this term from this atom
96 std::pair<TNode, TNode> search_pair(atom, term);
97 SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
98 Assert(find != d_termsToTheories.end());
99
100 // Get the theories that were already notified
101 Theory::Set alreadyNotified = 0;
102 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
103 if (theoriesFind != d_alreadyNotifiedMap.end()) {
104 alreadyNotified = (*theoriesFind).second;
105 }
106
107 // Return the ones that haven't been notified yet
108 return Theory::setDifference((*find).second, alreadyNotified);
109 }
110
111
112 Theory::Set SharedTermsDatabase::getNotifiedTheories(TNode term) const {
113 // Get the theories that were already notified
114 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
115 if (theoriesFind != d_alreadyNotifiedMap.end()) {
116 return (*theoriesFind).second;
117 } else {
118 return 0;
119 }
120 }
121
122
123 SharedTermsDatabase::NotifyList* SharedTermsDatabase::getNewNotifyList()
124 {
125 NotifyList* retval;
126 if (d_allocatedNLSize == d_allocatedNLNext) {
127 retval = new (true) NotifyList(d_context);
128 d_allocatedNL.push_back(retval);
129 d_allocatedNLNext = ++d_allocatedNLSize;
130 }
131 else {
132 retval = d_allocatedNL[d_allocatedNLNext];
133 d_allocatedNLNext = d_allocatedNLNext + 1;
134 }
135 Assert(retval->empty());
136 return retval;
137 }
138
139
140 void SharedTermsDatabase::mergeSharedTerms(TNode a, TNode b)
141 {
142 // Note: a is the new representative
143
144 NotifyList* pnlLeft = NULL;
145 NotifyList* pnlRight = NULL;
146
147 TermToNotifyList::iterator it = d_termToNotifyList.find(a);
148 if (it == d_termToNotifyList.end()) {
149 pnlLeft = getNewNotifyList();
150 d_termToNotifyList[a] = pnlLeft;
151 }
152 else {
153 pnlLeft = (*it).second;
154 }
155 it = d_termToNotifyList.find(b);
156 if (it != d_termToNotifyList.end()) {
157 pnlRight = (*it).second;
158 }
159
160 // Get theories interested in EC for lhs
161 Theory::Set lhsSet = getNotifiedTheories(a);
162 Theory::Set rhsSet = getNotifiedTheories(b);
163 NotifyList::iterator nit;
164 TNode left, right;
165
166 for (TheoryId currentTheory = THEORY_FIRST; currentTheory != THEORY_LAST; ++ currentTheory) {
167
168 if (Theory::setContains(currentTheory, rhsSet)) {
169 right = b;
170 }
171 else if (pnlRight != NULL &&
172 ((nit = pnlRight->end()) != pnlRight->end())) {
173 right = (*nit).second;
174 }
175 else {
176 // no match for right: continue
177 continue;
178 }
179
180 if (Theory::setContains(currentTheory, lhsSet)) {
181 left = a;
182 }
183 else if ((nit = pnlLeft->find(currentTheory)) != pnlLeft->end()) {
184 left = (*nit).second;
185 }
186 else {
187 // no match for left: insert right into left
188 (*pnlLeft)[currentTheory] = right;
189 continue;
190 }
191
192 // New shared equality: notify the client
193
194 // TODO: add propagation of disequalities?
195
196 // Normalize the equality
197 Node equality = left.eqNode(right);
198 Node normalized = Rewriter::rewriteEquality(currentTheory, equality);
199 if (normalized.getKind() != kind::CONST_BOOLEAN || !normalized.getConst<bool>()) {
200 // Notify client
201 d_sharedNotify.notify(normalized, equality, currentTheory);
202 }
203 }
204
205 }
206
207
208 void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
209 Theory::Set alreadyNotified = d_alreadyNotifiedMap[term];
210 Theory::Set newlyNotified = Theory::setDifference(theories, alreadyNotified);
211 if (newlyNotified != 0) {
212 TNode rep = d_equalityEngine.getRepresentative(term);
213 if (rep != term) {
214 TermToNotifyList::iterator it = d_termToNotifyList.find(rep);
215 Assert(it != d_termToNotifyList.end());
216 NotifyList* pnl = (*it).second;
217 for (TheoryId theory = THEORY_FIRST; theory != THEORY_LAST; ++ theory) {
218 if (Theory::setContains(theory, newlyNotified) &&
219 pnl->find(theory) == pnl->end()) {
220 (*pnl)[theory] = term;
221 }
222 }
223 }
224 }
225 d_alreadyNotifiedMap[term] = Theory::setUnion(newlyNotified, alreadyNotified);
226 }
227
228
229 bool SharedTermsDatabase::areEqual(TNode a, TNode b) {
230 return d_equalityEngine.areEqual(a,b);
231 }
232
233
234 bool SharedTermsDatabase::areDisequal(TNode a, TNode b) {
235 return d_equalityEngine.areDisequal(a,b);
236 }
237
238 void SharedTermsDatabase::processSharedLiteral(TNode literal, TNode reason)
239 {
240 bool negated = literal.getKind() == kind::NOT;
241 TNode atom = negated ? literal[0] : literal;
242 if (negated) {
243 Assert(!d_equalityEngine.areDisequal(atom[0], atom[1]));
244 d_equalityEngine.assertEquality(atom, false, reason);
245 // !!! need to send this out
246 }
247 else {
248 Assert(!d_equalityEngine.areEqual(atom[0], atom[1]));
249 d_equalityEngine.assertEquality(atom, true, reason);
250 }
251 }
252
253 static Node mkAnd(const std::vector<TNode>& conjunctions) {
254 Assert(conjunctions.size() > 0);
255
256 std::set<TNode> all;
257 all.insert(conjunctions.begin(), conjunctions.end());
258
259 if (all.size() == 1) {
260 // All the same, or just one
261 return conjunctions[0];
262 }
263
264 NodeBuilder<> conjunction(kind::AND);
265 std::set<TNode>::const_iterator it = all.begin();
266 std::set<TNode>::const_iterator it_end = all.end();
267 while (it != it_end) {
268 conjunction << *it;
269 ++ it;
270 }
271
272 return conjunction;
273 }/* mkAnd() */
274
275
276 Node SharedTermsDatabase::explain(TNode literal)
277 {
278 std::vector<TNode> assumptions;
279 if (literal.getKind() == kind::NOT) {
280 Assert(literal[0].getKind() == kind::EQUAL);
281 d_equalityEngine.explainEquality(literal[0][0], literal[0][1], false, assumptions);
282 } else {
283 Assert(literal.getKind() == kind::EQUAL);
284 d_equalityEngine.explainEquality(literal[0], literal[1], true, assumptions);
285 }
286 return mkAnd(assumptions);
287 }