Add missing datatype functions to new API (#3930)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 11 Mar 2020 21:05:59 +0000 (16:05 -0500)
committerGitHub <noreply@github.com>
Wed, 11 Mar 2020 21:05:59 +0000 (14:05 -0700)
This is in preparation for migrating the parser to use the Term-level API for datatypes.

Notably, this adds the function mkDatatypeSorts for making mutually recursive datatypes. I've added a unit test that demonstrates this method (which mirrors the Expr-level datatype API).

src/api/cvc4cpp.cpp
src/api/cvc4cpp.h
test/unit/api/datatype_api_black.h

index 3b28e2f5cd4b2a40c5a35d9f5237e6fcd869585a..ff25bbabb85545231718abca03eacffab29d3002 100644 (file)
@@ -1878,11 +1878,17 @@ std::string DatatypeDecl::toString() const
   return ss.str();
 }
 
+std::string DatatypeDecl::getName() const
+{
+  CVC4_API_CHECK_NOT_NULL;
+  return d_dtype->getName();
+}
+
 bool DatatypeDecl::isNull() const { return isNullHelper(); }
 
 // !!! This is only temporarily available until the parser is fully migrated
 // to the new API. !!!
-const CVC4::Datatype& DatatypeDecl::getDatatype(void) const { return *d_dtype; }
+CVC4::Datatype& DatatypeDecl::getDatatype(void) const { return *d_dtype; }
 
 std::ostream& operator<<(std::ostream& out,
                          const DatatypeSelectorDecl& stordecl)
@@ -1911,6 +1917,11 @@ Term DatatypeSelector::getSelectorTerm() const
   return sel;
 }
 
+Sort DatatypeSelector::getRangeSort() const
+{
+  return Sort(d_stor->getRangeType());
+}
+
 std::string DatatypeSelector::toString() const
 {
   std::stringstream ss;
@@ -2458,6 +2469,33 @@ Term Solver::mkTermInternal(Kind kind, const std::vector<Term>& children) const
   CVC4_API_SOLVER_TRY_CATCH_END;
 }
 
+std::vector<Sort> Solver::mkDatatypeSortsInternal(
+    std::vector<DatatypeDecl>& dtypedecls,
+    std::set<Sort>& unresolvedSorts) const
+{
+  CVC4_API_SOLVER_TRY_CATCH_BEGIN;
+
+  std::vector<CVC4::Datatype> datatypes;
+  for (size_t i = 0, ndts = dtypedecls.size(); i < ndts; i++)
+  {
+    CVC4_API_ARG_CHECK_EXPECTED(dtypedecls[i].getNumConstructors() > 0,
+                                dtypedecls[i])
+        << "a datatype declaration with at least one constructor";
+    datatypes.push_back(dtypedecls[i].getDatatype());
+  }
+  std::set<Type> utypes = sortSetToTypes(unresolvedSorts);
+  std::vector<CVC4::DatatypeType> dtypes =
+      d_exprMgr->mkMutualDatatypeTypes(datatypes, utypes);
+  std::vector<Sort> retTypes;
+  for (CVC4::DatatypeType t : dtypes)
+  {
+    retTypes.push_back(Sort(t));
+  }
+  return retTypes;
+
+  CVC4_API_SOLVER_TRY_CATCH_END;
+}
+
 /* Helpers for converting vectors.                                            */
 /* .......................................................................... */
 
@@ -2604,6 +2642,19 @@ Sort Solver::mkDatatypeSort(DatatypeDecl dtypedecl) const
   CVC4_API_SOLVER_TRY_CATCH_END;
 }
 
+std::vector<Sort> Solver::mkDatatypeSorts(
+    std::vector<DatatypeDecl>& dtypedecls) const
+{
+  std::set<Sort> unresolvedSorts;
+  return mkDatatypeSortsInternal(dtypedecls, unresolvedSorts);
+}
+
+std::vector<Sort> Solver::mkDatatypeSorts(std::vector<DatatypeDecl>& dtypedecls,
+                                          std::set<Sort>& unresolvedSorts) const
+{
+  return mkDatatypeSortsInternal(dtypedecls, unresolvedSorts);
+}
+
 Sort Solver::mkFunctionSort(Sort domain, Sort codomain) const
 {
   CVC4_API_SOLVER_TRY_CATCH_BEGIN;
index db29359c5cd1c4f8f6c11c9dc649a93454d369ba..dcf787b8eaa95f6dd645b72cc22f0a852fe7ee4b 100644 (file)
@@ -1260,9 +1260,12 @@ class CVC4_PUBLIC DatatypeDecl
    */
   std::string toString() const;
 
+  /** @return the name of this datatype declaration. */
+  std::string getName() const;
+
   // !!! This is only temporarily available until the parser is fully migrated
   // to the new API. !!!
-  const CVC4::Datatype& getDatatype(void) const;
+  CVC4::Datatype& getDatatype(void) const;
 
  private:
   /**
@@ -1350,6 +1353,9 @@ class CVC4_PUBLIC DatatypeSelector
    */
   Term getSelectorTerm() const;
 
+  /** @return the range sort of this argument. */
+  Sort getRangeSort() const;
+
   /**
    * @return a string representation of this datatype selector
    */
@@ -1920,6 +1926,37 @@ class CVC4_PUBLIC Solver
    */
   Sort mkDatatypeSort(DatatypeDecl dtypedecl) const;
 
+  /**
+   * Create a vector of datatype sorts. The names of the datatype declarations
+   * must be distinct.
+   *
+   * @param dtypedecls the datatype declarations from which the sort is created
+   * @return the datatype sorts
+   */
+  std::vector<Sort> mkDatatypeSorts(
+      std::vector<DatatypeDecl>& dtypedecls) const;
+
+  /**
+   * Create a vector of datatype sorts using unresolved sorts. The names of
+   * the datatype declarations in dtypedecls must be distinct.
+   *
+   * This method is called when the DatatypeDecl objects dtypedecls have been
+   * built using "unresolved" sorts.
+   *
+   * We associate each sort in unresolvedSorts with exacly one datatype from
+   * dtypedecls. In particular, it must have the same name as exactly one
+   * datatype declaration in dtypedecls.
+   *
+   * When constructing datatypes, unresolved sorts are replaced by the datatype
+   * sort constructed for the datatype declaration it is associated with.
+   *
+   * @param dtypedecls the datatype declarations from which the sort is created
+   * @param unresolvedSorts the list of unresolved sorts
+   * @return the datatype sorts
+   */
+  std::vector<Sort> mkDatatypeSorts(std::vector<DatatypeDecl>& dtypedecls,
+                                    std::set<Sort>& unresolvedSorts) const;
+
   /**
    * Create function sort.
    * @param domain the sort of the fuction argument
@@ -2867,6 +2904,16 @@ class CVC4_PUBLIC Solver
    */
   Term mkTermInternal(Kind kind, const std::vector<Term>& children) const;
 
+  /**
+   * Create a vector of datatype sorts, using unresolved sorts.
+   * @param dtypedecls the datatype declarations from which the sort is created
+   * @param unresolvedSorts the list of unresolved sorts
+   * @return the datatype sorts
+   */
+  std::vector<Sort> mkDatatypeSortsInternal(
+      std::vector<DatatypeDecl>& dtypedecls,
+      std::set<Sort>& unresolvedSorts) const;
+
   /* The expression manager of this solver. */
   std::unique_ptr<ExprManager> d_exprMgr;
   /* The SMT engine of this solver. */
index dcccd2628b05840bf7bb33fadd8fc8e8fe1bf17f..f252826246e2b9981f5a1d5e17f0a1e3c4d1e3a5 100644 (file)
@@ -27,6 +27,7 @@ class DatatypeBlack : public CxxTest::TestSuite
   void tearDown() override;
 
   void testMkDatatypeSort();
+  void testMkDatatypeSorts();
 
   void testDatatypeStructs();
   void testDatatypeNames();
@@ -57,6 +58,77 @@ void DatatypeBlack::testMkDatatypeSort()
   TS_ASSERT_THROWS_NOTHING(nilConstr.getConstructorTerm());
 }
 
+void DatatypeBlack::testMkDatatypeSorts()
+{
+  /* Create two mutual datatypes corresponding to this definition
+   * block:
+   *
+   *   DATATYPE
+   *     tree = node(left: tree, right: tree) | leaf(data: list),
+   *     list = cons(car: tree, cdr: list) | nil
+   *   END;
+   */
+  // Make unresolved types as placeholders
+  std::set<Sort> unresTypes;
+  Sort unresTree = d_solver.mkUninterpretedSort("tree");
+  Sort unresList = d_solver.mkUninterpretedSort("list");
+  unresTypes.insert(unresTree);
+  unresTypes.insert(unresList);
+
+  DatatypeDecl tree = d_solver.mkDatatypeDecl("tree");
+  DatatypeConstructorDecl node("node");
+  DatatypeSelectorDecl left("left", unresTree);
+  node.addSelector(left);
+  DatatypeSelectorDecl right("right", unresTree);
+  node.addSelector(right);
+  tree.addConstructor(node);
+
+  DatatypeConstructorDecl leaf("leaf");
+  DatatypeSelectorDecl data("data", unresList);
+  leaf.addSelector(data);
+  tree.addConstructor(leaf);
+
+  DatatypeDecl list = d_solver.mkDatatypeDecl("list");
+  DatatypeConstructorDecl cons("cons");
+  DatatypeSelectorDecl car("car", unresTree);
+  cons.addSelector(car);
+  DatatypeSelectorDecl cdr("cdr", unresTree);
+  cons.addSelector(cdr);
+  list.addConstructor(cons);
+
+  DatatypeConstructorDecl nil("nil");
+  list.addConstructor(nil);
+
+  std::vector<DatatypeDecl> dtdecls;
+  dtdecls.push_back(tree);
+  dtdecls.push_back(list);
+  std::vector<Sort> dtsorts;
+  TS_ASSERT_THROWS_NOTHING(dtsorts =
+                               d_solver.mkDatatypeSorts(dtdecls, unresTypes));
+  TS_ASSERT(dtsorts.size() == dtdecls.size());
+  for (unsigned i = 0, ndecl = dtdecls.size(); i < ndecl; i++)
+  {
+    TS_ASSERT(dtsorts[i].isDatatype());
+    TS_ASSERT(!dtsorts[i].getDatatype().isFinite());
+    TS_ASSERT(dtsorts[i].getDatatype().getName() == dtdecls[i].getName());
+  }
+  // verify the resolution was correct
+  Datatype dtTree = dtsorts[0].getDatatype();
+  DatatypeConstructor dtcTreeNode = dtTree[0];
+  TS_ASSERT(dtcTreeNode.getName() == "node");
+  DatatypeSelector dtsTreeNodeLeft = dtcTreeNode[0];
+  TS_ASSERT(dtsTreeNodeLeft.getName() == "left");
+  // argument type should have resolved to be recursive
+  TS_ASSERT(dtsTreeNodeLeft.getRangeSort().isDatatype());
+  TS_ASSERT(dtsTreeNodeLeft.getRangeSort() == dtsorts[0]);
+
+  // fails due to empty datatype
+  std::vector<DatatypeDecl> dtdeclsBad;
+  DatatypeDecl emptyD = d_solver.mkDatatypeDecl("emptyD");
+  dtdeclsBad.push_back(emptyD);
+  TS_ASSERT_THROWS(d_solver.mkDatatypeSorts(dtdeclsBad), CVC4ApiException&);
+}
+
 void DatatypeBlack::testDatatypeStructs()
 {
   Sort intSort = d_solver.getIntegerSort();
@@ -138,6 +210,8 @@ void DatatypeBlack::testDatatypeNames()
 
   // create datatype sort to test
   DatatypeDecl dtypeSpec = d_solver.mkDatatypeDecl("list");
+  TS_ASSERT_THROWS_NOTHING(dtypeSpec.getName());
+  TS_ASSERT(dtypeSpec.getName() == std::string("list"));
   DatatypeConstructorDecl cons("cons");
   DatatypeSelectorDecl head("head", intSort);
   cons.addSelector(head);
@@ -163,4 +237,8 @@ void DatatypeBlack::testDatatypeNames()
   // get selector
   DatatypeSelector dselTail = dcons[1];
   TS_ASSERT(dselTail.getName() == std::string("tail"));
+  TS_ASSERT(dselTail.getRangeSort() == dtypeSort);
+
+  // possible to construct null datatype declarations if not using solver
+  TS_ASSERT_THROWS(DatatypeDecl().getName(), CVC4ApiException&);
 }