FP: Rewrite to_fp conversion from signed bit-vector. (#6472)
authorAina Niemetz <aina.niemetz@gmail.com>
Mon, 3 May 2021 20:27:02 +0000 (13:27 -0700)
committerGitHub <noreply@github.com>
Mon, 3 May 2021 20:27:02 +0000 (20:27 +0000)
SymFPU does not allow to_fp conversion from signed bv of size 1. This
adds rewrites for this case.

Rewrites for the constant and the non-constant cases were tested in
isolation.

src/theory/fp/theory_fp_rewriter.cpp
test/regress/CMakeLists.txt
test/regress/regress0/fp/from_sbv.smt2 [new file with mode: 0644]

index e431ffa09948629155a18439f4868e06eae293c5..74e1ff52616583a488e3f86bc1053119ac778a8c 100644 (file)
  *       - Samuel Figuer results
  */
 
+#include "theory/fp/theory_fp_rewriter.h"
+
 #include <algorithm>
 
 #include "base/check.h"
+#include "theory/bv/theory_bv_utils.h"
 #include "theory/fp/fp_converter.h"
-#include "theory/fp/theory_fp_rewriter.h"
 
 namespace cvc5 {
 namespace theory {
@@ -333,6 +335,28 @@ namespace rewrite {
     return RewriteResponse(REWRITE_DONE, node);
   }
 
+  RewriteResponse toFPSignedBV(TNode node, bool isPreRewrite)
+  {
+    Assert(!isPreRewrite);
+    Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR);
+
+    /* symFPU does not allow conversions from signed bit-vector of size 1 */
+    if (node[1].getType().getBitVectorSize() == 1)
+    {
+      NodeManager* nm = NodeManager::currentNM();
+      Node op = nm->mkConst(FloatingPointToFPUnsignedBitVector(
+          node.getOperator().getConst<FloatingPointToFPSignedBitVector>()));
+      Node fromubv = nm->mkNode(op, node[0], node[1]);
+      return RewriteResponse(
+          REWRITE_AGAIN_FULL,
+          nm->mkNode(kind::ITE,
+                     node[1].eqNode(bv::utils::mkOne(1)),
+                     nm->mkNode(kind::FLOATINGPOINT_NEG, fromubv),
+                     fromubv));
+    }
+    return RewriteResponse(REWRITE_DONE, node);
+  }
+
   };  // namespace rewrite
 
 namespace constantFold {
@@ -736,15 +760,16 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_REAL);
 
     TNode op = node.getOperator();
-    const FloatingPointToFPReal &param = op.getConst<FloatingPointToFPReal>();
+    const FloatingPointSize& size =
+        op.getConst<FloatingPointToFPReal>().getSize();
 
     RoundingMode rm(node[0].getConst<RoundingMode>());
     Rational arg(node[1].getConst<Rational>());
 
-    FloatingPoint res(param.getSize(), rm, arg);
+    FloatingPoint res(size, rm, arg);
 
     Node lit = NodeManager::currentNM()->mkConst(res);
-    
+
     return RewriteResponse(REWRITE_DONE, lit);
   }
 
@@ -753,16 +778,27 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR);
 
     TNode op = node.getOperator();
-    const FloatingPointToFPSignedBitVector &param = op.getConst<FloatingPointToFPSignedBitVector>();
+    const FloatingPointSize& size =
+        op.getConst<FloatingPointToFPSignedBitVector>().getSize();
 
     RoundingMode rm(node[0].getConst<RoundingMode>());
-    BitVector arg(node[1].getConst<BitVector>());
+    BitVector sbv(node[1].getConst<BitVector>());
 
-    FloatingPoint res(param.getSize(), rm, arg, true);
+    NodeManager* nm = NodeManager::currentNM();
 
-    Node lit = NodeManager::currentNM()->mkConst(res);
-    
-    return RewriteResponse(REWRITE_DONE, lit);
+    /* symFPU does not allow conversions from signed bit-vector of size 1 */
+    if (sbv.getSize() == 1)
+    {
+      FloatingPoint fromubv(size, rm, sbv, false);
+      if (sbv.isBitSet(0))
+      {
+        return RewriteResponse(REWRITE_DONE, nm->mkConst(fromubv.negate()));
+      }
+      return RewriteResponse(REWRITE_DONE, nm->mkConst(fromubv));
+    }
+
+    return RewriteResponse(REWRITE_DONE,
+                           nm->mkConst(FloatingPoint(size, rm, sbv, true)));
   }
 
   RewriteResponse convertFromUBV(TNode node, bool isPreRewrite)
@@ -770,15 +806,16 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR);
 
     TNode op = node.getOperator();
-    const FloatingPointToFPUnsignedBitVector &param = op.getConst<FloatingPointToFPUnsignedBitVector>();
+    const FloatingPointSize& size =
+        op.getConst<FloatingPointToFPUnsignedBitVector>().getSize();
 
     RoundingMode rm(node[0].getConst<RoundingMode>());
     BitVector arg(node[1].getConst<BitVector>());
 
-    FloatingPoint res(param.getSize(), rm, arg, false);
+    FloatingPoint res(size, rm, arg, false);
 
     Node lit = NodeManager::currentNM()->mkConst(res);
-    
+
     return RewriteResponse(REWRITE_DONE, lit);
   }
 
@@ -787,13 +824,12 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV);
 
     TNode op = node.getOperator();
-    const FloatingPointToUBV &param = op.getConst<FloatingPointToUBV>();
+    const BitVectorSize& size = op.getConst<FloatingPointToUBV>().d_bv_size;
 
     RoundingMode rm(node[0].getConst<RoundingMode>());
     FloatingPoint arg(node[1].getConst<FloatingPoint>());
 
-    FloatingPoint::PartialBitVector res(
-        arg.convertToBV(param.d_bv_size, rm, false));
+    FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, false));
 
     if (res.second) {
       Node lit = NodeManager::currentNM()->mkConst(res.first);
@@ -809,13 +845,12 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV);
 
     TNode op = node.getOperator();
-    const FloatingPointToSBV &param = op.getConst<FloatingPointToSBV>();
+    const BitVectorSize& size = op.getConst<FloatingPointToSBV>().d_bv_size;
 
     RoundingMode rm(node[0].getConst<RoundingMode>());
     FloatingPoint arg(node[1].getConst<FloatingPoint>());
 
-    FloatingPoint::PartialBitVector res(
-        arg.convertToBV(param.d_bv_size, rm, true));
+    FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, true));
 
     if (res.second) {
       Node lit = NodeManager::currentNM()->mkConst(res.first);
@@ -848,7 +883,8 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV_TOTAL);
 
     TNode op = node.getOperator();
-    const FloatingPointToUBVTotal &param = op.getConst<FloatingPointToUBVTotal>();
+    const BitVectorSize& size =
+        op.getConst<FloatingPointToUBVTotal>().d_bv_size;
 
     RoundingMode rm(node[0].getConst<RoundingMode>());
     FloatingPoint arg(node[1].getConst<FloatingPoint>());
@@ -857,14 +893,12 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
     if (node[2].getMetaKind() == kind::metakind::CONSTANT) {
       BitVector partialValue(node[2].getConst<BitVector>());
 
-      BitVector folded(
-          arg.convertToBVTotal(param.d_bv_size, rm, false, partialValue));
+      BitVector folded(arg.convertToBVTotal(size, rm, false, partialValue));
       Node lit = NodeManager::currentNM()->mkConst(folded);
       return RewriteResponse(REWRITE_DONE, lit);
 
     } else {
-      FloatingPoint::PartialBitVector res(
-          arg.convertToBV(param.d_bv_size, rm, false));
+      FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, false));
 
       if (res.second) {
        Node lit = NodeManager::currentNM()->mkConst(res.first);
@@ -881,7 +915,8 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
     Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV_TOTAL);
 
     TNode op = node.getOperator();
-    const FloatingPointToSBVTotal &param = op.getConst<FloatingPointToSBVTotal>();
+    const BitVectorSize& size =
+        op.getConst<FloatingPointToSBVTotal>().d_bv_size;
 
     RoundingMode rm(node[0].getConst<RoundingMode>());
     FloatingPoint arg(node[1].getConst<FloatingPoint>());
@@ -890,14 +925,12 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
     if (node[2].getMetaKind() == kind::metakind::CONSTANT) {
       BitVector partialValue(node[2].getConst<BitVector>());
 
-      BitVector folded(
-          arg.convertToBVTotal(param.d_bv_size, rm, true, partialValue));
+      BitVector folded(arg.convertToBVTotal(size, rm, true, partialValue));
       Node lit = NodeManager::currentNM()->mkConst(folded);
       return RewriteResponse(REWRITE_DONE, lit);
 
     } else {
-      FloatingPoint::PartialBitVector res(
-          arg.convertToBV(param.d_bv_size, rm, true));
+      FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, true));
 
       if (res.second) {
        Node lit = NodeManager::currentNM()->mkConst(res.first);
@@ -1049,7 +1082,7 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
 TheoryFpRewriter::TheoryFpRewriter(context::UserContext* u) : d_fpExpDef(u)
 {
   /* Set up the pre-rewrite dispatch table */
-  for (unsigned i = 0; i < kind::LAST_KIND; ++i)
+  for (uint32_t i = 0; i < kind::LAST_KIND; ++i)
   {
     d_preRewriteTable[i] = rewrite::notFP;
   }
@@ -1140,7 +1173,7 @@ TheoryFpRewriter::TheoryFpRewriter(context::UserContext* u) : d_fpExpDef(u)
   d_preRewriteTable[kind::ROUNDINGMODE_BITBLAST] = rewrite::identity;
 
   /* Set up the post-rewrite dispatch table */
-  for (unsigned i = 0; i < kind::LAST_KIND; ++i)
+  for (uint32_t i = 0; i < kind::LAST_KIND; ++i)
   {
     d_postRewriteTable[i] = rewrite::notFP;
   }
@@ -1197,7 +1230,7 @@ TheoryFpRewriter::TheoryFpRewriter(context::UserContext* u) : d_fpExpDef(u)
       rewrite::identity;
   d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::identity;
   d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] =
-      rewrite::identity;
+      rewrite::toFPSignedBV;
   d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] =
       rewrite::identity;
   d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::identity;
@@ -1228,7 +1261,7 @@ TheoryFpRewriter::TheoryFpRewriter(context::UserContext* u) : d_fpExpDef(u)
   d_postRewriteTable[kind::ROUNDINGMODE_BITBLAST] = rewrite::identity;
 
   /* Set up the post-rewrite constant fold table */
-  for (unsigned i = 0; i < kind::LAST_KIND; ++i)
+  for (uint32_t i = 0; i < kind::LAST_KIND; ++i)
   {
     // Note that this is identity, not notFP
     // Constant folding is called after post-rewrite
index 765b12e90c474946d04f6c0450eb932fd75925cc..685ac8b4ee431afd6113c2d0c7adbb934f5f6cac 100644 (file)
@@ -574,6 +574,7 @@ set(regress_0_tests
   regress0/fp/down-cast-RNA.smt2
   regress0/fp/ext-rew-test.smt2
   regress0/fp/from_ubv.smt2
+  regress0/fp/from_sbv.smt2
   regress0/fp/issue-5524.smt2
   regress0/fp/issue3536.smt2
   regress0/fp/issue3582.smt2
diff --git a/test/regress/regress0/fp/from_sbv.smt2 b/test/regress/regress0/fp/from_sbv.smt2
new file mode 100644 (file)
index 0000000..226d658
--- /dev/null
@@ -0,0 +1,15 @@
+; COMMAND-LINE: --fp-exp
+; EXPECT: unsat
+(set-logic QF_BVFP)
+(declare-const x (_ BitVec 1))
+(declare-const rm RoundingMode)
+(assert (or
+  (distinct ((_ to_fp 5 11) rm #b1) (fp #b1 #b01111 #b0000000000))
+  (distinct ((_ to_fp 5 11) rm #b0) (_ +zero 5 11))
+  (ite
+     (= x #b1)
+     (= ((_ to_fp 5 11) rm x) ((_ to_fp_unsigned 5 11) rm x))
+     (distinct ((_ to_fp 5 11) rm x) ((_ to_fp_unsigned 5 11) rm x))
+  )
+  ))
+(check-sat)