62bdf310bec42acffebd0a37934c8fbe91c2654d
1 /********************* */
4 ** Top contributors (to current version):
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
12 ** \brief Model object for the non-linear extension class
15 #include "theory/arith/nl_model.h"
17 #include "expr/node_algorithm.h"
18 #include "theory/arith/arith_msum.h"
19 #include "theory/arith/arith_utilities.h"
20 #include "theory/rewriter.h"
22 using namespace CVC4::kind
;
28 NlModel::NlModel(context::Context
* c
) : d_used_approx(false)
30 d_true
= NodeManager::currentNM()->mkConst(true);
31 d_false
= NodeManager::currentNM()->mkConst(false);
32 d_zero
= NodeManager::currentNM()->mkConst(Rational(0));
33 d_one
= NodeManager::currentNM()->mkConst(Rational(1));
34 d_two
= NodeManager::currentNM()->mkConst(Rational(2));
37 NlModel::~NlModel() {}
39 void NlModel::reset(TheoryModel
* m
)
46 void NlModel::resetCheck()
48 d_used_approx
= false;
49 d_check_model_solved
.clear();
50 d_check_model_bounds
.clear();
51 d_check_model_vars
.clear();
52 d_check_model_subs
.clear();
55 Node
NlModel::computeConcreteModelValue(Node n
)
57 return computeModelValue(n
, true);
60 Node
NlModel::computeAbstractModelValue(Node n
)
62 return computeModelValue(n
, false);
65 Node
NlModel::computeModelValue(Node n
, bool isConcrete
)
67 unsigned index
= isConcrete
? 0 : 1;
68 std::map
<Node
, Node
>::iterator it
= d_mv
[index
].find(n
);
69 if (it
!= d_mv
[index
].end())
73 Trace("nl-ext-mv-debug") << "computeModelValue " << n
<< ", index=" << index
81 && (n
.getKind() == NONLINEAR_MULT
82 || isTranscendentalKind(n
.getKind())))
86 // use model value for abstraction
87 ret
= getRepresentative(n
);
91 // abstraction does not exist, use model value
92 ret
= getValueInternal(n
);
95 else if (n
.getNumChildren() == 0)
97 if (n
.getKind() == PI
)
99 // we are interested in the exact value of PI, which cannot be computed.
100 // hence, we return PI itself when asked for the concrete value.
105 ret
= getValueInternal(n
);
110 // otherwise, compute true value
111 std::vector
<Node
> children
;
112 if (n
.getMetaKind() == metakind::PARAMETERIZED
)
114 children
.push_back(n
.getOperator());
116 for (unsigned i
= 0; i
< n
.getNumChildren(); i
++)
118 Node mc
= computeModelValue(n
[i
], isConcrete
);
119 children
.push_back(mc
);
121 ret
= NodeManager::currentNM()->mkNode(n
.getKind(), children
);
122 if (n
.getKind() == APPLY_UF
)
124 ret
= getValueInternal(ret
);
128 ret
= Rewriter::rewrite(ret
);
131 Trace("nl-ext-mv-debug") << "computed " << (index
== 0 ? "M" : "M_A") << "["
132 << n
<< "] = " << ret
<< std::endl
;
133 d_mv
[index
][n
] = ret
;
137 Node
NlModel::getValueInternal(Node n
) const
139 return d_model
->getValue(n
);
142 bool NlModel::hasTerm(Node n
) const
144 return d_model
->hasTerm(n
);
147 Node
NlModel::getRepresentative(Node n
) const
149 return d_model
->getRepresentative(n
);
152 int NlModel::compare(Node i
, Node j
, bool isConcrete
, bool isAbsolute
)
154 Node ci
= computeModelValue(i
, isConcrete
);
155 Node cj
= computeModelValue(j
, isConcrete
);
160 return compareValue(ci
, cj
, isAbsolute
);
164 return cj
.isConst() ? -1 : 0;
167 int NlModel::compareValue(Node i
, Node j
, bool isAbsolute
) const
169 Assert(i
.isConst() && j
.isConst());
175 else if (!isAbsolute
)
177 ret
= i
.getConst
<Rational
>() < j
.getConst
<Rational
>() ? 1 : -1;
181 ret
= (i
.getConst
<Rational
>().abs() == j
.getConst
<Rational
>().abs()
183 : (i
.getConst
<Rational
>().abs() < j
.getConst
<Rational
>().abs()
190 bool NlModel::checkModel(const std::vector
<Node
>& assertions
,
191 const std::vector
<Node
>& false_asserts
,
193 std::vector
<Node
>& lemmas
,
194 std::vector
<Node
>& gs
)
196 Trace("nl-ext-cm-debug") << " solve for equalities..." << std::endl
;
197 for (const Node
& atom
: false_asserts
)
199 // see if it corresponds to a univariate polynomial equation of degree two
200 if (atom
.getKind() == EQUAL
)
202 if (!solveEqualitySimple(atom
, d
, lemmas
))
204 // no chance we will satisfy this equality
205 Trace("nl-ext-cm") << "...check-model : failed to solve equality : "
206 << atom
<< std::endl
;
211 // all remaining variables are constrained to their exact model values
212 Trace("nl-ext-cm-debug") << " set exact bounds for remaining variables..."
214 std::unordered_set
<TNode
, TNodeHashFunction
> visited
;
215 std::vector
<TNode
> visit
;
217 for (const Node
& a
: assertions
)
224 if (visited
.find(cur
) == visited
.end())
227 if (cur
.getType().isReal() && !cur
.isConst())
229 Kind k
= cur
.getKind();
230 if (k
!= MULT
&& k
!= PLUS
&& k
!= NONLINEAR_MULT
231 && !isTranscendentalKind(k
))
233 // if we have not set an approximate bound for it
234 if (!hasCheckModelAssignment(cur
))
236 // set its exact model value in the substitution
237 Node curv
= computeConcreteModelValue(cur
);
239 << "check-model-bound : exact : " << cur
<< " = ";
240 printRationalApprox("nl-ext-cm", curv
);
241 Trace("nl-ext-cm") << std::endl
;
242 bool ret
= addCheckModelSubstitution(cur
, curv
);
247 for (const Node
& cn
: cur
)
252 } while (!visit
.empty());
255 Trace("nl-ext-cm-debug") << " check assertions..." << std::endl
;
256 std::vector
<Node
> check_assertions
;
257 for (const Node
& a
: assertions
)
259 if (d_check_model_solved
.find(a
) == d_check_model_solved
.end())
262 // apply the substitution to a
263 if (!d_check_model_vars
.empty())
265 av
= av
.substitute(d_check_model_vars
.begin(),
266 d_check_model_vars
.end(),
267 d_check_model_subs
.begin(),
268 d_check_model_subs
.end());
269 av
= Rewriter::rewrite(av
);
271 // simple check literal
272 if (!simpleCheckModelLit(av
))
274 Trace("nl-ext-cm") << "...check-model : assertion failed : " << a
276 check_assertions
.push_back(av
);
277 Trace("nl-ext-cm-debug")
278 << "...check-model : failed assertion, value : " << av
<< std::endl
;
283 if (!check_assertions
.empty())
285 Trace("nl-ext-cm") << "...simple check failed." << std::endl
;
286 // TODO (#1450) check model for general case
289 Trace("nl-ext-cm") << "...simple check succeeded!" << std::endl
;
291 // must assert and re-check if produce models is true
292 if (options::produceModels())
294 NodeManager
* nm
= NodeManager::currentNM();
295 // model guard whose semantics is "the model we constructed holds"
296 Node mg
= nm
->mkSkolem("model", nm
->booleanType());
298 // assert the constructed model as assertions
299 for (const std::pair
<const Node
, std::pair
<Node
, Node
> > cb
:
300 d_check_model_bounds
)
302 Node l
= cb
.second
.first
;
303 Node u
= cb
.second
.second
;
305 Node pred
= nm
->mkNode(AND
, nm
->mkNode(GEQ
, v
, l
), nm
->mkNode(GEQ
, u
, v
));
306 pred
= nm
->mkNode(OR
, mg
.negate(), pred
);
307 lemmas
.push_back(pred
);
313 bool NlModel::addCheckModelSubstitution(TNode v
, TNode s
)
315 // should not substitute the same variable twice
316 Trace("nl-ext-model") << "* check model substitution : " << v
<< " -> " << s
318 // should not set exact bound more than once
319 if (std::find(d_check_model_vars
.begin(), d_check_model_vars
.end(), v
)
320 != d_check_model_vars
.end())
322 Trace("nl-ext-model") << "...ERROR: already has value." << std::endl
;
323 // this should never happen since substitutions should be applied eagerly
327 // if we previously had an approximate bound, the exact bound should be in its
329 std::map
<Node
, std::pair
<Node
, Node
> >::iterator itb
=
330 d_check_model_bounds
.find(v
);
331 if (itb
!= d_check_model_bounds
.end())
333 if (s
.getConst
<Rational
>() >= itb
->second
.first
.getConst
<Rational
>()
334 || s
.getConst
<Rational
>() <= itb
->second
.second
.getConst
<Rational
>())
336 Trace("nl-ext-model")
337 << "...ERROR: already has bound which is out of range." << std::endl
;
341 for (unsigned i
= 0, size
= d_check_model_subs
.size(); i
< size
; i
++)
343 Node ms
= d_check_model_subs
[i
];
344 Node mss
= ms
.substitute(v
, s
);
347 mss
= Rewriter::rewrite(mss
);
349 d_check_model_subs
[i
] = mss
;
351 d_check_model_vars
.push_back(v
);
352 d_check_model_subs
.push_back(s
);
356 bool NlModel::addCheckModelBound(TNode v
, TNode l
, TNode u
)
358 Trace("nl-ext-model") << "* check model bound : " << v
<< " -> [" << l
<< " "
359 << u
<< "]" << std::endl
;
362 // bound is exact, can add as substitution
363 return addCheckModelSubstitution(v
, l
);
365 // should not set a bound for a value that is exact
366 if (std::find(d_check_model_vars
.begin(), d_check_model_vars
.end(), v
)
367 != d_check_model_vars
.end())
369 Trace("nl-ext-model")
370 << "...ERROR: setting bound for variable that already has exact value."
377 Assert(l
.getConst
<Rational
>() <= u
.getConst
<Rational
>());
378 d_check_model_bounds
[v
] = std::pair
<Node
, Node
>(l
, u
);
379 if (Trace
.isOn("nl-ext-cm"))
381 Trace("nl-ext-cm") << "check-model-bound : approximate : ";
382 printRationalApprox("nl-ext-cm", l
);
383 Trace("nl-ext-cm") << " <= " << v
<< " <= ";
384 printRationalApprox("nl-ext-cm", u
);
385 Trace("nl-ext-cm") << std::endl
;
390 bool NlModel::hasCheckModelAssignment(Node v
) const
392 if (d_check_model_bounds
.find(v
) != d_check_model_bounds
.end())
396 return std::find(d_check_model_vars
.begin(), d_check_model_vars
.end(), v
)
397 != d_check_model_vars
.end();
400 void NlModel::setUsedApproximate() { d_used_approx
= true; }
402 bool NlModel::usedApproximate() const { return d_used_approx
; }
404 bool NlModel::solveEqualitySimple(Node eq
,
406 std::vector
<Node
>& lemmas
)
409 if (!d_check_model_vars
.empty())
411 seq
= eq
.substitute(d_check_model_vars
.begin(),
412 d_check_model_vars
.end(),
413 d_check_model_subs
.begin(),
414 d_check_model_subs
.end());
415 seq
= Rewriter::rewrite(seq
);
418 if (seq
.getConst
<bool>())
420 d_check_model_solved
[eq
] = Node::null();
426 Trace("nl-ext-cms") << "simple solve equality " << seq
<< "..." << std::endl
;
427 Assert(seq
.getKind() == EQUAL
);
428 std::map
<Node
, Node
> msum
;
429 if (!ArithMSum::getMonomialSumLit(seq
, msum
))
431 Trace("nl-ext-cms") << "...fail, could not determine monomial sum."
435 bool is_valid
= true;
436 // the variable we will solve a quadratic equation for
441 NodeManager
* nm
= NodeManager::currentNM();
442 // the list of variables that occur as a monomial in msum, and whose value
443 // is so far unconstrained in the model.
444 std::unordered_set
<Node
, NodeHashFunction
> unc_vars
;
445 // the list of variables that occur as a factor in a monomial, and whose
446 // value is so far unconstrained in the model.
447 std::unordered_set
<Node
, NodeHashFunction
> unc_vars_factor
;
448 for (std::pair
<const Node
, Node
>& m
: msum
)
451 Node coeff
= m
.second
.isNull() ? d_one
: m
.second
;
456 else if (v
.getKind() == NONLINEAR_MULT
)
458 if (v
.getNumChildren() == 2 && v
[0].isVar() && v
[0] == v
[1]
459 && (var
.isNull() || var
== v
[0]))
461 // may solve quadratic
468 Trace("nl-ext-cms-debug")
469 << "...invalid due to non-linear monomial " << v
<< std::endl
;
470 // may wish to set an exact bound for a factor and repeat
471 for (const Node
& vc
: v
)
473 unc_vars_factor
.insert(vc
);
477 else if (!v
.isVar() || (!var
.isNull() && var
!= v
))
479 Trace("nl-ext-cms-debug")
480 << "...invalid due to factor " << v
<< std::endl
;
481 // cannot solve multivariate
485 // if b is non-zero, then var is also an unconstrained variable
488 unc_vars
.insert(var
);
489 unc_vars_factor
.insert(var
);
492 // if v is unconstrained, we may turn this equality into a substitution
494 unc_vars_factor
.insert(v
);
498 // set the variable to solve for
505 // see if we can solve for a variable?
506 for (const Node
& uv
: unc_vars
)
508 Trace("nl-ext-cm-debug") << "check subs var : " << uv
<< std::endl
;
509 // cannot already have a bound
510 if (uv
.isVar() && !hasCheckModelAssignment(uv
))
514 if (ArithMSum::isolate(uv
, msum
, veqc
, slv
, EQUAL
) != 0)
516 Assert(!slv
.isNull());
517 // currently do not support substitution-with-coefficients
518 if (veqc
.isNull() && !expr::hasSubterm(slv
, uv
))
521 << "check-model-subs : " << uv
<< " -> " << slv
<< std::endl
;
522 bool ret
= addCheckModelSubstitution(uv
, slv
);
525 Trace("nl-ext-cms") << "...success, model substitution " << uv
526 << " -> " << slv
<< std::endl
;
527 d_check_model_solved
[eq
] = uv
;
534 // see if we can assign a variable to a constant
535 for (const Node
& uvf
: unc_vars_factor
)
537 Trace("nl-ext-cm-debug") << "check set var : " << uvf
<< std::endl
;
538 // cannot already have a bound
539 if (uvf
.isVar() && !hasCheckModelAssignment(uvf
))
541 Node uvfv
= computeConcreteModelValue(uvf
);
542 Trace("nl-ext-cm") << "check-model-bound : exact : " << uvf
<< " = ";
543 printRationalApprox("nl-ext-cm", uvfv
);
544 Trace("nl-ext-cm") << std::endl
;
545 bool ret
= addCheckModelSubstitution(uvf
, uvfv
);
547 return ret
? solveEqualitySimple(eq
, d
, lemmas
) : false;
550 Trace("nl-ext-cms") << "...fail due to constrained invalid terms."
554 else if (var
.isNull() || var
.getType().isInteger())
556 // cannot solve quadratic equations for integer variables
557 Trace("nl-ext-cms") << "...fail due to variable to solve for." << std::endl
;
561 // we are linear, it is simple
566 Trace("nl-ext-cms") << "...fail due to zero a/b." << std::endl
;
570 Node val
= nm
->mkConst(-c
.getConst
<Rational
>() / b
.getConst
<Rational
>());
571 Trace("nl-ext-cm") << "check-model-bound : exact : " << var
<< " = ";
572 printRationalApprox("nl-ext-cm", val
);
573 Trace("nl-ext-cm") << std::endl
;
574 bool ret
= addCheckModelSubstitution(var
, val
);
577 Trace("nl-ext-cms") << "...success, solved linear." << std::endl
;
578 d_check_model_solved
[eq
] = var
;
582 Trace("nl-ext-quad") << "Solve quadratic : " << seq
<< std::endl
;
583 Trace("nl-ext-quad") << " a : " << a
<< std::endl
;
584 Trace("nl-ext-quad") << " b : " << b
<< std::endl
;
585 Trace("nl-ext-quad") << " c : " << c
<< std::endl
;
586 Node two_a
= nm
->mkNode(MULT
, d_two
, a
);
587 two_a
= Rewriter::rewrite(two_a
);
588 Node sqrt_val
= nm
->mkNode(
589 MINUS
, nm
->mkNode(MULT
, b
, b
), nm
->mkNode(MULT
, d_two
, two_a
, c
));
590 sqrt_val
= Rewriter::rewrite(sqrt_val
);
591 Trace("nl-ext-quad") << "Will approximate sqrt " << sqrt_val
<< std::endl
;
592 Assert(sqrt_val
.isConst());
593 // if it is negative, then we are in conflict
594 if (sqrt_val
.getConst
<Rational
>().sgn() == -1)
596 Node conf
= seq
.negate();
597 Trace("nl-ext-lemma") << "NlModel::Lemma : quadratic no root : " << conf
599 lemmas
.push_back(conf
);
600 Trace("nl-ext-cms") << "...fail due to negative discriminant." << std::endl
;
603 if (hasCheckModelAssignment(var
))
605 Trace("nl-ext-cms") << "...fail due to bounds on variable to solve for."
607 // two quadratic equations for same variable, give up
610 // approximate the square root of sqrt_val
612 if (!getApproximateSqrt(sqrt_val
, l
, u
, 15 + d
))
614 Trace("nl-ext-cms") << "...fail, could not approximate sqrt." << std::endl
;
617 d_used_approx
= true;
618 Trace("nl-ext-quad") << "...got " << l
<< " <= sqrt(" << sqrt_val
619 << ") <= " << u
<< std::endl
;
620 Node negb
= nm
->mkConst(-b
.getConst
<Rational
>());
621 Node coeffa
= nm
->mkConst(Rational(1) / two_a
.getConst
<Rational
>());
622 // two possible bound regions
625 Node m_var
= computeConcreteModelValue(var
);
626 Assert(m_var
.isConst());
627 for (unsigned r
= 0; r
< 2; r
++)
629 for (unsigned b
= 0; b
< 2; b
++)
631 Node val
= b
== 0 ? l
: u
;
632 // (-b +- approx_sqrt( b^2 - 4ac ))/2a
633 Node approx
= nm
->mkNode(
634 MULT
, coeffa
, nm
->mkNode(r
== 0 ? MINUS
: PLUS
, negb
, val
));
635 approx
= Rewriter::rewrite(approx
);
636 bounds
[r
][b
] = approx
;
637 Assert(approx
.isConst());
639 if (bounds
[r
][0].getConst
<Rational
>() > bounds
[r
][1].getConst
<Rational
>())
641 // ensure bound is (lower, upper)
642 Node tmp
= bounds
[r
][0];
643 bounds
[r
][0] = bounds
[r
][1];
650 nm
->mkConst(Rational(1) / Rational(2)),
651 nm
->mkNode(PLUS
, bounds
[r
][0], bounds
[r
][1])));
652 Trace("nl-ext-cm-debug") << "Bound option #" << r
<< " : ";
653 printRationalApprox("nl-ext-cm-debug", bounds
[r
][0]);
654 Trace("nl-ext-cm-debug") << "...";
655 printRationalApprox("nl-ext-cm-debug", bounds
[r
][1]);
656 Trace("nl-ext-cm-debug") << std::endl
;
657 diff
= Rewriter::rewrite(diff
);
658 Assert(diff
.isConst());
659 diff
= nm
->mkConst(diff
.getConst
<Rational
>().abs());
660 diff_bound
[r
] = diff
;
661 Trace("nl-ext-cm-debug") << "...diff from model value (";
662 printRationalApprox("nl-ext-cm-debug", m_var
);
663 Trace("nl-ext-cm-debug") << ") is ";
664 printRationalApprox("nl-ext-cm-debug", diff_bound
[r
]);
665 Trace("nl-ext-cm-debug") << std::endl
;
667 // take the one that var is closer to in the model
668 Node cmp
= nm
->mkNode(GEQ
, diff_bound
[0], diff_bound
[1]);
669 cmp
= Rewriter::rewrite(cmp
);
670 Assert(cmp
.isConst());
671 unsigned r_use_index
= cmp
== d_true
? 1 : 0;
672 Trace("nl-ext-cm") << "check-model-bound : approximate (sqrt) : ";
673 printRationalApprox("nl-ext-cm", bounds
[r_use_index
][0]);
674 Trace("nl-ext-cm") << " <= " << var
<< " <= ";
675 printRationalApprox("nl-ext-cm", bounds
[r_use_index
][1]);
676 Trace("nl-ext-cm") << std::endl
;
678 addCheckModelBound(var
, bounds
[r_use_index
][0], bounds
[r_use_index
][1]);
681 d_check_model_solved
[eq
] = var
;
682 Trace("nl-ext-cms") << "...success, solved quadratic." << std::endl
;
687 bool NlModel::simpleCheckModelLit(Node lit
)
689 Trace("nl-ext-cms") << "*** Simple check-model lit for " << lit
<< "..."
693 Trace("nl-ext-cms") << " return constant." << std::endl
;
694 return lit
.getConst
<bool>();
696 NodeManager
* nm
= NodeManager::currentNM();
697 bool pol
= lit
.getKind() != kind::NOT
;
698 Node atom
= lit
.getKind() == kind::NOT
? lit
[0] : lit
;
700 if (atom
.getKind() == EQUAL
)
702 // x = a is ( x >= a ^ x <= a )
703 for (unsigned i
= 0; i
< 2; i
++)
705 Node lit
= nm
->mkNode(GEQ
, atom
[i
], atom
[1 - i
]);
710 lit
= Rewriter::rewrite(lit
);
711 bool success
= simpleCheckModelLit(lit
);
714 // false != true -> one conjunct of equality is false, we fail
715 // true != false -> one disjunct of disequality is true, we succeed
719 // both checks passed and polarity is true, or both checks failed and
723 else if (atom
.getKind() != GEQ
)
725 Trace("nl-ext-cms") << " failed due to unknown literal." << std::endl
;
728 // get the monomial sum
729 std::map
<Node
, Node
> msum
;
730 if (!ArithMSum::getMonomialSumLit(atom
, msum
))
732 Trace("nl-ext-cms") << " failed due to get msum." << std::endl
;
735 // simple interval analysis
736 if (simpleCheckModelMsum(msum
, pol
))
740 // can also try reasoning about univariate quadratic equations
741 Trace("nl-ext-cms-debug")
742 << "* Try univariate quadratic analysis..." << std::endl
;
743 std::vector
<Node
> vs_invalid
;
744 std::unordered_set
<Node
, NodeHashFunction
> vs
;
745 std::map
<Node
, Node
> v_a
;
746 std::map
<Node
, Node
> v_b
;
747 // get coefficients...
748 for (std::pair
<const Node
, Node
>& m
: msum
)
755 v_b
[v
] = m
.second
.isNull() ? d_one
: m
.second
;
758 else if (v
.getKind() == NONLINEAR_MULT
&& v
.getNumChildren() == 2
759 && v
[0] == v
[1] && v
[0].isVar())
761 v_a
[v
[0]] = m
.second
.isNull() ? d_one
: m
.second
;
766 vs_invalid
.push_back(v
);
770 // solve the valid variables...
771 Node invalid_vsum
= vs_invalid
.empty() ? d_zero
772 : (vs_invalid
.size() == 1
774 : nm
->mkNode(PLUS
, vs_invalid
));
775 // substitution to try
776 std::vector
<Node
> qvars
;
777 std::vector
<Node
> qsubs
;
778 for (const Node
& v
: vs
)
780 // is it a valid variable?
781 std::map
<Node
, std::pair
<Node
, Node
> >::iterator bit
=
782 d_check_model_bounds
.find(v
);
783 if (!expr::hasSubterm(invalid_vsum
, v
) && bit
!= d_check_model_bounds
.end())
785 std::map
<Node
, Node
>::iterator it
= v_a
.find(v
);
790 int asgn
= a
.getConst
<Rational
>().sgn();
792 Node t
= nm
->mkNode(MULT
, a
, v
, v
);
798 t
= nm
->mkNode(PLUS
, t
, nm
->mkNode(MULT
, b
, v
));
800 t
= Rewriter::rewrite(t
);
801 Trace("nl-ext-cms-debug") << "Trying to find min/max for quadratic "
802 << t
<< "..." << std::endl
;
803 Trace("nl-ext-cms-debug") << " a = " << a
<< std::endl
;
804 Trace("nl-ext-cms-debug") << " b = " << b
<< std::endl
;
805 // find maximal/minimal value on the interval
806 Node apex
= nm
->mkNode(
807 DIVISION
, nm
->mkNode(UMINUS
, b
), nm
->mkNode(MULT
, d_two
, a
));
808 apex
= Rewriter::rewrite(apex
);
809 Assert(apex
.isConst());
810 // for lower, upper, whether we are greater than the apex
813 for (unsigned r
= 0; r
< 2; r
++)
815 boundn
[r
] = r
== 0 ? bit
->second
.first
: bit
->second
.second
;
816 Node cmpn
= nm
->mkNode(GT
, boundn
[r
], apex
);
817 cmpn
= Rewriter::rewrite(cmpn
);
818 Assert(cmpn
.isConst());
819 cmp
[r
] = cmpn
.getConst
<bool>();
821 Trace("nl-ext-cms-debug") << " apex " << apex
<< std::endl
;
822 Trace("nl-ext-cms-debug")
823 << " lower " << boundn
[0] << ", cmp: " << cmp
[0] << std::endl
;
824 Trace("nl-ext-cms-debug")
825 << " upper " << boundn
[1] << ", cmp: " << cmp
[1] << std::endl
;
826 Assert(boundn
[0].getConst
<Rational
>()
827 <= boundn
[1].getConst
<Rational
>());
830 if (cmp
[0] != cmp
[1])
832 Assert(!cmp
[0] && cmp
[1]);
833 // does the sign match the bound?
834 if ((asgn
== 1) == pol
)
836 // the apex is the max/min value
838 Trace("nl-ext-cms-debug") << " ...set to apex." << std::endl
;
842 // it is one of the endpoints, plug in and compare
844 for (unsigned r
= 0; r
< 2; r
++)
846 qsubs
.push_back(boundn
[r
]);
847 Node ts
= t
.substitute(
848 qvars
.begin(), qvars
.end(), qsubs
.begin(), qsubs
.end());
849 tcmpn
[r
] = Rewriter::rewrite(ts
);
852 Node tcmp
= nm
->mkNode(LT
, tcmpn
[0], tcmpn
[1]);
853 Trace("nl-ext-cms-debug")
854 << " ...both sides of apex, compare " << tcmp
<< std::endl
;
855 tcmp
= Rewriter::rewrite(tcmp
);
856 Assert(tcmp
.isConst());
857 unsigned bindex_use
= (tcmp
.getConst
<bool>() == pol
) ? 1 : 0;
858 Trace("nl-ext-cms-debug")
859 << " ...set to " << (bindex_use
== 1 ? "upper" : "lower")
861 s
= boundn
[bindex_use
];
866 // both to one side of the apex
867 // we figure out which bound to use (lower or upper) based on
869 // (1) whether a's sign is positive,
870 // (2) whether we are greater than the apex of the parabola,
871 // (3) the polarity of the constraint, i.e. >= or <=.
872 // there are 8 cases of these factors, which we test here.
873 unsigned bindex_use
= (((asgn
== 1) == cmp
[0]) == pol
) ? 0 : 1;
874 Trace("nl-ext-cms-debug")
875 << " ...set to " << (bindex_use
== 1 ? "upper" : "lower")
877 s
= boundn
[bindex_use
];
881 Trace("nl-ext-cms") << "* set bound based on quadratic : " << v
882 << " -> " << s
<< std::endl
;
888 Assert(qvars
.size() == qsubs
.size());
890 lit
.substitute(qvars
.begin(), qvars
.end(), qsubs
.begin(), qsubs
.end());
891 slit
= Rewriter::rewrite(slit
);
892 return simpleCheckModelLit(slit
);
897 bool NlModel::simpleCheckModelMsum(const std::map
<Node
, Node
>& msum
, bool pol
)
899 Trace("nl-ext-cms-debug") << "* Try simple interval analysis..." << std::endl
;
900 NodeManager
* nm
= NodeManager::currentNM();
901 // map from transcendental functions to whether they were set to lower
903 bool simpleSuccess
= true;
904 std::map
<Node
, bool> set_bound
;
905 std::vector
<Node
> sum_bound
;
906 for (const std::pair
<const Node
, Node
>& m
: msum
)
911 sum_bound
.push_back(m
.second
.isNull() ? d_one
: m
.second
);
915 Trace("nl-ext-cms-debug") << "- monomial : " << v
<< std::endl
;
916 // --- whether we should set a lower bound for this monomial
918 (m
.second
.isNull() || m
.second
.getConst
<Rational
>().sgn() == 1)
920 Trace("nl-ext-cms-debug")
921 << "set bound to " << (set_lower
? "lower" : "upper") << std::endl
;
923 // --- Collect variables and factors in v
924 std::vector
<Node
> vars
;
925 std::vector
<unsigned> factors
;
926 if (v
.getKind() == NONLINEAR_MULT
)
928 unsigned last_start
= 0;
929 for (unsigned i
= 0, nchildren
= v
.getNumChildren(); i
< nchildren
; i
++)
931 // are we at the end?
932 if (i
+ 1 == nchildren
|| v
[i
+ 1] != v
[i
])
934 unsigned vfact
= 1 + (i
- last_start
);
935 last_start
= (i
+ 1);
936 vars
.push_back(v
[i
]);
937 factors
.push_back(vfact
);
944 factors
.push_back(1);
947 // --- Get the lower and upper bounds and sign information.
948 // Whether we have an (odd) number of negative factors in vars, apart
949 // from the variable at choose_index.
950 bool has_neg_factor
= false;
951 int choose_index
= -1;
952 std::vector
<Node
> ls
;
953 std::vector
<Node
> us
;
954 // the relevant sign information for variables with odd exponents:
955 // 1: both signs of the interval of this variable are positive,
956 // -1: both signs of the interval of this variable are negative.
957 std::vector
<int> signs
;
958 Trace("nl-ext-cms-debug") << "get sign information..." << std::endl
;
959 for (unsigned i
= 0, size
= vars
.size(); i
< size
; i
++)
962 unsigned vcfact
= factors
[i
];
963 if (Trace
.isOn("nl-ext-cms-debug"))
965 Trace("nl-ext-cms-debug") << "-- " << vc
;
968 Trace("nl-ext-cms-debug") << "^" << vcfact
;
970 Trace("nl-ext-cms-debug") << " ";
972 std::map
<Node
, std::pair
<Node
, Node
> >::iterator bit
=
973 d_check_model_bounds
.find(vc
);
974 // if there is a model bound for this term
975 if (bit
!= d_check_model_bounds
.end())
977 Node l
= bit
->second
.first
;
978 Node u
= bit
->second
.second
;
985 int lsgn
= l
.getConst
<Rational
>().sgn();
986 int usgn
= u
.getConst
<Rational
>().sgn();
987 Trace("nl-ext-cms-debug")
988 << "bound_sign(" << lsgn
<< "," << usgn
<< ") ";
993 // must have a negative factor
994 has_neg_factor
= !has_neg_factor
;
997 else if (choose_index
== -1)
999 // set the choose index to this
1005 // ambiguous, can't determine the bound
1007 << " failed due to ambiguious monomial." << std::endl
;
1012 Trace("nl-ext-cms-debug") << " -> " << vsign
<< std::endl
;
1013 signs
.push_back(vsign
);
1017 Trace("nl-ext-cms-debug") << std::endl
;
1019 << " failed due to unknown bound for " << vc
<< std::endl
;
1020 // should either assign a model bound or eliminate the variable
1026 // whether we will try to minimize/maximize (-1/1) the absolute value
1027 int setAbs
= (set_lower
== has_neg_factor
) ? 1 : -1;
1028 Trace("nl-ext-cms-debug")
1029 << "set absolute value to " << (setAbs
== 1 ? "maximal" : "minimal")
1032 std::vector
<Node
> vbs
;
1033 Trace("nl-ext-cms-debug") << "set bounds..." << std::endl
;
1034 for (unsigned i
= 0, size
= vars
.size(); i
< size
; i
++)
1037 unsigned vcfact
= factors
[i
];
1041 int vcsign
= signs
[i
];
1042 Trace("nl-ext-cms-debug")
1043 << "Bounds for " << vc
<< " : " << l
<< ", " << u
1044 << ", sign : " << vcsign
<< ", factor : " << vcfact
<< std::endl
;
1047 // by convention, always say it is lower if they are the same
1048 vc_set_lower
= true;
1049 Trace("nl-ext-cms-debug")
1050 << "..." << vc
<< " equal bound, set to lower" << std::endl
;
1054 if (vcfact
% 2 == 0)
1056 // minimize or maximize its absolute value
1057 Rational la
= l
.getConst
<Rational
>().abs();
1058 Rational ua
= u
.getConst
<Rational
>().abs();
1061 // by convention, always say it is lower if abs are the same
1062 vc_set_lower
= true;
1063 Trace("nl-ext-cms-debug")
1064 << "..." << vc
<< " equal abs, set to lower" << std::endl
;
1068 vc_set_lower
= (la
> ua
) == (setAbs
== 1);
1071 else if (signs
[i
] == 0)
1073 // we choose this index to match the overall set_lower
1074 vc_set_lower
= set_lower
;
1078 vc_set_lower
= (signs
[i
] != setAbs
);
1080 Trace("nl-ext-cms-debug")
1081 << "..." << vc
<< " set to " << (vc_set_lower
? "lower" : "upper")
1084 // check whether this is a conflicting bound
1085 std::map
<Node
, bool>::iterator itsb
= set_bound
.find(vc
);
1086 if (itsb
== set_bound
.end())
1088 set_bound
[vc
] = vc_set_lower
;
1090 else if (itsb
->second
!= vc_set_lower
)
1093 << " failed due to conflicting bound for " << vc
<< std::endl
;
1096 // must over/under approximate based on vc_set_lower, computed above
1097 Node vb
= vc_set_lower
? l
: u
;
1098 for (unsigned i
= 0; i
< vcfact
; i
++)
1107 Node vbound
= vbs
.size() == 1 ? vbs
[0] : nm
->mkNode(MULT
, vbs
);
1108 sum_bound
.push_back(ArithMSum::mkCoeffTerm(m
.second
, vbound
));
1111 // if the exact bound was computed via simple analysis above
1114 if (sum_bound
.size() > 1)
1116 bound
= nm
->mkNode(kind::PLUS
, sum_bound
);
1118 else if (sum_bound
.size() == 1)
1120 bound
= sum_bound
[0];
1126 // make the comparison
1127 Node comp
= nm
->mkNode(kind::GEQ
, bound
, d_zero
);
1130 comp
= comp
.negate();
1132 Trace("nl-ext-cms") << " comparison is : " << comp
<< std::endl
;
1133 comp
= Rewriter::rewrite(comp
);
1134 Assert(comp
.isConst());
1135 Trace("nl-ext-cms") << " returned : " << comp
<< std::endl
;
1136 return comp
== d_true
;
1139 bool NlModel::isRefineableTfFun(Node tf
)
1141 Assert(tf
.getKind() == SINE
|| tf
.getKind() == EXPONENTIAL
);
1142 if (tf
.getKind() == SINE
)
1144 // we do not consider e.g. sin( -1*x ), since considering sin( x ) will
1145 // have the same effect. We also do not consider sin(x+y) since this is
1146 // handled by introducing a fresh variable (see the map d_tr_base in
1147 // NonlinearExtension).
1154 Node c
= computeAbstractModelValue(tf
[0]);
1155 Assert(c
.isConst());
1156 int csign
= c
.getConst
<Rational
>().sgn();
1164 bool NlModel::getApproximateSqrt(Node c
, Node
& l
, Node
& u
, unsigned iter
) const
1166 Assert(c
.isConst());
1167 if (c
== d_one
|| c
== d_zero
)
1173 Rational rc
= c
.getConst
<Rational
>();
1175 Rational rl
= rc
< Rational(1) ? rc
: Rational(1);
1176 Rational ru
= rc
< Rational(1) ? Rational(1) : rc
;
1178 Rational half
= Rational(1) / Rational(2);
1179 while (count
< iter
)
1181 Rational curr
= half
* (rl
+ ru
);
1182 Rational curr_sq
= curr
* curr
;
1189 else if (curr_sq
< rc
)
1200 NodeManager
* nm
= NodeManager::currentNM();
1201 l
= nm
->mkConst(rl
);
1202 u
= nm
->mkConst(ru
);
1206 void NlModel::printModelValue(const char* c
, Node n
, unsigned prec
) const
1210 Trace(c
) << " " << n
<< " -> ";
1211 for (int i
= 1; i
>= 0; --i
)
1213 std::map
<Node
, Node
>::const_iterator it
= d_mv
[i
].find(n
);
1214 Assert(it
!= d_mv
[i
].end());
1215 if (it
->second
.isConst())
1217 printRationalApprox(c
, it
->second
, prec
);
1223 Trace(c
) << (i
== 1 ? " [actual: " : " ]");
1225 Trace(c
) << std::endl
;
1229 void NlModel::getModelValueRepair(std::map
<Node
, Node
>& arithModel
,
1230 std::map
<Node
, Node
>& approximations
)
1232 // Record the approximations we used. This code calls the
1233 // recordApproximation method of the model, which overrides the model
1234 // values for variables that we solved for, using techniques specific to
1236 Trace("nl-model") << "NlModel::getModelValueRepair:" << std::endl
;
1237 NodeManager
* nm
= NodeManager::currentNM();
1238 for (const std::pair
<const Node
, std::pair
<Node
, Node
> >& cb
:
1239 d_check_model_bounds
)
1241 Node l
= cb
.second
.first
;
1242 Node u
= cb
.second
.second
;
1247 Node pred
= nm
->mkNode(AND
, nm
->mkNode(GEQ
, v
, l
), nm
->mkNode(GEQ
, u
, v
));
1248 approximations
[v
] = pred
;
1249 Trace("nl-model") << v
<< " approximated as " << pred
<< std::endl
;
1255 Trace("nl-model") << v
<< " exact approximation is " << l
<< std::endl
;
1258 // Also record the exact values we used. An exact value can be seen as a
1259 // special kind approximation of the form (choice x. x = exact_value).
1260 // Notice that the above term gets rewritten such that the choice function
1262 for (size_t i
= 0, num
= d_check_model_vars
.size(); i
< num
; i
++)
1264 Node v
= d_check_model_vars
[i
];
1265 Node s
= d_check_model_subs
[i
];
1268 Trace("nl-model") << v
<< " solved is " << s
<< std::endl
;
1272 } // namespace arith
1273 } // namespace theory