Refactor rewriting of arithmetic atoms (#8175)
authorGereon Kremer <gereon.kremer@cs.rwth-aachen.de>
Mon, 28 Feb 2022 23:15:27 +0000 (00:15 +0100)
committerGitHub <noreply@github.com>
Mon, 28 Feb 2022 23:15:27 +0000 (23:15 +0000)
This PR uses the new utilities for atom rewriting in the arithmetic rewriter.

src/theory/arith/arith_rewriter.cpp
src/theory/arith/arith_rewriter.h

index 38fb0d162905ced381c4d3cc5e6c1f5a4c16ee62..863eaee1d8794b86606aa8a71349d12d77ead221 100644 (file)
@@ -33,6 +33,7 @@
 #include "theory/arith/rewriter/addition.h"
 #include "theory/arith/rewriter/node_utils.h"
 #include "theory/arith/rewriter/ordering.h"
+#include "theory/arith/rewriter/rewrite_atom.h"
 #include "theory/theory.h"
 #include "util/bitvector.h"
 #include "util/divisible.h"
@@ -47,20 +48,6 @@ namespace arith {
 
 namespace {
 
-template <typename L, typename R>
-bool evaluateRelation(Kind rel, const L& l, const R& r)
-{
-  switch (rel)
-  {
-    case Kind::LT: return l < r;
-    case Kind::LEQ: return l <= r;
-    case Kind::EQUAL: return l == r;
-    case Kind::GEQ: return l >= r;
-    case Kind::GT: return l > r;
-    default: Unreachable(); return false;
-  }
-}
-
 /**
  * Check whether the parent has a child that is a constant zero.
  * If so, return this child. Otherwise, return std::nullopt.
@@ -85,7 +72,7 @@ ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
 RewriteResponse ArithRewriter::preRewrite(TNode t)
 {
   Trace("arith-rewriter") << "preRewrite(" << t << ")" << std::endl;
-  if (isAtom(t))
+  if (rewriter::isAtom(t))
   {
     auto res = preRewriteAtom(t);
     Trace("arith-rewriter")
@@ -100,7 +87,7 @@ RewriteResponse ArithRewriter::preRewrite(TNode t)
 RewriteResponse ArithRewriter::postRewrite(TNode t)
 {
   Trace("arith-rewriter") << "postRewrite(" << t << ")" << std::endl;
-  if (isAtom(t))
+  if (rewriter::isAtom(t))
   {
     auto res = postRewriteAtom(t);
     Trace("arith-rewriter")
@@ -114,41 +101,33 @@ RewriteResponse ArithRewriter::postRewrite(TNode t)
 
 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom)
 {
-  Assert(isAtom(atom));
-
-  NodeManager* nm = NodeManager::currentNM();
+  Assert(rewriter::isAtom(atom));
 
-  if (isRelationOperator(atom.getKind()) && atom[0] == atom[1])
+  if (auto response = rewriter::tryEvaluateRelationReflexive(atom); response)
   {
-    switch (atom.getKind())
-    {
-      case Kind::LT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
-      case Kind::LEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
-      case Kind::EQUAL: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
-      case Kind::GEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
-      case Kind::GT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
-      default:;
-    }
+    return RewriteResponse(REWRITE_DONE, rewriter::mkConst(*response));
   }
 
   switch (atom.getKind())
   {
     case Kind::GT:
-      return RewriteResponse(REWRITE_DONE,
-                             nm->mkNode(kind::LEQ, atom[0], atom[1]).notNode());
+      return RewriteResponse(
+          REWRITE_DONE,
+          rewriter::buildRelation(kind::LEQ, atom[0], atom[1], true));
     case Kind::LT:
-      return RewriteResponse(REWRITE_DONE,
-                             nm->mkNode(kind::GEQ, atom[0], atom[1]).notNode());
+      return RewriteResponse(
+          REWRITE_DONE,
+          rewriter::buildRelation(kind::GEQ, atom[0], atom[1], true));
     case Kind::IS_INTEGER:
       if (atom[0].getType().isInteger())
       {
-        return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+        return RewriteResponse(REWRITE_DONE, rewriter::mkConst(true));
       }
       break;
     case Kind::DIVISIBLE:
       if (atom.getOperator().getConst<Divisible>().k.isOne())
       {
-        return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+        return RewriteResponse(REWRITE_DONE, rewriter::mkConst(true));
       }
       break;
     default:;
@@ -159,93 +138,91 @@ RewriteResponse ArithRewriter::preRewriteAtom(TNode atom)
 
 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom)
 {
-  Assert(isAtom(atom));
+  Assert(rewriter::isAtom(atom));
+
   if (atom.getKind() == kind::IS_INTEGER)
   {
     return rewriteExtIntegerOp(atom);
   }
   else if (atom.getKind() == kind::DIVISIBLE)
   {
+    const Integer& k = atom.getOperator().getConst<Divisible>().k;
     if (atom[0].isConst())
     {
+      const Rational& num = atom[0].getConst<Rational>();
       return RewriteResponse(REWRITE_DONE,
-                             NodeManager::currentNM()->mkConst(bool(
-                                 (atom[0].getConst<Rational>()
-                                  / atom.getOperator().getConst<Divisible>().k)
-                                     .isIntegral())));
+                             rewriter::mkConst((num / k).isIntegral()));
     }
-    if (atom.getOperator().getConst<Divisible>().k.isOne())
+    if (k.isOne())
     {
-      return RewriteResponse(REWRITE_DONE,
-                             NodeManager::currentNM()->mkConst(true));
+      return RewriteResponse(REWRITE_DONE, rewriter::mkConst(true));
     }
     NodeManager* nm = NodeManager::currentNM();
     return RewriteResponse(
         REWRITE_AGAIN,
-        nm->mkNode(kind::EQUAL,
-                   nm->mkNode(kind::INTS_MODULUS_TOTAL,
-                              atom[0],
-                              nm->mkConstInt(Rational(
-                                  atom.getOperator().getConst<Divisible>().k))),
-                   nm->mkConstInt(Rational(0))));
+        nm->mkNode(
+            kind::EQUAL,
+            nm->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], rewriter::mkConst(k)),
+            rewriter::mkConst(Integer(0))));
+  }
+
+  if (auto response = rewriter::tryEvaluateRelationReflexive(atom); response)
+  {
+    return RewriteResponse(REWRITE_DONE, rewriter::mkConst(*response));
   }
 
   // left |><| right
+  Kind kind = atom.getKind();
   TNode left = atom[0];
   TNode right = atom[1];
+  Assert(isRelationOperator(kind));
 
-  auto* nm = NodeManager::currentNM();
-  if (left.isConst())
+  if (auto response = rewriter::tryEvaluateRelation(kind, left, right);
+      response)
   {
-    const Rational& l = left.getConst<Rational>();
-    if (right.isConst())
-    {
-      const Rational& r = right.getConst<Rational>();
-      return RewriteResponse(
-          REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
-    }
-    else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
-    {
-      const RealAlgebraicNumber& r =
-          right.getOperator().getConst<RealAlgebraicNumber>();
-      return RewriteResponse(
-          REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
-    }
+    return RewriteResponse(REWRITE_DONE, rewriter::mkConst(*response));
   }
-  else if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+
+  bool negate = false;
+
+  switch (atom.getKind())
   {
-    const RealAlgebraicNumber& l =
-        left.getOperator().getConst<RealAlgebraicNumber>();
-    if (right.isConst())
+    case Kind::LEQ:
+      kind = Kind::GEQ;
+      negate = true;
+      break;
+    case Kind::LT:
+      kind = Kind::GT;
+      negate = true;
+      break;
+    default: break;
+  }
+
+  rewriter::Sum sum;
+  rewriter::addToSum(sum, left, negate);
+  rewriter::addToSum(sum, right, !negate);
+
+  // Now we have (rsum <kind> 0)
+  if (rewriter::isIntegral(atom))
+  {
+    if (kind == Kind::EQUAL)
     {
-      const Rational& r = right.getConst<Rational>();
-      return RewriteResponse(
-          REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+      return RewriteResponse(REWRITE_DONE,
+                             rewriter::buildIntegerEquality(std::move(sum)));
     }
-    else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+    return RewriteResponse(
+        REWRITE_DONE, rewriter::buildIntegerInequality(std::move(sum), kind));
+  }
+  else
+  {
+    if (kind == Kind::EQUAL)
     {
-      const RealAlgebraicNumber& r =
-          right.getOperator().getConst<RealAlgebraicNumber>();
-      return RewriteResponse(
-          REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+      return RewriteResponse(REWRITE_DONE,
+                             rewriter::buildRealEquality(std::move(sum)));
     }
+    return RewriteResponse(REWRITE_DONE,
+                           rewriter::buildRealInequality(std::move(sum), kind));
   }
-
-  Polynomial pleft = Polynomial::parsePolynomial(left);
-  Polynomial pright = Polynomial::parsePolynomial(right);
-
-  Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
-  Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
-
-  Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
-  Assert(cmp.isNormalForm());
-  return RewriteResponse(REWRITE_DONE, cmp.getNode());
-}
-
-bool ArithRewriter::isAtom(TNode n) {
-  Kind k = n.getKind();
-  return arith::isRelationOperator(k) || k == kind::IS_INTEGER
-      || k == kind::DIVISIBLE;
 }
 
 RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre)
index c4590ffdc58d241af6c852704ec7840380e2f38f..3ffa90015b27f1f975e9ee0ac3c060581db5a835 100644 (file)
@@ -43,7 +43,9 @@ class ArithRewriter : public TheoryRewriter
   TrustNode expandDefinition(Node node) override;
 
  private:
+  /** preRewrite for atoms */
   static RewriteResponse preRewriteAtom(TNode t);
+  /** postRewrite for atoms */
   static RewriteResponse postRewriteAtom(TNode t);
 
   static RewriteResponse preRewriteTerm(TNode t);
@@ -82,11 +84,6 @@ class ArithRewriter : public TheoryRewriter
   /** postRewrite transcendental functions */
   static RewriteResponse postRewriteTranscendental(TNode t);
 
-  static bool isAtom(TNode n);
-
-  static inline bool isTerm(TNode n) {
-    return !isAtom(n);
-  }
   /** return rewrite */
   static RewriteResponse returnRewrite(TNode t, Node ret, Rewrite r);
   /** The operator elimination utility */