Eliminating some uses of const rational in arithmetic (#7846)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 20 Dec 2021 19:10:30 +0000 (13:10 -0600)
committerGitHub <noreply@github.com>
Mon, 20 Dec 2021 19:10:30 +0000 (19:10 +0000)
Note that there are several nested dependencies in arithmetic for constructing constants Constant::mkConstant ---> mkRationalNode ---> mkConst(CONST_RATIONAL, r)

This starts to disambiguate these calls.

src/preprocessing/passes/pseudo_boolean_processor.cpp
src/theory/arith/arith_rewriter.cpp
src/theory/arith/branch_and_bound.cpp
src/theory/arith/dio_solver.cpp
src/theory/arith/nl/transcendental/sine_solver.cpp
src/theory/arith/operator_elim.cpp

index 0e7ac9c79e74b7fc04ccd8eb76e1256583a0a674..eae1d00fd5a296a8feb1b0486d183d0e805fd080 100644 (file)
@@ -301,7 +301,8 @@ void PseudoBooleanProcessor::learn(Node assertion)
 Node PseudoBooleanProcessor::mkGeqOne(Node v)
 {
   NodeManager* nm = NodeManager::currentNM();
-  return nm->mkNode(kind::GEQ, v, mkRationalNode(Rational(1)));
+  return nm->mkNode(
+      kind::GEQ, v, nm->mkConstRealOrInt(v.getType(), Rational(1)));
 }
 
 void PseudoBooleanProcessor::learn(const std::vector<Node>& assertions)
index af6f23c1fd2ab3bea34ea2c40ebc36d9bbc99d13..0268a9eb1d688432d5315c762fe69922bd8d5e73 100644 (file)
@@ -49,7 +49,7 @@ bool ArithRewriter::isAtom(TNode n) {
 
 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
   Assert(t.isConst());
-  Assert(t.getKind() == kind::CONST_RATIONAL);
+  Assert(t.getKind() == CONST_RATIONAL || t.getKind() == CONST_INTEGER);
 
   return RewriteResponse(REWRITE_DONE, t);
 }
@@ -66,7 +66,8 @@ RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
   if(pre){
     if(t[0] == t[1]){
       Rational zero(0);
-      Node zeroNode  = mkRationalNode(zero);
+      Node zeroNode =
+          NodeManager::currentNM()->mkConstRealOrInt(t.getType(), zero);
       return RewriteResponse(REWRITE_DONE, zeroNode);
     }else{
       Node noMinus = makeSubtractionNode(t[0],t[1]);
@@ -83,9 +84,12 @@ RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
   Assert(t.getKind() == kind::UMINUS);
 
-  if(t[0].getKind() == kind::CONST_RATIONAL){
+  if (t[0].isConst())
+  {
     Rational neg = -(t[0].getConst<Rational>());
-    return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
+    NodeManager* nm = NodeManager::currentNM();
+    return RewriteResponse(REWRITE_DONE,
+                           nm->mkConstRealOrInt(t[0].getType(), neg));
   }
 
   Node noUminus = makeUnaryMinusNode(t[0]);
@@ -142,7 +146,7 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
         } else {
           return RewriteResponse(
               REWRITE_DONE,
-              NodeManager::currentNM()->mkConst(CONST_RATIONAL, -rat));
+              NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat));
         }
       }
       return RewriteResponse(REWRITE_DONE, t);
@@ -208,7 +212,7 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
         } else {
           return RewriteResponse(
               REWRITE_DONE,
-              NodeManager::currentNM()->mkConst(CONST_RATIONAL, -rat));
+              NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat));
         }
       }
       return RewriteResponse(REWRITE_DONE, t);
@@ -217,7 +221,8 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
     case kind::TO_INTEGER: return rewriteExtIntegerOp(t);
     case kind::POW:
       {
-        if(t[1].getKind() == kind::CONST_RATIONAL){
+        if (t[1].isConst())
+        {
           const Rational& exp = t[1].getConst<Rational>();
           TNode base = t[0];
           if(exp.sgn() == 0){
@@ -241,8 +246,9 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
             }
           }
         }
-        else if (t[0].getKind() == kind::CONST_RATIONAL
-                 && t[0].getConst<Rational>().getNumerator().toUnsignedInt() == 2)
+        else if (t[0].isConst()
+                 && t[0].getConst<Rational>().getNumerator().toUnsignedInt()
+                        == 2)
         {
           return RewriteResponse(
               REWRITE_DONE, NodeManager::currentNM()->mkNode(kind::POW2, t[1]));
@@ -270,19 +276,20 @@ RewriteResponse ArithRewriter::preRewriteMult(TNode t){
   Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
 
   if(t.getNumChildren() == 2){
-    if(t[0].getKind() == kind::CONST_RATIONAL
-       && t[0].getConst<Rational>().isOne()){
+    if (t[0].isConst() && t[0].getConst<Rational>().isOne())
+    {
       return RewriteResponse(REWRITE_DONE, t[1]);
     }
-    if(t[1].getKind() == kind::CONST_RATIONAL
-       && t[1].getConst<Rational>().isOne()){
+    if (t[1].isConst() && t[1].getConst<Rational>().isOne())
+    {
       return RewriteResponse(REWRITE_DONE, t[0]);
     }
   }
 
   // Rewrite multiplications with a 0 argument and to 0
   for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
-    if((*i).getKind() == kind::CONST_RATIONAL) {
+    if ((*i).isConst())
+    {
       if((*i).getConst<Rational>().isZero()) {
         TNode zero = (*i);
         return RewriteResponse(REWRITE_DONE, zero);
@@ -387,13 +394,10 @@ RewriteResponse ArithRewriter::postRewritePow2(TNode t)
     Integer i = t[0].getConst<Rational>().getNumerator();
     if (i < 0)
     {
-      return RewriteResponse(
-          REWRITE_DONE,
-          nm->mkConst(CONST_RATIONAL, Rational(Integer(0), Integer(1))));
+      return RewriteResponse(REWRITE_DONE, nm->mkConstInt(Rational(0)));
     }
     unsigned long k = i.getUnsignedLong();
-    Node ret =
-        nm->mkConst(CONST_RATIONAL, Rational(Integer(2).pow(k), Integer(1)));
+    Node ret = nm->mkConstInt(Rational(Integer(2).pow(k)));
     return RewriteResponse(REWRITE_DONE, ret);
   }
   return RewriteResponse(REWRITE_DONE, t);
@@ -455,8 +459,9 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
   NodeManager* nm = NodeManager::currentNM();
   switch( t.getKind() ){
   case kind::EXPONENTIAL: {
-    if(t[0].getKind() == kind::CONST_RATIONAL){
-      Node one = nm->mkConst(CONST_RATIONAL, Rational(1));
+    if (t[0].isConst())
+    {
+      Node one = nm->mkConstReal(Rational(1));
       if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
         return RewriteResponse(
             REWRITE_AGAIN,
@@ -480,17 +485,16 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
   }
     break;
   case kind::SINE:
-    if(t[0].getKind() == kind::CONST_RATIONAL){
+    if (t[0].isConst())
+    {
       const Rational& rat = t[0].getConst<Rational>();
       if(rat.sgn() == 0){
-        return RewriteResponse(REWRITE_DONE,
-                               nm->mkConst(CONST_RATIONAL, Rational(0)));
+        return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
       }
       else if (rat.sgn() == -1)
       {
-        Node ret = nm->mkNode(
-            kind::UMINUS,
-            nm->mkNode(kind::SINE, nm->mkConst(CONST_RATIONAL, -rat)));
+        Node ret = nm->mkNode(kind::UMINUS,
+                              nm->mkNode(kind::SINE, nm->mkConstReal(-rat)));
         return RewriteResponse(REWRITE_AGAIN_FULL, ret);
       }
     }else{
@@ -507,7 +511,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
         {
           if (itm->second.isNull())
           {
-            pi_factor = mkRationalNode(Rational(1));
+            pi_factor = nm->mkConstReal(Rational(1));
           }
           else
           {
@@ -564,7 +568,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
           // sin( PI + x ) = -sin( x )
           if (rem.isNull())
           {
-            return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(0)));
+            return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
           }
           else
           {
@@ -584,7 +588,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
           {
             Assert(r_abs.getNumerator() == one);
             return RewriteResponse(REWRITE_DONE,
-                                   mkRationalNode(Rational(r.sgn())));
+                                   nm->mkConstReal(Rational(r.sgn())));
           }
           else if (r_abs.getDenominator() == six)
           {
@@ -593,7 +597,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
             {
               return RewriteResponse(
                   REWRITE_DONE,
-                  mkRationalNode(Rational(r.sgn()) / Rational(2)));
+                  nm->mkConstReal(Rational(r.sgn()) / Rational(2)));
             }
           }
         }
@@ -603,13 +607,13 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
   case kind::COSINE: {
     return RewriteResponse(
         REWRITE_AGAIN_FULL,
-        nm->mkNode(kind::SINE,
-                   nm->mkNode(kind::MINUS,
-                              nm->mkNode(kind::MULT,
-                                         nm->mkConst(CONST_RATIONAL,
-                                                     Rational(1) / Rational(2)),
-                                         mkPi()),
-                              t[0])));
+        nm->mkNode(
+            kind::SINE,
+            nm->mkNode(kind::MINUS,
+                       nm->mkNode(kind::MULT,
+                                  nm->mkConstReal(Rational(1) / Rational(2)),
+                                  mkPi()),
+                       t[0])));
   }
   break;
   case kind::TANGENT:
@@ -624,7 +628,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
   {
     return RewriteResponse(REWRITE_AGAIN_FULL,
                            nm->mkNode(kind::DIVISION,
-                                      mkRationalNode(Rational(1)),
+                                      nm->mkConstReal(Rational(1)),
                                       nm->mkNode(kind::SINE, t[0])));
   }
   break;
@@ -632,7 +636,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
   {
     return RewriteResponse(REWRITE_AGAIN_FULL,
                            nm->mkNode(kind::DIVISION,
-                                      mkRationalNode(Rational(1)),
+                                      nm->mkConstReal(Rational(1)),
                                       nm->mkNode(kind::COSINE, t[0])));
   }
   break;
@@ -660,17 +664,15 @@ RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
     if(atom.getOperator().getConst<Divisible>().k.isOne()) {
       return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
     }
+    NodeManager* nm = NodeManager::currentNM();
     return RewriteResponse(
         REWRITE_AGAIN,
-        NodeManager::currentNM()->mkNode(
-            kind::EQUAL,
-            NodeManager::currentNM()->mkNode(
-                kind::INTS_MODULUS_TOTAL,
-                atom[0],
-                NodeManager::currentNM()->mkConst(
-                    CONST_RATIONAL,
-                    Rational(atom.getOperator().getConst<Divisible>().k))),
-            NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(0))));
+        nm->mkNode(kind::EQUAL,
+                   nm->mkNode(kind::INTS_MODULUS_TOTAL,
+                              atom[0],
+                              nm->mkConstInt(Rational(
+                                  atom.getOperator().getConst<Divisible>().k))),
+                   nm->mkConstInt(Rational(0))));
   }
 
   // left |><| right
@@ -747,8 +749,9 @@ RewriteResponse ArithRewriter::preRewrite(TNode t){
 }
 
 Node ArithRewriter::makeUnaryMinusNode(TNode n){
+  NodeManager* nm = NodeManager::currentNM();
   Rational qNegOne(-1);
-  return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
+  return nm->mkNode(kind::MULT, nm->mkConstRealOrInt(n.getType(), qNegOne), n);
 }
 
 Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
@@ -763,12 +766,14 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
 
   Node left = t[0];
   Node right = t[1];
-  if(right.getKind() == kind::CONST_RATIONAL){
+  if (right.isConst())
+  {
+    NodeManager* nm = NodeManager::currentNM();
     const Rational& den = right.getConst<Rational>();
 
     if(den.isZero()){
       if(t.getKind() == kind::DIVISION_TOTAL){
-        return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
+        return RewriteResponse(REWRITE_DONE, nm->mkConstReal(0));
       }else{
         // This is unsupported, but this is not a good place to complain
         return RewriteResponse(REWRITE_DONE, t);
@@ -776,16 +781,17 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
     }
     Assert(den != Rational(0));
 
-    if(left.getKind() == kind::CONST_RATIONAL){
+    if (left.isConst())
+    {
       const Rational& num = left.getConst<Rational>();
       Rational div = num / den;
-      Node result =  mkRationalNode(div);
+      Node result = nm->mkConstReal(div);
       return RewriteResponse(REWRITE_DONE, result);
     }
 
     Rational div = den.inverse();
 
-    Node result = mkRationalNode(div);
+    Node result = nm->mkConstReal(div);
 
     Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
     if(pre){
@@ -793,16 +799,14 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
     }else{
       return RewriteResponse(REWRITE_AGAIN, mult);
     }
-  }else{
-    return RewriteResponse(REWRITE_DONE, t);
   }
+  return RewriteResponse(REWRITE_DONE, t);
 }
 
 RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
 {
   NodeManager* nm = NodeManager::currentNM();
   Kind k = t.getKind();
-  Node zero = nm->mkConst(CONST_RATIONAL, Rational(0));
   if (k == kind::INTS_MODULUS)
   {
     if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
@@ -867,10 +871,10 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
   Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
   TNode n = t[0];
   TNode d = t[1];
-  bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
+  bool dIsConstant = d.isConst();
   if(dIsConstant && d.getConst<Rational>().isZero()){
     // (div x 0) ---> 0 or (mod x 0) ---> 0
-    return returnRewrite(t, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO);
+    return returnRewrite(t, nm->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO);
   }else if(dIsConstant && d.getConst<Rational>().isOne()){
     if (k == kind::INTS_MODULUS_TOTAL)
     {
@@ -886,14 +890,13 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
     // pull negation
     // (div x (- c)) ---> (- (div x c))
     // (mod x (- c)) ---> (mod x c)
-    Node nn = nm->mkNode(
-        k, t[0], nm->mkConst(CONST_RATIONAL, -t[1].getConst<Rational>()));
+    Node nn = nm->mkNode(k, t[0], nm->mkConstInt(-t[1].getConst<Rational>()));
     Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
                    ? nm->mkNode(kind::UMINUS, nn)
                    : nn;
     return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
   }
-  else if (dIsConstant && n.getKind() == kind::CONST_RATIONAL)
+  else if (dIsConstant && n.isConst())
   {
     Assert(d.getConst<Rational>().isIntegral());
     Assert(n.getConst<Rational>().isIntegral());
@@ -907,7 +910,7 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
 
     // constant evaluation
     // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
-    Node resultNode = mkRationalNode(Rational(result));
+    Node resultNode = nm->mkConstInt(Rational(result));
     return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
   }
   if (k == kind::INTS_MODULUS_TOTAL)
@@ -953,7 +956,7 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
     if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
     {
       // (div (mod x c) c) --> 0
-      Node ret = mkRationalNode(0);
+      Node ret = nm->mkConstInt(0);
       return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
     }
   }
index eb02339bb3934ad7e0a94b2d69c3bc2a51053bc1..6d9a71722ecf6a857c992615299cd961b1412846 100644 (file)
@@ -58,10 +58,10 @@ TrustNode BranchAndBound::branchIntegerVariable(TNode var, Rational value)
 
     // Prioritize trying a simple rounding of the real solution first,
     // it that fails, fall back on original branch and bound strategy.
-    Node ub = rewrite(nm->mkNode(LEQ, var, mkRationalNode(nearest - 1)));
-    Node lb = rewrite(nm->mkNode(GEQ, var, mkRationalNode(nearest + 1)));
+    Node ub = rewrite(nm->mkNode(LEQ, var, nm->mkConstInt(nearest - 1)));
+    Node lb = rewrite(nm->mkNode(GEQ, var, nm->mkConstInt(nearest + 1)));
     Node right = nm->mkNode(OR, ub, lb);
-    Node rawEq = nm->mkNode(EQUAL, var, mkRationalNode(nearest));
+    Node rawEq = nm->mkNode(EQUAL, var, nm->mkConstInt(nearest));
     Node eq = rewrite(rawEq);
     // Also preprocess it before we send it out. This is important since
     // arithmetic may prefer eliminating equalities.
@@ -78,8 +78,8 @@ TrustNode BranchAndBound::branchIntegerVariable(TNode var, Rational value)
     Trace("integers") << "l: " << l << std::endl;
     if (proofsEnabled())
     {
-      Node less = nm->mkNode(LT, var, mkRationalNode(nearest));
-      Node greater = nm->mkNode(GT, var, mkRationalNode(nearest));
+      Node less = nm->mkNode(LT, var, nm->mkConstInt(nearest));
+      Node greater = nm->mkNode(GT, var, nm->mkConstInt(nearest));
       // TODO (project #37): justify. Thread proofs through *ensureLiteral*.
       Debug("integers::pf") << "less: " << less << std::endl;
       Debug("integers::pf") << "greater: " << greater << std::endl;
@@ -119,7 +119,7 @@ TrustNode BranchAndBound::branchIntegerVariable(TNode var, Rational value)
   }
   else
   {
-    Node ub = rewrite(nm->mkNode(LEQ, var, mkRationalNode(floor)));
+    Node ub = rewrite(nm->mkNode(LEQ, var, nm->mkConstInt(floor)));
     Node lb = ub.notNode();
     if (proofsEnabled())
     {
index 99dcc93caf170994cc2d752c9b6120402762c662..af3d8a692e4a22d4f3b895e0452ae5b9e8fb47e3 100644 (file)
@@ -820,8 +820,10 @@ void DioSolver::addTrailElementAsLemma(TrailIndex i) {
 
 Node DioSolver::trailIndexToEquality(TrailIndex i) const {
   const SumPair& sp = d_trail[i].d_eq;
-  Node zero = mkRationalNode(0);
-  Node eq = (sp.getNode()).eqNode(zero);
+  Node n = sp.getNode();
+  Node zero =
+      NodeManager::currentNM()->mkConstRealOrInt(n.getType(), Rational(0));
+  Node eq = n.eqNode(zero);
   return eq;
 }
 
index 6c1bec647b78f6cac3d55f0d65d824776021536c..d574a95725e7d301bcb66ab174111a886d59cc6e 100644 (file)
@@ -45,10 +45,9 @@ namespace {
  */
 inline Node mkValidPhase(TNode a, TNode pi)
 {
+  NodeManager* nm = NodeManager::currentNM();
   return mkBounded(
-      NodeManager::currentNM()->mkNode(Kind::MULT, mkRationalNode(-1), pi),
-      a,
-      pi);
+      nm->mkNode(Kind::MULT, nm->mkConstReal(Rational(-1)), pi), a, pi);
 }
 }  // namespace
 
index 99f5621d6b4959c445fee9257993a0cae92eaca8..05a83c81c1f0a4146276a59bbf834448598f17a2 100644 (file)
@@ -122,8 +122,8 @@ Node OperatorElim::eliminateOperators(Node node,
       // 0 <= node[0] - toIntSkolem < 1
       Node v =
           bvm->mkBoundVar<ToIntWitnessVarAttribute>(node[0], nm->integerType());
-      Node one = mkRationalNode(1);
-      Node zero = mkRationalNode(0);
+      Node one = nm->mkConstReal(Rational(1));
+      Node zero = nm->mkConstReal(Rational(0));
       Node diff = nm->mkNode(MINUS, node[0], v);
       Node lem = mkInRange(diff, zero, one);
       Node toIntSkolem =