Improve FP rewriter: const folding, other (#1126)
authorMartin Brain <martin.brain@cs.ox.ac.uk>
Wed, 27 Sep 2017 00:21:51 +0000 (17:21 -0700)
committerAndres Noetzli <noetzli@stanford.edu>
Wed, 27 Sep 2017 00:24:55 +0000 (17:24 -0700)
src/theory/fp/theory_fp_rewriter.cpp
src/theory/fp/theory_fp_rewriter.h

index 747aaeac6b3d5da5f8ce4a35f843d083b26e5252..ec42099c23fed6ab0bddb5d259d120310e483e01 100644 (file)
  **
  ** \brief [[ Rewrite rules for floating point theories. ]]
  **
- ** \todo [[ Constant folding
**          Push negations up through arithmetic operators (include max and min? maybe not due to +0/-0)
+ ** \todo [[ Single argument constant propagate / simplify
            Push negations through arithmetic operators (include max and min? maybe not due to +0/-0)
  **          classifications to normal tests (maybe)
  **          (= x (fp.neg x)) --> (isNaN x)
  **          (fp.eq x (fp.neg x)) --> (isZero x)   (previous and reorganise should be sufficient)
- **          (fp.eq x const) --> various = depending on const 
+ **          (fp.eq x const) --> various = depending on const
  **          (fp.abs (fp.neg x)) --> (fp.abs x)
  **          (fp.isPositive (fp.neg x)) --> (fp.isNegative x)
  **          (fp.isNegative (fp.neg x)) --> (fp.isPositive x)
  **          (fp.isPositive (fp.abs x)) --> (not (isNaN x))
  **          (fp.isNegative (fp.abs x)) --> false
+ **          A -> castA --> A
+ **          A -> castB -> castC  -->  A -> castC if A <= B <= C
+ **          A -> castB -> castA  -->  A if A <= B
+ **          promotion converts can ignore rounding mode
+ **          Samuel Figuer results
  **       ]]
  **/
 
@@ -137,7 +142,7 @@ namespace rewrite {
   }
 
   RewriteResponse removed (TNode node, bool) {  
-    Unreachable("kind (%d) should have been removed?",node.getKind());
+    Unreachable("kind (%s) should have been removed?",kindToString(node.getKind()).c_str());
     return RewriteResponse(REWRITE_DONE, node);
   }
 
@@ -150,10 +155,10 @@ namespace rewrite {
     return RewriteResponse(REWRITE_DONE, node);
   }
 
-  RewriteResponse equal (TNode node, bool isPreRewrite) {  
-    // We should only get equalities of floating point or rounding mode types.
+  RewriteResponse equal (TNode node, bool isPreRewrite) {
     Assert(node.getKind() == kind::EQUAL);
-
+  
+    // We should only get equalities of floating point or rounding mode types.
     TypeNode tn = node[0].getType(true);
 
     Assert(tn.isFloatingPoint() || tn.isRoundingMode());
@@ -169,74 +174,6 @@ namespace rewrite {
     }
   }
 
-  RewriteResponse convertFromRealLiteral (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_REAL);
-
-    // \todo Honour the rounding mode and work for something other than doubles!
-
-    if (node[1].getKind() == kind::CONST_RATIONAL) {
-      TNode op = node.getOperator();
-      const FloatingPointToFPReal &param = op.getConst<FloatingPointToFPReal>();
-      Node lit =
-       NodeManager::currentNM()->mkConst(FloatingPoint(param.t.exponent(),
-                                                       param.t.significand(),
-                                                       node[1].getConst<Rational>().getDouble()));
-      
-      return RewriteResponse(REWRITE_DONE, lit);
-    } else {
-      return RewriteResponse(REWRITE_DONE, node);
-    }
-  }
-
-  RewriteResponse convertFromIEEEBitVectorLiteral (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR);
-
-    // \todo Handle arbitrary length bit vectors without using strings!
-
-    if (node[0].getKind() == kind::CONST_BITVECTOR) {
-      TNode op = node.getOperator();
-      const FloatingPointToFPIEEEBitVector &param = op.getConst<FloatingPointToFPIEEEBitVector>();
-      const BitVector &bv = node[0].getConst<BitVector>();
-      std::string bitString(bv.toString());
-
-      Node lit =
-       NodeManager::currentNM()->mkConst(FloatingPoint(param.t.exponent(),
-                                                       param.t.significand(),
-                                                       bitString));
-
-      return RewriteResponse(REWRITE_DONE, lit);
-    } else {
-      return RewriteResponse(REWRITE_DONE, node);
-    }
-  }
-
-  RewriteResponse convertFromLiteral (TNode node, bool) {
-    Assert(node.getKind() == kind::FLOATINGPOINT_FP);
-
-    // \todo Handle arbitrary length bit vectors without using strings!
-
-    if ((node[0].getKind() == kind::CONST_BITVECTOR) &&
-       (node[1].getKind() == kind::CONST_BITVECTOR) &&
-       (node[2].getKind() == kind::CONST_BITVECTOR)) {
-
-      BitVector bv(node[0].getConst<BitVector>());
-      bv = bv.concat(node[1].getConst<BitVector>());
-      bv = bv.concat(node[2].getConst<BitVector>());
-
-      std::string bitString(bv.toString());
-      std::reverse(bitString.begin(), bitString.end());
-
-      // +1 to support the hidden bit
-      Node lit =
-       NodeManager::currentNM()->mkConst(FloatingPoint(node[1].getConst<BitVector>().getSize(),
-                                                       node[2].getConst<BitVector>().getSize() + 1,
-                                                       bitString));
-
-      return RewriteResponse(REWRITE_DONE, lit);
-    } else {
-      return RewriteResponse(REWRITE_DONE, node);
-    }
-  }
 
   // Note these cannot be assumed to be symmetric for +0/-0, thus no symmetry reorder
   RewriteResponse compactMinMax (TNode node, bool isPreRewrite) {
@@ -309,11 +246,410 @@ namespace rewrite {
     } 
   }
 
-
 }; /* CVC4::theory::fp::rewrite */
 
+
+namespace constantFold {
+
+
+  RewriteResponse fpLiteral (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_FP);
+    
+    BitVector bv(node[0].getConst<BitVector>());
+    bv = bv.concat(node[1].getConst<BitVector>());
+    bv = bv.concat(node[2].getConst<BitVector>());
+    
+    // +1 to support the hidden bit
+    Node lit =
+      NodeManager::currentNM()->mkConst(FloatingPoint(node[1].getConst<BitVector>().getSize(),
+                                                     node[2].getConst<BitVector>().getSize() + 1,
+                                                     bv));
+    
+    return RewriteResponse(REWRITE_DONE, lit);
+  }
+
+  RewriteResponse abs (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_ABS);
+    Assert(node.getNumChildren() == 1);
+
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().absolute()));
+  }
+
+
+  RewriteResponse neg (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_NEG);
+    Assert(node.getNumChildren() == 1);
+
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().negate()));
+  }
+
+
+  RewriteResponse plus (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_PLUS);
+    Assert(node.getNumChildren() == 3);
+
+    RoundingMode rm(node[0].getConst<RoundingMode>());
+    FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+    FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+    
+    Assert(arg1.t == arg2.t);
+      
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.plus(rm, arg2)));
+  }
+
+  RewriteResponse mult (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_MULT);
+    Assert(node.getNumChildren() == 3);
+
+    RoundingMode rm(node[0].getConst<RoundingMode>());
+    FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+    FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+    
+    Assert(arg1.t == arg2.t);
+    
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.mult(rm, arg2)));
+  }
+
+  RewriteResponse fma (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_FMA);
+    Assert(node.getNumChildren() == 4);
+
+    RoundingMode rm(node[0].getConst<RoundingMode>());
+    FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+    FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+    FloatingPoint arg3(node[3].getConst<FloatingPoint>());
+    
+    Assert(arg1.t == arg2.t);
+    Assert(arg1.t == arg3.t);
+    
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.fma(rm, arg2, arg3)));
+  }
+
+  RewriteResponse div (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_DIV);
+    Assert(node.getNumChildren() == 3);
+
+    RoundingMode rm(node[0].getConst<RoundingMode>());
+    FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+    FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+    
+    Assert(arg1.t == arg2.t);
+    
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.div(rm, arg2)));
+  }
+  
+  RewriteResponse sqrt (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_SQRT);
+    Assert(node.getNumChildren() == 2);
+
+    RoundingMode rm(node[0].getConst<RoundingMode>());
+    FloatingPoint arg(node[1].getConst<FloatingPoint>());
+    
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg.sqrt(rm)));
+  }
+
+  RewriteResponse rti (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_RTI);
+    Assert(node.getNumChildren() == 2);
+
+    RoundingMode rm(node[0].getConst<RoundingMode>());
+    FloatingPoint arg(node[1].getConst<FloatingPoint>());
+    
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg.rti(rm)));
+  }
+
+  RewriteResponse rem (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_REM);
+    Assert(node.getNumChildren() == 2);
+
+    FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+    FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+    
+    Assert(arg1.t == arg2.t);
+    
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.rem(arg2)));
+  }
+
+  RewriteResponse min (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_MIN);
+    Assert(node.getNumChildren() == 2);
+
+    FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+    FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+    
+    Assert(arg1.t == arg2.t);
+
+    FloatingPoint::PartialFloatingPoint res(arg1.min(arg2));
+
+    if (res.second) {
+      Node lit = NodeManager::currentNM()->mkConst(res.first);
+      return RewriteResponse(REWRITE_DONE, lit);
+    } else {
+      // Can't constant fold the underspecified case
+      return RewriteResponse(REWRITE_DONE, node);
+    }
+  }
+
+  RewriteResponse max (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_MAX);
+    Assert(node.getNumChildren() == 2);
+
+    FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+    FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+    
+    Assert(arg1.t == arg2.t);
+
+    FloatingPoint::PartialFloatingPoint res(arg1.max(arg2));
+
+    if (res.second) {
+      Node lit = NodeManager::currentNM()->mkConst(res.first);
+      return RewriteResponse(REWRITE_DONE, lit);
+    } else {
+      // Can't constant fold the underspecified case
+      return RewriteResponse(REWRITE_DONE, node);
+    }
+  }
+
+  
+  RewriteResponse equal (TNode node, bool isPreRewrite) {
+    Assert(node.getKind() == kind::EQUAL);
+  
+    // We should only get equalities of floating point or rounding mode types.
+    TypeNode tn = node[0].getType(true);
+
+    if (tn.isFloatingPoint()) {
+      FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+      FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+    
+      Assert(arg1.t == arg2.t);
+    
+      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 == arg2));
+
+    } else if (tn.isRoundingMode()) {
+      RoundingMode arg1(node[0].getConst<RoundingMode>());
+      RoundingMode arg2(node[1].getConst<RoundingMode>());
+    
+      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 == arg2));
+
+    } else {
+      Unreachable("Equality of unknown type");
+    }
+
+    return RewriteResponse(REWRITE_DONE, node);
+  }
+
+
+  RewriteResponse leq (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_LEQ);
+    Assert(node.getNumChildren() == 2);
+
+    FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+    FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+    
+    Assert(arg1.t == arg2.t);
+    
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 <= arg2));
+  }
+
+
+  RewriteResponse lt (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_LT);
+    Assert(node.getNumChildren() == 2);
+
+    FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+    FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+    
+    Assert(arg1.t == arg2.t);
+    
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 < arg2));
+  }
+
+
+  RewriteResponse isNormal (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_ISN);
+    Assert(node.getNumChildren() == 1);
+
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNormal()));
+  }
+
+  RewriteResponse isSubnormal (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_ISSN);
+    Assert(node.getNumChildren() == 1);
+
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isSubnormal()));
+  }
+
+  RewriteResponse isZero (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_ISZ);
+    Assert(node.getNumChildren() == 1);
+
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isZero()));
+  }
+
+  RewriteResponse isInfinite (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_ISINF);
+    Assert(node.getNumChildren() == 1);
+
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isInfinite()));
+  }
+
+  RewriteResponse isNaN (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_ISNAN);
+    Assert(node.getNumChildren() == 1);
+
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNaN()));
+  }
+
+  RewriteResponse isNegative (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_ISNEG);
+    Assert(node.getNumChildren() == 1);
+
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNegative()));
+  }
+
+  RewriteResponse isPositive (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_ISPOS);
+    Assert(node.getNumChildren() == 1);
+
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isPositive()));
+  }
+
+  RewriteResponse convertFromIEEEBitVectorLiteral (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR);
+
+    TNode op = node.getOperator();
+    const FloatingPointToFPIEEEBitVector &param = op.getConst<FloatingPointToFPIEEEBitVector>();
+    const BitVector &bv = node[0].getConst<BitVector>();
+    
+    Node lit =
+      NodeManager::currentNM()->mkConst(FloatingPoint(param.t.exponent(),
+                                                     param.t.significand(),
+                                                     bv));
+    
+    return RewriteResponse(REWRITE_DONE, lit);
+  }
+
+  RewriteResponse constantConvert (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT);
+    Assert(node.getNumChildren() == 2);
+
+    RoundingMode rm(node[0].getConst<RoundingMode>());
+    FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+    FloatingPointToFPFloatingPoint info = node.getOperator().getConst<FloatingPointToFPFloatingPoint>();
+
+    return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.convert(info.t,rm)));
+  }
+
+  RewriteResponse convertFromRealLiteral (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_REAL);
+  
+    TNode op = node.getOperator();
+    const FloatingPointToFPReal &param = op.getConst<FloatingPointToFPReal>();
+
+    RoundingMode rm(node[0].getConst<RoundingMode>());
+    Rational arg(node[1].getConst<Rational>());
+
+    FloatingPoint res(param.t, rm, arg);
+    
+    Node lit = NodeManager::currentNM()->mkConst(res);
+    
+    return RewriteResponse(REWRITE_DONE, lit);
+  }
+
+  RewriteResponse convertFromSBV (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR);
+  
+    TNode op = node.getOperator();
+    const FloatingPointToFPSignedBitVector &param = op.getConst<FloatingPointToFPSignedBitVector>();
+
+    RoundingMode rm(node[0].getConst<RoundingMode>());
+    BitVector arg(node[1].getConst<BitVector>());
+
+    FloatingPoint res(param.t, rm, arg, true);
+    
+    Node lit = NodeManager::currentNM()->mkConst(res);
+    
+    return RewriteResponse(REWRITE_DONE, lit);
+  }
+
+  RewriteResponse convertFromUBV (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR);
+  
+    TNode op = node.getOperator();
+    const FloatingPointToFPUnsignedBitVector &param = op.getConst<FloatingPointToFPUnsignedBitVector>();
+
+    RoundingMode rm(node[0].getConst<RoundingMode>());
+    BitVector arg(node[1].getConst<BitVector>());
+
+    FloatingPoint res(param.t, rm, arg, false);
+    
+    Node lit = NodeManager::currentNM()->mkConst(res);
+    
+    return RewriteResponse(REWRITE_DONE, lit);
+  }
+
+  RewriteResponse convertToUBV (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV);
+  
+    TNode op = node.getOperator();
+    const FloatingPointToUBV &param = op.getConst<FloatingPointToUBV>();
+
+    RoundingMode rm(node[0].getConst<RoundingMode>());
+    FloatingPoint arg(node[1].getConst<FloatingPoint>());
+
+    FloatingPoint::PartialBitVector res(arg.convertToBV(param.bvs, rm, false));
+
+    if (res.second) {
+      Node lit = NodeManager::currentNM()->mkConst(res.first);
+      return RewriteResponse(REWRITE_DONE, lit);
+    } else {
+      // Can't constant fold the underspecified case
+      return RewriteResponse(REWRITE_DONE, node);
+    }
+  }
+
+  RewriteResponse convertToSBV (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV);
+  
+    TNode op = node.getOperator();
+    const FloatingPointToSBV &param = op.getConst<FloatingPointToSBV>();
+
+    RoundingMode rm(node[0].getConst<RoundingMode>());
+    FloatingPoint arg(node[1].getConst<FloatingPoint>());
+
+    FloatingPoint::PartialBitVector res(arg.convertToBV(param.bvs, rm, true));
+
+    if (res.second) {
+      Node lit = NodeManager::currentNM()->mkConst(res.first);
+      return RewriteResponse(REWRITE_DONE, lit);
+    } else {
+      // Can't constant fold the underspecified case
+      return RewriteResponse(REWRITE_DONE, node);
+    }
+  }
+
+  RewriteResponse convertToReal (TNode node, bool) {
+    Assert(node.getKind() == kind::FLOATINGPOINT_TO_REAL);
+  
+    FloatingPoint arg(node[0].getConst<FloatingPoint>());
+
+    FloatingPoint::PartialRational res(arg.convertToRational());
+
+    if (res.second) {
+      Node lit = NodeManager::currentNM()->mkConst(res.first);
+      return RewriteResponse(REWRITE_DONE, lit);
+    } else {
+      // Can't constant fold the underspecified case
+      return RewriteResponse(REWRITE_DONE, node);
+    }
+  }
+
+};  /* CVC4::theory::fp::constantFold */
+
+
 RewriteFunction TheoryFpRewriter::preRewriteTable[kind::LAST_KIND]; 
 RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND]; 
+RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND]; 
 
 
   /**
@@ -381,6 +717,7 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND];
     /******** Variables ********/
     preRewriteTable[kind::VARIABLE] = rewrite::variable;
     preRewriteTable[kind::BOUND_VARIABLE] = rewrite::variable;
+    preRewriteTable[kind::SKOLEM] = rewrite::variable;
 
     preRewriteTable[kind::EQUAL] = rewrite::equal;
 
@@ -403,7 +740,7 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND];
     postRewriteTable[kind::FLOATINGPOINT_TYPE] = rewrite::type;
       
     /******** Operations ********/
-    postRewriteTable[kind::FLOATINGPOINT_FP] = rewrite::convertFromLiteral;
+    postRewriteTable[kind::FLOATINGPOINT_FP] = rewrite::identity;
     postRewriteTable[kind::FLOATINGPOINT_ABS] = rewrite::identity;
     postRewriteTable[kind::FLOATINGPOINT_NEG] = rewrite::removeDoubleNegation;
     postRewriteTable[kind::FLOATINGPOINT_PLUS] = rewrite::reorderBinaryOperation;
@@ -434,9 +771,9 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND];
     postRewriteTable[kind::FLOATINGPOINT_ISPOS] = rewrite::identity;
 
     /******** Conversions ********/
-    postRewriteTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = rewrite::convertFromIEEEBitVectorLiteral;
+    postRewriteTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = rewrite::identity;
     postRewriteTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] = rewrite::identity;
-    postRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::convertFromRealLiteral;
+    postRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::identity;
     postRewriteTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] = rewrite::identity;
     postRewriteTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] = rewrite::identity;
     postRewriteTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::removed;
@@ -447,10 +784,81 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND];
     /******** Variables ********/
     postRewriteTable[kind::VARIABLE] = rewrite::variable;
     postRewriteTable[kind::BOUND_VARIABLE] = rewrite::variable;
+    postRewriteTable[kind::SKOLEM] = rewrite::variable;
 
     postRewriteTable[kind::EQUAL] = rewrite::equal;
 
 
+
+
+    /* Set up the post-rewrite constant fold table */
+    for (unsigned i = 0; i < kind::LAST_KIND; ++i) {
+      // Note that this is identity, not notFP
+      // Constant folding is called after post-rewrite
+      // So may have to deal with cases of things being
+      // re-written to non-floating-point sorts (i.e. true).
+      constantFoldTable[i] = rewrite::identity;
+    }
+
+    /******** Constants ********/
+    /* Already folded! */
+    constantFoldTable[kind::CONST_FLOATINGPOINT] = rewrite::identity;
+    constantFoldTable[kind::CONST_ROUNDINGMODE] = rewrite::identity; 
+
+    /******** Sorts(?) ********/
+    /* These kinds should only appear in types */
+    constantFoldTable[kind::FLOATINGPOINT_TYPE] = rewrite::type;
+      
+    /******** Operations ********/
+    constantFoldTable[kind::FLOATINGPOINT_FP] = constantFold::fpLiteral;
+    constantFoldTable[kind::FLOATINGPOINT_ABS] = constantFold::abs;
+    constantFoldTable[kind::FLOATINGPOINT_NEG] = constantFold::neg;
+    constantFoldTable[kind::FLOATINGPOINT_PLUS] = constantFold::plus;
+    constantFoldTable[kind::FLOATINGPOINT_SUB] = rewrite::removed;
+    constantFoldTable[kind::FLOATINGPOINT_MULT] = constantFold::mult;
+    constantFoldTable[kind::FLOATINGPOINT_DIV] = constantFold::div;
+    constantFoldTable[kind::FLOATINGPOINT_FMA] = constantFold::fma;
+    constantFoldTable[kind::FLOATINGPOINT_SQRT] = constantFold::sqrt;
+    constantFoldTable[kind::FLOATINGPOINT_REM] = constantFold::rem;
+    constantFoldTable[kind::FLOATINGPOINT_RTI] = constantFold::rti;
+    constantFoldTable[kind::FLOATINGPOINT_MIN] = constantFold::min;
+    constantFoldTable[kind::FLOATINGPOINT_MAX] = constantFold::max;
+
+    /******** Comparisons ********/
+    constantFoldTable[kind::FLOATINGPOINT_EQ] = rewrite::removed;
+    constantFoldTable[kind::FLOATINGPOINT_LEQ] = constantFold::leq;
+    constantFoldTable[kind::FLOATINGPOINT_LT] = constantFold::lt;
+    constantFoldTable[kind::FLOATINGPOINT_GEQ] = rewrite::removed;
+    constantFoldTable[kind::FLOATINGPOINT_GT] = rewrite::removed;
+
+    /******** Classifications ********/
+    constantFoldTable[kind::FLOATINGPOINT_ISN] = constantFold::isNormal;
+    constantFoldTable[kind::FLOATINGPOINT_ISSN] = constantFold::isSubnormal;
+    constantFoldTable[kind::FLOATINGPOINT_ISZ] = constantFold::isZero;
+    constantFoldTable[kind::FLOATINGPOINT_ISINF] = constantFold::isInfinite;
+    constantFoldTable[kind::FLOATINGPOINT_ISNAN] = constantFold::isNaN;
+    constantFoldTable[kind::FLOATINGPOINT_ISNEG] = constantFold::isNegative;
+    constantFoldTable[kind::FLOATINGPOINT_ISPOS] = constantFold::isPositive;
+
+    /******** Conversions ********/
+    constantFoldTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = constantFold::convertFromIEEEBitVectorLiteral;
+    constantFoldTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] = constantFold::constantConvert;
+    constantFoldTable[kind::FLOATINGPOINT_TO_FP_REAL] = constantFold::convertFromRealLiteral;
+    constantFoldTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] = constantFold::convertFromSBV;
+    constantFoldTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] = constantFold::convertFromUBV;
+    constantFoldTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::removed;
+    constantFoldTable[kind::FLOATINGPOINT_TO_UBV] = constantFold::convertToUBV;
+    constantFoldTable[kind::FLOATINGPOINT_TO_SBV] = constantFold::convertToSBV;
+    constantFoldTable[kind::FLOATINGPOINT_TO_REAL] = constantFold::convertToReal;
+
+    /******** Variables ********/
+    constantFoldTable[kind::VARIABLE] = rewrite::variable;
+    constantFoldTable[kind::BOUND_VARIABLE] = rewrite::variable;
+
+    constantFoldTable[kind::EQUAL] = constantFold::equal;
+
+
+
   }
 
 
@@ -506,6 +914,83 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND];
       Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): before " << node << std::endl;
       Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): after  " << res.node << std::endl;
     }
+
+    if (res.status == REWRITE_DONE) {
+      bool allChildrenConst = true;
+      bool apartFromRoundingMode = false;
+      for (Node::const_iterator i = res.node.begin();
+          i != res.node.end();
+          ++i) {
+
+       if ((*i).getMetaKind() != kind::metakind::CONSTANT) {
+         if ((*i).getType().isRoundingMode() && !apartFromRoundingMode) {
+           apartFromRoundingMode = true;
+         } else {
+           allChildrenConst = false;
+           break;
+         }
+       }
+      }
+
+      if (allChildrenConst) {
+       RewriteStatus rs = REWRITE_DONE;    // This is a bit messy because
+       Node rn = res.node;                 // RewriteResponse is too functional..
+
+       if (apartFromRoundingMode) {
+         if (!(res.node.getKind() == kind::EQUAL)) {  // Avoid infinite recursion...
+           // We are close to being able to constant fold this
+           // and in many cases the rounding mode really doesn't matter.
+           // So we can try brute forcing our way through them.
+
+           NodeManager *nm = NodeManager::currentNM();
+
+           Node RNE(nm->mkConst(roundNearestTiesToEven));
+           Node RNA(nm->mkConst(roundNearestTiesToAway));
+           Node RTZ(nm->mkConst(roundTowardPositive));
+           Node RTN(nm->mkConst(roundTowardNegative));
+           Node RTP(nm->mkConst(roundTowardZero));
+
+           TNode RM(res.node[0]);
+
+           Node wRNE(res.node.substitute(RM, TNode(RNE)));
+           Node wRNA(res.node.substitute(RM, TNode(RNA)));
+           Node wRTZ(res.node.substitute(RM, TNode(RTZ)));
+           Node wRTN(res.node.substitute(RM, TNode(RTN)));
+           Node wRTP(res.node.substitute(RM, TNode(RTP)));
+
+
+           rs = REWRITE_AGAIN_FULL;
+           rn = nm->mkNode(kind::ITE,
+                           nm->mkNode(kind::EQUAL, RM, RNE),
+                           wRNE,
+                           nm->mkNode(kind::ITE,
+                                      nm->mkNode(kind::EQUAL, RM, RNA),
+                                      wRNA,
+                                      nm->mkNode(kind::ITE,
+                                                 nm->mkNode(kind::EQUAL, RM, RTZ),
+                                                 wRTZ,
+                                                 nm->mkNode(kind::ITE,
+                                                            nm->mkNode(kind::EQUAL, RM, RTN),
+                                                            wRTN,
+                                                            wRTP))));
+         }
+       } else {
+         RewriteResponse tmp = constantFoldTable [res.node.getKind()] (res.node, false);
+         rs = tmp.status;
+         rn = tmp.node;
+       }
+
+       RewriteResponse constRes(rs,rn);
+
+       if (constRes.node != res.node) {
+         Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): before constant fold " << res.node << std::endl;
+         Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): after constant fold " << constRes.node << std::endl;
+       }
+
+       return constRes;
+      }
+    }
+
     return res;
   }
 
index d2a9a046688ba7e67fad28651f1f0fcb5efe033a..56492f921e5bcdef8e558c508468edbff8cfeb05 100644 (file)
@@ -32,6 +32,8 @@ class TheoryFpRewriter {
  protected :
   static RewriteFunction preRewriteTable[kind::LAST_KIND];
   static RewriteFunction postRewriteTable[kind::LAST_KIND];
+  static RewriteFunction constantFoldTable[kind::LAST_KIND]; 
+
 
  public: