Fixes and additions for API for parametric datatypes (#7760)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 13 Dec 2021 18:24:07 +0000 (12:24 -0600)
committerGitHub <noreply@github.com>
Mon, 13 Dec 2021 18:24:07 +0000 (18:24 +0000)
14 files changed:
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5.h
src/api/java/io/github/cvc5/api/Datatype.java
src/api/java/io/github/cvc5/api/DatatypeConstructor.java
src/api/java/jni/datatype.cpp
src/api/java/jni/datatype_constructor.cpp
src/api/python/cvc5.pxd
src/api/python/cvc5.pxi
src/parser/parser.cpp
src/parser/smt2/Smt2.g
src/theory/arith/nl/iand_solver.cpp
test/unit/api/cpp/datatype_api_black.cpp
test/unit/api/java/DatatypeTest.java
test/unit/api/python/test_datatype_api.py

index aa3e7fa3f3652a7a701271ee8493e9968208b240..1e5738ed8fcb969a5df7ce2d444b9b93dd68a607 100644 (file)
@@ -3804,7 +3804,7 @@ Term DatatypeConstructor::getConstructorTerm() const
   CVC5_API_TRY_CATCH_END;
 }
 
-Term DatatypeConstructor::getSpecializedConstructorTerm(
+Term DatatypeConstructor::getInstantiatedConstructorTerm(
     const Sort& retSort) const
 {
   CVC5_API_TRY_CATCH_BEGIN;
@@ -4109,6 +4109,18 @@ size_t Datatype::getNumConstructors() const
   CVC5_API_TRY_CATCH_END;
 }
 
+std::vector<Sort> Datatype::getParameters() const
+{
+  CVC5_API_TRY_CATCH_BEGIN;
+  CVC5_API_CHECK_NOT_NULL;
+  CVC5_API_CHECK(isParametric()) << "Expected parametric datatype";
+  //////// all checks before this line
+  std::vector<cvc5::TypeNode> params = d_dtype->getParameters();
+  return Sort::typeNodeVectorToSorts(d_solver, params);
+  ////////
+  CVC5_API_TRY_CATCH_END;
+}
+
 bool Datatype::isParametric() const
 {
   CVC5_API_TRY_CATCH_BEGIN;
index b65db16a359a542caa083659736445396a788e6f..e38030abe719bdb1b1c236fb381aa82e26fdda3f 100644 (file)
@@ -315,6 +315,7 @@ class CVC5_EXPORT Sort
   friend class DatatypeConstructorDecl;
   friend class DatatypeSelector;
   friend class DatatypeDecl;
+  friend class Datatype;
   friend class Op;
   friend class Solver;
   friend class Grammar;
@@ -434,7 +435,10 @@ class CVC5_EXPORT Sort
   bool isDatatype() const;
 
   /**
-   * Is this a parametric datatype sort?
+   * Is this a parametric datatype sort? A parametric datatype sort is either
+   * one that is returned by a call to Solver::mkDatatypeSort() or Solver::mkDatatypeSorts() for a
+   * parametric datatype, or an instantiated datatype sort returned by
+   * Sort::instantiate() for parametric datatype sort `s`.
    * @return true if the sort is a parametric datatype sort
    */
   bool isParametricDatatype() const;
@@ -1993,7 +1997,7 @@ class CVC5_EXPORT DatatypeConstructor
    * @param retSort the desired return sort of the constructor
    * @return the constructor term
    */
-  Term getSpecializedConstructorTerm(const Sort& retSort) const;
+  Term getInstantiatedConstructorTerm(const Sort& retSort) const;
 
   /**
    * Get the tester operator of this datatype constructor.
@@ -2244,6 +2248,12 @@ class CVC5_EXPORT Datatype
   /** @return the number of constructors for this Datatype. */
   size_t getNumConstructors() const;
 
+  /**
+   * @return the parameters of this datatype, if it is parametric. An exception
+   * is thrown if this datatype is not parametric.
+   */
+  std::vector<Sort> getParameters() const;
+
   /** @return true if this datatype is parametric */
   bool isParametric() const;
 
index bc33ba10b2bad005684316f318af6dd105738e95..39ea2bb193ee20257612e119755a3e72118dd069 100644 (file)
@@ -108,6 +108,18 @@ public class Datatype extends AbstractPointer implements Iterable<DatatypeConstr
 
   private native int getNumConstructors(long pointer);
 
+  /**
+   * @return the parameters of this datatype, if it is parametric. An exception
+   * is thrown if this datatype is not parametric.
+   */
+  public Sort[] getParameters() {
+    long[] sortPointers = getParameters(pointer);
+    Sort[] sorts = Utils.getSorts(solver, sortPointers);
+    return sorts;
+  }
+
+  private native long[] getParameters(long pointer);
+
   /** @return true if this datatype is parametric */
   public boolean isParametric()
   {
index 5fd9d8407d7b34230f75315329c9c6a92c371e7c..4a041dd992d3fc0d690b5189c9c213c85ca47449 100644 (file)
@@ -78,13 +78,14 @@ public class DatatypeConstructor extends AbstractPointer implements Iterable<Dat
    * @param retSort the desired return sort of the constructor
    * @return the constructor term
    */
-  public Term getSpecializedConstructorTerm(Sort retSort)
-  {
-    long termPointer = getSpecializedConstructorTerm(pointer, retSort.getPointer());
+  public Term getInstantiatedConstructorTerm(Sort retSort) {
+    long termPointer =
+        getInstantiatedConstructorTerm(pointer, retSort.getPointer());
     return new Term(solver, termPointer);
   }
 
-  private native long getSpecializedConstructorTerm(long pointer, long retSortPointer);
+  private native long getInstantiatedConstructorTerm(
+      long pointer, long retSortPointer);
 
   /**
    * Get the tester operator of this datatype constructor.
index 3f340c93ce07dc4c967157afb524ce75c6d1a07b..6cc8ba58e91e8f48cf305f968a86f97de74fb0b6 100644 (file)
@@ -112,6 +112,28 @@ JNIEXPORT jint JNICALL Java_io_github_cvc5_api_Datatype_getNumConstructors(
   CVC5_JAVA_API_TRY_CATCH_END_RETURN(env, 0);
 }
 
+/*
+ * Class:     io_github_cvc5_api_Datatype
+ * Method:    getParameters
+ * Signature: (J)[J
+ */
+JNIEXPORT jlongArray JNICALL Java_io_github_cvc5_api_Datatype_getParameters(
+    JNIEnv* env, jobject, jlong pointer)
+{
+  CVC5_JAVA_API_TRY_CATCH_BEGIN;
+  Datatype* current = (Datatype*)pointer;
+  std::vector<Sort> sorts = current->getParameters();
+  std::vector<jlong> sortPointers(sorts.size());
+  for (size_t i = 0; i < sorts.size(); i++)
+  {
+    sortPointers[i] = reinterpret_cast<jlong>(new Sort(sorts[i]));
+  }
+  jlongArray ret = env->NewLongArray(sorts.size());
+  env->SetLongArrayRegion(ret, 0, sorts.size(), sortPointers.data());
+  return ret;
+  CVC5_JAVA_API_TRY_CATCH_END_RETURN(env, nullptr);
+}
+
 /*
  * Class:     io_github_cvc5_api_Datatype
  * Method:    isParametric
index 7fe5f21c6457bf405be5dd6afab752654108a906..de9e4fa4c5d90a8840f29add4135b1d4cb2b074f 100644 (file)
@@ -65,17 +65,17 @@ Java_io_github_cvc5_api_DatatypeConstructor_getConstructorTerm(JNIEnv* env,
 
 /*
  * Class:     io_github_cvc5_api_DatatypeConstructor
- * Method:    getSpecializedConstructorTerm
+ * Method:    getInstantiatedConstructorTerm
  * Signature: (JJ)J
  */
 JNIEXPORT jlong JNICALL
-Java_io_github_cvc5_api_DatatypeConstructor_getSpecializedConstructorTerm(
+Java_io_github_cvc5_api_DatatypeConstructor_getInstantiatedConstructorTerm(
     JNIEnv* env, jobject, jlong pointer, jlong retSortPointer)
 {
   CVC5_JAVA_API_TRY_CATCH_BEGIN;
   DatatypeConstructor* current = (DatatypeConstructor*)pointer;
   Sort* sort = (Sort*)retSortPointer;
-  Term* retPointer = new Term(current->getSpecializedConstructorTerm(*sort));
+  Term* retPointer = new Term(current->getInstantiatedConstructorTerm(*sort));
   return (jlong)retPointer;
   CVC5_JAVA_API_TRY_CATCH_END_RETURN(env, 0);
 }
index f5dc2aca22b2870aa7cc988ca0cd4f93875da747..2baed575a8c8fda6286f2798869f69e19a46550e 100644 (file)
@@ -51,6 +51,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         DatatypeSelector getSelector(const string& name) except +
         string getName() except +
         size_t getNumConstructors() except +
+        vector[Sort] getParameters() except +
         bint isParametric() except +
         bint isCodatatype() except +
         bint isTuple() except +
@@ -76,7 +77,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         DatatypeSelector operator[](const string& name) except +
         string getName() except +
         Term getConstructorTerm() except +
-        Term getSpecializedConstructorTerm(const Sort& retSort) except +
+        Term getInstantiatedConstructorTerm(const Sort& retSort) except +
         Term getTesterTerm() except +
         size_t getNumSelectors() except +
         DatatypeSelector getSelector(const string& name) except +
index 6220545a105dddbad7e8be682becbd4214237d31..5c75047ef79d9a73c9dd16105b0129f28a546d73 100644 (file)
@@ -153,6 +153,18 @@ cdef class Datatype:
         """
         return self.cd.getNumConstructors()
 
+    def getParameters(self):
+        """
+            :return: the parameters of this datatype, if it is parametric. An
+            exception is thrown if this datatype is not parametric.
+        """
+        param_sorts = []
+        for s in self.cd.getParameters():
+            sort = Sort(self.solver)
+            sort.csort = s
+            param_sorts.append(sort)
+        return param_sorts
+
     def isParametric(self):
         """:return: True if this datatype is parametric."""
         return self.cd.isParametric()
@@ -233,15 +245,17 @@ cdef class DatatypeConstructor:
         term.cterm = self.cdc.getConstructorTerm()
         return term
 
-    def getSpecializedConstructorTerm(self, Sort retSort):
+    def getInstantiatedConstructorTerm(self, Sort retSort):
         """
-            Specialized method for parametric datatypes (see :cpp:func:`DatatypeConstructor::getSpecializedConstructorTerm() <cvc5::api::DatatypeConstructor::getSpecializedConstructorTerm>`).
+            Specialized method for parametric datatypes (see
+            :cpp:func:`DatatypeConstructor::getInstantiatedConstructorTerm()
+            <cvc5::api::DatatypeConstructor::getInstantiatedConstructorTerm>`).
 
             :param retSort: the desired return sort of the constructor
             :return: the constructor operator as a term.
         """
         cdef Term term = Term(self.solver)
-        term.cterm = self.cdc.getSpecializedConstructorTerm(retSort.csort)
+        term.cterm = self.cdc.getInstantiatedConstructorTerm(retSort.csort)
         return term
 
     def getTesterTerm(self):
index c94e3974828b46cb441088b2fd5cb931ae44a513..1481d66fe531ca2406b5f41a625ae45fdd62f4c3 100644 (file)
@@ -583,7 +583,7 @@ api::Term Parser::applyTypeAscription(api::Term t, api::Sort s)
       // lookup by name
       api::DatatypeConstructor dc = d.getConstructor(t.toString());
       // ask the constructor for the specialized constructor term
-      t = dc.getSpecializedConstructorTerm(s);
+      t = dc.getInstantiatedConstructorTerm(s);
     }
     // the type of t does not match the sort s by design (constructor type
     // vs datatype type), thus we use an alternative check here.
index fe401b3697de92b4d7416e4e3a93af0e02b08408..41cd4869a9fd8dfa8c5adef7be592814f872102f 100644 (file)
@@ -1364,7 +1364,7 @@ termNonVariable[cvc5::api::Term& expr, cvc5::api::Term& expr2]
           {
             // lookup constructor by name
             api::DatatypeConstructor dc = dt.getConstructor(f.toString());
-            api::Term scons = dc.getSpecializedConstructorTerm(expr.getSort());
+            api::Term scons = dc.getInstantiatedConstructorTerm(expr.getSort());
             // take the type of the specialized constructor instead
             type = scons.getSort();
           }
index 92ecaf4e8de22cf4c8981537fa7b325c65a216e3..b5278a353ebf6a6639a1341fa1645aae33f799f5 100644 (file)
@@ -296,7 +296,9 @@ Node IAndSolver::bitwiseLemma(Node i)
       bitIAnd = d_iandUtils.createBitwiseIAndNode(x, y, high_bit, j);
       // enforce bitwise equality
       lem = nm->mkNode(
-          AND, lem, rewrite(d_iandUtils.iextract(high_bit, j, i)).eqNode(bitIAnd));
+          AND,
+          lem,
+          rewrite(d_iandUtils.iextract(high_bit, j, i)).eqNode(bitIAnd));
     }
   }
   return lem;
index 745abc17ca9f073c5956d4dc93a0a8440b792d27..fecf228a5ddd6a8bafd97c752a4765a6341a491a 100644 (file)
@@ -231,6 +231,7 @@ TEST_F(TestApiBlackDatatype, datatypeNames)
   dtypeSpec.addConstructor(nil);
   Sort dtypeSort = d_solver.mkDatatypeSort(dtypeSpec);
   Datatype dt = dtypeSort.getDatatype();
+  ASSERT_THROW(dt.getParameters(), CVC5ApiException);
   ASSERT_EQ(dt.getName(), std::string("list"));
   ASSERT_NO_THROW(dt.getConstructor("nil"));
   ASSERT_NO_THROW(dt["cons"]);
@@ -274,6 +275,8 @@ TEST_F(TestApiBlackDatatype, parametricDatatype)
   Sort pairType = d_solver.mkDatatypeSort(pairSpec);
 
   ASSERT_TRUE(pairType.getDatatype().isParametric());
+  std::vector<Sort> dparams = pairType.getDatatype().getParameters();
+  ASSERT_TRUE(dparams[0] == t1 && dparams[1] == t2);
 
   v.clear();
   v.push_back(d_solver.getIntegerSort());
@@ -576,12 +579,16 @@ TEST_F(TestApiBlackDatatype, datatypeSpecializedCons)
   iargs.push_back(isort);
   Sort listInt = dtsorts[0].instantiate(iargs);
 
+  std::vector<Sort> liparams = listInt.getDatatype().getParameters();
+  // the parameter of the datatype is not instantiated
+  ASSERT_TRUE(liparams.size() == 1 && liparams[0] == x);
+
   Term testConsTerm;
   // get the specialized constructor term for list[Int]
-  ASSERT_NO_THROW(testConsTerm = nilc.getSpecializedConstructorTerm(listInt));
+  ASSERT_NO_THROW(testConsTerm = nilc.getInstantiatedConstructorTerm(listInt));
   ASSERT_NE(testConsTerm, nilc.getConstructorTerm());
   // error to get the specialized constructor term for Int
-  ASSERT_THROW(nilc.getSpecializedConstructorTerm(isort), CVC5ApiException);
+  ASSERT_THROW(nilc.getInstantiatedConstructorTerm(isort), CVC5ApiException);
 }
 }  // namespace test
 }  // namespace cvc5
index fb23ea51586ee36bc1c7f0ebbbd6fbd98adb0bbf..e94785b085f93fde9f74bdffeb25b856fd08e96f 100644 (file)
@@ -248,6 +248,8 @@ class DatatypeTest
     Sort pairType = d_solver.mkDatatypeSort(pairSpec);
 
     assertTrue(pairType.getDatatype().isParametric());
+    Sort[] dparams = pairType.getDatatype().getParameters();
+    assertTrue(dparams[0].equals(t1) && dparams[1].equals(t2));
 
     v.clear();
     v.add(d_solver.getIntegerSort());
@@ -562,10 +564,12 @@ class DatatypeTest
 
     AtomicReference<Term> atomicTerm = new AtomicReference<>();
     // get the specialized constructor term for list[Int]
-    assertDoesNotThrow(() -> atomicTerm.set(nilc.getSpecializedConstructorTerm(listInt)));
+    assertDoesNotThrow(
+        () -> atomicTerm.set(nilc.getInstantiatedConstructorTerm(listInt)));
     Term testConsTerm = atomicTerm.get();
     assertNotEquals(testConsTerm, nilc.getConstructorTerm());
     // error to get the specialized constructor term for Int
-    assertThrows(CVC5ApiException.class, () -> nilc.getSpecializedConstructorTerm(isort));
+    assertThrows(CVC5ApiException.class,
+        () -> nilc.getInstantiatedConstructorTerm(isort));
   }
 }
index 43124d4dc17de0b43953cdd1bfa9926b6d8f73cc..af34e098e129aa71841c3f304aec68c82f5173a7 100644 (file)
@@ -155,6 +155,8 @@ def test_datatype_structs(solver):
     dtypeSpec.addConstructor(nil)
     dtypeSort = solver.mkDatatypeSort(dtypeSpec)
     dt = dtypeSort.getDatatype()
+    # not parametric datatype
+    with pytest.raises(RuntimeError): dt.getParameters()
     assert not dt.isCodatatype()
     assert not dt.isTuple()
     assert not dt.isRecord()
@@ -262,7 +264,7 @@ def test_parametric_datatype(solver):
     v.append(t1)
     v.append(t2)
     pairSpec = solver.mkDatatypeDecl("pair", v)
-
+        
     mkpair = solver.mkDatatypeConstructorDecl("mk-pair")
     mkpair.addSelector("first", t1)
     mkpair.addSelector("second", t2)
@@ -271,6 +273,8 @@ def test_parametric_datatype(solver):
     pairType = solver.mkDatatypeSort(pairSpec)
 
     assert pairType.getDatatype().isParametric()
+    dparams = pairType.getDatatype().getParameters()
+    assert dparams[0]==t1 and dparams[1]==t2
 
     v.clear()
     v.append(solver.getIntegerSort())
@@ -558,8 +562,8 @@ def test_datatype_specialized_cons(solver):
 
     testConsTerm = Term(solver)
     # get the specialized constructor term for list[Int]
-    testConsTerm = nilc.getSpecializedConstructorTerm(listInt)
+    testConsTerm = nilc.getInstantiatedConstructorTerm(listInt)
     assert testConsTerm != nilc.getConstructorTerm()
     # error to get the specialized constructor term for Int
     with pytest.raises(RuntimeError):
-        nilc.getSpecializedConstructorTerm(isort)
+        nilc.getInstantiatedConstructorTerm(isort)