Improvements to simple transcendental function check model. (#1823)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 30 Apr 2018 01:45:37 +0000 (20:45 -0500)
committerGitHub <noreply@github.com>
Mon, 30 Apr 2018 01:45:37 +0000 (20:45 -0500)
src/theory/arith/nonlinear_extension.cpp
src/theory/arith/nonlinear_extension.h

index e81dcc8811f3de8e0251d2b275df86e30424db95..01e6e2ff6a93ad99869d397e7e7499f82f257856 100644 (file)
@@ -885,8 +885,11 @@ bool NonlinearExtension::checkModelTf(const std::vector<Node>& assertions)
 
 bool NonlinearExtension::simpleCheckModelTfLit(Node lit)
 {
-  Trace("nl-ext-tf-check-model-simple") << "simple check-model for " << lit
-                                        << "..." << std::endl;
+  Trace("nl-ext-cms") << "simple check-model for " << lit << "..." << std::endl;
+  if (lit.isConst() && lit.getConst<bool>())
+  {
+    return true;
+  }
   NodeManager* nm = NodeManager::currentNM();
   bool pol = lit.getKind() != kind::NOT;
   Node atom = lit.getKind() == kind::NOT ? lit[0] : lit;
@@ -909,31 +912,165 @@ bool NonlinearExtension::simpleCheckModelTfLit(Node lit)
         }
         else
         {
-          std::map<Node, std::pair<Node, Node> >::iterator bit =
-              d_tf_check_model_bounds.find(v);
-          if (bit != d_tf_check_model_bounds.end())
+          Trace("nl-ext-cms-debug") << "--- monomial : " << v << std::endl;
+          // --- whether we should set a lower bound for this monomial
+          bool set_lower =
+              (m.second.isNull() || m.second.getConst<Rational>().sgn() == 1)
+              == pol;
+          Trace("nl-ext-cms-debug")
+              << "set bound to " << (set_lower ? "lower" : "upper")
+              << std::endl;
+
+          // --- Collect variables and factors in v
+          std::vector<Node> vars;
+          std::vector<unsigned> factors;
+          if (v.getKind() == NONLINEAR_MULT)
           {
-            bool set_lower =
-                (m.second.isNull() || m.second.getConst<Rational>().sgn() == 1)
-                == pol;
-            std::map<Node, bool>::iterator itsb = set_bound.find(v);
-            if (itsb != set_bound.end() && itsb->second != set_lower)
+            unsigned last_start = 0;
+            for (unsigned i = 0, nchildren = v.getNumChildren(); i < nchildren;
+                 i++)
             {
-              Trace("nl-ext-tf-check-model-simple")
-                  << "  failed due to conflicting bound for " << v << std::endl;
-              return false;
+              // are we at the end?
+              if (i + 1 == nchildren || v[i + 1] != v[i])
+              {
+                unsigned vfact = 1 + (i - last_start);
+                last_start = (i + 1);
+                vars.push_back(v[i]);
+                factors.push_back(vfact);
+              }
             }
-            set_bound[v] = set_lower;
-            // must over/under approximate
-            Node vbound = set_lower ? bit->second.first : bit->second.second;
-            sum_bound.push_back(ArithMSum::mkCoeffTerm(m.second, vbound));
           }
           else
           {
-            Trace("nl-ext-tf-check-model-simple")
-                << "  failed due to unknown bound for " << v << std::endl;
-            return false;
+            vars.push_back(v);
+            factors.push_back(1);
+          }
+
+          // --- Get the lower and upper bounds and sign information.
+          // Whether we have an (odd) number of negative factors in vars, apart
+          // from the variable at choose_index.
+          bool has_neg_factor = false;
+          int choose_index = -1;
+          std::vector<Node> ls;
+          std::vector<Node> us;
+          std::vector<int> signs;
+          Trace("nl-ext-cms-debug") << "get sign information..." << std::endl;
+          for (unsigned i = 0, size = vars.size(); i < size; i++)
+          {
+            Node vc = vars[i];
+            unsigned vcfact = factors[i];
+            if (Trace.isOn("nl-ext-cms-debug"))
+            {
+              Trace("nl-ext-cms-debug") << "* " << vc;
+              if (vcfact > 1)
+              {
+                Trace("nl-ext-cms-debug") << "^" << vcfact;
+              }
+              Trace("nl-ext-cms-debug") << " ";
+            }
+            std::map<Node, std::pair<Node, Node> >::iterator bit =
+                d_tf_check_model_bounds.find(vc);
+            if (bit != d_tf_check_model_bounds.end())
+            {
+              Node l = bit->second.first;
+              Node u = bit->second.second;
+              ls.push_back(l);
+              us.push_back(u);
+              int vsign = 1;
+              if (vcfact % 2 == 1)
+              {
+                int lsgn = l.getConst<Rational>().sgn();
+                int usgn = u.getConst<Rational>().sgn();
+                Trace("nl-ext-cms-debug")
+                    << "bound_sign(" << lsgn << "," << usgn << ") ";
+                if (lsgn == -1)
+                {
+                  if (usgn < 1)
+                  {
+                    // must have a negative factor
+                    has_neg_factor = !has_neg_factor;
+                    vsign = -1;
+                  }
+                  else if (choose_index == -1)
+                  {
+                    // set the choose index to this
+                    choose_index = i;
+                    vsign = 0;
+                  }
+                  else
+                  {
+                    // ambiguous, can't determine the bound
+                    return false;
+                  }
+                }
+              }
+              Trace("nl-ext-cms-debug") << " -> " << vsign << std::endl;
+              signs.push_back(vsign);
+            }
+            else
+            {
+              Trace("nl-ext-cms-debug") << std::endl;
+              Trace("nl-ext-cms")
+                  << "  failed due to unknown bound for " << vc << std::endl;
+              return false;
+            }
+          }
+          // whether we will try to minimize/maximize (-1/1) the absolute value
+          int minimizeAbs = set_lower == has_neg_factor ? -1 : 1;
+
+          std::vector<Node> vbs;
+          Trace("nl-ext-cms-debug") << "set bounds..." << std::endl;
+          for (unsigned i = 0, size = vars.size(); i < size; i++)
+          {
+            Node vc = vars[i];
+            unsigned vcfact = factors[i];
+            Node l = ls[i];
+            Node u = us[i];
+            bool vc_set_lower;
+            if (l == u)
+            {
+              // by convention, always say it is lower if they are the same
+              vc_set_lower = true;
+              Trace("nl-ext-cms-debug")
+                  << "..." << vc << " equal bound, set to lower" << std::endl;
+            }
+            else
+            {
+              if (signs[i] == 0)
+              {
+                // we choose this index to match the overall set_lower
+                vc_set_lower = set_lower;
+              }
+              else
+              {
+                // minimize or maximize its absolute value
+                vc_set_lower = (signs[i] == minimizeAbs);
+              }
+              Trace("nl-ext-cms-debug")
+                  << "..." << vc << " set to "
+                  << (vc_set_lower ? "lower" : "upper") << std::endl;
+            }
+            // check whether this is a conflicting bound
+            std::map<Node, bool>::iterator itsb = set_bound.find(vc);
+            if (itsb == set_bound.end())
+            {
+              set_bound[vc] = vc_set_lower;
+            }
+            else if (itsb->second != vc_set_lower)
+            {
+              Trace("nl-ext-cms") << "  failed due to conflicting bound for "
+                                  << vc << std::endl;
+              return false;
+            }
+            // must over/under approximate
+            Node vb = set_lower ? l : u;
+            for (unsigned i = 0; i < vcfact; i++)
+            {
+              vbs.push_back(vb);
+            }
           }
+          Node vbound = vbs.size() == 1 ? vbs[0] : nm->mkNode(MULT, vbs);
+          sum_bound.push_back(ArithMSum::mkCoeffTerm(m.second, vbound));
         }
       }
       Node bound;
@@ -954,12 +1091,10 @@ bool NonlinearExtension::simpleCheckModelTfLit(Node lit)
       {
         comp = comp.negate();
       }
-      Trace("nl-ext-tf-check-model-simple") << "  comparison is : " << comp
-                                            << std::endl;
+      Trace("nl-ext-cms") << "  comparison is : " << comp << std::endl;
       comp = Rewriter::rewrite(comp);
       Assert(comp.isConst());
-      Trace("nl-ext-tf-check-model-simple") << "  returned : " << comp
-                                            << std::endl;
+      Trace("nl-ext-cms") << "  returned : " << comp << std::endl;
       return comp == d_true;
     }
   }
@@ -982,10 +1117,12 @@ bool NonlinearExtension::simpleCheckModelTfLit(Node lit)
         return success;
       }
     }
+    // both checks passed and polarity is true, or both checks failed and
+    // polarity is false
+    return pol;
   }
 
-  Trace("nl-ext-tf-check-model-simple") << "  failed due to unknown literal."
-                                        << std::endl;
+  Trace("nl-ext-cms") << "  failed due to unknown literal." << std::endl;
   return false;
 }
 
index c7e6b2b2af6147f79a3bb04b5712e6ee64600b46..6985f69ddc7f18f41fb709a105cee49f4696aae8 100644 (file)
@@ -281,6 +281,7 @@ class NonlinearExtension {
    *   2.0*sin( 1 ) > 1.5
    *   -1.0*sin( 1 ) < -0.79
    *   -1.0*sin( 1 ) > -0.91
+   *   sin( 1 )*sin( 1 ) + sin( 1 ) > 0.0
    * It will return false for literals like:
    *   sin( 1 ) > 0.85
    * It will also return false for literals like: