From: Andrew Reynolds Date: Sat, 28 Sep 2019 02:01:54 +0000 (-0500) Subject: Support smt2 language "match" term (#3258) X-Git-Tag: cvc5-1.0.0~3928 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=e25f99329c9905c67a565481dcb0d6a4499a7557;p=cvc5.git Support smt2 language "match" term (#3258) --- diff --git a/src/expr/node.h b/src/expr/node.h index f0ee7a56c..b8a665f0c 100644 --- a/src/expr/node.h +++ b/src/expr/node.h @@ -468,7 +468,8 @@ public: inline bool isClosure() const { assertTNodeNotExpired(); return getKind() == kind::LAMBDA || getKind() == kind::FORALL - || getKind() == kind::EXISTS || getKind() == kind::CHOICE; + || getKind() == kind::EXISTS || getKind() == kind::CHOICE + || getKind() == kind::MATCH_BIND_CASE; } /** diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g index 21e09317d..9a8232df9 100644 --- a/src/parser/smt2/Smt2.g +++ b/src/parser/smt2/Smt2.g @@ -1810,15 +1810,14 @@ termNonVariable[CVC4::Expr& expr, CVC4::Expr& expr2] std::string attr; Expr attexpr; std::vector patexprs; - std::vector patconds; + std::vector matchcases; std::unordered_set names; std::vector< std::pair > binders; - int match_vindex = -1; - std::vector match_ptypes; Type type; Type type2; api::Term atomTerm; ParseOp p; + std::vector argTypes; } : LPAREN_TOK quantOp[kind] LPAREN_TOK sortedVarList[sortedVarNames] RPAREN_TOK @@ -1912,107 +1911,92 @@ termNonVariable[CVC4::Expr& expr, CVC4::Expr& expr2] } LPAREN_TOK ( - /* match cases */ - LPAREN_TOK INDEX_TOK term[f, f2] { - if( match_vindex==-1 ){ - match_vindex = (int)patexprs.size(); + // case with non-nullary pattern + LPAREN_TOK LPAREN_TOK term[f, f2] { + args.clear(); + PARSER_STATE->pushScope(true); + // f should be a constructor + type = f.getType(); + Debug("parser-dt") << "Pattern head : " << f << " " << type << std::endl; + if (!type.isConstructor()) + { + PARSER_STATE->parseError("Pattern must be application of a constructor or a variable."); + } + if (Datatype::datatypeOf(f).isParametric()) + { + type = Datatype::datatypeOf(f)[Datatype::indexOf(f)].getSpecializedConstructorType(expr.getType()); + } + argTypes = static_cast(type).getArgTypes(); + } + // arguments of the pattern + ( symbol[name,CHECK_NONE,SYM_VARIABLE] { + if (args.size() >= argTypes.size()) + { + PARSER_STATE->parseError("Too many arguments for pattern."); + } + //make of proper type + Expr arg = PARSER_STATE->mkBoundVar(name, argTypes[args.size()]); + args.push_back( arg ); } - patexprs.push_back( f ); - patconds.push_back(MK_CONST(bool(true))); + )* + RPAREN_TOK term[f3, f2] { + // make the match case + std::vector cargs; + cargs.push_back(f); + cargs.insert(cargs.end(),args.begin(),args.end()); + Expr c = MK_EXPR(kind::APPLY_CONSTRUCTOR,cargs); + Expr bvl = MK_EXPR(kind::BOUND_VAR_LIST,args); + Expr mc = MK_EXPR(kind::MATCH_BIND_CASE, bvl, c, f3); + matchcases.push_back(mc); + // now, pop the scope + PARSER_STATE->popScope(); } RPAREN_TOK - | LPAREN_TOK LPAREN_TOK term[f, f2] { - args.clear(); - PARSER_STATE->pushScope(true); - //f should be a constructor - type = f.getType(); - Debug("parser-dt") << "Pattern head : " << f << " " << f.getType() << std::endl; - if( !type.isConstructor() ){ - PARSER_STATE->parseError("Pattern must be application of a constructor or a variable."); - } - if( Datatype::datatypeOf(f).isParametric() ){ - type = Datatype::datatypeOf(f)[Datatype::indexOf(f)].getSpecializedConstructorType(expr.getType()); - } - match_ptypes = ((ConstructorType)type).getArgTypes(); - } - //arguments - ( symbol[name,CHECK_NONE,SYM_VARIABLE] { - if( args.size()>=match_ptypes.size() ){ - PARSER_STATE->parseError("Too many arguments for pattern."); - } - //make of proper type - Expr arg = PARSER_STATE->mkBoundVar(name, match_ptypes[args.size()]); - args.push_back( arg ); - } - )* - RPAREN_TOK - term[f3, f2] { - const DatatypeConstructor& dtc = Datatype::datatypeOf(f)[Datatype::indexOf(f)]; - if( args.size()!=dtc.getNumArgs() ){ - PARSER_STATE->parseError("Bad number of arguments for application of constructor in pattern."); - } - //FIXME: make MATCH a kind and make this a rewrite - // build a lambda - std::vector largs; - largs.push_back( MK_EXPR( CVC4::kind::BOUND_VAR_LIST, args ) ); - largs.push_back( f3 ); - std::vector< Expr > aargs; - aargs.push_back( MK_EXPR( CVC4::kind::LAMBDA, largs ) ); - for( unsigned i=0; ipopScope(); } - | LPAREN_TOK symbol[name,CHECK_DECLARED,SYM_VARIABLE] { - f = PARSER_STATE->getVariable(name); - type = f.getType(); - if( !type.isConstructor() || !((ConstructorType)type).getArgTypes().empty() ){ - PARSER_STATE->parseError("Must apply constructors of arity greater than 0 to arguments in pattern."); - } - } - term[f3, f2] { - const DatatypeConstructor& dtc = Datatype::datatypeOf(f)[Datatype::indexOf(f)]; - patexprs.push_back( f3 ); - patconds.push_back( MK_EXPR( CVC4::kind::APPLY_TESTER, dtc.getTester(), expr ) ); - } - RPAREN_TOK - )+ - RPAREN_TOK RPAREN_TOK { - if( match_vindex==-1 ){ - const Datatype& dt = ((DatatypeType)expr.getType()).getDatatype(); - std::map< unsigned, bool > processed; - unsigned count = 0; - //ensure that all datatype constructors are matched (to ensure exhaustiveness) - for( unsigned i=0; i=dt.getNumConstructors() ){ - PARSER_STATE->parseError("Pattern is not legal for the head of a match."); + // case with nullary or variable pattern + | LPAREN_TOK symbol[name,CHECK_NONE,SYM_VARIABLE] { + if (PARSER_STATE->isDeclared(name,SYM_VARIABLE)) + { + f = PARSER_STATE->getVariable(name); + type = f.getType(); + if (!type.isConstructor() || + !((ConstructorType)type).getArgTypes().empty()) + { + PARSER_STATE->parseError("Must apply constructors of arity greater than 0 to arguments in pattern."); + } + // make nullary constructor application + f = MK_EXPR(kind::APPLY_CONSTRUCTOR, f); } - if( processed.find( curr_index )==processed.end() ){ - processed[curr_index] = true; - count++; + else + { + // it has the type of the head expr + f = PARSER_STATE->mkBoundVar(name, expr.getType()); } } - if( count!=dt.getNumConstructors() ){ - PARSER_STATE->parseError("Patterns are not exhaustive in a match construct."); - } - } - //now, make the ITE - int end_index = match_vindex==-1 ? patexprs.size()-1 : match_vindex; - bool first_time = true; - for( int index = end_index; index>=0; index-- ){ - if( first_time ){ - expr = patexprs[index]; - first_time = false; - }else{ - expr = MK_EXPR( CVC4::kind::ITE, patconds[index], patexprs[index], expr ); + term[f3, f2] { + Expr mc; + if (f.getKind() == kind::BOUND_VARIABLE) + { + Expr bvl = MK_EXPR(kind::BOUND_VAR_LIST, f); + mc = MK_EXPR(kind::MATCH_BIND_CASE, bvl, f, f3); + } + else + { + mc = MK_EXPR(kind::MATCH_CASE, f, f3); + } + matchcases.push_back(mc); } + RPAREN_TOK + )+ + RPAREN_TOK RPAREN_TOK { + //now, make the match + if (matchcases.empty()) + { + PARSER_STATE->parseError("Must have at least one case in match."); } + std::vector mchildren; + mchildren.push_back(expr); + mchildren.insert(mchildren.end(), matchcases.begin(), matchcases.end()); + expr = MK_EXPR(kind::MATCH, mchildren); } /* attributed expressions */ diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index df9bee981..013288880 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -540,7 +540,31 @@ void Smt2Printer::toStream(std::ostream& out, } return; - case kind::LAMBDA: + case kind::LAMBDA: out << smtKindString(k, d_variant) << " "; break; + case kind::MATCH: + out << smtKindString(k, d_variant) << " "; + toStream(out, n[0], toDepth, types, TypeNode::null()); + out << " ("; + for (size_t i = 1, nchild = n.getNumChildren(); i < nchild; i++) + { + if (i > 1) + { + out << " "; + } + toStream(out, n[i], toDepth, types, TypeNode::null()); + } + out << "))"; + return; + case kind::MATCH_BIND_CASE: + // ignore the binder + toStream(out, n[1], toDepth, types, TypeNode::null()); + out << " "; + toStream(out, n[2], toDepth, types, TypeNode::null()); + out << ")"; + return; + case kind::MATCH_CASE: + // do nothing + break; case kind::CHOICE: out << smtKindString(k, d_variant) << " "; break; // arith theory @@ -1030,6 +1054,7 @@ static string smtKindString(Kind k, Variant v) case kind::LAMBDA: return "lambda"; + case kind::MATCH: return "match"; case kind::CHOICE: return "choice"; // arith theory diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp index ac3bff21b..802dedcbd 100644 --- a/src/theory/datatypes/datatypes_rewriter.cpp +++ b/src/theory/datatypes/datatypes_rewriter.cpp @@ -148,6 +148,96 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) return RewriteResponse(REWRITE_AGAIN_FULL, ret); } } + else if (k == MATCH) + { + Trace("dt-rewrite-match") << "Rewrite match: " << in << std::endl; + Node h = in[0]; + std::vector cases; + std::vector rets; + TypeNode t = h.getType(); + const Datatype& dt = t.getDatatype(); + for (size_t k = 1, nchild = in.getNumChildren(); k < nchild; k++) + { + Node c = in[k]; + Node cons; + Kind ck = c.getKind(); + if (ck == MATCH_CASE) + { + Assert(c[0].getKind() == APPLY_CONSTRUCTOR); + cons = c[0].getOperator(); + } + else if (ck == MATCH_BIND_CASE) + { + if (c[1].getKind() == APPLY_CONSTRUCTOR) + { + cons = c[1].getOperator(); + } + } + else + { + AlwaysAssert(false); + } + size_t cindex = 0; + // cons is null in the default case + if (!cons.isNull()) + { + cindex = Datatype::indexOf(cons.toExpr()); + } + Node body; + if (ck == MATCH_CASE) + { + body = c[1]; + } + else if (ck == MATCH_BIND_CASE) + { + std::vector vars; + std::vector subs; + if (cons.isNull()) + { + Assert(c[1].getKind() == BOUND_VARIABLE); + vars.push_back(c[1]); + subs.push_back(h); + } + else + { + for (size_t i = 0, vsize = c[0].getNumChildren(); i < vsize; i++) + { + vars.push_back(c[0][i]); + Node sc = nm->mkNode( + APPLY_SELECTOR_TOTAL, + Node::fromExpr(dt[cindex].getSelectorInternal(t.toType(), i)), + h); + subs.push_back(sc); + } + } + body = + c[2].substitute(vars.begin(), vars.end(), subs.begin(), subs.end()); + } + if (!cons.isNull()) + { + cases.push_back(mkTester(h, cindex, dt)); + } + else + { + // variables have no constraints + cases.push_back(nm->mkConst(true)); + } + rets.push_back(body); + } + Assert(!cases.empty()); + // now make the ITE + std::reverse(cases.begin(), cases.end()); + std::reverse(rets.begin(), rets.end()); + Node ret = rets[0]; + AlwaysAssert(cases[0].isConst() || cases.size() == dt.getNumConstructors()); + for (unsigned i = 1, ncases = cases.size(); i < ncases; i++) + { + ret = nm->mkNode(ITE, cases[i], rets[i], ret); + } + Trace("dt-rewrite-match") + << "Rewrite match: " << in << " ... " << ret << std::endl; + return RewriteResponse(REWRITE_AGAIN_FULL, ret); + } if (k == kind::EQUAL) { diff --git a/src/theory/datatypes/kinds b/src/theory/datatypes/kinds index a0b00bcb0..22d13da0c 100644 --- a/src/theory/datatypes/kinds +++ b/src/theory/datatypes/kinds @@ -119,4 +119,23 @@ typerule DT_SYGUS_BOUND ::CVC4::theory::datatypes::DtSygusBoundTypeRule operator DT_SYGUS_EVAL 1: "datatypes sygus evaluation function" typerule DT_SYGUS_EVAL ::CVC4::theory::datatypes::DtSyguEvalTypeRule + +# Kinds for match terms. For example, the match term +# (match l (((cons h t) h) (nil 0))) +# is represented by the AST +# (MATCH l +# (MATCH_BIND_CASE (BOUND_VAR_LIST h t) (cons h t) h) +# (MATCH_CASE nil 0) +# ) +# where notice that patterns with free variables use MATCH_BIND_CASE whereas +# patterns with no free variables use MATCH_CASE. + +operator MATCH 2: "match construct" +operator MATCH_CASE 2 "a match case" +operator MATCH_BIND_CASE 3 "a match case with bound variables" + +typerule MATCH ::CVC4::theory::datatypes::MatchTypeRule +typerule MATCH_CASE ::CVC4::theory::datatypes::MatchCaseTypeRule +typerule MATCH_BIND_CASE ::CVC4::theory::datatypes::MatchBindCaseTypeRule + endtheory diff --git a/src/theory/datatypes/theory_datatypes_type_rules.h b/src/theory/datatypes/theory_datatypes_type_rules.h index 22ac074f0..c8c16f368 100644 --- a/src/theory/datatypes/theory_datatypes_type_rules.h +++ b/src/theory/datatypes/theory_datatypes_type_rules.h @@ -427,7 +427,165 @@ class DtSyguEvalTypeRule } return TypeNode::fromType(dt.getSygusType()); } -}; /* class DtSygusBoundTypeRule */ +}; /* class DtSyguEvalTypeRule */ + +class MatchTypeRule +{ + public: + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) + { + Assert(n.getKind() == kind::MATCH); + + TypeNode retType; + + TypeNode headType = n[0].getType(check); + if (!headType.isDatatype()) + { + throw TypeCheckingExceptionPrivate(n, "expecting datatype head in match"); + } + const Datatype& hdt = headType.getDatatype(); + + std::unordered_set patIndices; + bool patHasVariable = false; + // the type of a match case list is the least common type of its cases + for (unsigned i = 1, nchildren = n.getNumChildren(); i < nchildren; i++) + { + Node nc = n[i]; + if (check) + { + Kind nck = nc.getKind(); + std::unordered_set bvs; + if (nck == kind::MATCH_BIND_CASE) + { + for (const Node& v : nc[0]) + { + Assert(v.getKind() == kind::BOUND_VARIABLE); + bvs.insert(v); + } + } + else if (nck != kind::MATCH_CASE) + { + throw TypeCheckingExceptionPrivate( + n, "expected a match case in match expression"); + } + // get the pattern type + unsigned pindex = nck == kind::MATCH_CASE ? 0 : 1; + TypeNode patType = nc[pindex].getType(); + // should be caught in the above call + if (!patType.isDatatype()) + { + throw TypeCheckingExceptionPrivate( + n, "expecting datatype pattern in match"); + } + Kind ncpk = nc[pindex].getKind(); + if (ncpk == kind::APPLY_CONSTRUCTOR) + { + for (const Node& arg : nc[pindex]) + { + if (bvs.find(arg) == bvs.end()) + { + throw TypeCheckingExceptionPrivate( + n, + "expecting distinct bound variable as argument to " + "constructor in pattern of match"); + } + bvs.erase(arg); + } + unsigned ci = Datatype::indexOf(nc[pindex].getOperator().toExpr()); + patIndices.insert(ci); + } + else if (ncpk == kind::BOUND_VARIABLE) + { + patHasVariable = true; + } + else + { + throw TypeCheckingExceptionPrivate( + n, "unexpected kind of term in pattern in match"); + } + const Datatype& pdt = patType.getDatatype(); + // compare datatypes instead of the types to catch parametric case, + // where the pattern has parametric type. + if (hdt != pdt) + { + std::stringstream ss; + ss << "pattern of a match case does not match the head type in match"; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } + TypeNode currType = nc.getType(check); + if (i == 1) + { + retType = currType; + } + else + { + retType = TypeNode::leastCommonTypeNode(retType, currType); + if (retType.isNull()) + { + throw TypeCheckingExceptionPrivate( + n, "incomparable types in match case list"); + } + } + } + if (check) + { + if (!patHasVariable && patIndices.size() < hdt.getNumConstructors()) + { + throw TypeCheckingExceptionPrivate( + n, "cases for match term are not exhaustive"); + } + } + return retType; + } +}; /* class MatchTypeRule */ + +class MatchCaseTypeRule +{ + public: + inline static TypeNode computeType(NodeManager* nodeManager, + TNode n, + bool check) + { + Assert(n.getKind() == kind::MATCH_CASE); + if (check) + { + TypeNode patType = n[0].getType(check); + if (!patType.isDatatype()) + { + throw TypeCheckingExceptionPrivate( + n, "expecting datatype pattern in match case"); + } + } + return n[1].getType(check); + } +}; /* class MatchCaseTypeRule */ + +class MatchBindCaseTypeRule +{ + public: + inline static TypeNode computeType(NodeManager* nodeManager, + TNode n, + bool check) + { + Assert(n.getKind() == kind::MATCH_BIND_CASE); + if (check) + { + if (n[0].getKind() != kind::BOUND_VAR_LIST) + { + throw TypeCheckingExceptionPrivate( + n, "expected a bound variable list in match bind case"); + } + TypeNode patType = n[1].getType(check); + if (!patType.isDatatype()) + { + throw TypeCheckingExceptionPrivate( + n, "expecting datatype pattern in match bind case"); + } + } + return n[2].getType(check); + } +}; /* class MatchBindCaseTypeRule */ } /* CVC4::theory::datatypes namespace */ } /* CVC4::theory namespace */ diff --git a/test/regress/regress1/datatypes/dt-color-2.6.smt2 b/test/regress/regress1/datatypes/dt-color-2.6.smt2 index f6148994e..f5d8129ae 100644 --- a/test/regress/regress1/datatypes/dt-color-2.6.smt2 +++ b/test/regress/regress1/datatypes/dt-color-2.6.smt2 @@ -11,7 +11,7 @@ (assert (or (distinct a b c d) (< (match a ((red 5) (green 3) (blue 2))) 0) - (< (match b ((red 2) (_ 1))) 0) + (< (match b ((red 2) (x 1))) 0) )) (check-sat)