Bv2int fail on demand
[cvc5.git] / src / preprocessing / passes / miplib_trick.cpp
1 /********************* */
2 /*! \file miplib_trick.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Mathias Preiner, Tim King, Andres Noetzli
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief The MIPLIB trick preprocessing pass
13 **
14 **/
15
16 #include "preprocessing/passes/miplib_trick.h"
17
18 #include <vector>
19
20 #include "expr/node_self_iterator.h"
21 #include "options/arith_options.h"
22 #include "smt/smt_statistics_registry.h"
23 #include "smt_util/boolean_simplification.h"
24 #include "theory/booleans/circuit_propagator.h"
25 #include "theory/theory_model.h"
26
27 namespace CVC4 {
28 namespace preprocessing {
29 namespace passes {
30
31 using namespace CVC4::theory;
32
33 namespace {
34
35 /**
36 * Remove conjuncts in toRemove from conjunction n. Return # of removed
37 * conjuncts.
38 */
39 size_t removeFromConjunction(Node& n,
40 const std::unordered_set<unsigned long>& toRemove)
41 {
42 Assert(n.getKind() == kind::AND);
43 Node trueNode = NodeManager::currentNM()->mkConst(true);
44 size_t removals = 0;
45 for (Node::iterator j = n.begin(); j != n.end(); ++j)
46 {
47 size_t subremovals = 0;
48 Node sub = *j;
49 if (toRemove.find(sub.getId()) != toRemove.end()
50 || (sub.getKind() == kind::AND
51 && (subremovals = removeFromConjunction(sub, toRemove)) > 0))
52 {
53 NodeBuilder<> b(kind::AND);
54 b.append(n.begin(), j);
55 if (subremovals > 0)
56 {
57 removals += subremovals;
58 b << sub;
59 }
60 else
61 {
62 ++removals;
63 }
64 for (++j; j != n.end(); ++j)
65 {
66 if (toRemove.find((*j).getId()) != toRemove.end())
67 {
68 ++removals;
69 }
70 else if ((*j).getKind() == kind::AND)
71 {
72 sub = *j;
73 if ((subremovals = removeFromConjunction(sub, toRemove)) > 0)
74 {
75 removals += subremovals;
76 b << sub;
77 }
78 else
79 {
80 b << *j;
81 }
82 }
83 else
84 {
85 b << *j;
86 }
87 }
88 if (b.getNumChildren() == 0)
89 {
90 n = trueNode;
91 b.clear();
92 }
93 else if (b.getNumChildren() == 1)
94 {
95 n = b[0];
96 b.clear();
97 }
98 else
99 {
100 n = b;
101 }
102 n = Rewriter::rewrite(n);
103 return removals;
104 }
105 }
106
107 Assert(removals == 0);
108 return 0;
109 }
110
111 /**
112 * Trace nodes back to their assertions using CircuitPropagator's
113 * BackEdgesMap.
114 */
115 void traceBackToAssertions(booleans::CircuitPropagator* propagator,
116 const std::vector<Node>& nodes,
117 std::vector<TNode>& assertions)
118 {
119 const booleans::CircuitPropagator::BackEdgesMap& backEdges =
120 propagator->getBackEdges();
121 for (vector<Node>::const_iterator i = nodes.begin(); i != nodes.end(); ++i)
122 {
123 booleans::CircuitPropagator::BackEdgesMap::const_iterator j =
124 backEdges.find(*i);
125 // term must appear in map, otherwise how did we get here?!
126 Assert(j != backEdges.end());
127 // if term maps to empty, that means it's a top-level assertion
128 if (!(*j).second.empty())
129 {
130 traceBackToAssertions(propagator, (*j).second, assertions);
131 }
132 else
133 {
134 assertions.push_back(*i);
135 }
136 }
137 }
138
139 } // namespace
140
141 MipLibTrick::MipLibTrick(PreprocessingPassContext* preprocContext)
142 : PreprocessingPass(preprocContext, "miplib-trick")
143 {
144 if (!options::incrementalSolving())
145 {
146 NodeManager::currentNM()->subscribeEvents(this);
147 }
148 }
149
150 MipLibTrick::~MipLibTrick()
151 {
152 if (!options::incrementalSolving())
153 {
154 NodeManager::currentNM()->unsubscribeEvents(this);
155 }
156 }
157
158 void MipLibTrick::nmNotifyNewVar(TNode n, uint32_t flags)
159 {
160 if (n.getType().isBoolean())
161 {
162 d_boolVars.push_back(n);
163 }
164 }
165
166 void MipLibTrick::nmNotifyNewSkolem(TNode n,
167 const std::string& comment,
168 uint32_t flags)
169 {
170 if (n.getType().isBoolean())
171 {
172 d_boolVars.push_back(n);
173 }
174 }
175
176 PreprocessingPassResult MipLibTrick::applyInternal(
177 AssertionPipeline* assertionsToPreprocess)
178 {
179 Assert(assertionsToPreprocess->getRealAssertionsEnd()
180 == assertionsToPreprocess->size());
181 Assert(!options::incrementalSolving());
182
183 context::Context fakeContext;
184 TheoryEngine* te = d_preprocContext->getTheoryEngine();
185 booleans::CircuitPropagator* propagator =
186 d_preprocContext->getCircuitPropagator();
187 const booleans::CircuitPropagator::BackEdgesMap& backEdges =
188 propagator->getBackEdges();
189 unordered_set<unsigned long> removeAssertions;
190 SubstitutionMap& top_level_substs =
191 d_preprocContext->getTopLevelSubstitutions();
192
193 NodeManager* nm = NodeManager::currentNM();
194 Node zero = nm->mkConst(Rational(0)), one = nm->mkConst(Rational(1));
195 Node trueNode = nm->mkConst(true);
196
197 unordered_map<TNode, Node, TNodeHashFunction> intVars;
198 for (TNode v0 : d_boolVars)
199 {
200 if (propagator->isAssigned(v0))
201 {
202 Debug("miplib") << "ineligible: " << v0 << " because assigned "
203 << propagator->getAssignment(v0) << endl;
204 continue;
205 }
206
207 vector<TNode> assertions;
208 booleans::CircuitPropagator::BackEdgesMap::const_iterator j0 =
209 backEdges.find(v0);
210 // if not in back edges map, the bool var is unconstrained, showing up in no
211 // assertions. if maps to an empty vector, that means the bool var was
212 // asserted itself.
213 if (j0 != backEdges.end())
214 {
215 if (!(*j0).second.empty())
216 {
217 traceBackToAssertions(propagator, (*j0).second, assertions);
218 }
219 else
220 {
221 assertions.push_back(v0);
222 }
223 }
224 Debug("miplib") << "for " << v0 << endl;
225 bool eligible = true;
226 map<pair<Node, Node>, uint64_t> marks;
227 map<pair<Node, Node>, vector<Rational> > coef;
228 map<pair<Node, Node>, vector<Rational> > checks;
229 map<pair<Node, Node>, vector<TNode> > asserts;
230 for (vector<TNode>::const_iterator j1 = assertions.begin();
231 j1 != assertions.end();
232 ++j1)
233 {
234 Debug("miplib") << " found: " << *j1 << endl;
235 if ((*j1).getKind() != kind::IMPLIES)
236 {
237 eligible = false;
238 Debug("miplib") << " -- INELIGIBLE -- (not =>)" << endl;
239 break;
240 }
241 Node conj = BooleanSimplification::simplify((*j1)[0]);
242 if (conj.getKind() == kind::AND && conj.getNumChildren() > 6)
243 {
244 eligible = false;
245 Debug("miplib") << " -- INELIGIBLE -- (N-ary /\\ too big)" << endl;
246 break;
247 }
248 if (conj.getKind() != kind::AND && !conj.isVar()
249 && !(conj.getKind() == kind::NOT && conj[0].isVar()))
250 {
251 eligible = false;
252 Debug("miplib") << " -- INELIGIBLE -- (not /\\ or literal)" << endl;
253 break;
254 }
255 if ((*j1)[1].getKind() != kind::EQUAL
256 || !(((*j1)[1][0].isVar()
257 && (*j1)[1][1].getKind() == kind::CONST_RATIONAL)
258 || ((*j1)[1][0].getKind() == kind::CONST_RATIONAL
259 && (*j1)[1][1].isVar())))
260 {
261 eligible = false;
262 Debug("miplib") << " -- INELIGIBLE -- (=> (and X X) X)" << endl;
263 break;
264 }
265 if (conj.getKind() == kind::AND)
266 {
267 vector<Node> posv;
268 bool found_x = false;
269 map<TNode, bool> neg;
270 for (Node::iterator ii = conj.begin(); ii != conj.end(); ++ii)
271 {
272 if ((*ii).isVar())
273 {
274 posv.push_back(*ii);
275 neg[*ii] = false;
276 found_x = found_x || v0 == *ii;
277 }
278 else if ((*ii).getKind() == kind::NOT && (*ii)[0].isVar())
279 {
280 posv.push_back((*ii)[0]);
281 neg[(*ii)[0]] = true;
282 found_x = found_x || v0 == (*ii)[0];
283 }
284 else
285 {
286 eligible = false;
287 Debug("miplib")
288 << " -- INELIGIBLE -- (non-var: " << *ii << ")" << endl;
289 break;
290 }
291 if (propagator->isAssigned(posv.back()))
292 {
293 eligible = false;
294 Debug("miplib") << " -- INELIGIBLE -- (" << posv.back()
295 << " asserted)" << endl;
296 break;
297 }
298 }
299 if (!eligible)
300 {
301 break;
302 }
303 if (!found_x)
304 {
305 eligible = false;
306 Debug("miplib") << " --INELIGIBLE -- (couldn't find " << v0
307 << " in conjunction)" << endl;
308 break;
309 }
310 sort(posv.begin(), posv.end());
311 const Node pos = NodeManager::currentNM()->mkNode(kind::AND, posv);
312 const TNode var = ((*j1)[1][0].getKind() == kind::CONST_RATIONAL)
313 ? (*j1)[1][1]
314 : (*j1)[1][0];
315 const pair<Node, Node> pos_var(pos, var);
316 const Rational& constant =
317 ((*j1)[1][0].getKind() == kind::CONST_RATIONAL)
318 ? (*j1)[1][0].getConst<Rational>()
319 : (*j1)[1][1].getConst<Rational>();
320 uint64_t mark = 0;
321 unsigned countneg = 0, thepos = 0;
322 for (unsigned ii = 0; ii < pos.getNumChildren(); ++ii)
323 {
324 if (neg[pos[ii]])
325 {
326 ++countneg;
327 }
328 else
329 {
330 thepos = ii;
331 mark |= (0x1 << ii);
332 }
333 }
334 if ((marks[pos_var] & (1lu << mark)) != 0)
335 {
336 eligible = false;
337 Debug("miplib") << " -- INELIGIBLE -- (remarked)" << endl;
338 break;
339 }
340 Debug("miplib") << "mark is " << mark << " -- " << (1lu << mark)
341 << endl;
342 marks[pos_var] |= (1lu << mark);
343 Debug("miplib") << "marks[" << pos << "," << var << "] now "
344 << marks[pos_var] << endl;
345 if (countneg == pos.getNumChildren())
346 {
347 if (constant != 0)
348 {
349 eligible = false;
350 Debug("miplib") << " -- INELIGIBLE -- (nonzero constant)" << endl;
351 break;
352 }
353 }
354 else if (countneg == pos.getNumChildren() - 1)
355 {
356 Assert(coef[pos_var].size() <= 6 && thepos < 6);
357 if (coef[pos_var].size() <= thepos)
358 {
359 coef[pos_var].resize(thepos + 1);
360 }
361 coef[pos_var][thepos] = constant;
362 }
363 else
364 {
365 if (checks[pos_var].size() <= mark)
366 {
367 checks[pos_var].resize(mark + 1);
368 }
369 checks[pos_var][mark] = constant;
370 }
371 asserts[pos_var].push_back(*j1);
372 }
373 else
374 {
375 TNode x = conj;
376 if (x != v0 && x != (v0).notNode())
377 {
378 eligible = false;
379 Debug("miplib")
380 << " -- INELIGIBLE -- (x not present where I expect it)" << endl;
381 break;
382 }
383 const bool xneg = (x.getKind() == kind::NOT);
384 x = xneg ? x[0] : x;
385 Debug("miplib") << " x:" << x << " " << xneg << endl;
386 const TNode var = ((*j1)[1][0].getKind() == kind::CONST_RATIONAL)
387 ? (*j1)[1][1]
388 : (*j1)[1][0];
389 const pair<Node, Node> x_var(x, var);
390 const Rational& constant =
391 ((*j1)[1][0].getKind() == kind::CONST_RATIONAL)
392 ? (*j1)[1][0].getConst<Rational>()
393 : (*j1)[1][1].getConst<Rational>();
394 unsigned mark = (xneg ? 0 : 1);
395 if ((marks[x_var] & (1u << mark)) != 0)
396 {
397 eligible = false;
398 Debug("miplib") << " -- INELIGIBLE -- (remarked)" << endl;
399 break;
400 }
401 marks[x_var] |= (1u << mark);
402 if (xneg)
403 {
404 if (constant != 0)
405 {
406 eligible = false;
407 Debug("miplib") << " -- INELIGIBLE -- (nonzero constant)" << endl;
408 break;
409 }
410 }
411 else
412 {
413 Assert(coef[x_var].size() <= 6);
414 coef[x_var].resize(6);
415 coef[x_var][0] = constant;
416 }
417 asserts[x_var].push_back(*j1);
418 }
419 }
420 if (eligible)
421 {
422 for (map<pair<Node, Node>, uint64_t>::const_iterator j = marks.begin();
423 j != marks.end();
424 ++j)
425 {
426 const TNode pos = (*j).first.first;
427 const TNode var = (*j).first.second;
428 const pair<Node, Node>& pos_var = (*j).first;
429 const uint64_t mark = (*j).second;
430 const unsigned numVars =
431 pos.getKind() == kind::AND ? pos.getNumChildren() : 1;
432 uint64_t expected = (uint64_t(1) << (1 << numVars)) - 1;
433 expected = (expected == 0) ? -1 : expected; // fix for overflow
434 Debug("miplib") << "[" << pos << "] => " << hex << mark << " expect "
435 << expected << dec << endl;
436 Assert(pos.getKind() == kind::AND || pos.isVar());
437 if (mark != expected)
438 {
439 Debug("miplib") << " -- INELIGIBLE " << pos
440 << " -- (insufficiently marked, got " << mark
441 << " for " << numVars << " vars, expected "
442 << expected << endl;
443 }
444 else
445 {
446 if (mark != 3)
447 { // exclude single-var case; nothing to check there
448 uint64_t sz = (uint64_t(1) << checks[pos_var].size()) - 1;
449 sz = (sz == 0) ? -1 : sz; // fix for overflow
450 Assert(sz == mark) << "expected size " << sz << " == mark " << mark;
451 for (size_t k = 0; k < checks[pos_var].size(); ++k)
452 {
453 if ((k & (k - 1)) != 0)
454 {
455 Rational sum = 0;
456 Debug("miplib") << k << " => " << checks[pos_var][k] << endl;
457 for (size_t v1 = 1, kk = k; kk != 0; ++v1, kk >>= 1)
458 {
459 if ((kk & 0x1) == 1)
460 {
461 Assert(pos.getKind() == kind::AND);
462 Debug("miplib")
463 << "var " << v1 << " : " << pos[v1 - 1]
464 << " coef:" << coef[pos_var][v1 - 1] << endl;
465 sum += coef[pos_var][v1 - 1];
466 }
467 }
468 Debug("miplib") << "checkSum is " << sum << " input says "
469 << checks[pos_var][k] << endl;
470 if (sum != checks[pos_var][k])
471 {
472 eligible = false;
473 Debug("miplib") << " -- INELIGIBLE " << pos
474 << " -- (nonlinear combination)" << endl;
475 break;
476 }
477 }
478 else
479 {
480 Assert(checks[pos_var][k] == 0)
481 << "checks[(" << pos << "," << var << ")][" << k
482 << "] should be 0, but it's "
483 << checks[pos_var]
484 [k]; // we never set for single-positive-var
485 }
486 }
487 }
488 if (!eligible)
489 {
490 eligible = true; // next is still eligible
491 continue;
492 }
493
494 Debug("miplib") << " -- ELIGIBLE " << v0 << " , " << pos << " --"
495 << endl;
496 vector<Node> newVars;
497 expr::NodeSelfIterator ii, iiend;
498 if (pos.getKind() == kind::AND)
499 {
500 ii = pos.begin();
501 iiend = pos.end();
502 }
503 else
504 {
505 ii = expr::NodeSelfIterator::self(pos);
506 iiend = expr::NodeSelfIterator::selfEnd(pos);
507 }
508 for (; ii != iiend; ++ii)
509 {
510 Node& varRef = intVars[*ii];
511 if (varRef.isNull())
512 {
513 stringstream ss;
514 ss << "mipvar_" << *ii;
515 Node newVar = nm->mkSkolem(
516 ss.str(),
517 nm->integerType(),
518 "a variable introduced due to scrubbing a miplib encoding",
519 NodeManager::SKOLEM_EXACT_NAME);
520 Node geq = Rewriter::rewrite(nm->mkNode(kind::GEQ, newVar, zero));
521 Node leq = Rewriter::rewrite(nm->mkNode(kind::LEQ, newVar, one));
522
523 Node n = Rewriter::rewrite(geq.andNode(leq));
524 assertionsToPreprocess->push_back(n);
525 PROOF(ProofManager::currentPM()->addDependence(n, Node::null()));
526
527 SubstitutionMap nullMap(&fakeContext);
528 Theory::PPAssertStatus status CVC4_UNUSED; // just for assertions
529 status = te->solve(geq, nullMap);
530 Assert(status == Theory::PP_ASSERT_STATUS_UNSOLVED)
531 << "unexpected solution from arith's ppAssert()";
532 Assert(nullMap.empty())
533 << "unexpected substitution from arith's ppAssert()";
534 status = te->solve(leq, nullMap);
535 Assert(status == Theory::PP_ASSERT_STATUS_UNSOLVED)
536 << "unexpected solution from arith's ppAssert()";
537 Assert(nullMap.empty())
538 << "unexpected substitution from arith's ppAssert()";
539 te->getModel()->addSubstitution(*ii, newVar.eqNode(one));
540 newVars.push_back(newVar);
541 varRef = newVar;
542 }
543 else
544 {
545 newVars.push_back(varRef);
546 }
547 d_preprocContext->enableIntegers();
548 }
549 Node sum;
550 if (pos.getKind() == kind::AND)
551 {
552 NodeBuilder<> sumb(kind::PLUS);
553 for (size_t jj = 0; jj < pos.getNumChildren(); ++jj)
554 {
555 sumb << nm->mkNode(
556 kind::MULT, nm->mkConst(coef[pos_var][jj]), newVars[jj]);
557 }
558 sum = sumb;
559 }
560 else
561 {
562 sum = nm->mkNode(
563 kind::MULT, nm->mkConst(coef[pos_var][0]), newVars[0]);
564 }
565 Debug("miplib") << "vars[] " << var << endl
566 << " eq " << Rewriter::rewrite(sum) << endl;
567 Node newAssertion = var.eqNode(Rewriter::rewrite(sum));
568 if (top_level_substs.hasSubstitution(newAssertion[0]))
569 {
570 // Warning() << "RE-SUBSTITUTION " << newAssertion[0] << endl;
571 // Warning() << "REPLACE " << newAssertion[1] << endl;
572 // Warning() << "ORIG " <<
573 // top_level_substs.getSubstitution(newAssertion[0]) << endl;
574 Assert(top_level_substs.getSubstitution(newAssertion[0])
575 == newAssertion[1]);
576 }
577 else if (pos.getNumChildren() <= options::arithMLTrickSubstitutions())
578 {
579 top_level_substs.addSubstitution(newAssertion[0], newAssertion[1]);
580 Debug("miplib") << "addSubs: " << newAssertion[0] << " to "
581 << newAssertion[1] << endl;
582 }
583 else
584 {
585 Debug("miplib")
586 << "skipSubs: " << newAssertion[0] << " to " << newAssertion[1]
587 << " (threshold is " << options::arithMLTrickSubstitutions()
588 << ")" << endl;
589 }
590 newAssertion = Rewriter::rewrite(newAssertion);
591 Debug("miplib") << " " << newAssertion << endl;
592
593 assertionsToPreprocess->push_back(newAssertion);
594 PROOF(ProofManager::currentPM()->addDependence(newAssertion,
595 Node::null()));
596
597 Debug("miplib") << " assertions to remove: " << endl;
598 for (vector<TNode>::const_iterator k = asserts[pos_var].begin(),
599 k_end = asserts[pos_var].end();
600 k != k_end;
601 ++k)
602 {
603 Debug("miplib") << " " << *k << endl;
604 removeAssertions.insert((*k).getId());
605 }
606 }
607 }
608 }
609 }
610 if (!removeAssertions.empty())
611 {
612 Debug("miplib") << " scrubbing miplib encoding..." << endl;
613 for (size_t i = 0, size = assertionsToPreprocess->getRealAssertionsEnd();
614 i < size;
615 ++i)
616 {
617 Node assertion = (*assertionsToPreprocess)[i];
618 if (removeAssertions.find(assertion.getId()) != removeAssertions.end())
619 {
620 Debug("miplib") << " - removing " << assertion << endl;
621 assertionsToPreprocess->replace(i, trueNode);
622 ++d_statistics.d_numMiplibAssertionsRemoved;
623 }
624 else if (assertion.getKind() == kind::AND)
625 {
626 size_t removals = removeFromConjunction(assertion, removeAssertions);
627 if (removals > 0)
628 {
629 Debug("miplib") << " - reduced " << assertion << endl;
630 Debug("miplib") << " - by " << removals << " conjuncts" << endl;
631 d_statistics.d_numMiplibAssertionsRemoved += removals;
632 }
633 }
634 Debug("miplib") << "had: " << assertion[i] << endl;
635 assertionsToPreprocess->replace(
636 i, Rewriter::rewrite(top_level_substs.apply(assertion)));
637 Debug("miplib") << "now: " << assertion << endl;
638 }
639 }
640 else
641 {
642 Debug("miplib") << " miplib pass found nothing." << endl;
643 }
644 assertionsToPreprocess->updateRealAssertionsEnd();
645 return PreprocessingPassResult::NO_CONFLICT;
646 }
647
648 MipLibTrick::Statistics::Statistics()
649 : d_numMiplibAssertionsRemoved(
650 "preprocessing::passes::MipLibTrick::numMiplibAssertionsRemoved", 0)
651 {
652 smtStatisticsRegistry()->registerStat(&d_numMiplibAssertionsRemoved);
653 }
654
655 MipLibTrick::Statistics::~Statistics()
656 {
657 smtStatisticsRegistry()->unregisterStat(&d_numMiplibAssertionsRemoved);
658 }
659
660
661 } // namespace passes
662 } // namespace preprocessing
663 } // namespace CVC4