Refactor abs rewriting (#7935)
authorGereon Kremer <gkremer@stanford.edu>
Thu, 20 Jan 2022 18:29:25 +0000 (10:29 -0800)
committerGitHub <noreply@github.com>
Thu, 20 Jan 2022 18:29:25 +0000 (18:29 +0000)
This PR refactors rewriting for ABS to also support real algebraic number. We also generalize the ABS operator to real arguments in general, instead of integer arguments.

src/theory/arith/arith_rewriter.cpp
src/theory/arith/arith_rewriter.h
src/theory/arith/kinds
test/unit/theory/theory_arith_rewriter_black.cpp

index c698490b0ef234370aabf2ae8ce97de3460fa193..4eba5dbfc006b92dc532a7a26ec960ae66323539 100644 (file)
@@ -331,22 +331,7 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
       case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
       case kind::INTS_DIVISION_TOTAL:
       case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, true);
-      case kind::ABS:
-        if (t[0].isConst())
-        {
-          const Rational& rat = t[0].getConst<Rational>();
-          if (rat >= 0)
-          {
-            return RewriteResponse(REWRITE_DONE, t[0]);
-          }
-          else
-          {
-            return RewriteResponse(REWRITE_DONE,
-                                   NodeManager::currentNM()->mkConstRealOrInt(
-                                       t[0].getType(), -rat));
-          }
-        }
-        return RewriteResponse(REWRITE_DONE, t);
+      case kind::ABS: return rewriteAbs(t);
       case kind::IS_INTEGER:
       case kind::TO_INTEGER: return RewriteResponse(REWRITE_DONE, t);
       case kind::TO_REAL:
@@ -394,22 +379,7 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
       case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
       case kind::INTS_DIVISION_TOTAL:
       case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false);
-      case kind::ABS:
-        if (t[0].isConst())
-        {
-          const Rational& rat = t[0].getConst<Rational>();
-          if (rat >= 0)
-          {
-            return RewriteResponse(REWRITE_DONE, t[0]);
-          }
-          else
-          {
-            return RewriteResponse(REWRITE_DONE,
-                                   NodeManager::currentNM()->mkConstRealOrInt(
-                                       t[0].getType(), -rat));
-          }
-        }
-        return RewriteResponse(REWRITE_DONE, t);
+      case kind::ABS: return rewriteAbs(t);
       case kind::TO_REAL:
       case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
       case kind::TO_INTEGER: return rewriteExtIntegerOp(t);
@@ -997,6 +967,36 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
   return RewriteResponse(REWRITE_DONE, t);
 }
 
+RewriteResponse ArithRewriter::rewriteAbs(TNode t)
+{
+  Assert(t.getKind() == Kind::ABS);
+  Assert(t.getNumChildren() == 1);
+
+  if (t[0].isConst())
+  {
+    const Rational& rat = t[0].getConst<Rational>();
+    if (rat >= 0)
+    {
+      return RewriteResponse(REWRITE_DONE, t[0]);
+    }
+    return RewriteResponse(
+        REWRITE_DONE,
+        NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat));
+  }
+  if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+  {
+    const RealAlgebraicNumber& ran =
+        t[0].getOperator().getConst<RealAlgebraicNumber>();
+    if (ran >= RealAlgebraicNumber())
+    {
+      return RewriteResponse(REWRITE_DONE, t[0]);
+    }
+    return RewriteResponse(
+        REWRITE_DONE, NodeManager::currentNM()->mkRealAlgebraicNumber(-ran));
+  }
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
 RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
 {
   NodeManager* nm = NodeManager::currentNM();
index 2e89432f8f40be68c7a4977f307a3d2921a4538b..90140cc18713c00be3fc7e2921844ebeb1efb953 100644 (file)
@@ -59,6 +59,7 @@ class ArithRewriter : public TheoryRewriter
   static RewriteResponse rewriteMinus(TNode t);
   static RewriteResponse rewriteUMinus(TNode t, bool pre);
   static RewriteResponse rewriteDiv(TNode t, bool pre);
+  static RewriteResponse rewriteAbs(TNode t);
   static RewriteResponse rewriteIntsDivMod(TNode t, bool pre);
   static RewriteResponse rewriteIntsDivModTotal(TNode t, bool pre);
   /** Entry for applications of to_int and is_int */
index 40557bc657d3b55d73721f7b645109b46cb80c93..ba326ba5603a5be576457808b1465630efd435f8 100644 (file)
@@ -145,7 +145,7 @@ typerule CAST_TO_REAL ::cvc5::theory::arith::ArithOperatorTypeRule
 typerule TO_INTEGER ::cvc5::theory::arith::ArithOperatorTypeRule
 typerule IS_INTEGER "SimpleTypeRule<RBool, AReal>"
 
-typerule ABS "SimpleTypeRule<RInteger, AInteger>"
+typerule ABS ::cvc5::theory::arith::ArithOperatorTypeRule
 typerule INTS_DIVISION "SimpleTypeRule<RInteger, AInteger, AInteger>"
 typerule INTS_MODULUS "SimpleTypeRule<RInteger, AInteger, AInteger>"
 typerule DIVISIBLE "SimpleTypeRule<RBool, AInteger>"
index 0147b591f60b49e9753347ee1593e403e574ee73..3fcd74356fb052bb8e9e9e6ce7f69a2204905adb 100644 (file)
@@ -79,5 +79,41 @@ TEST_F(TestTheoryArithRewriterBlack, RealAlgebraicNumber)
   }
 }
 
+TEST_F(TestTheoryArithRewriterBlack, Abs)
+{
+  {
+    Node a = d_nodeManager->mkConstReal(10);
+    Node b = d_nodeManager->mkConstReal(-10);
+    Node m = d_nodeManager->mkNode(Kind::ABS, a);
+    Node n = d_nodeManager->mkNode(Kind::ABS, b);
+    m = d_slvEngine->getRewriter()->rewrite(m);
+    n = d_slvEngine->getRewriter()->rewrite(n);
+    EXPECT_EQ(m, a);
+    EXPECT_EQ(n, a);
+  }
+  {
+    Node a = d_nodeManager->mkConstReal(Rational(3,2));
+    Node b = d_nodeManager->mkConstReal(Rational(-3,2));
+    Node m = d_nodeManager->mkNode(Kind::ABS, a);
+    Node n = d_nodeManager->mkNode(Kind::ABS, b);
+    m = d_slvEngine->getRewriter()->rewrite(m);
+    n = d_slvEngine->getRewriter()->rewrite(n);
+    EXPECT_EQ(m, a);
+    EXPECT_EQ(n, a);
+  }
+  {
+    RealAlgebraicNumber msqrt2({-2, 0, 1}, -2, -1);
+    RealAlgebraicNumber sqrt2({-2, 0, 1}, 1, 2);
+    Node a = d_nodeManager->mkRealAlgebraicNumber(msqrt2);
+    Node b = d_nodeManager->mkRealAlgebraicNumber(sqrt2);
+    Node m = d_nodeManager->mkNode(Kind::ABS, a);
+    Node n = d_nodeManager->mkNode(Kind::ABS, b);
+    m = d_slvEngine->getRewriter()->rewrite(m);
+    n = d_slvEngine->getRewriter()->rewrite(n);
+    EXPECT_EQ(m, b);
+    EXPECT_EQ(n, b);
+  }
+}
+
 }  // namespace test
 }  // namespace cvc5