Simplify trigger notifications in equality engine (#4921)
[cvc5.git] / src / theory / arith / congruence_manager.cpp
1 /********************* */
2 /*! \file congruence_manager.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Tim King, Dejan Jovanovic, Paul Meng
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief [[ Add one-line brief description here ]]
13 **
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
16 **/
17
18 #include "theory/arith/congruence_manager.h"
19
20 #include "base/output.h"
21 #include "smt/smt_statistics_registry.h"
22 #include "theory/arith/arith_utilities.h"
23 #include "theory/arith/constraint.h"
24 #include "options/arith_options.h"
25
26 namespace CVC4 {
27 namespace theory {
28 namespace arith {
29
30 ArithCongruenceManager::ArithCongruenceManager(
31 context::Context* c,
32 ConstraintDatabase& cd,
33 SetupLiteralCallBack setup,
34 const ArithVariables& avars,
35 RaiseEqualityEngineConflict raiseConflict)
36 : d_inConflict(c),
37 d_raiseConflict(raiseConflict),
38 d_notify(*this),
39 d_keepAlive(c),
40 d_propagatations(c),
41 d_explanationMap(c),
42 d_constraintDatabase(cd),
43 d_setupLiteral(setup),
44 d_avariables(avars),
45 d_ee(nullptr)
46 {
47 }
48
49 ArithCongruenceManager::~ArithCongruenceManager() {}
50
51 bool ArithCongruenceManager::needsEqualityEngine(EeSetupInfo& esi)
52 {
53 esi.d_notify = &d_notify;
54 esi.d_name = "theory::arith::ArithCongruenceManager";
55 return true;
56 }
57
58 void ArithCongruenceManager::finishInit(eq::EqualityEngine* ee)
59 {
60 Assert(ee != nullptr);
61 d_ee = ee;
62 d_ee->addFunctionKind(kind::NONLINEAR_MULT);
63 d_ee->addFunctionKind(kind::EXPONENTIAL);
64 d_ee->addFunctionKind(kind::SINE);
65 d_ee->addFunctionKind(kind::IAND);
66 }
67
68 ArithCongruenceManager::Statistics::Statistics():
69 d_watchedVariables("theory::arith::congruence::watchedVariables", 0),
70 d_watchedVariableIsZero("theory::arith::congruence::watchedVariableIsZero", 0),
71 d_watchedVariableIsNotZero("theory::arith::congruence::watchedVariableIsNotZero", 0),
72 d_equalsConstantCalls("theory::arith::congruence::equalsConstantCalls", 0),
73 d_propagations("theory::arith::congruence::propagations", 0),
74 d_propagateConstraints("theory::arith::congruence::propagateConstraints", 0),
75 d_conflicts("theory::arith::congruence::conflicts", 0)
76 {
77 smtStatisticsRegistry()->registerStat(&d_watchedVariables);
78 smtStatisticsRegistry()->registerStat(&d_watchedVariableIsZero);
79 smtStatisticsRegistry()->registerStat(&d_watchedVariableIsNotZero);
80 smtStatisticsRegistry()->registerStat(&d_equalsConstantCalls);
81 smtStatisticsRegistry()->registerStat(&d_propagations);
82 smtStatisticsRegistry()->registerStat(&d_propagateConstraints);
83 smtStatisticsRegistry()->registerStat(&d_conflicts);
84 }
85
86 ArithCongruenceManager::Statistics::~Statistics(){
87 smtStatisticsRegistry()->unregisterStat(&d_watchedVariables);
88 smtStatisticsRegistry()->unregisterStat(&d_watchedVariableIsZero);
89 smtStatisticsRegistry()->unregisterStat(&d_watchedVariableIsNotZero);
90 smtStatisticsRegistry()->unregisterStat(&d_equalsConstantCalls);
91 smtStatisticsRegistry()->unregisterStat(&d_propagations);
92 smtStatisticsRegistry()->unregisterStat(&d_propagateConstraints);
93 smtStatisticsRegistry()->unregisterStat(&d_conflicts);
94 }
95
96 ArithCongruenceManager::ArithCongruenceNotify::ArithCongruenceNotify(ArithCongruenceManager& acm)
97 : d_acm(acm)
98 {}
99
100 bool ArithCongruenceManager::ArithCongruenceNotify::eqNotifyTriggerPredicate(
101 TNode predicate, bool value)
102 {
103 Assert(predicate.getKind() == kind::EQUAL);
104 Debug("arith::congruences")
105 << "ArithCongruenceNotify::eqNotifyTriggerPredicate(" << predicate << ", "
106 << (value ? "true" : "false") << ")" << std::endl;
107 if (value) {
108 return d_acm.propagate(predicate);
109 }
110 return d_acm.propagate(predicate.notNode());
111 }
112
113 bool ArithCongruenceManager::ArithCongruenceNotify::eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value) {
114 Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 << ", " << (value ? "true" : "false") << ")" << std::endl;
115 if (value) {
116 return d_acm.propagate(t1.eqNode(t2));
117 } else {
118 return d_acm.propagate(t1.eqNode(t2).notNode());
119 }
120 }
121 void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyConstantTermMerge(TNode t1, TNode t2) {
122 Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << std::endl;
123 d_acm.propagate(t1.eqNode(t2));
124 }
125 void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyNewClass(TNode t) {
126 }
127 void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyMerge(TNode t1,
128 TNode t2)
129 {
130 }
131 void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyDisequal(TNode t1, TNode t2, TNode reason) {
132 }
133
134 void ArithCongruenceManager::raiseConflict(Node conflict){
135 Assert(!inConflict());
136 Debug("arith::conflict") << "difference manager conflict " << conflict << std::endl;
137 d_inConflict.raise();
138 d_raiseConflict.raiseEEConflict(conflict);
139 }
140 bool ArithCongruenceManager::inConflict() const{
141 return d_inConflict.isRaised();
142 }
143
144 bool ArithCongruenceManager::hasMorePropagations() const {
145 return !d_propagatations.empty();
146 }
147 const Node ArithCongruenceManager::getNextPropagation() {
148 Assert(hasMorePropagations());
149 Node prop = d_propagatations.front();
150 d_propagatations.dequeue();
151 return prop;
152 }
153
154 bool ArithCongruenceManager::canExplain(TNode n) const {
155 return d_explanationMap.find(n) != d_explanationMap.end();
156 }
157
158 Node ArithCongruenceManager::externalToInternal(TNode n) const{
159 Assert(canExplain(n));
160 ExplainMap::const_iterator iter = d_explanationMap.find(n);
161 size_t pos = (*iter).second;
162 return d_propagatations[pos];
163 }
164
165 void ArithCongruenceManager::pushBack(TNode n){
166 d_explanationMap.insert(n, d_propagatations.size());
167 d_propagatations.enqueue(n);
168
169 ++(d_statistics.d_propagations);
170 }
171 void ArithCongruenceManager::pushBack(TNode n, TNode r){
172 d_explanationMap.insert(r, d_propagatations.size());
173 d_explanationMap.insert(n, d_propagatations.size());
174 d_propagatations.enqueue(n);
175
176 ++(d_statistics.d_propagations);
177 }
178 void ArithCongruenceManager::pushBack(TNode n, TNode r, TNode w){
179 d_explanationMap.insert(w, d_propagatations.size());
180 d_explanationMap.insert(r, d_propagatations.size());
181 d_explanationMap.insert(n, d_propagatations.size());
182 d_propagatations.enqueue(n);
183
184 ++(d_statistics.d_propagations);
185 }
186
187 void ArithCongruenceManager::watchedVariableIsZero(ConstraintCP lb, ConstraintCP ub){
188 Assert(lb->isLowerBound());
189 Assert(ub->isUpperBound());
190 Assert(lb->getVariable() == ub->getVariable());
191 Assert(lb->getValue().sgn() == 0);
192 Assert(ub->getValue().sgn() == 0);
193
194 ++(d_statistics.d_watchedVariableIsZero);
195
196 ArithVar s = lb->getVariable();
197 Node reason = Constraint::externalExplainByAssertions(lb,ub);
198
199 d_keepAlive.push_back(reason);
200 assertionToEqualityEngine(true, s, reason);
201 }
202
203 void ArithCongruenceManager::watchedVariableIsZero(ConstraintCP eq){
204 Assert(eq->isEquality());
205 Assert(eq->getValue().sgn() == 0);
206
207 ++(d_statistics.d_watchedVariableIsZero);
208
209 ArithVar s = eq->getVariable();
210
211 //Explain for conflict is correct as these proofs are generated
212 //and stored eagerly
213 //These will be safe for propagation later as well
214 Node reason = eq->externalExplainByAssertions();
215
216 d_keepAlive.push_back(reason);
217 assertionToEqualityEngine(true, s, reason);
218 }
219
220 void ArithCongruenceManager::watchedVariableCannotBeZero(ConstraintCP c){
221 ++(d_statistics.d_watchedVariableIsNotZero);
222
223 ArithVar s = c->getVariable();
224
225 //Explain for conflict is correct as these proofs are generated and stored eagerly
226 //These will be safe for propagation later as well
227 Node reason = c->externalExplainByAssertions();
228
229 d_keepAlive.push_back(reason);
230 assertionToEqualityEngine(false, s, reason);
231 }
232
233
234 bool ArithCongruenceManager::propagate(TNode x){
235 Debug("arith::congruenceManager")<< "ArithCongruenceManager::propagate("<<x<<")"<<std::endl;
236 if(inConflict()){
237 return true;
238 }
239
240 Node rewritten = Rewriter::rewrite(x);
241
242 //Need to still propagate this!
243 if(rewritten.getKind() == kind::CONST_BOOLEAN){
244 pushBack(x);
245
246 if(rewritten.getConst<bool>()){
247 return true;
248 }else{
249 ++(d_statistics.d_conflicts);
250
251 Node conf = flattenAnd(explainInternal(x));
252 raiseConflict(conf);
253 Debug("arith::congruenceManager") << "rewritten to false "<<x<<" with explanation "<< conf << std::endl;
254 return false;
255 }
256 }
257
258 Assert(rewritten.getKind() != kind::CONST_BOOLEAN);
259
260 ConstraintP c = d_constraintDatabase.lookup(rewritten);
261 if(c == NullConstraint){
262 //using setup as there may not be a corresponding congruence literal yet
263 d_setupLiteral(rewritten);
264 c = d_constraintDatabase.lookup(rewritten);
265 Assert(c != NullConstraint);
266 }
267
268 Debug("arith::congruenceManager")<< "x is "
269 << c->hasProof() << " "
270 << (x == rewritten) << " "
271 << c->canBePropagated() << " "
272 << c->negationHasProof() << std::endl;
273
274 if(c->negationHasProof()){
275 Node expC = explainInternal(x);
276 ConstraintCP negC = c->getNegation();
277 Node neg = negC->externalExplainByAssertions();
278 Node conf = expC.andNode(neg);
279 Node final = flattenAnd(conf);
280
281 ++(d_statistics.d_conflicts);
282 raiseConflict(final);
283 Debug("arith::congruenceManager") << "congruenceManager found a conflict " << final << std::endl;
284 return false;
285 }
286
287 // Cases for propagation
288 // C : c has a proof
289 // S : x == rewritten
290 // P : c can be propagated
291 //
292 // CSP
293 // 000 : propagate x, and mark C it as being explained
294 // 001 : propagate x, and propagate c after marking it as being explained
295 // 01* : propagate x, mark c but do not propagate c
296 // 10* : propagate x, do not mark c and do not propagate c
297 // 11* : drop the constraint, do not propagate x or c
298
299 if(!c->hasProof() && x != rewritten){
300 if(c->assertedToTheTheory()){
301 pushBack(x, rewritten, c->getWitness());
302 }else{
303 pushBack(x, rewritten);
304 }
305
306 c->setEqualityEngineProof();
307 if(c->canBePropagated() && !c->assertedToTheTheory()){
308
309 ++(d_statistics.d_propagateConstraints);
310 c->propagate();
311 }
312 }else if(!c->hasProof() && x == rewritten){
313 if(c->assertedToTheTheory()){
314 pushBack(x, c->getWitness());
315 }else{
316 pushBack(x);
317 }
318 c->setEqualityEngineProof();
319 }else if(c->hasProof() && x != rewritten){
320 if(c->assertedToTheTheory()){
321 pushBack(x);
322 }else{
323 pushBack(x);
324 }
325 }else{
326 Assert(c->hasProof() && x == rewritten);
327 }
328 return true;
329 }
330
331 void ArithCongruenceManager::explain(TNode literal, std::vector<TNode>& assumptions) {
332 if (literal.getKind() != kind::NOT) {
333 d_ee->explainEquality(literal[0], literal[1], true, assumptions);
334 } else {
335 d_ee->explainEquality(literal[0][0], literal[0][1], false, assumptions);
336 }
337 }
338
339 void ArithCongruenceManager::enqueueIntoNB(const std::set<TNode> s, NodeBuilder<>& nb){
340 std::set<TNode>::const_iterator it = s.begin();
341 std::set<TNode>::const_iterator it_end = s.end();
342 for(; it != it_end; ++it) {
343 nb << *it;
344 }
345 }
346
347 Node ArithCongruenceManager::explainInternal(TNode internal){
348 std::vector<TNode> assumptions;
349 explain(internal, assumptions);
350
351 std::set<TNode> assumptionSet;
352 assumptionSet.insert(assumptions.begin(), assumptions.end());
353
354 if (assumptionSet.size() == 1) {
355 // All the same, or just one
356 return assumptions[0];
357 }else{
358 NodeBuilder<> conjunction(kind::AND);
359 enqueueIntoNB(assumptionSet, conjunction);
360 return conjunction;
361 }
362 }
363
364 Node ArithCongruenceManager::explain(TNode external){
365 Trace("arith-ee") << "Ask for explanation of " << external << std::endl;
366 Node internal = externalToInternal(external);
367 Trace("arith-ee") << "...internal = " << internal << std::endl;
368 return explainInternal(internal);
369 }
370
371 void ArithCongruenceManager::explain(TNode external, NodeBuilder<>& out){
372 Node internal = externalToInternal(external);
373
374 std::vector<TNode> assumptions;
375 explain(internal, assumptions);
376 std::set<TNode> assumptionSet;
377 assumptionSet.insert(assumptions.begin(), assumptions.end());
378
379 enqueueIntoNB(assumptionSet, out);
380 }
381
382 void ArithCongruenceManager::addWatchedPair(ArithVar s, TNode x, TNode y){
383 Assert(!isWatchedVariable(s));
384
385 Debug("arith::congruenceManager")
386 << "addWatchedPair(" << s << ", " << x << ", " << y << ")" << std::endl;
387
388
389 ++(d_statistics.d_watchedVariables);
390
391 d_watchedVariables.add(s);
392
393 Node eq = x.eqNode(y);
394 d_watchedEqualities.set(s, eq);
395 }
396
397 void ArithCongruenceManager::assertionToEqualityEngine(bool isEquality, ArithVar s, TNode reason){
398 Assert(isWatchedVariable(s));
399
400 TNode eq = d_watchedEqualities[s];
401 Assert(eq.getKind() == kind::EQUAL);
402
403 Trace("arith-ee") << "Assert " << eq << ", pol " << isEquality << ", reason " << reason << std::endl;
404 if(isEquality){
405 d_ee->assertEquality(eq, true, reason);
406 }else{
407 d_ee->assertEquality(eq, false, reason);
408 }
409 }
410
411 void ArithCongruenceManager::equalsConstant(ConstraintCP c){
412 Assert(c->isEquality());
413
414 ++(d_statistics.d_equalsConstantCalls);
415 Debug("equalsConstant") << "equals constant " << c << std::endl;
416
417 ArithVar x = c->getVariable();
418 Node xAsNode = d_avariables.asNode(x);
419 Node asRational = mkRationalNode(c->getValue().getNoninfinitesimalPart());
420
421
422 //No guarentee this is in normal form!
423 Node eq = xAsNode.eqNode(asRational);
424 d_keepAlive.push_back(eq);
425
426 Node reason = c->externalExplainByAssertions();
427 d_keepAlive.push_back(reason);
428
429 Trace("arith-ee") << "Assert equalsConstant " << eq << ", reason " << reason << std::endl;
430 d_ee->assertEquality(eq, true, reason);
431 }
432
433 void ArithCongruenceManager::equalsConstant(ConstraintCP lb, ConstraintCP ub){
434 Assert(lb->isLowerBound());
435 Assert(ub->isUpperBound());
436 Assert(lb->getVariable() == ub->getVariable());
437
438 ++(d_statistics.d_equalsConstantCalls);
439 Debug("equalsConstant") << "equals constant " << lb << std::endl
440 << ub << std::endl;
441
442 ArithVar x = lb->getVariable();
443 Node reason = Constraint::externalExplainByAssertions(lb,ub);
444
445 Node xAsNode = d_avariables.asNode(x);
446 Node asRational = mkRationalNode(lb->getValue().getNoninfinitesimalPart());
447
448 //No guarentee this is in normal form!
449 Node eq = xAsNode.eqNode(asRational);
450 d_keepAlive.push_back(eq);
451 d_keepAlive.push_back(reason);
452
453 Trace("arith-ee") << "Assert equalsConstant2 " << eq << ", reason " << reason << std::endl;
454 d_ee->assertEquality(eq, true, reason);
455 }
456
457 void ArithCongruenceManager::addSharedTerm(Node x){
458 d_ee->addTriggerTerm(x, THEORY_ARITH);
459 }
460
461 }/* CVC4::theory::arith namespace */
462 }/* CVC4::theory namespace */
463 }/* CVC4 namespace */