New C++ API: Add templated getIndices method for OpTerm (#3073)
authormakaimann <makaim@stanford.edu>
Sun, 11 Aug 2019 16:07:31 +0000 (09:07 -0700)
committerGitHub <noreply@github.com>
Sun, 11 Aug 2019 16:07:31 +0000 (09:07 -0700)
* Implement templated getIndices method for OpTerm

* Add getIndices unit tests

* Update src/api/cvc4cpp.cpp

Co-Authored-By: Aina Niemetz <aina.niemetz@gmail.com>
* Update src/api/cvc4cpp.cpp

Co-Authored-By: Aina Niemetz <aina.niemetz@gmail.com>
* Add comment about DIVISIBLE_OP

* Update test/unit/api/opterm_black.h

Co-Authored-By: Aina Niemetz <aina.niemetz@gmail.com>
* Update test/unit/api/opterm_black.h

Co-Authored-By: Aina Niemetz <aina.niemetz@gmail.com>
* Update test/unit/api/opterm_black.h

Co-Authored-By: Aina Niemetz <aina.niemetz@gmail.com>
* Update test/unit/api/opterm_black.h

Co-Authored-By: Aina Niemetz <aina.niemetz@gmail.com>
* Add exception checks to other unit tests (instead of having its own function)

* Fix unit test names in opterm_black.h

* Add description to docstring for getIndices

* Formatting

* Clang format older commits

* Use '-' in docstring list to match other docstrings

* Support creating DIVISIBLE_OP with a string (for arbitrary precision integers)

* Move mkOpTerm(DIVISIBLE_OP, <str>) test to solver_black.h

* Fix pointer access

* Replace switch statement with if statement

* Guard string input for CVC4::Integer in mkOpTerm for consistency on GMP/CLN back-end

src/api/cvc4cpp.cpp
src/api/cvc4cpp.h
test/unit/api/opterm_black.h
test/unit/api/solver_black.h

index bdb5f2f59f8e70d01ce3caadaec88a16818c82d4..942235e9c0c921b4a7d52a1a1630b3218b614130 100644 (file)
@@ -1288,6 +1288,144 @@ Sort OpTerm::getSort() const
 
 bool OpTerm::isNull() const { return d_expr->isNull(); }
 
+template <>
+std::string OpTerm::getIndices() const
+{
+  CVC4_API_CHECK_NOT_NULL;
+  std::string i;
+  Kind k = intToExtKind(d_expr->getKind());
+
+  if (k == DIVISIBLE_OP)
+  {
+    // DIVISIBLE_OP returns a string index to support
+    // arbitrary precision integers
+    CVC4::Integer _int = d_expr->getConst<Divisible>().k;
+    i = _int.toString();
+  }
+  else if (k == RECORD_UPDATE_OP)
+  {
+    i = d_expr->getConst<RecordUpdate>().getField();
+  }
+  else
+  {
+    CVC4_API_CHECK(false) << "Can't get string index from"
+                          << " kind " << kindToString(k);
+  }
+
+  return i;
+}
+
+template <>
+Kind OpTerm::getIndices() const
+{
+  CVC4_API_CHECK_NOT_NULL;
+  Kind kind = intToExtKind(d_expr->getKind());
+  CVC4_API_KIND_CHECK_EXPECTED(kind == CHAIN_OP, kind) << "CHAIN_OP";
+  return intToExtKind(d_expr->getConst<Chain>().getOperator());
+}
+
+template <>
+uint32_t OpTerm::getIndices() const
+{
+  CVC4_API_CHECK_NOT_NULL;
+  uint32_t i;
+  Kind k = intToExtKind(d_expr->getKind());
+  switch (k)
+  {
+    case BITVECTOR_REPEAT_OP:
+      i = d_expr->getConst<BitVectorRepeat>().repeatAmount;
+      break;
+    case BITVECTOR_ZERO_EXTEND_OP:
+      i = d_expr->getConst<BitVectorZeroExtend>().zeroExtendAmount;
+      break;
+    case BITVECTOR_SIGN_EXTEND_OP:
+      i = d_expr->getConst<BitVectorSignExtend>().signExtendAmount;
+      break;
+    case BITVECTOR_ROTATE_LEFT_OP:
+      i = d_expr->getConst<BitVectorRotateLeft>().rotateLeftAmount;
+      break;
+    case BITVECTOR_ROTATE_RIGHT_OP:
+      i = d_expr->getConst<BitVectorRotateRight>().rotateRightAmount;
+      break;
+    case INT_TO_BITVECTOR_OP:
+      i = d_expr->getConst<IntToBitVector>().size;
+      break;
+    case FLOATINGPOINT_TO_UBV_OP:
+      i = d_expr->getConst<FloatingPointToUBV>().bvs.size;
+      break;
+    case FLOATINGPOINT_TO_UBV_TOTAL_OP:
+      i = d_expr->getConst<FloatingPointToUBVTotal>().bvs.size;
+      break;
+    case FLOATINGPOINT_TO_SBV_OP:
+      i = d_expr->getConst<FloatingPointToSBV>().bvs.size;
+      break;
+    case FLOATINGPOINT_TO_SBV_TOTAL_OP:
+      i = d_expr->getConst<FloatingPointToSBVTotal>().bvs.size;
+      break;
+    case TUPLE_UPDATE_OP: i = d_expr->getConst<TupleUpdate>().getIndex(); break;
+    default:
+      CVC4ApiExceptionStream().ostream() << "Can't get uint32_t index from"
+                                         << " kind " << kindToString(k);
+  }
+  return i;
+}
+
+template <>
+std::pair<uint32_t, uint32_t> OpTerm::getIndices() const
+{
+  CVC4_API_CHECK_NOT_NULL;
+  std::pair<uint32_t, uint32_t> indices;
+  Kind k = intToExtKind(d_expr->getKind());
+
+  // using if/else instead of case statement because want local variables
+  if (k == BITVECTOR_EXTRACT_OP)
+  {
+    CVC4::BitVectorExtract ext = d_expr->getConst<BitVectorExtract>();
+    indices = std::make_pair(ext.high, ext.low);
+  }
+  else if (k == FLOATINGPOINT_TO_FP_IEEE_BITVECTOR_OP)
+  {
+    CVC4::FloatingPointToFPIEEEBitVector ext =
+        d_expr->getConst<FloatingPointToFPIEEEBitVector>();
+    indices = std::make_pair(ext.t.exponent(), ext.t.significand());
+  }
+  else if (k == FLOATINGPOINT_TO_FP_FLOATINGPOINT_OP)
+  {
+    CVC4::FloatingPointToFPFloatingPoint ext =
+        d_expr->getConst<FloatingPointToFPFloatingPoint>();
+    indices = std::make_pair(ext.t.exponent(), ext.t.significand());
+  }
+  else if (k == FLOATINGPOINT_TO_FP_REAL_OP)
+  {
+    CVC4::FloatingPointToFPReal ext = d_expr->getConst<FloatingPointToFPReal>();
+    indices = std::make_pair(ext.t.exponent(), ext.t.significand());
+  }
+  else if (k == FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR_OP)
+  {
+    CVC4::FloatingPointToFPSignedBitVector ext =
+        d_expr->getConst<FloatingPointToFPSignedBitVector>();
+    indices = std::make_pair(ext.t.exponent(), ext.t.significand());
+  }
+  else if (k == FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR_OP)
+  {
+    CVC4::FloatingPointToFPUnsignedBitVector ext =
+        d_expr->getConst<FloatingPointToFPUnsignedBitVector>();
+    indices = std::make_pair(ext.t.exponent(), ext.t.significand());
+  }
+  else if (k == FLOATINGPOINT_TO_FP_GENERIC_OP)
+  {
+    CVC4::FloatingPointToFPGeneric ext =
+        d_expr->getConst<FloatingPointToFPGeneric>();
+    indices = std::make_pair(ext.t.exponent(), ext.t.significand());
+  }
+  else
+  {
+    CVC4_API_CHECK(false) << "Can't get pair<uint32_t, uint32_t> indices from"
+                          << " kind " << kindToString(k);
+  }
+  return indices;
+}
+
 std::string OpTerm::toString() const { return d_expr->toString(); }
 
 // !!! This is only temporarily available until the parser is fully migrated
@@ -2724,10 +2862,26 @@ OpTerm Solver::mkOpTerm(Kind kind, Kind k) const
 OpTerm Solver::mkOpTerm(Kind kind, const std::string& arg) const
 {
   CVC4_API_SOLVER_TRY_CATCH_BEGIN;
-  CVC4_API_KIND_CHECK_EXPECTED(kind == RECORD_UPDATE_OP, kind)
-      << "RECORD_UPDATE_OP";
-
-  return *mkValHelper<CVC4::RecordUpdate>(CVC4::RecordUpdate(arg)).d_expr.get();
+  CVC4_API_KIND_CHECK_EXPECTED(
+      (kind == RECORD_UPDATE_OP) || (kind == DIVISIBLE_OP), kind)
+      << "RECORD_UPDATE_OP or DIVISIBLE_OP";
+  OpTerm res;
+  if (kind == RECORD_UPDATE_OP)
+  {
+    res =
+        *mkValHelper<CVC4::RecordUpdate>(CVC4::RecordUpdate(arg)).d_expr.get();
+  }
+  else
+  {
+    /* CLN and GMP handle this case differently, CLN interprets it as 0, GMP
+     * throws an std::invalid_argument exception. For consistency, we treat it
+     * as invalid. */
+    CVC4_API_ARG_CHECK_EXPECTED(arg != ".", arg)
+        << "a string representing an integer, real or rational value.";
+    res = *mkValHelper<CVC4::Divisible>(CVC4::Divisible(CVC4::Integer(arg)))
+               .d_expr.get();
+  }
+  return res;
   CVC4_API_SOLVER_TRY_CATCH_END;
 }
 
index 2ff1cb91d595acb3a3e029feecd30c47a492b8a5..67e8bb6e7ccd955331fb7a1da6a64429b225369e 100644 (file)
@@ -872,6 +872,19 @@ class CVC4_PUBLIC OpTerm
    */
   bool isNull() const;
 
+  /**
+   * Get the indices used to create this OpTerm.
+   * Supports the following template arguments:
+   *   - string
+   *   - Kind
+   *   - uint32_t
+   *   - pair<uint32_t, uint32_t>
+   * Check the OpTerm Kind with getKind() to determine which argument to use.
+   * @return the indices used to create this OpTerm
+   */
+  template <typename T>
+  T getIndices() const;
+
   /**
    * @return a string representation of this operator term
    */
@@ -1818,6 +1831,7 @@ class CVC4_PUBLIC Solver
   /**
    * Create operator of kind:
    *   - RECORD_UPDATE_OP
+   *   - DIVISIBLE_OP (to support arbitrary precision integers)
    * See enum Kind for a description of the parameters.
    * @param kind the kind of the operator
    * @param arg the string argument to this operator
index 395ee8451c1325992ff8e9a2ce5a66b35082d592..150cebcbfbecfbc51e3f2479ea7a6bd173827220 100644 (file)
@@ -27,6 +27,10 @@ class OpTermBlack : public CxxTest::TestSuite
   void testGetKind();
   void testGetSort();
   void testIsNull();
+  void testGetIndicesString();
+  void testGetIndicesKind();
+  void testGetIndicesUint();
+  void testGetIndicesPairUint();
 
  private:
   Solver d_solver;
@@ -55,3 +59,150 @@ void OpTermBlack::testIsNull()
   x = d_solver.mkOpTerm(BITVECTOR_EXTRACT_OP, 31, 1);
   TS_ASSERT(!x.isNull());
 }
+
+void OpTermBlack::testGetIndicesString()
+{
+  OpTerm x;
+  TS_ASSERT_THROWS(x.getIndices<std::string>(), CVC4ApiException&);
+
+  OpTerm divisible_ot = d_solver.mkOpTerm(DIVISIBLE_OP, 4);
+  std::string divisible_idx = divisible_ot.getIndices<std::string>();
+  TS_ASSERT(divisible_idx == "4");
+
+  OpTerm record_update_ot = d_solver.mkOpTerm(RECORD_UPDATE_OP, "test");
+  std::string record_update_idx = record_update_ot.getIndices<std::string>();
+  TS_ASSERT(record_update_idx == "test");
+  TS_ASSERT_THROWS(record_update_ot.getIndices<uint32_t>(), CVC4ApiException&);
+}
+
+void OpTermBlack::testGetIndicesKind()
+{
+  OpTerm chain_ot = d_solver.mkOpTerm(CHAIN_OP, AND);
+  Kind chain_idx = chain_ot.getIndices<Kind>();
+  TS_ASSERT(chain_idx == AND);
+}
+
+void OpTermBlack::testGetIndicesUint()
+{
+  OpTerm bitvector_repeat_ot = d_solver.mkOpTerm(BITVECTOR_REPEAT_OP, 5);
+  uint32_t bitvector_repeat_idx = bitvector_repeat_ot.getIndices<uint32_t>();
+  TS_ASSERT(bitvector_repeat_idx == 5);
+  TS_ASSERT_THROWS(
+      (bitvector_repeat_ot.getIndices<std::pair<uint32_t, uint32_t>>()),
+      CVC4ApiException&);
+
+  OpTerm bitvector_zero_extend_ot =
+      d_solver.mkOpTerm(BITVECTOR_ZERO_EXTEND_OP, 6);
+  uint32_t bitvector_zero_extend_idx =
+      bitvector_zero_extend_ot.getIndices<uint32_t>();
+  TS_ASSERT(bitvector_zero_extend_idx == 6);
+
+  OpTerm bitvector_sign_extend_ot =
+      d_solver.mkOpTerm(BITVECTOR_SIGN_EXTEND_OP, 7);
+  uint32_t bitvector_sign_extend_idx =
+      bitvector_sign_extend_ot.getIndices<uint32_t>();
+  TS_ASSERT(bitvector_sign_extend_idx == 7);
+
+  OpTerm bitvector_rotate_left_ot =
+      d_solver.mkOpTerm(BITVECTOR_ROTATE_LEFT_OP, 8);
+  uint32_t bitvector_rotate_left_idx =
+      bitvector_rotate_left_ot.getIndices<uint32_t>();
+  TS_ASSERT(bitvector_rotate_left_idx == 8);
+
+  OpTerm bitvector_rotate_right_ot =
+      d_solver.mkOpTerm(BITVECTOR_ROTATE_RIGHT_OP, 9);
+  uint32_t bitvector_rotate_right_idx =
+      bitvector_rotate_right_ot.getIndices<uint32_t>();
+  TS_ASSERT(bitvector_rotate_right_idx == 9);
+
+  OpTerm int_to_bitvector_ot = d_solver.mkOpTerm(INT_TO_BITVECTOR_OP, 10);
+  uint32_t int_to_bitvector_idx = int_to_bitvector_ot.getIndices<uint32_t>();
+  TS_ASSERT(int_to_bitvector_idx == 10);
+
+  OpTerm floatingpoint_to_ubv_ot =
+      d_solver.mkOpTerm(FLOATINGPOINT_TO_UBV_OP, 11);
+  uint32_t floatingpoint_to_ubv_idx =
+      floatingpoint_to_ubv_ot.getIndices<uint32_t>();
+  TS_ASSERT(floatingpoint_to_ubv_idx == 11);
+
+  OpTerm floatingpoint_to_ubv_total_ot =
+      d_solver.mkOpTerm(FLOATINGPOINT_TO_UBV_TOTAL_OP, 12);
+  uint32_t floatingpoint_to_ubv_total_idx =
+      floatingpoint_to_ubv_total_ot.getIndices<uint32_t>();
+  TS_ASSERT(floatingpoint_to_ubv_total_idx == 12);
+
+  OpTerm floatingpoint_to_sbv_ot =
+      d_solver.mkOpTerm(FLOATINGPOINT_TO_SBV_OP, 13);
+  uint32_t floatingpoint_to_sbv_idx =
+      floatingpoint_to_sbv_ot.getIndices<uint32_t>();
+  TS_ASSERT(floatingpoint_to_sbv_idx == 13);
+
+  OpTerm floatingpoint_to_sbv_total_ot =
+      d_solver.mkOpTerm(FLOATINGPOINT_TO_SBV_TOTAL_OP, 14);
+  uint32_t floatingpoint_to_sbv_total_idx =
+      floatingpoint_to_sbv_total_ot.getIndices<uint32_t>();
+  TS_ASSERT(floatingpoint_to_sbv_total_idx == 14);
+
+  OpTerm tuple_update_ot = d_solver.mkOpTerm(TUPLE_UPDATE_OP, 5);
+  uint32_t tuple_update_idx = tuple_update_ot.getIndices<uint32_t>();
+  TS_ASSERT(tuple_update_idx == 5);
+  TS_ASSERT_THROWS(tuple_update_ot.getIndices<std::string>(),
+                   CVC4ApiException&);
+}
+
+void OpTermBlack::testGetIndicesPairUint()
+{
+  OpTerm bitvector_extract_ot = d_solver.mkOpTerm(BITVECTOR_EXTRACT_OP, 4, 0);
+  std::pair<uint32_t, uint32_t> bitvector_extract_indices =
+      bitvector_extract_ot.getIndices<std::pair<uint32_t, uint32_t>>();
+  TS_ASSERT((bitvector_extract_indices == std::pair<uint32_t, uint32_t>{4, 0}));
+
+  OpTerm floatingpoint_to_fp_ieee_bitvector_ot =
+      d_solver.mkOpTerm(FLOATINGPOINT_TO_FP_IEEE_BITVECTOR_OP, 4, 25);
+  std::pair<uint32_t, uint32_t> floatingpoint_to_fp_ieee_bitvector_indices =
+      floatingpoint_to_fp_ieee_bitvector_ot
+          .getIndices<std::pair<uint32_t, uint32_t>>();
+  TS_ASSERT((floatingpoint_to_fp_ieee_bitvector_indices
+             == std::pair<uint32_t, uint32_t>{4, 25}));
+
+  OpTerm floatingpoint_to_fp_floatingpoint_ot =
+      d_solver.mkOpTerm(FLOATINGPOINT_TO_FP_FLOATINGPOINT_OP, 4, 25);
+  std::pair<uint32_t, uint32_t> floatingpoint_to_fp_floatingpoint_indices =
+      floatingpoint_to_fp_floatingpoint_ot
+          .getIndices<std::pair<uint32_t, uint32_t>>();
+  TS_ASSERT((floatingpoint_to_fp_floatingpoint_indices
+             == std::pair<uint32_t, uint32_t>{4, 25}));
+
+  OpTerm floatingpoint_to_fp_real_ot =
+      d_solver.mkOpTerm(FLOATINGPOINT_TO_FP_REAL_OP, 4, 25);
+  std::pair<uint32_t, uint32_t> floatingpoint_to_fp_real_indices =
+      floatingpoint_to_fp_real_ot.getIndices<std::pair<uint32_t, uint32_t>>();
+  TS_ASSERT((floatingpoint_to_fp_real_indices
+             == std::pair<uint32_t, uint32_t>{4, 25}));
+
+  OpTerm floatingpoint_to_fp_signed_bitvector_ot =
+      d_solver.mkOpTerm(FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR_OP, 4, 25);
+  std::pair<uint32_t, uint32_t> floatingpoint_to_fp_signed_bitvector_indices =
+      floatingpoint_to_fp_signed_bitvector_ot
+          .getIndices<std::pair<uint32_t, uint32_t>>();
+  TS_ASSERT((floatingpoint_to_fp_signed_bitvector_indices
+             == std::pair<uint32_t, uint32_t>{4, 25}));
+
+  OpTerm floatingpoint_to_fp_unsigned_bitvector_ot =
+      d_solver.mkOpTerm(FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR_OP, 4, 25);
+  std::pair<uint32_t, uint32_t> floatingpoint_to_fp_unsigned_bitvector_indices =
+      floatingpoint_to_fp_unsigned_bitvector_ot
+          .getIndices<std::pair<uint32_t, uint32_t>>();
+  TS_ASSERT((floatingpoint_to_fp_unsigned_bitvector_indices
+             == std::pair<uint32_t, uint32_t>{4, 25}));
+
+  OpTerm floatingpoint_to_fp_generic_ot =
+      d_solver.mkOpTerm(FLOATINGPOINT_TO_FP_GENERIC_OP, 4, 25);
+  std::pair<uint32_t, uint32_t> floatingpoint_to_fp_generic_indices =
+      floatingpoint_to_fp_generic_ot
+          .getIndices<std::pair<uint32_t, uint32_t>>();
+  TS_ASSERT((floatingpoint_to_fp_generic_indices
+             == std::pair<uint32_t, uint32_t>{4, 25}));
+  TS_ASSERT_THROWS(floatingpoint_to_fp_generic_ot.getIndices<std::string>(),
+                   CVC4ApiException&);
+}
index a82807b3ba7ad2efb0098b342a377797efc48763..3782b900a71142f4ae070b9aa7ddca400699fc80 100644 (file)
@@ -451,6 +451,7 @@ void SolverBlack::testMkOpTerm()
 
   // mkOpTerm(Kind kind, const std::string& arg)
   TS_ASSERT_THROWS_NOTHING(d_solver->mkOpTerm(RECORD_UPDATE_OP, "asdf"));
+  TS_ASSERT_THROWS_NOTHING(d_solver->mkOpTerm(DIVISIBLE_OP, "2147483648"));
   TS_ASSERT_THROWS(d_solver->mkOpTerm(BITVECTOR_EXTRACT_OP, "asdf"),
                    CVC4ApiException&);