From 4a7c0c73f69aabb20be4c79c47047ce23d3358b0 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 15 Dec 2021 17:49:19 -0600 Subject: [PATCH] Ensure match terms are exhaustive in its type rule (#7807) Fixes cvc5/cvc5-projects#382. Makes it so that we always fully type check match terms before they are rewritten, which guards potential unsoundness. --- src/theory/datatypes/datatypes_rewriter.cpp | 7 +- .../datatypes/theory_datatypes_type_rules.cpp | 115 +++++++++--------- test/unit/api/cpp/solver_black.cpp | 34 ++++++ 3 files changed, 94 insertions(+), 62 deletions(-) diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp index 196b4f01d..93449b637 100644 --- a/src/theory/datatypes/datatypes_rewriter.cpp +++ b/src/theory/datatypes/datatypes_rewriter.cpp @@ -152,6 +152,8 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) else if (kind == MATCH) { Trace("dt-rewrite-match") << "Rewrite match: " << in << std::endl; + // ensure we've type checked + TypeNode tin = in.getType(); Node h = in[0]; std::vector cases; std::vector rets; @@ -228,8 +230,9 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) std::reverse(cases.begin(), cases.end()); std::reverse(rets.begin(), rets.end()); Node ret = rets[0]; - AlwaysAssert(cases[0].isConst() || cases.size() == dt.getNumConstructors()); - for (unsigned i = 1, ncases = cases.size(); i < ncases; i++) + // notice that due to our type checker, either there is a variable pattern + // or all constructors are present in the match. + for (size_t i = 1, ncases = cases.size(); i < ncases; i++) { ret = nm->mkNode(ITE, cases[i], rets[i], ret); } diff --git a/src/theory/datatypes/theory_datatypes_type_rules.cpp b/src/theory/datatypes/theory_datatypes_type_rules.cpp index 422af1731..7ce5c0df6 100644 --- a/src/theory/datatypes/theory_datatypes_type_rules.cpp +++ b/src/theory/datatypes/theory_datatypes_type_rules.cpp @@ -398,67 +398,64 @@ TypeNode MatchTypeRule::computeType(NodeManager* nodeManager, for (unsigned i = 1, nchildren = n.getNumChildren(); i < nchildren; i++) { Node nc = n[i]; - if (check) + Kind nck = nc.getKind(); + std::unordered_set bvs; + if (nck == kind::MATCH_BIND_CASE) { - Kind nck = nc.getKind(); - std::unordered_set bvs; - if (nck == kind::MATCH_BIND_CASE) - { - for (const Node& v : nc[0]) - { - Assert(v.getKind() == kind::BOUND_VARIABLE); - bvs.insert(v); - } - } - else if (nck != kind::MATCH_CASE) + for (const Node& v : nc[0]) { - throw TypeCheckingExceptionPrivate( - n, "expected a match case in match expression"); + Assert(v.getKind() == kind::BOUND_VARIABLE); + bvs.insert(v); } - // get the pattern type - unsigned pindex = nck == kind::MATCH_CASE ? 0 : 1; - TypeNode patType = nc[pindex].getType(); - // should be caught in the above call - if (!patType.isDatatype()) - { - throw TypeCheckingExceptionPrivate( - n, "expecting datatype pattern in match"); - } - Kind ncpk = nc[pindex].getKind(); - if (ncpk == kind::APPLY_CONSTRUCTOR) + } + else if (nck != kind::MATCH_CASE) + { + throw TypeCheckingExceptionPrivate( + n, "expected a match case in match expression"); + } + // get the pattern type + uint32_t pindex = nck == kind::MATCH_CASE ? 0 : 1; + TypeNode patType = nc[pindex].getType(); + // should be caught in the above call + if (!patType.isDatatype()) + { + throw TypeCheckingExceptionPrivate( + n, "expecting datatype pattern in match"); + } + Kind ncpk = nc[pindex].getKind(); + if (ncpk == kind::APPLY_CONSTRUCTOR) + { + for (const Node& arg : nc[pindex]) { - for (const Node& arg : nc[pindex]) + if (bvs.find(arg) == bvs.end()) { - if (bvs.find(arg) == bvs.end()) - { - throw TypeCheckingExceptionPrivate( - n, - "expecting distinct bound variable as argument to " - "constructor in pattern of match"); - } - bvs.erase(arg); + throw TypeCheckingExceptionPrivate( + n, + "expecting distinct bound variable as argument to " + "constructor in pattern of match"); } - unsigned ci = utils::indexOf(nc[pindex].getOperator()); - patIndices.insert(ci); - } - else if (ncpk == kind::BOUND_VARIABLE) - { - patHasVariable = true; - } - else - { - throw TypeCheckingExceptionPrivate( - n, "unexpected kind of term in pattern in match"); - } - const DType& pdt = patType.getDType(); - // compare datatypes instead of the types to catch parametric case, - // where the pattern has parametric type. - if (hdt.getTypeNode() != pdt.getTypeNode()) - { - std::stringstream ss; - ss << "pattern of a match case does not match the head type in match"; - throw TypeCheckingExceptionPrivate(n, ss.str()); + bvs.erase(arg); } + size_t ci = utils::indexOf(nc[pindex].getOperator()); + patIndices.insert(ci); + } + else if (ncpk == kind::BOUND_VARIABLE) + { + patHasVariable = true; + } + else + { + throw TypeCheckingExceptionPrivate( + n, "unexpected kind of term in pattern in match"); + } + const DType& pdt = patType.getDType(); + // compare datatypes instead of the types to catch parametric case, + // where the pattern has parametric type. + if (hdt.getTypeNode() != pdt.getTypeNode()) + { + std::stringstream ss; + ss << "pattern of a match case does not match the head type in match"; + throw TypeCheckingExceptionPrivate(n, ss.str()); } TypeNode currType = nc.getType(check); if (i == 1) @@ -475,13 +472,11 @@ TypeNode MatchTypeRule::computeType(NodeManager* nodeManager, } } } - if (check) + // it is mandatory to check this here to ensure the match is exhaustive + if (!patHasVariable && patIndices.size() < hdt.getNumConstructors()) { - if (!patHasVariable && patIndices.size() < hdt.getNumConstructors()) - { - throw TypeCheckingExceptionPrivate( - n, "cases for match term are not exhaustive"); - } + throw TypeCheckingExceptionPrivate( + n, "cases for match term are not exhaustive"); } return retType; } diff --git a/test/unit/api/cpp/solver_black.cpp b/test/unit/api/cpp/solver_black.cpp index b17054637..31034a15e 100644 --- a/test/unit/api/cpp/solver_black.cpp +++ b/test/unit/api/cpp/solver_black.cpp @@ -2777,6 +2777,40 @@ TEST_F(TestApiBlackSolver, proj_issue381) ASSERT_NO_THROW(d_solver.simplify(t187)); } + +TEST_F(TestApiBlackSolver, proj_issue382) +{ + Sort s1 = d_solver.getBooleanSort(); + Sort psort = d_solver.mkParamSort("_x1"); + DatatypeConstructorDecl ctor = d_solver.mkDatatypeConstructorDecl("_x20"); + ctor.addSelector("_x19", psort); + DatatypeDecl dtdecl = d_solver.mkDatatypeDecl("_x0", psort); + dtdecl.addConstructor(ctor); + Sort s2 = d_solver.mkDatatypeSort(dtdecl); + Sort s6 = s2.instantiate({s1}); + Term t13 = d_solver.mkVar(s6, "_x58"); + Term t18 = d_solver.mkConst(s6, "_x63"); + Term t52 = d_solver.mkVar(s6, "_x70"); + Term t53 = d_solver.mkTerm( + MATCH_BIND_CASE, d_solver.mkTerm(VARIABLE_LIST, t52), t52, t18); + Term t73 = d_solver.mkVar(s1, "_x78"); + Term t81 = + d_solver.mkTerm(MATCH_BIND_CASE, + d_solver.mkTerm(VARIABLE_LIST, t73), + d_solver.mkTerm(APPLY_CONSTRUCTOR, + s6.getDatatype() + .getConstructor("_x20") + .getInstantiatedConstructorTerm(s6), + t73), + t18); + Term t82 = d_solver.mkTerm(MATCH, {t13, t53, t53, t53, t81}); + Term t325 = d_solver.mkTerm( + APPLY_SELECTOR, + t82.getSort().getDatatype().getSelector("_x19").getSelectorTerm(), + t82); + ASSERT_NO_THROW(d_solver.simplify(t325)); +} + TEST_F(TestApiBlackSolver, proj_issue383) { d_solver.setOption("produce-models", "true"); -- 2.30.2