Make floating-point comparison operators chainable (#1101)
authorMartin <martin.brain@diffblue.com>
Fri, 15 Sep 2017 03:04:20 +0000 (04:04 +0100)
committerAndres Noetzli <andres.noetzli@gmail.com>
Fri, 15 Sep 2017 03:04:20 +0000 (20:04 -0700)
Floating-point comparison operators are chainable according to the standard.

src/theory/fp/kinds
src/theory/fp/theory_fp_rewriter.cpp

index 983d5aa5c15878f85e4dd17f9cd9c999430741c1..a1b28d885be5f96a6f107583fb5413392ced4d1d 100644 (file)
@@ -94,16 +94,16 @@ operator FLOATINGPOINT_MAX 2 "floating-point maximum"
 typerule FLOATINGPOINT_MAX   ::CVC4::theory::fp::FloatingPointOperationTypeRule
 
 
-operator FLOATINGPOINT_LEQ 2 "floating-point less than or equal"
+operator FLOATINGPOINT_LEQ 2: "floating-point less than or equal"
 typerule FLOATINGPOINT_LEQ   ::CVC4::theory::fp::FloatingPointTestTypeRule
 
-operator FLOATINGPOINT_LT 2 "floating-point less than"
+operator FLOATINGPOINT_LT 2: "floating-point less than"
 typerule FLOATINGPOINT_LT   ::CVC4::theory::fp::FloatingPointTestTypeRule
 
-operator FLOATINGPOINT_GEQ 2 "floating-point greater than or equal"
+operator FLOATINGPOINT_GEQ 2: "floating-point greater than or equal"
 typerule FLOATINGPOINT_GEQ   ::CVC4::theory::fp::FloatingPointTestTypeRule
 
-operator FLOATINGPOINT_GT 2 "floating-point greater than"
+operator FLOATINGPOINT_GT 2: "floating-point greater than"
 typerule FLOATINGPOINT_GT   ::CVC4::theory::fp::FloatingPointTestTypeRule
 
 
index ba4bf9228e2cc574c3c7e3415f161fbee7f9adfa..747aaeac6b3d5da5f8ce4a35f843d083b26e5252 100644 (file)
@@ -37,6 +37,16 @@ namespace fp {
 
 namespace rewrite {
   /** Rewrite rules **/
+  template <RewriteFunction first, RewriteFunction second>
+  RewriteResponse then (TNode node, bool isPreRewrite) {
+    RewriteResponse result(first(node, isPreRewrite));
+
+    if (result.status == REWRITE_DONE) {
+      return second(result.node, isPreRewrite);
+    } else {
+      return result;
+    }
+  }
 
   RewriteResponse notFP (TNode node, bool) {
     Unreachable("non floating-point kind (%d) in floating point rewrite?",node.getKind());
@@ -67,6 +77,34 @@ namespace rewrite {
     return RewriteResponse(REWRITE_DONE, addition);
   }
 
+  RewriteResponse breakChain (TNode node, bool isPreRewrite) {
+    Assert(isPreRewrite);  // Should be run first
+
+    Kind k = node.getKind();
+    Assert(k == kind::FLOATINGPOINT_EQ ||
+          k == kind::FLOATINGPOINT_GEQ ||
+          k == kind::FLOATINGPOINT_LEQ ||
+          k == kind::FLOATINGPOINT_GT ||
+          k == kind::FLOATINGPOINT_LT);
+
+
+    size_t children = node.getNumChildren();
+    if (children > 2) {
+
+      NodeBuilder<> conjunction(kind::AND);
+
+      for (size_t i = 0; i < children - 1; ++i) {
+       for (size_t j = i + 1; j < children; ++j) {
+         conjunction << NodeManager::currentNM()->mkNode(k, node[i], node[j]);
+       }
+      }
+      return RewriteResponse(REWRITE_AGAIN_FULL, conjunction);
+
+    } else {
+      return RewriteResponse(REWRITE_DONE, node);
+    }
+  }
+
 
   /* Implies (fp.eq x x) --> (not (isNaN x))
    */
@@ -314,11 +352,11 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND];
     preRewriteTable[kind::FLOATINGPOINT_MAX] = rewrite::compactMinMax;
 
     /******** Comparisons ********/
-    preRewriteTable[kind::FLOATINGPOINT_EQ] = rewrite::ieeeEqToEq;
-    preRewriteTable[kind::FLOATINGPOINT_LEQ] = rewrite::identity;
-    preRewriteTable[kind::FLOATINGPOINT_LT] = rewrite::identity;
-    preRewriteTable[kind::FLOATINGPOINT_GEQ] = rewrite::geqToleq;
-    preRewriteTable[kind::FLOATINGPOINT_GT] = rewrite::gtTolt;
+    preRewriteTable[kind::FLOATINGPOINT_EQ] = rewrite::then<rewrite::breakChain,rewrite::ieeeEqToEq>;
+    preRewriteTable[kind::FLOATINGPOINT_LEQ] = rewrite::breakChain;
+    preRewriteTable[kind::FLOATINGPOINT_LT] = rewrite::breakChain;
+    preRewriteTable[kind::FLOATINGPOINT_GEQ] = rewrite::then<rewrite::breakChain,rewrite::geqToleq>;
+    preRewriteTable[kind::FLOATINGPOINT_GT] = rewrite::then<rewrite::breakChain,rewrite::gtTolt>;
 
     /******** Classifications ********/
     preRewriteTable[kind::FLOATINGPOINT_ISN] = rewrite::identity;