Updating the copyright headers and scripts.
[cvc5.git] / src / theory / shared_terms_database.cpp
1 /********************* */
2 /*! \file shared_terms_database.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Dejan Jovanovic, Morgan Deters, Clark Barrett
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2016 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** [[ Add lengthier description here ]]
13 ** \todo document this file
14 **/
15
16 #include "theory/shared_terms_database.h"
17
18 #include "smt/smt_statistics_registry.h"
19 #include "theory/theory_engine.h"
20
21 using namespace std;
22 using namespace CVC4;
23 using namespace theory;
24
25 SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine, context::Context* context)
26 : ContextNotifyObj(context)
27 , d_statSharedTerms("theory::shared_terms", 0)
28 , d_addedSharedTermsSize(context, 0)
29 , d_termsToTheories(context)
30 , d_alreadyNotifiedMap(context)
31 , d_registeredEqualities(context)
32 , d_EENotify(*this)
33 , d_equalityEngine(d_EENotify, context, "SharedTermsDatabase", true)
34 , d_theoryEngine(theoryEngine)
35 , d_inConflict(context, false)
36 {
37 smtStatisticsRegistry()->registerStat(&d_statSharedTerms);
38 }
39
40 SharedTermsDatabase::~SharedTermsDatabase() throw(AssertionException)
41 {
42 smtStatisticsRegistry()->unregisterStat(&d_statSharedTerms);
43 }
44
45 void SharedTermsDatabase::addEqualityToPropagate(TNode equality) {
46 d_registeredEqualities.insert(equality);
47 d_equalityEngine.addTriggerEquality(equality);
48 checkForConflict();
49 }
50
51
52 void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theories) {
53 Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", " << term << ", " << Theory::setToString(theories) << ")" << std::endl;
54
55 std::pair<TNode, TNode> search_pair(atom, term);
56 SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
57 if (find == d_termsToTheories.end()) {
58 // First time for this term and this atom
59 d_atomsToTerms[atom].push_back(term);
60 d_addedSharedTerms.push_back(atom);
61 d_addedSharedTermsSize = d_addedSharedTermsSize + 1;
62 d_termsToTheories[search_pair] = theories;
63 } else {
64 Assert(theories != (*find).second);
65 d_termsToTheories[search_pair] = Theory::setUnion(theories, (*find).second);
66 }
67 }
68
69 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::begin(TNode atom) const {
70 Assert(hasSharedTerms(atom));
71 return d_atomsToTerms.find(atom)->second.begin();
72 }
73
74 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::end(TNode atom) const {
75 Assert(hasSharedTerms(atom));
76 return d_atomsToTerms.find(atom)->second.end();
77 }
78
79 bool SharedTermsDatabase::hasSharedTerms(TNode atom) const {
80 return d_atomsToTerms.find(atom) != d_atomsToTerms.end();
81 }
82
83 void SharedTermsDatabase::backtrack() {
84 for (int i = d_addedSharedTerms.size() - 1, i_end = (int)d_addedSharedTermsSize; i >= i_end; -- i) {
85 TNode atom = d_addedSharedTerms[i];
86 shared_terms_list& list = d_atomsToTerms[atom];
87 list.pop_back();
88 if (list.empty()) {
89 d_atomsToTerms.erase(atom);
90 }
91 }
92 d_addedSharedTerms.resize(d_addedSharedTermsSize);
93 }
94
95 Theory::Set SharedTermsDatabase::getTheoriesToNotify(TNode atom, TNode term) const {
96 // Get the theories that share this term from this atom
97 std::pair<TNode, TNode> search_pair(atom, term);
98 SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
99 Assert(find != d_termsToTheories.end());
100
101 // Get the theories that were already notified
102 Theory::Set alreadyNotified = 0;
103 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
104 if (theoriesFind != d_alreadyNotifiedMap.end()) {
105 alreadyNotified = (*theoriesFind).second;
106 }
107
108 // Return the ones that haven't been notified yet
109 return Theory::setDifference((*find).second, alreadyNotified);
110 }
111
112
113 Theory::Set SharedTermsDatabase::getNotifiedTheories(TNode term) const {
114 // Get the theories that were already notified
115 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
116 if (theoriesFind != d_alreadyNotifiedMap.end()) {
117 return (*theoriesFind).second;
118 } else {
119 return 0;
120 }
121 }
122
123 bool SharedTermsDatabase::propagateSharedEquality(TheoryId theory, TNode a, TNode b, bool value)
124 {
125 Debug("shared-terms-database") << "SharedTermsDatabase::newEquality(" << theory << "," << a << "," << b << ", " << (value ? "true" : "false") << ")" << endl;
126
127 if (d_inConflict) {
128 return false;
129 }
130
131 // Propagate away
132 Node equality = a.eqNode(b);
133 if (value) {
134 d_theoryEngine->assertToTheory(equality, equality, theory, THEORY_BUILTIN);
135 } else {
136 d_theoryEngine->assertToTheory(equality.notNode(), equality.notNode(), theory, THEORY_BUILTIN);
137 }
138
139 // As you were
140 return true;
141 }
142
143 void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
144
145 // Find out if there are any new theories that were notified about this term
146 Theory::Set alreadyNotified = 0;
147 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
148 if (theoriesFind != d_alreadyNotifiedMap.end()) {
149 alreadyNotified = (*theoriesFind).second;
150 }
151 Theory::Set newlyNotified = Theory::setDifference(theories, alreadyNotified);
152
153 // If no new theories were notified, we are done
154 if (newlyNotified == 0) {
155 return;
156 }
157
158 Debug("shared-terms-database") << "SharedTermsDatabase::markNotified(" << term << ")" << endl;
159
160 // First update the set of notified theories for this term
161 d_alreadyNotifiedMap[term] = Theory::setUnion(newlyNotified, alreadyNotified);
162
163 // Mark the shared terms in the equality engine
164 theory::TheoryId currentTheory;
165 while ((currentTheory = Theory::setPop(newlyNotified)) != THEORY_LAST) {
166 d_equalityEngine.addTriggerTerm(term, currentTheory);
167 }
168
169 // Check for any conflits
170 checkForConflict();
171 }
172
173 bool SharedTermsDatabase::areEqual(TNode a, TNode b) const {
174 if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
175 return d_equalityEngine.areEqual(a,b);
176 } else {
177 Assert(d_equalityEngine.hasTerm(a) || a.isConst());
178 Assert(d_equalityEngine.hasTerm(b) || b.isConst());
179 // since one (or both) of them is a constant, and the other is in the equality engine, they are not same
180 return false;
181 }
182 }
183
184 bool SharedTermsDatabase::areDisequal(TNode a, TNode b) const {
185 if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
186 return d_equalityEngine.areDisequal(a,b,false);
187 } else {
188 Assert(d_equalityEngine.hasTerm(a) || a.isConst());
189 Assert(d_equalityEngine.hasTerm(b) || b.isConst());
190 // one (or both) are in the equality engine
191 return false;
192 }
193 }
194
195 void SharedTermsDatabase::assertEquality(TNode equality, bool polarity, TNode reason)
196 {
197 Debug("shared-terms-database::assert") << "SharedTermsDatabase::assertEquality(" << equality << ", " << (polarity ? "true" : "false") << ", " << reason << ")" << endl;
198 // Add it to the equality engine
199 d_equalityEngine.assertEquality(equality, polarity, reason);
200 // Check for conflict
201 checkForConflict();
202 }
203
204 bool SharedTermsDatabase::propagateEquality(TNode equality, bool polarity) {
205 if (polarity) {
206 d_theoryEngine->propagate(equality, THEORY_BUILTIN);
207 } else {
208 d_theoryEngine->propagate(equality.notNode(), THEORY_BUILTIN);
209 }
210 return true;
211 }
212
213 static Node mkAnd(const std::vector<TNode>& conjunctions) {
214 Assert(conjunctions.size() > 0);
215
216 std::set<TNode> all;
217 all.insert(conjunctions.begin(), conjunctions.end());
218
219 if (all.size() == 1) {
220 // All the same, or just one
221 return conjunctions[0];
222 }
223
224 NodeBuilder<> conjunction(kind::AND);
225 std::set<TNode>::const_iterator it = all.begin();
226 std::set<TNode>::const_iterator it_end = all.end();
227 while (it != it_end) {
228 conjunction << *it;
229 ++ it;
230 }
231
232 return conjunction;
233 }
234
235 void SharedTermsDatabase::checkForConflict() {
236 if (d_inConflict) {
237 d_inConflict = false;
238 std::vector<TNode> assumptions;
239 d_equalityEngine.explainEquality(d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
240 Node conflict = mkAnd(assumptions);
241 d_theoryEngine->conflict(conflict, THEORY_BUILTIN);
242 d_conflictLHS = d_conflictRHS = Node::null();
243 }
244 }
245
246 bool SharedTermsDatabase::isKnown(TNode literal) const {
247 bool polarity = literal.getKind() != kind::NOT;
248 TNode equality = polarity ? literal : literal[0];
249 if (polarity) {
250 return d_equalityEngine.areEqual(equality[0], equality[1]);
251 } else {
252 return d_equalityEngine.areDisequal(equality[0], equality[1], false);
253 }
254 }
255
256 Node SharedTermsDatabase::explain(TNode literal) const {
257 bool polarity = literal.getKind() != kind::NOT;
258 TNode atom = polarity ? literal : literal[0];
259 Assert(atom.getKind() == kind::EQUAL);
260 std::vector<TNode> assumptions;
261 d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
262 return mkAnd(assumptions);
263 }