Cleanup some includes (#5847)
[cvc5.git] / src / preprocessing / passes / miplib_trick.cpp
1 /********************* */
2 /*! \file miplib_trick.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Mathias Preiner, Andrew Reynolds, Morgan Deters
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief The MIPLIB trick preprocessing pass
13 **
14 **/
15
16 #include "preprocessing/passes/miplib_trick.h"
17
18 #include <vector>
19
20 #include "expr/node_self_iterator.h"
21 #include "options/arith_options.h"
22 #include "options/smt_options.h"
23 #include "smt/smt_statistics_registry.h"
24 #include "smt_util/boolean_simplification.h"
25 #include "theory/booleans/circuit_propagator.h"
26 #include "theory/theory_model.h"
27 #include "theory/trust_substitutions.h"
28
29 namespace CVC4 {
30 namespace preprocessing {
31 namespace passes {
32
33 using namespace CVC4::theory;
34
35 namespace {
36
37 /**
38 * Remove conjuncts in toRemove from conjunction n. Return # of removed
39 * conjuncts.
40 */
41 size_t removeFromConjunction(Node& n,
42 const std::unordered_set<unsigned long>& toRemove)
43 {
44 Assert(n.getKind() == kind::AND);
45 Node trueNode = NodeManager::currentNM()->mkConst(true);
46 size_t removals = 0;
47 for (Node::iterator j = n.begin(); j != n.end(); ++j)
48 {
49 size_t subremovals = 0;
50 Node sub = *j;
51 if (toRemove.find(sub.getId()) != toRemove.end()
52 || (sub.getKind() == kind::AND
53 && (subremovals = removeFromConjunction(sub, toRemove)) > 0))
54 {
55 NodeBuilder<> b(kind::AND);
56 b.append(n.begin(), j);
57 if (subremovals > 0)
58 {
59 removals += subremovals;
60 b << sub;
61 }
62 else
63 {
64 ++removals;
65 }
66 for (++j; j != n.end(); ++j)
67 {
68 if (toRemove.find((*j).getId()) != toRemove.end())
69 {
70 ++removals;
71 }
72 else if ((*j).getKind() == kind::AND)
73 {
74 sub = *j;
75 if ((subremovals = removeFromConjunction(sub, toRemove)) > 0)
76 {
77 removals += subremovals;
78 b << sub;
79 }
80 else
81 {
82 b << *j;
83 }
84 }
85 else
86 {
87 b << *j;
88 }
89 }
90 if (b.getNumChildren() == 0)
91 {
92 n = trueNode;
93 b.clear();
94 }
95 else if (b.getNumChildren() == 1)
96 {
97 n = b[0];
98 b.clear();
99 }
100 else
101 {
102 n = b;
103 }
104 n = Rewriter::rewrite(n);
105 return removals;
106 }
107 }
108
109 Assert(removals == 0);
110 return 0;
111 }
112
113 /**
114 * Trace nodes back to their assertions using CircuitPropagator's
115 * BackEdgesMap.
116 */
117 void traceBackToAssertions(booleans::CircuitPropagator* propagator,
118 const std::vector<Node>& nodes,
119 std::vector<TNode>& assertions)
120 {
121 const booleans::CircuitPropagator::BackEdgesMap& backEdges =
122 propagator->getBackEdges();
123 for (vector<Node>::const_iterator i = nodes.begin(); i != nodes.end(); ++i)
124 {
125 booleans::CircuitPropagator::BackEdgesMap::const_iterator j =
126 backEdges.find(*i);
127 // term must appear in map, otherwise how did we get here?!
128 Assert(j != backEdges.end());
129 // if term maps to empty, that means it's a top-level assertion
130 if (!(*j).second.empty())
131 {
132 traceBackToAssertions(propagator, (*j).second, assertions);
133 }
134 else
135 {
136 assertions.push_back(*i);
137 }
138 }
139 }
140
141 } // namespace
142
143 MipLibTrick::MipLibTrick(PreprocessingPassContext* preprocContext)
144 : PreprocessingPass(preprocContext, "miplib-trick")
145 {
146 if (!options::incrementalSolving())
147 {
148 NodeManager::currentNM()->subscribeEvents(this);
149 }
150 }
151
152 MipLibTrick::~MipLibTrick()
153 {
154 if (!options::incrementalSolving())
155 {
156 NodeManager::currentNM()->unsubscribeEvents(this);
157 }
158 }
159
160 void MipLibTrick::nmNotifyNewVar(TNode n)
161 {
162 if (n.getType().isBoolean())
163 {
164 d_boolVars.push_back(n);
165 }
166 }
167
168 void MipLibTrick::nmNotifyNewSkolem(TNode n,
169 const std::string& comment,
170 uint32_t flags)
171 {
172 if (n.getType().isBoolean())
173 {
174 d_boolVars.push_back(n);
175 }
176 }
177
178 PreprocessingPassResult MipLibTrick::applyInternal(
179 AssertionPipeline* assertionsToPreprocess)
180 {
181 Assert(assertionsToPreprocess->getRealAssertionsEnd()
182 == assertionsToPreprocess->size());
183 Assert(!options::incrementalSolving());
184 Assert(!options::unsatCores());
185
186 context::Context fakeContext;
187 TheoryEngine* te = d_preprocContext->getTheoryEngine();
188 booleans::CircuitPropagator* propagator =
189 d_preprocContext->getCircuitPropagator();
190 const booleans::CircuitPropagator::BackEdgesMap& backEdges =
191 propagator->getBackEdges();
192 unordered_set<unsigned long> removeAssertions;
193
194 theory::TrustSubstitutionMap& tlsm =
195 d_preprocContext->getTopLevelSubstitutions();
196 SubstitutionMap& top_level_substs = tlsm.get();
197
198 NodeManager* nm = NodeManager::currentNM();
199 Node zero = nm->mkConst(Rational(0)), one = nm->mkConst(Rational(1));
200 Node trueNode = nm->mkConst(true);
201
202 unordered_map<TNode, Node, TNodeHashFunction> intVars;
203 for (TNode v0 : d_boolVars)
204 {
205 if (propagator->isAssigned(v0))
206 {
207 Debug("miplib") << "ineligible: " << v0 << " because assigned "
208 << propagator->getAssignment(v0) << endl;
209 continue;
210 }
211
212 vector<TNode> assertions;
213 booleans::CircuitPropagator::BackEdgesMap::const_iterator j0 =
214 backEdges.find(v0);
215 // if not in back edges map, the bool var is unconstrained, showing up in no
216 // assertions. if maps to an empty vector, that means the bool var was
217 // asserted itself.
218 if (j0 != backEdges.end())
219 {
220 if (!(*j0).second.empty())
221 {
222 traceBackToAssertions(propagator, (*j0).second, assertions);
223 }
224 else
225 {
226 assertions.push_back(v0);
227 }
228 }
229 Debug("miplib") << "for " << v0 << endl;
230 bool eligible = true;
231 map<pair<Node, Node>, uint64_t> marks;
232 map<pair<Node, Node>, vector<Rational> > coef;
233 map<pair<Node, Node>, vector<Rational> > checks;
234 map<pair<Node, Node>, vector<TNode> > asserts;
235 for (vector<TNode>::const_iterator j1 = assertions.begin();
236 j1 != assertions.end();
237 ++j1)
238 {
239 Debug("miplib") << " found: " << *j1 << endl;
240 if ((*j1).getKind() != kind::IMPLIES)
241 {
242 eligible = false;
243 Debug("miplib") << " -- INELIGIBLE -- (not =>)" << endl;
244 break;
245 }
246 Node conj = BooleanSimplification::simplify((*j1)[0]);
247 if (conj.getKind() == kind::AND && conj.getNumChildren() > 6)
248 {
249 eligible = false;
250 Debug("miplib") << " -- INELIGIBLE -- (N-ary /\\ too big)" << endl;
251 break;
252 }
253 if (conj.getKind() != kind::AND && !conj.isVar()
254 && !(conj.getKind() == kind::NOT && conj[0].isVar()))
255 {
256 eligible = false;
257 Debug("miplib") << " -- INELIGIBLE -- (not /\\ or literal)" << endl;
258 break;
259 }
260 if ((*j1)[1].getKind() != kind::EQUAL
261 || !(((*j1)[1][0].isVar()
262 && (*j1)[1][1].getKind() == kind::CONST_RATIONAL)
263 || ((*j1)[1][0].getKind() == kind::CONST_RATIONAL
264 && (*j1)[1][1].isVar())))
265 {
266 eligible = false;
267 Debug("miplib") << " -- INELIGIBLE -- (=> (and X X) X)" << endl;
268 break;
269 }
270 if (conj.getKind() == kind::AND)
271 {
272 vector<Node> posv;
273 bool found_x = false;
274 map<TNode, bool> neg;
275 for (Node::iterator ii = conj.begin(); ii != conj.end(); ++ii)
276 {
277 if ((*ii).isVar())
278 {
279 posv.push_back(*ii);
280 neg[*ii] = false;
281 found_x = found_x || v0 == *ii;
282 }
283 else if ((*ii).getKind() == kind::NOT && (*ii)[0].isVar())
284 {
285 posv.push_back((*ii)[0]);
286 neg[(*ii)[0]] = true;
287 found_x = found_x || v0 == (*ii)[0];
288 }
289 else
290 {
291 eligible = false;
292 Debug("miplib")
293 << " -- INELIGIBLE -- (non-var: " << *ii << ")" << endl;
294 break;
295 }
296 if (propagator->isAssigned(posv.back()))
297 {
298 eligible = false;
299 Debug("miplib") << " -- INELIGIBLE -- (" << posv.back()
300 << " asserted)" << endl;
301 break;
302 }
303 }
304 if (!eligible)
305 {
306 break;
307 }
308 if (!found_x)
309 {
310 eligible = false;
311 Debug("miplib") << " --INELIGIBLE -- (couldn't find " << v0
312 << " in conjunction)" << endl;
313 break;
314 }
315 sort(posv.begin(), posv.end());
316 const Node pos = NodeManager::currentNM()->mkNode(kind::AND, posv);
317 const TNode var = ((*j1)[1][0].getKind() == kind::CONST_RATIONAL)
318 ? (*j1)[1][1]
319 : (*j1)[1][0];
320 const pair<Node, Node> pos_var(pos, var);
321 const Rational& constant =
322 ((*j1)[1][0].getKind() == kind::CONST_RATIONAL)
323 ? (*j1)[1][0].getConst<Rational>()
324 : (*j1)[1][1].getConst<Rational>();
325 uint64_t mark = 0;
326 unsigned countneg = 0, thepos = 0;
327 for (unsigned ii = 0; ii < pos.getNumChildren(); ++ii)
328 {
329 if (neg[pos[ii]])
330 {
331 ++countneg;
332 }
333 else
334 {
335 thepos = ii;
336 mark |= (0x1 << ii);
337 }
338 }
339 if ((marks[pos_var] & (1lu << mark)) != 0)
340 {
341 eligible = false;
342 Debug("miplib") << " -- INELIGIBLE -- (remarked)" << endl;
343 break;
344 }
345 Debug("miplib") << "mark is " << mark << " -- " << (1lu << mark)
346 << endl;
347 marks[pos_var] |= (1lu << mark);
348 Debug("miplib") << "marks[" << pos << "," << var << "] now "
349 << marks[pos_var] << endl;
350 if (countneg == pos.getNumChildren())
351 {
352 if (constant != 0)
353 {
354 eligible = false;
355 Debug("miplib") << " -- INELIGIBLE -- (nonzero constant)" << endl;
356 break;
357 }
358 }
359 else if (countneg == pos.getNumChildren() - 1)
360 {
361 Assert(coef[pos_var].size() <= 6 && thepos < 6);
362 if (coef[pos_var].size() <= thepos)
363 {
364 coef[pos_var].resize(thepos + 1);
365 }
366 coef[pos_var][thepos] = constant;
367 }
368 else
369 {
370 if (checks[pos_var].size() <= mark)
371 {
372 checks[pos_var].resize(mark + 1);
373 }
374 checks[pos_var][mark] = constant;
375 }
376 asserts[pos_var].push_back(*j1);
377 }
378 else
379 {
380 TNode x = conj;
381 if (x != v0 && x != (v0).notNode())
382 {
383 eligible = false;
384 Debug("miplib")
385 << " -- INELIGIBLE -- (x not present where I expect it)" << endl;
386 break;
387 }
388 const bool xneg = (x.getKind() == kind::NOT);
389 x = xneg ? x[0] : x;
390 Debug("miplib") << " x:" << x << " " << xneg << endl;
391 const TNode var = ((*j1)[1][0].getKind() == kind::CONST_RATIONAL)
392 ? (*j1)[1][1]
393 : (*j1)[1][0];
394 const pair<Node, Node> x_var(x, var);
395 const Rational& constant =
396 ((*j1)[1][0].getKind() == kind::CONST_RATIONAL)
397 ? (*j1)[1][0].getConst<Rational>()
398 : (*j1)[1][1].getConst<Rational>();
399 unsigned mark = (xneg ? 0 : 1);
400 if ((marks[x_var] & (1u << mark)) != 0)
401 {
402 eligible = false;
403 Debug("miplib") << " -- INELIGIBLE -- (remarked)" << endl;
404 break;
405 }
406 marks[x_var] |= (1u << mark);
407 if (xneg)
408 {
409 if (constant != 0)
410 {
411 eligible = false;
412 Debug("miplib") << " -- INELIGIBLE -- (nonzero constant)" << endl;
413 break;
414 }
415 }
416 else
417 {
418 Assert(coef[x_var].size() <= 6);
419 coef[x_var].resize(6);
420 coef[x_var][0] = constant;
421 }
422 asserts[x_var].push_back(*j1);
423 }
424 }
425 if (eligible)
426 {
427 for (map<pair<Node, Node>, uint64_t>::const_iterator j = marks.begin();
428 j != marks.end();
429 ++j)
430 {
431 const TNode pos = (*j).first.first;
432 const TNode var = (*j).first.second;
433 const pair<Node, Node>& pos_var = (*j).first;
434 const uint64_t mark = (*j).second;
435 const unsigned numVars =
436 pos.getKind() == kind::AND ? pos.getNumChildren() : 1;
437 uint64_t expected = (uint64_t(1) << (1 << numVars)) - 1;
438 expected = (expected == 0) ? -1 : expected; // fix for overflow
439 Debug("miplib") << "[" << pos << "] => " << hex << mark << " expect "
440 << expected << dec << endl;
441 Assert(pos.getKind() == kind::AND || pos.isVar());
442 if (mark != expected)
443 {
444 Debug("miplib") << " -- INELIGIBLE " << pos
445 << " -- (insufficiently marked, got " << mark
446 << " for " << numVars << " vars, expected "
447 << expected << endl;
448 }
449 else
450 {
451 if (mark != 3)
452 { // exclude single-var case; nothing to check there
453 uint64_t sz = (uint64_t(1) << checks[pos_var].size()) - 1;
454 sz = (sz == 0) ? -1 : sz; // fix for overflow
455 Assert(sz == mark) << "expected size " << sz << " == mark " << mark;
456 for (size_t k = 0; k < checks[pos_var].size(); ++k)
457 {
458 if ((k & (k - 1)) != 0)
459 {
460 Rational sum = 0;
461 Debug("miplib") << k << " => " << checks[pos_var][k] << endl;
462 for (size_t v1 = 1, kk = k; kk != 0; ++v1, kk >>= 1)
463 {
464 if ((kk & 0x1) == 1)
465 {
466 Assert(pos.getKind() == kind::AND);
467 Debug("miplib")
468 << "var " << v1 << " : " << pos[v1 - 1]
469 << " coef:" << coef[pos_var][v1 - 1] << endl;
470 sum += coef[pos_var][v1 - 1];
471 }
472 }
473 Debug("miplib") << "checkSum is " << sum << " input says "
474 << checks[pos_var][k] << endl;
475 if (sum != checks[pos_var][k])
476 {
477 eligible = false;
478 Debug("miplib") << " -- INELIGIBLE " << pos
479 << " -- (nonlinear combination)" << endl;
480 break;
481 }
482 }
483 else
484 {
485 Assert(checks[pos_var][k] == 0)
486 << "checks[(" << pos << "," << var << ")][" << k
487 << "] should be 0, but it's "
488 << checks[pos_var]
489 [k]; // we never set for single-positive-var
490 }
491 }
492 }
493 if (!eligible)
494 {
495 eligible = true; // next is still eligible
496 continue;
497 }
498
499 Debug("miplib") << " -- ELIGIBLE " << v0 << " , " << pos << " --"
500 << endl;
501 vector<Node> newVars;
502 expr::NodeSelfIterator ii, iiend;
503 if (pos.getKind() == kind::AND)
504 {
505 ii = pos.begin();
506 iiend = pos.end();
507 }
508 else
509 {
510 ii = expr::NodeSelfIterator::self(pos);
511 iiend = expr::NodeSelfIterator::selfEnd(pos);
512 }
513 for (; ii != iiend; ++ii)
514 {
515 Node& varRef = intVars[*ii];
516 if (varRef.isNull())
517 {
518 stringstream ss;
519 ss << "mipvar_" << *ii;
520 Node newVar = nm->mkSkolem(
521 ss.str(),
522 nm->integerType(),
523 "a variable introduced due to scrubbing a miplib encoding",
524 NodeManager::SKOLEM_EXACT_NAME);
525 Node geq = Rewriter::rewrite(nm->mkNode(kind::GEQ, newVar, zero));
526 Node leq = Rewriter::rewrite(nm->mkNode(kind::LEQ, newVar, one));
527 TrustNode tgeq = TrustNode::mkTrustLemma(geq, nullptr);
528 TrustNode tleq = TrustNode::mkTrustLemma(leq, nullptr);
529
530 Node n = Rewriter::rewrite(geq.andNode(leq));
531 assertionsToPreprocess->push_back(n);
532 TrustSubstitutionMap tnullMap(&fakeContext, nullptr);
533 CVC4_UNUSED SubstitutionMap& nullMap = tnullMap.get();
534 Theory::PPAssertStatus status CVC4_UNUSED; // just for assertions
535 status = te->solve(tgeq, tnullMap);
536 Assert(status == Theory::PP_ASSERT_STATUS_UNSOLVED)
537 << "unexpected solution from arith's ppAssert()";
538 Assert(nullMap.empty())
539 << "unexpected substitution from arith's ppAssert()";
540 status = te->solve(tleq, tnullMap);
541 Assert(status == Theory::PP_ASSERT_STATUS_UNSOLVED)
542 << "unexpected solution from arith's ppAssert()";
543 Assert(nullMap.empty())
544 << "unexpected substitution from arith's ppAssert()";
545 te->getModel()->addSubstitution(*ii, newVar.eqNode(one));
546 newVars.push_back(newVar);
547 varRef = newVar;
548 }
549 else
550 {
551 newVars.push_back(varRef);
552 }
553 d_preprocContext->enableIntegers();
554 }
555 Node sum;
556 if (pos.getKind() == kind::AND)
557 {
558 NodeBuilder<> sumb(kind::PLUS);
559 for (size_t jj = 0; jj < pos.getNumChildren(); ++jj)
560 {
561 sumb << nm->mkNode(
562 kind::MULT, nm->mkConst(coef[pos_var][jj]), newVars[jj]);
563 }
564 sum = sumb;
565 }
566 else
567 {
568 sum = nm->mkNode(
569 kind::MULT, nm->mkConst(coef[pos_var][0]), newVars[0]);
570 }
571 Debug("miplib") << "vars[] " << var << endl
572 << " eq " << Rewriter::rewrite(sum) << endl;
573 Node newAssertion = var.eqNode(Rewriter::rewrite(sum));
574 if (top_level_substs.hasSubstitution(newAssertion[0]))
575 {
576 // Warning() << "RE-SUBSTITUTION " << newAssertion[0] << endl;
577 // Warning() << "REPLACE " << newAssertion[1] << endl;
578 // Warning() << "ORIG " <<
579 // top_level_substs.getSubstitution(newAssertion[0]) << endl;
580 Assert(top_level_substs.getSubstitution(newAssertion[0])
581 == newAssertion[1]);
582 }
583 else if (pos.getNumChildren() <= options::arithMLTrickSubstitutions())
584 {
585 top_level_substs.addSubstitution(newAssertion[0], newAssertion[1]);
586 Debug("miplib") << "addSubs: " << newAssertion[0] << " to "
587 << newAssertion[1] << endl;
588 }
589 else
590 {
591 Debug("miplib")
592 << "skipSubs: " << newAssertion[0] << " to " << newAssertion[1]
593 << " (threshold is " << options::arithMLTrickSubstitutions()
594 << ")" << endl;
595 }
596 newAssertion = Rewriter::rewrite(newAssertion);
597 Debug("miplib") << " " << newAssertion << endl;
598
599 assertionsToPreprocess->push_back(newAssertion);
600 Debug("miplib") << " assertions to remove: " << endl;
601 for (vector<TNode>::const_iterator k = asserts[pos_var].begin(),
602 k_end = asserts[pos_var].end();
603 k != k_end;
604 ++k)
605 {
606 Debug("miplib") << " " << *k << endl;
607 removeAssertions.insert((*k).getId());
608 }
609 }
610 }
611 }
612 }
613 if (!removeAssertions.empty())
614 {
615 Debug("miplib") << " scrubbing miplib encoding..." << endl;
616 for (size_t i = 0, size = assertionsToPreprocess->getRealAssertionsEnd();
617 i < size;
618 ++i)
619 {
620 Node assertion = (*assertionsToPreprocess)[i];
621 if (removeAssertions.find(assertion.getId()) != removeAssertions.end())
622 {
623 Debug("miplib") << " - removing " << assertion << endl;
624 assertionsToPreprocess->replace(i, trueNode);
625 ++d_statistics.d_numMiplibAssertionsRemoved;
626 }
627 else if (assertion.getKind() == kind::AND)
628 {
629 size_t removals = removeFromConjunction(assertion, removeAssertions);
630 if (removals > 0)
631 {
632 Debug("miplib") << " - reduced " << assertion << endl;
633 Debug("miplib") << " - by " << removals << " conjuncts" << endl;
634 d_statistics.d_numMiplibAssertionsRemoved += removals;
635 }
636 }
637 Debug("miplib") << "had: " << assertion[i] << endl;
638 assertionsToPreprocess->replace(
639 i, Rewriter::rewrite(top_level_substs.apply(assertion)));
640 Debug("miplib") << "now: " << assertion << endl;
641 }
642 }
643 else
644 {
645 Debug("miplib") << " miplib pass found nothing." << endl;
646 }
647 assertionsToPreprocess->updateRealAssertionsEnd();
648 return PreprocessingPassResult::NO_CONFLICT;
649 }
650
651 MipLibTrick::Statistics::Statistics()
652 : d_numMiplibAssertionsRemoved(
653 "preprocessing::passes::MipLibTrick::numMiplibAssertionsRemoved", 0)
654 {
655 smtStatisticsRegistry()->registerStat(&d_numMiplibAssertionsRemoved);
656 }
657
658 MipLibTrick::Statistics::~Statistics()
659 {
660 smtStatisticsRegistry()->unregisterStat(&d_numMiplibAssertionsRemoved);
661 }
662
663
664 } // namespace passes
665 } // namespace preprocessing
666 } // namespace CVC4