Ensure match terms are exhaustive in its type rule (#7807)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 15 Dec 2021 23:49:19 +0000 (17:49 -0600)
committerGitHub <noreply@github.com>
Wed, 15 Dec 2021 23:49:19 +0000 (23:49 +0000)
Fixes cvc5/cvc5-projects#382.

Makes it so that we always fully type check match terms before they are rewritten, which guards potential unsoundness.

src/theory/datatypes/datatypes_rewriter.cpp
src/theory/datatypes/theory_datatypes_type_rules.cpp
test/unit/api/cpp/solver_black.cpp

index 196b4f01dfbce5d475b51a22b14fe6aed63566df..93449b637834291283f09139b84a748aff28304c 100644 (file)
@@ -152,6 +152,8 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in)
   else if (kind == MATCH)
   {
     Trace("dt-rewrite-match") << "Rewrite match: " << in << std::endl;
+    // ensure we've type checked
+    TypeNode tin = in.getType();
     Node h = in[0];
     std::vector<Node> cases;
     std::vector<Node> rets;
@@ -228,8 +230,9 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in)
     std::reverse(cases.begin(), cases.end());
     std::reverse(rets.begin(), rets.end());
     Node ret = rets[0];
-    AlwaysAssert(cases[0].isConst() || cases.size() == dt.getNumConstructors());
-    for (unsigned i = 1, ncases = cases.size(); i < ncases; i++)
+    // notice that due to our type checker, either there is a variable pattern
+    // or all constructors are present in the match.
+    for (size_t i = 1, ncases = cases.size(); i < ncases; i++)
     {
       ret = nm->mkNode(ITE, cases[i], rets[i], ret);
     }
index 422af1731ab8582bb2393bfbb6cf3ba1b9f2420c..7ce5c0df60a5a94e47cae02aac9ff4b47a14cbab 100644 (file)
@@ -398,67 +398,64 @@ TypeNode MatchTypeRule::computeType(NodeManager* nodeManager,
   for (unsigned i = 1, nchildren = n.getNumChildren(); i < nchildren; i++)
   {
     Node nc = n[i];
-    if (check)
+    Kind nck = nc.getKind();
+    std::unordered_set<Node> bvs;
+    if (nck == kind::MATCH_BIND_CASE)
     {
-      Kind nck = nc.getKind();
-      std::unordered_set<Node> bvs;
-      if (nck == kind::MATCH_BIND_CASE)
-      {
-        for (const Node& v : nc[0])
-        {
-          Assert(v.getKind() == kind::BOUND_VARIABLE);
-          bvs.insert(v);
-        }
-      }
-      else if (nck != kind::MATCH_CASE)
+      for (const Node& v : nc[0])
       {
-        throw TypeCheckingExceptionPrivate(
-            n, "expected a match case in match expression");
+        Assert(v.getKind() == kind::BOUND_VARIABLE);
+        bvs.insert(v);
       }
-      // get the pattern type
-      unsigned pindex = nck == kind::MATCH_CASE ? 0 : 1;
-      TypeNode patType = nc[pindex].getType();
-      // should be caught in the above call
-      if (!patType.isDatatype())
-      {
-        throw TypeCheckingExceptionPrivate(
-            n, "expecting datatype pattern in match");
-      }
-      Kind ncpk = nc[pindex].getKind();
-      if (ncpk == kind::APPLY_CONSTRUCTOR)
+    }
+    else if (nck != kind::MATCH_CASE)
+    {
+      throw TypeCheckingExceptionPrivate(
+          n, "expected a match case in match expression");
+    }
+    // get the pattern type
+    uint32_t pindex = nck == kind::MATCH_CASE ? 0 : 1;
+    TypeNode patType = nc[pindex].getType();
+    // should be caught in the above call
+    if (!patType.isDatatype())
+    {
+      throw TypeCheckingExceptionPrivate(
+          n, "expecting datatype pattern in match");
+    }
+    Kind ncpk = nc[pindex].getKind();
+    if (ncpk == kind::APPLY_CONSTRUCTOR)
+    {
+      for (const Node& arg : nc[pindex])
       {
-        for (const Node& arg : nc[pindex])
+        if (bvs.find(arg) == bvs.end())
         {
-          if (bvs.find(arg) == bvs.end())
-          {
-            throw TypeCheckingExceptionPrivate(
-                n,
-                "expecting distinct bound variable as argument to "
-                "constructor in pattern of match");
-          }
-          bvs.erase(arg);
+          throw TypeCheckingExceptionPrivate(
+              n,
+              "expecting distinct bound variable as argument to "
+              "constructor in pattern of match");
         }
-        unsigned ci = utils::indexOf(nc[pindex].getOperator());
-        patIndices.insert(ci);
-      }
-      else if (ncpk == kind::BOUND_VARIABLE)
-      {
-        patHasVariable = true;
-      }
-      else
-      {
-        throw TypeCheckingExceptionPrivate(
-            n, "unexpected kind of term in pattern in match");
-      }
-      const DType& pdt = patType.getDType();
-      // compare datatypes instead of the types to catch parametric case,
-      // where the pattern has parametric type.
-      if (hdt.getTypeNode() != pdt.getTypeNode())
-      {
-        std::stringstream ss;
-        ss << "pattern of a match case does not match the head type in match";
-        throw TypeCheckingExceptionPrivate(n, ss.str());
+        bvs.erase(arg);
       }
+      size_t ci = utils::indexOf(nc[pindex].getOperator());
+      patIndices.insert(ci);
+    }
+    else if (ncpk == kind::BOUND_VARIABLE)
+    {
+      patHasVariable = true;
+    }
+    else
+    {
+      throw TypeCheckingExceptionPrivate(
+          n, "unexpected kind of term in pattern in match");
+    }
+    const DType& pdt = patType.getDType();
+    // compare datatypes instead of the types to catch parametric case,
+    // where the pattern has parametric type.
+    if (hdt.getTypeNode() != pdt.getTypeNode())
+    {
+      std::stringstream ss;
+      ss << "pattern of a match case does not match the head type in match";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
     }
     TypeNode currType = nc.getType(check);
     if (i == 1)
@@ -475,13 +472,11 @@ TypeNode MatchTypeRule::computeType(NodeManager* nodeManager,
       }
     }
   }
-  if (check)
+  // it is mandatory to check this here to ensure the match is exhaustive
+  if (!patHasVariable && patIndices.size() < hdt.getNumConstructors())
   {
-    if (!patHasVariable && patIndices.size() < hdt.getNumConstructors())
-    {
-      throw TypeCheckingExceptionPrivate(
-          n, "cases for match term are not exhaustive");
-    }
+    throw TypeCheckingExceptionPrivate(
+        n, "cases for match term are not exhaustive");
   }
   return retType;
 }
index b170546374dc7bea012f7ea18af29366335f99d8..31034a15e2bcbed78371805d995bb457ec88df0b 100644 (file)
@@ -2777,6 +2777,40 @@ TEST_F(TestApiBlackSolver, proj_issue381)
   ASSERT_NO_THROW(d_solver.simplify(t187));
 }
 
+
+TEST_F(TestApiBlackSolver, proj_issue382)
+{
+  Sort s1 = d_solver.getBooleanSort();
+  Sort psort = d_solver.mkParamSort("_x1");
+  DatatypeConstructorDecl ctor = d_solver.mkDatatypeConstructorDecl("_x20");
+  ctor.addSelector("_x19", psort);
+  DatatypeDecl dtdecl = d_solver.mkDatatypeDecl("_x0", psort);
+  dtdecl.addConstructor(ctor);
+  Sort s2 = d_solver.mkDatatypeSort(dtdecl);
+  Sort s6 = s2.instantiate({s1});
+  Term t13 = d_solver.mkVar(s6, "_x58");
+  Term t18 = d_solver.mkConst(s6, "_x63");
+  Term t52 = d_solver.mkVar(s6, "_x70");
+  Term t53 = d_solver.mkTerm(
+      MATCH_BIND_CASE, d_solver.mkTerm(VARIABLE_LIST, t52), t52, t18);
+  Term t73 = d_solver.mkVar(s1, "_x78");
+  Term t81 =
+      d_solver.mkTerm(MATCH_BIND_CASE,
+                      d_solver.mkTerm(VARIABLE_LIST, t73),
+                      d_solver.mkTerm(APPLY_CONSTRUCTOR,
+                                      s6.getDatatype()
+                                          .getConstructor("_x20")
+                                          .getInstantiatedConstructorTerm(s6),
+                                      t73),
+                      t18);
+  Term t82 = d_solver.mkTerm(MATCH, {t13, t53, t53, t53, t81});
+  Term t325 = d_solver.mkTerm(
+      APPLY_SELECTOR,
+      t82.getSort().getDatatype().getSelector("_x19").getSelectorTerm(),
+      t82);
+  ASSERT_NO_THROW(d_solver.simplify(t325));
+}
+
 TEST_F(TestApiBlackSolver, proj_issue383)
 {
   d_solver.setOption("produce-models", "true");