Support smt2 language "match" term (#3258)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sat, 28 Sep 2019 02:01:54 +0000 (21:01 -0500)
committerGitHub <noreply@github.com>
Sat, 28 Sep 2019 02:01:54 +0000 (21:01 -0500)
src/expr/node.h
src/parser/smt2/Smt2.g
src/printer/smt2/smt2_printer.cpp
src/theory/datatypes/datatypes_rewriter.cpp
src/theory/datatypes/kinds
src/theory/datatypes/theory_datatypes_type_rules.h
test/regress/regress1/datatypes/dt-color-2.6.smt2

index f0ee7a56c6a18edffd29639fa50a010f3f88f241..b8a665f0ce80bdfd07de3fd6cc6405fc1bc65d9c 100644 (file)
@@ -468,7 +468,8 @@ public:
   inline bool isClosure() const {
     assertTNodeNotExpired();
     return getKind() == kind::LAMBDA || getKind() == kind::FORALL
-           || getKind() == kind::EXISTS || getKind() == kind::CHOICE;
+           || getKind() == kind::EXISTS || getKind() == kind::CHOICE
+           || getKind() == kind::MATCH_BIND_CASE;
   }
 
   /**
index 21e09317d8de46018a1ded5ea7fb21691494b8b6..9a8232df9a7b50f7916737e0d4080b0e1e05a7e5 100644 (file)
@@ -1810,15 +1810,14 @@ termNonVariable[CVC4::Expr& expr, CVC4::Expr& expr2]
   std::string attr;
   Expr attexpr;
   std::vector<Expr> patexprs;
-  std::vector<Expr> patconds;
+  std::vector<Expr> matchcases;
   std::unordered_set<std::string> names;
   std::vector< std::pair<std::string, Expr> > binders;
-  int match_vindex = -1;
-  std::vector<Type> match_ptypes;
   Type type;
   Type type2;
   api::Term atomTerm;
   ParseOp p;
+  std::vector<Type> argTypes;
 }
   : LPAREN_TOK quantOp[kind]
     LPAREN_TOK sortedVarList[sortedVarNames] RPAREN_TOK
@@ -1912,107 +1911,92 @@ termNonVariable[CVC4::Expr& expr, CVC4::Expr& expr2]
     }
     LPAREN_TOK
     (
-      /* match cases */
-       LPAREN_TOK INDEX_TOK term[f, f2] {
-          if( match_vindex==-1 ){
-            match_vindex = (int)patexprs.size();
+      // case with non-nullary pattern
+      LPAREN_TOK LPAREN_TOK term[f, f2] {
+          args.clear();
+          PARSER_STATE->pushScope(true);
+          // f should be a constructor
+          type = f.getType();
+          Debug("parser-dt") << "Pattern head : " << f << " " << type << std::endl;
+          if (!type.isConstructor())
+          {
+            PARSER_STATE->parseError("Pattern must be application of a constructor or a variable.");
+          }
+          if (Datatype::datatypeOf(f).isParametric())
+          {
+            type = Datatype::datatypeOf(f)[Datatype::indexOf(f)].getSpecializedConstructorType(expr.getType());
+          }
+          argTypes = static_cast<ConstructorType>(type).getArgTypes();
+        }
+        // arguments of the pattern
+        ( symbol[name,CHECK_NONE,SYM_VARIABLE] {
+            if (args.size() >= argTypes.size())
+            {
+              PARSER_STATE->parseError("Too many arguments for pattern.");
+            }
+            //make of proper type
+            Expr arg = PARSER_STATE->mkBoundVar(name, argTypes[args.size()]);
+            args.push_back( arg );
           }
-          patexprs.push_back( f );
-          patconds.push_back(MK_CONST(bool(true)));
+        )*
+        RPAREN_TOK term[f3, f2] {
+          // make the match case
+          std::vector<Expr> cargs;
+          cargs.push_back(f);
+          cargs.insert(cargs.end(),args.begin(),args.end());
+          Expr c = MK_EXPR(kind::APPLY_CONSTRUCTOR,cargs);
+          Expr bvl = MK_EXPR(kind::BOUND_VAR_LIST,args);
+          Expr mc = MK_EXPR(kind::MATCH_BIND_CASE, bvl, c, f3);
+          matchcases.push_back(mc);
+          // now, pop the scope
+          PARSER_STATE->popScope();
         }
         RPAREN_TOK
-      | LPAREN_TOK LPAREN_TOK term[f, f2] {
-           args.clear();
-           PARSER_STATE->pushScope(true);
-           //f should be a constructor
-           type = f.getType();
-           Debug("parser-dt") << "Pattern head : " << f << " " << f.getType() << std::endl;
-           if( !type.isConstructor() ){
-             PARSER_STATE->parseError("Pattern must be application of a constructor or a variable.");
-           }
-           if( Datatype::datatypeOf(f).isParametric() ){
-             type = Datatype::datatypeOf(f)[Datatype::indexOf(f)].getSpecializedConstructorType(expr.getType());
-           }
-           match_ptypes = ((ConstructorType)type).getArgTypes();
-         }
-         //arguments
-         ( symbol[name,CHECK_NONE,SYM_VARIABLE] {
-             if( args.size()>=match_ptypes.size() ){
-               PARSER_STATE->parseError("Too many arguments for pattern.");
-             }
-             //make of proper type
-             Expr arg = PARSER_STATE->mkBoundVar(name, match_ptypes[args.size()]);
-             args.push_back( arg );
-           }
-         )*
-         RPAREN_TOK
-         term[f3, f2] {
-           const DatatypeConstructor& dtc = Datatype::datatypeOf(f)[Datatype::indexOf(f)];
-           if( args.size()!=dtc.getNumArgs() ){
-             PARSER_STATE->parseError("Bad number of arguments for application of constructor in pattern.");
-           }
-           //FIXME: make MATCH a kind and make this a rewrite
-           // build a lambda
-           std::vector<Expr> largs;
-           largs.push_back( MK_EXPR( CVC4::kind::BOUND_VAR_LIST, args ) );
-           largs.push_back( f3 );
-           std::vector< Expr > aargs;
-           aargs.push_back( MK_EXPR( CVC4::kind::LAMBDA, largs ) );
-           for( unsigned i=0; i<dtc.getNumArgs(); i++ ){
-             //can apply total version since we will be guarded by ITE condition
-             // however, we need to apply partial version since we don't have the internal selector available
-             aargs.push_back( MK_EXPR( CVC4::kind::APPLY_SELECTOR, dtc[i].getSelector(), expr ) );
-           }
-           patexprs.push_back( MK_EXPR( CVC4::kind::APPLY_UF, aargs ) );
-           patconds.push_back( MK_EXPR( CVC4::kind::APPLY_TESTER, dtc.getTester(), expr ) );
-         }
-         RPAREN_TOK
-         { PARSER_STATE->popScope(); }
-       | LPAREN_TOK symbol[name,CHECK_DECLARED,SYM_VARIABLE] {
-           f = PARSER_STATE->getVariable(name);
-           type = f.getType();
-           if( !type.isConstructor() || !((ConstructorType)type).getArgTypes().empty() ){
-             PARSER_STATE->parseError("Must apply constructors of arity greater than 0 to arguments in pattern.");
-           }
-         }
-         term[f3, f2] {
-           const DatatypeConstructor& dtc = Datatype::datatypeOf(f)[Datatype::indexOf(f)];
-           patexprs.push_back( f3 );
-           patconds.push_back( MK_EXPR( CVC4::kind::APPLY_TESTER, dtc.getTester(), expr ) );
-         }
-         RPAREN_TOK
-    )+
-    RPAREN_TOK RPAREN_TOK  {
-      if( match_vindex==-1 ){
-        const Datatype& dt = ((DatatypeType)expr.getType()).getDatatype();
-        std::map< unsigned, bool > processed;
-        unsigned count = 0;
-        //ensure that all datatype constructors are matched (to ensure exhaustiveness)
-        for( unsigned i=0; i<patconds.size(); i++ ){
-          unsigned curr_index = Datatype::indexOf(patconds[i].getOperator());
-          if( curr_index<0 && curr_index>=dt.getNumConstructors() ){
-            PARSER_STATE->parseError("Pattern is not legal for the head of a match.");
+      // case with nullary or variable pattern
+      | LPAREN_TOK symbol[name,CHECK_NONE,SYM_VARIABLE] {
+          if (PARSER_STATE->isDeclared(name,SYM_VARIABLE))
+          {
+            f = PARSER_STATE->getVariable(name);
+            type = f.getType();
+            if (!type.isConstructor() || 
+                !((ConstructorType)type).getArgTypes().empty())
+            {
+              PARSER_STATE->parseError("Must apply constructors of arity greater than 0 to arguments in pattern.");
+            }
+            // make nullary constructor application
+            f = MK_EXPR(kind::APPLY_CONSTRUCTOR, f);
           }
-          if( processed.find( curr_index )==processed.end() ){
-            processed[curr_index] = true;
-            count++;
+          else
+          {
+            // it has the type of the head expr
+            f = PARSER_STATE->mkBoundVar(name, expr.getType());
           }
         }
-        if( count!=dt.getNumConstructors() ){
-          PARSER_STATE->parseError("Patterns are not exhaustive in a match construct.");
-        }
-      }
-      //now, make the ITE
-      int end_index = match_vindex==-1 ? patexprs.size()-1 : match_vindex;
-      bool first_time = true;
-      for( int index = end_index; index>=0; index-- ){
-        if( first_time ){
-          expr = patexprs[index];
-          first_time = false;
-        }else{
-          expr = MK_EXPR( CVC4::kind::ITE, patconds[index], patexprs[index], expr );
+        term[f3, f2] {
+          Expr mc;
+          if (f.getKind() == kind::BOUND_VARIABLE)
+          {
+            Expr bvl = MK_EXPR(kind::BOUND_VAR_LIST, f);
+            mc = MK_EXPR(kind::MATCH_BIND_CASE, bvl, f, f3);
+          }
+          else
+          {
+            mc = MK_EXPR(kind::MATCH_CASE, f, f3);
+          }
+          matchcases.push_back(mc);
         }
+        RPAREN_TOK
+    )+
+    RPAREN_TOK RPAREN_TOK  {
+      //now, make the match
+      if (matchcases.empty())
+      {
+        PARSER_STATE->parseError("Must have at least one case in match.");
       }
+      std::vector<Expr> mchildren;
+      mchildren.push_back(expr);
+      mchildren.insert(mchildren.end(), matchcases.begin(), matchcases.end());
+      expr = MK_EXPR(kind::MATCH, mchildren);
     }
 
     /* attributed expressions */
index df9bee98139a30094b46f6a167b66b7d101911c8..013288880923db93aa97561ec513ae76f8d20f13 100644 (file)
@@ -540,7 +540,31 @@ void Smt2Printer::toStream(std::ostream& out,
     }
     return;
 
-  case kind::LAMBDA:
+  case kind::LAMBDA: out << smtKindString(k, d_variant) << " "; break;
+  case kind::MATCH:
+    out << smtKindString(k, d_variant) << " ";
+    toStream(out, n[0], toDepth, types, TypeNode::null());
+    out << " (";
+    for (size_t i = 1, nchild = n.getNumChildren(); i < nchild; i++)
+    {
+      if (i > 1)
+      {
+        out << " ";
+      }
+      toStream(out, n[i], toDepth, types, TypeNode::null());
+    }
+    out << "))";
+    return;
+  case kind::MATCH_BIND_CASE:
+    // ignore the binder
+    toStream(out, n[1], toDepth, types, TypeNode::null());
+    out << " ";
+    toStream(out, n[2], toDepth, types, TypeNode::null());
+    out << ")";
+    return;
+  case kind::MATCH_CASE:
+    // do nothing
+    break;
   case kind::CHOICE: out << smtKindString(k, d_variant) << " "; break;
 
   // arith theory
@@ -1030,6 +1054,7 @@ static string smtKindString(Kind k, Variant v)
 
   case kind::LAMBDA:
     return "lambda";
+  case kind::MATCH: return "match";
   case kind::CHOICE: return "choice";
 
   // arith theory
index ac3bff21bffa74267d74ec33a207718a45751301..802dedcbdbac0040cd710005c021db6d4d2ee9e7 100644 (file)
@@ -148,6 +148,96 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in)
       return RewriteResponse(REWRITE_AGAIN_FULL, ret);
     }
   }
+  else if (k == MATCH)
+  {
+    Trace("dt-rewrite-match") << "Rewrite match: " << in << std::endl;
+    Node h = in[0];
+    std::vector<Node> cases;
+    std::vector<Node> rets;
+    TypeNode t = h.getType();
+    const Datatype& dt = t.getDatatype();
+    for (size_t k = 1, nchild = in.getNumChildren(); k < nchild; k++)
+    {
+      Node c = in[k];
+      Node cons;
+      Kind ck = c.getKind();
+      if (ck == MATCH_CASE)
+      {
+        Assert(c[0].getKind() == APPLY_CONSTRUCTOR);
+        cons = c[0].getOperator();
+      }
+      else if (ck == MATCH_BIND_CASE)
+      {
+        if (c[1].getKind() == APPLY_CONSTRUCTOR)
+        {
+          cons = c[1].getOperator();
+        }
+      }
+      else
+      {
+        AlwaysAssert(false);
+      }
+      size_t cindex = 0;
+      // cons is null in the default case
+      if (!cons.isNull())
+      {
+        cindex = Datatype::indexOf(cons.toExpr());
+      }
+      Node body;
+      if (ck == MATCH_CASE)
+      {
+        body = c[1];
+      }
+      else if (ck == MATCH_BIND_CASE)
+      {
+        std::vector<Node> vars;
+        std::vector<Node> subs;
+        if (cons.isNull())
+        {
+          Assert(c[1].getKind() == BOUND_VARIABLE);
+          vars.push_back(c[1]);
+          subs.push_back(h);
+        }
+        else
+        {
+          for (size_t i = 0, vsize = c[0].getNumChildren(); i < vsize; i++)
+          {
+            vars.push_back(c[0][i]);
+            Node sc = nm->mkNode(
+                APPLY_SELECTOR_TOTAL,
+                Node::fromExpr(dt[cindex].getSelectorInternal(t.toType(), i)),
+                h);
+            subs.push_back(sc);
+          }
+        }
+        body =
+            c[2].substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+      }
+      if (!cons.isNull())
+      {
+        cases.push_back(mkTester(h, cindex, dt));
+      }
+      else
+      {
+        // variables have no constraints
+        cases.push_back(nm->mkConst(true));
+      }
+      rets.push_back(body);
+    }
+    Assert(!cases.empty());
+    // now make the ITE
+    std::reverse(cases.begin(), cases.end());
+    std::reverse(rets.begin(), rets.end());
+    Node ret = rets[0];
+    AlwaysAssert(cases[0].isConst() || cases.size() == dt.getNumConstructors());
+    for (unsigned i = 1, ncases = cases.size(); i < ncases; i++)
+    {
+      ret = nm->mkNode(ITE, cases[i], rets[i], ret);
+    }
+    Trace("dt-rewrite-match")
+        << "Rewrite match: " << in << " ... " << ret << std::endl;
+    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+  }
 
   if (k == kind::EQUAL)
   {
index a0b00bcb0a31e2af3475dd60d8241afba1ce2481..22d13da0c72231f6a6cd7f0e1ca3f4b90413a074 100644 (file)
@@ -119,4 +119,23 @@ typerule DT_SYGUS_BOUND ::CVC4::theory::datatypes::DtSygusBoundTypeRule
 operator DT_SYGUS_EVAL 1: "datatypes sygus evaluation function"
 typerule DT_SYGUS_EVAL ::CVC4::theory::datatypes::DtSyguEvalTypeRule
 
+
+# Kinds for match terms. For example, the match term
+#   (match l (((cons h t) h) (nil 0))) 
+# is represented by the AST
+#   (MATCH l 
+#      (MATCH_BIND_CASE (BOUND_VAR_LIST h t) (cons h t) h)
+#      (MATCH_CASE nil 0)
+#   )
+# where notice that patterns with free variables use MATCH_BIND_CASE whereas
+# patterns with no free variables use MATCH_CASE.
+
+operator MATCH 2: "match construct"
+operator MATCH_CASE 2 "a match case"
+operator MATCH_BIND_CASE 3 "a match case with bound variables"
+
+typerule MATCH ::CVC4::theory::datatypes::MatchTypeRule
+typerule MATCH_CASE ::CVC4::theory::datatypes::MatchCaseTypeRule
+typerule MATCH_BIND_CASE ::CVC4::theory::datatypes::MatchBindCaseTypeRule
+
 endtheory
index 22ac074f0d0ed9630abf5765239de89eebdb69ca..c8c16f368c46b639833134ce31aa552d4d410633 100644 (file)
@@ -427,7 +427,165 @@ class DtSyguEvalTypeRule
     }
     return TypeNode::fromType(dt.getSygusType());
   }
-}; /* class DtSygusBoundTypeRule */
+}; /* class DtSyguEvalTypeRule */
+
+class MatchTypeRule
+{
+ public:
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
+  {
+    Assert(n.getKind() == kind::MATCH);
+
+    TypeNode retType;
+
+    TypeNode headType = n[0].getType(check);
+    if (!headType.isDatatype())
+    {
+      throw TypeCheckingExceptionPrivate(n, "expecting datatype head in match");
+    }
+    const Datatype& hdt = headType.getDatatype();
+
+    std::unordered_set<unsigned> patIndices;
+    bool patHasVariable = false;
+    // the type of a match case list is the least common type of its cases
+    for (unsigned i = 1, nchildren = n.getNumChildren(); i < nchildren; i++)
+    {
+      Node nc = n[i];
+      if (check)
+      {
+        Kind nck = nc.getKind();
+        std::unordered_set<Node, NodeHashFunction> bvs;
+        if (nck == kind::MATCH_BIND_CASE)
+        {
+          for (const Node& v : nc[0])
+          {
+            Assert(v.getKind() == kind::BOUND_VARIABLE);
+            bvs.insert(v);
+          }
+        }
+        else if (nck != kind::MATCH_CASE)
+        {
+          throw TypeCheckingExceptionPrivate(
+              n, "expected a match case in match expression");
+        }
+        // get the pattern type
+        unsigned pindex = nck == kind::MATCH_CASE ? 0 : 1;
+        TypeNode patType = nc[pindex].getType();
+        // should be caught in the above call
+        if (!patType.isDatatype())
+        {
+          throw TypeCheckingExceptionPrivate(
+              n, "expecting datatype pattern in match");
+        }
+        Kind ncpk = nc[pindex].getKind();
+        if (ncpk == kind::APPLY_CONSTRUCTOR)
+        {
+          for (const Node& arg : nc[pindex])
+          {
+            if (bvs.find(arg) == bvs.end())
+            {
+              throw TypeCheckingExceptionPrivate(
+                  n,
+                  "expecting distinct bound variable as argument to "
+                  "constructor in pattern of match");
+            }
+            bvs.erase(arg);
+          }
+          unsigned ci = Datatype::indexOf(nc[pindex].getOperator().toExpr());
+          patIndices.insert(ci);
+        }
+        else if (ncpk == kind::BOUND_VARIABLE)
+        {
+          patHasVariable = true;
+        }
+        else
+        {
+          throw TypeCheckingExceptionPrivate(
+              n, "unexpected kind of term in pattern in match");
+        }
+        const Datatype& pdt = patType.getDatatype();
+        // compare datatypes instead of the types to catch parametric case,
+        // where the pattern has parametric type.
+        if (hdt != pdt)
+        {
+          std::stringstream ss;
+          ss << "pattern of a match case does not match the head type in match";
+          throw TypeCheckingExceptionPrivate(n, ss.str());
+        }
+      }
+      TypeNode currType = nc.getType(check);
+      if (i == 1)
+      {
+        retType = currType;
+      }
+      else
+      {
+        retType = TypeNode::leastCommonTypeNode(retType, currType);
+        if (retType.isNull())
+        {
+          throw TypeCheckingExceptionPrivate(
+              n, "incomparable types in match case list");
+        }
+      }
+    }
+    if (check)
+    {
+      if (!patHasVariable && patIndices.size() < hdt.getNumConstructors())
+      {
+        throw TypeCheckingExceptionPrivate(
+            n, "cases for match term are not exhaustive");
+      }
+    }
+    return retType;
+  }
+}; /* class MatchTypeRule */
+
+class MatchCaseTypeRule
+{
+ public:
+  inline static TypeNode computeType(NodeManager* nodeManager,
+                                     TNode n,
+                                     bool check)
+  {
+    Assert(n.getKind() == kind::MATCH_CASE);
+    if (check)
+    {
+      TypeNode patType = n[0].getType(check);
+      if (!patType.isDatatype())
+      {
+        throw TypeCheckingExceptionPrivate(
+            n, "expecting datatype pattern in match case");
+      }
+    }
+    return n[1].getType(check);
+  }
+}; /* class MatchCaseTypeRule */
+
+class MatchBindCaseTypeRule
+{
+ public:
+  inline static TypeNode computeType(NodeManager* nodeManager,
+                                     TNode n,
+                                     bool check)
+  {
+    Assert(n.getKind() == kind::MATCH_BIND_CASE);
+    if (check)
+    {
+      if (n[0].getKind() != kind::BOUND_VAR_LIST)
+      {
+        throw TypeCheckingExceptionPrivate(
+            n, "expected a bound variable list in match bind case");
+      }
+      TypeNode patType = n[1].getType(check);
+      if (!patType.isDatatype())
+      {
+        throw TypeCheckingExceptionPrivate(
+            n, "expecting datatype pattern in match bind case");
+      }
+    }
+    return n[2].getType(check);
+  }
+}; /* class MatchBindCaseTypeRule */
 
 } /* CVC4::theory::datatypes namespace */
 } /* CVC4::theory namespace */
index f6148994eab1ffbca1b8ed502d53cc7db332277d..f5d8129aea67d1d76195ac1d7bb19af8e11ba4e9 100644 (file)
@@ -11,7 +11,7 @@
 
 (assert (or (distinct a b c d)
  (< (match a ((red 5) (green 3) (blue 2))) 0)
- (< (match b ((red 2) (_ 1))) 0)
+ (< (match b ((red 2) (x 1))) 0)
  ))
 
 (check-sat)