Add std::vector<Term> Op:: getIndices() and operator[] for Op (#6397)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Thu, 22 Jul 2021 07:10:39 +0000 (02:10 -0500)
committerGitHub <noreply@github.com>
Thu, 22 Jul 2021 07:10:39 +0000 (07:10 +0000)
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5.h
test/unit/api/op_black.cpp

index 4e14b8e34068e3f5c5cce6fb0631116d8158b201..9b920b8753248ede5d0c9fd6230b865e4ca15696 100644 (file)
@@ -570,6 +570,7 @@ const static std::unordered_map<cvc5::Kind, Kind, cvc5::kind::KindHashFunction>
         {cvc5::Kind::APPLY_UPDATER, APPLY_UPDATER},
         {cvc5::Kind::DT_SIZE, DT_SIZE},
         {cvc5::Kind::TUPLE_PROJECT, TUPLE_PROJECT},
+        {cvc5::Kind::TUPLE_PROJECT_OP, TUPLE_PROJECT},
         /* Separation Logic ------------------------------------------------ */
         {cvc5::Kind::SEP_NIL, SEP_NIL},
         {cvc5::Kind::SEP_EMP, SEP_EMP},
@@ -1865,6 +1866,14 @@ size_t Op::getNumIndices() const
 {
   CVC5_API_TRY_CATCH_BEGIN;
   CVC5_API_CHECK_NOT_NULL;
+  //////// all checks before this line
+  return getNumIndicesHelper();
+  ////////
+  CVC5_API_TRY_CATCH_END;
+}
+
+size_t Op::getNumIndicesHelper() const
+{
   if (!isIndexedHelper())
   {
     return 0;
@@ -1898,9 +1907,185 @@ size_t Op::getNumIndices() const
       break;
     default: CVC5_API_CHECK(false) << "Unhandled kind " << kindToString(k);
   }
+  return size;
+}
+
+Term Op::operator[](size_t index) const
+{
+  return getIndexHelper(index);
+}
+
+Term Op::getIndexHelper(size_t index) const
+{
+  CVC5_API_TRY_CATCH_BEGIN;
+  CVC5_API_CHECK_NOT_NULL;
+  CVC5_API_CHECK(!d_node->isNull())
+      << "Expecting a non-null internal expression. This Op is not indexed.";
+  CVC5_API_CHECK(index < getNumIndicesHelper()) << "index out of bound";
+  Kind k = intToExtKind(d_node->getKind());
+  Term t;
+  switch (k)
+  {
+    case DIVISIBLE:
+    {
+      t = d_solver->mkValHelper<Rational>(
+          Rational(d_node->getConst<Divisible>().k));
+      break;
+    }
+    case BITVECTOR_REPEAT:
+    {
+      t = d_solver->mkValHelper<cvc5::Rational>(
+          d_node->getConst<BitVectorRepeat>().d_repeatAmount);
+      break;
+    }
+    case BITVECTOR_ZERO_EXTEND:
+    {
+      t = d_solver->mkValHelper<cvc5::Rational>(
+          d_node->getConst<BitVectorZeroExtend>().d_zeroExtendAmount);
+      break;
+    }
+    case BITVECTOR_SIGN_EXTEND:
+    {
+      t = d_solver->mkValHelper<cvc5::Rational>(
+          d_node->getConst<BitVectorSignExtend>().d_signExtendAmount);
+      break;
+    }
+    case BITVECTOR_ROTATE_LEFT:
+    {
+      t = d_solver->mkValHelper<cvc5::Rational>(
+          d_node->getConst<BitVectorRotateLeft>().d_rotateLeftAmount);
+      break;
+    }
+    case BITVECTOR_ROTATE_RIGHT:
+    {
+      t = d_solver->mkValHelper<cvc5::Rational>(
+          d_node->getConst<BitVectorRotateRight>().d_rotateRightAmount);
+      break;
+    }
+    case INT_TO_BITVECTOR:
+    {
+      t = d_solver->mkValHelper<cvc5::Rational>(
+          d_node->getConst<IntToBitVector>().d_size);
+      break;
+    }
+    case IAND:
+    {
+      t = d_solver->mkValHelper<cvc5::Rational>(
+          d_node->getConst<IntAnd>().d_size);
+      break;
+    }
+    case FLOATINGPOINT_TO_UBV:
+    {
+      t = d_solver->mkValHelper<cvc5::Rational>(
+          d_node->getConst<FloatingPointToUBV>().d_bv_size.d_size);
+      break;
+    }
+    case FLOATINGPOINT_TO_SBV:
+    {
+      t = d_solver->mkValHelper<cvc5::Rational>(
+          d_node->getConst<FloatingPointToSBV>().d_bv_size.d_size);
+      break;
+    }
+    case REGEXP_REPEAT:
+    {
+      t = d_solver->mkValHelper<cvc5::Rational>(
+          d_node->getConst<RegExpRepeat>().d_repeatAmount);
+      break;
+    }
+    case BITVECTOR_EXTRACT:
+    {
+      cvc5::BitVectorExtract ext = d_node->getConst<BitVectorExtract>();
+      t = index == 0 ? d_solver->mkValHelper<cvc5::Rational>(ext.d_high)
+                     : d_solver->mkValHelper<cvc5::Rational>(ext.d_low);
+      break;
+    }
+    case FLOATINGPOINT_TO_FP_IEEE_BITVECTOR:
+    {
+      cvc5::FloatingPointToFPIEEEBitVector ext =
+          d_node->getConst<FloatingPointToFPIEEEBitVector>();
+
+      t = index == 0 ? d_solver->mkValHelper<cvc5::Rational>(
+              ext.getSize().exponentWidth())
+                     : d_solver->mkValHelper<cvc5::Rational>(
+                         ext.getSize().significandWidth());
+      break;
+    }
+    case FLOATINGPOINT_TO_FP_FLOATINGPOINT:
+    {
+      cvc5::FloatingPointToFPFloatingPoint ext =
+          d_node->getConst<FloatingPointToFPFloatingPoint>();
+      t = index == 0 ? d_solver->mkValHelper<cvc5::Rational>(
+              ext.getSize().exponentWidth())
+                     : d_solver->mkValHelper<cvc5::Rational>(
+                         ext.getSize().significandWidth());
+      break;
+    }
+    case FLOATINGPOINT_TO_FP_REAL:
+    {
+      cvc5::FloatingPointToFPReal ext =
+          d_node->getConst<FloatingPointToFPReal>();
+
+      t = index == 0 ? d_solver->mkValHelper<cvc5::Rational>(
+              ext.getSize().exponentWidth())
+                     : d_solver->mkValHelper<cvc5::Rational>(
+                         ext.getSize().significandWidth());
+      break;
+    }
+    case FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR:
+    {
+      cvc5::FloatingPointToFPSignedBitVector ext =
+          d_node->getConst<FloatingPointToFPSignedBitVector>();
+      t = index == 0 ? d_solver->mkValHelper<cvc5::Rational>(
+              ext.getSize().exponentWidth())
+                     : d_solver->mkValHelper<cvc5::Rational>(
+                         ext.getSize().significandWidth());
+      break;
+    }
+    case FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR:
+    {
+      cvc5::FloatingPointToFPUnsignedBitVector ext =
+          d_node->getConst<FloatingPointToFPUnsignedBitVector>();
+      t = index == 0 ? d_solver->mkValHelper<cvc5::Rational>(
+              ext.getSize().exponentWidth())
+                     : d_solver->mkValHelper<cvc5::Rational>(
+                         ext.getSize().significandWidth());
+      break;
+    }
+    case FLOATINGPOINT_TO_FP_GENERIC:
+    {
+      cvc5::FloatingPointToFPGeneric ext =
+          d_node->getConst<FloatingPointToFPGeneric>();
+      t = index == 0 ? d_solver->mkValHelper<cvc5::Rational>(
+              ext.getSize().exponentWidth())
+                     : d_solver->mkValHelper<cvc5::Rational>(
+                         ext.getSize().significandWidth());
+      break;
+    }
+    case REGEXP_LOOP:
+    {
+      cvc5::RegExpLoop ext = d_node->getConst<RegExpLoop>();
+      t = index == 0 ? d_solver->mkValHelper<cvc5::Rational>(ext.d_loopMinOcc)
+                     : d_solver->mkValHelper<cvc5::Rational>(ext.d_loopMaxOcc);
+
+      break;
+    }
+
+    case TUPLE_PROJECT:
+    {
+      const std::vector<uint32_t>& projectionIndices =
+          d_node->getConst<TupleProjectOp>().getIndices();
+      t = d_solver->mkValHelper<cvc5::Rational>(projectionIndices[index]);
+      break;
+    }
+    default:
+    {
+      CVC5_API_CHECK(false) << "Unhandled kind " << kindToString(k);
+      break;
+    }
+  }
 
   //////// all checks before this line
-  return size;
+  return t;
   ////////
   CVC5_API_TRY_CATCH_END;
 }
@@ -2043,6 +2228,25 @@ std::pair<uint32_t, uint32_t> Op::getIndices() const
   CVC5_API_TRY_CATCH_END;
 }
 
+template <>
+std::vector<api::Term> Op::getIndices() const
+{
+  CVC5_API_TRY_CATCH_BEGIN;
+  CVC5_API_CHECK_NOT_NULL;
+  CVC5_API_CHECK(!d_node->isNull())
+      << "Expecting a non-null internal expression. This Op is not indexed.";
+  size_t size = getNumIndicesHelper();
+  std::vector<Term> terms(getNumIndices());
+  for (size_t i = 0; i < size; i++)
+  {
+    terms[i] = getIndexHelper(i);
+  }
+  //////// all checks before this line
+  return terms;
+  ////////
+  CVC5_API_TRY_CATCH_END;
+}
+
 std::string Op::toString() const
 {
   CVC5_API_TRY_CATCH_BEGIN;
index 0ee5990da7545f92aff437537c8cb21dfd539890..672ccf8f8242f8bd7e28ea061f5594068b44f55a 100644 (file)
@@ -782,6 +782,8 @@ namespace cvc5::api {
 /* Op                                                                     */
 /* -------------------------------------------------------------------------- */
 
+class Term;
+
 /**
  * A cvc5 operator.
  * An operator is a term that represents certain operators, instantiated
@@ -842,6 +844,14 @@ class CVC5_EXPORT Op
    */
   size_t getNumIndices() const;
 
+  /**
+   * Get the index at position i.
+   * @param i the position of the index to return
+   * @return the index at position i
+   */
+
+  Term operator[](size_t i) const;
+
   /**
    * Get the indices used to create this Op.
    * Supports the following template arguments:
@@ -891,6 +901,19 @@ class CVC5_EXPORT Op
    */
   bool isIndexedHelper() const;
 
+  /**
+   * Helper for getNumIndices
+   * @return the number of indices of this op
+   */
+  size_t getNumIndicesHelper() const;
+
+  /**
+   * Helper for operator[](size_t i).
+   * @param i position of the index. Should be less than getNumIndicesHelper().
+   * @return the index at position i
+   */
+  Term getIndexHelper(size_t index) const;
+
   /**
    * The associated solver object.
    */
index 8c4e31c22303c1a6e09300ea71ca9d9f24bcd23b..fd45b1c225242e8861e827cab87ccbcfdfc980bc 100644 (file)
@@ -92,6 +92,52 @@ TEST_F(TestApiBlackOp, getNumIndices)
   ASSERT_EQ(2, regexp_loop.getNumIndices());
 }
 
+TEST_F(TestApiBlackOp, subscriptOperator)
+{
+  Op plus = d_solver.mkOp(PLUS);
+  Op divisible = d_solver.mkOp(DIVISIBLE, 4);
+  Op bitvector_repeat = d_solver.mkOp(BITVECTOR_REPEAT, 4);
+  Op bitvector_zero_extend = d_solver.mkOp(BITVECTOR_ZERO_EXTEND, 4);
+  Op bitvector_sign_extend = d_solver.mkOp(BITVECTOR_SIGN_EXTEND, 4);
+  Op bitvector_rotate_left = d_solver.mkOp(BITVECTOR_ROTATE_LEFT, 4);
+  Op bitvector_rotate_right = d_solver.mkOp(BITVECTOR_ROTATE_RIGHT, 4);
+  Op int_to_bitvector = d_solver.mkOp(INT_TO_BITVECTOR, 4);
+  Op iand = d_solver.mkOp(IAND, 4);
+  Op floatingpoint_to_ubv = d_solver.mkOp(FLOATINGPOINT_TO_UBV, 4);
+  Op floatingopint_to_sbv = d_solver.mkOp(FLOATINGPOINT_TO_SBV, 4);
+  Op floatingpoint_to_fp_ieee_bitvector =
+      d_solver.mkOp(FLOATINGPOINT_TO_FP_IEEE_BITVECTOR, 4, 5);
+  Op floatingpoint_to_fp_floatingpoint =
+      d_solver.mkOp(FLOATINGPOINT_TO_FP_FLOATINGPOINT, 4, 5);
+  Op floatingpoint_to_fp_real = d_solver.mkOp(FLOATINGPOINT_TO_FP_REAL, 4, 5);
+  Op floatingpoint_to_fp_signed_bitvector =
+      d_solver.mkOp(FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR, 4, 5);
+  Op floatingpoint_to_fp_unsigned_bitvector =
+      d_solver.mkOp(FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR, 4, 5);
+  Op floatingpoint_to_fp_generic =
+      d_solver.mkOp(FLOATINGPOINT_TO_FP_GENERIC, 4, 5);
+  Op regexp_loop = d_solver.mkOp(REGEXP_LOOP, 4, 5);
+
+  ASSERT_THROW(plus[0], CVC5ApiException);
+  ASSERT_EQ(4, divisible[0].getUInt32Value());
+  ASSERT_EQ(4, bitvector_repeat[0].getUInt32Value());
+  ASSERT_EQ(4, bitvector_zero_extend[0].getUInt32Value());
+  ASSERT_EQ(4, bitvector_sign_extend[0].getUInt32Value());
+  ASSERT_EQ(4, bitvector_rotate_left[0].getUInt32Value());
+  ASSERT_EQ(4, bitvector_rotate_right[0].getUInt32Value());
+  ASSERT_EQ(4, int_to_bitvector[0].getUInt32Value());
+  ASSERT_EQ(4, iand[0].getUInt32Value());
+  ASSERT_EQ(4, floatingpoint_to_ubv[0].getUInt32Value());
+  ASSERT_EQ(4, floatingopint_to_sbv[0].getUInt32Value());
+  ASSERT_EQ(4, floatingpoint_to_fp_ieee_bitvector[0].getUInt32Value());
+  ASSERT_EQ(4, floatingpoint_to_fp_floatingpoint[0].getUInt32Value());
+  ASSERT_EQ(4, floatingpoint_to_fp_real[0].getUInt32Value());
+  ASSERT_EQ(4, floatingpoint_to_fp_signed_bitvector[0].getUInt32Value());
+  ASSERT_EQ(4, floatingpoint_to_fp_unsigned_bitvector[0].getUInt32Value());
+  ASSERT_EQ(4, floatingpoint_to_fp_generic[0].getUInt32Value());
+  ASSERT_EQ(4, regexp_loop[0].getUInt32Value());
+}
+
 TEST_F(TestApiBlackOp, getIndicesString)
 {
   Op x;
@@ -207,6 +253,22 @@ TEST_F(TestApiBlackOp, getIndicesPairUint)
                CVC5ApiException);
 }
 
+TEST_F(TestApiBlackOp, getIndicesVector)
+{
+  std::vector<uint32_t> indices = {0, 3, 2, 0, 1, 2};
+  Op tuple_project_op = d_solver.mkOp(TUPLE_PROJECT, indices);
+
+  ASSERT_TRUE(tuple_project_op.isIndexed());
+  std::vector<Term> tuple_project_extract_indices =
+      tuple_project_op.getIndices<std::vector<Term>>();
+  ASSERT_THROW(tuple_project_op.getIndices<std::string>(), CVC5ApiException);
+  for (size_t i = 0; i < indices.size(); i++)
+  {
+    ASSERT_EQ(indices[i], tuple_project_extract_indices[i].getUInt32Value());
+    ASSERT_EQ(indices[i], tuple_project_op[i].getUInt32Value());
+  }
+}
+
 TEST_F(TestApiBlackOp, opScopingToString)
 {
   Op bitvector_repeat_ot = d_solver.mkOp(BITVECTOR_REPEAT, 5);