From: Martin Date: Tue, 3 Oct 2017 00:41:24 +0000 (+0100) Subject: Add 5 FP kinds for partial to total fn conversion (#1128) X-Git-Tag: cvc5-1.0.0~5587 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=6861f66d2e2b54fc31d9151b4dbeb2964ea07f94;p=cvc5.git Add 5 FP kinds for partial to total fn conversion (#1128) - Add new kinds for partially defined functions - Print the new kinds - Type rules for the new total kinds - Constant folding and rewrites for the new total kinds --- diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 105e2c0dd..c7d6b34ab 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -899,6 +899,8 @@ static string smtKindString(Kind k) throw() { case kind::FLOATINGPOINT_RTI: return "fp.roundToIntegral"; case kind::FLOATINGPOINT_MIN: return "fp.min"; case kind::FLOATINGPOINT_MAX: return "fp.max"; + case kind::FLOATINGPOINT_MIN_TOTAL: return "fp.min_total"; + case kind::FLOATINGPOINT_MAX_TOTAL: return "fp.max_total"; case kind::FLOATINGPOINT_LEQ: return "fp.leq"; case kind::FLOATINGPOINT_LT: return "fp.lt"; @@ -920,8 +922,11 @@ static string smtKindString(Kind k) throw() { case kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR: return "to_fp_unsigned"; case kind::FLOATINGPOINT_TO_FP_GENERIC: return "to_fp_unsigned"; case kind::FLOATINGPOINT_TO_UBV: return "fp.to_ubv"; + case kind::FLOATINGPOINT_TO_UBV_TOTAL: return "fp.to_ubv_total"; case kind::FLOATINGPOINT_TO_SBV: return "fp.to_sbv"; + case kind::FLOATINGPOINT_TO_SBV_TOTAL: return "fp.to_sbv_total"; case kind::FLOATINGPOINT_TO_REAL: return "fp.to_real"; + case kind::FLOATINGPOINT_TO_REAL_TOTAL: return "fp.to_real_total"; //string theory case kind::STRING_CONCAT: return "str.++"; @@ -1043,6 +1048,14 @@ static void printFpParameterizedOp(std::ostream& out, TNode n) throw() { out << "fp.to_sbv " << n.getOperator().getConst().bvs.size; break; + case kind::FLOATINGPOINT_TO_UBV_TOTAL: + out << "fp.to_ubv_total " + << n.getOperator().getConst().bvs.size; + break; + case kind::FLOATINGPOINT_TO_SBV_TOTAL: + out << "fp.to_sbv_total " + << n.getOperator().getConst().bvs.size; + break; default: out << n.getKind(); } diff --git a/src/theory/fp/kinds b/src/theory/fp/kinds index 61a291b53..144e5736f 100644 --- a/src/theory/fp/kinds +++ b/src/theory/fp/kinds @@ -110,6 +110,12 @@ typerule FLOATINGPOINT_MIN ::CVC4::theory::fp::FloatingPointOperationTypeRule operator FLOATINGPOINT_MAX 2 "floating-point maximum" typerule FLOATINGPOINT_MAX ::CVC4::theory::fp::FloatingPointOperationTypeRule +operator FLOATINGPOINT_MIN_TOTAL 3 "floating-point minimum (defined for all inputs)" +typerule FLOATINGPOINT_MIN_TOTAL ::CVC4::theory::fp::FloatingPointPartialOperationTypeRule + +operator FLOATINGPOINT_MAX_TOTAL 3 "floating-point maximum (defined for all inputs)" +typerule FLOATINGPOINT_MAX_TOTAL ::CVC4::theory::fp::FloatingPointPartialOperationTypeRule + operator FLOATINGPOINT_LEQ 2: "floating-point less than or equal" typerule FLOATINGPOINT_LEQ ::CVC4::theory::fp::FloatingPointTestTypeRule @@ -236,6 +242,17 @@ parameterized FLOATINGPOINT_TO_UBV FLOATINGPOINT_TO_UBV_OP 2 "convert a floating typerule FLOATINGPOINT_TO_UBV ::CVC4::theory::fp::FloatingPointToUBVTypeRule +constant FLOATINGPOINT_TO_UBV_TOTAL_OP \ + ::CVC4::FloatingPointToUBVTotal \ + "::CVC4::FloatingPointToBVHashFunction<0x4>" \ + "util/floatingpoint.h" \ + "operator for to_ubv_total" +typerule FLOATINGPOINT_TO_UBV_TOTAL_OP ::CVC4::theory::fp::FloatingPointParametricOpTypeRule + +parameterized FLOATINGPOINT_TO_UBV_TOTAL FLOATINGPOINT_TO_UBV_TOTAL_OP 3 "convert a floating-point value to an unsigned bit vector (defined for all inputs)" +typerule FLOATINGPOINT_TO_UBV_TOTAL ::CVC4::theory::fp::FloatingPointToUBVTotalTypeRule + + constant FLOATINGPOINT_TO_SBV_OP \ ::CVC4::FloatingPointToSBV \ @@ -248,8 +265,22 @@ parameterized FLOATINGPOINT_TO_SBV FLOATINGPOINT_TO_SBV_OP 2 "convert a floating typerule FLOATINGPOINT_TO_SBV ::CVC4::theory::fp::FloatingPointToSBVTypeRule +constant FLOATINGPOINT_TO_SBV_TOTAL_OP \ + ::CVC4::FloatingPointToSBVTotal \ + "::CVC4::FloatingPointToBVHashFunction<0x8>" \ + "util/floatingpoint.h" \ + "operator for to_sbv_total" +typerule FLOATINGPOINT_TO_SBV_TOTAL_OP ::CVC4::theory::fp::FloatingPointParametricOpTypeRule + +parameterized FLOATINGPOINT_TO_SBV_TOTAL FLOATINGPOINT_TO_SBV_TOTAL_OP 3 "convert a floating-point value to a signed bit vector (defined for all inputs)" +typerule FLOATINGPOINT_TO_SBV_TOTAL ::CVC4::theory::fp::FloatingPointToSBVTotalTypeRule + + operator FLOATINGPOINT_TO_REAL 1 "floating-point to real" typerule FLOATINGPOINT_TO_REAL ::CVC4::theory::fp::FloatingPointToRealTypeRule +operator FLOATINGPOINT_TO_REAL_TOTAL 2 "floating-point to real (defined for all inputs)" +typerule FLOATINGPOINT_TO_REAL_TOTAL ::CVC4::theory::fp::FloatingPointToRealTotalTypeRule + endtheory diff --git a/src/theory/fp/theory_fp_rewriter.cpp b/src/theory/fp/theory_fp_rewriter.cpp index ec42099c2..98ac536ec 100644 --- a/src/theory/fp/theory_fp_rewriter.cpp +++ b/src/theory/fp/theory_fp_rewriter.cpp @@ -179,7 +179,8 @@ namespace rewrite { RewriteResponse compactMinMax (TNode node, bool isPreRewrite) { #ifdef CVC4_ASSERTIONS Kind k = node.getKind(); - Assert((k == kind::FLOATINGPOINT_MIN) || (k == kind::FLOATINGPOINT_MAX)); + Assert((k == kind::FLOATINGPOINT_MIN) || (k == kind::FLOATINGPOINT_MAX) || + (k == kind::FLOATINGPOINT_MIN_TOTAL) || (k == kind::FLOATINGPOINT_MAX_TOTAL)); #endif if (node[0] == node[1]) { return RewriteResponse(REWRITE_DONE, node[0]); @@ -410,6 +411,66 @@ namespace constantFold { } } + RewriteResponse minTotal (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_MIN_TOTAL); + Assert(node.getNumChildren() == 3); + + FloatingPoint arg1(node[0].getConst()); + FloatingPoint arg2(node[1].getConst()); + + Assert(arg1.t == arg2.t); + + // Can be called with the third argument non-constant + if (node[2].getMetaKind() == kind::metakind::CONSTANT) { + BitVector arg3(node[2].getConst()); + + FloatingPoint folded(arg1.minTotal(arg2, arg3.isBitSet(0))); + Node lit = NodeManager::currentNM()->mkConst(folded); + return RewriteResponse(REWRITE_DONE, lit); + + } else { + FloatingPoint::PartialFloatingPoint res(arg1.min(arg2)); + + if (res.second) { + Node lit = NodeManager::currentNM()->mkConst(res.first); + return RewriteResponse(REWRITE_DONE, lit); + } else { + // Can't constant fold the underspecified case + return RewriteResponse(REWRITE_DONE, node); + } + } + } + + RewriteResponse maxTotal (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_MAX_TOTAL); + Assert(node.getNumChildren() == 3); + + FloatingPoint arg1(node[0].getConst()); + FloatingPoint arg2(node[1].getConst()); + + Assert(arg1.t == arg2.t); + + // Can be called with the third argument non-constant + if (node[2].getMetaKind() == kind::metakind::CONSTANT) { + BitVector arg3(node[2].getConst()); + + FloatingPoint folded(arg1.maxTotal(arg2, arg3.isBitSet(0))); + Node lit = NodeManager::currentNM()->mkConst(folded); + return RewriteResponse(REWRITE_DONE, lit); + + } else { + FloatingPoint::PartialFloatingPoint res(arg1.max(arg2)); + + if (res.second) { + Node lit = NodeManager::currentNM()->mkConst(res.first); + return RewriteResponse(REWRITE_DONE, lit); + } else { + // Can't constant fold the underspecified case + return RewriteResponse(REWRITE_DONE, node); + } + } + } + RewriteResponse equal (TNode node, bool isPreRewrite) { Assert(node.getKind() == kind::EQUAL); @@ -644,6 +705,94 @@ namespace constantFold { } } + RewriteResponse convertToUBVTotal (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV_TOTAL); + + TNode op = node.getOperator(); + const FloatingPointToUBVTotal ¶m = op.getConst(); + + RoundingMode rm(node[0].getConst()); + FloatingPoint arg(node[1].getConst()); + + // Can be called with the third argument non-constant + if (node[2].getMetaKind() == kind::metakind::CONSTANT) { + BitVector partialValue(node[2].getConst()); + + BitVector folded(arg.convertToBVTotal(param.bvs, rm, false, partialValue)); + Node lit = NodeManager::currentNM()->mkConst(folded); + return RewriteResponse(REWRITE_DONE, lit); + + } else { + FloatingPoint::PartialBitVector res(arg.convertToBV(param.bvs, rm, false)); + + if (res.second) { + Node lit = NodeManager::currentNM()->mkConst(res.first); + return RewriteResponse(REWRITE_DONE, lit); + } else { + // Can't constant fold the underspecified case + return RewriteResponse(REWRITE_DONE, node); + } + } + } + + RewriteResponse convertToSBVTotal (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV_TOTAL); + + TNode op = node.getOperator(); + const FloatingPointToSBVTotal ¶m = op.getConst(); + + RoundingMode rm(node[0].getConst()); + FloatingPoint arg(node[1].getConst()); + + // Can be called with the third argument non-constant + if (node[2].getMetaKind() == kind::metakind::CONSTANT) { + BitVector partialValue(node[2].getConst()); + + BitVector folded(arg.convertToBVTotal(param.bvs, rm, true, partialValue)); + Node lit = NodeManager::currentNM()->mkConst(folded); + return RewriteResponse(REWRITE_DONE, lit); + + } else { + + FloatingPoint::PartialBitVector res(arg.convertToBV(param.bvs, rm, true)); + + if (res.second) { + Node lit = NodeManager::currentNM()->mkConst(res.first); + return RewriteResponse(REWRITE_DONE, lit); + } else { + // Can't constant fold the underspecified case + return RewriteResponse(REWRITE_DONE, node); + } + } + } + + RewriteResponse convertToRealTotal (TNode node, bool) { + Assert(node.getKind() == kind::FLOATINGPOINT_TO_REAL_TOTAL); + + FloatingPoint arg(node[0].getConst()); + + // Can be called with the third argument non-constant + if (node[1].getMetaKind() == kind::metakind::CONSTANT) { + Rational partialValue(node[1].getConst()); + + Rational folded(arg.convertToRationalTotal(partialValue)); + Node lit = NodeManager::currentNM()->mkConst(folded); + return RewriteResponse(REWRITE_DONE, lit); + + } else { + FloatingPoint::PartialRational res(arg.convertToRational()); + + if (res.second) { + Node lit = NodeManager::currentNM()->mkConst(res.first); + return RewriteResponse(REWRITE_DONE, lit); + } else { + // Can't constant fold the underspecified case + return RewriteResponse(REWRITE_DONE, node); + } + } + } + + }; /* CVC4::theory::fp::constantFold */ @@ -686,6 +835,8 @@ RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND]; preRewriteTable[kind::FLOATINGPOINT_RTI] = rewrite::identity; preRewriteTable[kind::FLOATINGPOINT_MIN] = rewrite::compactMinMax; preRewriteTable[kind::FLOATINGPOINT_MAX] = rewrite::compactMinMax; + preRewriteTable[kind::FLOATINGPOINT_MIN_TOTAL] = rewrite::compactMinMax; + preRewriteTable[kind::FLOATINGPOINT_MAX_TOTAL] = rewrite::compactMinMax; /******** Comparisons ********/ preRewriteTable[kind::FLOATINGPOINT_EQ] = rewrite::then; @@ -713,6 +864,9 @@ RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND]; preRewriteTable[kind::FLOATINGPOINT_TO_UBV] = rewrite::identity; preRewriteTable[kind::FLOATINGPOINT_TO_SBV] = rewrite::identity; preRewriteTable[kind::FLOATINGPOINT_TO_REAL] = rewrite::identity; + preRewriteTable[kind::FLOATINGPOINT_TO_UBV_TOTAL] = rewrite::identity; + preRewriteTable[kind::FLOATINGPOINT_TO_SBV_TOTAL] = rewrite::identity; + preRewriteTable[kind::FLOATINGPOINT_TO_REAL_TOTAL] = rewrite::identity; /******** Variables ********/ preRewriteTable[kind::VARIABLE] = rewrite::variable; @@ -753,6 +907,8 @@ RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND]; postRewriteTable[kind::FLOATINGPOINT_RTI] = rewrite::identity; postRewriteTable[kind::FLOATINGPOINT_MIN] = rewrite::compactMinMax; postRewriteTable[kind::FLOATINGPOINT_MAX] = rewrite::compactMinMax; + postRewriteTable[kind::FLOATINGPOINT_MIN_TOTAL] = rewrite::compactMinMax; + postRewriteTable[kind::FLOATINGPOINT_MAX_TOTAL] = rewrite::compactMinMax; /******** Comparisons ********/ postRewriteTable[kind::FLOATINGPOINT_EQ] = rewrite::removed; @@ -780,6 +936,9 @@ RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND]; postRewriteTable[kind::FLOATINGPOINT_TO_UBV] = rewrite::identity; postRewriteTable[kind::FLOATINGPOINT_TO_SBV] = rewrite::identity; postRewriteTable[kind::FLOATINGPOINT_TO_REAL] = rewrite::identity; + postRewriteTable[kind::FLOATINGPOINT_TO_UBV_TOTAL] = rewrite::identity; + postRewriteTable[kind::FLOATINGPOINT_TO_SBV_TOTAL] = rewrite::identity; + postRewriteTable[kind::FLOATINGPOINT_TO_REAL_TOTAL] = rewrite::identity; /******** Variables ********/ postRewriteTable[kind::VARIABLE] = rewrite::variable; @@ -823,6 +982,8 @@ RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND]; constantFoldTable[kind::FLOATINGPOINT_RTI] = constantFold::rti; constantFoldTable[kind::FLOATINGPOINT_MIN] = constantFold::min; constantFoldTable[kind::FLOATINGPOINT_MAX] = constantFold::max; + constantFoldTable[kind::FLOATINGPOINT_MIN_TOTAL] = constantFold::minTotal; + constantFoldTable[kind::FLOATINGPOINT_MAX_TOTAL] = constantFold::maxTotal; /******** Comparisons ********/ constantFoldTable[kind::FLOATINGPOINT_EQ] = rewrite::removed; @@ -850,6 +1011,9 @@ RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND]; constantFoldTable[kind::FLOATINGPOINT_TO_UBV] = constantFold::convertToUBV; constantFoldTable[kind::FLOATINGPOINT_TO_SBV] = constantFold::convertToSBV; constantFoldTable[kind::FLOATINGPOINT_TO_REAL] = constantFold::convertToReal; + constantFoldTable[kind::FLOATINGPOINT_TO_UBV_TOTAL] = constantFold::convertToUBVTotal; + constantFoldTable[kind::FLOATINGPOINT_TO_SBV_TOTAL] = constantFold::convertToSBVTotal; + constantFoldTable[kind::FLOATINGPOINT_TO_REAL_TOTAL] = constantFold::convertToRealTotal; /******** Variables ********/ constantFoldTable[kind::VARIABLE] = rewrite::variable; @@ -918,6 +1082,7 @@ RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND]; if (res.status == REWRITE_DONE) { bool allChildrenConst = true; bool apartFromRoundingMode = false; + bool apartFromPartiallyDefinedArgument = false; for (Node::const_iterator i = res.node.begin(); i != res.node.end(); ++i) { @@ -925,6 +1090,15 @@ RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND]; if ((*i).getMetaKind() != kind::metakind::CONSTANT) { if ((*i).getType().isRoundingMode() && !apartFromRoundingMode) { apartFromRoundingMode = true; + } else if ((res.node.getKind() == kind::FLOATINGPOINT_MIN_TOTAL || + res.node.getKind() == kind::FLOATINGPOINT_MAX_TOTAL || + res.node.getKind() == kind::FLOATINGPOINT_TO_UBV_TOTAL || + res.node.getKind() == kind::FLOATINGPOINT_TO_SBV_TOTAL || + res.node.getKind() == kind::FLOATINGPOINT_TO_REAL_TOTAL) && + ((*i).getType().isBitVector() || + (*i).getType().isReal()) && + !apartFromPartiallyDefinedArgument) { + apartFromPartiallyDefinedArgument = true; } else { allChildrenConst = false; break; diff --git a/src/theory/fp/theory_fp_type_rules.h b/src/theory/fp/theory_fp_type_rules.h index fe39993d4..aa213f84f 100644 --- a/src/theory/fp/theory_fp_type_rules.h +++ b/src/theory/fp/theory_fp_type_rules.h @@ -189,6 +189,42 @@ class FloatingPointRoundingOperationTypeRule { } }; +class FloatingPointPartialOperationTypeRule { + public: + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, + bool check) { + TRACE("FloatingPointOperationTypeRule"); + + TypeNode firstOperand = n[0].getType(check); + + if (check) { + if (!firstOperand.isFloatingPoint()) { + throw TypeCheckingExceptionPrivate( + n, "floating-point operation applied to a non floating-point sort"); + } + + size_t children = n.getNumChildren(); + for (size_t i = 1; i < children - 1; ++i) { + if (!(n[i].getType(check) == firstOperand)) { + throw TypeCheckingExceptionPrivate( + n, "floating-point partial operation applied to mixed sorts"); + } + } + + TypeNode UFValueType = n[children - 1].getType(check); + + if (!(UFValueType.isBitVector()) || + !(UFValueType.getBitVectorSize() == 1)) { + throw TypeCheckingExceptionPrivate( + n, "floating-point partial operation final argument must be a bit-vector of length 1"); + } + } + + return firstOperand; + } +}; + + class FloatingPointParametricOpTypeRule { public: inline static TypeNode computeType(NodeManager* nodeManager, TNode n, @@ -438,6 +474,86 @@ class FloatingPointToSBVTypeRule { } }; +class FloatingPointToUBVTotalTypeRule { + public: + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, + bool check) { + TRACE("FloatingPointToUBVTotalTypeRule"); + + FloatingPointToUBVTotal info = n.getOperator().getConst(); + + if (check) { + TypeNode roundingModeType = n[0].getType(check); + + if (!roundingModeType.isRoundingMode()) { + throw TypeCheckingExceptionPrivate( + n, "first argument must be a rounding mode"); + } + + TypeNode operandType = n[1].getType(check); + + if (!(operandType.isFloatingPoint())) { + throw TypeCheckingExceptionPrivate(n, + "conversion to unsigned bit vector total" + "used with a sort other than " + "floating-point"); + } + + TypeNode defaultValueType = n[2].getType(check); + + if (!(defaultValueType.isBitVector()) || + !(defaultValueType.getBitVectorSize() == info)) { + throw TypeCheckingExceptionPrivate(n, + "conversion to unsigned bit vector total" + "needs a bit vector of the same length" + "as last argument"); + } + } + + return nodeManager->mkBitVectorType(info.bvs); + } +}; + +class FloatingPointToSBVTotalTypeRule { + public: + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, + bool check) { + TRACE("FloatingPointToSBVTotalTypeRule"); + + FloatingPointToSBVTotal info = n.getOperator().getConst(); + + if (check) { + TypeNode roundingModeType = n[0].getType(check); + + if (!roundingModeType.isRoundingMode()) { + throw TypeCheckingExceptionPrivate( + n, "first argument must be a rounding mode"); + } + + TypeNode operandType = n[1].getType(check); + + if (!(operandType.isFloatingPoint())) { + throw TypeCheckingExceptionPrivate(n, + "conversion to signed bit vector " + "used with a sort other than " + "floating-point"); + } + + TypeNode defaultValueType = n[2].getType(check); + + if (!(defaultValueType.isBitVector()) || + !(defaultValueType.getBitVectorSize() == info)) { + throw TypeCheckingExceptionPrivate(n, + "conversion to signed bit vector total" + "needs a bit vector of the same length" + "as last argument"); + } + } + + return nodeManager->mkBitVectorType(info.bvs); + } +}; + class FloatingPointToRealTypeRule { public: inline static TypeNode computeType(NodeManager* nodeManager, TNode n, @@ -457,6 +573,33 @@ class FloatingPointToRealTypeRule { } }; +class FloatingPointToRealTotalTypeRule { + public: + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, + bool check) { + TRACE("FloatingPointToRealTotalTypeRule"); + + if (check) { + TypeNode operandType = n[0].getType(check); + + if (!operandType.isFloatingPoint()) { + throw TypeCheckingExceptionPrivate( + n, "floating-point to real total applied to a non floating-point sort"); + } + + TypeNode defaultValueType = n[1].getType(check); + + if (!defaultValueType.isReal()) { + throw TypeCheckingExceptionPrivate( + n, "floating-point to real total needs a real second argument"); + } + + } + + return nodeManager->realType(); + } +}; + class CardinalityComputer { public: