Eliminate match from LFSC proofs (#8090)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 23 Feb 2022 20:08:49 +0000 (14:08 -0600)
committerGitHub <noreply@github.com>
Wed, 23 Feb 2022 20:08:49 +0000 (20:08 +0000)
The smt 2.6 term match is very hard to represent in LFSC, this eliminates it in favor of the ITE term that it is syntax sugar for. Like other aspects of the LFSC conversion, this is part of the trusted core.

This avoids internal type errors in the LFSC node converter.

src/proof/lfsc/lfsc_node_converter.cpp
src/proof/lfsc/lfsc_node_converter.h
src/theory/datatypes/datatypes_rewriter.cpp
src/theory/datatypes/datatypes_rewriter.h

index 3dc0aa3aa7c3741eb2433de72bf7cebe0a930967..865e471ef8e0b365776c4c15e7223d8639afb815 100644 (file)
@@ -27,6 +27,7 @@
 #include "expr/skolem_manager.h"
 #include "printer/smt2/smt2_printer.h"
 #include "theory/bv/theory_bv_utils.h"
+#include "theory/datatypes/datatypes_rewriter.h"
 #include "theory/strings/word.h"
 #include "theory/uf/theory_uf_rewriter.h"
 #include "util/bitvector.h"
@@ -69,6 +70,18 @@ LfscNodeConverter::LfscNodeConverter()
       getSymbolInternal(FUNCTION_TYPE, setType, "Seq");
 }
 
+Node LfscNodeConverter::preConvert(Node n)
+{
+  // match is not supported in LFSC syntax, we eliminate it at pre-order
+  // traversal, which avoids type-checking errors during conversion, since e.g.
+  // match case nodes are required but cannot be preserved
+  if (n.getKind() == MATCH)
+  {
+    return theory::datatypes::DatatypesRewriter::expandMatch(n);
+  }
+  return n;
+}
+
 Node LfscNodeConverter::postConvert(Node n)
 {
   NodeManager* nm = NodeManager::currentNM();
index bbfbaba8e0496f1e5ef7803145ea8c4fb21a717d..bd65af5032361b5b3b5e9c8b707e1ccef8b46bde 100644 (file)
@@ -36,9 +36,11 @@ class LfscNodeConverter : public NodeConverter
  public:
   LfscNodeConverter();
   ~LfscNodeConverter() {}
-  /** convert to internal */
+  /** convert at pre-order traversal */
+  Node preConvert(Node n) override;
+  /** convert at post-order traversal */
   Node postConvert(Node n) override;
-  /** convert to internal */
+  /** convert type at post-order traversal */
   TypeNode postConvertType(TypeNode tn) override;
   /**
    * Get the null terminator for kind k and type tn. The type tn can be
index 8879c43b820ad6a766c7a219e380f840db84db5e..12acc1402005b2b7372552be1bd5de2042bb3013 100644 (file)
@@ -153,90 +153,7 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in)
   else if (kind == MATCH)
   {
     Trace("dt-rewrite-match") << "Rewrite match: " << in << std::endl;
-    // ensure we've type checked
-    TypeNode tin = in.getType();
-    Node h = in[0];
-    std::vector<Node> cases;
-    std::vector<Node> rets;
-    TypeNode t = h.getType();
-    const DType& dt = t.getDType();
-    for (size_t k = 1, nchild = in.getNumChildren(); k < nchild; k++)
-    {
-      Node c = in[k];
-      Node cons;
-      Kind ck = c.getKind();
-      if (ck == MATCH_CASE)
-      {
-        Assert(c[0].getKind() == APPLY_CONSTRUCTOR);
-        cons = c[0].getOperator();
-      }
-      else if (ck == MATCH_BIND_CASE)
-      {
-        if (c[1].getKind() == APPLY_CONSTRUCTOR)
-        {
-          cons = c[1].getOperator();
-        }
-      }
-      else
-      {
-        AlwaysAssert(false);
-      }
-      size_t cindex = 0;
-      // cons is null in the default case
-      if (!cons.isNull())
-      {
-        cindex = utils::indexOf(cons);
-      }
-      Node body;
-      if (ck == MATCH_CASE)
-      {
-        body = c[1];
-      }
-      else if (ck == MATCH_BIND_CASE)
-      {
-        std::vector<Node> vars;
-        std::vector<Node> subs;
-        if (cons.isNull())
-        {
-          Assert(c[1].getKind() == BOUND_VARIABLE);
-          vars.push_back(c[1]);
-          subs.push_back(h);
-        }
-        else
-        {
-          for (size_t i = 0, vsize = c[0].getNumChildren(); i < vsize; i++)
-          {
-            vars.push_back(c[0][i]);
-            Node sc =
-                nm->mkNode(APPLY_SELECTOR, dt[cindex][i].getSelector(), h);
-            subs.push_back(sc);
-          }
-        }
-        body =
-            c[2].substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
-      }
-      if (!cons.isNull())
-      {
-        cases.push_back(utils::mkTester(h, cindex, dt));
-      }
-      else
-      {
-        // variables have no constraints
-        cases.push_back(nm->mkConst(true));
-      }
-      rets.push_back(body);
-    }
-    Assert(!cases.empty());
-    // now make the ITE
-    std::reverse(cases.begin(), cases.end());
-    std::reverse(rets.begin(), rets.end());
-    Node ret = rets[0];
-    // notice that due to our type checker, either there is a variable pattern
-    // or all constructors are present in the match.
-    for (size_t i = 1, ncases = cases.size(); i < ncases; i++)
-    {
-      ret = nm->mkNode(ITE, cases[i], rets[i], ret);
-    }
+    Node ret = expandMatch(in);
     Trace("dt-rewrite-match")
         << "Rewrite match: " << in << " ... " << ret << std::endl;
     return RewriteResponse(REWRITE_AGAIN_FULL, ret);
@@ -308,6 +225,94 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in)
 
   return RewriteResponse(REWRITE_DONE, in);
 }
+Node DatatypesRewriter::expandMatch(Node in)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  // ensure we've type checked
+  TypeNode tin = in.getType();
+  Node h = in[0];
+  std::vector<Node> cases;
+  std::vector<Node> rets;
+  TypeNode t = h.getType();
+  const DType& dt = t.getDType();
+  for (size_t k = 1, nchild = in.getNumChildren(); k < nchild; k++)
+  {
+    Node c = in[k];
+    Node cons;
+    Kind ck = c.getKind();
+    if (ck == MATCH_CASE)
+    {
+      Assert(c[0].getKind() == APPLY_CONSTRUCTOR);
+      cons = c[0].getOperator();
+    }
+    else if (ck == MATCH_BIND_CASE)
+    {
+      if (c[1].getKind() == APPLY_CONSTRUCTOR)
+      {
+        cons = c[1].getOperator();
+      }
+    }
+    else
+    {
+      AlwaysAssert(false) << "Bad case for match term";
+    }
+    size_t cindex = 0;
+    // cons is null in the default case
+    if (!cons.isNull())
+    {
+      cindex = utils::indexOf(cons);
+    }
+    Node body;
+    if (ck == MATCH_CASE)
+    {
+      body = c[1];
+    }
+    else if (ck == MATCH_BIND_CASE)
+    {
+      std::vector<Node> vars;
+      std::vector<Node> subs;
+      if (cons.isNull())
+      {
+        Assert(c[1].getKind() == BOUND_VARIABLE);
+        vars.push_back(c[1]);
+        subs.push_back(h);
+      }
+      else
+      {
+        for (size_t i = 0, vsize = c[0].getNumChildren(); i < vsize; i++)
+        {
+          vars.push_back(c[0][i]);
+          Node sc = nm->mkNode(APPLY_SELECTOR, dt[cindex][i].getSelector(), h);
+          subs.push_back(sc);
+        }
+      }
+      body =
+          c[2].substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+    }
+    if (!cons.isNull())
+    {
+      cases.push_back(utils::mkTester(h, cindex, dt));
+    }
+    else
+    {
+      // variables have no constraints
+      cases.push_back(nm->mkConst(true));
+    }
+    rets.push_back(body);
+  }
+  Assert(!cases.empty());
+  // now make the ITE
+  std::reverse(cases.begin(), cases.end());
+  std::reverse(rets.begin(), rets.end());
+  Node ret = rets[0];
+  // notice that due to our type checker, either there is a variable pattern
+  // or all constructors are present in the match.
+  for (size_t i = 1, ncases = cases.size(); i < ncases; i++)
+  {
+    ret = nm->mkNode(ITE, cases[i], rets[i], ret);
+  }
+  return ret;
+}
 
 RewriteResponse DatatypesRewriter::preRewrite(TNode in)
 {
index 31e2a1befa70f51cfe8161a367f4329b56eb53da..86ff493ea014bf476f0ff1567f42575cf2a580cb 100644 (file)
@@ -71,6 +71,14 @@ class DatatypesRewriter : public TheoryRewriter
    * internal selector function for selC (possibly a shared selector).
    */
   static Node expandApplySelector(Node n);
+  /**
+   * Expand a match term into its definition.
+   * For example
+   *   (MATCH x (((APPLY_CONSTRUCTOR CONS y z) z) (APPLY_CONSTRUCTOR NIL x)))
+   * returns
+   *   (ITE (APPLY_TESTER CONS x) (APPLY_SELECTOR x) x)
+   */
+  static Node expandMatch(Node n);
   /** expand defintions */
   TrustNode expandDefinition(Node n) override;