[API] Add `{is,get}RoundingModeValue()` (#8429)
authorAndres Noetzli <andres.noetzli@gmail.com>
Tue, 29 Mar 2022 14:43:02 +0000 (07:43 -0700)
committerGitHub <noreply@github.com>
Tue, 29 Mar 2022 14:43:02 +0000 (14:43 +0000)
It also fixes a wrong entry in s_rmodes_internal.

src/api/cpp/cvc5.cpp
src/api/cpp/cvc5.h
src/api/java/io/github/cvc5/api/Term.java
src/api/java/jni/term.cpp
src/api/python/cvc5.pxd
src/api/python/cvc5.pxi
test/unit/api/cpp/term_black.cpp
test/unit/api/java/TermTest.java
test/unit/api/python/test_term.py

index 4879779632856e597c3944540f8e3749964535cd..4f92c98c6c90fff0b978449eefa847e36661034e 100644 (file)
@@ -719,6 +719,31 @@ const static std::unordered_set<Kind> s_indexed_kinds(
      FLOATINGPOINT_TO_FP_FROM_SBV,
      FLOATINGPOINT_TO_FP_FROM_UBV});
 
+/* -------------------------------------------------------------------------- */
+/* Rounding Mode for Floating Points                                          */
+/* -------------------------------------------------------------------------- */
+
+const static std::unordered_map<RoundingMode, cvc5::RoundingMode> s_rmodes{
+    {ROUND_NEAREST_TIES_TO_EVEN,
+     cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_EVEN},
+    {ROUND_TOWARD_POSITIVE, cvc5::RoundingMode::ROUND_TOWARD_POSITIVE},
+    {ROUND_TOWARD_NEGATIVE, cvc5::RoundingMode::ROUND_TOWARD_NEGATIVE},
+    {ROUND_TOWARD_ZERO, cvc5::RoundingMode::ROUND_TOWARD_ZERO},
+    {ROUND_NEAREST_TIES_TO_AWAY,
+     cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_AWAY},
+};
+
+const static std::unordered_map<cvc5::RoundingMode, RoundingMode>
+    s_rmodes_internal{
+        {cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_EVEN,
+         ROUND_NEAREST_TIES_TO_EVEN},
+        {cvc5::RoundingMode::ROUND_TOWARD_POSITIVE, ROUND_TOWARD_POSITIVE},
+        {cvc5::RoundingMode::ROUND_TOWARD_NEGATIVE, ROUND_TOWARD_NEGATIVE},
+        {cvc5::RoundingMode::ROUND_TOWARD_ZERO, ROUND_TOWARD_ZERO},
+        {cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_AWAY,
+         ROUND_NEAREST_TIES_TO_AWAY},
+    };
+
 namespace {
 
 /** Convert a cvc5::Kind (internal) to a cvc5::api::Kind (external). */
@@ -3009,6 +3034,29 @@ std::vector<Term> Term::getTupleValue() const
   CVC5_API_TRY_CATCH_END;
 }
 
+bool Term::isRoundingModeValue() const
+{
+  CVC5_API_TRY_CATCH_BEGIN;
+  CVC5_API_CHECK_NOT_NULL;
+  //////// all checks before this line
+  return d_node->getKind() == cvc5::Kind::CONST_ROUNDINGMODE;
+  ////////
+  CVC5_API_TRY_CATCH_END;
+}
+RoundingMode Term::getRoundingModeValue() const
+{
+  CVC5_API_TRY_CATCH_BEGIN;
+  CVC5_API_CHECK_NOT_NULL;
+  CVC5_API_ARG_CHECK_EXPECTED(
+      d_node->getKind() == cvc5::Kind::CONST_ROUNDINGMODE, *d_node)
+      << "Term to be a floating-point rounding mode value when calling "
+         "getRoundingModeValue()";
+  //////// all checks before this line
+  return s_rmodes_internal.at(d_node->getConst<cvc5::RoundingMode>());
+  ////////
+  CVC5_API_TRY_CATCH_END;
+}
+
 bool Term::isFloatingPointPosZero() const
 {
   CVC5_API_TRY_CATCH_BEGIN;
@@ -4543,33 +4591,6 @@ std::ostream& operator<<(std::ostream& out, const Grammar& grammar)
   return out << grammar.toString();
 }
 
-/* -------------------------------------------------------------------------- */
-/* Rounding Mode for Floating Points                                          */
-/* -------------------------------------------------------------------------- */
-
-const static std::unordered_map<RoundingMode, cvc5::RoundingMode> s_rmodes{
-    {ROUND_NEAREST_TIES_TO_EVEN,
-     cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_EVEN},
-    {ROUND_TOWARD_POSITIVE, cvc5::RoundingMode::ROUND_TOWARD_POSITIVE},
-    {ROUND_TOWARD_NEGATIVE, cvc5::RoundingMode::ROUND_TOWARD_NEGATIVE},
-    {ROUND_TOWARD_ZERO, cvc5::RoundingMode::ROUND_TOWARD_ZERO},
-    {ROUND_NEAREST_TIES_TO_AWAY,
-     cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_AWAY},
-};
-
-const static std::unordered_map<cvc5::RoundingMode,
-                                RoundingMode,
-                                cvc5::RoundingModeHashFunction>
-    s_rmodes_internal{
-        {cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_EVEN,
-         ROUND_NEAREST_TIES_TO_EVEN},
-        {cvc5::RoundingMode::ROUND_TOWARD_POSITIVE, ROUND_TOWARD_POSITIVE},
-        {cvc5::RoundingMode::ROUND_TOWARD_POSITIVE, ROUND_TOWARD_NEGATIVE},
-        {cvc5::RoundingMode::ROUND_TOWARD_ZERO, ROUND_TOWARD_ZERO},
-        {cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_AWAY,
-         ROUND_NEAREST_TIES_TO_AWAY},
-    };
-
 /* -------------------------------------------------------------------------- */
 /* Options                                                                    */
 /* -------------------------------------------------------------------------- */
index 18f66dfa8d3c9ed9af07534629381497fce1e46e..3fd0dad4a557fddb5c2f733f032f512f71c0880f 100644 (file)
@@ -1526,6 +1526,16 @@ class CVC5_EXPORT Term
    */
   std::vector<Term> getTupleValue() const;
 
+  /**
+   * @return true if the term is a floating-point rounding mode value.
+   */
+  bool isRoundingModeValue() const;
+  /**
+   * Asserts isRoundingModeValue().
+   * @return the floating-point rounding mode value held by the term.
+   */
+  RoundingMode getRoundingModeValue() const;
+
   /**
    * @return true if the term is the floating-point value for positive zero.
    */
index 7fddc49be15dd525f71ca4969c14a18d6a8ab9a2..ba95d3458bf340f248020138dda770dcd6061623 100644 (file)
@@ -522,6 +522,28 @@ public class Term extends AbstractPointer implements Comparable<Term>, Iterable<
 
   private native String getUninterpretedSortValue(long pointer);
 
+  /**
+   * @return true if the term is a floating-point rounding mode value.
+   */
+  public boolean isRoundingModeValue()
+  {
+    return isRoundingModeValue(pointer);
+  }
+
+  private native boolean isRoundingModeValue(long pointer);
+
+  /**
+   * Asserts isRoundingModeValue().
+   * @return the floating-point rounding mode value held by the term.
+   */
+  public RoundingMode getRoundingModeValue() throws CVC5ApiException
+  {
+    int value = getRoundingModeValue(pointer);
+    return RoundingMode.fromInt(value);
+  }
+
+  private native int getRoundingModeValue(long pointer);
+
   /**
    * @return true if the term is a tuple value.
    */
index c45bb4d6a4f9130821e83f0ae35822738e15388b..a0f5f1bc4eeb01176e861be0d70d6f34a5480720 100644 (file)
@@ -663,6 +663,34 @@ Java_io_github_cvc5_api_Term_getUninterpretedSortValue(JNIEnv* env,
   CVC5_JAVA_API_TRY_CATCH_END_RETURN(env, nullptr);
 }
 
+/*
+ * Class:     io_github_cvc5_api_Term
+ * Method:    isRoundingModeValue
+ * Signature: (J)Z
+ */
+JNIEXPORT jboolean JNICALL Java_io_github_cvc5_api_Term_isRoundingModeValue(
+    JNIEnv* env, jobject, jlong pointer)
+{
+  CVC5_JAVA_API_TRY_CATCH_BEGIN;
+  Term* current = reinterpret_cast<Term*>(pointer);
+  return static_cast<jboolean>(current->isRoundingModeValue());
+  CVC5_JAVA_API_TRY_CATCH_END_RETURN(env, static_cast<jboolean>(false));
+}
+
+/*
+ * Class:     io_github_cvc5_api_Term
+ * Method:    getRoundingModeValue
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_io_github_cvc5_api_Term_getRoundingModeValue(
+    JNIEnv* env, jobject, jlong pointer)
+{
+  CVC5_JAVA_API_TRY_CATCH_BEGIN;
+  Term* current = reinterpret_cast<Term*>(pointer);
+  return static_cast<jint>(current->getRoundingModeValue());
+  CVC5_JAVA_API_TRY_CATCH_END_RETURN(env, 0);
+}
+
 /*
  * Class:     io_github_cvc5_api_Term
  * Method:    isTupleValue
index f03941c27d9af437eb30d78d1da3459fe2edbf67..ce5dd100c801de358d2e89098da6872414cee6cf 100644 (file)
@@ -505,6 +505,11 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         string getBitVectorValue(uint32_t base) except +
         bint isUninterpretedSortValue() except +
         string getUninterpretedSortValue() except +
+        bint isTupleValue() except +
+        vector[Term] getTupleValue() except +
+        bint isRoundingModeValue() except +
+        RoundingMode getRoundingModeValue() except +
+
         bint isFloatingPointPosZero() except +
         bint isFloatingPointNegZero() except +
         bint isFloatingPointPosInf() except +
@@ -517,8 +522,6 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         set[Term] getSetValue() except +
         bint isSequenceValue() except +
         vector[Term] getSequenceValue() except +
-        bint isTupleValue() except +
-        vector[Term] getTupleValue() except +
 
 
     cdef cppclass TermHashFunction:
index 23fc5fa6948fe35cf709347affba5991a840d5fd..a378434ab3dd3c4ef471a912614c6409d135a053 100644 (file)
@@ -3531,6 +3531,17 @@ cdef class Term:
         """:return: True iff this term is a tuple value."""
         return self.cterm.isTupleValue()
 
+    def isRoundingModeValue(self):
+        """:return: True if the term is a floating-point rounding mode value."""
+        return self.cterm.isRoundingModeValue()
+
+    def getRoundingModeValue(self):
+        """
+        Asserts isRoundingModeValue().
+        :return: the floating-point rounding mode value held by the term.
+        """
+        return RoundingMode(<int> self.cterm.getRoundingModeValue())
+
     def getTupleValue(self):
         """
           Asserts :py:meth:`isTupleValue()`.
index 81dfd2e43cefa3a8139ce70051e4e13cb9a9cb50..070ccacab3eb5a49714eb8158c8395873176c4fc 100644 (file)
@@ -920,6 +920,35 @@ TEST_F(TestApiBlackTerm, getUninterpretedSortValue)
   ASSERT_EQ(vx.getUninterpretedSortValue(), vy.getUninterpretedSortValue());
 }
 
+TEST_F(TestApiBlackTerm, isRoundingModeValue)
+{
+  ASSERT_FALSE(d_solver.mkInteger(15).isRoundingModeValue());
+  ASSERT_TRUE(d_solver.mkRoundingMode(RoundingMode::ROUND_NEAREST_TIES_TO_EVEN)
+                  .isRoundingModeValue());
+  ASSERT_FALSE(
+      d_solver.mkConst(d_solver.getRoundingModeSort()).isRoundingModeValue());
+}
+
+TEST_F(TestApiBlackTerm, getRoundingModeValue)
+{
+  ASSERT_THROW(d_solver.mkInteger(15).getRoundingModeValue(), CVC5ApiException);
+  ASSERT_EQ(d_solver.mkRoundingMode(RoundingMode::ROUND_NEAREST_TIES_TO_EVEN)
+                .getRoundingModeValue(),
+            RoundingMode::ROUND_NEAREST_TIES_TO_EVEN);
+  ASSERT_EQ(d_solver.mkRoundingMode(RoundingMode::ROUND_TOWARD_POSITIVE)
+                .getRoundingModeValue(),
+            RoundingMode::ROUND_TOWARD_POSITIVE);
+  ASSERT_EQ(d_solver.mkRoundingMode(RoundingMode::ROUND_TOWARD_NEGATIVE)
+                .getRoundingModeValue(),
+            RoundingMode::ROUND_TOWARD_NEGATIVE);
+  ASSERT_EQ(d_solver.mkRoundingMode(RoundingMode::ROUND_TOWARD_ZERO)
+                .getRoundingModeValue(),
+            RoundingMode::ROUND_TOWARD_ZERO);
+  ASSERT_EQ(d_solver.mkRoundingMode(RoundingMode::ROUND_NEAREST_TIES_TO_AWAY)
+                .getRoundingModeValue(),
+            RoundingMode::ROUND_NEAREST_TIES_TO_AWAY);
+}
+
 TEST_F(TestApiBlackTerm, getTuple)
 {
   Sort s1 = d_solver.getIntegerSort();
index 9baa290af7492966eee4e2e2c981afbf9eb85bea..e9875d362bfe84eeb1994356e587a2685805ce6b 100644 (file)
@@ -870,6 +870,31 @@ class TermTest
     assertDoesNotThrow(() -> vy.getUninterpretedSortValue());
   }
 
+  @Test void isRoundingModeValue() throws CVC5ApiException
+  {
+    assertFalse(d_solver.mkInteger(15).isRoundingModeValue());
+    assertTrue(
+        d_solver.mkRoundingMode(RoundingMode.ROUND_NEAREST_TIES_TO_EVEN).isRoundingModeValue());
+    assertFalse(d_solver.mkConst(d_solver.getRoundingModeSort()).isRoundingModeValue());
+  }
+
+  @Test void getRoundingModeValue() throws CVC5ApiException
+  {
+    assertThrows(CVC5ApiException.class, () -> d_solver.mkInteger(15).getRoundingModeValue());
+    assertEquals(
+        d_solver.mkRoundingMode(RoundingMode.ROUND_NEAREST_TIES_TO_EVEN).getRoundingModeValue(),
+        RoundingMode.ROUND_NEAREST_TIES_TO_EVEN);
+    assertEquals(d_solver.mkRoundingMode(RoundingMode.ROUND_TOWARD_POSITIVE).getRoundingModeValue(),
+        RoundingMode.ROUND_TOWARD_POSITIVE);
+    assertEquals(d_solver.mkRoundingMode(RoundingMode.ROUND_TOWARD_NEGATIVE).getRoundingModeValue(),
+        RoundingMode.ROUND_TOWARD_NEGATIVE);
+    assertEquals(d_solver.mkRoundingMode(RoundingMode.ROUND_TOWARD_ZERO).getRoundingModeValue(),
+        RoundingMode.ROUND_TOWARD_ZERO);
+    assertEquals(
+        d_solver.mkRoundingMode(RoundingMode.ROUND_NEAREST_TIES_TO_AWAY).getRoundingModeValue(),
+        RoundingMode.ROUND_NEAREST_TIES_TO_AWAY);
+  }
+
   @Test void getTuple()
   {
     Sort s1 = d_solver.getIntegerSort();
index 27d1d913e3cddf743fabaa97389f6580172c15dc..16f3fa180afe773e0c17d4c89f563a73582dd67e 100644 (file)
@@ -13,7 +13,7 @@
 
 import pytest
 import cvc5
-from cvc5 import Kind
+from cvc5 import Kind, RoundingMode
 from cvc5 import Sort, Term
 from fractions import Fraction
 
@@ -972,6 +972,34 @@ def test_get_uninterpreted_sort_value(solver):
     assert vx.getUninterpretedSortValue() == vy.getUninterpretedSortValue()
 
 
+def test_is_rounding_mode_value(solver):
+    assert not solver.mkInteger(15).isRoundingModeValue()
+    assert solver.mkRoundingMode(
+        RoundingMode.RoundNearestTiesToEven).isRoundingModeValue()
+    assert not solver.mkConst(
+        solver.getRoundingModeSort()).isRoundingModeValue()
+
+
+def test_get_rounding_mode_value(solver):
+    with pytest.raises(RuntimeError):
+        solver.mkInteger(15).getRoundingModeValue()
+    assert solver.mkRoundingMode(
+        RoundingMode.RoundNearestTiesToEven).getRoundingModeValue(
+        ) == RoundingMode.RoundNearestTiesToEven
+    assert solver.mkRoundingMode(
+        RoundingMode.RoundTowardPositive).getRoundingModeValue(
+        ) == RoundingMode.RoundTowardPositive
+    assert solver.mkRoundingMode(
+        RoundingMode.RoundTowardNegative).getRoundingModeValue(
+        ) == RoundingMode.RoundTowardNegative
+    assert solver.mkRoundingMode(
+        RoundingMode.RoundTowardZero).getRoundingModeValue(
+        ) == RoundingMode.RoundTowardZero
+    assert solver.mkRoundingMode(
+        RoundingMode.RoundNearestTiesToAway).getRoundingModeValue(
+        ) == RoundingMode.RoundNearestTiesToAway
+
+
 def test_get_tuple(solver):
     s1 = solver.getIntegerSort()
     s2 = solver.getRealSort()