Adding capture avoiding substitution (#2867)
authorHaniel Barbosa <hanielbbarbosa@gmail.com>
Fri, 15 Mar 2019 22:01:42 +0000 (17:01 -0500)
committerGitHub <noreply@github.com>
Fri, 15 Mar 2019 22:01:42 +0000 (17:01 -0500)
src/expr/node.h
src/expr/node_algorithm.cpp
src/expr/node_algorithm.h

index 50add7b1779a30c89676a33935c37cd6ed11db8d..003863c8e6ec0cd8b941799d2f5974c1a01cbee6 100644 (file)
@@ -459,7 +459,7 @@ public:
     assertTNodeNotExpired();
     return getMetaKind() == kind::metakind::VARIABLE;
   }
-  
+
   /**
    * Returns true if this node represents a nullary operator
    */
@@ -467,12 +467,11 @@ public:
     assertTNodeNotExpired();
     return getMetaKind() == kind::metakind::NULLARY_OPERATOR;
   }
-  
+
   inline bool isClosure() const {
     assertTNodeNotExpired();
     return getKind() == kind::LAMBDA || getKind() == kind::FORALL
-           || getKind() == kind::EXISTS || getKind() == kind::REWRITE_RULE
-           || getKind() == kind::CHOICE;
+           || getKind() == kind::EXISTS || getKind() == kind::CHOICE;
   }
 
   /**
index 6923efec279ac8174d2336e822a3443ed2384aa0..dcf78fb37a68f550c0bb5d91ca3b043a9206d11d 100644 (file)
@@ -273,5 +273,117 @@ void getSymbols(TNode n,
   } while (!visit.empty());
 }
 
+Node substituteCaptureAvoiding(TNode n, Node src, Node dest)
+{
+  if (n == src)
+  {
+    return dest;
+  }
+  if (src == dest)
+  {
+    return n;
+  }
+  std::vector<Node> srcs;
+  std::vector<Node> dests;
+  srcs.push_back(src);
+  dests.push_back(dest);
+  return substituteCaptureAvoiding(n, srcs, dests);
+}
+
+Node substituteCaptureAvoiding(TNode n,
+                               std::vector<Node>& src,
+                               std::vector<Node>& dest)
+{
+  std::unordered_map<TNode, Node, TNodeHashFunction> visited;
+  std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
+  std::vector<TNode> visit;
+  TNode curr;
+  visit.push_back(n);
+  Assert(src.size() == dest.size(),
+         "Substitution domain and range must be equal size");
+  do
+  {
+    curr = visit.back();
+    visit.pop_back();
+    it = visited.find(curr);
+
+    if (it == visited.end())
+    {
+      auto itt = std::find(src.rbegin(), src.rend(), curr);
+      if (itt != src.rend())
+      {
+        Assert(
+            (std::distance(src.begin(), itt.base()) - 1) >= 0
+            && static_cast<unsigned>(std::distance(src.begin(), itt.base()) - 1)
+                   < dest.size());
+        Node n = dest[std::distance(src.begin(), itt.base()) - 1];
+        visited[curr] = n;
+        continue;
+      }
+      if (curr.getNumChildren() == 0)
+      {
+        visited[curr] = curr;
+        continue;
+      }
+
+      visited[curr] = Node::null();
+      // if binder, rename variables to avoid capture
+      if (curr.isClosure())
+      {
+        NodeManager* nm = NodeManager::currentNM();
+        // have new vars -> renames subs in the end of current sub
+        for (const Node& v : curr[0])
+        {
+          src.push_back(v);
+          dest.push_back(nm->mkBoundVar(v.getType()));
+        }
+      }
+      // save for post-visit
+      visit.push_back(curr);
+      // visit children
+      if (curr.getMetaKind() == kind::metakind::PARAMETERIZED)
+      {
+        // push the operator
+        visit.push_back(curr.getOperator());
+      }
+      for (unsigned i = 0, size = curr.getNumChildren(); i < size; ++i)
+      {
+        visit.push_back(curr[i]);
+      }
+    }
+    else if (it->second.isNull())
+    {
+      // build node
+      NodeBuilder<> nb(curr.getKind());
+      if (curr.getMetaKind() == kind::metakind::PARAMETERIZED)
+      {
+        // push the operator
+        Assert(visited.find(curr.getOperator()) != visited.end());
+        nb << visited[curr.getOperator()];
+      }
+      // collect substituted children
+      for (unsigned i = 0, size = curr.getNumChildren(); i < size; ++i)
+      {
+        Assert(visited.find(curr[i]) != visited.end());
+        nb << visited[curr[i]];
+      }
+      Node n = nb;
+      visited[curr] = n;
+
+      // remove renaming
+      if (curr.isClosure())
+      {
+        // remove beginning of sub which correspond to renaming of variables in
+        // this binder
+        unsigned nchildren = curr[0].getNumChildren();
+        src.resize(src.size() - nchildren);
+        dest.resize(dest.size() - nchildren);
+      }
+    }
+  } while (!visit.empty());
+  Assert(visited.find(n) != visited.end());
+  return visited[n];
+}
+
 }  // namespace expr
 }  // namespace CVC4
index bf2cb5877c689b50ab1d13b098b5e07c90572c5f..7cc12b6645dddfad77374719a13b0a478b922d0c 100644 (file)
@@ -83,6 +83,19 @@ void getSymbols(TNode n, std::unordered_set<Node, NodeHashFunction>& syms);
 void getSymbols(TNode n,
                 std::unordered_set<Node, NodeHashFunction>& syms,
                 std::unordered_set<TNode, TNodeHashFunction>& visited);
+/**
+ * Substitution of Nodes in a capture avoiding way.
+ */
+Node substituteCaptureAvoiding(TNode n, Node src, Node dest);
+
+/**
+ * Simultaneous substitution of Nodes in a capture avoiding way.  Elements in
+ * source will be replaced by their corresponding element in dest.  Both
+ * vectors should have the same size.
+ */
+Node substituteCaptureAvoiding(TNode n,
+                               std::vector<Node>& src,
+                               std::vector<Node>& dest);
 
 }  // namespace expr
 }  // namespace CVC4