Eliminate more uses of CONST_RATIONAL (#8590)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 8 Apr 2022 20:58:02 +0000 (15:58 -0500)
committerGitHub <noreply@github.com>
Fri, 8 Apr 2022 20:58:02 +0000 (20:58 +0000)
To eliminate arithmetic subtyping, we require distinguishing CONST_RATIONAL and CONST_INTEGER internally. Code should avoid usage of these kinds and use trusted utilities instead (e.g. mkConstReal, mkConstInst, isConst).

src/theory/arith/arith_ite_utils.cpp
src/theory/arith/arith_utilities.cpp
src/theory/arith/arith_utilities.h
src/theory/arith/nl/ext_theory_callback.cpp
src/theory/arith/nl/ext_theory_callback.h
src/theory/arith/nl/nl_model.cpp
src/theory/arith/type_enumerator.h
src/theory/bags/bags_utils.cpp
src/theory/bags/theory_bags.cpp
src/theory/strings/theory_strings.cpp

index a30d312e77b38cc84f472d42ecdf3ebf3b61a05e..6bf79b4ea39c692b6c8ed104015f3c123a097d8c 100644 (file)
@@ -181,7 +181,8 @@ const Integer& ArithIteUtils::gcdIte(Node n){
   if(d_gcds.find(n) != d_gcds.end()){
     return d_gcds[n];
   }
-  if(n.getKind() == kind::CONST_RATIONAL){
+  if (n.isConst())
+  {
     const Rational& q = n.getConst<Rational>();
     if(q.isIntegral()){
       d_gcds[n] = q.getNumerator();
index 7b9db4beb6507f863c410e613089f3ae20f76d51..b0709c1b35a931b18966ba652289ba9f814767f7 100644 (file)
@@ -99,6 +99,12 @@ Kind transKinds(Kind k1, Kind k2)
   return UNDEFINED_KIND;
 }
 
+bool isZero(const Node& n)
+{
+  Assert(n.getType().isRealOrInt());
+  return n.isConst() && n.getConst<Rational>().sgn() == 0;
+}
+
 bool isTranscendentalKind(Kind k)
 {
   // many operators are eliminated during rewriting
index d4df80e3d8b383bf57f26a091b35e78db562f325..98afea98f9446a0be05eda4dc25421f39b828684 100644 (file)
@@ -298,6 +298,9 @@ Kind joinKinds(Kind k1, Kind k2);
  */
 Kind transKinds(Kind k1, Kind k2);
 
+/** Is n (integer or real) zero? */
+bool isZero(const Node& n);
+
 /** Is k a transcendental function kind? */
 bool isTranscendentalKind(Kind k);
 /**
index d6993398440bb57e9b5b751a2ec4e40f14fb3577..bc21c6b4bb6496c5be620b259e01436efe8b1932 100644 (file)
@@ -27,7 +27,6 @@ namespace nl {
 
 NlExtTheoryCallback::NlExtTheoryCallback(eq::EqualityEngine* ee) : d_ee(ee)
 {
-  d_zero = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(0));
 }
 
 bool NlExtTheoryCallback::getCurrentSubstitution(
@@ -73,7 +72,7 @@ bool NlExtTheoryCallback::isExtfReduced(
     // we do not handle reductions of transcendental functions here
     return false;
   }
-  if (n != d_zero)
+  if (!isZero(n))
   {
     Kind k = n.getKind();
     if (k != NONLINEAR_MULT && !isTranscendentalKind(k) && k != IAND
@@ -91,7 +90,6 @@ bool NlExtTheoryCallback::isExtfReduced(
   // As an optimization, we minimize the explanation for why a term can be
   // simplified to zero, for example, if (= x 0) ^ (= y 5) => (= (* x y) 0),
   // we minimize the explanation to (= x 0) => (= (* x y) 0).
-  Assert(n == d_zero);
   id = ExtReducedId::ARITH_SR_ZERO;
   if (on.getKind() == NONLINEAR_MULT)
   {
@@ -124,7 +122,7 @@ bool NlExtTheoryCallback::isExtfReduced(
       {
         for (unsigned r = 0; r < 2; r++)
         {
-          if (eqs[j][r] == d_zero && vars.find(eqs[j][1 - r]) != vars.end())
+          if (isZero(eqs[j][r]) && vars.find(eqs[j][1 - r]) != vars.end())
           {
             Trace("nl-ext-zero-exp")
                 << "...single exp : " << eqs[j] << std::endl;
index c27fef929471f73da6208286d03d145a0540f20c..c458e55f944f85eed9698c5c8ef68b38d47bdcbf 100644 (file)
@@ -68,8 +68,6 @@ class NlExtTheoryCallback : public ExtTheoryCallback
  private:
   /** The underlying equality engine. */
   eq::EqualityEngine* d_ee;
-  /** Commonly used nodes */
-  Node d_zero;
 };
 
 }  // namespace nl
index a73cf4771fe40c3d8d4f5269424a6e64e972cf5d..e13294f10faf1af4dc8b10a8b141d9f47e425b0c 100644 (file)
@@ -324,6 +324,8 @@ bool NlModel::addSubstitution(TNode v, TNode s)
 
 bool NlModel::addBound(TNode v, TNode l, TNode u)
 {
+  Assert(l.getType().isSubtypeOf(v.getType()));
+  Assert(u.getType().isSubtypeOf(v.getType()));
   Trace("nl-ext-model") << "* check model bound : " << v << " -> [" << l << " "
                         << u << "]" << std::endl;
   if (l == u)
index 0e5c0d5f09b47b25b08b13da1d72895b09d8bc91..cb116e97db46da87d084ecbc0b5b4ce0f97575b7 100644 (file)
@@ -41,7 +41,7 @@ class RationalEnumerator : public TypeEnumeratorBase<RationalEnumerator> {
 
   Node operator*() override
   {
-    return NodeManager::currentNM()->mkConst(kind::CONST_RATIONAL, d_rat);
+    return NodeManager::currentNM()->mkConstReal(d_rat);
   }
   RationalEnumerator& operator++() override
   {
@@ -85,8 +85,7 @@ class IntegerEnumerator : public TypeEnumeratorBase<IntegerEnumerator> {
 
   Node operator*() override
   {
-    return NodeManager::currentNM()->mkConst(kind::CONST_RATIONAL,
-                                             Rational(d_int));
+    return NodeManager::currentNM()->mkConstInt(Rational(d_int));
   }
 
   IntegerEnumerator& operator++() override
index 2c1b5126ef83fa8320863b2745e48019a63af190..3c5089943ef672602a94cf8f033661e9e87f3b9c 100644 (file)
@@ -738,7 +738,7 @@ Node BagsUtils::evaluateBagFilter(TNode n)
 
   for (const auto& [e, count] : elements)
   {
-    Node multiplicity = nm->mkConst(CONST_RATIONAL, count);
+    Node multiplicity = nm->mkConstInt(count);
     Node bag = nm->mkBag(bagType.getBagElementType(), e, multiplicity);
     Node pOfe = nm->mkNode(APPLY_UF, P, e);
     Node ite = nm->mkNode(ITE, pOfe, bag, empty);
index 183b3213bf789dae829af377010431a7e731828e..4307dcbe345217aee0dc2fb2838ac3b193177a47 100644 (file)
@@ -401,7 +401,7 @@ bool TheoryBags::collectModelValues(TheoryModel* m,
                 nm->getSkolemManager()->mkDummySkolem("slack", elementType);
             Trace("bags-model") << "newElement is " << newElement << std::endl;
             Rational difference = rCardRational - constructedRational;
-            Node multiplicity = nm->mkConst(CONST_RATIONAL, difference);
+            Node multiplicity = nm->mkConstInt(difference);
             Node slackBag = nm->mkBag(elementType, newElement, multiplicity);
             constructedBag =
                 nm->mkNode(kind::BAG_UNION_DISJOINT, constructedBag, slackBag);
index 479aa870b3bc76b24559a9f35634bd76ef645e5f..e1627de97e920205f0bed17d3c17b5c558fb1287 100644 (file)
@@ -775,7 +775,7 @@ Node TheoryStrings::mkSkeletonFromBase(Node r,
   TypeNode etn = r.getType().getSequenceElementType();
   for (size_t i = currIndex; i < nextIndex; i++)
   {
-    cacheVals.push_back(nm->mkConst(CONST_RATIONAL, Rational(currIndex)));
+    cacheVals.push_back(nm->mkConstInt(Rational(currIndex)));
     Node kv = sm->mkSkolemFunction(
         SkolemFunId::SEQ_MODEL_BASE_ELEMENT, etn, cacheVals);
     skChildren.push_back(nm->mkSeqUnit(etn, kv));