Properly set up equality engine for BV bitblast solver. (#5905)
[cvc5.git] / src / theory / bv / bv_subtheory_core.cpp
1 /********************* */
2 /*! \file bv_subtheory_core.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Liana Hadarean, Aina Niemetz
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Algebraic solver.
13 **
14 ** Algebraic solver.
15 **/
16
17 #include "theory/bv/bv_subtheory_core.h"
18
19 #include "options/bv_options.h"
20 #include "options/smt_options.h"
21 #include "smt/smt_statistics_registry.h"
22 #include "theory/bv/bv_solver_lazy.h"
23 #include "theory/bv/theory_bv_utils.h"
24 #include "theory/ext_theory.h"
25 #include "theory/theory_model.h"
26
27 using namespace std;
28 using namespace CVC4;
29 using namespace CVC4::context;
30 using namespace CVC4::theory;
31 using namespace CVC4::theory::bv;
32 using namespace CVC4::theory::bv::utils;
33
34 bool CoreSolverExtTheoryCallback::getCurrentSubstitution(
35 int effort,
36 const std::vector<Node>& vars,
37 std::vector<Node>& subs,
38 std::map<Node, std::vector<Node> >& exp)
39 {
40 if (d_equalityEngine == nullptr)
41 {
42 return false;
43 }
44 // get the constant equivalence classes
45 bool retVal = false;
46 for (const Node& n : vars)
47 {
48 if (d_equalityEngine->hasTerm(n))
49 {
50 Node nr = d_equalityEngine->getRepresentative(n);
51 if (nr.isConst())
52 {
53 subs.push_back(nr);
54 exp[n].push_back(n.eqNode(nr));
55 retVal = true;
56 }
57 else
58 {
59 subs.push_back(n);
60 }
61 }
62 else
63 {
64 subs.push_back(n);
65 }
66 }
67 // return true if the substitution is non-trivial
68 return retVal;
69 }
70
71 bool CoreSolverExtTheoryCallback::getReduction(int effort,
72 Node n,
73 Node& nr,
74 bool& satDep)
75 {
76 Trace("bv-ext") << "TheoryBV::checkExt : non-reduced : " << n << std::endl;
77 if (n.getKind() == kind::BITVECTOR_TO_NAT)
78 {
79 nr = utils::eliminateBv2Nat(n);
80 satDep = false;
81 return true;
82 }
83 else if (n.getKind() == kind::INT_TO_BITVECTOR)
84 {
85 nr = utils::eliminateInt2Bv(n);
86 satDep = false;
87 return true;
88 }
89 return false;
90 }
91
92 CoreSolver::CoreSolver(context::Context* c, BVSolverLazy* bv)
93 : SubtheorySolver(c, bv),
94 d_notify(*this),
95 d_isComplete(c, true),
96 d_lemmaThreshold(16),
97 d_preregisterCalled(false),
98 d_checkCalled(false),
99 d_bv(bv),
100 d_extTheoryCb(),
101 d_extTheory(new ExtTheory(d_extTheoryCb,
102 bv->d_bv.getSatContext(),
103 bv->d_bv.getUserContext(),
104 bv->d_bv.getOutputChannel())),
105 d_reasons(c),
106 d_needsLastCallCheck(false),
107 d_extf_range_infer(bv->d_bv.getUserContext()),
108 d_extf_collapse_infer(bv->d_bv.getUserContext())
109 {
110 d_extTheory->addFunctionKind(kind::BITVECTOR_TO_NAT);
111 d_extTheory->addFunctionKind(kind::INT_TO_BITVECTOR);
112 }
113
114 CoreSolver::~CoreSolver() {}
115
116 bool CoreSolver::needsEqualityEngine(EeSetupInfo& esi)
117 {
118 esi.d_notify = &d_notify;
119 esi.d_name = "theory::bv::ee";
120 return true;
121 }
122
123 void CoreSolver::finishInit()
124 {
125 // use the parent's equality engine, which may be the one we allocated above
126 d_equalityEngine = d_bv->d_bv.getEqualityEngine();
127
128 // The kinds we are treating as function application in congruence
129 d_equalityEngine->addFunctionKind(kind::BITVECTOR_CONCAT, true);
130 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_AND);
131 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_OR);
132 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_XOR);
133 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_NOT);
134 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_NAND);
135 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_NOR);
136 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_XNOR);
137 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_COMP);
138 d_equalityEngine->addFunctionKind(kind::BITVECTOR_MULT, true);
139 d_equalityEngine->addFunctionKind(kind::BITVECTOR_PLUS, true);
140 d_equalityEngine->addFunctionKind(kind::BITVECTOR_EXTRACT, true);
141 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SUB);
142 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_NEG);
143 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_UDIV);
144 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_UREM);
145 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SDIV);
146 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SREM);
147 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SMOD);
148 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SHL);
149 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_LSHR);
150 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_ASHR);
151 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_ULT);
152 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_ULE);
153 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_UGT);
154 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_UGE);
155 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SLT);
156 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SLE);
157 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SGT);
158 // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SGE);
159 d_equalityEngine->addFunctionKind(kind::BITVECTOR_TO_NAT);
160 d_equalityEngine->addFunctionKind(kind::INT_TO_BITVECTOR);
161 }
162
163 void CoreSolver::preRegister(TNode node) {
164 d_preregisterCalled = true;
165 if (node.getKind() == kind::EQUAL) {
166 d_equalityEngine->addTriggerPredicate(node);
167 } else {
168 d_equalityEngine->addTerm(node);
169 // Register with the extended theory, for context-dependent simplification.
170 // Notice we do this for registered terms but not internally generated
171 // equivalence classes. The two should roughly cooincide. Since ExtTheory is
172 // being used as a heuristic, it is good enough to be registered here.
173 d_extTheory->registerTerm(node);
174 }
175 }
176
177
178 void CoreSolver::explain(TNode literal, std::vector<TNode>& assumptions) {
179 bool polarity = literal.getKind() != kind::NOT;
180 TNode atom = polarity ? literal : literal[0];
181 if (atom.getKind() == kind::EQUAL) {
182 d_equalityEngine->explainEquality(atom[0], atom[1], polarity, assumptions);
183 } else {
184 d_equalityEngine->explainPredicate(atom, polarity, assumptions);
185 }
186 }
187
188 bool CoreSolver::check(Theory::Effort e) {
189 Trace("bitvector::core") << "CoreSolver::check \n";
190
191 d_bv->d_im.spendResource(ResourceManager::Resource::TheoryCheckStep);
192
193 d_checkCalled = true;
194 Assert(!d_bv->inConflict());
195 ++(d_statistics.d_numCallstoCheck);
196 bool ok = true;
197 std::vector<Node> core_eqs;
198 TNodeBoolMap seen;
199 while (! done()) {
200 TNode fact = get();
201 if (d_isComplete && !isCompleteForTerm(fact, seen)) {
202 d_isComplete = false;
203 }
204
205 // only reason about equalities
206 if (fact.getKind() == kind::EQUAL || (fact.getKind() == kind::NOT && fact[0].getKind() == kind::EQUAL)) {
207 ok = assertFactToEqualityEngine(fact, fact);
208 } else {
209 ok = assertFactToEqualityEngine(fact, fact);
210 }
211 if (!ok)
212 return false;
213 }
214
215 if (Theory::fullEffort(e) && isComplete()) {
216 buildModel();
217 }
218
219 return true;
220 }
221
222 void CoreSolver::buildModel()
223 {
224 Debug("bv-core") << "CoreSolver::buildModel() \n";
225 NodeManager* nm = NodeManager::currentNM();
226 d_modelValues.clear();
227 TNodeSet constants;
228 TNodeSet constants_in_eq_engine;
229 // collect constants in equality engine
230 eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine);
231 while (!eqcs_i.isFinished())
232 {
233 TNode repr = *eqcs_i;
234 if (repr.getKind() == kind::CONST_BITVECTOR)
235 {
236 // must check if it's just the constant
237 eq::EqClassIterator it(repr, d_equalityEngine);
238 if (!(++it).isFinished() || true)
239 {
240 constants.insert(repr);
241 constants_in_eq_engine.insert(repr);
242 }
243 }
244 ++eqcs_i;
245 }
246
247 // build repr to value map
248
249 eqcs_i = eq::EqClassesIterator(d_equalityEngine);
250 while (!eqcs_i.isFinished())
251 {
252 TNode repr = *eqcs_i;
253 ++eqcs_i;
254
255 if (!repr.isVar() && repr.getKind() != kind::CONST_BITVECTOR
256 && !d_bv->isSharedTerm(repr))
257 {
258 continue;
259 }
260
261 TypeNode type = repr.getType();
262 if (type.isBitVector() && repr.getKind() != kind::CONST_BITVECTOR)
263 {
264 Debug("bv-core-model") << " processing " << repr << "\n";
265 // we need to assign a value for it
266 TypeEnumerator te(type);
267 Node val;
268 do
269 {
270 val = *te;
271 ++te;
272 // Debug("bv-core-model") << " trying value " << val << "\n";
273 // Debug("bv-core-model") << " is in set? " << constants.count(val) <<
274 // "\n"; Debug("bv-core-model") << " enumerator done? " <<
275 // te.isFinished() << "\n";
276 } while (constants.count(val) != 0 && !(te.isFinished()));
277
278 if (te.isFinished() && constants.count(val) != 0)
279 {
280 // if we cannot enumerate anymore values we just return the lemma
281 // stating that at least two of the representatives are equal.
282 std::vector<TNode> representatives;
283 representatives.push_back(repr);
284
285 for (TNodeSet::const_iterator it = constants_in_eq_engine.begin();
286 it != constants_in_eq_engine.end();
287 ++it)
288 {
289 TNode constant = *it;
290 if (utils::getSize(constant) == utils::getSize(repr))
291 {
292 representatives.push_back(constant);
293 }
294 }
295 for (ModelValue::const_iterator it = d_modelValues.begin();
296 it != d_modelValues.end();
297 ++it)
298 {
299 representatives.push_back(it->first);
300 }
301 std::vector<Node> equalities;
302 for (unsigned i = 0; i < representatives.size(); ++i)
303 {
304 for (unsigned j = i + 1; j < representatives.size(); ++j)
305 {
306 TNode a = representatives[i];
307 TNode b = representatives[j];
308 if (a.getKind() == kind::CONST_BITVECTOR
309 && b.getKind() == kind::CONST_BITVECTOR)
310 {
311 Assert(a != b);
312 continue;
313 }
314 if (utils::getSize(a) == utils::getSize(b))
315 {
316 equalities.push_back(nm->mkNode(kind::EQUAL, a, b));
317 }
318 }
319 }
320 // better off letting the SAT solver split on values
321 if (equalities.size() > d_lemmaThreshold)
322 {
323 d_isComplete = false;
324 return;
325 }
326
327 if (equalities.size() == 0)
328 {
329 Debug("bv-core") << " lemma: true (no equalities)" << std::endl;
330 }
331 else
332 {
333 Node lemma = utils::mkOr(equalities);
334 d_bv->lemma(lemma);
335 Debug("bv-core") << " lemma: " << lemma << std::endl;
336 }
337 return;
338 }
339
340 Debug("bv-core-model") << " " << repr << " => " << val << "\n";
341 constants.insert(val);
342 d_modelValues[repr] = val;
343 }
344 }
345 }
346
347 bool CoreSolver::assertFactToEqualityEngine(TNode fact, TNode reason) {
348 // Notify the equality engine
349 if (!d_bv->inConflict()
350 && (!d_bv->wasPropagatedBySubtheory(fact)
351 || d_bv->getPropagatingSubtheory(fact) != SUB_CORE))
352 {
353 Debug("bv-slicer-eq") << "CoreSolver::assertFactToEqualityEngine fact=" << fact << endl;
354 // Debug("bv-slicer-eq") << " reason=" << reason << endl;
355 bool negated = fact.getKind() == kind::NOT;
356 TNode predicate = negated ? fact[0] : fact;
357 if (predicate.getKind() == kind::EQUAL) {
358 if (negated) {
359 // dis-equality
360 d_equalityEngine->assertEquality(predicate, false, reason);
361 } else {
362 // equality
363 d_equalityEngine->assertEquality(predicate, true, reason);
364 }
365 } else {
366 // Adding predicate if the congruence over it is turned on
367 if (d_equalityEngine->isFunctionKind(predicate.getKind()))
368 {
369 d_equalityEngine->assertPredicate(predicate, !negated, reason);
370 }
371 }
372 }
373
374 // checking for a conflict
375 if (d_bv->inConflict())
376 {
377 return false;
378 }
379 return true;
380 }
381
382 bool CoreSolver::NotifyClass::eqNotifyTriggerPredicate(TNode predicate, bool value) {
383 Debug("bitvector::core") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false" ) << ")" << std::endl;
384 if (value) {
385 return d_solver.storePropagation(predicate);
386 }
387 return d_solver.storePropagation(predicate.notNode());
388 }
389
390 bool CoreSolver::NotifyClass::eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value) {
391 Debug("bitvector::core") << "NotifyClass::eqNotifyTriggerTermMerge(" << t1 << ", " << t2 << ")" << std::endl;
392 if (value) {
393 return d_solver.storePropagation(t1.eqNode(t2));
394 } else {
395 return d_solver.storePropagation(t1.eqNode(t2).notNode());
396 }
397 }
398
399 void CoreSolver::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t2) {
400 d_solver.conflict(t1, t2);
401 }
402
403 bool CoreSolver::storePropagation(TNode literal) {
404 return d_bv->storePropagation(literal, SUB_CORE);
405 }
406
407 void CoreSolver::conflict(TNode a, TNode b) {
408 std::vector<TNode> assumptions;
409 d_equalityEngine->explainEquality(a, b, true, assumptions);
410 Node conflict = flattenAnd(assumptions);
411 d_bv->setConflict(conflict);
412 }
413
414 bool CoreSolver::isCompleteForTerm(TNode term, TNodeBoolMap& seen) {
415 return utils::isEqualityTerm(term, seen);
416 }
417
418 bool CoreSolver::collectModelValues(TheoryModel* m,
419 const std::set<Node>& termSet)
420 {
421 if (Debug.isOn("bitvector-model")) {
422 context::CDQueue<Node>::const_iterator it = d_assertionQueue.begin();
423 for (; it!= d_assertionQueue.end(); ++it) {
424 Debug("bitvector-model")
425 << "CoreSolver::collectModelValues (assert " << *it << ")\n";
426 }
427 }
428 if (isComplete()) {
429 Debug("bitvector-model") << "CoreSolver::collectModelValues complete.";
430 for (ModelValue::const_iterator it = d_modelValues.begin(); it != d_modelValues.end(); ++it) {
431 Node a = it->first;
432 Node b = it->second;
433 Debug("bitvector-model") << "CoreSolver::collectModelValues modelValues "
434 << a << " => " << b << ")\n";
435 if (!m->assertEquality(a, b, true))
436 {
437 return false;
438 }
439 }
440 }
441 return true;
442 }
443
444 Node CoreSolver::getModelValue(TNode var) {
445 Debug("bitvector-model") << "CoreSolver::getModelValue (" << var <<")";
446 Assert(isComplete());
447 TNode repr = d_equalityEngine->getRepresentative(var);
448 Node result = Node();
449 if (repr.getKind() == kind::CONST_BITVECTOR) {
450 result = repr;
451 } else if (d_modelValues.find(repr) == d_modelValues.end()) {
452 // it may be a shared term that never gets asserted
453 // result is just Null
454 Assert(d_bv->isSharedTerm(var));
455 } else {
456 result = d_modelValues[repr];
457 }
458 Debug("bitvector-model") << " => " << result <<"\n";
459 return result;
460 }
461
462 EqualityStatus CoreSolver::getEqualityStatus(TNode a, TNode b)
463 {
464 if (d_equalityEngine->areEqual(a, b))
465 {
466 // The terms are implied to be equal
467 return EQUALITY_TRUE;
468 }
469 if (d_equalityEngine->areDisequal(a, b, false))
470 {
471 // The terms are implied to be dis-equal
472 return EQUALITY_FALSE;
473 }
474 return EQUALITY_UNKNOWN;
475 }
476
477 bool CoreSolver::hasTerm(TNode node) const
478 {
479 return d_equalityEngine->hasTerm(node);
480 }
481 void CoreSolver::addTermToEqualityEngine(TNode node)
482 {
483 d_equalityEngine->addTerm(node);
484 }
485
486 CoreSolver::Statistics::Statistics()
487 : d_numCallstoCheck("theory::bv::CoreSolver::NumCallsToCheck", 0)
488 {
489 smtStatisticsRegistry()->registerStat(&d_numCallstoCheck);
490 }
491 CoreSolver::Statistics::~Statistics() {
492 smtStatisticsRegistry()->unregisterStat(&d_numCallstoCheck);
493 }
494
495 void CoreSolver::checkExtf(Theory::Effort e)
496 {
497 if (e == Theory::EFFORT_LAST_CALL)
498 {
499 std::vector<Node> nred = d_extTheory->getActive();
500 doExtfReductions(nred);
501 }
502 Assert(e == Theory::EFFORT_FULL);
503 // do inferences (adds external lemmas) TODO: this can be improved to add
504 // internal inferences
505 std::vector<Node> nred;
506 if (d_extTheory->doInferences(0, nred))
507 {
508 return;
509 }
510 d_needsLastCallCheck = false;
511 if (!nred.empty())
512 {
513 // other inferences involving bv2nat, int2bv
514 if (options::bvAlgExtf())
515 {
516 if (doExtfInferences(nred))
517 {
518 return;
519 }
520 }
521 if (!options::bvLazyReduceExtf())
522 {
523 if (doExtfReductions(nred))
524 {
525 return;
526 }
527 }
528 else
529 {
530 d_needsLastCallCheck = true;
531 }
532 }
533 }
534
535 bool CoreSolver::needsCheckLastEffort() const { return d_needsLastCallCheck; }
536
537 bool CoreSolver::doExtfInferences(std::vector<Node>& terms)
538 {
539 NodeManager* nm = NodeManager::currentNM();
540 bool sentLemma = false;
541 eq::EqualityEngine* ee = d_equalityEngine;
542 std::map<Node, Node> op_map;
543 for (unsigned j = 0; j < terms.size(); j++)
544 {
545 TNode n = terms[j];
546 Assert(n.getKind() == kind::BITVECTOR_TO_NAT
547 || n.getKind() == kind::INT_TO_BITVECTOR);
548 if (n.getKind() == kind::BITVECTOR_TO_NAT)
549 {
550 // range lemmas
551 if (d_extf_range_infer.find(n) == d_extf_range_infer.end())
552 {
553 d_extf_range_infer.insert(n);
554 unsigned bvs = n[0].getType().getBitVectorSize();
555 Node min = nm->mkConst(Rational(0));
556 Node max = nm->mkConst(Rational(Integer(1).multiplyByPow2(bvs)));
557 Node lem = nm->mkNode(kind::AND,
558 nm->mkNode(kind::GEQ, n, min),
559 nm->mkNode(kind::LT, n, max));
560 Trace("bv-extf-lemma")
561 << "BV extf lemma (range) : " << lem << std::endl;
562 d_bv->d_im.lemma(lem, InferenceId::UNKNOWN);
563 sentLemma = true;
564 }
565 }
566 Node r = (ee && ee->hasTerm(n[0])) ? ee->getRepresentative(n[0]) : n[0];
567 op_map[r] = n;
568 }
569 for (unsigned j = 0; j < terms.size(); j++)
570 {
571 TNode n = terms[j];
572 Node r = (ee && ee->hasTerm(n[0])) ? ee->getRepresentative(n) : n;
573 std::map<Node, Node>::iterator it = op_map.find(r);
574 if (it != op_map.end())
575 {
576 Node parent = it->second;
577 // Node cterm = parent[0]==n ? parent : nm->mkNode( parent.getOperator(),
578 // n );
579 Node cterm = parent[0].eqNode(n);
580 Trace("bv-extf-lemma-debug")
581 << "BV extf collapse based on : " << cterm << std::endl;
582 if (d_extf_collapse_infer.find(cterm) == d_extf_collapse_infer.end())
583 {
584 d_extf_collapse_infer.insert(cterm);
585
586 Node t = n[0];
587 if (t.getType() == parent.getType())
588 {
589 if (n.getKind() == kind::INT_TO_BITVECTOR)
590 {
591 Assert(t.getType().isInteger());
592 // congruent modulo 2^( bv width )
593 unsigned bvs = n.getType().getBitVectorSize();
594 Node coeff = nm->mkConst(Rational(Integer(1).multiplyByPow2(bvs)));
595 Node k = nm->mkSkolem(
596 "int_bv_cong", t.getType(), "for int2bv/bv2nat congruence");
597 t = nm->mkNode(kind::PLUS, t, nm->mkNode(kind::MULT, coeff, k));
598 }
599 Node lem = parent.eqNode(t);
600
601 if (parent[0] != n)
602 {
603 Assert(ee->areEqual(parent[0], n));
604 lem = nm->mkNode(kind::IMPLIES, parent[0].eqNode(n), lem);
605 }
606 // this handles inferences of the form, e.g.:
607 // ((_ int2bv w) (bv2nat x)) == x (if x is bit-width w)
608 // (bv2nat ((_ int2bv w) x)) == x + k*2^w for some k
609 Trace("bv-extf-lemma")
610 << "BV extf lemma (collapse) : " << lem << std::endl;
611 d_bv->d_im.lemma(lem, InferenceId::UNKNOWN);
612 sentLemma = true;
613 }
614 }
615 Trace("bv-extf-lemma-debug")
616 << "BV extf f collapse based on : " << cterm << std::endl;
617 }
618 }
619 return sentLemma;
620 }
621
622 bool CoreSolver::doExtfReductions(std::vector<Node>& terms)
623 {
624 std::vector<Node> nredr;
625 if (d_extTheory->doReductions(0, terms, nredr))
626 {
627 return true;
628 }
629 Assert(nredr.empty());
630 return false;
631 }