4195a8a16213854ad5e1815903ef997f01d4e810
[cvc5.git] / src / theory / quantifiers / extended_rewrite.cpp
1 /********************* */
2 /*! \file extended_rewrite.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of extended rewriting techniques
13 **/
14
15 #include "theory/quantifiers/extended_rewrite.h"
16
17 #include "options/quantifiers_options.h"
18 #include "theory/arith/arith_msum.h"
19 #include "theory/bv/theory_bv_utils.h"
20 #include "theory/datatypes/datatypes_rewriter.h"
21 #include "theory/quantifiers/term_util.h"
22 #include "theory/rewriter.h"
23
24 using namespace CVC4::kind;
25 using namespace std;
26
27 namespace CVC4 {
28 namespace theory {
29 namespace quantifiers {
30
31 struct ExtRewriteAttributeId
32 {
33 };
34 typedef expr::Attribute<ExtRewriteAttributeId, Node> ExtRewriteAttribute;
35
36 ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr)
37 {
38 }
39 void ExtendedRewriter::setCache(Node n, Node ret)
40 {
41 ExtRewriteAttribute era;
42 n.setAttribute(era, ret);
43 }
44
45 bool ExtendedRewriter::addToChildren(Node nc,
46 std::vector<Node>& children,
47 bool dropDup)
48 {
49 // If the operator is non-additive, do not consider duplicates
50 if (dropDup
51 && std::find(children.begin(), children.end(), nc) != children.end())
52 {
53 return false;
54 }
55 children.push_back(nc);
56 return true;
57 }
58
59 Node ExtendedRewriter::extendedRewrite(Node n)
60 {
61 n = Rewriter::rewrite(n);
62 if (!options::sygusExtRew())
63 {
64 return n;
65 }
66
67 // has it already been computed?
68 if (n.hasAttribute(ExtRewriteAttribute()))
69 {
70 return n.getAttribute(ExtRewriteAttribute());
71 }
72
73 Node ret = n;
74 NodeManager* nm = NodeManager::currentNM();
75
76 //--------------------pre-rewrite
77 Node pre_new_ret;
78 if (ret.getKind() == IMPLIES)
79 {
80 pre_new_ret = nm->mkNode(OR, ret[0].negate(), ret[1]);
81 debugExtendedRewrite(ret, pre_new_ret, "IMPLIES elim");
82 }
83 else if (ret.getKind() == XOR)
84 {
85 pre_new_ret = nm->mkNode(EQUAL, ret[0].negate(), ret[1]);
86 debugExtendedRewrite(ret, pre_new_ret, "XOR elim");
87 }
88 else if (ret.getKind() == NOT)
89 {
90 pre_new_ret = extendedRewriteNnf(ret);
91 debugExtendedRewrite(ret, pre_new_ret, "NNF");
92 }
93 if (!pre_new_ret.isNull())
94 {
95 ret = extendedRewrite(pre_new_ret);
96 Trace("q-ext-rewrite-debug") << "...ext-pre-rewrite : " << n << " -> "
97 << pre_new_ret << std::endl;
98 setCache(n, ret);
99 return ret;
100 }
101 //--------------------end pre-rewrite
102
103 //--------------------rewrite children
104 if (n.getNumChildren() > 0)
105 {
106 std::vector<Node> children;
107 if (n.getMetaKind() == metakind::PARAMETERIZED)
108 {
109 children.push_back(n.getOperator());
110 }
111 Kind k = n.getKind();
112 bool childChanged = false;
113 bool isNonAdditive = TermUtil::isNonAdditive(k);
114 bool isAssoc = TermUtil::isAssoc(k);
115 for (unsigned i = 0; i < n.getNumChildren(); i++)
116 {
117 Node nc = extendedRewrite(n[i]);
118 childChanged = nc != n[i] || childChanged;
119 if (isAssoc && nc.getKind() == n.getKind())
120 {
121 for (const Node& ncc : nc)
122 {
123 if (!addToChildren(ncc, children, isNonAdditive))
124 {
125 childChanged = true;
126 }
127 }
128 }
129 else if (!addToChildren(nc, children, isNonAdditive))
130 {
131 childChanged = true;
132 }
133 }
134 Assert(!children.empty());
135 // Some commutative operators have rewriters that are agnostic to order,
136 // thus, we sort here.
137 if (TermUtil::isComm(k) && (d_aggr || children.size() <= 5))
138 {
139 childChanged = true;
140 std::sort(children.begin(), children.end());
141 }
142 if (childChanged)
143 {
144 if (isNonAdditive && children.size() == 1)
145 {
146 // we may have subsumed children down to one
147 ret = children[0];
148 }
149 else
150 {
151 ret = nm->mkNode(k, children);
152 }
153 }
154 }
155 ret = Rewriter::rewrite(ret);
156 //--------------------end rewrite children
157
158 // now, do extended rewrite
159 Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret
160 << " (from " << n << ")" << std::endl;
161 Node new_ret;
162
163 //---------------------- theory-independent post-rewriting
164 if (ret.getKind() == ITE)
165 {
166 new_ret = extendedRewriteIte(ITE, ret);
167 }
168 else if (ret.getKind() == AND || ret.getKind() == OR)
169 {
170 // all kinds are legal to substitute over : hence we give the empty map
171 std::map<Kind, bool> bcp_kinds;
172 new_ret = extendedRewriteBcp(AND, OR, NOT, bcp_kinds, ret);
173 debugExtendedRewrite(ret, new_ret, "Bool bcp");
174 if (new_ret.isNull())
175 {
176 // equality resolution
177 new_ret =
178 extendedRewriteEqRes(AND, OR, EQUAL, NOT, bcp_kinds, ret, false);
179 debugExtendedRewrite(ret, new_ret, "Bool eq res");
180 }
181 }
182 else if (ret.getKind() == EQUAL)
183 {
184 new_ret = extendedRewriteEqChain(EQUAL, AND, OR, NOT, ret);
185 debugExtendedRewrite(ret, new_ret, "Bool eq-chain simplify");
186 }
187 if (new_ret.isNull() && ret.getKind() != ITE)
188 {
189 // simple ITE pulling
190 new_ret = extendedRewritePullIte(ITE, ret);
191 }
192 //----------------------end theory-independent post-rewriting
193
194 //----------------------theory-specific post-rewriting
195 if (new_ret.isNull())
196 {
197 Node atom = ret.getKind() == NOT ? ret[0] : ret;
198 bool pol = ret.getKind() != NOT;
199 TheoryId tid = Theory::theoryOf(atom);
200 if (tid == THEORY_ARITH)
201 {
202 new_ret = extendedRewriteArith(atom, pol);
203 }
204 // add back negation if not processed
205 if (!pol && !new_ret.isNull())
206 {
207 new_ret = new_ret.negate();
208 }
209 }
210 //----------------------end theory-specific post-rewriting
211
212 //----------------------aggressive rewrites
213 if (new_ret.isNull() && d_aggr)
214 {
215 new_ret = extendedRewriteAggr(ret);
216 }
217 //----------------------end aggressive rewrites
218
219 setCache(n, ret);
220 if (!new_ret.isNull())
221 {
222 ret = extendedRewrite(new_ret);
223 }
224 Trace("q-ext-rewrite-debug") << "...ext-rewrite : " << n << " -> " << ret
225 << std::endl;
226 setCache(n, ret);
227 return ret;
228 }
229
230 Node ExtendedRewriter::extendedRewriteAggr(Node n)
231 {
232 Node new_ret;
233 Trace("q-ext-rewrite-debug2")
234 << "Do aggressive rewrites on " << n << std::endl;
235 bool polarity = n.getKind() != NOT;
236 Node ret_atom = n.getKind() == NOT ? n[0] : n;
237 if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal())
238 || ret_atom.getKind() == GEQ)
239 {
240 // ITE term removal in polynomials
241 // e.g. ite( x=0, x, y ) = x+1 ---> ( x=0 ^ y = x+1 )
242 Trace("q-ext-rewrite-debug2")
243 << "Compute monomial sum " << ret_atom << std::endl;
244 // compute monomial sum
245 std::map<Node, Node> msum;
246 if (ArithMSum::getMonomialSumLit(ret_atom, msum))
247 {
248 for (std::map<Node, Node>::iterator itm = msum.begin(); itm != msum.end();
249 ++itm)
250 {
251 Node v = itm->first;
252 Trace("q-ext-rewrite-debug2")
253 << itm->first << " * " << itm->second << std::endl;
254 if (v.getKind() == ITE)
255 {
256 Node veq;
257 int res = ArithMSum::isolate(v, msum, veq, ret_atom.getKind());
258 if (res != 0)
259 {
260 Trace("q-ext-rewrite-debug")
261 << " have ITE relation, solved form : " << veq << std::endl;
262 // try pulling ITE
263 new_ret = extendedRewritePullIte(ITE, veq);
264 if (!new_ret.isNull())
265 {
266 if (!polarity)
267 {
268 new_ret = new_ret.negate();
269 }
270 break;
271 }
272 }
273 else
274 {
275 Trace("q-ext-rewrite-debug")
276 << " failed to isolate " << v << " in " << n << std::endl;
277 }
278 }
279 }
280 }
281 else
282 {
283 Trace("q-ext-rewrite-debug")
284 << " failed to get monomial sum of " << n << std::endl;
285 }
286 }
287 // TODO (#1706) : conditional rewriting, condition merging
288 return new_ret;
289 }
290
291 Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full)
292 {
293 Assert(n.getKind() == itek);
294 Assert(n[1] != n[2]);
295
296 NodeManager* nm = NodeManager::currentNM();
297
298 Trace("ext-rew-ite") << "Rewrite ITE : " << n << std::endl;
299
300 Node flip_cond;
301 if (n[0].getKind() == NOT)
302 {
303 flip_cond = n[0][0];
304 }
305 else if (n[0].getKind() == OR)
306 {
307 // a | b ---> ~( ~a & ~b )
308 flip_cond = TermUtil::simpleNegate(n[0]);
309 }
310 if (!flip_cond.isNull())
311 {
312 Node new_ret = nm->mkNode(ITE, flip_cond, n[2], n[1]);
313 // only print debug trace if full=true
314 if (full)
315 {
316 debugExtendedRewrite(n, new_ret, "ITE flip");
317 }
318 return new_ret;
319 }
320
321 // get entailed equalities in the condition
322 std::vector<Node> eq_conds;
323 Kind ck = n[0].getKind();
324 if (ck == EQUAL)
325 {
326 eq_conds.push_back(n[0]);
327 }
328 else if (ck == AND)
329 {
330 for (const Node& cn : n[0])
331 {
332 if (cn.getKind() == EQUAL)
333 {
334 eq_conds.push_back(cn);
335 }
336 }
337 }
338
339 Node new_ret;
340 Node b;
341 Node e;
342 Node t1 = n[1];
343 Node t2 = n[2];
344 std::stringstream ss_reason;
345
346 for (const Node& eq : eq_conds)
347 {
348 // simple invariant ITE
349 for (unsigned i = 0; i <= 1; i++)
350 {
351 // ite( x = y ^ C, y, x ) ---> x
352 // this is subsumed by the rewrites below
353 if (t2 == eq[i] && t1 == eq[1 - i])
354 {
355 new_ret = t2;
356 ss_reason << "ITE simple rev subs";
357 break;
358 }
359 }
360 if (!new_ret.isNull())
361 {
362 break;
363 }
364 }
365
366 if (new_ret.isNull() && d_aggr)
367 {
368 // If x is less than t based on an ordering, then we use { x -> t } as a
369 // substitution to the children of ite( x = t ^ C, s, t ) below.
370 std::vector<Node> vars;
371 std::vector<Node> subs;
372 for (const Node& eq : eq_conds)
373 {
374 inferSubstitution(eq, vars, subs);
375 }
376
377 if (!vars.empty())
378 {
379 // reverse substitution to opposite child
380 // r{ x -> t } = s implies ite( x=t ^ C, s, r ) ---> r
381 Node nn =
382 t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
383 if (nn != t2)
384 {
385 nn = Rewriter::rewrite(nn);
386 if (nn == t1)
387 {
388 new_ret = t2;
389 ss_reason << "ITE rev subs";
390 }
391 }
392
393 // ite( x=t ^ C, s, r ) ---> ite( x=t ^ C, s{ x -> t }, r )
394 nn = t1.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
395 if (nn != t1)
396 {
397 // If full=false, then we've duplicated a term u in the children of n.
398 // For example, when ITE pulling, we have n is of the form:
399 // ite( C, f( u, t1 ), f( u, t2 ) )
400 // We must show that at least one copy of u dissappears in this case.
401 nn = Rewriter::rewrite(nn);
402 if (nn == t2)
403 {
404 new_ret = nn;
405 ss_reason << "ITE subs invariant";
406 }
407 else if (full || nn.isConst())
408 {
409 new_ret = nm->mkNode(itek, n[0], nn, t2);
410 ss_reason << "ITE subs";
411 }
412 }
413 }
414 }
415
416 // only print debug trace if full=true
417 if (!new_ret.isNull() && full)
418 {
419 debugExtendedRewrite(n, new_ret, ss_reason.str().c_str());
420 }
421
422 return new_ret;
423 }
424
425 Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n)
426 {
427 NodeManager* nm = NodeManager::currentNM();
428 TypeNode tn = n.getType();
429 std::vector<Node> children;
430 bool hasOp = (n.getMetaKind() == metakind::PARAMETERIZED);
431 if (hasOp)
432 {
433 children.push_back(n.getOperator());
434 }
435 unsigned nchildren = n.getNumChildren();
436 for (unsigned i = 0; i < nchildren; i++)
437 {
438 children.push_back(n[i]);
439 }
440 std::map<unsigned, std::map<unsigned, Node> > ite_c;
441 for (unsigned i = 0; i < nchildren; i++)
442 {
443 if (n[i].getKind() == itek)
444 {
445 unsigned ii = hasOp ? i + 1 : i;
446 for (unsigned j = 0; j < 2; j++)
447 {
448 children[ii] = n[i][j + 1];
449 Node pull = nm->mkNode(n.getKind(), children);
450 Node pullr = Rewriter::rewrite(pull);
451 children[ii] = n[i];
452 ite_c[i][j] = pullr;
453 }
454 if (ite_c[i][0] == ite_c[i][1])
455 {
456 // ITE dual invariance
457 // f( t1..s1..tn ) ---> t and f( t1..s2..tn ) ---> t implies
458 // f( t1..ite( A, s1, s2 )..tn ) ---> t
459 debugExtendedRewrite(n, ite_c[i][0], "ITE dual invariant");
460 return ite_c[i][0];
461 }
462 else if (d_aggr)
463 {
464 for (unsigned j = 0; j < 2; j++)
465 {
466 Node pullr = ite_c[i][j];
467 if (pullr.isConst() || pullr == n[i][j + 1])
468 {
469 // ITE single child elimination
470 // f( t1..s1..tn ) ---> t where t is a constant or s1 itself
471 // implies
472 // f( t1..ite( A, s1, s2 )..tn ) ---> ite( A, t, f( t1..s2..tn ) )
473 Node new_ret;
474 if (tn.isBoolean())
475 {
476 // remove false/true child immediately
477 bool pol = pullr.getConst<bool>();
478 std::vector<Node> new_children;
479 new_children.push_back((j == 0) == pol ? n[i][0]
480 : n[i][0].negate());
481 new_children.push_back(ite_c[i][1 - j]);
482 new_ret = nm->mkNode(pol ? OR : AND, new_children);
483 debugExtendedRewrite(n, new_ret, "ITE Bool single elim");
484 }
485 else
486 {
487 new_ret = nm->mkNode(itek, n[i][0], ite_c[i][0], ite_c[i][1]);
488 debugExtendedRewrite(n, new_ret, "ITE single elim");
489 }
490 return new_ret;
491 }
492 }
493 }
494 }
495 }
496
497 for (std::pair<const unsigned, std::map<unsigned, Node> >& ip : ite_c)
498 {
499 Node nite = n[ip.first];
500 Assert(nite.getKind() == itek);
501 // now, simply pull the ITE and try ITE rewrites
502 Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]);
503 pull_ite = Rewriter::rewrite(pull_ite);
504 if (pull_ite.getKind() == ITE)
505 {
506 Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false);
507 if (!new_pull_ite.isNull())
508 {
509 debugExtendedRewrite(n, new_pull_ite, "ITE pull rewrite");
510 return new_pull_ite;
511 }
512 }
513 else
514 {
515 // A general rewrite could eliminate the ITE by pulling.
516 // An example is:
517 // ~( ite( C, ~x, ~ite( C, y, x ) ) ) --->
518 // ite( C, ~~x, ite( C, y, x ) ) --->
519 // x
520 // where ~ is bitvector negation.
521 debugExtendedRewrite(n, pull_ite, "ITE pull basic elim");
522 return pull_ite;
523 }
524 }
525
526 return Node::null();
527 }
528
529 Node ExtendedRewriter::extendedRewriteNnf(Node ret)
530 {
531 Assert(ret.getKind() == NOT);
532
533 Kind nk = ret[0].getKind();
534 bool neg_ch = false;
535 bool neg_ch_1 = false;
536 if (nk == AND || nk == OR)
537 {
538 neg_ch = true;
539 nk = nk == AND ? OR : AND;
540 }
541 else if (nk == IMPLIES)
542 {
543 neg_ch = true;
544 neg_ch_1 = true;
545 nk = AND;
546 }
547 else if (nk == ITE)
548 {
549 neg_ch = true;
550 neg_ch_1 = true;
551 }
552 else if (nk == XOR)
553 {
554 nk = EQUAL;
555 }
556 else if (nk == EQUAL && ret[0][0].getType().isBoolean())
557 {
558 neg_ch_1 = true;
559 }
560 else
561 {
562 return Node::null();
563 }
564
565 std::vector<Node> new_children;
566 for (unsigned i = 0, nchild = ret[0].getNumChildren(); i < nchild; i++)
567 {
568 Node c = ret[0][i];
569 c = (i == 0 ? neg_ch_1 : false) != neg_ch ? c.negate() : c;
570 new_children.push_back(c);
571 }
572 return NodeManager::currentNM()->mkNode(nk, new_children);
573 }
574
575 Node ExtendedRewriter::extendedRewriteBcp(
576 Kind andk, Kind ork, Kind notk, std::map<Kind, bool>& bcp_kinds, Node ret)
577 {
578 Kind k = ret.getKind();
579 Assert(k == andk || k == ork);
580 Trace("ext-rew-bcp") << "BCP: **** INPUT: " << ret << std::endl;
581
582 NodeManager* nm = NodeManager::currentNM();
583
584 TypeNode tn = ret.getType();
585 Node truen = TermUtil::mkTypeMaxValue(tn);
586 Node falsen = TermUtil::mkTypeValue(tn, 0);
587
588 // terms to process
589 std::vector<Node> to_process;
590 for (const Node& cn : ret)
591 {
592 to_process.push_back(cn);
593 }
594 // the processing terms
595 std::vector<Node> clauses;
596 // the terms we have propagated information to
597 std::unordered_set<Node, NodeHashFunction> prop_clauses;
598 // the assignment
599 std::map<Node, Node> assign;
600 std::vector<Node> avars;
601 std::vector<Node> asubs;
602
603 Kind ok = k == andk ? ork : andk;
604 // global polarity : when k=ork, everything is negated
605 bool gpol = k == andk;
606
607 do
608 {
609 // process the current nodes
610 while (!to_process.empty())
611 {
612 std::vector<Node> new_to_process;
613 for (const Node& cn : to_process)
614 {
615 Trace("ext-rew-bcp-debug") << "process " << cn << std::endl;
616 Kind cnk = cn.getKind();
617 bool pol = cnk != notk;
618 Node cln = cnk == notk ? cn[0] : cn;
619 Assert(cln.getKind() != notk);
620 if ((pol && cln.getKind() == k) || (!pol && cln.getKind() == ok))
621 {
622 // flatten
623 for (const Node& ccln : cln)
624 {
625 Node lccln = pol ? ccln : TermUtil::mkNegate(notk, ccln);
626 new_to_process.push_back(lccln);
627 }
628 }
629 else
630 {
631 // add it to the assignment
632 Node val = gpol == pol ? truen : falsen;
633 std::map<Node, Node>::iterator it = assign.find(cln);
634 Trace("ext-rew-bcp") << "BCP: assign " << cln << " -> " << val
635 << std::endl;
636 if (it != assign.end())
637 {
638 if (val != it->second)
639 {
640 Trace("ext-rew-bcp") << "BCP: conflict!" << std::endl;
641 // a conflicting assignment: we are done
642 return gpol ? falsen : truen;
643 }
644 }
645 else
646 {
647 assign[cln] = val;
648 avars.push_back(cln);
649 asubs.push_back(val);
650 }
651
652 // also, treat it as clause if possible
653 if (cln.getNumChildren() > 0
654 & (bcp_kinds.empty()
655 || bcp_kinds.find(cln.getKind()) != bcp_kinds.end()))
656 {
657 if (std::find(clauses.begin(), clauses.end(), cn) == clauses.end()
658 && prop_clauses.find(cn) == prop_clauses.end())
659 {
660 Trace("ext-rew-bcp") << "BCP: new clause: " << cn << std::endl;
661 clauses.push_back(cn);
662 }
663 }
664 }
665 }
666 to_process.clear();
667 to_process.insert(
668 to_process.end(), new_to_process.begin(), new_to_process.end());
669 }
670
671 // apply substitution to all subterms of clauses
672 std::vector<Node> new_clauses;
673 for (const Node& c : clauses)
674 {
675 bool cpol = c.getKind() != notk;
676 Node ca = c.getKind() == notk ? c[0] : c;
677 bool childChanged = false;
678 std::vector<Node> ccs_children;
679 for (const Node& cc : ca)
680 {
681 Node ccs = cc;
682 if (bcp_kinds.empty())
683 {
684 Trace("ext-rew-bcp-debug") << "...do ordinary substitute"
685 << std::endl;
686 ccs = cc.substitute(
687 avars.begin(), avars.end(), asubs.begin(), asubs.end());
688 }
689 else
690 {
691 Trace("ext-rew-bcp-debug") << "...do partial substitute" << std::endl;
692 // substitution is only applicable to compatible kinds
693 ccs = partialSubstitute(ccs, assign, bcp_kinds);
694 }
695 childChanged = childChanged || ccs != cc;
696 ccs_children.push_back(ccs);
697 }
698 if (childChanged)
699 {
700 if (ca.getMetaKind() == metakind::PARAMETERIZED)
701 {
702 ccs_children.insert(ccs_children.begin(), ca.getOperator());
703 }
704 Node ccs = nm->mkNode(ca.getKind(), ccs_children);
705 ccs = cpol ? ccs : TermUtil::mkNegate(notk, ccs);
706 Trace("ext-rew-bcp") << "BCP: propagated " << c << " -> " << ccs
707 << std::endl;
708 ccs = Rewriter::rewrite(ccs);
709 Trace("ext-rew-bcp") << "BCP: rewritten to " << ccs << std::endl;
710 to_process.push_back(ccs);
711 // store this as a node that propagation touched. This marks c so that
712 // it will not be included in the final construction.
713 prop_clauses.insert(ca);
714 }
715 else
716 {
717 new_clauses.push_back(c);
718 }
719 }
720 clauses.clear();
721 clauses.insert(clauses.end(), new_clauses.begin(), new_clauses.end());
722 } while (!to_process.empty());
723
724 // remake the node
725 if (!prop_clauses.empty())
726 {
727 std::vector<Node> children;
728 for (std::pair<const Node, Node>& l : assign)
729 {
730 Node a = l.first;
731 // if propagation did not touch a
732 if (prop_clauses.find(a) == prop_clauses.end())
733 {
734 Assert(l.second == truen || l.second == falsen);
735 Node ln = (l.second == truen) == gpol ? a : TermUtil::mkNegate(notk, a);
736 children.push_back(ln);
737 }
738 }
739 Node new_ret = children.size() == 1 ? children[0] : nm->mkNode(k, children);
740 Trace("ext-rew-bcp") << "BCP: **** OUTPUT: " << new_ret << std::endl;
741 return new_ret;
742 }
743
744 return Node::null();
745 }
746
747 Node ExtendedRewriter::extendedRewriteEqRes(Kind andk,
748 Kind ork,
749 Kind eqk,
750 Kind notk,
751 std::map<Kind, bool>& bcp_kinds,
752 Node n,
753 bool isXor)
754 {
755 Assert(n.getKind() == andk || n.getKind() == ork);
756 Trace("ext-rew-eqres") << "Eq res: **** INPUT: " << n << std::endl;
757
758 NodeManager* nm = NodeManager::currentNM();
759 Kind nk = n.getKind();
760 bool gpol = (nk == andk);
761 for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
762 {
763 Node lit = n[i];
764 if (lit.getKind() == eqk)
765 {
766 // eq is the equality we are basing a substitution on
767 Node eq;
768 if (gpol == isXor)
769 {
770 // can only turn disequality into equality if types are the same
771 if (lit[1].getType() == lit.getType())
772 {
773 // t != s ---> ~t = s
774 if (lit[1].getKind() == notk && lit[0].getKind() != notk)
775 {
776 eq = nm->mkNode(EQUAL, lit[0], TermUtil::mkNegate(notk, lit[1]));
777 }
778 else
779 {
780 eq = nm->mkNode(EQUAL, TermUtil::mkNegate(notk, lit[0]), lit[1]);
781 }
782 }
783 }
784 else
785 {
786 eq = eqk == EQUAL ? lit : nm->mkNode(EQUAL, lit[0], lit[1]);
787 }
788 if (!eq.isNull())
789 {
790 // see if it corresponds to a substitution
791 std::vector<Node> vars;
792 std::vector<Node> subs;
793 if (inferSubstitution(eq, vars, subs))
794 {
795 Assert(vars.size() == 1);
796 std::vector<Node> children;
797 bool childrenChanged = false;
798 // apply to all other children
799 for (unsigned j = 0; j < nchild; j++)
800 {
801 Node ccs = n[j];
802 if (i != j)
803 {
804 if (bcp_kinds.empty())
805 {
806 ccs = ccs.substitute(
807 vars.begin(), vars.end(), subs.begin(), subs.end());
808 }
809 else
810 {
811 std::map<Node, Node> assign;
812 // vars.size()==subs.size()==1
813 assign[vars[0]] = subs[0];
814 // substitution is only applicable to compatible kinds
815 ccs = partialSubstitute(ccs, assign, bcp_kinds);
816 }
817 childrenChanged = childrenChanged || n[j] != ccs;
818 }
819 children.push_back(ccs);
820 }
821 if (childrenChanged)
822 {
823 return nm->mkNode(nk, children);
824 }
825 }
826 }
827 }
828 }
829
830 return Node::null();
831 }
832
833 Node ExtendedRewriter::extendedRewriteEqChain(
834 Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor)
835 {
836 Assert(ret.getKind() == eqk);
837
838 NodeManager* nm = NodeManager::currentNM();
839
840 TypeNode tn = ret[0].getType();
841
842 // sort/cancelling for Boolean EQUAL/XOR-chains
843 Trace("ext-rew-eqchain") << "Eq-Chain : " << ret << std::endl;
844
845 // get the children on either side
846 bool gpol = true;
847 std::vector<Node> children;
848 for (unsigned r = 0, size = ret.getNumChildren(); r < size; r++)
849 {
850 Node curr = ret[r];
851 // assume, if necessary, right associative
852 while (curr.getKind() == eqk && curr[0].getType() == tn)
853 {
854 children.push_back(curr[0]);
855 curr = curr[1];
856 }
857 children.push_back(curr);
858 }
859
860 std::map<Node, bool> cstatus;
861 // add children to status
862 for (const Node& c : children)
863 {
864 Node a = c;
865 if (a.getKind() == notk)
866 {
867 gpol = !gpol;
868 a = a[0];
869 }
870 Trace("ext-rew-eqchain") << "...child : " << a << std::endl;
871 std::map<Node, bool>::iterator itc = cstatus.find(a);
872 if (itc == cstatus.end())
873 {
874 cstatus[a] = true;
875 }
876 else
877 {
878 // cancels
879 cstatus.erase(a);
880 if (isXor)
881 {
882 gpol = !gpol;
883 }
884 }
885 }
886 Trace("ext-rew-eqchain") << "Global polarity : " << gpol << std::endl;
887
888 if (cstatus.empty())
889 {
890 return TermUtil::mkTypeConst(tn, gpol);
891 }
892
893 children.clear();
894
895 // cancel AND/OR children if possible
896 for (std::pair<const Node, bool>& cp : cstatus)
897 {
898 if (cp.second)
899 {
900 Node c = cp.first;
901 Kind ck = c.getKind();
902 if (ck == andk || ck == ork)
903 {
904 for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++)
905 {
906 Node cl = c[j];
907 Node ca = cl.getKind() == notk ? cl[0] : cl;
908 bool capol = cl.getKind() != notk;
909 // if this already exists as a child of the equality chain
910 std::map<Node, bool>::iterator itc = cstatus.find(ca);
911 if (itc != cstatus.end() && itc->second)
912 {
913 // cancel it
914 cstatus[ca] = false;
915 cstatus[c] = false;
916 // make new child
917 // x = ( y | ~x ) ---> y & x
918 // x = ( y | x ) ---> ~y | x
919 // x = ( y & x ) ---> y | ~x
920 // x = ( y & ~x ) ---> ~y & ~x
921 std::vector<Node> new_children;
922 for (unsigned k = 0, nchild = c.getNumChildren(); k < nchild; k++)
923 {
924 if (j != k)
925 {
926 new_children.push_back(c[k]);
927 }
928 }
929 Node nc[2];
930 nc[0] = c[j];
931 nc[1] = new_children.size() == 1 ? new_children[0]
932 : nm->mkNode(ck, new_children);
933 // negate the proper child
934 unsigned nindex = (ck == andk) == capol ? 0 : 1;
935 nc[nindex] = TermUtil::mkNegate(notk, nc[nindex]);
936 Kind nk = capol ? ork : andk;
937 // store as new child
938 children.push_back(nm->mkNode(nk, nc[0], nc[1]));
939 if (isXor)
940 {
941 gpol = !gpol;
942 }
943 break;
944 }
945 }
946 }
947 }
948 }
949
950 // sorted right associative chain
951 bool has_nvar = false;
952 unsigned nvar_index = 0;
953 for (std::pair<const Node, bool>& cp : cstatus)
954 {
955 if (cp.second)
956 {
957 if (!cp.first.isVar())
958 {
959 has_nvar = true;
960 nvar_index = children.size();
961 }
962 children.push_back(cp.first);
963 }
964 }
965 std::sort(children.begin(), children.end());
966
967 Node new_ret;
968 if (!gpol)
969 {
970 // negate the constant child if it exists
971 unsigned nindex = has_nvar ? nvar_index : 0;
972 children[nindex] = TermUtil::mkNegate(notk, children[nindex]);
973 }
974 new_ret = children.back();
975 unsigned index = children.size() - 1;
976 while (index > 0)
977 {
978 index--;
979 new_ret = nm->mkNode(eqk, children[index], new_ret);
980 }
981 new_ret = Rewriter::rewrite(new_ret);
982 if (new_ret != ret)
983 {
984 return new_ret;
985 }
986 return Node::null();
987 }
988
989 Node ExtendedRewriter::partialSubstitute(Node n,
990 std::map<Node, Node>& assign,
991 std::map<Kind, bool>& rkinds)
992 {
993 std::unordered_map<TNode, Node, TNodeHashFunction> visited;
994 std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
995 std::vector<TNode> visit;
996 TNode cur;
997 visit.push_back(n);
998 do
999 {
1000 cur = visit.back();
1001 visit.pop_back();
1002 it = visited.find(cur);
1003
1004 if (it == visited.end())
1005 {
1006 std::map<Node, Node>::iterator it = assign.find(cur);
1007 if (it != assign.end())
1008 {
1009 visited[cur] = it->second;
1010 }
1011 else
1012 {
1013 // can only recurse on these kinds
1014 Kind k = cur.getKind();
1015 if (rkinds.find(k) != rkinds.end())
1016 {
1017 visited[cur] = Node::null();
1018 visit.push_back(cur);
1019 for (const Node& cn : cur)
1020 {
1021 visit.push_back(cn);
1022 }
1023 }
1024 else
1025 {
1026 visited[cur] = cur;
1027 }
1028 }
1029 }
1030 else if (it->second.isNull())
1031 {
1032 Node ret = cur;
1033 bool childChanged = false;
1034 std::vector<Node> children;
1035 if (cur.getMetaKind() == metakind::PARAMETERIZED)
1036 {
1037 children.push_back(cur.getOperator());
1038 }
1039 for (const Node& cn : cur)
1040 {
1041 it = visited.find(cn);
1042 Assert(it != visited.end());
1043 Assert(!it->second.isNull());
1044 childChanged = childChanged || cn != it->second;
1045 children.push_back(it->second);
1046 }
1047 if (childChanged)
1048 {
1049 ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
1050 }
1051 visited[cur] = ret;
1052 }
1053 } while (!visit.empty());
1054 Assert(visited.find(n) != visited.end());
1055 Assert(!visited.find(n)->second.isNull());
1056 return visited[n];
1057 }
1058
1059 Node ExtendedRewriter::solveEquality(Node n)
1060 {
1061 // TODO (#1706) : implement
1062 Assert(n.getKind() == EQUAL);
1063
1064 return Node::null();
1065 }
1066
1067 bool ExtendedRewriter::inferSubstitution(Node n,
1068 std::vector<Node>& vars,
1069 std::vector<Node>& subs)
1070 {
1071 if (n.getKind() == EQUAL)
1072 {
1073 // see if it can be put into form x = y
1074 Node slv_eq = solveEquality(n);
1075 if (!slv_eq.isNull())
1076 {
1077 n = slv_eq;
1078 }
1079 Node v[2];
1080 for (unsigned i = 0; i < 2; i++)
1081 {
1082 if (n[i].isVar() || n[i].isConst())
1083 {
1084 v[i] = n[i];
1085 }
1086 else if (TermUtil::isNegate(n[i].getKind()) && n[i][0].isVar())
1087 {
1088 v[i] = n[i][0];
1089 }
1090 }
1091 for (unsigned i = 0; i < 2; i++)
1092 {
1093 TNode r1 = v[i];
1094 Node r2 = v[1 - i];
1095 if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
1096 {
1097 r2 = n[1 - i];
1098 if (v[i] != n[i])
1099 {
1100 Assert( TermUtil::isNegate( n[i].getKind() ) );
1101 r2 = TermUtil::mkNegate(n[i].getKind(), r2);
1102 }
1103 // TODO (#1706) : union find
1104 if (std::find(vars.begin(), vars.end(), r1) == vars.end())
1105 {
1106 vars.push_back(r1);
1107 subs.push_back(r2);
1108 return true;
1109 }
1110 }
1111 }
1112 }
1113 return false;
1114 }
1115
1116 Node ExtendedRewriter::extendedRewriteArith(Node ret, bool& pol)
1117 {
1118 Kind k = ret.getKind();
1119 NodeManager* nm = NodeManager::currentNM();
1120 Node new_ret;
1121 if (k == DIVISION || k == INTS_DIVISION || k == INTS_MODULUS)
1122 {
1123 // rewrite as though total
1124 std::vector<Node> children;
1125 bool all_const = true;
1126 for (unsigned i = 0, size = ret.getNumChildren(); i < size; i++)
1127 {
1128 if (ret[i].isConst())
1129 {
1130 children.push_back(ret[i]);
1131 }
1132 else
1133 {
1134 all_const = false;
1135 break;
1136 }
1137 }
1138 if (all_const)
1139 {
1140 Kind new_k = (ret.getKind() == DIVISION ? DIVISION_TOTAL
1141 : (ret.getKind() == INTS_DIVISION
1142 ? INTS_DIVISION_TOTAL
1143 : INTS_MODULUS_TOTAL));
1144 new_ret = nm->mkNode(new_k, children);
1145 debugExtendedRewrite(ret, new_ret, "total-interpretation");
1146 }
1147 }
1148 return new_ret;
1149 }
1150
1151 void ExtendedRewriter::debugExtendedRewrite(Node n,
1152 Node ret,
1153 const char* c) const
1154 {
1155 if (Trace.isOn("q-ext-rewrite"))
1156 {
1157 if (!ret.isNull())
1158 {
1159 Trace("q-ext-rewrite-apply") << "sygus-extr : apply " << c << std::endl;
1160 Trace("q-ext-rewrite") << "sygus-extr : " << c << " : " << n
1161 << " rewrites to " << ret << std::endl;
1162 }
1163 }
1164 }
1165
1166 } /* CVC4::theory::quantifiers namespace */
1167 } /* CVC4::theory namespace */
1168 } /* CVC4 namespace */