Merge pull request #142 from timothy-king/nlAlgMerge
[cvc5.git] / src / theory / arith / congruence_manager.cpp
1 /********************* */
2 /*! \file congruence_manager.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Tim King, Dejan Jovanovic, Morgan Deters, Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2016 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief [[ Add one-line brief description here ]]
13 **
14 ** [[ Add lengthier description here ]]
15 ** \todo document this file
16 **/
17
18 #include "theory/arith/congruence_manager.h"
19
20 #include "base/output.h"
21 #include "smt/smt_statistics_registry.h"
22 #include "theory/arith/arith_utilities.h"
23 #include "theory/arith/constraint.h"
24 #include "theory/quantifiers/equality_infer.h"
25 #include "options/arith_options.h"
26
27 namespace CVC4 {
28 namespace theory {
29 namespace arith {
30
31 ArithCongruenceManager::ArithCongruenceManager(context::Context* c, ConstraintDatabase& cd, SetupLiteralCallBack setup, const ArithVariables& avars, RaiseEqualityEngineConflict raiseConflict)
32 : d_inConflict(c),
33 d_raiseConflict(raiseConflict),
34 d_notify(*this),
35 d_eq_infer(NULL),
36 d_eqi_counter(0,c),
37 d_keepAlive(c),
38 d_propagatations(c),
39 d_explanationMap(c),
40 d_constraintDatabase(cd),
41 d_setupLiteral(setup),
42 d_avariables(avars),
43 d_ee(d_notify, c, "theory::arith::ArithCongruenceManager", true)
44 {
45 d_ee.addFunctionKind(kind::NONLINEAR_MULT);
46 //module to infer additional equalities based on normalization
47 if( options::sNormInferEq() ){
48 d_eq_infer = new quantifiers::EqualityInference(c, true);
49 d_true = NodeManager::currentNM()->mkConst( true );
50 }
51 }
52
53 ArithCongruenceManager::~ArithCongruenceManager() {
54 if( d_eq_infer ){
55 delete d_eq_infer;
56 }
57 }
58
59 ArithCongruenceManager::Statistics::Statistics():
60 d_watchedVariables("theory::arith::congruence::watchedVariables", 0),
61 d_watchedVariableIsZero("theory::arith::congruence::watchedVariableIsZero", 0),
62 d_watchedVariableIsNotZero("theory::arith::congruence::watchedVariableIsNotZero", 0),
63 d_equalsConstantCalls("theory::arith::congruence::equalsConstantCalls", 0),
64 d_propagations("theory::arith::congruence::propagations", 0),
65 d_propagateConstraints("theory::arith::congruence::propagateConstraints", 0),
66 d_conflicts("theory::arith::congruence::conflicts", 0)
67 {
68 smtStatisticsRegistry()->registerStat(&d_watchedVariables);
69 smtStatisticsRegistry()->registerStat(&d_watchedVariableIsZero);
70 smtStatisticsRegistry()->registerStat(&d_watchedVariableIsNotZero);
71 smtStatisticsRegistry()->registerStat(&d_equalsConstantCalls);
72 smtStatisticsRegistry()->registerStat(&d_propagations);
73 smtStatisticsRegistry()->registerStat(&d_propagateConstraints);
74 smtStatisticsRegistry()->registerStat(&d_conflicts);
75 }
76
77 ArithCongruenceManager::Statistics::~Statistics(){
78 smtStatisticsRegistry()->unregisterStat(&d_watchedVariables);
79 smtStatisticsRegistry()->unregisterStat(&d_watchedVariableIsZero);
80 smtStatisticsRegistry()->unregisterStat(&d_watchedVariableIsNotZero);
81 smtStatisticsRegistry()->unregisterStat(&d_equalsConstantCalls);
82 smtStatisticsRegistry()->unregisterStat(&d_propagations);
83 smtStatisticsRegistry()->unregisterStat(&d_propagateConstraints);
84 smtStatisticsRegistry()->unregisterStat(&d_conflicts);
85 }
86
87 ArithCongruenceManager::ArithCongruenceNotify::ArithCongruenceNotify(ArithCongruenceManager& acm)
88 : d_acm(acm)
89 {}
90
91 bool ArithCongruenceManager::ArithCongruenceNotify::eqNotifyTriggerEquality(TNode equality, bool value) {
92 Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false") << ")" << std::endl;
93 if (value) {
94 return d_acm.propagate(equality);
95 } else {
96 return d_acm.propagate(equality.notNode());
97 }
98 }
99 bool ArithCongruenceManager::ArithCongruenceNotify::eqNotifyTriggerPredicate(TNode predicate, bool value) {
100 Unreachable();
101 }
102
103 bool ArithCongruenceManager::ArithCongruenceNotify::eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value) {
104 Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 << ", " << (value ? "true" : "false") << ")" << std::endl;
105 if (value) {
106 return d_acm.propagate(t1.eqNode(t2));
107 } else {
108 return d_acm.propagate(t1.eqNode(t2).notNode());
109 }
110 }
111 void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyConstantTermMerge(TNode t1, TNode t2) {
112 Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << std::endl;
113 d_acm.propagate(t1.eqNode(t2));
114 }
115 void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyNewClass(TNode t) {
116 d_acm.eqNotifyNewClass(t);
117 }
118 void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyPreMerge(TNode t1, TNode t2) {
119 }
120 void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyPostMerge(TNode t1, TNode t2) {
121 d_acm.eqNotifyPostMerge(t1,t2);
122 }
123 void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyDisequal(TNode t1, TNode t2, TNode reason) {
124 }
125
126 void ArithCongruenceManager::raiseConflict(Node conflict){
127 Assert(!inConflict());
128 Debug("arith::conflict") << "difference manager conflict " << conflict << std::endl;
129 d_inConflict.raise();
130 d_raiseConflict.raiseEEConflict(conflict);
131 }
132 bool ArithCongruenceManager::inConflict() const{
133 return d_inConflict.isRaised();
134 }
135
136 bool ArithCongruenceManager::hasMorePropagations() const {
137 return !d_propagatations.empty();
138 }
139 const Node ArithCongruenceManager::getNextPropagation() {
140 Assert(hasMorePropagations());
141 Node prop = d_propagatations.front();
142 d_propagatations.dequeue();
143 return prop;
144 }
145
146 bool ArithCongruenceManager::canExplain(TNode n) const {
147 return d_explanationMap.find(n) != d_explanationMap.end();
148 }
149
150 void ArithCongruenceManager::setMasterEqualityEngine(eq::EqualityEngine* eq) {
151 d_ee.setMasterEqualityEngine(eq);
152 }
153
154 Node ArithCongruenceManager::externalToInternal(TNode n) const{
155 Assert(canExplain(n));
156 ExplainMap::const_iterator iter = d_explanationMap.find(n);
157 size_t pos = (*iter).second;
158 return d_propagatations[pos];
159 }
160
161 void ArithCongruenceManager::pushBack(TNode n){
162 d_explanationMap.insert(n, d_propagatations.size());
163 d_propagatations.enqueue(n);
164
165 ++(d_statistics.d_propagations);
166 }
167 void ArithCongruenceManager::pushBack(TNode n, TNode r){
168 d_explanationMap.insert(r, d_propagatations.size());
169 d_explanationMap.insert(n, d_propagatations.size());
170 d_propagatations.enqueue(n);
171
172 ++(d_statistics.d_propagations);
173 }
174 void ArithCongruenceManager::pushBack(TNode n, TNode r, TNode w){
175 d_explanationMap.insert(w, d_propagatations.size());
176 d_explanationMap.insert(r, d_propagatations.size());
177 d_explanationMap.insert(n, d_propagatations.size());
178 d_propagatations.enqueue(n);
179
180 ++(d_statistics.d_propagations);
181 }
182
183 void ArithCongruenceManager::watchedVariableIsZero(ConstraintCP lb, ConstraintCP ub){
184 Assert(lb->isLowerBound());
185 Assert(ub->isUpperBound());
186 Assert(lb->getVariable() == ub->getVariable());
187 Assert(lb->getValue().sgn() == 0);
188 Assert(ub->getValue().sgn() == 0);
189
190 ++(d_statistics.d_watchedVariableIsZero);
191
192 ArithVar s = lb->getVariable();
193 Node reason = Constraint::externalExplainByAssertions(lb,ub);
194
195 d_keepAlive.push_back(reason);
196 assertionToEqualityEngine(true, s, reason);
197 }
198
199 void ArithCongruenceManager::watchedVariableIsZero(ConstraintCP eq){
200 Assert(eq->isEquality());
201 Assert(eq->getValue().sgn() == 0);
202
203 ++(d_statistics.d_watchedVariableIsZero);
204
205 ArithVar s = eq->getVariable();
206
207 //Explain for conflict is correct as these proofs are generated
208 //and stored eagerly
209 //These will be safe for propagation later as well
210 Node reason = eq->externalExplainByAssertions();
211
212 d_keepAlive.push_back(reason);
213 assertionToEqualityEngine(true, s, reason);
214 }
215
216 void ArithCongruenceManager::watchedVariableCannotBeZero(ConstraintCP c){
217 ++(d_statistics.d_watchedVariableIsNotZero);
218
219 ArithVar s = c->getVariable();
220
221 //Explain for conflict is correct as these proofs are generated and stored eagerly
222 //These will be safe for propagation later as well
223 Node reason = c->externalExplainByAssertions();
224
225 d_keepAlive.push_back(reason);
226 assertionToEqualityEngine(false, s, reason);
227 }
228
229
230 bool ArithCongruenceManager::propagate(TNode x){
231 Debug("arith::congruenceManager")<< "ArithCongruenceManager::propagate("<<x<<")"<<std::endl;
232 if(inConflict()){
233 return true;
234 }
235
236 Node rewritten = Rewriter::rewrite(x);
237
238 //Need to still propagate this!
239 if(rewritten.getKind() == kind::CONST_BOOLEAN){
240 pushBack(x);
241
242 if(rewritten.getConst<bool>()){
243 return true;
244 }else{
245 ++(d_statistics.d_conflicts);
246
247 Node conf = flattenAnd(explainInternal(x));
248 raiseConflict(conf);
249 Debug("arith::congruenceManager") << "rewritten to false "<<x<<" with explanation "<< conf << std::endl;
250 return false;
251 }
252 }
253
254 Assert(rewritten.getKind() != kind::CONST_BOOLEAN);
255
256 ConstraintP c = d_constraintDatabase.lookup(rewritten);
257 if(c == NullConstraint){
258 //using setup as there may not be a corresponding congruence literal yet
259 d_setupLiteral(rewritten);
260 c = d_constraintDatabase.lookup(rewritten);
261 Assert(c != NullConstraint);
262 }
263
264 Debug("arith::congruenceManager")<< "x is "
265 << c->hasProof() << " "
266 << (x == rewritten) << " "
267 << c->canBePropagated() << " "
268 << c->negationHasProof() << std::endl;
269
270 if(c->negationHasProof()){
271 Node expC = explainInternal(x);
272 ConstraintCP negC = c->getNegation();
273 Node neg = negC->externalExplainByAssertions();
274 Node conf = expC.andNode(neg);
275 Node final = flattenAnd(conf);
276
277 ++(d_statistics.d_conflicts);
278 raiseConflict(final);
279 Debug("arith::congruenceManager") << "congruenceManager found a conflict " << final << std::endl;
280 return false;
281 }
282
283 // Cases for propagation
284 // C : c has a proof
285 // S : x == rewritten
286 // P : c can be propagated
287 //
288 // CSP
289 // 000 : propagate x, and mark C it as being explained
290 // 001 : propagate x, and propagate c after marking it as being explained
291 // 01* : propagate x, mark c but do not propagate c
292 // 10* : propagate x, do not mark c and do not propagate c
293 // 11* : drop the constraint, do not propagate x or c
294
295 if(!c->hasProof() && x != rewritten){
296 if(c->assertedToTheTheory()){
297 pushBack(x, rewritten, c->getWitness());
298 }else{
299 pushBack(x, rewritten);
300 }
301
302 c->setEqualityEngineProof();
303 if(c->canBePropagated() && !c->assertedToTheTheory()){
304
305 ++(d_statistics.d_propagateConstraints);
306 c->propagate();
307 }
308 }else if(!c->hasProof() && x == rewritten){
309 if(c->assertedToTheTheory()){
310 pushBack(x, c->getWitness());
311 }else{
312 pushBack(x);
313 }
314 c->setEqualityEngineProof();
315 }else if(c->hasProof() && x != rewritten){
316 if(c->assertedToTheTheory()){
317 pushBack(x);
318 }else{
319 pushBack(x);
320 }
321 }else{
322 Assert(c->hasProof() && x == rewritten);
323 }
324 return true;
325 }
326
327 void ArithCongruenceManager::explain(TNode literal, std::vector<TNode>& assumptions) {
328 if (literal.getKind() != kind::NOT) {
329 d_ee.explainEquality(literal[0], literal[1], true, assumptions);
330 } else {
331 d_ee.explainEquality(literal[0][0], literal[0][1], false, assumptions);
332 }
333 }
334
335 void ArithCongruenceManager::enqueueIntoNB(const std::set<TNode> s, NodeBuilder<>& nb){
336 std::set<TNode>::const_iterator it = s.begin();
337 std::set<TNode>::const_iterator it_end = s.end();
338 for(; it != it_end; ++it) {
339 nb << *it;
340 }
341 }
342
343 Node ArithCongruenceManager::explainInternal(TNode internal){
344 std::vector<TNode> assumptions;
345 explain(internal, assumptions);
346
347 std::set<TNode> assumptionSet;
348 assumptionSet.insert(assumptions.begin(), assumptions.end());
349
350 if (assumptionSet.size() == 1) {
351 // All the same, or just one
352 return assumptions[0];
353 }else{
354 NodeBuilder<> conjunction(kind::AND);
355 enqueueIntoNB(assumptionSet, conjunction);
356 return conjunction;
357 }
358 }
359
360 void ArithCongruenceManager::eqNotifyNewClass(TNode t) {
361 if( d_eq_infer ){
362 d_eq_infer->eqNotifyNewClass(t);
363 fixpointInfer();
364 }
365 }
366 void ArithCongruenceManager::eqNotifyPostMerge(TNode t1, TNode t2) {
367 if( d_eq_infer ){
368 d_eq_infer->eqNotifyMerge(t1, t2);
369 fixpointInfer();
370 }
371 }
372
373 Node ArithCongruenceManager::explain(TNode external){
374 Trace("arith-ee") << "Ask for explanation of " << external << std::endl;
375 Node internal = externalToInternal(external);
376 Trace("arith-ee") << "...internal = " << internal << std::endl;
377 return explainInternal(internal);
378 }
379
380 void ArithCongruenceManager::explain(TNode external, NodeBuilder<>& out){
381 Node internal = externalToInternal(external);
382
383 std::vector<TNode> assumptions;
384 explain(internal, assumptions);
385 std::set<TNode> assumptionSet;
386 assumptionSet.insert(assumptions.begin(), assumptions.end());
387
388 enqueueIntoNB(assumptionSet, out);
389 }
390
391 void ArithCongruenceManager::addWatchedPair(ArithVar s, TNode x, TNode y){
392 Assert(!isWatchedVariable(s));
393
394 Debug("arith::congruenceManager")
395 << "addWatchedPair(" << s << ", " << x << ", " << y << ")" << std::endl;
396
397
398 ++(d_statistics.d_watchedVariables);
399
400 d_watchedVariables.add(s);
401
402 Node eq = x.eqNode(y);
403 d_watchedEqualities.set(s, eq);
404 }
405
406 void ArithCongruenceManager::assertionToEqualityEngine(bool isEquality, ArithVar s, TNode reason){
407 Assert(isWatchedVariable(s));
408
409 TNode eq = d_watchedEqualities[s];
410 Assert(eq.getKind() == kind::EQUAL);
411
412 Trace("arith-ee") << "Assert " << eq << ", pol " << isEquality << ", reason " << reason << std::endl;
413 if(isEquality){
414 d_ee.assertEquality(eq, true, reason);
415 }else{
416 d_ee.assertEquality(eq, false, reason);
417 }
418 }
419
420 void ArithCongruenceManager::equalsConstant(ConstraintCP c){
421 Assert(c->isEquality());
422
423 ++(d_statistics.d_equalsConstantCalls);
424 Debug("equalsConstant") << "equals constant " << c << std::endl;
425
426 ArithVar x = c->getVariable();
427 Node xAsNode = d_avariables.asNode(x);
428 Node asRational = mkRationalNode(c->getValue().getNoninfinitesimalPart());
429
430
431 //No guarentee this is in normal form!
432 Node eq = xAsNode.eqNode(asRational);
433 d_keepAlive.push_back(eq);
434
435 Node reason = c->externalExplainByAssertions();
436 d_keepAlive.push_back(reason);
437
438 Trace("arith-ee") << "Assert equalsConstant " << eq << ", reason " << reason << std::endl;
439 d_ee.assertEquality(eq, true, reason);
440 }
441
442 void ArithCongruenceManager::equalsConstant(ConstraintCP lb, ConstraintCP ub){
443 Assert(lb->isLowerBound());
444 Assert(ub->isUpperBound());
445 Assert(lb->getVariable() == ub->getVariable());
446
447 ++(d_statistics.d_equalsConstantCalls);
448 Debug("equalsConstant") << "equals constant " << lb << std::endl
449 << ub << std::endl;
450
451 ArithVar x = lb->getVariable();
452 Node reason = Constraint::externalExplainByAssertions(lb,ub);
453
454 Node xAsNode = d_avariables.asNode(x);
455 Node asRational = mkRationalNode(lb->getValue().getNoninfinitesimalPart());
456
457 //No guarentee this is in normal form!
458 Node eq = xAsNode.eqNode(asRational);
459 d_keepAlive.push_back(eq);
460 d_keepAlive.push_back(reason);
461
462 Trace("arith-ee") << "Assert equalsConstant2 " << eq << ", reason " << reason << std::endl;
463 d_ee.assertEquality(eq, true, reason);
464 }
465
466 void ArithCongruenceManager::addSharedTerm(Node x){
467 d_ee.addTriggerTerm(x, THEORY_ARITH);
468 }
469
470 bool ArithCongruenceManager::fixpointInfer() {
471 if( d_eq_infer ){
472 while(! inConflict() && d_eqi_counter.get()<d_eq_infer->getNumPendingMerges() ) {
473 Trace("snorm-infer-eq-debug") << "Processing " << d_eqi_counter.get() << " / " << d_eq_infer->getNumPendingMerges() << std::endl;
474 Node eq = d_eq_infer->getPendingMerge( d_eqi_counter.get() );
475 Trace("snorm-infer-eq") << "ArithCongruenceManager : Infer by normalization : " << eq << std::endl;
476 if( !d_ee.areEqual( eq[0], eq[1] ) ){
477 Node eq_exp = d_eq_infer->getPendingMergeExplanation( d_eqi_counter.get() );
478 Trace("snorm-infer-eq") << " explanation : " << eq_exp << std::endl;
479 //regress explanation
480 std::vector<TNode> assumptions;
481 if( eq_exp.getKind()==kind::AND ){
482 for( unsigned i=0; i<eq_exp.getNumChildren(); i++ ){
483 explain( eq_exp[i], assumptions );
484 }
485 }else if( eq_exp.getKind()==kind::EQUAL ){
486 explain( eq_exp, assumptions );
487 }else{
488 //eq_exp should be true
489 Assert( eq_exp==d_true );
490 }
491 Node req_exp;
492 if( assumptions.empty() ){
493 req_exp = d_true;
494 }else{
495 std::set<TNode> assumptionSet;
496 assumptionSet.insert(assumptions.begin(), assumptions.end());
497 if( assumptionSet.size()==1 ){
498 req_exp = assumptions[0];
499 }else{
500 NodeBuilder<> conjunction(kind::AND);
501 enqueueIntoNB(assumptionSet, conjunction);
502 req_exp = conjunction;
503 }
504 }
505 Trace("snorm-infer-eq") << " regressed explanation : " << req_exp << std::endl;
506 d_ee.assertEquality( eq, true, req_exp );
507 d_keepAlive.push_back( req_exp );
508 }else{
509 Trace("snorm-infer-eq") << "...already equal." << std::endl;
510 }
511 d_eqi_counter = d_eqi_counter.get() + 1;
512 }
513 }
514 return inConflict();
515 }
516
517 }/* CVC4::theory::arith namespace */
518 }/* CVC4::theory namespace */
519 }/* CVC4 namespace */