Feature/fp rewrite improvement (#2154)
authorMartin <martin.brain@cs.ox.ac.uk>
Fri, 6 Jul 2018 18:22:57 +0000 (19:22 +0100)
committerAndres Noetzli <andres.noetzli@gmail.com>
Fri, 6 Jul 2018 18:22:57 +0000 (11:22 -0700)
src/theory/fp/theory_fp_rewriter.cpp

index 14853612603eb035fb46b823e899cc4fa5b092b3..372db40b188b467beb2ba65b3d4a00e3b581f15b 100644 (file)
@@ -18,7 +18,6 @@
  **          (= x (fp.neg x)) --> (isNaN x)
  **          (fp.eq x (fp.neg x)) --> (isZero x)   (previous and reorganise should be sufficient)
  **          (fp.eq x const) --> various = depending on const
- **          (fp.abs (fp.neg x)) --> (fp.abs x)
  **          (fp.isPositive (fp.neg x)) --> (fp.isNegative x)
  **          (fp.isNegative (fp.neg x)) --> (fp.isPositive x)
  **          (fp.isPositive (fp.abs x)) --> (not (isNaN x))
@@ -69,7 +68,17 @@ namespace rewrite {
   RewriteResponse removeDoubleNegation (TNode node, bool) {
     Assert(node.getKind() == kind::FLOATINGPOINT_NEG);
     if (node[0].getKind() == kind::FLOATINGPOINT_NEG) {
-      RewriteResponse(REWRITE_AGAIN, node[0][0]);
+      return RewriteResponse(REWRITE_AGAIN, node[0][0]);
+    }
+
+    return RewriteResponse(REWRITE_DONE, node);
+  }
+
+  RewriteResponse compactAbs (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_ABS);
+    if (node[0].getKind() == kind::FLOATINGPOINT_NEG ||
+       node[0].getKind() == kind::FLOATINGPOINT_ABS) {
+      return RewriteResponse(REWRITE_AGAIN, node[0][0]);
     }
 
     return RewriteResponse(REWRITE_DONE, node);
@@ -246,6 +255,34 @@ namespace rewrite {
     } 
   }
 
+  RewriteResponse compactRemainder (TNode node, bool isPreRewrite) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_REM);
+    Assert(!isPreRewrite);  // status assumes parts have been rewritten
+
+    Node working = node;
+
+    // (fp.rem (fp.rem X Y) Y) == (fp.rem X Y)
+    if (working[0].getKind() == kind::FLOATINGPOINT_REM && // short-cut matters!
+       working[0][1] == working[1]) {
+      working = working[0];
+    }
+
+    // Sign of the RHS does not matter
+    if (working[1].getKind() == kind::FLOATINGPOINT_NEG ||
+       working[1].getKind() == kind::FLOATINGPOINT_ABS) {
+      working[1] = working[1][0];
+    }
+
+    // Lift negation out of the LHS so it can be cancelled out
+    if (working[0].getKind() == kind::FLOATINGPOINT_NEG) {
+      NodeManager * nm = NodeManager::currentNM();
+      working = nm->mkNode(kind::FLOATINGPOINT_NEG,
+                          nm->mkNode(kind::FLOATINGPOINT_REM, working[0][0], working[1]));
+    }
+
+    return RewriteResponse(REWRITE_DONE, working);
+  }
+
 }; /* CVC4::theory::fp::rewrite */
 
 
@@ -931,7 +968,7 @@ RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND];
       
     /******** Operations ********/
     preRewriteTable[kind::FLOATINGPOINT_FP] = rewrite::identity;
-    preRewriteTable[kind::FLOATINGPOINT_ABS] = rewrite::identity;
+    preRewriteTable[kind::FLOATINGPOINT_ABS] = rewrite::compactAbs;
     preRewriteTable[kind::FLOATINGPOINT_NEG] = rewrite::removeDoubleNegation;
     preRewriteTable[kind::FLOATINGPOINT_PLUS] = rewrite::identity;
     preRewriteTable[kind::FLOATINGPOINT_SUB] = rewrite::convertSubtractionToAddition;
@@ -1014,7 +1051,7 @@ RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND];
       
     /******** Operations ********/
     postRewriteTable[kind::FLOATINGPOINT_FP] = rewrite::identity;
-    postRewriteTable[kind::FLOATINGPOINT_ABS] = rewrite::identity;
+    postRewriteTable[kind::FLOATINGPOINT_ABS] = rewrite::compactAbs;
     postRewriteTable[kind::FLOATINGPOINT_NEG] = rewrite::removeDoubleNegation;
     postRewriteTable[kind::FLOATINGPOINT_PLUS] = rewrite::reorderBinaryOperation;
     postRewriteTable[kind::FLOATINGPOINT_SUB] = rewrite::removed;
@@ -1022,7 +1059,7 @@ RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND];
     postRewriteTable[kind::FLOATINGPOINT_DIV] = rewrite::identity;
     postRewriteTable[kind::FLOATINGPOINT_FMA] = rewrite::reorderFMA;
     postRewriteTable[kind::FLOATINGPOINT_SQRT] = rewrite::identity;
-    postRewriteTable[kind::FLOATINGPOINT_REM] = rewrite::identity;
+    postRewriteTable[kind::FLOATINGPOINT_REM] = rewrite::compactRemainder;
     postRewriteTable[kind::FLOATINGPOINT_RTI] = rewrite::identity;
     postRewriteTable[kind::FLOATINGPOINT_MIN] = rewrite::compactMinMax;
     postRewriteTable[kind::FLOATINGPOINT_MAX] = rewrite::compactMinMax;