Merge pull request #38 from mdeters/uf-kinds
[cvc5.git] / src / theory / shared_terms_database.cpp
1 /********************* */
2 /*! \file shared_terms_database.cpp
3 ** \verbatim
4 ** Original author: Dejan Jovanovic
5 ** Major contributors: Morgan Deters
6 ** Minor contributors (to current version): Andrew Reynolds, Clark Barrett
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2013 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** [[ Add lengthier description here ]]
13 ** \todo document this file
14 **/
15
16
17 #include "theory/shared_terms_database.h"
18 #include "theory/theory_engine.h"
19
20 using namespace std;
21 using namespace CVC4;
22 using namespace theory;
23
24 SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine, context::Context* context)
25 : ContextNotifyObj(context)
26 , d_statSharedTerms("theory::shared_terms", 0)
27 , d_addedSharedTermsSize(context, 0)
28 , d_termsToTheories(context)
29 , d_alreadyNotifiedMap(context)
30 , d_registeredEqualities(context)
31 , d_EENotify(*this)
32 , d_equalityEngine(d_EENotify, context, "SharedTermsDatabase")
33 , d_theoryEngine(theoryEngine)
34 , d_inConflict(context, false)
35 {
36 StatisticsRegistry::registerStat(&d_statSharedTerms);
37 }
38
39 SharedTermsDatabase::~SharedTermsDatabase() throw(AssertionException)
40 {
41 StatisticsRegistry::unregisterStat(&d_statSharedTerms);
42 }
43
44 void SharedTermsDatabase::addEqualityToPropagate(TNode equality) {
45 d_registeredEqualities.insert(equality);
46 d_equalityEngine.addTriggerEquality(equality);
47 checkForConflict();
48 }
49
50
51 void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theories) {
52 Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", " << term << ", " << Theory::setToString(theories) << ")" << std::endl;
53
54 std::pair<TNode, TNode> search_pair(atom, term);
55 SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
56 if (find == d_termsToTheories.end()) {
57 // First time for this term and this atom
58 d_atomsToTerms[atom].push_back(term);
59 d_addedSharedTerms.push_back(atom);
60 d_addedSharedTermsSize = d_addedSharedTermsSize + 1;
61 d_termsToTheories[search_pair] = theories;
62 } else {
63 Assert(theories != (*find).second);
64 d_termsToTheories[search_pair] = Theory::setUnion(theories, (*find).second);
65 }
66 }
67
68 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::begin(TNode atom) const {
69 Assert(hasSharedTerms(atom));
70 return d_atomsToTerms.find(atom)->second.begin();
71 }
72
73 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::end(TNode atom) const {
74 Assert(hasSharedTerms(atom));
75 return d_atomsToTerms.find(atom)->second.end();
76 }
77
78 bool SharedTermsDatabase::hasSharedTerms(TNode atom) const {
79 return d_atomsToTerms.find(atom) != d_atomsToTerms.end();
80 }
81
82 void SharedTermsDatabase::backtrack() {
83 for (int i = d_addedSharedTerms.size() - 1, i_end = (int)d_addedSharedTermsSize; i >= i_end; -- i) {
84 TNode atom = d_addedSharedTerms[i];
85 shared_terms_list& list = d_atomsToTerms[atom];
86 list.pop_back();
87 if (list.empty()) {
88 d_atomsToTerms.erase(atom);
89 }
90 }
91 d_addedSharedTerms.resize(d_addedSharedTermsSize);
92 }
93
94 Theory::Set SharedTermsDatabase::getTheoriesToNotify(TNode atom, TNode term) const {
95 // Get the theories that share this term from this atom
96 std::pair<TNode, TNode> search_pair(atom, term);
97 SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
98 Assert(find != d_termsToTheories.end());
99
100 // Get the theories that were already notified
101 Theory::Set alreadyNotified = 0;
102 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
103 if (theoriesFind != d_alreadyNotifiedMap.end()) {
104 alreadyNotified = (*theoriesFind).second;
105 }
106
107 // Return the ones that haven't been notified yet
108 return Theory::setDifference((*find).second, alreadyNotified);
109 }
110
111
112 Theory::Set SharedTermsDatabase::getNotifiedTheories(TNode term) const {
113 // Get the theories that were already notified
114 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
115 if (theoriesFind != d_alreadyNotifiedMap.end()) {
116 return (*theoriesFind).second;
117 } else {
118 return 0;
119 }
120 }
121
122 bool SharedTermsDatabase::propagateSharedEquality(TheoryId theory, TNode a, TNode b, bool value)
123 {
124 Debug("shared-terms-database") << "SharedTermsDatabase::newEquality(" << theory << a << "," << b << ", " << (value ? "true" : "false") << ")" << endl;
125
126 if (d_inConflict) {
127 return false;
128 }
129
130 // Propagate away
131 Node equality = a.eqNode(b);
132 if (value) {
133 d_theoryEngine->assertToTheory(equality, equality, theory, THEORY_BUILTIN);
134 } else {
135 d_theoryEngine->assertToTheory(equality.notNode(), equality.notNode(), theory, THEORY_BUILTIN);
136 }
137
138 // As you were
139 return true;
140 }
141
142 void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
143
144 // Find out if there are any new theories that were notified about this term
145 Theory::Set alreadyNotified = 0;
146 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
147 if (theoriesFind != d_alreadyNotifiedMap.end()) {
148 alreadyNotified = (*theoriesFind).second;
149 }
150 Theory::Set newlyNotified = Theory::setDifference(theories, alreadyNotified);
151
152 // If no new theories were notified, we are done
153 if (newlyNotified == 0) {
154 return;
155 }
156
157 Debug("shared-terms-database") << "SharedTermsDatabase::markNotified(" << term << ")" << endl;
158
159 // First update the set of notified theories for this term
160 d_alreadyNotifiedMap[term] = Theory::setUnion(newlyNotified, alreadyNotified);
161
162 // Mark the shared terms in the equality engine
163 theory::TheoryId currentTheory;
164 while ((currentTheory = Theory::setPop(newlyNotified)) != THEORY_LAST) {
165 d_equalityEngine.addTriggerTerm(term, currentTheory);
166 }
167
168 // Check for any conflits
169 checkForConflict();
170 }
171
172 bool SharedTermsDatabase::areEqual(TNode a, TNode b) const {
173 if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
174 return d_equalityEngine.areEqual(a,b);
175 } else {
176 Assert(d_equalityEngine.hasTerm(a) || a.isConst());
177 Assert(d_equalityEngine.hasTerm(b) || b.isConst());
178 // since one (or both) of them is a constant, and the other is in the equality engine, they are not same
179 return false;
180 }
181 }
182
183 bool SharedTermsDatabase::areDisequal(TNode a, TNode b) const {
184 if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
185 return d_equalityEngine.areDisequal(a,b,false);
186 } else {
187 Assert(d_equalityEngine.hasTerm(a) || a.isConst());
188 Assert(d_equalityEngine.hasTerm(b) || b.isConst());
189 // one (or both) are in the equality engine
190 return false;
191 }
192 }
193
194 void SharedTermsDatabase::assertEquality(TNode equality, bool polarity, TNode reason)
195 {
196 Debug("shared-terms-database::assert") << "SharedTermsDatabase::assertEquality(" << equality << ", " << (polarity ? "true" : "false") << ", " << reason << ")" << endl;
197 // Add it to the equality engine
198 d_equalityEngine.assertEquality(equality, polarity, reason);
199 // Check for conflict
200 checkForConflict();
201 }
202
203 bool SharedTermsDatabase::propagateEquality(TNode equality, bool polarity) {
204 if (polarity) {
205 d_theoryEngine->propagate(equality, THEORY_BUILTIN);
206 } else {
207 d_theoryEngine->propagate(equality.notNode(), THEORY_BUILTIN);
208 }
209 return true;
210 }
211
212 static Node mkAnd(const std::vector<TNode>& conjunctions) {
213 Assert(conjunctions.size() > 0);
214
215 std::set<TNode> all;
216 all.insert(conjunctions.begin(), conjunctions.end());
217
218 if (all.size() == 1) {
219 // All the same, or just one
220 return conjunctions[0];
221 }
222
223 NodeBuilder<> conjunction(kind::AND);
224 std::set<TNode>::const_iterator it = all.begin();
225 std::set<TNode>::const_iterator it_end = all.end();
226 while (it != it_end) {
227 conjunction << *it;
228 ++ it;
229 }
230
231 return conjunction;
232 }
233
234 void SharedTermsDatabase::checkForConflict() {
235 if (d_inConflict) {
236 d_inConflict = false;
237 std::vector<TNode> assumptions;
238 d_equalityEngine.explainEquality(d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
239 Node conflict = mkAnd(assumptions);
240 d_theoryEngine->conflict(conflict, THEORY_BUILTIN);
241 d_conflictLHS = d_conflictRHS = Node::null();
242 }
243 }
244
245 bool SharedTermsDatabase::isKnown(TNode literal) const {
246 bool polarity = literal.getKind() != kind::NOT;
247 TNode equality = polarity ? literal : literal[0];
248 if (polarity) {
249 return d_equalityEngine.areEqual(equality[0], equality[1]);
250 } else {
251 return d_equalityEngine.areDisequal(equality[0], equality[1], false);
252 }
253 }
254
255 Node SharedTermsDatabase::explain(TNode literal) const {
256 bool polarity = literal.getKind() != kind::NOT;
257 TNode atom = polarity ? literal : literal[0];
258 Assert(atom.getKind() == kind::EQUAL);
259 std::vector<TNode> assumptions;
260 d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
261 return mkAnd(assumptions);
262 }