Fix more misuses of arithmetic subtypes (#8601)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 12 Apr 2022 13:58:18 +0000 (08:58 -0500)
committerGitHub <noreply@github.com>
Tue, 12 Apr 2022 13:58:18 +0000 (13:58 +0000)
src/theory/arith/arith_utilities.cpp
src/theory/arith/arith_utilities.h
src/theory/arith/nl/ext/proof_checker.cpp
src/theory/arith/nl/ext/split_zero_check.cpp
src/theory/arith/nl/ext/tangent_plane_check.cpp
src/theory/arith/nl/nl_model.cpp
src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp
src/theory/quantifiers/cegqi/ceg_instantiator.cpp

index d6bec5b9aee61ca2b6c6d5a8516913874561f083..f20964a57eb6fbbd7bf5dc43823c09b9a102e23d 100644 (file)
@@ -110,6 +110,11 @@ bool isZero(const Node& n)
   return n.isConst() && n.getConst<Rational>().sgn() == 0;
 }
 
+Node mkOne(const TypeNode& tn, bool isNeg)
+{
+  return NodeManager::currentNM()->mkConstRealOrInt(tn, isNeg ? -1 : 1);
+}
+
 bool isTranscendentalKind(Kind k)
 {
   // many operators are eliminated during rewriting
index c1412257c2351827e5a4402d39964bb91672c876..0f85468a28b4e1dfc73cd43dd3c457a3ae053a08 100644 (file)
@@ -228,6 +228,9 @@ Node mkZero(const TypeNode& tn);
 /** Is n (integer or real) zero? */
 bool isZero(const Node& n);
 
+/** Make one of the given type, maybe negated */
+Node mkOne(const TypeNode& tn, bool isNeg = false);
+
 // Returns an node that is the identity of a select few kinds.
 inline Node getIdentityType(const TypeNode& tn, Kind k)
 {
index 51ccd5576ce4251de7cf23c0b6d1c0fd8c8f38a9..cae590492813a0f2b5180f05b07ecf3ef5f50e76 100644 (file)
@@ -37,11 +37,6 @@ Node ExtProofRuleChecker::checkInternal(PfRule id,
                                         const std::vector<Node>& args)
 {
   NodeManager* nm = NodeManager::currentNM();
-  auto zero = nm->mkConst<Rational>(CONST_RATIONAL, 0);
-  auto one = nm->mkConst<Rational>(CONST_RATIONAL, 1);
-  auto mone = nm->mkConst<Rational>(CONST_RATIONAL, -1);
-  auto pi = nm->mkNullaryOperator(nm->realType(), Kind::PI);
-  auto mpi = nm->mkNode(Kind::MULT, mone, pi);
   Trace("nl-ext-checker") << "Checking " << id << std::endl;
   for (const auto& c : children)
   {
@@ -103,6 +98,7 @@ Node ExtProofRuleChecker::checkInternal(PfRule id,
         }
       }
     }
+    Node zero = nm->mkConstRealOrInt(mon.getType(), Rational(0));
     switch (sign)
     {
       case -1:
index 5c97e1a18a4978052c6636357a769be99ef0defb..869b0e0d65db57c6d4fdd91db7e54eadd02046cb 100644 (file)
@@ -17,7 +17,7 @@
 
 #include "expr/node.h"
 #include "proof/proof.h"
-#include "theory/arith/arith_msum.h"
+#include "theory/arith/arith_utilities.h"
 #include "theory/arith/inference_manager.h"
 #include "theory/arith/nl/ext/ext_state.h"
 #include "theory/arith/nl/nl_model.h"
@@ -29,7 +29,7 @@ namespace arith {
 namespace nl {
 
 SplitZeroCheck::SplitZeroCheck(Env& env, ExtState* data)
-    : EnvObj(env), d_data(data), d_zero_split(d_data->d_env.getUserContext())
+    : EnvObj(env), d_data(data), d_zero_split(userContext())
 {
 }
 
@@ -40,7 +40,7 @@ void SplitZeroCheck::check()
     Node v = d_data->d_ms_vars[i];
     if (d_zero_split.insert(v))
     {
-      Node eq = rewrite(v.eqNode(d_data->d_zero));
+      Node eq = rewrite(v.eqNode(mkZero(v.getType())));
       Node lem = eq.orNode(eq.negate());
       CDProof* proof = nullptr;
       if (d_data->isProofEnabled())
index 8f23cb957075d10de977738162ca4ac5fe38d5f4..add09b67d78e05675b8a655d069e30b3712e07f7 100644 (file)
@@ -18,6 +18,7 @@
 #include "expr/node.h"
 #include "proof/proof.h"
 #include "theory/arith/arith_msum.h"
+#include "theory/arith/arith_utilities.h"
 #include "theory/arith/inference_manager.h"
 #include "theory/arith/nl/ext/ext_state.h"
 #include "theory/arith/nl/nl_model.h"
@@ -63,7 +64,8 @@ void TangentPlaneCheck::check(bool asWaitingLemmas)
     for (unsigned j = 0; j < it->second.size(); j++)
     {
       Node tc = it->second[j];
-      if (tc != d_data->d_one)
+      Node one = mkOne(tc.getType());
+      if (tc != one)
       {
         Node tc_diff = d_data->d_mdb.getContainsDiffNl(tc, t);
         Assert(!tc_diff.isNull());
index e13294f10faf1af4dc8b10a8b141d9f47e425b0c..4c97801f0f16447f0008cbc7017cf60a391a23c5 100644 (file)
@@ -519,8 +519,7 @@ bool NlModel::solveEqualitySimple(Node eq,
     Assert(false);
     return false;
   }
-  Node val = nm->mkConst(CONST_RATIONAL,
-                         -c.getConst<Rational>() / b.getConst<Rational>());
+  Node val = nm->mkConstReal(-c.getConst<Rational>() / b.getConst<Rational>());
   if (TraceIsOn("nl-ext-cm"))
   {
     Trace("nl-ext-cm") << "check-model-bound : exact : " << var << " = ";
@@ -1086,8 +1085,9 @@ Node NlModel::getValueInternal(TNode n)
   // to mapping from the linear solver. This ensures that if the nonlinear
   // solver assumes that n = 0, then this assumption is recorded in the overall
   // model.
-  d_arithVal[n] = d_zero;
-  return d_zero;
+  Node zero = mkZero(n.getType());
+  d_arithVal[n] = zero;
+  return zero;
 }
 
 bool NlModel::hasAssignment(Node v) const
index d881fda5d938683341680f92462b523035b83a59..33b5f08f2fca8f0e88b2b8546baabb23a9e66413 100644 (file)
@@ -421,7 +421,7 @@ bool ArithInstantiator::processAssertions(CegInstantiator* ci,
                 MULT,
                 nm->mkConstReal(Rational(1)
                                 / d_mbp_coeff[rr][j].getConst<Rational>()),
-                value[t]);
+                nm->mkNode(TO_REAL, value[t]));
             value[t] = rewrite(value[t]);
           }
           // check if new best, if we have not already set it.
@@ -430,15 +430,20 @@ bool ArithInstantiator::processAssertions(CegInstantiator* ci,
             Assert(!value[t].isNull() && !best_bound_value[t].isNull());
             if (value[t] != best_bound_value[t])
             {
-              Kind k = rr == 0 ? GEQ : LEQ;
-              Node cmp_bound = nm->mkNode(k, value[t], best_bound_value[t]);
-              cmp_bound = rewrite(cmp_bound);
               // Should be comparing two constant values which should rewrite
               // to a constant. If a step failed, we assume that this is not
               // the new best bound. We might not be comparing constant
               // values (for instance if transcendental functions are
-              // involved), in which case we do update the best bound value.
-              if (!cmp_bound.isConst() || !cmp_bound.getConst<bool>())
+              // involved), in which case we do not update the best bound value.
+              if (!value[t].isConst() || !best_bound_value[t].isConst())
+              {
+                new_best = false;
+                break;
+              }
+              Rational rt = value[t].getConst<Rational>();
+              Rational brt = best_bound_value[t].getConst<Rational>();
+              bool cmp = rr == 0 ? rt >= brt : rt <= brt;
+              if (!cmp)
               {
                 new_best = false;
                 break;
index ac0aca8935c8d5c4c179de131efc50f605f5b399..f09faf22fd93d18431bceb8e51a65733a66b6ce8 100644 (file)
@@ -1146,6 +1146,7 @@ bool CegInstantiator::canApplyBasicSubstitution( Node n, std::vector< Node >& no
 
 Node CegInstantiator::applySubstitution( TypeNode tn, Node n, std::vector< Node >& vars, std::vector< Node >& subs, std::vector< TermProperties >& prop, 
                                          std::vector< Node >& non_basic, TermProperties& pv_prop, bool try_coeff ) {
+  NodeManager* nm = NodeManager::currentNM();
   n = rewrite(n);
   computeProgVars( n );
   bool is_basic = canApplyBasicSubstitution( n, non_basic );
@@ -1216,31 +1217,31 @@ Node CegInstantiator::applySubstitution( TypeNode tn, Node n, std::vector< Node
           pv_prop.d_coeff = rewrite(pv_prop.d_coeff);
           Trace("sygus-si-apply-subs-debug") << "Combined coeff : " << pv_prop.d_coeff << std::endl;
           std::vector< Node > children;
+          TypeNode type = n.getType();
           for( std::map< Node, Node >::iterator it = msum.begin(); it != msum.end(); ++it ){
-            Node c_coeff;
-            if( !msum_coeff[it->first].isNull() ){
-              c_coeff = rewrite(NodeManager::currentNM()->mkConstReal(
-                  pv_prop.d_coeff.getConst<Rational>()
-                  / msum_coeff[it->first].getConst<Rational>()));
-            }else{
-              c_coeff = pv_prop.d_coeff;
+            Rational c_coeff = pv_prop.d_coeff.getConst<Rational>();
+            Node mc = msum_coeff[it->first];
+            if (!mc.isNull())
+            {
+              Assert(mc.isConst());
+              c_coeff = c_coeff / mc.getConst<Rational>();
             }
             if( !it->second.isNull() ){
-              c_coeff = NodeManager::currentNM()->mkNode( MULT, c_coeff, it->second );
+              Assert(it->second.isConst());
+              c_coeff = c_coeff * it->second.getConst<Rational>();
             }
-            Assert(!c_coeff.isNull());
-            Node c;
-            if( msum_term[it->first].isNull() ){
-              c = c_coeff;
-            }else{
-              c = NodeManager::currentNM()->mkNode( MULT, c_coeff, msum_term[it->first] );
+            Node c = nm->mkConstRealOrInt(type, c_coeff);
+            Node v = msum_term[it->first];
+            if (!v.isNull())
+            {
+              Assert(v.getType() == type);
+              c = nm->mkNode(MULT, c, v);
             }
             children.push_back( c );
             Trace("sygus-si-apply-subs-debug") << "Add child : " << c << std::endl;
           }
-          Node nretc = children.size() == 1
-                           ? children[0]
-                           : NodeManager::currentNM()->mkNode(ADD, children);
+          Node nretc =
+              children.size() == 1 ? children[0] : nm->mkNode(ADD, children);
           nretc = rewrite(nretc);
           //ensure that nret does not contain vars
           if (!expr::hasSubterm(nretc, vars))