62bdf310bec42acffebd0a37934c8fbe91c2654d
[cvc5.git] / src / theory / arith / nl_model.cpp
1 /********************* */
2 /*! \file nl_model.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Model object for the non-linear extension class
13 **/
14
15 #include "theory/arith/nl_model.h"
16
17 #include "expr/node_algorithm.h"
18 #include "theory/arith/arith_msum.h"
19 #include "theory/arith/arith_utilities.h"
20 #include "theory/rewriter.h"
21
22 using namespace CVC4::kind;
23
24 namespace CVC4 {
25 namespace theory {
26 namespace arith {
27
28 NlModel::NlModel(context::Context* c) : d_used_approx(false)
29 {
30 d_true = NodeManager::currentNM()->mkConst(true);
31 d_false = NodeManager::currentNM()->mkConst(false);
32 d_zero = NodeManager::currentNM()->mkConst(Rational(0));
33 d_one = NodeManager::currentNM()->mkConst(Rational(1));
34 d_two = NodeManager::currentNM()->mkConst(Rational(2));
35 }
36
37 NlModel::~NlModel() {}
38
39 void NlModel::reset(TheoryModel* m)
40 {
41 d_model = m;
42 d_mv[0].clear();
43 d_mv[1].clear();
44 }
45
46 void NlModel::resetCheck()
47 {
48 d_used_approx = false;
49 d_check_model_solved.clear();
50 d_check_model_bounds.clear();
51 d_check_model_vars.clear();
52 d_check_model_subs.clear();
53 }
54
55 Node NlModel::computeConcreteModelValue(Node n)
56 {
57 return computeModelValue(n, true);
58 }
59
60 Node NlModel::computeAbstractModelValue(Node n)
61 {
62 return computeModelValue(n, false);
63 }
64
65 Node NlModel::computeModelValue(Node n, bool isConcrete)
66 {
67 unsigned index = isConcrete ? 0 : 1;
68 std::map<Node, Node>::iterator it = d_mv[index].find(n);
69 if (it != d_mv[index].end())
70 {
71 return it->second;
72 }
73 Trace("nl-ext-mv-debug") << "computeModelValue " << n << ", index=" << index
74 << std::endl;
75 Node ret;
76 if (n.isConst())
77 {
78 ret = n;
79 }
80 else if (index == 1
81 && (n.getKind() == NONLINEAR_MULT
82 || isTranscendentalKind(n.getKind())))
83 {
84 if (hasTerm(n))
85 {
86 // use model value for abstraction
87 ret = getRepresentative(n);
88 }
89 else
90 {
91 // abstraction does not exist, use model value
92 ret = getValueInternal(n);
93 }
94 }
95 else if (n.getNumChildren() == 0)
96 {
97 if (n.getKind() == PI)
98 {
99 // we are interested in the exact value of PI, which cannot be computed.
100 // hence, we return PI itself when asked for the concrete value.
101 ret = n;
102 }
103 else
104 {
105 ret = getValueInternal(n);
106 }
107 }
108 else
109 {
110 // otherwise, compute true value
111 std::vector<Node> children;
112 if (n.getMetaKind() == metakind::PARAMETERIZED)
113 {
114 children.push_back(n.getOperator());
115 }
116 for (unsigned i = 0; i < n.getNumChildren(); i++)
117 {
118 Node mc = computeModelValue(n[i], isConcrete);
119 children.push_back(mc);
120 }
121 ret = NodeManager::currentNM()->mkNode(n.getKind(), children);
122 if (n.getKind() == APPLY_UF)
123 {
124 ret = getValueInternal(ret);
125 }
126 else
127 {
128 ret = Rewriter::rewrite(ret);
129 }
130 }
131 Trace("nl-ext-mv-debug") << "computed " << (index == 0 ? "M" : "M_A") << "["
132 << n << "] = " << ret << std::endl;
133 d_mv[index][n] = ret;
134 return ret;
135 }
136
137 Node NlModel::getValueInternal(Node n) const
138 {
139 return d_model->getValue(n);
140 }
141
142 bool NlModel::hasTerm(Node n) const
143 {
144 return d_model->hasTerm(n);
145 }
146
147 Node NlModel::getRepresentative(Node n) const
148 {
149 return d_model->getRepresentative(n);
150 }
151
152 int NlModel::compare(Node i, Node j, bool isConcrete, bool isAbsolute)
153 {
154 Node ci = computeModelValue(i, isConcrete);
155 Node cj = computeModelValue(j, isConcrete);
156 if (ci.isConst())
157 {
158 if (cj.isConst())
159 {
160 return compareValue(ci, cj, isAbsolute);
161 }
162 return 1;
163 }
164 return cj.isConst() ? -1 : 0;
165 }
166
167 int NlModel::compareValue(Node i, Node j, bool isAbsolute) const
168 {
169 Assert(i.isConst() && j.isConst());
170 int ret;
171 if (i == j)
172 {
173 ret = 0;
174 }
175 else if (!isAbsolute)
176 {
177 ret = i.getConst<Rational>() < j.getConst<Rational>() ? 1 : -1;
178 }
179 else
180 {
181 ret = (i.getConst<Rational>().abs() == j.getConst<Rational>().abs()
182 ? 0
183 : (i.getConst<Rational>().abs() < j.getConst<Rational>().abs()
184 ? 1
185 : -1));
186 }
187 return ret;
188 }
189
190 bool NlModel::checkModel(const std::vector<Node>& assertions,
191 const std::vector<Node>& false_asserts,
192 unsigned d,
193 std::vector<Node>& lemmas,
194 std::vector<Node>& gs)
195 {
196 Trace("nl-ext-cm-debug") << " solve for equalities..." << std::endl;
197 for (const Node& atom : false_asserts)
198 {
199 // see if it corresponds to a univariate polynomial equation of degree two
200 if (atom.getKind() == EQUAL)
201 {
202 if (!solveEqualitySimple(atom, d, lemmas))
203 {
204 // no chance we will satisfy this equality
205 Trace("nl-ext-cm") << "...check-model : failed to solve equality : "
206 << atom << std::endl;
207 }
208 }
209 }
210
211 // all remaining variables are constrained to their exact model values
212 Trace("nl-ext-cm-debug") << " set exact bounds for remaining variables..."
213 << std::endl;
214 std::unordered_set<TNode, TNodeHashFunction> visited;
215 std::vector<TNode> visit;
216 TNode cur;
217 for (const Node& a : assertions)
218 {
219 visit.push_back(a);
220 do
221 {
222 cur = visit.back();
223 visit.pop_back();
224 if (visited.find(cur) == visited.end())
225 {
226 visited.insert(cur);
227 if (cur.getType().isReal() && !cur.isConst())
228 {
229 Kind k = cur.getKind();
230 if (k != MULT && k != PLUS && k != NONLINEAR_MULT
231 && !isTranscendentalKind(k))
232 {
233 // if we have not set an approximate bound for it
234 if (!hasCheckModelAssignment(cur))
235 {
236 // set its exact model value in the substitution
237 Node curv = computeConcreteModelValue(cur);
238 Trace("nl-ext-cm")
239 << "check-model-bound : exact : " << cur << " = ";
240 printRationalApprox("nl-ext-cm", curv);
241 Trace("nl-ext-cm") << std::endl;
242 bool ret = addCheckModelSubstitution(cur, curv);
243 AlwaysAssert(ret);
244 }
245 }
246 }
247 for (const Node& cn : cur)
248 {
249 visit.push_back(cn);
250 }
251 }
252 } while (!visit.empty());
253 }
254
255 Trace("nl-ext-cm-debug") << " check assertions..." << std::endl;
256 std::vector<Node> check_assertions;
257 for (const Node& a : assertions)
258 {
259 if (d_check_model_solved.find(a) == d_check_model_solved.end())
260 {
261 Node av = a;
262 // apply the substitution to a
263 if (!d_check_model_vars.empty())
264 {
265 av = av.substitute(d_check_model_vars.begin(),
266 d_check_model_vars.end(),
267 d_check_model_subs.begin(),
268 d_check_model_subs.end());
269 av = Rewriter::rewrite(av);
270 }
271 // simple check literal
272 if (!simpleCheckModelLit(av))
273 {
274 Trace("nl-ext-cm") << "...check-model : assertion failed : " << a
275 << std::endl;
276 check_assertions.push_back(av);
277 Trace("nl-ext-cm-debug")
278 << "...check-model : failed assertion, value : " << av << std::endl;
279 }
280 }
281 }
282
283 if (!check_assertions.empty())
284 {
285 Trace("nl-ext-cm") << "...simple check failed." << std::endl;
286 // TODO (#1450) check model for general case
287 return false;
288 }
289 Trace("nl-ext-cm") << "...simple check succeeded!" << std::endl;
290
291 // must assert and re-check if produce models is true
292 if (options::produceModels())
293 {
294 NodeManager* nm = NodeManager::currentNM();
295 // model guard whose semantics is "the model we constructed holds"
296 Node mg = nm->mkSkolem("model", nm->booleanType());
297 gs.push_back(mg);
298 // assert the constructed model as assertions
299 for (const std::pair<const Node, std::pair<Node, Node> > cb :
300 d_check_model_bounds)
301 {
302 Node l = cb.second.first;
303 Node u = cb.second.second;
304 Node v = cb.first;
305 Node pred = nm->mkNode(AND, nm->mkNode(GEQ, v, l), nm->mkNode(GEQ, u, v));
306 pred = nm->mkNode(OR, mg.negate(), pred);
307 lemmas.push_back(pred);
308 }
309 }
310 return true;
311 }
312
313 bool NlModel::addCheckModelSubstitution(TNode v, TNode s)
314 {
315 // should not substitute the same variable twice
316 Trace("nl-ext-model") << "* check model substitution : " << v << " -> " << s
317 << std::endl;
318 // should not set exact bound more than once
319 if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
320 != d_check_model_vars.end())
321 {
322 Trace("nl-ext-model") << "...ERROR: already has value." << std::endl;
323 // this should never happen since substitutions should be applied eagerly
324 Assert(false);
325 return false;
326 }
327 // if we previously had an approximate bound, the exact bound should be in its
328 // range
329 std::map<Node, std::pair<Node, Node> >::iterator itb =
330 d_check_model_bounds.find(v);
331 if (itb != d_check_model_bounds.end())
332 {
333 if (s.getConst<Rational>() >= itb->second.first.getConst<Rational>()
334 || s.getConst<Rational>() <= itb->second.second.getConst<Rational>())
335 {
336 Trace("nl-ext-model")
337 << "...ERROR: already has bound which is out of range." << std::endl;
338 return false;
339 }
340 }
341 for (unsigned i = 0, size = d_check_model_subs.size(); i < size; i++)
342 {
343 Node ms = d_check_model_subs[i];
344 Node mss = ms.substitute(v, s);
345 if (mss != ms)
346 {
347 mss = Rewriter::rewrite(mss);
348 }
349 d_check_model_subs[i] = mss;
350 }
351 d_check_model_vars.push_back(v);
352 d_check_model_subs.push_back(s);
353 return true;
354 }
355
356 bool NlModel::addCheckModelBound(TNode v, TNode l, TNode u)
357 {
358 Trace("nl-ext-model") << "* check model bound : " << v << " -> [" << l << " "
359 << u << "]" << std::endl;
360 if (l == u)
361 {
362 // bound is exact, can add as substitution
363 return addCheckModelSubstitution(v, l);
364 }
365 // should not set a bound for a value that is exact
366 if (std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
367 != d_check_model_vars.end())
368 {
369 Trace("nl-ext-model")
370 << "...ERROR: setting bound for variable that already has exact value."
371 << std::endl;
372 Assert(false);
373 return false;
374 }
375 Assert(l.isConst());
376 Assert(u.isConst());
377 Assert(l.getConst<Rational>() <= u.getConst<Rational>());
378 d_check_model_bounds[v] = std::pair<Node, Node>(l, u);
379 if (Trace.isOn("nl-ext-cm"))
380 {
381 Trace("nl-ext-cm") << "check-model-bound : approximate : ";
382 printRationalApprox("nl-ext-cm", l);
383 Trace("nl-ext-cm") << " <= " << v << " <= ";
384 printRationalApprox("nl-ext-cm", u);
385 Trace("nl-ext-cm") << std::endl;
386 }
387 return true;
388 }
389
390 bool NlModel::hasCheckModelAssignment(Node v) const
391 {
392 if (d_check_model_bounds.find(v) != d_check_model_bounds.end())
393 {
394 return true;
395 }
396 return std::find(d_check_model_vars.begin(), d_check_model_vars.end(), v)
397 != d_check_model_vars.end();
398 }
399
400 void NlModel::setUsedApproximate() { d_used_approx = true; }
401
402 bool NlModel::usedApproximate() const { return d_used_approx; }
403
404 bool NlModel::solveEqualitySimple(Node eq,
405 unsigned d,
406 std::vector<Node>& lemmas)
407 {
408 Node seq = eq;
409 if (!d_check_model_vars.empty())
410 {
411 seq = eq.substitute(d_check_model_vars.begin(),
412 d_check_model_vars.end(),
413 d_check_model_subs.begin(),
414 d_check_model_subs.end());
415 seq = Rewriter::rewrite(seq);
416 if (seq.isConst())
417 {
418 if (seq.getConst<bool>())
419 {
420 d_check_model_solved[eq] = Node::null();
421 return true;
422 }
423 return false;
424 }
425 }
426 Trace("nl-ext-cms") << "simple solve equality " << seq << "..." << std::endl;
427 Assert(seq.getKind() == EQUAL);
428 std::map<Node, Node> msum;
429 if (!ArithMSum::getMonomialSumLit(seq, msum))
430 {
431 Trace("nl-ext-cms") << "...fail, could not determine monomial sum."
432 << std::endl;
433 return false;
434 }
435 bool is_valid = true;
436 // the variable we will solve a quadratic equation for
437 Node var;
438 Node a = d_zero;
439 Node b = d_zero;
440 Node c = d_zero;
441 NodeManager* nm = NodeManager::currentNM();
442 // the list of variables that occur as a monomial in msum, and whose value
443 // is so far unconstrained in the model.
444 std::unordered_set<Node, NodeHashFunction> unc_vars;
445 // the list of variables that occur as a factor in a monomial, and whose
446 // value is so far unconstrained in the model.
447 std::unordered_set<Node, NodeHashFunction> unc_vars_factor;
448 for (std::pair<const Node, Node>& m : msum)
449 {
450 Node v = m.first;
451 Node coeff = m.second.isNull() ? d_one : m.second;
452 if (v.isNull())
453 {
454 c = coeff;
455 }
456 else if (v.getKind() == NONLINEAR_MULT)
457 {
458 if (v.getNumChildren() == 2 && v[0].isVar() && v[0] == v[1]
459 && (var.isNull() || var == v[0]))
460 {
461 // may solve quadratic
462 a = coeff;
463 var = v[0];
464 }
465 else
466 {
467 is_valid = false;
468 Trace("nl-ext-cms-debug")
469 << "...invalid due to non-linear monomial " << v << std::endl;
470 // may wish to set an exact bound for a factor and repeat
471 for (const Node& vc : v)
472 {
473 unc_vars_factor.insert(vc);
474 }
475 }
476 }
477 else if (!v.isVar() || (!var.isNull() && var != v))
478 {
479 Trace("nl-ext-cms-debug")
480 << "...invalid due to factor " << v << std::endl;
481 // cannot solve multivariate
482 if (is_valid)
483 {
484 is_valid = false;
485 // if b is non-zero, then var is also an unconstrained variable
486 if (b != d_zero)
487 {
488 unc_vars.insert(var);
489 unc_vars_factor.insert(var);
490 }
491 }
492 // if v is unconstrained, we may turn this equality into a substitution
493 unc_vars.insert(v);
494 unc_vars_factor.insert(v);
495 }
496 else
497 {
498 // set the variable to solve for
499 b = coeff;
500 var = v;
501 }
502 }
503 if (!is_valid)
504 {
505 // see if we can solve for a variable?
506 for (const Node& uv : unc_vars)
507 {
508 Trace("nl-ext-cm-debug") << "check subs var : " << uv << std::endl;
509 // cannot already have a bound
510 if (uv.isVar() && !hasCheckModelAssignment(uv))
511 {
512 Node slv;
513 Node veqc;
514 if (ArithMSum::isolate(uv, msum, veqc, slv, EQUAL) != 0)
515 {
516 Assert(!slv.isNull());
517 // currently do not support substitution-with-coefficients
518 if (veqc.isNull() && !expr::hasSubterm(slv, uv))
519 {
520 Trace("nl-ext-cm")
521 << "check-model-subs : " << uv << " -> " << slv << std::endl;
522 bool ret = addCheckModelSubstitution(uv, slv);
523 if (ret)
524 {
525 Trace("nl-ext-cms") << "...success, model substitution " << uv
526 << " -> " << slv << std::endl;
527 d_check_model_solved[eq] = uv;
528 }
529 return ret;
530 }
531 }
532 }
533 }
534 // see if we can assign a variable to a constant
535 for (const Node& uvf : unc_vars_factor)
536 {
537 Trace("nl-ext-cm-debug") << "check set var : " << uvf << std::endl;
538 // cannot already have a bound
539 if (uvf.isVar() && !hasCheckModelAssignment(uvf))
540 {
541 Node uvfv = computeConcreteModelValue(uvf);
542 Trace("nl-ext-cm") << "check-model-bound : exact : " << uvf << " = ";
543 printRationalApprox("nl-ext-cm", uvfv);
544 Trace("nl-ext-cm") << std::endl;
545 bool ret = addCheckModelSubstitution(uvf, uvfv);
546 // recurse
547 return ret ? solveEqualitySimple(eq, d, lemmas) : false;
548 }
549 }
550 Trace("nl-ext-cms") << "...fail due to constrained invalid terms."
551 << std::endl;
552 return false;
553 }
554 else if (var.isNull() || var.getType().isInteger())
555 {
556 // cannot solve quadratic equations for integer variables
557 Trace("nl-ext-cms") << "...fail due to variable to solve for." << std::endl;
558 return false;
559 }
560
561 // we are linear, it is simple
562 if (a == d_zero)
563 {
564 if (b == d_zero)
565 {
566 Trace("nl-ext-cms") << "...fail due to zero a/b." << std::endl;
567 Assert(false);
568 return false;
569 }
570 Node val = nm->mkConst(-c.getConst<Rational>() / b.getConst<Rational>());
571 Trace("nl-ext-cm") << "check-model-bound : exact : " << var << " = ";
572 printRationalApprox("nl-ext-cm", val);
573 Trace("nl-ext-cm") << std::endl;
574 bool ret = addCheckModelSubstitution(var, val);
575 if (ret)
576 {
577 Trace("nl-ext-cms") << "...success, solved linear." << std::endl;
578 d_check_model_solved[eq] = var;
579 }
580 return ret;
581 }
582 Trace("nl-ext-quad") << "Solve quadratic : " << seq << std::endl;
583 Trace("nl-ext-quad") << " a : " << a << std::endl;
584 Trace("nl-ext-quad") << " b : " << b << std::endl;
585 Trace("nl-ext-quad") << " c : " << c << std::endl;
586 Node two_a = nm->mkNode(MULT, d_two, a);
587 two_a = Rewriter::rewrite(two_a);
588 Node sqrt_val = nm->mkNode(
589 MINUS, nm->mkNode(MULT, b, b), nm->mkNode(MULT, d_two, two_a, c));
590 sqrt_val = Rewriter::rewrite(sqrt_val);
591 Trace("nl-ext-quad") << "Will approximate sqrt " << sqrt_val << std::endl;
592 Assert(sqrt_val.isConst());
593 // if it is negative, then we are in conflict
594 if (sqrt_val.getConst<Rational>().sgn() == -1)
595 {
596 Node conf = seq.negate();
597 Trace("nl-ext-lemma") << "NlModel::Lemma : quadratic no root : " << conf
598 << std::endl;
599 lemmas.push_back(conf);
600 Trace("nl-ext-cms") << "...fail due to negative discriminant." << std::endl;
601 return false;
602 }
603 if (hasCheckModelAssignment(var))
604 {
605 Trace("nl-ext-cms") << "...fail due to bounds on variable to solve for."
606 << std::endl;
607 // two quadratic equations for same variable, give up
608 return false;
609 }
610 // approximate the square root of sqrt_val
611 Node l, u;
612 if (!getApproximateSqrt(sqrt_val, l, u, 15 + d))
613 {
614 Trace("nl-ext-cms") << "...fail, could not approximate sqrt." << std::endl;
615 return false;
616 }
617 d_used_approx = true;
618 Trace("nl-ext-quad") << "...got " << l << " <= sqrt(" << sqrt_val
619 << ") <= " << u << std::endl;
620 Node negb = nm->mkConst(-b.getConst<Rational>());
621 Node coeffa = nm->mkConst(Rational(1) / two_a.getConst<Rational>());
622 // two possible bound regions
623 Node bounds[2][2];
624 Node diff_bound[2];
625 Node m_var = computeConcreteModelValue(var);
626 Assert(m_var.isConst());
627 for (unsigned r = 0; r < 2; r++)
628 {
629 for (unsigned b = 0; b < 2; b++)
630 {
631 Node val = b == 0 ? l : u;
632 // (-b +- approx_sqrt( b^2 - 4ac ))/2a
633 Node approx = nm->mkNode(
634 MULT, coeffa, nm->mkNode(r == 0 ? MINUS : PLUS, negb, val));
635 approx = Rewriter::rewrite(approx);
636 bounds[r][b] = approx;
637 Assert(approx.isConst());
638 }
639 if (bounds[r][0].getConst<Rational>() > bounds[r][1].getConst<Rational>())
640 {
641 // ensure bound is (lower, upper)
642 Node tmp = bounds[r][0];
643 bounds[r][0] = bounds[r][1];
644 bounds[r][1] = tmp;
645 }
646 Node diff =
647 nm->mkNode(MINUS,
648 m_var,
649 nm->mkNode(MULT,
650 nm->mkConst(Rational(1) / Rational(2)),
651 nm->mkNode(PLUS, bounds[r][0], bounds[r][1])));
652 Trace("nl-ext-cm-debug") << "Bound option #" << r << " : ";
653 printRationalApprox("nl-ext-cm-debug", bounds[r][0]);
654 Trace("nl-ext-cm-debug") << "...";
655 printRationalApprox("nl-ext-cm-debug", bounds[r][1]);
656 Trace("nl-ext-cm-debug") << std::endl;
657 diff = Rewriter::rewrite(diff);
658 Assert(diff.isConst());
659 diff = nm->mkConst(diff.getConst<Rational>().abs());
660 diff_bound[r] = diff;
661 Trace("nl-ext-cm-debug") << "...diff from model value (";
662 printRationalApprox("nl-ext-cm-debug", m_var);
663 Trace("nl-ext-cm-debug") << ") is ";
664 printRationalApprox("nl-ext-cm-debug", diff_bound[r]);
665 Trace("nl-ext-cm-debug") << std::endl;
666 }
667 // take the one that var is closer to in the model
668 Node cmp = nm->mkNode(GEQ, diff_bound[0], diff_bound[1]);
669 cmp = Rewriter::rewrite(cmp);
670 Assert(cmp.isConst());
671 unsigned r_use_index = cmp == d_true ? 1 : 0;
672 Trace("nl-ext-cm") << "check-model-bound : approximate (sqrt) : ";
673 printRationalApprox("nl-ext-cm", bounds[r_use_index][0]);
674 Trace("nl-ext-cm") << " <= " << var << " <= ";
675 printRationalApprox("nl-ext-cm", bounds[r_use_index][1]);
676 Trace("nl-ext-cm") << std::endl;
677 bool ret =
678 addCheckModelBound(var, bounds[r_use_index][0], bounds[r_use_index][1]);
679 if (ret)
680 {
681 d_check_model_solved[eq] = var;
682 Trace("nl-ext-cms") << "...success, solved quadratic." << std::endl;
683 }
684 return ret;
685 }
686
687 bool NlModel::simpleCheckModelLit(Node lit)
688 {
689 Trace("nl-ext-cms") << "*** Simple check-model lit for " << lit << "..."
690 << std::endl;
691 if (lit.isConst())
692 {
693 Trace("nl-ext-cms") << " return constant." << std::endl;
694 return lit.getConst<bool>();
695 }
696 NodeManager* nm = NodeManager::currentNM();
697 bool pol = lit.getKind() != kind::NOT;
698 Node atom = lit.getKind() == kind::NOT ? lit[0] : lit;
699
700 if (atom.getKind() == EQUAL)
701 {
702 // x = a is ( x >= a ^ x <= a )
703 for (unsigned i = 0; i < 2; i++)
704 {
705 Node lit = nm->mkNode(GEQ, atom[i], atom[1 - i]);
706 if (!pol)
707 {
708 lit = lit.negate();
709 }
710 lit = Rewriter::rewrite(lit);
711 bool success = simpleCheckModelLit(lit);
712 if (success != pol)
713 {
714 // false != true -> one conjunct of equality is false, we fail
715 // true != false -> one disjunct of disequality is true, we succeed
716 return success;
717 }
718 }
719 // both checks passed and polarity is true, or both checks failed and
720 // polarity is false
721 return pol;
722 }
723 else if (atom.getKind() != GEQ)
724 {
725 Trace("nl-ext-cms") << " failed due to unknown literal." << std::endl;
726 return false;
727 }
728 // get the monomial sum
729 std::map<Node, Node> msum;
730 if (!ArithMSum::getMonomialSumLit(atom, msum))
731 {
732 Trace("nl-ext-cms") << " failed due to get msum." << std::endl;
733 return false;
734 }
735 // simple interval analysis
736 if (simpleCheckModelMsum(msum, pol))
737 {
738 return true;
739 }
740 // can also try reasoning about univariate quadratic equations
741 Trace("nl-ext-cms-debug")
742 << "* Try univariate quadratic analysis..." << std::endl;
743 std::vector<Node> vs_invalid;
744 std::unordered_set<Node, NodeHashFunction> vs;
745 std::map<Node, Node> v_a;
746 std::map<Node, Node> v_b;
747 // get coefficients...
748 for (std::pair<const Node, Node>& m : msum)
749 {
750 Node v = m.first;
751 if (!v.isNull())
752 {
753 if (v.isVar())
754 {
755 v_b[v] = m.second.isNull() ? d_one : m.second;
756 vs.insert(v);
757 }
758 else if (v.getKind() == NONLINEAR_MULT && v.getNumChildren() == 2
759 && v[0] == v[1] && v[0].isVar())
760 {
761 v_a[v[0]] = m.second.isNull() ? d_one : m.second;
762 vs.insert(v[0]);
763 }
764 else
765 {
766 vs_invalid.push_back(v);
767 }
768 }
769 }
770 // solve the valid variables...
771 Node invalid_vsum = vs_invalid.empty() ? d_zero
772 : (vs_invalid.size() == 1
773 ? vs_invalid[0]
774 : nm->mkNode(PLUS, vs_invalid));
775 // substitution to try
776 std::vector<Node> qvars;
777 std::vector<Node> qsubs;
778 for (const Node& v : vs)
779 {
780 // is it a valid variable?
781 std::map<Node, std::pair<Node, Node> >::iterator bit =
782 d_check_model_bounds.find(v);
783 if (!expr::hasSubterm(invalid_vsum, v) && bit != d_check_model_bounds.end())
784 {
785 std::map<Node, Node>::iterator it = v_a.find(v);
786 if (it != v_a.end())
787 {
788 Node a = it->second;
789 Assert(a.isConst());
790 int asgn = a.getConst<Rational>().sgn();
791 Assert(asgn != 0);
792 Node t = nm->mkNode(MULT, a, v, v);
793 Node b = d_zero;
794 it = v_b.find(v);
795 if (it != v_b.end())
796 {
797 b = it->second;
798 t = nm->mkNode(PLUS, t, nm->mkNode(MULT, b, v));
799 }
800 t = Rewriter::rewrite(t);
801 Trace("nl-ext-cms-debug") << "Trying to find min/max for quadratic "
802 << t << "..." << std::endl;
803 Trace("nl-ext-cms-debug") << " a = " << a << std::endl;
804 Trace("nl-ext-cms-debug") << " b = " << b << std::endl;
805 // find maximal/minimal value on the interval
806 Node apex = nm->mkNode(
807 DIVISION, nm->mkNode(UMINUS, b), nm->mkNode(MULT, d_two, a));
808 apex = Rewriter::rewrite(apex);
809 Assert(apex.isConst());
810 // for lower, upper, whether we are greater than the apex
811 bool cmp[2];
812 Node boundn[2];
813 for (unsigned r = 0; r < 2; r++)
814 {
815 boundn[r] = r == 0 ? bit->second.first : bit->second.second;
816 Node cmpn = nm->mkNode(GT, boundn[r], apex);
817 cmpn = Rewriter::rewrite(cmpn);
818 Assert(cmpn.isConst());
819 cmp[r] = cmpn.getConst<bool>();
820 }
821 Trace("nl-ext-cms-debug") << " apex " << apex << std::endl;
822 Trace("nl-ext-cms-debug")
823 << " lower " << boundn[0] << ", cmp: " << cmp[0] << std::endl;
824 Trace("nl-ext-cms-debug")
825 << " upper " << boundn[1] << ", cmp: " << cmp[1] << std::endl;
826 Assert(boundn[0].getConst<Rational>()
827 <= boundn[1].getConst<Rational>());
828 Node s;
829 qvars.push_back(v);
830 if (cmp[0] != cmp[1])
831 {
832 Assert(!cmp[0] && cmp[1]);
833 // does the sign match the bound?
834 if ((asgn == 1) == pol)
835 {
836 // the apex is the max/min value
837 s = apex;
838 Trace("nl-ext-cms-debug") << " ...set to apex." << std::endl;
839 }
840 else
841 {
842 // it is one of the endpoints, plug in and compare
843 Node tcmpn[2];
844 for (unsigned r = 0; r < 2; r++)
845 {
846 qsubs.push_back(boundn[r]);
847 Node ts = t.substitute(
848 qvars.begin(), qvars.end(), qsubs.begin(), qsubs.end());
849 tcmpn[r] = Rewriter::rewrite(ts);
850 qsubs.pop_back();
851 }
852 Node tcmp = nm->mkNode(LT, tcmpn[0], tcmpn[1]);
853 Trace("nl-ext-cms-debug")
854 << " ...both sides of apex, compare " << tcmp << std::endl;
855 tcmp = Rewriter::rewrite(tcmp);
856 Assert(tcmp.isConst());
857 unsigned bindex_use = (tcmp.getConst<bool>() == pol) ? 1 : 0;
858 Trace("nl-ext-cms-debug")
859 << " ...set to " << (bindex_use == 1 ? "upper" : "lower")
860 << std::endl;
861 s = boundn[bindex_use];
862 }
863 }
864 else
865 {
866 // both to one side of the apex
867 // we figure out which bound to use (lower or upper) based on
868 // three factors:
869 // (1) whether a's sign is positive,
870 // (2) whether we are greater than the apex of the parabola,
871 // (3) the polarity of the constraint, i.e. >= or <=.
872 // there are 8 cases of these factors, which we test here.
873 unsigned bindex_use = (((asgn == 1) == cmp[0]) == pol) ? 0 : 1;
874 Trace("nl-ext-cms-debug")
875 << " ...set to " << (bindex_use == 1 ? "upper" : "lower")
876 << std::endl;
877 s = boundn[bindex_use];
878 }
879 Assert(!s.isNull());
880 qsubs.push_back(s);
881 Trace("nl-ext-cms") << "* set bound based on quadratic : " << v
882 << " -> " << s << std::endl;
883 }
884 }
885 }
886 if (!qvars.empty())
887 {
888 Assert(qvars.size() == qsubs.size());
889 Node slit =
890 lit.substitute(qvars.begin(), qvars.end(), qsubs.begin(), qsubs.end());
891 slit = Rewriter::rewrite(slit);
892 return simpleCheckModelLit(slit);
893 }
894 return false;
895 }
896
897 bool NlModel::simpleCheckModelMsum(const std::map<Node, Node>& msum, bool pol)
898 {
899 Trace("nl-ext-cms-debug") << "* Try simple interval analysis..." << std::endl;
900 NodeManager* nm = NodeManager::currentNM();
901 // map from transcendental functions to whether they were set to lower
902 // bound
903 bool simpleSuccess = true;
904 std::map<Node, bool> set_bound;
905 std::vector<Node> sum_bound;
906 for (const std::pair<const Node, Node>& m : msum)
907 {
908 Node v = m.first;
909 if (v.isNull())
910 {
911 sum_bound.push_back(m.second.isNull() ? d_one : m.second);
912 }
913 else
914 {
915 Trace("nl-ext-cms-debug") << "- monomial : " << v << std::endl;
916 // --- whether we should set a lower bound for this monomial
917 bool set_lower =
918 (m.second.isNull() || m.second.getConst<Rational>().sgn() == 1)
919 == pol;
920 Trace("nl-ext-cms-debug")
921 << "set bound to " << (set_lower ? "lower" : "upper") << std::endl;
922
923 // --- Collect variables and factors in v
924 std::vector<Node> vars;
925 std::vector<unsigned> factors;
926 if (v.getKind() == NONLINEAR_MULT)
927 {
928 unsigned last_start = 0;
929 for (unsigned i = 0, nchildren = v.getNumChildren(); i < nchildren; i++)
930 {
931 // are we at the end?
932 if (i + 1 == nchildren || v[i + 1] != v[i])
933 {
934 unsigned vfact = 1 + (i - last_start);
935 last_start = (i + 1);
936 vars.push_back(v[i]);
937 factors.push_back(vfact);
938 }
939 }
940 }
941 else
942 {
943 vars.push_back(v);
944 factors.push_back(1);
945 }
946
947 // --- Get the lower and upper bounds and sign information.
948 // Whether we have an (odd) number of negative factors in vars, apart
949 // from the variable at choose_index.
950 bool has_neg_factor = false;
951 int choose_index = -1;
952 std::vector<Node> ls;
953 std::vector<Node> us;
954 // the relevant sign information for variables with odd exponents:
955 // 1: both signs of the interval of this variable are positive,
956 // -1: both signs of the interval of this variable are negative.
957 std::vector<int> signs;
958 Trace("nl-ext-cms-debug") << "get sign information..." << std::endl;
959 for (unsigned i = 0, size = vars.size(); i < size; i++)
960 {
961 Node vc = vars[i];
962 unsigned vcfact = factors[i];
963 if (Trace.isOn("nl-ext-cms-debug"))
964 {
965 Trace("nl-ext-cms-debug") << "-- " << vc;
966 if (vcfact > 1)
967 {
968 Trace("nl-ext-cms-debug") << "^" << vcfact;
969 }
970 Trace("nl-ext-cms-debug") << " ";
971 }
972 std::map<Node, std::pair<Node, Node> >::iterator bit =
973 d_check_model_bounds.find(vc);
974 // if there is a model bound for this term
975 if (bit != d_check_model_bounds.end())
976 {
977 Node l = bit->second.first;
978 Node u = bit->second.second;
979 ls.push_back(l);
980 us.push_back(u);
981 int vsign = 0;
982 if (vcfact % 2 == 1)
983 {
984 vsign = 1;
985 int lsgn = l.getConst<Rational>().sgn();
986 int usgn = u.getConst<Rational>().sgn();
987 Trace("nl-ext-cms-debug")
988 << "bound_sign(" << lsgn << "," << usgn << ") ";
989 if (lsgn == -1)
990 {
991 if (usgn < 1)
992 {
993 // must have a negative factor
994 has_neg_factor = !has_neg_factor;
995 vsign = -1;
996 }
997 else if (choose_index == -1)
998 {
999 // set the choose index to this
1000 choose_index = i;
1001 vsign = 0;
1002 }
1003 else
1004 {
1005 // ambiguous, can't determine the bound
1006 Trace("nl-ext-cms")
1007 << " failed due to ambiguious monomial." << std::endl;
1008 return false;
1009 }
1010 }
1011 }
1012 Trace("nl-ext-cms-debug") << " -> " << vsign << std::endl;
1013 signs.push_back(vsign);
1014 }
1015 else
1016 {
1017 Trace("nl-ext-cms-debug") << std::endl;
1018 Trace("nl-ext-cms")
1019 << " failed due to unknown bound for " << vc << std::endl;
1020 // should either assign a model bound or eliminate the variable
1021 // via substitution
1022 Assert(false);
1023 return false;
1024 }
1025 }
1026 // whether we will try to minimize/maximize (-1/1) the absolute value
1027 int setAbs = (set_lower == has_neg_factor) ? 1 : -1;
1028 Trace("nl-ext-cms-debug")
1029 << "set absolute value to " << (setAbs == 1 ? "maximal" : "minimal")
1030 << std::endl;
1031
1032 std::vector<Node> vbs;
1033 Trace("nl-ext-cms-debug") << "set bounds..." << std::endl;
1034 for (unsigned i = 0, size = vars.size(); i < size; i++)
1035 {
1036 Node vc = vars[i];
1037 unsigned vcfact = factors[i];
1038 Node l = ls[i];
1039 Node u = us[i];
1040 bool vc_set_lower;
1041 int vcsign = signs[i];
1042 Trace("nl-ext-cms-debug")
1043 << "Bounds for " << vc << " : " << l << ", " << u
1044 << ", sign : " << vcsign << ", factor : " << vcfact << std::endl;
1045 if (l == u)
1046 {
1047 // by convention, always say it is lower if they are the same
1048 vc_set_lower = true;
1049 Trace("nl-ext-cms-debug")
1050 << "..." << vc << " equal bound, set to lower" << std::endl;
1051 }
1052 else
1053 {
1054 if (vcfact % 2 == 0)
1055 {
1056 // minimize or maximize its absolute value
1057 Rational la = l.getConst<Rational>().abs();
1058 Rational ua = u.getConst<Rational>().abs();
1059 if (la == ua)
1060 {
1061 // by convention, always say it is lower if abs are the same
1062 vc_set_lower = true;
1063 Trace("nl-ext-cms-debug")
1064 << "..." << vc << " equal abs, set to lower" << std::endl;
1065 }
1066 else
1067 {
1068 vc_set_lower = (la > ua) == (setAbs == 1);
1069 }
1070 }
1071 else if (signs[i] == 0)
1072 {
1073 // we choose this index to match the overall set_lower
1074 vc_set_lower = set_lower;
1075 }
1076 else
1077 {
1078 vc_set_lower = (signs[i] != setAbs);
1079 }
1080 Trace("nl-ext-cms-debug")
1081 << "..." << vc << " set to " << (vc_set_lower ? "lower" : "upper")
1082 << std::endl;
1083 }
1084 // check whether this is a conflicting bound
1085 std::map<Node, bool>::iterator itsb = set_bound.find(vc);
1086 if (itsb == set_bound.end())
1087 {
1088 set_bound[vc] = vc_set_lower;
1089 }
1090 else if (itsb->second != vc_set_lower)
1091 {
1092 Trace("nl-ext-cms")
1093 << " failed due to conflicting bound for " << vc << std::endl;
1094 return false;
1095 }
1096 // must over/under approximate based on vc_set_lower, computed above
1097 Node vb = vc_set_lower ? l : u;
1098 for (unsigned i = 0; i < vcfact; i++)
1099 {
1100 vbs.push_back(vb);
1101 }
1102 }
1103 if (!simpleSuccess)
1104 {
1105 break;
1106 }
1107 Node vbound = vbs.size() == 1 ? vbs[0] : nm->mkNode(MULT, vbs);
1108 sum_bound.push_back(ArithMSum::mkCoeffTerm(m.second, vbound));
1109 }
1110 }
1111 // if the exact bound was computed via simple analysis above
1112 // make the bound
1113 Node bound;
1114 if (sum_bound.size() > 1)
1115 {
1116 bound = nm->mkNode(kind::PLUS, sum_bound);
1117 }
1118 else if (sum_bound.size() == 1)
1119 {
1120 bound = sum_bound[0];
1121 }
1122 else
1123 {
1124 bound = d_zero;
1125 }
1126 // make the comparison
1127 Node comp = nm->mkNode(kind::GEQ, bound, d_zero);
1128 if (!pol)
1129 {
1130 comp = comp.negate();
1131 }
1132 Trace("nl-ext-cms") << " comparison is : " << comp << std::endl;
1133 comp = Rewriter::rewrite(comp);
1134 Assert(comp.isConst());
1135 Trace("nl-ext-cms") << " returned : " << comp << std::endl;
1136 return comp == d_true;
1137 }
1138
1139 bool NlModel::isRefineableTfFun(Node tf)
1140 {
1141 Assert(tf.getKind() == SINE || tf.getKind() == EXPONENTIAL);
1142 if (tf.getKind() == SINE)
1143 {
1144 // we do not consider e.g. sin( -1*x ), since considering sin( x ) will
1145 // have the same effect. We also do not consider sin(x+y) since this is
1146 // handled by introducing a fresh variable (see the map d_tr_base in
1147 // NonlinearExtension).
1148 if (!tf[0].isVar())
1149 {
1150 return false;
1151 }
1152 }
1153 // Figure 3 : c
1154 Node c = computeAbstractModelValue(tf[0]);
1155 Assert(c.isConst());
1156 int csign = c.getConst<Rational>().sgn();
1157 if (csign == 0)
1158 {
1159 return false;
1160 }
1161 return true;
1162 }
1163
1164 bool NlModel::getApproximateSqrt(Node c, Node& l, Node& u, unsigned iter) const
1165 {
1166 Assert(c.isConst());
1167 if (c == d_one || c == d_zero)
1168 {
1169 l = c;
1170 u = c;
1171 return true;
1172 }
1173 Rational rc = c.getConst<Rational>();
1174
1175 Rational rl = rc < Rational(1) ? rc : Rational(1);
1176 Rational ru = rc < Rational(1) ? Rational(1) : rc;
1177 unsigned count = 0;
1178 Rational half = Rational(1) / Rational(2);
1179 while (count < iter)
1180 {
1181 Rational curr = half * (rl + ru);
1182 Rational curr_sq = curr * curr;
1183 if (curr_sq == rc)
1184 {
1185 rl = curr_sq;
1186 ru = curr_sq;
1187 break;
1188 }
1189 else if (curr_sq < rc)
1190 {
1191 rl = curr;
1192 }
1193 else
1194 {
1195 ru = curr;
1196 }
1197 count++;
1198 }
1199
1200 NodeManager* nm = NodeManager::currentNM();
1201 l = nm->mkConst(rl);
1202 u = nm->mkConst(ru);
1203 return true;
1204 }
1205
1206 void NlModel::printModelValue(const char* c, Node n, unsigned prec) const
1207 {
1208 if (Trace.isOn(c))
1209 {
1210 Trace(c) << " " << n << " -> ";
1211 for (int i = 1; i >= 0; --i)
1212 {
1213 std::map<Node, Node>::const_iterator it = d_mv[i].find(n);
1214 Assert(it != d_mv[i].end());
1215 if (it->second.isConst())
1216 {
1217 printRationalApprox(c, it->second, prec);
1218 }
1219 else
1220 {
1221 Trace(c) << "?";
1222 }
1223 Trace(c) << (i == 1 ? " [actual: " : " ]");
1224 }
1225 Trace(c) << std::endl;
1226 }
1227 }
1228
1229 void NlModel::getModelValueRepair(std::map<Node, Node>& arithModel,
1230 std::map<Node, Node>& approximations)
1231 {
1232 // Record the approximations we used. This code calls the
1233 // recordApproximation method of the model, which overrides the model
1234 // values for variables that we solved for, using techniques specific to
1235 // this class.
1236 Trace("nl-model") << "NlModel::getModelValueRepair:" << std::endl;
1237 NodeManager* nm = NodeManager::currentNM();
1238 for (const std::pair<const Node, std::pair<Node, Node> >& cb :
1239 d_check_model_bounds)
1240 {
1241 Node l = cb.second.first;
1242 Node u = cb.second.second;
1243 Node pred;
1244 Node v = cb.first;
1245 if (l != u)
1246 {
1247 Node pred = nm->mkNode(AND, nm->mkNode(GEQ, v, l), nm->mkNode(GEQ, u, v));
1248 approximations[v] = pred;
1249 Trace("nl-model") << v << " approximated as " << pred << std::endl;
1250 }
1251 else
1252 {
1253 // overwrite
1254 arithModel[v] = l;
1255 Trace("nl-model") << v << " exact approximation is " << l << std::endl;
1256 }
1257 }
1258 // Also record the exact values we used. An exact value can be seen as a
1259 // special kind approximation of the form (choice x. x = exact_value).
1260 // Notice that the above term gets rewritten such that the choice function
1261 // is eliminated.
1262 for (size_t i = 0, num = d_check_model_vars.size(); i < num; i++)
1263 {
1264 Node v = d_check_model_vars[i];
1265 Node s = d_check_model_subs[i];
1266 // overwrite
1267 arithModel[v] = s;
1268 Trace("nl-model") << v << " solved is " << s << std::endl;
1269 }
1270 }
1271
1272 } // namespace arith
1273 } // namespace theory
1274 } // namespace CVC4