Minor cleanup related to notifications (#4898)
[cvc5.git] / src / theory / bv / bv_subtheory_core.cpp
1 /********************* */
2 /*! \file bv_subtheory_core.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Liana Hadarean, Aina Niemetz, Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Algebraic solver.
13 **
14 ** Algebraic solver.
15 **/
16
17 #include "theory/bv/bv_subtheory_core.h"
18
19 #include "options/bv_options.h"
20 #include "options/smt_options.h"
21 #include "smt/smt_statistics_registry.h"
22 #include "theory/bv/slicer.h"
23 #include "theory/bv/theory_bv.h"
24 #include "theory/bv/theory_bv_utils.h"
25 #include "theory/ext_theory.h"
26 #include "theory/theory_model.h"
27
28 using namespace std;
29 using namespace CVC4;
30 using namespace CVC4::context;
31 using namespace CVC4::theory;
32 using namespace CVC4::theory::bv;
33 using namespace CVC4::theory::bv::utils;
34
35 CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv, ExtTheory* extt)
36 : SubtheorySolver(c, bv),
37 d_notify(*this),
38 d_equalityEngine(d_notify, c, "theory::bv::ee", true),
39 d_slicer(new Slicer()),
40 d_isComplete(c, true),
41 d_lemmaThreshold(16),
42 d_useSlicer(false),
43 d_preregisterCalled(false),
44 d_checkCalled(false),
45 d_extTheory(extt),
46 d_reasons(c)
47 {
48 // The kinds we are treating as function application in congruence
49 d_equalityEngine.addFunctionKind(kind::BITVECTOR_CONCAT, true);
50 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_AND);
51 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_OR);
52 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_XOR);
53 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NOT);
54 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NAND);
55 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NOR);
56 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_XNOR);
57 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_COMP);
58 d_equalityEngine.addFunctionKind(kind::BITVECTOR_MULT, true);
59 d_equalityEngine.addFunctionKind(kind::BITVECTOR_PLUS, true);
60 d_equalityEngine.addFunctionKind(kind::BITVECTOR_EXTRACT, true);
61 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SUB);
62 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NEG);
63 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UDIV);
64 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UREM);
65 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SDIV);
66 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SREM);
67 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SMOD);
68 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SHL);
69 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_LSHR);
70 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ASHR);
71 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ULT);
72 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ULE);
73 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UGT);
74 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UGE);
75 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SLT);
76 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SLE);
77 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SGT);
78 // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SGE);
79 d_equalityEngine.addFunctionKind(kind::BITVECTOR_TO_NAT);
80 d_equalityEngine.addFunctionKind(kind::INT_TO_BITVECTOR);
81 }
82
83 CoreSolver::~CoreSolver() {}
84
85 void CoreSolver::setMasterEqualityEngine(eq::EqualityEngine* eq) {
86 d_equalityEngine.setMasterEqualityEngine(eq);
87 }
88
89 void CoreSolver::enableSlicer() {
90 AlwaysAssert(!d_preregisterCalled);
91 d_useSlicer = true;
92 d_statistics.d_slicerEnabled.setData(true);
93 }
94
95 void CoreSolver::preRegister(TNode node) {
96 d_preregisterCalled = true;
97 if (node.getKind() == kind::EQUAL) {
98 d_equalityEngine.addTriggerEquality(node);
99 if (d_useSlicer) {
100 d_slicer->processEquality(node);
101 AlwaysAssert(!d_checkCalled);
102 }
103 } else {
104 d_equalityEngine.addTerm(node);
105 // Register with the extended theory, for context-dependent simplification.
106 // Notice we do this for registered terms but not internally generated
107 // equivalence classes. The two should roughly cooincide. Since ExtTheory is
108 // being used as a heuristic, it is good enough to be registered here.
109 d_extTheory->registerTerm(node);
110 }
111 }
112
113
114 void CoreSolver::explain(TNode literal, std::vector<TNode>& assumptions) {
115 bool polarity = literal.getKind() != kind::NOT;
116 TNode atom = polarity ? literal : literal[0];
117 if (atom.getKind() == kind::EQUAL) {
118 d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
119 } else {
120 d_equalityEngine.explainPredicate(atom, polarity, assumptions);
121 }
122 }
123
124 Node CoreSolver::getBaseDecomposition(TNode a) {
125 std::vector<Node> a_decomp;
126 d_slicer->getBaseDecomposition(a, a_decomp);
127 Node new_a = utils::mkConcat(a_decomp);
128 Debug("bv-slicer") << "CoreSolver::getBaseDecomposition " << a <<" => " << new_a << "\n";
129 return new_a;
130 }
131
132 bool CoreSolver::decomposeFact(TNode fact) {
133 Debug("bv-slicer") << "CoreSolver::decomposeFact fact=" << fact << endl;
134 // FIXME: are this the right things to assert?
135 // assert decompositions since the equality engine does not know the semantics of
136 // concat:
137 // a == a_1 concat ... concat a_k
138 // b == b_1 concat ... concat b_k
139 TNode eq = fact.getKind() == kind::NOT? fact[0] : fact;
140
141 TNode a = eq[0];
142 TNode b = eq[1];
143 Node new_a = getBaseDecomposition(a);
144 Node new_b = getBaseDecomposition(b);
145
146 Assert(utils::getSize(new_a) == utils::getSize(new_b)
147 && utils::getSize(new_a) == utils::getSize(a));
148
149 NodeManager* nm = NodeManager::currentNM();
150 Node a_eq_new_a = nm->mkNode(kind::EQUAL, a, new_a);
151 Node b_eq_new_b = nm->mkNode(kind::EQUAL, b, new_b);
152
153 bool ok = true;
154 ok = assertFactToEqualityEngine(a_eq_new_a, utils::mkTrue());
155 if (!ok) return false;
156 ok = assertFactToEqualityEngine(b_eq_new_b, utils::mkTrue());
157 if (!ok) return false;
158 ok = assertFactToEqualityEngine(fact, fact);
159 if (!ok) return false;
160
161 if (fact.getKind() == kind::EQUAL) {
162 // assert the individual equalities as well
163 // a_i == b_i
164 if (new_a.getKind() == kind::BITVECTOR_CONCAT &&
165 new_b.getKind() == kind::BITVECTOR_CONCAT) {
166 Assert(new_a.getNumChildren() == new_b.getNumChildren());
167 for (unsigned i = 0; i < new_a.getNumChildren(); ++i) {
168 Node eq_i = nm->mkNode(kind::EQUAL, new_a[i], new_b[i]);
169 ok = assertFactToEqualityEngine(eq_i, fact);
170 if (!ok) return false;
171 }
172 }
173 }
174 return true;
175 }
176
177 bool CoreSolver::check(Theory::Effort e) {
178 Trace("bitvector::core") << "CoreSolver::check \n";
179
180 d_bv->spendResource(ResourceManager::Resource::TheoryCheckStep);
181
182 d_checkCalled = true;
183 Assert(!d_bv->inConflict());
184 ++(d_statistics.d_numCallstoCheck);
185 bool ok = true;
186 std::vector<Node> core_eqs;
187 TNodeBoolMap seen;
188 // slicer does not deal with cardinality constraints yet
189 if (d_useSlicer) {
190 d_isComplete = false;
191 }
192 while (! done()) {
193 TNode fact = get();
194 if (d_isComplete && !isCompleteForTerm(fact, seen)) {
195 d_isComplete = false;
196 }
197
198 // only reason about equalities
199 if (fact.getKind() == kind::EQUAL || (fact.getKind() == kind::NOT && fact[0].getKind() == kind::EQUAL)) {
200 if (d_useSlicer) {
201 ok = decomposeFact(fact);
202 } else {
203 ok = assertFactToEqualityEngine(fact, fact);
204 }
205 } else {
206 ok = assertFactToEqualityEngine(fact, fact);
207 }
208 if (!ok)
209 return false;
210 }
211
212 if (Theory::fullEffort(e) && isComplete()) {
213 buildModel();
214 }
215
216 return true;
217 }
218
219 void CoreSolver::buildModel()
220 {
221 Debug("bv-core") << "CoreSolver::buildModel() \n";
222 NodeManager* nm = NodeManager::currentNM();
223 d_modelValues.clear();
224 TNodeSet constants;
225 TNodeSet constants_in_eq_engine;
226 // collect constants in equality engine
227 eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(&d_equalityEngine);
228 while (!eqcs_i.isFinished())
229 {
230 TNode repr = *eqcs_i;
231 if (repr.getKind() == kind::CONST_BITVECTOR)
232 {
233 // must check if it's just the constant
234 eq::EqClassIterator it(repr, &d_equalityEngine);
235 if (!(++it).isFinished() || true)
236 {
237 constants.insert(repr);
238 constants_in_eq_engine.insert(repr);
239 }
240 }
241 ++eqcs_i;
242 }
243
244 // build repr to value map
245
246 eqcs_i = eq::EqClassesIterator(&d_equalityEngine);
247 while (!eqcs_i.isFinished())
248 {
249 TNode repr = *eqcs_i;
250 ++eqcs_i;
251
252 if (!repr.isVar() && repr.getKind() != kind::CONST_BITVECTOR
253 && !d_bv->isSharedTerm(repr))
254 {
255 continue;
256 }
257
258 TypeNode type = repr.getType();
259 if (type.isBitVector() && repr.getKind() != kind::CONST_BITVECTOR)
260 {
261 Debug("bv-core-model") << " processing " << repr << "\n";
262 // we need to assign a value for it
263 TypeEnumerator te(type);
264 Node val;
265 do
266 {
267 val = *te;
268 ++te;
269 // Debug("bv-core-model") << " trying value " << val << "\n";
270 // Debug("bv-core-model") << " is in set? " << constants.count(val) <<
271 // "\n"; Debug("bv-core-model") << " enumerator done? " <<
272 // te.isFinished() << "\n";
273 } while (constants.count(val) != 0 && !(te.isFinished()));
274
275 if (te.isFinished() && constants.count(val) != 0)
276 {
277 // if we cannot enumerate anymore values we just return the lemma
278 // stating that at least two of the representatives are equal.
279 std::vector<TNode> representatives;
280 representatives.push_back(repr);
281
282 for (TNodeSet::const_iterator it = constants_in_eq_engine.begin();
283 it != constants_in_eq_engine.end();
284 ++it)
285 {
286 TNode constant = *it;
287 if (utils::getSize(constant) == utils::getSize(repr))
288 {
289 representatives.push_back(constant);
290 }
291 }
292 for (ModelValue::const_iterator it = d_modelValues.begin();
293 it != d_modelValues.end();
294 ++it)
295 {
296 representatives.push_back(it->first);
297 }
298 std::vector<Node> equalities;
299 for (unsigned i = 0; i < representatives.size(); ++i)
300 {
301 for (unsigned j = i + 1; j < representatives.size(); ++j)
302 {
303 TNode a = representatives[i];
304 TNode b = representatives[j];
305 if (a.getKind() == kind::CONST_BITVECTOR
306 && b.getKind() == kind::CONST_BITVECTOR)
307 {
308 Assert(a != b);
309 continue;
310 }
311 if (utils::getSize(a) == utils::getSize(b))
312 {
313 equalities.push_back(nm->mkNode(kind::EQUAL, a, b));
314 }
315 }
316 }
317 // better off letting the SAT solver split on values
318 if (equalities.size() > d_lemmaThreshold)
319 {
320 d_isComplete = false;
321 return;
322 }
323
324 if (equalities.size() == 0)
325 {
326 Debug("bv-core") << " lemma: true (no equalities)" << std::endl;
327 }
328 else
329 {
330 Node lemma = utils::mkOr(equalities);
331 d_bv->lemma(lemma);
332 Debug("bv-core") << " lemma: " << lemma << std::endl;
333 }
334 return;
335 }
336
337 Debug("bv-core-model") << " " << repr << " => " << val << "\n";
338 constants.insert(val);
339 d_modelValues[repr] = val;
340 }
341 }
342 }
343
344 bool CoreSolver::assertFactToEqualityEngine(TNode fact, TNode reason) {
345 // Notify the equality engine
346 if (!d_bv->inConflict() && (!d_bv->wasPropagatedBySubtheory(fact) || d_bv->getPropagatingSubtheory(fact) != SUB_CORE)) {
347 Debug("bv-slicer-eq") << "CoreSolver::assertFactToEqualityEngine fact=" << fact << endl;
348 // Debug("bv-slicer-eq") << " reason=" << reason << endl;
349 bool negated = fact.getKind() == kind::NOT;
350 TNode predicate = negated ? fact[0] : fact;
351 if (predicate.getKind() == kind::EQUAL) {
352 if (negated) {
353 // dis-equality
354 d_equalityEngine.assertEquality(predicate, false, reason);
355 } else {
356 // equality
357 d_equalityEngine.assertEquality(predicate, true, reason);
358 }
359 } else {
360 // Adding predicate if the congruence over it is turned on
361 if (d_equalityEngine.isFunctionKind(predicate.getKind())) {
362 d_equalityEngine.assertPredicate(predicate, !negated, reason);
363 }
364 }
365 }
366
367 // checking for a conflict
368 if (d_bv->inConflict()) {
369 return false;
370 }
371 return true;
372 }
373
374 bool CoreSolver::NotifyClass::eqNotifyTriggerEquality(TNode equality, bool value) {
375 Debug("bitvector::core") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false" )<< ")" << std::endl;
376 if (value) {
377 return d_solver.storePropagation(equality);
378 } else {
379 return d_solver.storePropagation(equality.notNode());
380 }
381 }
382
383 bool CoreSolver::NotifyClass::eqNotifyTriggerPredicate(TNode predicate, bool value) {
384 Debug("bitvector::core") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false" ) << ")" << std::endl;
385 if (value) {
386 return d_solver.storePropagation(predicate);
387 } else {
388 return d_solver.storePropagation(predicate.notNode());
389 }
390 }
391
392 bool CoreSolver::NotifyClass::eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value) {
393 Debug("bitvector::core") << "NotifyClass::eqNotifyTriggerTermMerge(" << t1 << ", " << t2 << ")" << std::endl;
394 if (value) {
395 return d_solver.storePropagation(t1.eqNode(t2));
396 } else {
397 return d_solver.storePropagation(t1.eqNode(t2).notNode());
398 }
399 }
400
401 void CoreSolver::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t2) {
402 d_solver.conflict(t1, t2);
403 }
404
405 bool CoreSolver::storePropagation(TNode literal) {
406 return d_bv->storePropagation(literal, SUB_CORE);
407 }
408
409 void CoreSolver::conflict(TNode a, TNode b) {
410 std::vector<TNode> assumptions;
411 d_equalityEngine.explainEquality(a, b, true, assumptions);
412 Node conflict = flattenAnd(assumptions);
413 d_bv->setConflict(conflict);
414 }
415
416 bool CoreSolver::isCompleteForTerm(TNode term, TNodeBoolMap& seen) {
417 if (d_useSlicer)
418 return utils::isCoreTerm(term, seen);
419
420 return utils::isEqualityTerm(term, seen);
421 }
422
423 bool CoreSolver::collectModelInfo(TheoryModel* m, bool fullModel)
424 {
425 if (d_useSlicer) {
426 Unreachable();
427 }
428 if (Debug.isOn("bitvector-model")) {
429 context::CDQueue<Node>::const_iterator it = d_assertionQueue.begin();
430 for (; it!= d_assertionQueue.end(); ++it) {
431 Debug("bitvector-model") << "CoreSolver::collectModelInfo (assert "
432 << *it << ")\n";
433 }
434 }
435 set<Node> termSet;
436 d_bv->computeRelevantTerms(termSet);
437 if (!m->assertEqualityEngine(&d_equalityEngine, &termSet))
438 {
439 return false;
440 }
441 if (isComplete()) {
442 Debug("bitvector-model") << "CoreSolver::collectModelInfo complete.";
443 for (ModelValue::const_iterator it = d_modelValues.begin(); it != d_modelValues.end(); ++it) {
444 Node a = it->first;
445 Node b = it->second;
446 Debug("bitvector-model") << "CoreSolver::collectModelInfo modelValues "
447 << a << " => " << b <<")\n";
448 if (!m->assertEquality(a, b, true))
449 {
450 return false;
451 }
452 }
453 }
454 return true;
455 }
456
457 Node CoreSolver::getModelValue(TNode var) {
458 Debug("bitvector-model") << "CoreSolver::getModelValue (" << var <<")";
459 Assert(isComplete());
460 TNode repr = d_equalityEngine.getRepresentative(var);
461 Node result = Node();
462 if (repr.getKind() == kind::CONST_BITVECTOR) {
463 result = repr;
464 } else if (d_modelValues.find(repr) == d_modelValues.end()) {
465 // it may be a shared term that never gets asserted
466 // result is just Null
467 Assert(d_bv->isSharedTerm(var));
468 } else {
469 result = d_modelValues[repr];
470 }
471 Debug("bitvector-model") << " => " << result <<"\n";
472 return result;
473 }
474
475 CoreSolver::Statistics::Statistics()
476 : d_numCallstoCheck("theory::bv::CoreSolver::NumCallsToCheck", 0)
477 , d_slicerEnabled("theory::bv::CoreSolver::SlicerEnabled", false)
478 {
479 smtStatisticsRegistry()->registerStat(&d_numCallstoCheck);
480 smtStatisticsRegistry()->registerStat(&d_slicerEnabled);
481 }
482 CoreSolver::Statistics::~Statistics() {
483 smtStatisticsRegistry()->unregisterStat(&d_numCallstoCheck);
484 smtStatisticsRegistry()->unregisterStat(&d_slicerEnabled);
485 }