Further refactoring in preparation for CONST_INTEGER (#8687)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 2 May 2022 19:11:22 +0000 (14:11 -0500)
committerGitHub <noreply@github.com>
Mon, 2 May 2022 19:11:22 +0000 (19:11 +0000)
Miscellaneous refactorings from trying to enable CONST_INTEGER.

src/expr/node_manager_template.cpp
src/expr/node_manager_template.h
src/theory/arith/arith_rewriter.cpp
src/theory/arith/linear/infer_bounds.cpp
src/theory/arith/linear/normal_form.cpp
src/theory/arith/linear/normal_form.h
src/theory/arith/rewriter/node_utils.h
src/theory/arith/theory_arith_type_rules.cpp
src/theory/quantifiers/instantiate.cpp

index 24e1f01f3a4842df59e6ed7505249e76f4ab8721..c27679b4f8f5d7cb4d5ed92db6aa938fd55ba777 100644 (file)
@@ -1314,6 +1314,15 @@ Node NodeManager::mkConstInt(const Rational& r)
   return mkConst(kind::CONST_RATIONAL, r);
 }
 
+Node NodeManager::mkConstRealOrInt(const Rational& r)
+{
+  if (r.isIntegral())
+  {
+    return mkConstInt(r);
+  }
+  return mkConstReal(r);
+}
+
 Node NodeManager::mkConstRealOrInt(const TypeNode& tn, const Rational& r)
 {
   Assert(tn.isRealOrInt()) << "Expected real or int for mkConstRealOrInt, got "
@@ -1329,7 +1338,8 @@ Node NodeManager::mkRealAlgebraicNumber(const RealAlgebraicNumber& ran)
 {
   if (ran.isRational())
   {
-    return mkConstReal(ran.toRational());
+    // may generate an integer it is it integral
+    return mkConstRealOrInt(ran.toRational());
   }
   // Creating this node may refine the ran to the point where isRational returns
   // true
@@ -1341,7 +1351,8 @@ Node NodeManager::mkRealAlgebraicNumber(const RealAlgebraicNumber& ran)
     const RealAlgebraicNumber& cur = inner.getConst<RealAlgebraicNumber>();
     if (cur.isRational())
     {
-      return mkConstReal(cur.toRational());
+      // may generate an integer it is it integral
+      return mkConstRealOrInt(cur.toRational());
     }
     if (cur == ran) break;
     inner = mkConst(Kind::REAL_ALGEBRAIC_NUMBER_OP, cur);
index 678728c780b667f4c8920ff494e33b2e67868b62..e10f448b4945a0cec9c11bdf111e5d28fd67b53f 100644 (file)
@@ -686,6 +686,12 @@ class NodeManager
    */
   Node mkConstInt(const Rational& r);
 
+  /**
+   * Make constant real or int, which calls one of the above methods based
+   * on whether r is integral.
+   */
+  Node mkConstRealOrInt(const Rational& r);
+
   /**
    * Make constant real or int, which calls one of the above methods based
    * on the type tn.
index 128b2e84cf5476e39a8e3bdfc76ad92767f654fb..804e476fbbd71b85558d91c6d5cdf013c5783b8b 100644 (file)
@@ -899,6 +899,12 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t)
   Trace("arith-tf-rewrite")
       << "Rewrite transcendental function : " << t << std::endl;
   NodeManager* nm = NodeManager::currentNM();
+  if (t[0].getKind() == TO_REAL)
+  {
+    // always strip TO_REAL from argument.
+    Node ret = nm->mkNode(t.getKind(), t[0][0]);
+    return RewriteResponse(REWRITE_AGAIN, ret);
+  }
   switch (t.getKind())
   {
     case kind::EXPONENTIAL:
@@ -1019,7 +1025,6 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t)
             {
               new_arg = nm->mkNode(kind::ADD, new_arg, rem);
             }
-            new_arg = ensureReal(new_arg);
             // sin( 2*n*PI + x ) = sin( x )
             return RewriteResponse(REWRITE_AGAIN_FULL,
                                    nm->mkNode(kind::SINE, new_arg));
@@ -1049,8 +1054,8 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t)
             if (r_abs.getDenominator() == two)
             {
               Assert(r_abs.getNumerator() == one);
-              return RewriteResponse(
-                  REWRITE_DONE, ensureReal(nm->mkConstReal(Rational(r.sgn()))));
+              return RewriteResponse(REWRITE_DONE,
+                                     nm->mkConstReal(Rational(r.sgn())));
             }
             else if (r_abs.getDenominator() == six)
             {
index ec2843aa2834cecc6fefafa79530e438dcb4c63c..8666921f52b053e1be4cd2f99dacae8ad85c7872 100644 (file)
@@ -151,7 +151,7 @@ Node InferBoundsResult::getTerm() const { return d_term; }
 Node InferBoundsResult::getLiteral() const{
   const Rational& q = getValue().getNoninfinitesimalPart();
   NodeManager* nm = NodeManager::currentNM();
-  Node qnode = nm->mkConst(CONST_RATIONAL, q);
+  Node qnode = nm->mkConstReal(q);
 
   Kind k;
   if(d_upperBound){
index 81bb2383300a3f43534bd5570aad8f688be94946..ecbf10b1b9fb1ea1228567a9936db0ff50b12a03 100644 (file)
@@ -220,13 +220,16 @@ VarList VarList::operator*(const VarList& other) const {
 }
 
 bool Monomial::isMember(TNode n){
-  if(n.getKind() == kind::CONST_RATIONAL) {
+  Kind k = n.getKind();
+  if (k == kind::CONST_RATIONAL || k == kind::CONST_INTEGER)
+  {
     return true;
-  } else if(multStructured(n)) {
+  }
+  else if (multStructured(n))
+  {
     return VarList::isMember(n[1]);
-  } else {
-    return VarList::isMember(n);
   }
+  return VarList::isMember(n);
 }
 
 Monomial Monomial::mkMonomial(const Constant& c, const VarList& vl) {
@@ -249,13 +252,16 @@ Monomial Monomial::mkMonomial(const VarList& vl) {
 }
 
 Monomial Monomial::parseMonomial(Node n) {
-  if(n.getKind() == kind::CONST_RATIONAL) {
+  Kind k = n.getKind();
+  if (k == kind::CONST_RATIONAL || k == kind::CONST_INTEGER)
+  {
     return Monomial(Constant(n));
-  } else if(multStructured(n)) {
+  }
+  else if (multStructured(n))
+  {
     return Monomial::mkMonomial(Constant(n[0]),VarList::parseVarList(n[1]));
-  } else {
-    return Monomial(VarList::parseVarList(n));
   }
+  return Monomial(VarList::parseVarList(n));
 }
 Monomial Monomial::operator*(const Rational& q) const {
   if(q.isZero()){
@@ -699,7 +705,10 @@ SumPair SumPair::mkSumPair(const Polynomial& p){
   }
 }
 
-Comparison::Comparison(TNode n) : NodeWrapper(n) { Assert(isNormalForm()); }
+Comparison::Comparison(TNode n) : NodeWrapper(n)
+{
+  Assert(isNormalForm()) << "Bad comparison normal form: " << n;
+}
 
 SumPair Comparison::toSumPair() const {
   Kind cmpKind = comparisonKind();
@@ -719,8 +728,8 @@ SumPair Comparison::toSumPair() const {
         return SumPair(-p, c);
       }
     }
-  case kind::EQUAL:
-  case kind::DISTINCT:
+    case kind::EQUAL:
+    case kind::DISTINCT:
     {
       Polynomial left = getLeft();
       Polynomial right = getRight();
@@ -758,8 +767,8 @@ Polynomial Comparison::normalizedVariablePart() const {
         return -p;
       }
     }
-  case kind::EQUAL:
-  case kind::DISTINCT:
+    case kind::EQUAL:
+    case kind::DISTINCT:
     {
       Polynomial left = getLeft();
       Polynomial right = getRight();
@@ -798,8 +807,8 @@ DeltaRational Comparison::normalizedDeltaRational() const {
         return DeltaRational(-q, -delta);
       }
     }
-  case kind::EQUAL:
-  case kind::DISTINCT:
+    case kind::EQUAL:
+    case kind::DISTINCT:
     {
       Polynomial right = getRight();
       Monomial firstRight = right.getHead();
@@ -914,19 +923,20 @@ Node Comparison::toNode(Kind k, const Polynomial& l, const Polynomial& r) {
     return toNode(kind::GEQ, r, l).notNode();
   case kind::LT:
     return toNode(kind::GT, r, l).notNode();
-  case kind::DISTINCT:
-    return toNode(kind::EQUAL, r, l).notNode();
+  case kind::DISTINCT: return toNode(kind::EQUAL, r, l).notNode();
   default:
     Unreachable();
   }
 }
 
 bool Comparison::rightIsConstant() const {
+  Kind k;
   if(getNode().getKind() == kind::NOT){
-    return getNode()[0][1].getKind() == kind::CONST_RATIONAL;
+    k = getNode()[0][1].getKind();
   }else{
-    return getNode()[1].getKind() == kind::CONST_RATIONAL;
+    k = getNode()[1].getKind();
   }
+  return k == kind::CONST_RATIONAL || k == kind::CONST_INTEGER;
 }
 
 size_t Comparison::getComplexity() const{
@@ -1004,8 +1014,7 @@ bool Comparison::isNormalForm() const {
     return isNormalGT();
   case kind::GEQ:
     return isNormalGEQ();
-  case kind::EQUAL:
-    return isNormalEquality();
+  case kind::EQUAL: return isNormalEquality();
   case kind::LT:
     return isNormalLT();
   case kind::LEQ:
@@ -1304,7 +1313,9 @@ Node Comparison::mkIntEquality(const Polynomial& p){
 Comparison Comparison::mkComparison(Kind k, const Polynomial& l, const Polynomial& r){
 
   //Make this special case fast for sharing!
-  if((k == kind::EQUAL || k == kind::DISTINCT) && l.isVarList() && r.isVarList()){
+  if ((k == kind::EQUAL || k == kind::DISTINCT) && l.isVarList()
+      && r.isVarList())
+  {
     VarList vLeft = l.asVarList();
     VarList vRight = r.asVarList();
 
@@ -1312,7 +1323,8 @@ Comparison Comparison::mkComparison(Kind k, const Polynomial& l, const Polynomia
       // return true for equalities and false for disequalities
       return Comparison(k == kind::EQUAL);
     }else{
-      Node eqNode = vLeft < vRight ? toNode( kind::EQUAL, l, r) : toNode( kind::EQUAL, r, l);
+      Node eqNode = vLeft < vRight ? toNode(kind::EQUAL, l, r)
+                                   : toNode(kind::EQUAL, r, l);
       Node forK = (k == kind::DISTINCT) ? eqNode.notNode() : eqNode;
       return Comparison(forK);
     }
@@ -1327,10 +1339,10 @@ Comparison Comparison::mkComparison(Kind k, const Polynomial& l, const Polynomia
     Node result = Node::null();
     bool isInteger = diff.allIntegralVariables();
     switch(k){
-    case kind::EQUAL:
-      result = isInteger ? mkIntEquality(diff) : mkRatEquality(diff);
-      break;
-    case kind::DISTINCT:
+      case kind::EQUAL:
+        result = isInteger ? mkIntEquality(diff) : mkRatEquality(diff);
+        break;
+      case kind::DISTINCT:
       {
         Node eq = isInteger ? mkIntEquality(diff) : mkRatEquality(diff);
         result = eq.notNode();
@@ -1377,8 +1389,7 @@ Kind Comparison::comparisonKind(TNode literal){
   case kind::CONST_BOOLEAN:
   case kind::GT:
   case kind::GEQ:
-  case kind::EQUAL:
-    return literal.getKind();
+  case kind::EQUAL: return literal.getKind();
   case  kind::NOT:
     {
       TNode negatedAtom = literal[0];
@@ -1387,8 +1398,7 @@ Kind Comparison::comparisonKind(TNode literal){
         return kind::LEQ;
       case kind::GEQ: //(not (GEQ x c)) <=> (LT x c)
         return kind::LT;
-      case kind::EQUAL:
-        return kind::DISTINCT;
+      case kind::EQUAL: return kind::DISTINCT;
       default:
         return  kind::UNDEFINED_KIND;
       }
index 9656e2876a333737344cd938fa852a25058e7099..c9f8eb72b5ea0ae5f05e4cd8c87fdc79d84ee20a 100644 (file)
@@ -231,6 +231,7 @@ public:
    Kind k = n.getKind();
    switch (k)
    {
+     case kind::CONST_INTEGER:
      case kind::CONST_RATIONAL: return false;
      case kind::INTS_DIVISION:
      case kind::INTS_MODULUS:
@@ -347,13 +348,18 @@ class Constant : public NodeWrapper {
 public:
  Constant(Node n) : NodeWrapper(n) { Assert(isMember(getNode())); }
 
- static bool isMember(Node n) { return n.getKind() == kind::CONST_RATIONAL; }
+ static bool isMember(Node n)
+ {
+   Kind k = n.getKind();
+   return k == kind::CONST_RATIONAL || k == kind::CONST_INTEGER;
+ }
 
  bool isNormalForm() { return isMember(getNode()); }
 
  static Constant mkConstant(Node n)
  {
-   Assert(n.getKind() == kind::CONST_RATIONAL);
+   Assert(n.getKind() == kind::CONST_RATIONAL
+          || n.getKind() == kind::CONST_INTEGER);
    return Constant(n);
  }
 
@@ -633,9 +639,8 @@ private:
   }
 
   static bool multStructured(Node n) {
-    return n.getKind() ==  kind::MULT &&
-      n[0].getKind() == kind::CONST_RATIONAL &&
-      n.getNumChildren() == 2;
+    return n.getKind() == kind::MULT && n[0].isConst()
+           && n.getNumChildren() == 2;
   }
 
   Monomial(const Constant& c):
@@ -794,7 +799,7 @@ private:
   bool d_singleton;
 
   Polynomial(TNode n) : NodeWrapper(n), d_singleton(Monomial::isMember(n)) {
-    Assert(isMember(getNode()));
+    Assert(isMember(getNode())) << "Bad polynomial member " << n;
   }
 
   static Node makePlusNode(const std::vector<Monomial>& m) {
index bd6bff961defb3b0a2510989b4de84a840dadf22..54f8a16960ef4bb6dfb40ad2beddea3981919d79 100644 (file)
@@ -91,15 +91,7 @@ inline Node mkConst(const Integer& value)
 {
   return NodeManager::currentNM()->mkConstInt(value);
 }
-/** Create an integer or rational constant node */
-inline Node mkConst(const Rational& value)
-{
-  if (value.isIntegral())
-  {
-    return NodeManager::currentNM()->mkConstInt(value);
-  }
-  return NodeManager::currentNM()->mkConstReal(value);
-}
+
 /** Create a real algebraic number node */
 inline Node mkConst(const RealAlgebraicNumber& value)
 {
index eaf3315cbc07cabd73936eeb36cb908a82a65819..31dede3a953d20b7bd5b5b94d2d2cf5c0989762d 100644 (file)
@@ -30,10 +30,7 @@ TypeNode ArithConstantTypeRule::computeType(NodeManager* nodeManager,
   {
     return nodeManager->integerType();
   }
-  else
-  {
-    return nodeManager->realType();
-  }
+  return nodeManager->realType();
 }
 
 TypeNode ArithRealAlgebraicNumberOpTypeRule::computeType(
index 23abba94aca320fc3145c9f345c2e2a08fcc9cc7..bc0b2ba97e93d2ddd3917482dc411da3554d0e85 100644 (file)
@@ -152,7 +152,7 @@ bool Instantiate::addInstantiation(Node q,
                     << terms[i] << std::endl;
       bad_inst = true;
     }
-    else if (!terms[i].getType().isSubtypeOf(q[0][i].getType()))
+    else if (terms[i].getType() != q[0][i].getType())
     {
       Trace("inst") << "***& inst bad type : " << terms[i] << " "
                     << terms[i].getType() << "/" << q[0][i].getType()
@@ -761,7 +761,7 @@ Node Instantiate::ensureType(Node n, TypeNode tn)
   Trace("inst-add-debug2") << "Ensure " << n << " : " << tn << std::endl;
   TypeNode ntn = n.getType();
   Assert(ntn.isComparableTo(tn));
-  if (ntn.isSubtypeOf(tn))
+  if (ntn == tn)
   {
     return n;
   }
@@ -769,6 +769,10 @@ Node Instantiate::ensureType(Node n, TypeNode tn)
   {
     return NodeManager::currentNM()->mkNode(TO_INTEGER, n);
   }
+  else if (tn.isReal())
+  {
+    return NodeManager::currentNM()->mkNode(TO_REAL, n);
+  }
   return Node::null();
 }