Add parameter name for argument `isPreRewrite` for FP rewrites. (#6469)
authorAina Niemetz <aina.niemetz@gmail.com>
Fri, 30 Apr 2021 22:06:30 +0000 (15:06 -0700)
committerGitHub <noreply@github.com>
Fri, 30 Apr 2021 22:06:30 +0000 (22:06 +0000)
src/theory/fp/fp_expand_defs.cpp
src/theory/fp/theory_fp_rewriter.cpp

index 4e9803bf74be6bd17b983ebf91faedaed4143549..34cc1ed5d573c4a17f2b0a05b4f989a5245b12f9 100644 (file)
@@ -260,22 +260,23 @@ TrustNode FpExpandDefs::expandDefinition(Node node)
       << "FpExpandDefs::expandDefinition(): " << node << std::endl;
 
   Node res = node;
+  Kind kind = node.getKind();
 
-  if (node.getKind() == kind::FLOATINGPOINT_TO_FP_GENERIC)
+  if (kind == kind::FLOATINGPOINT_TO_FP_GENERIC)
   {
     res = removeToFPGeneric::removeToFPGeneric(node);
   }
-  else if (node.getKind() == kind::FLOATINGPOINT_MIN)
+  else if (kind == kind::FLOATINGPOINT_MIN)
   {
     res = NodeManager::currentNM()->mkNode(
         kind::FLOATINGPOINT_MIN_TOTAL, node[0], node[1], minUF(node));
   }
-  else if (node.getKind() == kind::FLOATINGPOINT_MAX)
+  else if (kind == kind::FLOATINGPOINT_MAX)
   {
     res = NodeManager::currentNM()->mkNode(
         kind::FLOATINGPOINT_MAX_TOTAL, node[0], node[1], maxUF(node));
   }
-  else if (node.getKind() == kind::FLOATINGPOINT_TO_UBV)
+  else if (kind == kind::FLOATINGPOINT_TO_UBV)
   {
     FloatingPointToUBV info = node.getOperator().getConst<FloatingPointToUBV>();
     FloatingPointToUBVTotal newInfo(info);
@@ -287,7 +288,7 @@ TrustNode FpExpandDefs::expandDefinition(Node node)
             node[1],
             toUBVUF(node));
   }
-  else if (node.getKind() == kind::FLOATINGPOINT_TO_SBV)
+  else if (kind == kind::FLOATINGPOINT_TO_SBV)
   {
     FloatingPointToSBV info = node.getOperator().getConst<FloatingPointToSBV>();
     FloatingPointToSBVTotal newInfo(info);
@@ -299,15 +300,11 @@ TrustNode FpExpandDefs::expandDefinition(Node node)
             node[1],
             toSBVUF(node));
   }
-  else if (node.getKind() == kind::FLOATINGPOINT_TO_REAL)
+  else if (kind == kind::FLOATINGPOINT_TO_REAL)
   {
     res = NodeManager::currentNM()->mkNode(
         kind::FLOATINGPOINT_TO_REAL_TOTAL, node[0], toRealUF(node));
   }
-  else
-  {
-    // Do nothing
-  }
 
   if (res != node)
   {
index 07fde6a88e39397b627520aeffbbe4e1906a65d0..e431ffa09948629155a18439f4868e06eae293c5 100644 (file)
@@ -58,21 +58,25 @@ namespace rewrite {
     }
   }
 
-  RewriteResponse notFP (TNode node, bool) {
+  RewriteResponse notFP(TNode node, bool isPreRewrite)
+  {
     Unreachable() << "non floating-point kind (" << node.getKind()
                   << ") in floating point rewrite?";
   }
 
-  RewriteResponse identity (TNode node, bool) {
+  RewriteResponse identity(TNode node, bool isPreRewrite)
+  {
     return RewriteResponse(REWRITE_DONE, node);
   }
 
-  RewriteResponse type (TNode node, bool) {
+  RewriteResponse type(TNode node, bool isPreRewrite)
+  {
     Unreachable() << "sort kind (" << node.getKind()
                   << ") found in expression?";
   }
 
-  RewriteResponse removeDoubleNegation (TNode node, bool) {
+  RewriteResponse removeDoubleNegation(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_NEG);
     if (node[0].getKind() == kind::FLOATINGPOINT_NEG) {
       return RewriteResponse(REWRITE_AGAIN, node[0][0]);
@@ -81,7 +85,8 @@ namespace rewrite {
     return RewriteResponse(REWRITE_DONE, node);
   }
 
-  RewriteResponse compactAbs (TNode node, bool) {
+  RewriteResponse compactAbs(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_ABS);
     if (node[0].getKind() == kind::FLOATINGPOINT_NEG
         || node[0].getKind() == kind::FLOATINGPOINT_ABS)
@@ -94,7 +99,8 @@ namespace rewrite {
     return RewriteResponse(REWRITE_DONE, node);
   }
 
-  RewriteResponse convertSubtractionToAddition (TNode node, bool) {
+  RewriteResponse convertSubtractionToAddition(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_SUB);
     Node negation = NodeManager::currentNM()->mkNode(kind::FLOATINGPOINT_NEG,node[2]);
     Node addition = NodeManager::currentNM()->mkNode(kind::FLOATINGPOINT_PLUS,node[0],node[1],negation);
@@ -129,7 +135,8 @@ namespace rewrite {
   /* Implies (fp.eq x x) --> (not (isNaN x))
    */
 
-  RewriteResponse ieeeEqToEq (TNode node, bool) {
+  RewriteResponse ieeeEqToEq(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_EQ);
     NodeManager *nm = NodeManager::currentNM();
 
@@ -145,24 +152,26 @@ namespace rewrite {
                                                            nm->mkNode(kind::FLOATINGPOINT_ISZ, node[1])))));
   }
 
-
-  RewriteResponse geqToleq (TNode node, bool) {
+  RewriteResponse geqToleq(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_GEQ);
     return RewriteResponse(REWRITE_DONE,NodeManager::currentNM()->mkNode(kind::FLOATINGPOINT_LEQ,node[1],node[0]));
   }
 
-  RewriteResponse gtTolt (TNode node, bool) {
+  RewriteResponse gtTolt(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_GT);
     return RewriteResponse(REWRITE_DONE,NodeManager::currentNM()->mkNode(kind::FLOATINGPOINT_LT,node[1],node[0]));
   }
 
-  RewriteResponse removed(TNode node, bool)
+  RewriteResponse removed(TNode node, bool isPreRewrite)
   {
     Unreachable() << "kind (" << node.getKind()
                   << ") should have been removed?";
   }
 
-  RewriteResponse variable (TNode node, bool) {  
+  RewriteResponse variable(TNode node, bool isPreRewrite)
+  {
     // We should only get floating point and rounding mode variables to rewrite.
     TypeNode tn = node.getType(true);
     Assert(tn.isFloatingPoint() || tn.isRoundingMode());
@@ -328,226 +337,264 @@ namespace rewrite {
 
 namespace constantFold {
 
+RewriteResponse fpLiteral(TNode node, bool isPreRewrite)
+{
+  Assert(node.getKind() == kind::FLOATINGPOINT_FP);
 
-  RewriteResponse fpLiteral (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_FP);
+  BitVector bv(node[0].getConst<BitVector>());
+  bv = bv.concat(node[1].getConst<BitVector>());
+  bv = bv.concat(node[2].getConst<BitVector>());
 
-    BitVector bv(node[0].getConst<BitVector>());
-    bv = bv.concat(node[1].getConst<BitVector>());
-    bv = bv.concat(node[2].getConst<BitVector>());
-    
-    // +1 to support the hidden bit
-    Node lit =
-      NodeManager::currentNM()->mkConst(FloatingPoint(node[1].getConst<BitVector>().getSize(),
-                                                     node[2].getConst<BitVector>().getSize() + 1,
-                                                     bv));
-    
-    return RewriteResponse(REWRITE_DONE, lit);
-  }
+  // +1 to support the hidden bit
+  Node lit = NodeManager::currentNM()->mkConst(
+      FloatingPoint(node[1].getConst<BitVector>().getSize(),
+                    node[2].getConst<BitVector>().getSize() + 1,
+                    bv));
 
-  RewriteResponse abs (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_ABS);
-    Assert(node.getNumChildren() == 1);
+  return RewriteResponse(REWRITE_DONE, lit);
+}
 
-    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().absolute()));
-  }
+RewriteResponse abs(TNode node, bool isPreRewrite)
+{
+  Assert(node.getKind() == kind::FLOATINGPOINT_ABS);
+  Assert(node.getNumChildren() == 1);
 
+  return RewriteResponse(REWRITE_DONE,
+                         NodeManager::currentNM()->mkConst(
+                             node[0].getConst<FloatingPoint>().absolute()));
+}
 
-  RewriteResponse neg (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_NEG);
-    Assert(node.getNumChildren() == 1);
+RewriteResponse neg(TNode node, bool isPreRewrite)
+{
+  Assert(node.getKind() == kind::FLOATINGPOINT_NEG);
+  Assert(node.getNumChildren() == 1);
 
-    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().negate()));
-  }
+  return RewriteResponse(REWRITE_DONE,
+                         NodeManager::currentNM()->mkConst(
+                             node[0].getConst<FloatingPoint>().negate()));
+}
 
+RewriteResponse plus(TNode node, bool isPreRewrite)
+{
+  Assert(node.getKind() == kind::FLOATINGPOINT_PLUS);
+  Assert(node.getNumChildren() == 3);
 
-  RewriteResponse plus (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_PLUS);
-    Assert(node.getNumChildren() == 3);
+  RoundingMode rm(node[0].getConst<RoundingMode>());
+  FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+  FloatingPoint arg2(node[2].getConst<FloatingPoint>());
 
-    RoundingMode rm(node[0].getConst<RoundingMode>());
-    FloatingPoint arg1(node[1].getConst<FloatingPoint>());
-    FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+  Assert(arg1.getSize() == arg2.getSize());
 
-    Assert(arg1.getSize() == arg2.getSize());
+  return RewriteResponse(
+      REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.plus(rm, arg2)));
+}
 
-    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.plus(rm, arg2)));
-  }
+RewriteResponse mult(TNode node, bool isPreRewrite)
+{
+  Assert(node.getKind() == kind::FLOATINGPOINT_MULT);
+  Assert(node.getNumChildren() == 3);
 
-  RewriteResponse mult (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_MULT);
-    Assert(node.getNumChildren() == 3);
+  RoundingMode rm(node[0].getConst<RoundingMode>());
+  FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+  FloatingPoint arg2(node[2].getConst<FloatingPoint>());
 
-    RoundingMode rm(node[0].getConst<RoundingMode>());
-    FloatingPoint arg1(node[1].getConst<FloatingPoint>());
-    FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+  Assert(arg1.getSize() == arg2.getSize());
 
-    Assert(arg1.getSize() == arg2.getSize());
+  return RewriteResponse(
+      REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.mult(rm, arg2)));
+}
 
-    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.mult(rm, arg2)));
-  }
+RewriteResponse fma(TNode node, bool isPreRewrite)
+{
+  Assert(node.getKind() == kind::FLOATINGPOINT_FMA);
+  Assert(node.getNumChildren() == 4);
 
-  RewriteResponse fma (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_FMA);
-    Assert(node.getNumChildren() == 4);
+  RoundingMode rm(node[0].getConst<RoundingMode>());
+  FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+  FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+  FloatingPoint arg3(node[3].getConst<FloatingPoint>());
 
-    RoundingMode rm(node[0].getConst<RoundingMode>());
-    FloatingPoint arg1(node[1].getConst<FloatingPoint>());
-    FloatingPoint arg2(node[2].getConst<FloatingPoint>());
-    FloatingPoint arg3(node[3].getConst<FloatingPoint>());
+  Assert(arg1.getSize() == arg2.getSize());
+  Assert(arg1.getSize() == arg3.getSize());
 
-    Assert(arg1.getSize() == arg2.getSize());
-    Assert(arg1.getSize() == arg3.getSize());
+  return RewriteResponse(
+      REWRITE_DONE,
+      NodeManager::currentNM()->mkConst(arg1.fma(rm, arg2, arg3)));
+}
 
-    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.fma(rm, arg2, arg3)));
-  }
+RewriteResponse div(TNode node, bool isPreRewrite)
+{
+  Assert(node.getKind() == kind::FLOATINGPOINT_DIV);
+  Assert(node.getNumChildren() == 3);
 
-  RewriteResponse div (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_DIV);
-    Assert(node.getNumChildren() == 3);
+  RoundingMode rm(node[0].getConst<RoundingMode>());
+  FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+  FloatingPoint arg2(node[2].getConst<FloatingPoint>());
 
-    RoundingMode rm(node[0].getConst<RoundingMode>());
-    FloatingPoint arg1(node[1].getConst<FloatingPoint>());
-    FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+  Assert(arg1.getSize() == arg2.getSize());
 
-    Assert(arg1.getSize() == arg2.getSize());
+  return RewriteResponse(REWRITE_DONE,
+                         NodeManager::currentNM()->mkConst(arg1.div(rm, arg2)));
+}
 
-    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.div(rm, arg2)));
-  }
-  
-  RewriteResponse sqrt (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_SQRT);
-    Assert(node.getNumChildren() == 2);
+RewriteResponse sqrt(TNode node, bool isPreRewrite)
+{
+  Assert(node.getKind() == kind::FLOATINGPOINT_SQRT);
+  Assert(node.getNumChildren() == 2);
 
-    RoundingMode rm(node[0].getConst<RoundingMode>());
-    FloatingPoint arg(node[1].getConst<FloatingPoint>());
-    
-    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg.sqrt(rm)));
-  }
+  RoundingMode rm(node[0].getConst<RoundingMode>());
+  FloatingPoint arg(node[1].getConst<FloatingPoint>());
 
-  RewriteResponse rti (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_RTI);
-    Assert(node.getNumChildren() == 2);
+  return RewriteResponse(REWRITE_DONE,
+                         NodeManager::currentNM()->mkConst(arg.sqrt(rm)));
+}
 
-    RoundingMode rm(node[0].getConst<RoundingMode>());
-    FloatingPoint arg(node[1].getConst<FloatingPoint>());
-    
-    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg.rti(rm)));
-  }
+RewriteResponse rti(TNode node, bool isPreRewrite)
+{
+  Assert(node.getKind() == kind::FLOATINGPOINT_RTI);
+  Assert(node.getNumChildren() == 2);
 
-  RewriteResponse rem (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_REM);
-    Assert(node.getNumChildren() == 2);
+  RoundingMode rm(node[0].getConst<RoundingMode>());
+  FloatingPoint arg(node[1].getConst<FloatingPoint>());
 
-    FloatingPoint arg1(node[0].getConst<FloatingPoint>());
-    FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+  return RewriteResponse(REWRITE_DONE,
+                         NodeManager::currentNM()->mkConst(arg.rti(rm)));
+}
 
-    Assert(arg1.getSize() == arg2.getSize());
+RewriteResponse rem(TNode node, bool isPreRewrite)
+{
+  Assert(node.getKind() == kind::FLOATINGPOINT_REM);
+  Assert(node.getNumChildren() == 2);
 
-    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.rem(arg2)));
-  }
+  FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+  FloatingPoint arg2(node[1].getConst<FloatingPoint>());
 
-  RewriteResponse min (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_MIN);
-    Assert(node.getNumChildren() == 2);
+  Assert(arg1.getSize() == arg2.getSize());
 
-    FloatingPoint arg1(node[0].getConst<FloatingPoint>());
-    FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+  return RewriteResponse(REWRITE_DONE,
+                         NodeManager::currentNM()->mkConst(arg1.rem(arg2)));
+}
 
-    Assert(arg1.getSize() == arg2.getSize());
+RewriteResponse min(TNode node, bool isPreRewrite)
+{
+  Assert(node.getKind() == kind::FLOATINGPOINT_MIN);
+  Assert(node.getNumChildren() == 2);
 
-    FloatingPoint::PartialFloatingPoint res(arg1.min(arg2));
+  FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+  FloatingPoint arg2(node[1].getConst<FloatingPoint>());
 
-    if (res.second) {
-      Node lit = NodeManager::currentNM()->mkConst(res.first);
-      return RewriteResponse(REWRITE_DONE, lit);
-    } else {
-      // Can't constant fold the underspecified case
-      return RewriteResponse(REWRITE_DONE, node);
-    }
+  Assert(arg1.getSize() == arg2.getSize());
+
+  FloatingPoint::PartialFloatingPoint res(arg1.min(arg2));
+
+  if (res.second)
+  {
+    Node lit = NodeManager::currentNM()->mkConst(res.first);
+    return RewriteResponse(REWRITE_DONE, lit);
+  }
+  else
+  {
+    // Can't constant fold the underspecified case
+    return RewriteResponse(REWRITE_DONE, node);
   }
+}
 
-  RewriteResponse max (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_MAX);
-    Assert(node.getNumChildren() == 2);
+RewriteResponse max(TNode node, bool isPreRewrite)
+{
+  Assert(node.getKind() == kind::FLOATINGPOINT_MAX);
+  Assert(node.getNumChildren() == 2);
 
-    FloatingPoint arg1(node[0].getConst<FloatingPoint>());
-    FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+  FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+  FloatingPoint arg2(node[1].getConst<FloatingPoint>());
 
-    Assert(arg1.getSize() == arg2.getSize());
+  Assert(arg1.getSize() == arg2.getSize());
 
-    FloatingPoint::PartialFloatingPoint res(arg1.max(arg2));
+  FloatingPoint::PartialFloatingPoint res(arg1.max(arg2));
 
-    if (res.second) {
-      Node lit = NodeManager::currentNM()->mkConst(res.first);
-      return RewriteResponse(REWRITE_DONE, lit);
-    } else {
-      // Can't constant fold the underspecified case
-      return RewriteResponse(REWRITE_DONE, node);
-    }
+  if (res.second)
+  {
+    Node lit = NodeManager::currentNM()->mkConst(res.first);
+    return RewriteResponse(REWRITE_DONE, lit);
   }
+  else
+  {
+    // Can't constant fold the underspecified case
+    return RewriteResponse(REWRITE_DONE, node);
+  }
+}
 
-  RewriteResponse minTotal (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_MIN_TOTAL);
-    Assert(node.getNumChildren() == 3);
-
-    FloatingPoint arg1(node[0].getConst<FloatingPoint>());
-    FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+RewriteResponse minTotal(TNode node, bool isPreRewrite)
+{
+  Assert(node.getKind() == kind::FLOATINGPOINT_MIN_TOTAL);
+  Assert(node.getNumChildren() == 3);
 
-    Assert(arg1.getSize() == arg2.getSize());
+  FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+  FloatingPoint arg2(node[1].getConst<FloatingPoint>());
 
-    // Can be called with the third argument non-constant
-    if (node[2].getMetaKind() == kind::metakind::CONSTANT) {
-      BitVector arg3(node[2].getConst<BitVector>());
+  Assert(arg1.getSize() == arg2.getSize());
 
-      FloatingPoint folded(arg1.minTotal(arg2, arg3.isBitSet(0)));
-      Node lit = NodeManager::currentNM()->mkConst(folded);
-      return RewriteResponse(REWRITE_DONE, lit);
+  // Can be called with the third argument non-constant
+  if (node[2].getMetaKind() == kind::metakind::CONSTANT)
+  {
+    BitVector arg3(node[2].getConst<BitVector>());
 
-    } else {
-      FloatingPoint::PartialFloatingPoint res(arg1.min(arg2));
+    FloatingPoint folded(arg1.minTotal(arg2, arg3.isBitSet(0)));
+    Node lit = NodeManager::currentNM()->mkConst(folded);
+    return RewriteResponse(REWRITE_DONE, lit);
+  }
+  else
+  {
+    FloatingPoint::PartialFloatingPoint res(arg1.min(arg2));
 
-      if (res.second) {
-       Node lit = NodeManager::currentNM()->mkConst(res.first);
-       return RewriteResponse(REWRITE_DONE, lit);
-      } else {
-       // Can't constant fold the underspecified case
-       return RewriteResponse(REWRITE_DONE, node);
-      }
+    if (res.second)
+    {
+      Node lit = NodeManager::currentNM()->mkConst(res.first);
+      return RewriteResponse(REWRITE_DONE, lit);
+    }
+    else
+    {
+      // Can't constant fold the underspecified case
+      return RewriteResponse(REWRITE_DONE, node);
     }
   }
+}
 
-  RewriteResponse maxTotal (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_MAX_TOTAL);
-    Assert(node.getNumChildren() == 3);
+RewriteResponse maxTotal(TNode node, bool isPreRewrite)
+{
+  Assert(node.getKind() == kind::FLOATINGPOINT_MAX_TOTAL);
+  Assert(node.getNumChildren() == 3);
 
-    FloatingPoint arg1(node[0].getConst<FloatingPoint>());
-    FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+  FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+  FloatingPoint arg2(node[1].getConst<FloatingPoint>());
 
-    Assert(arg1.getSize() == arg2.getSize());
+  Assert(arg1.getSize() == arg2.getSize());
 
-    // Can be called with the third argument non-constant
-    if (node[2].getMetaKind() == kind::metakind::CONSTANT) {
-      BitVector arg3(node[2].getConst<BitVector>());
-
-      FloatingPoint folded(arg1.maxTotal(arg2, arg3.isBitSet(0)));
-      Node lit = NodeManager::currentNM()->mkConst(folded);
-     return RewriteResponse(REWRITE_DONE, lit);
+  // Can be called with the third argument non-constant
+  if (node[2].getMetaKind() == kind::metakind::CONSTANT)
+  {
+    BitVector arg3(node[2].getConst<BitVector>());
 
-    } else {
-      FloatingPoint::PartialFloatingPoint res(arg1.max(arg2));
+    FloatingPoint folded(arg1.maxTotal(arg2, arg3.isBitSet(0)));
+    Node lit = NodeManager::currentNM()->mkConst(folded);
+    return RewriteResponse(REWRITE_DONE, lit);
+  }
+  else
+  {
+    FloatingPoint::PartialFloatingPoint res(arg1.max(arg2));
 
-      if (res.second) {
-       Node lit = NodeManager::currentNM()->mkConst(res.first);
-       return RewriteResponse(REWRITE_DONE, lit);
-      } else {
-       // Can't constant fold the underspecified case
-       return RewriteResponse(REWRITE_DONE, node);
-      }
+    if (res.second)
+    {
+      Node lit = NodeManager::currentNM()->mkConst(res.first);
+      return RewriteResponse(REWRITE_DONE, lit);
+    }
+    else
+    {
+      // Can't constant fold the underspecified case
+      return RewriteResponse(REWRITE_DONE, node);
     }
   }
+}
 
-  
   RewriteResponse equal (TNode node, bool isPreRewrite) {
     Assert(node.getKind() == kind::EQUAL);
 
@@ -572,8 +619,8 @@ namespace constantFold {
     Unreachable() << "Equality of unknown type";
   }
 
-
-  RewriteResponse leq (TNode node, bool) {
+  RewriteResponse leq(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_LEQ);
     Assert(node.getNumChildren() == 2);
 
@@ -585,8 +632,8 @@ namespace constantFold {
     return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 <= arg2));
   }
 
-
-  RewriteResponse lt (TNode node, bool) {
+  RewriteResponse lt(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_LT);
     Assert(node.getNumChildren() == 2);
 
@@ -598,57 +645,64 @@ namespace constantFold {
     return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 < arg2));
   }
 
-
-  RewriteResponse isNormal (TNode node, bool) {
+  RewriteResponse isNormal(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_ISN);
     Assert(node.getNumChildren() == 1);
 
     return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNormal()));
   }
 
-  RewriteResponse isSubnormal (TNode node, bool) {
+  RewriteResponse isSubnormal(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_ISSN);
     Assert(node.getNumChildren() == 1);
 
     return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isSubnormal()));
   }
 
-  RewriteResponse isZero (TNode node, bool) {
+  RewriteResponse isZero(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_ISZ);
     Assert(node.getNumChildren() == 1);
 
     return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isZero()));
   }
 
-  RewriteResponse isInfinite (TNode node, bool) {
+  RewriteResponse isInfinite(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_ISINF);
     Assert(node.getNumChildren() == 1);
 
     return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isInfinite()));
   }
 
-  RewriteResponse isNaN (TNode node, bool) {
+  RewriteResponse isNaN(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_ISNAN);
     Assert(node.getNumChildren() == 1);
 
     return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNaN()));
   }
 
-  RewriteResponse isNegative (TNode node, bool) {
+  RewriteResponse isNegative(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_ISNEG);
     Assert(node.getNumChildren() == 1);
 
     return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNegative()));
   }
 
-  RewriteResponse isPositive (TNode node, bool) {
+  RewriteResponse isPositive(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_ISPOS);
     Assert(node.getNumChildren() == 1);
 
     return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isPositive()));
   }
 
-  RewriteResponse convertFromIEEEBitVectorLiteral (TNode node, bool) {
+  RewriteResponse convertFromIEEEBitVectorLiteral(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR);
 
     TNode op = node.getOperator();
@@ -663,7 +717,8 @@ namespace constantFold {
     return RewriteResponse(REWRITE_DONE, lit);
   }
 
-  RewriteResponse constantConvert (TNode node, bool) {
+  RewriteResponse constantConvert(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT);
     Assert(node.getNumChildren() == 2);
 
@@ -676,7 +731,8 @@ namespace constantFold {
         NodeManager::currentNM()->mkConst(arg1.convert(info.getSize(), rm)));
   }
 
-  RewriteResponse convertFromRealLiteral (TNode node, bool) {
+  RewriteResponse convertFromRealLiteral(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_REAL);
 
     TNode op = node.getOperator();
@@ -692,7 +748,8 @@ namespace constantFold {
     return RewriteResponse(REWRITE_DONE, lit);
   }
 
-  RewriteResponse convertFromSBV (TNode node, bool) {
+  RewriteResponse convertFromSBV(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR);
 
     TNode op = node.getOperator();
@@ -708,7 +765,8 @@ namespace constantFold {
     return RewriteResponse(REWRITE_DONE, lit);
   }
 
-  RewriteResponse convertFromUBV (TNode node, bool) {
+  RewriteResponse convertFromUBV(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR);
 
     TNode op = node.getOperator();
@@ -724,7 +782,8 @@ namespace constantFold {
     return RewriteResponse(REWRITE_DONE, lit);
   }
 
-  RewriteResponse convertToUBV (TNode node, bool) {
+  RewriteResponse convertToUBV(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV);
 
     TNode op = node.getOperator();
@@ -745,7 +804,8 @@ namespace constantFold {
     }
   }
 
-  RewriteResponse convertToSBV (TNode node, bool) {
+  RewriteResponse convertToSBV(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV);
 
     TNode op = node.getOperator();
@@ -766,7 +826,8 @@ namespace constantFold {
     }
   }
 
-  RewriteResponse convertToReal (TNode node, bool) {
+  RewriteResponse convertToReal(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_REAL);
 
     FloatingPoint arg(node[0].getConst<FloatingPoint>());
@@ -782,7 +843,8 @@ namespace constantFold {
     }
   }
 
-  RewriteResponse convertToUBVTotal (TNode node, bool) {
+  RewriteResponse convertToUBVTotal(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV_TOTAL);
 
     TNode op = node.getOperator();
@@ -814,7 +876,8 @@ namespace constantFold {
     }
   }
 
-  RewriteResponse convertToSBVTotal (TNode node, bool) {
+  RewriteResponse convertToSBVTotal(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV_TOTAL);
 
     TNode op = node.getOperator();
@@ -846,7 +909,8 @@ namespace constantFold {
     }
   }
 
-  RewriteResponse convertToRealTotal (TNode node, bool) {
+  RewriteResponse convertToRealTotal(TNode node, bool isPreRewrite)
+  {
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_REAL_TOTAL);
 
     FloatingPoint arg(node[0].getConst<FloatingPoint>());
@@ -872,7 +936,7 @@ namespace constantFold {
     }
   }
 
-  RewriteResponse componentFlag(TNode node, bool)
+  RewriteResponse componentFlag(TNode node, bool isPreRewrite)
   {
     Kind k = node.getKind();
 
@@ -901,7 +965,7 @@ namespace constantFold {
                            NodeManager::currentNM()->mkConst(res));
   }
 
-  RewriteResponse componentExponent(TNode node, bool)
+  RewriteResponse componentExponent(TNode node, bool isPreRewrite)
   {
     Assert(node.getKind() == kind::FLOATINGPOINT_COMPONENT_EXPONENT);
 
@@ -918,7 +982,7 @@ namespace constantFold {
     );
   }
 
-  RewriteResponse componentSignificand(TNode node, bool)
+  RewriteResponse componentSignificand(TNode node, bool isPreRewrite)
   {
     Assert(node.getKind() == kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND);
 
@@ -934,7 +998,7 @@ namespace constantFold {
     );
   }
 
-  RewriteResponse roundingModeBitBlast(TNode node, bool)
+  RewriteResponse roundingModeBitBlast(TNode node, bool isPreRewrite)
   {
     Assert(node.getKind() == kind::ROUNDINGMODE_BITBLAST);