1 /********************* */
2 /*! \file arith_entail.cpp
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Andres Noetzli
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
12 ** \brief Implementation of arithmetic entailment computation for string terms.
15 #include "theory/strings/arith_entail.h"
17 #include "expr/attribute.h"
18 #include "expr/node_algorithm.h"
19 #include "theory/arith/arith_msum.h"
20 #include "theory/rewriter.h"
21 #include "theory/strings/theory_strings_utils.h"
22 #include "theory/strings/word.h"
23 #include "theory/theory.h"
25 using namespace CVC4::kind
;
31 bool ArithEntail::checkEq(Node a
, Node b
)
37 Node ar
= Rewriter::rewrite(a
);
38 Node br
= Rewriter::rewrite(b
);
42 bool ArithEntail::check(Node a
, Node b
, bool strict
)
48 Node diff
= NodeManager::currentNM()->mkNode(kind::MINUS
, a
, b
);
49 return check(diff
, strict
);
52 struct StrCheckEntailArithTag
55 struct StrCheckEntailArithComputedTag
58 /** Attribute true for expressions for which check returned true */
59 typedef expr::Attribute
<StrCheckEntailArithTag
, bool> StrCheckEntailArithAttr
;
60 typedef expr::Attribute
<StrCheckEntailArithComputedTag
, bool>
61 StrCheckEntailArithComputedAttr
;
63 bool ArithEntail::check(Node a
, bool strict
)
67 return a
.getConst
<Rational
>().sgn() >= (strict
? 1 : 0);
70 Node ar
= strict
? NodeManager::currentNM()->mkNode(
71 kind::MINUS
, a
, NodeManager::currentNM()->mkConst(Rational(1)))
73 ar
= Rewriter::rewrite(ar
);
75 if (ar
.getAttribute(StrCheckEntailArithComputedAttr()))
77 return ar
.getAttribute(StrCheckEntailArithAttr());
80 bool ret
= checkInternal(ar
);
83 // try with approximations
84 ret
= checkApprox(ar
);
87 ar
.setAttribute(StrCheckEntailArithAttr(), ret
);
88 ar
.setAttribute(StrCheckEntailArithComputedAttr(), true);
92 bool ArithEntail::checkApprox(Node ar
)
94 Assert(Rewriter::rewrite(ar
) == ar
);
95 NodeManager
* nm
= NodeManager::currentNM();
96 std::map
<Node
, Node
> msum
;
97 Trace("strings-ent-approx-debug")
98 << "Setup arithmetic approximations for " << ar
<< std::endl
;
99 if (!ArithMSum::getMonomialSum(ar
, msum
))
101 Trace("strings-ent-approx-debug")
102 << "...failed to get monomial sum!" << std::endl
;
105 // for each monomial v*c, mApprox[v] a list of
106 // possibilities for how the term can be soundly approximated, that is,
107 // if mApprox[v] contains av, then v*c > av*c. Notice that if c
108 // is positive, then v > av, otherwise if c is negative, then v < av.
109 // In other words, av is an under-approximation if c is positive, and an
110 // over-approximation if c is negative.
111 bool changed
= false;
112 std::map
<Node
, std::vector
<Node
> > mApprox
;
113 // map from approximations to their monomial sums
114 std::map
<Node
, std::map
<Node
, Node
> > approxMsums
;
115 // aarSum stores each monomial that does not have multiple approximations
116 std::vector
<Node
> aarSum
;
117 for (std::pair
<const Node
, Node
>& m
: msum
)
121 Trace("strings-ent-approx-debug")
122 << "Get approximations " << v
<< "..." << std::endl
;
125 Node mn
= c
.isNull() ? nm
->mkConst(Rational(1)) : c
;
126 aarSum
.push_back(mn
);
130 // c.isNull() means c = 1
131 bool isOverApprox
= !c
.isNull() && c
.getConst
<Rational
>().sgn() == -1;
132 std::vector
<Node
>& approx
= mApprox
[v
];
133 std::unordered_set
<Node
, NodeHashFunction
> visited
;
134 std::vector
<Node
> toProcess
;
135 toProcess
.push_back(v
);
138 Node curr
= toProcess
.back();
139 Trace("strings-ent-approx-debug") << " process " << curr
<< std::endl
;
140 curr
= Rewriter::rewrite(curr
);
141 toProcess
.pop_back();
142 if (visited
.find(curr
) == visited
.end())
144 visited
.insert(curr
);
145 std::vector
<Node
> currApprox
;
146 getArithApproximations(curr
, currApprox
, isOverApprox
);
147 if (currApprox
.empty())
149 Trace("strings-ent-approx-debug")
150 << "...approximation: " << curr
<< std::endl
;
151 // no approximations, thus curr is a possibility
152 approx
.push_back(curr
);
157 toProcess
.end(), currApprox
.begin(), currApprox
.end());
160 } while (!toProcess
.empty());
161 Assert(!approx
.empty());
162 // if we have only one approximation, move it to final
163 if (approx
.size() == 1)
165 changed
= v
!= approx
[0];
166 Node mn
= ArithMSum::mkCoeffTerm(c
, approx
[0]);
167 aarSum
.push_back(mn
);
172 // compute monomial sum form for each approximation, used below
173 for (const Node
& aa
: approx
)
175 if (approxMsums
.find(aa
) == approxMsums
.end())
177 CVC4_UNUSED
bool ret
=
178 ArithMSum::getMonomialSum(aa
, approxMsums
[aa
]);
188 // approximations had no effect, return
189 Trace("strings-ent-approx-debug") << "...no approximations" << std::endl
;
192 // get the current "fixed" sum for the abstraction of ar
193 Node aar
= aarSum
.empty()
194 ? nm
->mkConst(Rational(0))
195 : (aarSum
.size() == 1 ? aarSum
[0] : nm
->mkNode(PLUS
, aarSum
));
196 aar
= Rewriter::rewrite(aar
);
197 Trace("strings-ent-approx-debug")
198 << "...processed fixed sum " << aar
<< " with " << mApprox
.size()
199 << " approximated monomials." << std::endl
;
200 // if we have a choice of how to approximate
201 if (!mApprox
.empty())
203 // convert aar back to monomial sum
204 std::map
<Node
, Node
> msumAar
;
205 if (!ArithMSum::getMonomialSum(aar
, msumAar
))
209 if (Trace
.isOn("strings-ent-approx"))
211 Trace("strings-ent-approx")
212 << "---- Check arithmetic entailment by under-approximation " << ar
213 << " >= 0" << std::endl
;
214 Trace("strings-ent-approx") << "FIXED:" << std::endl
;
215 ArithMSum::debugPrintMonomialSum(msumAar
, "strings-ent-approx");
216 Trace("strings-ent-approx") << "APPROX:" << std::endl
;
217 for (std::pair
<const Node
, std::vector
<Node
> >& a
: mApprox
)
219 Node c
= msum
[a
.first
];
220 Trace("strings-ent-approx") << " ";
223 Trace("strings-ent-approx") << c
<< " * ";
225 Trace("strings-ent-approx")
226 << a
.second
<< " ...from " << a
.first
<< std::endl
;
228 Trace("strings-ent-approx") << std::endl
;
231 // incorporate monomials one at a time that have a choice of approximations
232 while (!mApprox
.empty())
237 // Look at each approximation, take the one with the best score.
238 // Notice that we are in the process of trying to prove
239 // ( c1*t1 + .. + cn*tn ) + ( approx_1 | ... | approx_m ) >= 0,
240 // where c1*t1 + .. + cn*tn is the "fixed" component of our sum (aar)
241 // and approx_1 ... approx_m are possible approximations. The
242 // intution here is that we want coefficients c1...cn to be positive.
243 // This is because arithmetic string terms t1...tn (which may be
244 // applications of len, indexof, str.to.int) are never entailed to be
245 // negative. Hence, we add the approx_i that contributes the "most"
246 // towards making all constants c1...cn positive and cancelling negative
247 // monomials in approx_i itself.
248 for (std::pair
<const Node
, std::vector
<Node
> >& nam
: mApprox
)
250 Node cr
= msum
[nam
.first
];
251 for (const Node
& aa
: nam
.second
)
253 unsigned helpsCancelCount
= 0;
254 unsigned addsObligationCount
= 0;
255 std::map
<Node
, Node
>::iterator it
;
256 // we are processing an approximation cr*( c1*t1 + ... + cn*tn )
257 for (std::pair
<const Node
, Node
>& aam
: approxMsums
[aa
])
259 // Say aar is of the form t + c*ti, and aam is the monomial ci*ti
260 // where ci != 0. We say aam:
261 // (1) helps cancel if c != 0 and c>0 != ci>0
262 // (2) adds obligation if c>=0 and c+ci<0
264 Node ci
= aam
.second
;
267 ci
= ci
.isNull() ? cr
268 : Rewriter::rewrite(nm
->mkNode(MULT
, ci
, cr
));
270 Trace("strings-ent-approx-debug") << ci
<< "*" << ti
<< " ";
271 int ciSgn
= ci
.isNull() ? 1 : ci
.getConst
<Rational
>().sgn();
272 it
= msumAar
.find(ti
);
273 if (it
!= msumAar
.end())
276 int cSgn
= c
.isNull() ? 1 : c
.getConst
<Rational
>().sgn();
279 addsObligationCount
+= (ciSgn
== -1 ? 1 : 0);
281 else if (cSgn
!= ciSgn
)
284 Rational r1
= c
.isNull() ? one
: c
.getConst
<Rational
>();
285 Rational r2
= ci
.isNull() ? one
: ci
.getConst
<Rational
>();
286 Rational r12
= r1
+ r2
;
289 addsObligationCount
++;
295 addsObligationCount
+= (ciSgn
== -1 ? 1 : 0);
298 Trace("strings-ent-approx-debug")
299 << "counts=" << helpsCancelCount
<< "," << addsObligationCount
300 << " for " << aa
<< " into " << aar
<< std::endl
;
301 int score
= (addsObligationCount
> 0 ? 0 : 2)
302 + (helpsCancelCount
> 0 ? 1 : 0);
303 // if its the best, update v and vapprox
304 if (v
.isNull() || score
> maxScore
)
316 Trace("strings-ent-approx")
317 << "- Decide " << v
<< " = " << vapprox
<< std::endl
;
318 // we incorporate v approximated by vapprox into the overall approximation
320 Assert(!v
.isNull() && !vapprox
.isNull());
321 Assert(msum
.find(v
) != msum
.end());
322 Node mn
= ArithMSum::mkCoeffTerm(msum
[v
], vapprox
);
323 aar
= nm
->mkNode(PLUS
, aar
, mn
);
324 // update the msumAar map
325 aar
= Rewriter::rewrite(aar
);
327 if (!ArithMSum::getMonomialSum(aar
, msumAar
))
330 Trace("strings-ent-approx")
331 << "...failed to get monomial sum!" << std::endl
;
334 // we have processed the approximation for v
337 Trace("strings-ent-approx") << "-----------------" << std::endl
;
341 Trace("strings-ent-approx-debug")
342 << "...approximation had no effect" << std::endl
;
343 // this should never happen, but we avoid the infinite loop for sanity here
347 // Check entailment on the approximation of ar.
348 // Notice that this may trigger further reasoning by approximation. For
349 // example, len( replace( x ++ y, substr( x, 0, n ), z ) ) may be
350 // under-approximated as len( x ) + len( y ) - len( substr( x, 0, n ) ) on
351 // this call, where in the recursive call we may over-approximate
352 // len( substr( x, 0, n ) ) as len( x ). In this example, we can infer
353 // that len( replace( x ++ y, substr( x, 0, n ), z ) ) >= len( y ) in two
357 Trace("strings-ent-approx")
358 << "*** StrArithApprox: showed " << ar
359 << " >= 0 using under-approximation!" << std::endl
;
360 Trace("strings-ent-approx")
361 << "*** StrArithApprox: under-approximation was " << aar
<< std::endl
;
367 void ArithEntail::getArithApproximations(Node a
,
368 std::vector
<Node
>& approx
,
371 NodeManager
* nm
= NodeManager::currentNM();
372 // We do not handle PLUS here since this leads to exponential behavior.
373 // Instead, this is managed, e.g. during checkApprox, where
374 // PLUS terms are expanded "on-demand" during the reasoning.
375 Trace("strings-ent-approx-debug")
376 << "Get arith approximations " << a
<< std::endl
;
377 Kind ak
= a
.getKind();
382 if (ArithMSum::getMonomial(a
, c
, v
))
384 bool isNeg
= c
.getConst
<Rational
>().sgn() == -1;
385 getArithApproximations(v
, approx
, isNeg
? !isOverApprox
: isOverApprox
);
386 for (unsigned i
= 0, size
= approx
.size(); i
< size
; i
++)
388 approx
[i
] = nm
->mkNode(MULT
, c
, approx
[i
]);
392 else if (ak
== STRING_LENGTH
)
394 Kind aak
= a
[0].getKind();
395 if (aak
== STRING_SUBSTR
)
397 // over,under-approximations for len( substr( x, n, m ) )
398 Node lenx
= nm
->mkNode(STRING_LENGTH
, a
[0][0]);
402 // m >= len( substr( x, n, m ) )
405 approx
.push_back(a
[0][2]);
407 if (check(lenx
, a
[0][1]))
409 // n <= len( x ) implies
410 // len( x ) - n >= len( substr( x, n, m ) )
411 approx
.push_back(nm
->mkNode(MINUS
, lenx
, a
[0][1]));
415 // len( x ) >= len( substr( x, n, m ) )
416 approx
.push_back(lenx
);
421 // 0 <= n and n+m <= len( x ) implies
422 // m <= len( substr( x, n, m ) )
423 Node npm
= nm
->mkNode(PLUS
, a
[0][1], a
[0][2]);
424 if (check(a
[0][1]) && check(lenx
, npm
))
426 approx
.push_back(a
[0][2]);
428 // 0 <= n and n+m >= len( x ) implies
429 // len(x)-n <= len( substr( x, n, m ) )
430 if (check(a
[0][1]) && check(npm
, lenx
))
432 approx
.push_back(nm
->mkNode(MINUS
, lenx
, a
[0][1]));
436 else if (aak
== STRING_STRREPL
)
438 // over,under-approximations for len( replace( x, y, z ) )
439 // notice this is either len( x ) or ( len( x ) + len( z ) - len( y ) )
440 Node lenx
= nm
->mkNode(STRING_LENGTH
, a
[0][0]);
441 Node leny
= nm
->mkNode(STRING_LENGTH
, a
[0][1]);
442 Node lenz
= nm
->mkNode(STRING_LENGTH
, a
[0][2]);
445 if (check(leny
, lenz
))
447 // len( y ) >= len( z ) implies
448 // len( x ) >= len( replace( x, y, z ) )
449 approx
.push_back(lenx
);
453 // len( x ) + len( z ) >= len( replace( x, y, z ) )
454 approx
.push_back(nm
->mkNode(PLUS
, lenx
, lenz
));
459 if (check(lenz
, leny
) || check(lenz
, lenx
))
461 // len( y ) <= len( z ) or len( x ) <= len( z ) implies
462 // len( x ) <= len( replace( x, y, z ) )
463 approx
.push_back(lenx
);
467 // len( x ) - len( y ) <= len( replace( x, y, z ) )
468 approx
.push_back(nm
->mkNode(MINUS
, lenx
, leny
));
472 else if (aak
== STRING_ITOS
)
474 // over,under-approximations for len( int.to.str( x ) )
477 if (check(a
[0][0], false))
479 if (check(a
[0][0], true))
482 // x >= len( int.to.str( x ) )
483 approx
.push_back(a
[0][0]);
488 // x+1 >= len( int.to.str( x ) )
490 nm
->mkNode(PLUS
, nm
->mkConst(Rational(1)), a
[0][0]));
499 // len( int.to.str( x ) ) >= 1
500 approx
.push_back(nm
->mkConst(Rational(1)));
502 // other crazy things are possible here, e.g.
503 // len( int.to.str( len( y ) + 10 ) ) >= 2
507 else if (ak
== STRING_STRIDOF
)
509 // over,under-approximations for indexof( x, y, n )
512 Node lenx
= nm
->mkNode(STRING_LENGTH
, a
[0]);
513 Node leny
= nm
->mkNode(STRING_LENGTH
, a
[1]);
514 if (check(lenx
, leny
))
516 // len( x ) >= len( y ) implies
517 // len( x ) - len( y ) >= indexof( x, y, n )
518 approx
.push_back(nm
->mkNode(MINUS
, lenx
, leny
));
522 // len( x ) >= indexof( x, y, n )
523 approx
.push_back(lenx
);
529 // contains( substr( x, n, len( x ) ), y ) implies
530 // n <= indexof( x, y, n )
531 // ...hard to test, runs risk of non-termination
533 // -1 <= indexof( x, y, n )
534 approx
.push_back(nm
->mkConst(Rational(-1)));
537 else if (ak
== STRING_STOI
)
539 // over,under-approximations for str.to.int( x )
544 // y >= str.to.int( int.to.str( y ) )
548 // -1 <= str.to.int( x )
549 approx
.push_back(nm
->mkConst(Rational(-1)));
552 Trace("strings-ent-approx-debug") << "Return " << approx
.size() << std::endl
;
555 bool ArithEntail::checkWithEqAssumption(Node assumption
, Node a
, bool strict
)
557 Assert(assumption
.getKind() == kind::EQUAL
);
558 Assert(Rewriter::rewrite(assumption
) == assumption
);
559 Trace("strings-entail") << "checkWithEqAssumption: " << assumption
<< " " << a
560 << ", strict=" << strict
<< std::endl
;
562 // Find candidates variables to compute substitutions for
563 std::unordered_set
<Node
, NodeHashFunction
> candVars
;
564 std::vector
<Node
> toVisit
= {assumption
};
565 while (!toVisit
.empty())
567 Node curr
= toVisit
.back();
570 if (curr
.getKind() == kind::PLUS
|| curr
.getKind() == kind::MULT
571 || curr
.getKind() == kind::MINUS
|| curr
.getKind() == kind::EQUAL
)
573 for (const auto& currChild
: curr
)
575 toVisit
.push_back(currChild
);
578 else if (curr
.isVar() && Theory::theoryOf(curr
) == THEORY_ARITH
)
580 candVars
.insert(curr
);
582 else if (curr
.getKind() == kind::STRING_LENGTH
)
584 candVars
.insert(curr
);
588 // Check if any of the candidate variables are in n
590 Assert(toVisit
.empty());
591 toVisit
.push_back(a
);
592 while (!toVisit
.empty())
594 Node curr
= toVisit
.back();
597 for (const auto& currChild
: curr
)
599 toVisit
.push_back(currChild
);
602 if (candVars
.find(curr
) != candVars
.end())
611 // No suitable candidate found
615 Node solution
= ArithMSum::solveEqualityFor(assumption
, v
);
616 if (solution
.isNull())
618 // Could not solve for v
621 Trace("strings-entail") << "checkWithEqAssumption: subs " << v
<< " -> "
622 << solution
<< std::endl
;
624 // use capture avoiding substitution
625 a
= expr::substituteCaptureAvoiding(a
, v
, solution
);
626 return check(a
, strict
);
629 bool ArithEntail::checkWithAssumption(Node assumption
,
634 Assert(Rewriter::rewrite(assumption
) == assumption
);
636 NodeManager
* nm
= NodeManager::currentNM();
638 if (!assumption
.isConst() && assumption
.getKind() != kind::EQUAL
)
640 // We rewrite inequality assumptions from x <= y to x + (str.len s) = y
641 // where s is some fresh string variable. We use (str.len s) because
642 // (str.len s) must be non-negative for the equation to hold.
644 if (assumption
.getKind() == kind::GEQ
)
651 // (not (>= s t)) --> (>= (t - 1) s)
652 Assert(assumption
.getKind() == kind::NOT
653 && assumption
[0].getKind() == kind::GEQ
);
654 x
= nm
->mkNode(kind::MINUS
, assumption
[0][1], nm
->mkConst(Rational(1)));
655 y
= assumption
[0][0];
658 Node s
= nm
->mkBoundVar("slackVal", nm
->stringType());
659 Node slen
= nm
->mkNode(kind::STRING_LENGTH
, s
);
660 assumption
= Rewriter::rewrite(
661 nm
->mkNode(kind::EQUAL
, x
, nm
->mkNode(kind::PLUS
, y
, slen
)));
664 Node diff
= nm
->mkNode(kind::MINUS
, a
, b
);
666 if (assumption
.isConst())
668 bool assumptionBool
= assumption
.getConst
<bool>();
671 res
= check(diff
, strict
);
680 res
= checkWithEqAssumption(assumption
, diff
, strict
);
685 bool ArithEntail::checkWithAssumptions(std::vector
<Node
> assumptions
,
690 // TODO: We currently try to show the entailment with each assumption
691 // independently. In the future, we should make better use of multiple
694 for (const auto& assumption
: assumptions
)
696 Assert(Rewriter::rewrite(assumption
) == assumption
);
698 if (checkWithAssumption(assumption
, a
, b
, strict
))
707 Node
ArithEntail::getConstantBound(Node a
, bool isLower
)
709 Assert(Rewriter::rewrite(a
) == a
);
715 else if (a
.getKind() == kind::STRING_LENGTH
)
719 ret
= NodeManager::currentNM()->mkConst(Rational(0));
722 else if (a
.getKind() == kind::PLUS
|| a
.getKind() == kind::MULT
)
724 std::vector
<Node
> children
;
726 for (unsigned i
= 0; i
< a
.getNumChildren(); i
++)
728 Node ac
= getConstantBound(a
[i
], isLower
);
737 if (ac
.getConst
<Rational
>().sgn() == 0)
739 if (a
.getKind() == kind::MULT
)
748 if (a
.getKind() == kind::MULT
)
750 if ((ac
.getConst
<Rational
>().sgn() > 0) != isLower
)
757 children
.push_back(ac
);
763 if (children
.empty())
765 ret
= NodeManager::currentNM()->mkConst(Rational(0));
767 else if (children
.size() == 1)
773 ret
= NodeManager::currentNM()->mkNode(a
.getKind(), children
);
774 ret
= Rewriter::rewrite(ret
);
778 Trace("strings-rewrite-cbound")
779 << "Constant " << (isLower
? "lower" : "upper") << " bound for " << a
780 << " is " << ret
<< std::endl
;
781 Assert(ret
.isNull() || ret
.isConst());
782 // entailment check should be at least as powerful as computing a lower bound
783 Assert(!isLower
|| ret
.isNull() || ret
.getConst
<Rational
>().sgn() < 0
785 Assert(!isLower
|| ret
.isNull() || ret
.getConst
<Rational
>().sgn() <= 0
790 bool ArithEntail::checkInternal(Node a
)
792 Assert(Rewriter::rewrite(a
) == a
);
793 // check whether a >= 0
796 return a
.getConst
<Rational
>().sgn() >= 0;
798 else if (a
.getKind() == kind::STRING_LENGTH
)
803 else if (a
.getKind() == kind::PLUS
|| a
.getKind() == kind::MULT
)
805 for (unsigned i
= 0; i
< a
.getNumChildren(); i
++)
807 if (!checkInternal(a
[i
]))
812 // t1 >= 0 ^ ... ^ tn >= 0 => t1 op ... op tn >= 0
819 bool ArithEntail::inferZerosInSumGeq(Node x
,
820 std::vector
<Node
>& ys
,
821 std::vector
<Node
>& zeroYs
)
823 Assert(zeroYs
.empty());
825 NodeManager
* nm
= NodeManager::currentNM();
827 // Check if we can show that y1 + ... + yn >= x
828 Node sum
= (ys
.size() > 1) ? nm
->mkNode(PLUS
, ys
) : ys
[0];
834 // Try to remove yi one-by-one and check if we can still show:
836 // y1 + ... + yi-1 + yi+1 + ... + yn >= x
838 // If that's the case, we know that yi can be zero and the inequality still
841 while (i
< ys
.size())
844 std::vector
<Node
>::iterator pos
= ys
.erase(ys
.begin() + i
);
847 sum
= nm
->mkNode(PLUS
, ys
);
851 sum
= ys
.size() == 1 ? ys
[0] : nm
->mkConst(Rational(0));
856 zeroYs
.push_back(yi
);
867 } // namespace strings
868 } // namespace theory