(proof-new) Make shared solver proof producing (#5169)
[cvc5.git] / src / theory / shared_terms_database.cpp
1 /********************* */
2 /*! \file shared_terms_database.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Dejan Jovanovic, Andrew Reynolds, Morgan Deters
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** [[ Add lengthier description here ]]
13 ** \todo document this file
14 **/
15
16 #include "theory/shared_terms_database.h"
17
18 #include "smt/smt_statistics_registry.h"
19 #include "theory/theory_engine.h"
20
21 using namespace std;
22 using namespace CVC4::theory;
23
24 namespace CVC4 {
25
26 SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine,
27 context::Context* context,
28 context::UserContext* userContext,
29 ProofNodeManager* pnm)
30 : ContextNotifyObj(context),
31 d_statSharedTerms("theory::shared_terms", 0),
32 d_addedSharedTermsSize(context, 0),
33 d_termsToTheories(context),
34 d_alreadyNotifiedMap(context),
35 d_registeredEqualities(context),
36 d_EENotify(*this),
37 d_theoryEngine(theoryEngine),
38 d_inConflict(context, false),
39 d_conflictPolarity(),
40 d_satContext(context),
41 d_userContext(userContext),
42 d_equalityEngine(nullptr),
43 d_pfee(nullptr),
44 d_pnm(pnm)
45 {
46 smtStatisticsRegistry()->registerStat(&d_statSharedTerms);
47 }
48
49 SharedTermsDatabase::~SharedTermsDatabase()
50 {
51 smtStatisticsRegistry()->unregisterStat(&d_statSharedTerms);
52 }
53
54 void SharedTermsDatabase::setEqualityEngine(eq::EqualityEngine* ee)
55 {
56 Assert(ee != nullptr);
57 d_equalityEngine = ee;
58 // if proofs are enabled, make the proof equality engine
59 if (d_pnm != nullptr)
60 {
61 d_pfee.reset(
62 new eq::ProofEqEngine(d_satContext, d_userContext, *ee, d_pnm));
63 }
64 }
65
66 bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi)
67 {
68 esi.d_notify = &d_EENotify;
69 esi.d_name = "SharedTermsDatabase";
70 return true;
71 }
72
73 void SharedTermsDatabase::addEqualityToPropagate(TNode equality) {
74 Assert(d_equalityEngine != nullptr);
75 d_registeredEqualities.insert(equality);
76 d_equalityEngine->addTriggerPredicate(equality);
77 checkForConflict();
78 }
79
80 void SharedTermsDatabase::addSharedTerm(TNode atom,
81 TNode term,
82 TheoryIdSet theories)
83 {
84 Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", "
85 << term << ", " << TheoryIdSetUtil::setToString(theories)
86 << ")" << std::endl;
87
88 std::pair<TNode, TNode> search_pair(atom, term);
89 SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
90 if (find == d_termsToTheories.end()) {
91 // First time for this term and this atom
92 d_atomsToTerms[atom].push_back(term);
93 d_addedSharedTerms.push_back(atom);
94 d_addedSharedTermsSize = d_addedSharedTermsSize + 1;
95 d_termsToTheories[search_pair] = theories;
96 } else {
97 Assert(theories != (*find).second);
98 d_termsToTheories[search_pair] =
99 TheoryIdSetUtil::setUnion(theories, (*find).second);
100 }
101 }
102
103 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::begin(TNode atom) const {
104 Assert(hasSharedTerms(atom));
105 return d_atomsToTerms.find(atom)->second.begin();
106 }
107
108 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::end(TNode atom) const {
109 Assert(hasSharedTerms(atom));
110 return d_atomsToTerms.find(atom)->second.end();
111 }
112
113 bool SharedTermsDatabase::hasSharedTerms(TNode atom) const {
114 return d_atomsToTerms.find(atom) != d_atomsToTerms.end();
115 }
116
117 void SharedTermsDatabase::backtrack() {
118 for (int i = d_addedSharedTerms.size() - 1, i_end = (int)d_addedSharedTermsSize; i >= i_end; -- i) {
119 TNode atom = d_addedSharedTerms[i];
120 shared_terms_list& list = d_atomsToTerms[atom];
121 list.pop_back();
122 if (list.empty()) {
123 d_atomsToTerms.erase(atom);
124 }
125 }
126 d_addedSharedTerms.resize(d_addedSharedTermsSize);
127 }
128
129 TheoryIdSet SharedTermsDatabase::getTheoriesToNotify(TNode atom,
130 TNode term) const
131 {
132 // Get the theories that share this term from this atom
133 std::pair<TNode, TNode> search_pair(atom, term);
134 SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
135 Assert(find != d_termsToTheories.end());
136
137 // Get the theories that were already notified
138 TheoryIdSet alreadyNotified = 0;
139 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
140 if (theoriesFind != d_alreadyNotifiedMap.end()) {
141 alreadyNotified = (*theoriesFind).second;
142 }
143
144 // Return the ones that haven't been notified yet
145 return TheoryIdSetUtil::setDifference((*find).second, alreadyNotified);
146 }
147
148 TheoryIdSet SharedTermsDatabase::getNotifiedTheories(TNode term) const
149 {
150 // Get the theories that were already notified
151 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
152 if (theoriesFind != d_alreadyNotifiedMap.end()) {
153 return (*theoriesFind).second;
154 } else {
155 return 0;
156 }
157 }
158
159 bool SharedTermsDatabase::propagateSharedEquality(TheoryId theory, TNode a, TNode b, bool value)
160 {
161 Debug("shared-terms-database") << "SharedTermsDatabase::newEquality(" << theory << "," << a << "," << b << ", " << (value ? "true" : "false") << ")" << endl;
162
163 if (d_inConflict) {
164 return false;
165 }
166
167 // Propagate away
168 Node equality = a.eqNode(b);
169 if (value) {
170 d_theoryEngine->assertToTheory(equality, equality, theory, THEORY_BUILTIN);
171 } else {
172 d_theoryEngine->assertToTheory(equality.notNode(), equality.notNode(), theory, THEORY_BUILTIN);
173 }
174
175 // As you were
176 return true;
177 }
178
179 void SharedTermsDatabase::markNotified(TNode term, TheoryIdSet theories)
180 {
181 // Find out if there are any new theories that were notified about this term
182 TheoryIdSet alreadyNotified = 0;
183 AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
184 if (theoriesFind != d_alreadyNotifiedMap.end()) {
185 alreadyNotified = (*theoriesFind).second;
186 }
187 TheoryIdSet newlyNotified =
188 TheoryIdSetUtil::setDifference(theories, alreadyNotified);
189
190 // If no new theories were notified, we are done
191 if (newlyNotified == 0) {
192 return;
193 }
194
195 Debug("shared-terms-database") << "SharedTermsDatabase::markNotified(" << term << ")" << endl;
196
197 // First update the set of notified theories for this term
198 d_alreadyNotifiedMap[term] =
199 TheoryIdSetUtil::setUnion(newlyNotified, alreadyNotified);
200
201 if (d_equalityEngine == nullptr)
202 {
203 // if we are not assigned an equality engine, there is nothing to do
204 return;
205 }
206
207 // Mark the shared terms in the equality engine
208 theory::TheoryId currentTheory;
209 while ((currentTheory = TheoryIdSetUtil::setPop(newlyNotified))
210 != THEORY_LAST)
211 {
212 d_equalityEngine->addTriggerTerm(term, currentTheory);
213 }
214
215 // Check for any conflits
216 checkForConflict();
217 }
218
219 bool SharedTermsDatabase::areEqual(TNode a, TNode b) const {
220 Assert(d_equalityEngine != nullptr);
221 if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b))
222 {
223 return d_equalityEngine->areEqual(a, b);
224 } else {
225 Assert(d_equalityEngine->hasTerm(a) || a.isConst());
226 Assert(d_equalityEngine->hasTerm(b) || b.isConst());
227 // since one (or both) of them is a constant, and the other is in the equality engine, they are not same
228 return false;
229 }
230 }
231
232 bool SharedTermsDatabase::areDisequal(TNode a, TNode b) const {
233 Assert(d_equalityEngine != nullptr);
234 if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b))
235 {
236 return d_equalityEngine->areDisequal(a, b, false);
237 } else {
238 Assert(d_equalityEngine->hasTerm(a) || a.isConst());
239 Assert(d_equalityEngine->hasTerm(b) || b.isConst());
240 // one (or both) are in the equality engine
241 return false;
242 }
243 }
244
245 theory::eq::EqualityEngine* SharedTermsDatabase::getEqualityEngine()
246 {
247 return d_equalityEngine;
248 }
249
250 void SharedTermsDatabase::assertEquality(TNode equality, bool polarity, TNode reason)
251 {
252 Assert(d_equalityEngine != nullptr);
253 Debug("shared-terms-database::assert") << "SharedTermsDatabase::assertEquality(" << equality << ", " << (polarity ? "true" : "false") << ", " << reason << ")" << endl;
254 // Add it to the equality engine
255 d_equalityEngine->assertEquality(equality, polarity, reason);
256 // Check for conflict
257 checkForConflict();
258 }
259
260 bool SharedTermsDatabase::propagateEquality(TNode equality, bool polarity) {
261 if (polarity) {
262 d_theoryEngine->propagate(equality, THEORY_BUILTIN);
263 } else {
264 d_theoryEngine->propagate(equality.notNode(), THEORY_BUILTIN);
265 }
266 return true;
267 }
268
269 void SharedTermsDatabase::checkForConflict()
270 {
271 if (!d_inConflict)
272 {
273 return;
274 }
275 d_inConflict = false;
276 TrustNode trnc;
277 if (d_pfee != nullptr)
278 {
279 Node conflict = d_conflictLHS.eqNode(d_conflictRHS);
280 conflict = d_conflictPolarity ? conflict : conflict.notNode();
281 trnc = d_pfee->assertConflict(conflict);
282 }
283 else
284 {
285 // standard explain
286 std::vector<TNode> assumptions;
287 d_equalityEngine->explainEquality(
288 d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
289 Node conflictNode = NodeManager::currentNM()->mkAnd(assumptions);
290 trnc = TrustNode::mkTrustConflict(conflictNode, nullptr);
291 }
292 d_theoryEngine->conflict(trnc, THEORY_BUILTIN);
293 d_conflictLHS = d_conflictRHS = Node::null();
294 }
295
296 bool SharedTermsDatabase::isKnown(TNode literal) const {
297 Assert(d_equalityEngine != nullptr);
298 bool polarity = literal.getKind() != kind::NOT;
299 TNode equality = polarity ? literal : literal[0];
300 if (polarity) {
301 return d_equalityEngine->areEqual(equality[0], equality[1]);
302 } else {
303 return d_equalityEngine->areDisequal(equality[0], equality[1], false);
304 }
305 }
306
307 theory::TrustNode SharedTermsDatabase::explain(TNode literal) const
308 {
309 if (d_pfee != nullptr)
310 {
311 // use the proof equality engine if it exists
312 return d_pfee->explain(literal);
313 }
314 // otherwise, explain without proofs
315 Node exp = d_equalityEngine->mkExplainLit(literal);
316 // no proof generator
317 return TrustNode::mkTrustPropExp(literal, exp, nullptr);
318 }
319
320 } /* namespace CVC4 */