From d02e1cb3eb74380495aa3ff9e57fd04e4411aa55 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Thu, 24 May 2018 14:01:44 -0500 Subject: [PATCH] Fixes for non-linear check model (#1974) --- src/theory/arith/nonlinear_extension.cpp | 79 +++++++++++++++++++----- 1 file changed, 64 insertions(+), 15 deletions(-) diff --git a/src/theory/arith/nonlinear_extension.cpp b/src/theory/arith/nonlinear_extension.cpp index c9a4c5075..be8f22222 100644 --- a/src/theory/arith/nonlinear_extension.cpp +++ b/src/theory/arith/nonlinear_extension.cpp @@ -1056,6 +1056,9 @@ void NonlinearExtension::addCheckModelSubstitution(TNode v, TNode s) void NonlinearExtension::addCheckModelBound(TNode v, TNode l, TNode u) { Assert(!hasCheckModelAssignment(v)); + Assert(l.isConst()); + Assert(u.isConst()); + Assert(l.getConst() <= u.getConst()); d_check_model_bounds[v] = std::pair(l, u); } @@ -1294,6 +1297,14 @@ bool NonlinearExtension::solveEqualitySimple(Node eq) MULT, coeffa, nm->mkNode(r == 0 ? MINUS : PLUS, negb, val)); approx = Rewriter::rewrite(approx); bounds[r][b] = approx; + Assert(approx.isConst()); + } + if (bounds[r][0].getConst() > bounds[r][1].getConst()) + { + // ensure bound is (lower, upper) + Node tmp = bounds[r][0]; + bounds[r][0] = bounds[r][1]; + bounds[r][1] = tmp; } Node diff = nm->mkNode(MINUS, @@ -1448,26 +1459,31 @@ bool NonlinearExtension::simpleCheckModelLit(Node lit) t = Rewriter::rewrite(t); Trace("nl-ext-cms-debug") << "Trying to find min/max for quadratic " << t << "..." << std::endl; + Trace("nl-ext-cms-debug") << " a = " << a << std::endl; + Trace("nl-ext-cms-debug") << " b = " << b << std::endl; // find maximal/minimal value on the interval Node apex = nm->mkNode( DIVISION, nm->mkNode(UMINUS, b), nm->mkNode(MULT, d_two, a)); apex = Rewriter::rewrite(apex); Assert(apex.isConst()); + // for lower, upper, whether we are greater than the apex bool cmp[2]; Node boundn[2]; for (unsigned r = 0; r < 2; r++) { boundn[r] = r == 0 ? bit->second.first : bit->second.second; - Node cmpn = nm->mkNode(LT, boundn[r], apex); + Node cmpn = nm->mkNode(GT, boundn[r], apex); cmpn = Rewriter::rewrite(cmpn); Assert(cmpn.isConst()); cmp[r] = cmpn.getConst(); } Trace("nl-ext-cms-debug") << " apex " << apex << std::endl; Trace("nl-ext-cms-debug") - << " min " << boundn[0] << ", cmp: " << cmp[0] << std::endl; + << " lower " << boundn[0] << ", cmp: " << cmp[0] << std::endl; Trace("nl-ext-cms-debug") - << " max " << boundn[1] << ", cmp: " << cmp[1] << std::endl; + << " upper " << boundn[1] << ", cmp: " << cmp[1] << std::endl; + Assert(boundn[0].getConst() + <= boundn[1].getConst()); Node s; qvars.push_back(v); if (cmp[0] != cmp[1]) @@ -1497,19 +1513,25 @@ bool NonlinearExtension::simpleCheckModelLit(Node lit) << " ...both sides of apex, compare " << tcmp << std::endl; tcmp = Rewriter::rewrite(tcmp); Assert(tcmp.isConst()); - unsigned bindex_use = tcmp.getConst() == pol ? 1 : 0; + unsigned bindex_use = (tcmp.getConst() == pol) ? 1 : 0; Trace("nl-ext-cms-debug") - << " ...set to " << (bindex_use == 1 ? "max" : "min") + << " ...set to " << (bindex_use == 1 ? "upper" : "lower") << std::endl; s = boundn[bindex_use]; } } else { - // both to one side - unsigned bindex_use = ((asgn == 1) == cmp[0]) == pol ? 0 : 1; + // both to one side of the apex + // we figure out which bound to use (lower or upper) based on + // three factors: + // (1) whether a's sign is positive, + // (2) whether we are greater than the apex of the parabola, + // (3) the polarity of the constraint, i.e. >= or <=. + // there are 8 cases of these factors, which we test here. + unsigned bindex_use = (((asgn == 1) == cmp[0]) == pol) ? 0 : 1; Trace("nl-ext-cms-debug") - << " ...set to " << (bindex_use == 1 ? "max" : "min") + << " ...set to " << (bindex_use == 1 ? "upper" : "lower") << std::endl; s = boundn[bindex_use]; } @@ -1589,6 +1611,9 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map& msum, int choose_index = -1; std::vector ls; std::vector us; + // the relevant sign information for variables with odd exponents: + // 1: both signs of the interval of this variable are positive, + // -1: both signs of the interval of this variable are negative. std::vector signs; Trace("nl-ext-cms-debug") << "get sign information..." << std::endl; for (unsigned i = 0, size = vars.size(); i < size; i++) @@ -1613,9 +1638,10 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map& msum, Node u = bit->second.second; ls.push_back(l); us.push_back(u); - int vsign = 1; + int vsign = 0; if (vcfact % 2 == 1) { + vsign = 1; int lsgn = l.getConst().sgn(); int usgn = u.getConst().sgn(); Trace("nl-ext-cms-debug") @@ -1658,7 +1684,10 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map& msum, } } // whether we will try to minimize/maximize (-1/1) the absolute value - int minimizeAbs = set_lower == has_neg_factor ? -1 : 1; + int setAbs = (set_lower == has_neg_factor) ? 1 : -1; + Trace("nl-ext-cms-debug") + << "set absolute value to " << (setAbs == 1 ? "maximal" : "minimal") + << std::endl; std::vector vbs; Trace("nl-ext-cms-debug") << "set bounds..." << std::endl; @@ -1669,6 +1698,10 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map& msum, Node l = ls[i]; Node u = us[i]; bool vc_set_lower; + int vcsign = signs[i]; + Trace("nl-ext-cms-debug") + << "Bounds for " << vc << " : " << l << ", " << u + << ", sign : " << vcsign << ", factor : " << vcfact << std::endl; if (l == u) { // by convention, always say it is lower if they are the same @@ -1678,15 +1711,31 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map& msum, } else { - if (signs[i] == 0) + if (vcfact % 2 == 0) + { + // minimize or maximize its absolute value + Rational la = l.getConst().abs(); + Rational ua = u.getConst().abs(); + if (la == ua) + { + // by convention, always say it is lower if abs are the same + vc_set_lower = true; + Trace("nl-ext-cms-debug") + << "..." << vc << " equal abs, set to lower" << std::endl; + } + else + { + vc_set_lower = (la > ua) == (setAbs == 1); + } + } + else if (signs[i] == 0) { // we choose this index to match the overall set_lower vc_set_lower = set_lower; } else { - // minimize or maximize its absolute value - vc_set_lower = (signs[i] == minimizeAbs); + vc_set_lower = (signs[i] == setAbs); } Trace("nl-ext-cms-debug") << "..." << vc << " set to " << (vc_set_lower ? "lower" : "upper") @@ -1704,8 +1753,8 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map& msum, << " failed due to conflicting bound for " << vc << std::endl; return false; } - // must over/under approximate - Node vb = set_lower ? l : u; + // must over/under approximate based on vc_set_lower, computed above + Node vb = vc_set_lower ? l : u; for (unsigned i = 0; i < vcfact; i++) { vbs.push_back(vb); -- 2.30.2