From f02b9bca33f9092c6d793e977326acc315f3b073 Mon Sep 17 00:00:00 2001 From: Andres Noetzli Date: Tue, 29 Mar 2022 07:43:02 -0700 Subject: [PATCH] [API] Add `{is,get}RoundingModeValue()` (#8429) It also fixes a wrong entry in s_rmodes_internal. --- src/api/cpp/cvc5.cpp | 75 +++++++++++++++-------- src/api/cpp/cvc5.h | 10 +++ src/api/java/io/github/cvc5/api/Term.java | 22 +++++++ src/api/java/jni/term.cpp | 28 +++++++++ src/api/python/cvc5.pxd | 7 ++- src/api/python/cvc5.pxi | 11 ++++ test/unit/api/cpp/term_black.cpp | 29 +++++++++ test/unit/api/java/TermTest.java | 25 ++++++++ test/unit/api/python/test_term.py | 30 ++++++++- 9 files changed, 207 insertions(+), 30 deletions(-) diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index 487977963..4f92c98c6 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -719,6 +719,31 @@ const static std::unordered_set s_indexed_kinds( FLOATINGPOINT_TO_FP_FROM_SBV, FLOATINGPOINT_TO_FP_FROM_UBV}); +/* -------------------------------------------------------------------------- */ +/* Rounding Mode for Floating Points */ +/* -------------------------------------------------------------------------- */ + +const static std::unordered_map s_rmodes{ + {ROUND_NEAREST_TIES_TO_EVEN, + cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_EVEN}, + {ROUND_TOWARD_POSITIVE, cvc5::RoundingMode::ROUND_TOWARD_POSITIVE}, + {ROUND_TOWARD_NEGATIVE, cvc5::RoundingMode::ROUND_TOWARD_NEGATIVE}, + {ROUND_TOWARD_ZERO, cvc5::RoundingMode::ROUND_TOWARD_ZERO}, + {ROUND_NEAREST_TIES_TO_AWAY, + cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_AWAY}, +}; + +const static std::unordered_map + s_rmodes_internal{ + {cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_EVEN, + ROUND_NEAREST_TIES_TO_EVEN}, + {cvc5::RoundingMode::ROUND_TOWARD_POSITIVE, ROUND_TOWARD_POSITIVE}, + {cvc5::RoundingMode::ROUND_TOWARD_NEGATIVE, ROUND_TOWARD_NEGATIVE}, + {cvc5::RoundingMode::ROUND_TOWARD_ZERO, ROUND_TOWARD_ZERO}, + {cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_AWAY, + ROUND_NEAREST_TIES_TO_AWAY}, + }; + namespace { /** Convert a cvc5::Kind (internal) to a cvc5::api::Kind (external). */ @@ -3009,6 +3034,29 @@ std::vector Term::getTupleValue() const CVC5_API_TRY_CATCH_END; } +bool Term::isRoundingModeValue() const +{ + CVC5_API_TRY_CATCH_BEGIN; + CVC5_API_CHECK_NOT_NULL; + //////// all checks before this line + return d_node->getKind() == cvc5::Kind::CONST_ROUNDINGMODE; + //////// + CVC5_API_TRY_CATCH_END; +} +RoundingMode Term::getRoundingModeValue() const +{ + CVC5_API_TRY_CATCH_BEGIN; + CVC5_API_CHECK_NOT_NULL; + CVC5_API_ARG_CHECK_EXPECTED( + d_node->getKind() == cvc5::Kind::CONST_ROUNDINGMODE, *d_node) + << "Term to be a floating-point rounding mode value when calling " + "getRoundingModeValue()"; + //////// all checks before this line + return s_rmodes_internal.at(d_node->getConst()); + //////// + CVC5_API_TRY_CATCH_END; +} + bool Term::isFloatingPointPosZero() const { CVC5_API_TRY_CATCH_BEGIN; @@ -4543,33 +4591,6 @@ std::ostream& operator<<(std::ostream& out, const Grammar& grammar) return out << grammar.toString(); } -/* -------------------------------------------------------------------------- */ -/* Rounding Mode for Floating Points */ -/* -------------------------------------------------------------------------- */ - -const static std::unordered_map s_rmodes{ - {ROUND_NEAREST_TIES_TO_EVEN, - cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_EVEN}, - {ROUND_TOWARD_POSITIVE, cvc5::RoundingMode::ROUND_TOWARD_POSITIVE}, - {ROUND_TOWARD_NEGATIVE, cvc5::RoundingMode::ROUND_TOWARD_NEGATIVE}, - {ROUND_TOWARD_ZERO, cvc5::RoundingMode::ROUND_TOWARD_ZERO}, - {ROUND_NEAREST_TIES_TO_AWAY, - cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_AWAY}, -}; - -const static std::unordered_map - s_rmodes_internal{ - {cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_EVEN, - ROUND_NEAREST_TIES_TO_EVEN}, - {cvc5::RoundingMode::ROUND_TOWARD_POSITIVE, ROUND_TOWARD_POSITIVE}, - {cvc5::RoundingMode::ROUND_TOWARD_POSITIVE, ROUND_TOWARD_NEGATIVE}, - {cvc5::RoundingMode::ROUND_TOWARD_ZERO, ROUND_TOWARD_ZERO}, - {cvc5::RoundingMode::ROUND_NEAREST_TIES_TO_AWAY, - ROUND_NEAREST_TIES_TO_AWAY}, - }; - /* -------------------------------------------------------------------------- */ /* Options */ /* -------------------------------------------------------------------------- */ diff --git a/src/api/cpp/cvc5.h b/src/api/cpp/cvc5.h index 18f66dfa8..3fd0dad4a 100644 --- a/src/api/cpp/cvc5.h +++ b/src/api/cpp/cvc5.h @@ -1526,6 +1526,16 @@ class CVC5_EXPORT Term */ std::vector getTupleValue() const; + /** + * @return true if the term is a floating-point rounding mode value. + */ + bool isRoundingModeValue() const; + /** + * Asserts isRoundingModeValue(). + * @return the floating-point rounding mode value held by the term. + */ + RoundingMode getRoundingModeValue() const; + /** * @return true if the term is the floating-point value for positive zero. */ diff --git a/src/api/java/io/github/cvc5/api/Term.java b/src/api/java/io/github/cvc5/api/Term.java index 7fddc49be..ba95d3458 100644 --- a/src/api/java/io/github/cvc5/api/Term.java +++ b/src/api/java/io/github/cvc5/api/Term.java @@ -522,6 +522,28 @@ public class Term extends AbstractPointer implements Comparable, Iterable< private native String getUninterpretedSortValue(long pointer); + /** + * @return true if the term is a floating-point rounding mode value. + */ + public boolean isRoundingModeValue() + { + return isRoundingModeValue(pointer); + } + + private native boolean isRoundingModeValue(long pointer); + + /** + * Asserts isRoundingModeValue(). + * @return the floating-point rounding mode value held by the term. + */ + public RoundingMode getRoundingModeValue() throws CVC5ApiException + { + int value = getRoundingModeValue(pointer); + return RoundingMode.fromInt(value); + } + + private native int getRoundingModeValue(long pointer); + /** * @return true if the term is a tuple value. */ diff --git a/src/api/java/jni/term.cpp b/src/api/java/jni/term.cpp index c45bb4d6a..a0f5f1bc4 100644 --- a/src/api/java/jni/term.cpp +++ b/src/api/java/jni/term.cpp @@ -663,6 +663,34 @@ Java_io_github_cvc5_api_Term_getUninterpretedSortValue(JNIEnv* env, CVC5_JAVA_API_TRY_CATCH_END_RETURN(env, nullptr); } +/* + * Class: io_github_cvc5_api_Term + * Method: isRoundingModeValue + * Signature: (J)Z + */ +JNIEXPORT jboolean JNICALL Java_io_github_cvc5_api_Term_isRoundingModeValue( + JNIEnv* env, jobject, jlong pointer) +{ + CVC5_JAVA_API_TRY_CATCH_BEGIN; + Term* current = reinterpret_cast(pointer); + return static_cast(current->isRoundingModeValue()); + CVC5_JAVA_API_TRY_CATCH_END_RETURN(env, static_cast(false)); +} + +/* + * Class: io_github_cvc5_api_Term + * Method: getRoundingModeValue + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_io_github_cvc5_api_Term_getRoundingModeValue( + JNIEnv* env, jobject, jlong pointer) +{ + CVC5_JAVA_API_TRY_CATCH_BEGIN; + Term* current = reinterpret_cast(pointer); + return static_cast(current->getRoundingModeValue()); + CVC5_JAVA_API_TRY_CATCH_END_RETURN(env, 0); +} + /* * Class: io_github_cvc5_api_Term * Method: isTupleValue diff --git a/src/api/python/cvc5.pxd b/src/api/python/cvc5.pxd index f03941c27..ce5dd100c 100644 --- a/src/api/python/cvc5.pxd +++ b/src/api/python/cvc5.pxd @@ -505,6 +505,11 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api": string getBitVectorValue(uint32_t base) except + bint isUninterpretedSortValue() except + string getUninterpretedSortValue() except + + bint isTupleValue() except + + vector[Term] getTupleValue() except + + bint isRoundingModeValue() except + + RoundingMode getRoundingModeValue() except + + bint isFloatingPointPosZero() except + bint isFloatingPointNegZero() except + bint isFloatingPointPosInf() except + @@ -517,8 +522,6 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api": set[Term] getSetValue() except + bint isSequenceValue() except + vector[Term] getSequenceValue() except + - bint isTupleValue() except + - vector[Term] getTupleValue() except + cdef cppclass TermHashFunction: diff --git a/src/api/python/cvc5.pxi b/src/api/python/cvc5.pxi index 23fc5fa69..a378434ab 100644 --- a/src/api/python/cvc5.pxi +++ b/src/api/python/cvc5.pxi @@ -3531,6 +3531,17 @@ cdef class Term: """:return: True iff this term is a tuple value.""" return self.cterm.isTupleValue() + def isRoundingModeValue(self): + """:return: True if the term is a floating-point rounding mode value.""" + return self.cterm.isRoundingModeValue() + + def getRoundingModeValue(self): + """ + Asserts isRoundingModeValue(). + :return: the floating-point rounding mode value held by the term. + """ + return RoundingMode( self.cterm.getRoundingModeValue()) + def getTupleValue(self): """ Asserts :py:meth:`isTupleValue()`. diff --git a/test/unit/api/cpp/term_black.cpp b/test/unit/api/cpp/term_black.cpp index 81dfd2e43..070ccacab 100644 --- a/test/unit/api/cpp/term_black.cpp +++ b/test/unit/api/cpp/term_black.cpp @@ -920,6 +920,35 @@ TEST_F(TestApiBlackTerm, getUninterpretedSortValue) ASSERT_EQ(vx.getUninterpretedSortValue(), vy.getUninterpretedSortValue()); } +TEST_F(TestApiBlackTerm, isRoundingModeValue) +{ + ASSERT_FALSE(d_solver.mkInteger(15).isRoundingModeValue()); + ASSERT_TRUE(d_solver.mkRoundingMode(RoundingMode::ROUND_NEAREST_TIES_TO_EVEN) + .isRoundingModeValue()); + ASSERT_FALSE( + d_solver.mkConst(d_solver.getRoundingModeSort()).isRoundingModeValue()); +} + +TEST_F(TestApiBlackTerm, getRoundingModeValue) +{ + ASSERT_THROW(d_solver.mkInteger(15).getRoundingModeValue(), CVC5ApiException); + ASSERT_EQ(d_solver.mkRoundingMode(RoundingMode::ROUND_NEAREST_TIES_TO_EVEN) + .getRoundingModeValue(), + RoundingMode::ROUND_NEAREST_TIES_TO_EVEN); + ASSERT_EQ(d_solver.mkRoundingMode(RoundingMode::ROUND_TOWARD_POSITIVE) + .getRoundingModeValue(), + RoundingMode::ROUND_TOWARD_POSITIVE); + ASSERT_EQ(d_solver.mkRoundingMode(RoundingMode::ROUND_TOWARD_NEGATIVE) + .getRoundingModeValue(), + RoundingMode::ROUND_TOWARD_NEGATIVE); + ASSERT_EQ(d_solver.mkRoundingMode(RoundingMode::ROUND_TOWARD_ZERO) + .getRoundingModeValue(), + RoundingMode::ROUND_TOWARD_ZERO); + ASSERT_EQ(d_solver.mkRoundingMode(RoundingMode::ROUND_NEAREST_TIES_TO_AWAY) + .getRoundingModeValue(), + RoundingMode::ROUND_NEAREST_TIES_TO_AWAY); +} + TEST_F(TestApiBlackTerm, getTuple) { Sort s1 = d_solver.getIntegerSort(); diff --git a/test/unit/api/java/TermTest.java b/test/unit/api/java/TermTest.java index 9baa290af..e9875d362 100644 --- a/test/unit/api/java/TermTest.java +++ b/test/unit/api/java/TermTest.java @@ -870,6 +870,31 @@ class TermTest assertDoesNotThrow(() -> vy.getUninterpretedSortValue()); } + @Test void isRoundingModeValue() throws CVC5ApiException + { + assertFalse(d_solver.mkInteger(15).isRoundingModeValue()); + assertTrue( + d_solver.mkRoundingMode(RoundingMode.ROUND_NEAREST_TIES_TO_EVEN).isRoundingModeValue()); + assertFalse(d_solver.mkConst(d_solver.getRoundingModeSort()).isRoundingModeValue()); + } + + @Test void getRoundingModeValue() throws CVC5ApiException + { + assertThrows(CVC5ApiException.class, () -> d_solver.mkInteger(15).getRoundingModeValue()); + assertEquals( + d_solver.mkRoundingMode(RoundingMode.ROUND_NEAREST_TIES_TO_EVEN).getRoundingModeValue(), + RoundingMode.ROUND_NEAREST_TIES_TO_EVEN); + assertEquals(d_solver.mkRoundingMode(RoundingMode.ROUND_TOWARD_POSITIVE).getRoundingModeValue(), + RoundingMode.ROUND_TOWARD_POSITIVE); + assertEquals(d_solver.mkRoundingMode(RoundingMode.ROUND_TOWARD_NEGATIVE).getRoundingModeValue(), + RoundingMode.ROUND_TOWARD_NEGATIVE); + assertEquals(d_solver.mkRoundingMode(RoundingMode.ROUND_TOWARD_ZERO).getRoundingModeValue(), + RoundingMode.ROUND_TOWARD_ZERO); + assertEquals( + d_solver.mkRoundingMode(RoundingMode.ROUND_NEAREST_TIES_TO_AWAY).getRoundingModeValue(), + RoundingMode.ROUND_NEAREST_TIES_TO_AWAY); + } + @Test void getTuple() { Sort s1 = d_solver.getIntegerSort(); diff --git a/test/unit/api/python/test_term.py b/test/unit/api/python/test_term.py index 27d1d913e..16f3fa180 100644 --- a/test/unit/api/python/test_term.py +++ b/test/unit/api/python/test_term.py @@ -13,7 +13,7 @@ import pytest import cvc5 -from cvc5 import Kind +from cvc5 import Kind, RoundingMode from cvc5 import Sort, Term from fractions import Fraction @@ -972,6 +972,34 @@ def test_get_uninterpreted_sort_value(solver): assert vx.getUninterpretedSortValue() == vy.getUninterpretedSortValue() +def test_is_rounding_mode_value(solver): + assert not solver.mkInteger(15).isRoundingModeValue() + assert solver.mkRoundingMode( + RoundingMode.RoundNearestTiesToEven).isRoundingModeValue() + assert not solver.mkConst( + solver.getRoundingModeSort()).isRoundingModeValue() + + +def test_get_rounding_mode_value(solver): + with pytest.raises(RuntimeError): + solver.mkInteger(15).getRoundingModeValue() + assert solver.mkRoundingMode( + RoundingMode.RoundNearestTiesToEven).getRoundingModeValue( + ) == RoundingMode.RoundNearestTiesToEven + assert solver.mkRoundingMode( + RoundingMode.RoundTowardPositive).getRoundingModeValue( + ) == RoundingMode.RoundTowardPositive + assert solver.mkRoundingMode( + RoundingMode.RoundTowardNegative).getRoundingModeValue( + ) == RoundingMode.RoundTowardNegative + assert solver.mkRoundingMode( + RoundingMode.RoundTowardZero).getRoundingModeValue( + ) == RoundingMode.RoundTowardZero + assert solver.mkRoundingMode( + RoundingMode.RoundNearestTiesToAway).getRoundingModeValue( + ) == RoundingMode.RoundNearestTiesToAway + + def test_get_tuple(solver): s1 = solver.getIntegerSort() s2 = solver.getRealSort() -- 2.30.2