Refactor and fixes related to getSpecializedConstructorTerm (#7774)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 10 Dec 2021 21:47:18 +0000 (15:47 -0600)
committerGitHub <noreply@github.com>
Fri, 10 Dec 2021 21:47:18 +0000 (21:47 +0000)
Fixes cvc5/cvc5-projects#381.

13 files changed:
src/api/cpp/cvc5.cpp
src/expr/dtype_cons.cpp
src/expr/dtype_cons.h
src/theory/datatypes/datatypes_rewriter.cpp
src/theory/datatypes/theory_datatypes.cpp
src/theory/datatypes/theory_datatypes_utils.cpp
src/theory/datatypes/type_enumerator.cpp
src/theory/quantifiers/cegqi/ceg_instantiator.cpp
src/theory/quantifiers/quant_split.cpp
src/theory/quantifiers/quantifiers_rewriter.cpp
src/theory/quantifiers/skolemize.cpp
src/theory/quantifiers/sygus/sygus_grammar_cons.cpp
test/unit/api/cpp/solver_black.cpp

index e062e60ed15e1f73675d5fb109ce450f17cb0d92..aa3e7fa3f3652a7a701271ee8493e9968208b240 100644 (file)
@@ -3815,13 +3815,7 @@ Term DatatypeConstructor::getSpecializedConstructorTerm(
       << "Cannot get specialized constructor type for non-datatype type "
       << retSort;
   //////// all checks before this line
-
-  NodeManager* nm = d_solver->getNodeManager();
-  Node ret =
-      nm->mkNode(kind::APPLY_TYPE_ASCRIPTION,
-                 nm->mkConst(AscriptionType(
-                     d_ctor->getSpecializedConstructorType(*retSort.d_type))),
-                 d_ctor->getConstructor());
+  Node ret = d_ctor->getInstantiatedConstructor(*retSort.d_type);
   (void)ret.getType(true); /* kick off type checking */
   // apply type ascription to the operator
   Term sctor = api::Term(d_solver, ret);
index 6ba3970c9696d808ba7e1ebc37742eaf43e257d2..a054dffb890c2ff101f111e544e0ac48a97e020b 100644 (file)
@@ -83,6 +83,16 @@ Node DTypeConstructor::getConstructor() const
   return d_constructor;
 }
 
+Node DTypeConstructor::getInstantiatedConstructor(TypeNode returnType) const
+{
+  Assert(isResolved());
+  NodeManager* nm = NodeManager::currentNM();
+  return nm->mkNode(
+      kind::APPLY_TYPE_ASCRIPTION,
+      nm->mkConst(AscriptionType(getInstantiatedConstructorType(returnType))),
+      d_constructor);
+}
+
 Node DTypeConstructor::getTester() const
 {
   Assert(isResolved());
@@ -116,12 +126,12 @@ unsigned DTypeConstructor::getWeight() const
 
 size_t DTypeConstructor::getNumArgs() const { return d_args.size(); }
 
-TypeNode DTypeConstructor::getSpecializedConstructorType(
+TypeNode DTypeConstructor::getInstantiatedConstructorType(
     TypeNode returnType) const
 {
   Assert(isResolved());
   Assert(returnType.isDatatype())
-      << "DTypeConstructor::getSpecializedConstructorType: expected datatype, "
+      << "DTypeConstructor::getInstantiatedConstructorType: expected datatype, "
          "got "
       << returnType;
   TypeNode ctn = d_constructor.getType();
@@ -439,7 +449,7 @@ Node DTypeConstructor::computeGroundTerm(TypeNode t,
                             << ", ascribe to " << t << std::endl;
     groundTerms[0] = nm->mkNode(
         APPLY_TYPE_ASCRIPTION,
-        nm->mkConst(AscriptionType(getSpecializedConstructorType(t))),
+        nm->mkConst(AscriptionType(getInstantiatedConstructorType(t))),
         groundTerms[0]);
     groundTerm = nm->mkNode(APPLY_CONSTRUCTOR, groundTerms);
   }
@@ -456,7 +466,7 @@ void DTypeConstructor::computeSharedSelectors(TypeNode domainType) const
     TypeNode ctype;
     if (domainType.isParametricDatatype())
     {
-      ctype = getSpecializedConstructorType(domainType);
+      ctype = getInstantiatedConstructorType(domainType);
     }
     else
     {
index a6268aad1454178a623203b1c643d73a0800a341..657f6b7b8536c7d5f0194e8b610452911de0869a 100644 (file)
@@ -85,6 +85,12 @@ class DTypeConstructor
    * DType must be resolved.
    */
   Node getConstructor() const;
+  /**
+   * Get the specialized constructor term of this constructor, which is
+   * the constructor wrapped in a APPLY_TYPE_ASCRIPTION. This is required
+   * for constructing applications of constructors for parametric datatypes.
+   */
+  Node getInstantiatedConstructor(TypeNode returnType) const;
 
   /**
    * Get the tester operator of this constructor.  The
@@ -139,7 +145,7 @@ class DTypeConstructor
    * "cons" constructor type for lists of int---namely,
    * "int -> list[int] -> list[int]".
    */
-  TypeNode getSpecializedConstructorType(TypeNode returnType) const;
+  TypeNode getInstantiatedConstructorType(TypeNode returnType) const;
 
   /**
    * Return the cardinality of this constructor (the product of the
index 903a08bb4d5b44c7b56d087bfbc02d6e5b802b4f..b475d51e7a22309c3c83997158cec0f992f2f903 100644 (file)
@@ -329,10 +329,7 @@ RewriteResponse DatatypesRewriter::preRewrite(TNode in)
         // get the constructor object
         const DTypeConstructor& dtc = utils::datatypeOf(op)[utils::indexOf(op)];
         // create ascribed constructor type
-        Node tc = NodeManager::currentNM()->mkConst(
-            AscriptionType(dtc.getSpecializedConstructorType(tn)));
-        Node op_new = NodeManager::currentNM()->mkNode(
-            kind::APPLY_TYPE_ASCRIPTION, tc, op);
+        Node op_new = dtc.getInstantiatedConstructor(tn);
         // make new node
         std::vector<Node> children;
         children.push_back(op_new);
@@ -891,7 +888,14 @@ TrustNode DatatypesRewriter::expandDefinition(Node n)
       size_t cindex = utils::cindexOf(op);
       const DTypeConstructor& dc = dt[cindex];
       NodeBuilder b(APPLY_CONSTRUCTOR);
-      b << dc.getConstructor();
+      if (tn.isParametricDatatype())
+      {
+        b << dc.getInstantiatedConstructor(n[0].getType());
+      }
+      else
+      {
+        b << dc.getConstructor();
+      }
       Trace("dt-expand") << "Expand updater " << n << std::endl;
       Trace("dt-expand") << "expr is " << n << std::endl;
       Trace("dt-expand") << "updateIndex is " << updateIndex << std::endl;
index a9f0c3198e7e47a616a80ff905d6ec0a3779b09a..3f11ab1da7bf05e39f2737671df467ae7c6a01f5 100644 (file)
@@ -1244,7 +1244,7 @@ bool TheoryDatatypes::collectModelValues(TheoryModel* m,
           for( unsigned i=0; i<pcons.size(); i++ ){
             // must try the infinite ones first
             bool cfinite =
-                d_env.isFiniteType(dt[i].getSpecializedConstructorType(tt));
+                d_env.isFiniteType(dt[i].getInstantiatedConstructorType(tt));
             if( pcons[i] && (r==1)==cfinite ){
               neqc = utils::getInstCons(eqc, dt, i);
               break;
index 3b36ad2f294f227a4c8b8e2703819c4aa81fe0f5..a429f83333349be26053a6dd1bd4be1708c2d749 100644 (file)
@@ -63,12 +63,7 @@ Node mkApplyCons(TypeNode tn,
     // add type ascription for ambiguous constructor types
     Debug("datatypes-parametric")
         << "Constructor is " << dt[index] << std::endl;
-    TypeNode tspec = dt[index].getSpecializedConstructorType(tn);
-    Debug("datatypes-parametric")
-        << "Type specification is " << tspec << std::endl;
-    cchildren[0] = nm->mkNode(APPLY_TYPE_ASCRIPTION,
-                              nm->mkConst(AscriptionType(tspec)),
-                              cchildren[0]);
+    cchildren[0] = dt[index].getInstantiatedConstructor(tn);
   }
   return nm->mkNode(APPLY_CONSTRUCTOR, cchildren);
 }
index 6528f1052ac83baa544433db42ccba6a46876e80..69ebc9c78f903c24c2f753c30630827233c2c12f 100644 (file)
@@ -143,11 +143,7 @@ Node DatatypesEnumerator::getTermEnum( TypeNode tn, unsigned i ){
      NodeBuilder b(kind::APPLY_CONSTRUCTOR);
      if (d_datatype.isParametric())
      {
-       NodeManager* nm = NodeManager::currentNM();
-       TypeNode typ = ctor.getSpecializedConstructorType(d_type);
-       b << nm->mkNode(kind::APPLY_TYPE_ASCRIPTION,
-                       nm->mkConst(AscriptionType(typ)),
-                       ctor.getConstructor());
+       b << ctor.getInstantiatedConstructor(d_type);
      }
      else
      {
@@ -245,7 +241,7 @@ Node DatatypesEnumerator::getTermEnum( TypeNode tn, unsigned i ){
      TypeNode typ;
      if (d_datatype.isParametric())
      {
-       typ = ctor.getSpecializedConstructorType(d_type);
+       typ = ctor.getInstantiatedConstructorType(d_type);
      }
      for (unsigned a = 0; a < ctor.getNumArgs(); ++a)
      {
index 9556d3f9c381c6f0374085a3f8394b5a9391affa..ec33fe5fd55a029960e9ec9489bd3e830e9e01ca 100644 (file)
@@ -357,7 +357,7 @@ CegHandledStatus CegInstantiator::isCbqiSort(
       if (dt.isParametric())
       {
         // if parametric, must instantiate the argument types
-        consType = dt[i].getSpecializedConstructorType(tn);
+        consType = dt[i].getInstantiatedConstructorType(tn);
       }
       else
       {
index 55fa2a1e5a6416c74bbbaa6ef3a02291fddcb143..e6cee778bc16f673816bff276fc72179d10123cc 100644 (file)
@@ -167,7 +167,7 @@ void QuantDSplit::check(Theory::Effort e, QEffort quant_e)
       for (unsigned j = 0, ncons = dt.getNumConstructors(); j < ncons; j++)
       {
         std::vector<Node> vars;
-        TypeNode dtjtn = dt[j].getSpecializedConstructorType(tn);
+        TypeNode dtjtn = dt[j].getInstantiatedConstructorType(tn);
         Assert(dtjtn.getNumChildren() == dt[j].getNumArgs() + 1);
         for (unsigned k = 0, nargs = dt[j].getNumArgs(); k < nargs; k++)
         {
index ba10a2efcc0f3fd3fe4411c6ab77d4169fdd3a44..2002c73dbb3114cdb282e307ae4d08bf081ad42d 100644 (file)
@@ -906,9 +906,9 @@ bool QuantifiersRewriter::getVarElimLit(Node body,
       // take into account if parametric
       if (dt.isParametric())
       {
-        tspec = c.getSpecializedConstructorType(lit[0].getType());
-        cons = nm->mkNode(
-            APPLY_TYPE_ASCRIPTION, nm->mkConst(AscriptionType(tspec)), cons);
+        TypeNode ltn = lit[0].getType();
+        tspec = c.getInstantiatedConstructorType(ltn);
+        cons = c.getInstantiatedConstructor(ltn);
       }
       else
       {
index f116b2f3cb6099fe9eaaf0a3b2650ac53b074083..9f2f9c91c568e4271248c5f73c9edbd4634e2dfb 100644 (file)
@@ -137,8 +137,8 @@ void Skolemize::getSelfSel(const DType& dt,
   TypeNode tspec;
   if (dt.isParametric())
   {
-    tspec = dc.getSpecializedConstructorType(n.getType());
-    Trace("sk-ind-debug") << "Specialized constructor type : " << tspec
+    tspec = dc.getInstantiatedConstructorType(n.getType());
+    Trace("sk-ind-debug") << "Instantiated constructor type : " << tspec
                           << std::endl;
     Assert(tspec.getNumChildren() == dc.getNumArgs());
   }
index 438afbe8268f5023d15f29087a4cafb430856586..95d3a5ab5b834bb1558524f0f7c0fa7e0f5636c2 100644 (file)
@@ -483,7 +483,7 @@ void CegGrammarConstructor::collectSygusGrammarTypesFor(
         {
           // get the specialized constructor type, which accounts for
           // parametric datatypes
-          TypeNode ctn = dt[i].getSpecializedConstructorType(range);
+          TypeNode ctn = dt[i].getInstantiatedConstructorType(range);
           std::vector<TypeNode> argTypes = ctn.getArgTypes();
           for (size_t j = 0, nargs = argTypes.size(); j < nargs; ++j)
           {
@@ -1010,12 +1010,11 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
       {
         Trace("sygus-grammar-def") << "...for " << dt[l].getName() << std::endl;
         Node cop = dt[l].getConstructor();
-        TypeNode tspec = dt[l].getSpecializedConstructorType(types[i]);
+        TypeNode tspec = dt[l].getInstantiatedConstructorType(types[i]);
         // must specialize if a parametric datatype
         if (dt.isParametric())
         {
-          cop = nm->mkNode(
-              APPLY_TYPE_ASCRIPTION, nm->mkConst(AscriptionType(tspec)), cop);
+          cop = dt[l].getInstantiatedConstructor(types[i]);
         }
         if (dt[l].getNumArgs() == 0)
         {
index c268ee4f8d3267adcdc65a1132944c616e16457d..2df5de4b87e7e4cbd38a424373270b3769aa27c7 100644 (file)
@@ -2750,5 +2750,29 @@ TEST_F(TestApiBlackSolver, getDatatypeArity)
   ASSERT_EQ(s3.getDatatypeArity(), 0);
 }
 
+TEST_F(TestApiBlackSolver, proj_issue381)
+{
+  Sort s1 = d_solver.getBooleanSort();
+
+  Sort psort = d_solver.mkParamSort("_x9");
+  DatatypeDecl dtdecl = d_solver.mkDatatypeDecl("_x8", psort);
+  DatatypeConstructorDecl ctor = d_solver.mkDatatypeConstructorDecl("_x22");
+  ctor.addSelector("_x19", s1);
+  dtdecl.addConstructor(ctor);
+  Sort s3 = d_solver.mkDatatypeSort(dtdecl);
+  Sort s6 = s3.instantiate({s1});
+  Term t26 = d_solver.mkConst(s6, "_x63");
+  Term t5 = d_solver.mkTrue();
+  Term t187 = d_solver.mkTerm(APPLY_UPDATER,
+                              t26.getSort()
+                                  .getDatatype()
+                                  .getConstructor("_x22")
+                                  .getSelector("_x19")
+                                  .getUpdaterTerm(),
+                              t26,
+                              t5);
+  ASSERT_NO_THROW(d_solver.simplify(t187));
+}
+
 }  // namespace test
 }  // namespace cvc5