Fixes for non-linear check model (#1974)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 24 May 2018 19:01:44 +0000 (14:01 -0500)
committerGitHub <noreply@github.com>
Thu, 24 May 2018 19:01:44 +0000 (14:01 -0500)
src/theory/arith/nonlinear_extension.cpp

index c9a4c5075f84b61d6433521cb668e27b86757b57..be8f22222598385dc54cc7c1f804fa4e1d6c3499 100644 (file)
@@ -1056,6 +1056,9 @@ void NonlinearExtension::addCheckModelSubstitution(TNode v, TNode s)
 void NonlinearExtension::addCheckModelBound(TNode v, TNode l, TNode u)
 {
   Assert(!hasCheckModelAssignment(v));
+  Assert(l.isConst());
+  Assert(u.isConst());
+  Assert(l.getConst<Rational>() <= u.getConst<Rational>());
   d_check_model_bounds[v] = std::pair<Node, Node>(l, u);
 }
 
@@ -1294,6 +1297,14 @@ bool NonlinearExtension::solveEqualitySimple(Node eq)
           MULT, coeffa, nm->mkNode(r == 0 ? MINUS : PLUS, negb, val));
       approx = Rewriter::rewrite(approx);
       bounds[r][b] = approx;
+      Assert(approx.isConst());
+    }
+    if (bounds[r][0].getConst<Rational>() > bounds[r][1].getConst<Rational>())
+    {
+      // ensure bound is (lower, upper)
+      Node tmp = bounds[r][0];
+      bounds[r][0] = bounds[r][1];
+      bounds[r][1] = tmp;
     }
     Node diff =
         nm->mkNode(MINUS,
@@ -1448,26 +1459,31 @@ bool NonlinearExtension::simpleCheckModelLit(Node lit)
         t = Rewriter::rewrite(t);
         Trace("nl-ext-cms-debug") << "Trying to find min/max for quadratic "
                                   << t << "..." << std::endl;
+        Trace("nl-ext-cms-debug") << "    a = " << a << std::endl;
+        Trace("nl-ext-cms-debug") << "    b = " << b << std::endl;
         // find maximal/minimal value on the interval
         Node apex = nm->mkNode(
             DIVISION, nm->mkNode(UMINUS, b), nm->mkNode(MULT, d_two, a));
         apex = Rewriter::rewrite(apex);
         Assert(apex.isConst());
+        // for lower, upper, whether we are greater than the apex
         bool cmp[2];
         Node boundn[2];
         for (unsigned r = 0; r < 2; r++)
         {
           boundn[r] = r == 0 ? bit->second.first : bit->second.second;
-          Node cmpn = nm->mkNode(LT, boundn[r], apex);
+          Node cmpn = nm->mkNode(GT, boundn[r], apex);
           cmpn = Rewriter::rewrite(cmpn);
           Assert(cmpn.isConst());
           cmp[r] = cmpn.getConst<bool>();
         }
         Trace("nl-ext-cms-debug") << "  apex " << apex << std::endl;
         Trace("nl-ext-cms-debug")
-            << "  min " << boundn[0] << ", cmp: " << cmp[0] << std::endl;
+            << "  lower " << boundn[0] << ", cmp: " << cmp[0] << std::endl;
         Trace("nl-ext-cms-debug")
-            << "  max " << boundn[1] << ", cmp: " << cmp[1] << std::endl;
+            << "  upper " << boundn[1] << ", cmp: " << cmp[1] << std::endl;
+        Assert(boundn[0].getConst<Rational>()
+               <= boundn[1].getConst<Rational>());
         Node s;
         qvars.push_back(v);
         if (cmp[0] != cmp[1])
@@ -1497,19 +1513,25 @@ bool NonlinearExtension::simpleCheckModelLit(Node lit)
                 << "  ...both sides of apex, compare " << tcmp << std::endl;
             tcmp = Rewriter::rewrite(tcmp);
             Assert(tcmp.isConst());
-            unsigned bindex_use = tcmp.getConst<bool>() == pol ? 1 : 0;
+            unsigned bindex_use = (tcmp.getConst<bool>() == pol) ? 1 : 0;
             Trace("nl-ext-cms-debug")
-                << "  ...set to " << (bindex_use == 1 ? "max" : "min")
+                << "  ...set to " << (bindex_use == 1 ? "upper" : "lower")
                 << std::endl;
             s = boundn[bindex_use];
           }
         }
         else
         {
-          // both to one side
-          unsigned bindex_use = ((asgn == 1) == cmp[0]) == pol ? 0 : 1;
+          // both to one side of the apex
+          // we figure out which bound to use (lower or upper) based on
+          // three factors:
+          // (1) whether a's sign is positive,
+          // (2) whether we are greater than the apex of the parabola,
+          // (3) the polarity of the constraint, i.e. >= or <=.
+          // there are 8 cases of these factors, which we test here.
+          unsigned bindex_use = (((asgn == 1) == cmp[0]) == pol) ? 0 : 1;
           Trace("nl-ext-cms-debug")
-              << "  ...set to " << (bindex_use == 1 ? "max" : "min")
+              << "  ...set to " << (bindex_use == 1 ? "upper" : "lower")
               << std::endl;
           s = boundn[bindex_use];
         }
@@ -1589,6 +1611,9 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map<Node, Node>& msum,
       int choose_index = -1;
       std::vector<Node> ls;
       std::vector<Node> us;
+      // the relevant sign information for variables with odd exponents:
+      //   1: both signs of the interval of this variable are positive,
+      //  -1: both signs of the interval of this variable are negative.
       std::vector<int> signs;
       Trace("nl-ext-cms-debug") << "get sign information..." << std::endl;
       for (unsigned i = 0, size = vars.size(); i < size; i++)
@@ -1613,9 +1638,10 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map<Node, Node>& msum,
           Node u = bit->second.second;
           ls.push_back(l);
           us.push_back(u);
-          int vsign = 1;
+          int vsign = 0;
           if (vcfact % 2 == 1)
           {
+            vsign = 1;
             int lsgn = l.getConst<Rational>().sgn();
             int usgn = u.getConst<Rational>().sgn();
             Trace("nl-ext-cms-debug")
@@ -1658,7 +1684,10 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map<Node, Node>& msum,
         }
       }
       // whether we will try to minimize/maximize (-1/1) the absolute value
-      int minimizeAbs = set_lower == has_neg_factor ? -1 : 1;
+      int setAbs = (set_lower == has_neg_factor) ? 1 : -1;
+      Trace("nl-ext-cms-debug")
+          << "set absolute value to " << (setAbs == 1 ? "maximal" : "minimal")
+          << std::endl;
 
       std::vector<Node> vbs;
       Trace("nl-ext-cms-debug") << "set bounds..." << std::endl;
@@ -1669,6 +1698,10 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map<Node, Node>& msum,
         Node l = ls[i];
         Node u = us[i];
         bool vc_set_lower;
+        int vcsign = signs[i];
+        Trace("nl-ext-cms-debug")
+            << "Bounds for " << vc << " : " << l << ", " << u
+            << ", sign : " << vcsign << ", factor : " << vcfact << std::endl;
         if (l == u)
         {
           // by convention, always say it is lower if they are the same
@@ -1678,15 +1711,31 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map<Node, Node>& msum,
         }
         else
         {
-          if (signs[i] == 0)
+          if (vcfact % 2 == 0)
+          {
+            // minimize or maximize its absolute value
+            Rational la = l.getConst<Rational>().abs();
+            Rational ua = u.getConst<Rational>().abs();
+            if (la == ua)
+            {
+              // by convention, always say it is lower if abs are the same
+              vc_set_lower = true;
+              Trace("nl-ext-cms-debug")
+                  << "..." << vc << " equal abs, set to lower" << std::endl;
+            }
+            else
+            {
+              vc_set_lower = (la > ua) == (setAbs == 1);
+            }
+          }
+          else if (signs[i] == 0)
           {
             // we choose this index to match the overall set_lower
             vc_set_lower = set_lower;
           }
           else
           {
-            // minimize or maximize its absolute value
-            vc_set_lower = (signs[i] == minimizeAbs);
+            vc_set_lower = (signs[i] == setAbs);
           }
           Trace("nl-ext-cms-debug")
               << "..." << vc << " set to " << (vc_set_lower ? "lower" : "upper")
@@ -1704,8 +1753,8 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map<Node, Node>& msum,
               << "  failed due to conflicting bound for " << vc << std::endl;
           return false;
         }
-        // must over/under approximate
-        Node vb = set_lower ? l : u;
+        // must over/under approximate based on vc_set_lower, computed above
+        Node vb = vc_set_lower ? l : u;
         for (unsigned i = 0; i < vcfact; i++)
         {
           vbs.push_back(vb);