From e23377411d993e126403eb186c80f664419d512c Mon Sep 17 00:00:00 2001 From: Martin Brain Date: Tue, 26 Sep 2017 17:21:51 -0700 Subject: [PATCH] Improve FP rewriter: const folding, other (#1126) --- src/theory/fp/theory_fp_rewriter.cpp | 643 +++++++++++++++++++++++---- src/theory/fp/theory_fp_rewriter.h | 2 + 2 files changed, 566 insertions(+), 79 deletions(-) diff --git a/src/theory/fp/theory_fp_rewriter.cpp b/src/theory/fp/theory_fp_rewriter.cpp index 747aaeac6..ec42099c2 100644 --- a/src/theory/fp/theory_fp_rewriter.cpp +++ b/src/theory/fp/theory_fp_rewriter.cpp @@ -12,17 +12,22 @@ ** ** \brief [[ Rewrite rules for floating point theories. ]] ** - ** \todo [[ Constant folding - ** Push negations up through arithmetic operators (include max and min? maybe not due to +0/-0) + ** \todo [[ Single argument constant propagate / simplify + Push negations through arithmetic operators (include max and min? maybe not due to +0/-0) ** classifications to normal tests (maybe) ** (= x (fp.neg x)) --> (isNaN x) ** (fp.eq x (fp.neg x)) --> (isZero x) (previous and reorganise should be sufficient) - ** (fp.eq x const) --> various = depending on const + ** (fp.eq x const) --> various = depending on const ** (fp.abs (fp.neg x)) --> (fp.abs x) ** (fp.isPositive (fp.neg x)) --> (fp.isNegative x) ** (fp.isNegative (fp.neg x)) --> (fp.isPositive x) ** (fp.isPositive (fp.abs x)) --> (not (isNaN x)) ** (fp.isNegative (fp.abs x)) --> false + ** A -> castA --> A + ** A -> castB -> castC --> A -> castC if A <= B <= C + ** A -> castB -> castA --> A if A <= B + ** promotion converts can ignore rounding mode + ** Samuel Figuer results ** ]] **/ @@ -137,7 +142,7 @@ namespace rewrite { } RewriteResponse removed (TNode node, bool) { - Unreachable("kind (%d) should have been removed?",node.getKind()); + Unreachable("kind (%s) should have been removed?",kindToString(node.getKind()).c_str()); return RewriteResponse(REWRITE_DONE, node); } @@ -150,10 +155,10 @@ namespace rewrite { return RewriteResponse(REWRITE_DONE, node); } - RewriteResponse equal (TNode node, bool isPreRewrite) { - // We should only get equalities of floating point or rounding mode types. + RewriteResponse equal (TNode node, bool isPreRewrite) { Assert(node.getKind() == kind::EQUAL); - + + // We should only get equalities of floating point or rounding mode types. TypeNode tn = node[0].getType(true); Assert(tn.isFloatingPoint() || tn.isRoundingMode()); @@ -169,74 +174,6 @@ namespace rewrite { } } - RewriteResponse convertFromRealLiteral (TNode node, bool) { - Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_REAL); - - // \todo Honour the rounding mode and work for something other than doubles! - - if (node[1].getKind() == kind::CONST_RATIONAL) { - TNode op = node.getOperator(); - const FloatingPointToFPReal ¶m = op.getConst(); - Node lit = - NodeManager::currentNM()->mkConst(FloatingPoint(param.t.exponent(), - param.t.significand(), - node[1].getConst().getDouble())); - - return RewriteResponse(REWRITE_DONE, lit); - } else { - return RewriteResponse(REWRITE_DONE, node); - } - } - - RewriteResponse convertFromIEEEBitVectorLiteral (TNode node, bool) { - Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR); - - // \todo Handle arbitrary length bit vectors without using strings! - - if (node[0].getKind() == kind::CONST_BITVECTOR) { - TNode op = node.getOperator(); - const FloatingPointToFPIEEEBitVector ¶m = op.getConst(); - const BitVector &bv = node[0].getConst(); - std::string bitString(bv.toString()); - - Node lit = - NodeManager::currentNM()->mkConst(FloatingPoint(param.t.exponent(), - param.t.significand(), - bitString)); - - return RewriteResponse(REWRITE_DONE, lit); - } else { - return RewriteResponse(REWRITE_DONE, node); - } - } - - RewriteResponse convertFromLiteral (TNode node, bool) { - Assert(node.getKind() == kind::FLOATINGPOINT_FP); - - // \todo Handle arbitrary length bit vectors without using strings! - - if ((node[0].getKind() == kind::CONST_BITVECTOR) && - (node[1].getKind() == kind::CONST_BITVECTOR) && - (node[2].getKind() == kind::CONST_BITVECTOR)) { - - BitVector bv(node[0].getConst()); - bv = bv.concat(node[1].getConst()); - bv = bv.concat(node[2].getConst()); - - std::string bitString(bv.toString()); - std::reverse(bitString.begin(), bitString.end()); - - // +1 to support the hidden bit - Node lit = - NodeManager::currentNM()->mkConst(FloatingPoint(node[1].getConst().getSize(), - node[2].getConst().getSize() + 1, - bitString)); - - return RewriteResponse(REWRITE_DONE, lit); - } else { - return RewriteResponse(REWRITE_DONE, node); - } - } // Note these cannot be assumed to be symmetric for +0/-0, thus no symmetry reorder RewriteResponse compactMinMax (TNode node, bool isPreRewrite) { @@ -309,11 +246,410 @@ namespace rewrite { } } - }; /* CVC4::theory::fp::rewrite */ + +namespace constantFold { + + + RewriteResponse fpLiteral (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_FP); + + BitVector bv(node[0].getConst()); + bv = bv.concat(node[1].getConst()); + bv = bv.concat(node[2].getConst()); + + // +1 to support the hidden bit + Node lit = + NodeManager::currentNM()->mkConst(FloatingPoint(node[1].getConst().getSize(), + node[2].getConst().getSize() + 1, + bv)); + + return RewriteResponse(REWRITE_DONE, lit); + } + + RewriteResponse abs (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_ABS); + Assert(node.getNumChildren() == 1); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst().absolute())); + } + + + RewriteResponse neg (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_NEG); + Assert(node.getNumChildren() == 1); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst().negate())); + } + + + RewriteResponse plus (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_PLUS); + Assert(node.getNumChildren() == 3); + + RoundingMode rm(node[0].getConst()); + FloatingPoint arg1(node[1].getConst()); + FloatingPoint arg2(node[2].getConst()); + + Assert(arg1.t == arg2.t); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.plus(rm, arg2))); + } + + RewriteResponse mult (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_MULT); + Assert(node.getNumChildren() == 3); + + RoundingMode rm(node[0].getConst()); + FloatingPoint arg1(node[1].getConst()); + FloatingPoint arg2(node[2].getConst()); + + Assert(arg1.t == arg2.t); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.mult(rm, arg2))); + } + + RewriteResponse fma (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_FMA); + Assert(node.getNumChildren() == 4); + + RoundingMode rm(node[0].getConst()); + FloatingPoint arg1(node[1].getConst()); + FloatingPoint arg2(node[2].getConst()); + FloatingPoint arg3(node[3].getConst()); + + Assert(arg1.t == arg2.t); + Assert(arg1.t == arg3.t); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.fma(rm, arg2, arg3))); + } + + RewriteResponse div (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_DIV); + Assert(node.getNumChildren() == 3); + + RoundingMode rm(node[0].getConst()); + FloatingPoint arg1(node[1].getConst()); + FloatingPoint arg2(node[2].getConst()); + + Assert(arg1.t == arg2.t); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.div(rm, arg2))); + } + + RewriteResponse sqrt (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_SQRT); + Assert(node.getNumChildren() == 2); + + RoundingMode rm(node[0].getConst()); + FloatingPoint arg(node[1].getConst()); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg.sqrt(rm))); + } + + RewriteResponse rti (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_RTI); + Assert(node.getNumChildren() == 2); + + RoundingMode rm(node[0].getConst()); + FloatingPoint arg(node[1].getConst()); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg.rti(rm))); + } + + RewriteResponse rem (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_REM); + Assert(node.getNumChildren() == 2); + + FloatingPoint arg1(node[0].getConst()); + FloatingPoint arg2(node[1].getConst()); + + Assert(arg1.t == arg2.t); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.rem(arg2))); + } + + RewriteResponse min (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_MIN); + Assert(node.getNumChildren() == 2); + + FloatingPoint arg1(node[0].getConst()); + FloatingPoint arg2(node[1].getConst()); + + Assert(arg1.t == arg2.t); + + FloatingPoint::PartialFloatingPoint res(arg1.min(arg2)); + + if (res.second) { + Node lit = NodeManager::currentNM()->mkConst(res.first); + return RewriteResponse(REWRITE_DONE, lit); + } else { + // Can't constant fold the underspecified case + return RewriteResponse(REWRITE_DONE, node); + } + } + + RewriteResponse max (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_MAX); + Assert(node.getNumChildren() == 2); + + FloatingPoint arg1(node[0].getConst()); + FloatingPoint arg2(node[1].getConst()); + + Assert(arg1.t == arg2.t); + + FloatingPoint::PartialFloatingPoint res(arg1.max(arg2)); + + if (res.second) { + Node lit = NodeManager::currentNM()->mkConst(res.first); + return RewriteResponse(REWRITE_DONE, lit); + } else { + // Can't constant fold the underspecified case + return RewriteResponse(REWRITE_DONE, node); + } + } + + + RewriteResponse equal (TNode node, bool isPreRewrite) { + Assert(node.getKind() == kind::EQUAL); + + // We should only get equalities of floating point or rounding mode types. + TypeNode tn = node[0].getType(true); + + if (tn.isFloatingPoint()) { + FloatingPoint arg1(node[0].getConst()); + FloatingPoint arg2(node[1].getConst()); + + Assert(arg1.t == arg2.t); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 == arg2)); + + } else if (tn.isRoundingMode()) { + RoundingMode arg1(node[0].getConst()); + RoundingMode arg2(node[1].getConst()); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 == arg2)); + + } else { + Unreachable("Equality of unknown type"); + } + + return RewriteResponse(REWRITE_DONE, node); + } + + + RewriteResponse leq (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_LEQ); + Assert(node.getNumChildren() == 2); + + FloatingPoint arg1(node[0].getConst()); + FloatingPoint arg2(node[1].getConst()); + + Assert(arg1.t == arg2.t); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 <= arg2)); + } + + + RewriteResponse lt (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_LT); + Assert(node.getNumChildren() == 2); + + FloatingPoint arg1(node[0].getConst()); + FloatingPoint arg2(node[1].getConst()); + + Assert(arg1.t == arg2.t); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 < arg2)); + } + + + RewriteResponse isNormal (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_ISN); + Assert(node.getNumChildren() == 1); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst().isNormal())); + } + + RewriteResponse isSubnormal (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_ISSN); + Assert(node.getNumChildren() == 1); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst().isSubnormal())); + } + + RewriteResponse isZero (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_ISZ); + Assert(node.getNumChildren() == 1); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst().isZero())); + } + + RewriteResponse isInfinite (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_ISINF); + Assert(node.getNumChildren() == 1); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst().isInfinite())); + } + + RewriteResponse isNaN (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_ISNAN); + Assert(node.getNumChildren() == 1); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst().isNaN())); + } + + RewriteResponse isNegative (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_ISNEG); + Assert(node.getNumChildren() == 1); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst().isNegative())); + } + + RewriteResponse isPositive (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_ISPOS); + Assert(node.getNumChildren() == 1); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst().isPositive())); + } + + RewriteResponse convertFromIEEEBitVectorLiteral (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR); + + TNode op = node.getOperator(); + const FloatingPointToFPIEEEBitVector ¶m = op.getConst(); + const BitVector &bv = node[0].getConst(); + + Node lit = + NodeManager::currentNM()->mkConst(FloatingPoint(param.t.exponent(), + param.t.significand(), + bv)); + + return RewriteResponse(REWRITE_DONE, lit); + } + + RewriteResponse constantConvert (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT); + Assert(node.getNumChildren() == 2); + + RoundingMode rm(node[0].getConst()); + FloatingPoint arg1(node[1].getConst()); + FloatingPointToFPFloatingPoint info = node.getOperator().getConst(); + + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.convert(info.t,rm))); + } + + RewriteResponse convertFromRealLiteral (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_REAL); + + TNode op = node.getOperator(); + const FloatingPointToFPReal ¶m = op.getConst(); + + RoundingMode rm(node[0].getConst()); + Rational arg(node[1].getConst()); + + FloatingPoint res(param.t, rm, arg); + + Node lit = NodeManager::currentNM()->mkConst(res); + + return RewriteResponse(REWRITE_DONE, lit); + } + + RewriteResponse convertFromSBV (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR); + + TNode op = node.getOperator(); + const FloatingPointToFPSignedBitVector ¶m = op.getConst(); + + RoundingMode rm(node[0].getConst()); + BitVector arg(node[1].getConst()); + + FloatingPoint res(param.t, rm, arg, true); + + Node lit = NodeManager::currentNM()->mkConst(res); + + return RewriteResponse(REWRITE_DONE, lit); + } + + RewriteResponse convertFromUBV (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR); + + TNode op = node.getOperator(); + const FloatingPointToFPUnsignedBitVector ¶m = op.getConst(); + + RoundingMode rm(node[0].getConst()); + BitVector arg(node[1].getConst()); + + FloatingPoint res(param.t, rm, arg, false); + + Node lit = NodeManager::currentNM()->mkConst(res); + + return RewriteResponse(REWRITE_DONE, lit); + } + + RewriteResponse convertToUBV (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV); + + TNode op = node.getOperator(); + const FloatingPointToUBV ¶m = op.getConst(); + + RoundingMode rm(node[0].getConst()); + FloatingPoint arg(node[1].getConst()); + + FloatingPoint::PartialBitVector res(arg.convertToBV(param.bvs, rm, false)); + + if (res.second) { + Node lit = NodeManager::currentNM()->mkConst(res.first); + return RewriteResponse(REWRITE_DONE, lit); + } else { + // Can't constant fold the underspecified case + return RewriteResponse(REWRITE_DONE, node); + } + } + + RewriteResponse convertToSBV (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV); + + TNode op = node.getOperator(); + const FloatingPointToSBV ¶m = op.getConst(); + + RoundingMode rm(node[0].getConst()); + FloatingPoint arg(node[1].getConst()); + + FloatingPoint::PartialBitVector res(arg.convertToBV(param.bvs, rm, true)); + + if (res.second) { + Node lit = NodeManager::currentNM()->mkConst(res.first); + return RewriteResponse(REWRITE_DONE, lit); + } else { + // Can't constant fold the underspecified case + return RewriteResponse(REWRITE_DONE, node); + } + } + + RewriteResponse convertToReal (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_TO_REAL); + + FloatingPoint arg(node[0].getConst()); + + FloatingPoint::PartialRational res(arg.convertToRational()); + + if (res.second) { + Node lit = NodeManager::currentNM()->mkConst(res.first); + return RewriteResponse(REWRITE_DONE, lit); + } else { + // Can't constant fold the underspecified case + return RewriteResponse(REWRITE_DONE, node); + } + } + +}; /* CVC4::theory::fp::constantFold */ + + RewriteFunction TheoryFpRewriter::preRewriteTable[kind::LAST_KIND]; RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND]; +RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND]; /** @@ -381,6 +717,7 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND]; /******** Variables ********/ preRewriteTable[kind::VARIABLE] = rewrite::variable; preRewriteTable[kind::BOUND_VARIABLE] = rewrite::variable; + preRewriteTable[kind::SKOLEM] = rewrite::variable; preRewriteTable[kind::EQUAL] = rewrite::equal; @@ -403,7 +740,7 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND]; postRewriteTable[kind::FLOATINGPOINT_TYPE] = rewrite::type; /******** Operations ********/ - postRewriteTable[kind::FLOATINGPOINT_FP] = rewrite::convertFromLiteral; + postRewriteTable[kind::FLOATINGPOINT_FP] = rewrite::identity; postRewriteTable[kind::FLOATINGPOINT_ABS] = rewrite::identity; postRewriteTable[kind::FLOATINGPOINT_NEG] = rewrite::removeDoubleNegation; postRewriteTable[kind::FLOATINGPOINT_PLUS] = rewrite::reorderBinaryOperation; @@ -434,9 +771,9 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND]; postRewriteTable[kind::FLOATINGPOINT_ISPOS] = rewrite::identity; /******** Conversions ********/ - postRewriteTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = rewrite::convertFromIEEEBitVectorLiteral; + postRewriteTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = rewrite::identity; postRewriteTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::convertFromRealLiteral; + postRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::identity; postRewriteTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] = rewrite::identity; postRewriteTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] = rewrite::identity; postRewriteTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::removed; @@ -447,10 +784,81 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND]; /******** Variables ********/ postRewriteTable[kind::VARIABLE] = rewrite::variable; postRewriteTable[kind::BOUND_VARIABLE] = rewrite::variable; + postRewriteTable[kind::SKOLEM] = rewrite::variable; postRewriteTable[kind::EQUAL] = rewrite::equal; + + + /* Set up the post-rewrite constant fold table */ + for (unsigned i = 0; i < kind::LAST_KIND; ++i) { + // Note that this is identity, not notFP + // Constant folding is called after post-rewrite + // So may have to deal with cases of things being + // re-written to non-floating-point sorts (i.e. true). + constantFoldTable[i] = rewrite::identity; + } + + /******** Constants ********/ + /* Already folded! */ + constantFoldTable[kind::CONST_FLOATINGPOINT] = rewrite::identity; + constantFoldTable[kind::CONST_ROUNDINGMODE] = rewrite::identity; + + /******** Sorts(?) ********/ + /* These kinds should only appear in types */ + constantFoldTable[kind::FLOATINGPOINT_TYPE] = rewrite::type; + + /******** Operations ********/ + constantFoldTable[kind::FLOATINGPOINT_FP] = constantFold::fpLiteral; + constantFoldTable[kind::FLOATINGPOINT_ABS] = constantFold::abs; + constantFoldTable[kind::FLOATINGPOINT_NEG] = constantFold::neg; + constantFoldTable[kind::FLOATINGPOINT_PLUS] = constantFold::plus; + constantFoldTable[kind::FLOATINGPOINT_SUB] = rewrite::removed; + constantFoldTable[kind::FLOATINGPOINT_MULT] = constantFold::mult; + constantFoldTable[kind::FLOATINGPOINT_DIV] = constantFold::div; + constantFoldTable[kind::FLOATINGPOINT_FMA] = constantFold::fma; + constantFoldTable[kind::FLOATINGPOINT_SQRT] = constantFold::sqrt; + constantFoldTable[kind::FLOATINGPOINT_REM] = constantFold::rem; + constantFoldTable[kind::FLOATINGPOINT_RTI] = constantFold::rti; + constantFoldTable[kind::FLOATINGPOINT_MIN] = constantFold::min; + constantFoldTable[kind::FLOATINGPOINT_MAX] = constantFold::max; + + /******** Comparisons ********/ + constantFoldTable[kind::FLOATINGPOINT_EQ] = rewrite::removed; + constantFoldTable[kind::FLOATINGPOINT_LEQ] = constantFold::leq; + constantFoldTable[kind::FLOATINGPOINT_LT] = constantFold::lt; + constantFoldTable[kind::FLOATINGPOINT_GEQ] = rewrite::removed; + constantFoldTable[kind::FLOATINGPOINT_GT] = rewrite::removed; + + /******** Classifications ********/ + constantFoldTable[kind::FLOATINGPOINT_ISN] = constantFold::isNormal; + constantFoldTable[kind::FLOATINGPOINT_ISSN] = constantFold::isSubnormal; + constantFoldTable[kind::FLOATINGPOINT_ISZ] = constantFold::isZero; + constantFoldTable[kind::FLOATINGPOINT_ISINF] = constantFold::isInfinite; + constantFoldTable[kind::FLOATINGPOINT_ISNAN] = constantFold::isNaN; + constantFoldTable[kind::FLOATINGPOINT_ISNEG] = constantFold::isNegative; + constantFoldTable[kind::FLOATINGPOINT_ISPOS] = constantFold::isPositive; + + /******** Conversions ********/ + constantFoldTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = constantFold::convertFromIEEEBitVectorLiteral; + constantFoldTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] = constantFold::constantConvert; + constantFoldTable[kind::FLOATINGPOINT_TO_FP_REAL] = constantFold::convertFromRealLiteral; + constantFoldTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] = constantFold::convertFromSBV; + constantFoldTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] = constantFold::convertFromUBV; + constantFoldTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::removed; + constantFoldTable[kind::FLOATINGPOINT_TO_UBV] = constantFold::convertToUBV; + constantFoldTable[kind::FLOATINGPOINT_TO_SBV] = constantFold::convertToSBV; + constantFoldTable[kind::FLOATINGPOINT_TO_REAL] = constantFold::convertToReal; + + /******** Variables ********/ + constantFoldTable[kind::VARIABLE] = rewrite::variable; + constantFoldTable[kind::BOUND_VARIABLE] = rewrite::variable; + + constantFoldTable[kind::EQUAL] = constantFold::equal; + + + } @@ -506,6 +914,83 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND]; Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): before " << node << std::endl; Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): after " << res.node << std::endl; } + + if (res.status == REWRITE_DONE) { + bool allChildrenConst = true; + bool apartFromRoundingMode = false; + for (Node::const_iterator i = res.node.begin(); + i != res.node.end(); + ++i) { + + if ((*i).getMetaKind() != kind::metakind::CONSTANT) { + if ((*i).getType().isRoundingMode() && !apartFromRoundingMode) { + apartFromRoundingMode = true; + } else { + allChildrenConst = false; + break; + } + } + } + + if (allChildrenConst) { + RewriteStatus rs = REWRITE_DONE; // This is a bit messy because + Node rn = res.node; // RewriteResponse is too functional.. + + if (apartFromRoundingMode) { + if (!(res.node.getKind() == kind::EQUAL)) { // Avoid infinite recursion... + // We are close to being able to constant fold this + // and in many cases the rounding mode really doesn't matter. + // So we can try brute forcing our way through them. + + NodeManager *nm = NodeManager::currentNM(); + + Node RNE(nm->mkConst(roundNearestTiesToEven)); + Node RNA(nm->mkConst(roundNearestTiesToAway)); + Node RTZ(nm->mkConst(roundTowardPositive)); + Node RTN(nm->mkConst(roundTowardNegative)); + Node RTP(nm->mkConst(roundTowardZero)); + + TNode RM(res.node[0]); + + Node wRNE(res.node.substitute(RM, TNode(RNE))); + Node wRNA(res.node.substitute(RM, TNode(RNA))); + Node wRTZ(res.node.substitute(RM, TNode(RTZ))); + Node wRTN(res.node.substitute(RM, TNode(RTN))); + Node wRTP(res.node.substitute(RM, TNode(RTP))); + + + rs = REWRITE_AGAIN_FULL; + rn = nm->mkNode(kind::ITE, + nm->mkNode(kind::EQUAL, RM, RNE), + wRNE, + nm->mkNode(kind::ITE, + nm->mkNode(kind::EQUAL, RM, RNA), + wRNA, + nm->mkNode(kind::ITE, + nm->mkNode(kind::EQUAL, RM, RTZ), + wRTZ, + nm->mkNode(kind::ITE, + nm->mkNode(kind::EQUAL, RM, RTN), + wRTN, + wRTP)))); + } + } else { + RewriteResponse tmp = constantFoldTable [res.node.getKind()] (res.node, false); + rs = tmp.status; + rn = tmp.node; + } + + RewriteResponse constRes(rs,rn); + + if (constRes.node != res.node) { + Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): before constant fold " << res.node << std::endl; + Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): after constant fold " << constRes.node << std::endl; + } + + return constRes; + } + } + return res; } diff --git a/src/theory/fp/theory_fp_rewriter.h b/src/theory/fp/theory_fp_rewriter.h index d2a9a0466..56492f921 100644 --- a/src/theory/fp/theory_fp_rewriter.h +++ b/src/theory/fp/theory_fp_rewriter.h @@ -32,6 +32,8 @@ class TheoryFpRewriter { protected : static RewriteFunction preRewriteTable[kind::LAST_KIND]; static RewriteFunction postRewriteTable[kind::LAST_KIND]; + static RewriteFunction constantFoldTable[kind::LAST_KIND]; + public: -- 2.30.2