Add utilities for flattening nodes (#7961)
authorGereon Kremer <gkremer@stanford.edu>
Mon, 31 Jan 2022 21:49:21 +0000 (13:49 -0800)
committerGitHub <noreply@github.com>
Mon, 31 Jan 2022 21:49:21 +0000 (21:49 +0000)
This PR adds new utilities for flattening nodes and checking whether they can be flattened. They replace the custom implementation we used in the arithmetic rewriter.

src/expr/CMakeLists.txt
src/expr/algorithm/flatten.h [new file with mode: 0644]
src/theory/arith/arith_rewriter.cpp
test/unit/node/CMakeLists.txt
test/unit/node/node_algorithms_black.cpp [new file with mode: 0644]

index ab66aa236b9cf9df5bf9b69eaad38b01c949b1dd..67e67ca44f13ace45ecd6f565631f5a2c2f7676d 100644 (file)
@@ -14,6 +14,7 @@
 ##
 
 libcvc5_add_sources(
+  algorithm/flatten.h
   array_store_all.cpp
   array_store_all.h
   ascription_type.cpp
diff --git a/src/expr/algorithm/flatten.h b/src/expr/algorithm/flatten.h
new file mode 100644 (file)
index 0000000..f481946
--- /dev/null
@@ -0,0 +1,121 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Gereon Kremer
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Utilities for flattening nodes.
+ */
+
+#include "cvc5_private.h"
+
+#ifndef CVC5__EXPR__ALGORITHMS__FLATTEN_H
+#define CVC5__EXPR__ALGORITHMS__FLATTEN_H
+
+#include <algorithm>
+
+#include "expr/node.h"
+
+namespace cvc5::expr::algorithm {
+
+/**
+ * Flatten a node into a vector of its (direct or indirect) children.
+ * Optionally, a sequence of kinds that should be flattened can be passed. If no
+ * kinds are given, flattening is done based on the kind of t.
+ * Note that flatten(t, c) is equivalent to flatten(t, c, t.getKind()).
+ * @param t The node to be flattened
+ * @param children The resulting list of children
+ * @param kinds Optional sequence of kinds to consider for flattening
+ */
+template <typename... Kinds>
+void flatten(TNode t, std::vector<TNode>& children, Kinds... kinds)
+{
+  std::vector<TNode> queue = {t};
+  while (!queue.empty())
+  {
+    TNode cur = queue.back();
+    queue.pop_back();
+    bool recurse = false;
+    // figure out whether to recurse into cur
+    if constexpr (sizeof...(kinds) == 0)
+    {
+      recurse = t.getKind() == cur.getKind();
+    }
+    else
+    {
+      recurse = ((kinds == cur.getKind()) || ...);
+    }
+    if (recurse)
+    {
+      queue.insert(queue.end(), cur.rbegin(), cur.rend());
+    }
+    else
+    {
+      children.emplace_back(cur);
+    }
+  }
+}
+
+/**
+ * Check whether a node can be flattened, that is whether calling flatten()
+ * returns something other than its direct children. If no kinds are passed
+ * explicitly, this simply checks whether any of the children has the same kind
+ * as t. If a sequence of kinds is passed, this checks whether any of the
+ * children has one of these kinds.
+ * Note that canFlatten(t) is equivalent to canFlatten(t, t.getKind()).
+ * @param t The node that should be checked
+ * @param kinds Optional sequence of kinds
+ * @return true iff t could be flattened
+ */
+template <typename... Kinds>
+bool canFlatten(TNode t, Kinds... kinds)
+{
+  if constexpr (sizeof...(kinds) == 0)
+  {
+    return std::any_of(t.begin(), t.end(), [k = t.getKind()](TNode child) {
+      return child.getKind() == k;
+    });
+  }
+  else
+  {
+    if (!((t.getKind() == kinds) || ...))
+    {
+      return false;
+    }
+    return std::any_of(t.begin(), t.end(), [=](TNode child) {
+      return ((child.getKind() == kinds) || ...);
+    });
+  }
+}
+
+/**
+ * If t can be flattened, return a new node of the same kind as t with the
+ * flattened children. Otherwise, return t.
+ * If a sequence of kinds is given, the flattening (and the respective check)
+ * are done with respect to these kinds: see the documentation of flatten()
+ * and canFlatten() for more details.
+ * @param t The node to be flattened
+ * @param kinds Optional sequence of kinds
+ * @return A flattened version of t
+ */
+template <typename... Kinds>
+Node flatten(TNode t, Kinds... kinds)
+{
+  if (!canFlatten(t, kinds...))
+  {
+    return t;
+  }
+  std::vector<TNode> children;
+  flatten(t, children, kinds...);
+  return NodeManager::currentNM()->mkNode(t.getKind(), children);
+}
+
+}  // namespace cvc5::expr
+
+#endif
index 9f1fa40bb8280ee841c9b35cd7d659a52b25d565..e157484afbb9f204d9ddaed82b964607c859401c 100644 (file)
@@ -24,6 +24,7 @@
 #include <stack>
 #include <vector>
 
+#include "expr/algorithm/flatten.h"
 #include "smt/logic_exception.h"
 #include "theory/arith/arith_msum.h"
 #include "theory/arith/arith_utilities.h"
@@ -455,52 +456,24 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
   }
 }
 
-static bool canFlatten(Kind k, TNode t){
-  for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
-    TNode child = *i;
-    if(child.getKind() == k){
-      return true;
-    }
-  }
-  return false;
-}
-
-static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
-  if(t.getKind() == k){
-    for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
-      TNode child = *i;
-      if(child.getKind() == k){
-        flatten(pb, k, child);
-      }else{
-        pb.push_back(child);
-      }
-    }
-  }else{
-    pb.push_back(t);
-  }
-}
-
-static Node flatten(Kind k, TNode t){
-  std::vector<TNode> pb;
-  flatten(pb, k, t);
-  Assert(pb.size() >= 2);
-  return NodeManager::currentNM()->mkNode(k, pb);
-}
 
 RewriteResponse ArithRewriter::preRewritePlus(TNode t){
   Assert(t.getKind() == kind::PLUS);
-
-  if(canFlatten(kind::PLUS, t)){
-    return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
-  }else{
-    return RewriteResponse(REWRITE_DONE, t);
-  }
+  return RewriteResponse(REWRITE_DONE, expr::algorithm::flatten(t));
 }
 
 RewriteResponse ArithRewriter::postRewritePlus(TNode t){
   Assert(t.getKind() == kind::PLUS);
   Assert(t.getNumChildren() > 1);
 
+  {
+    Node flat = expr::algorithm::flatten(t);
+    if (flat != t)
+    {
+      return RewriteResponse(REWRITE_AGAIN, flat);
+    }
+  }
+
   Rational rational;
   RealAlgebraicNumber ran;
   std::vector<Monomial> monomials;
index c42e0452a2d38a71910be9ef1ce31c5972df7898..989e408d2d656de467c46db737802a33f16211a9 100644 (file)
@@ -20,6 +20,7 @@ cvc5_add_unit_test_black(kind_black expr)
 cvc5_add_unit_test_black(kind_map_black expr)
 cvc5_add_unit_test_black(node_black expr)
 cvc5_add_unit_test_black(node_algorithm_black expr)
+cvc5_add_unit_test_black(node_algorithms_black expr)
 cvc5_add_unit_test_black(node_builder_black expr)
 cvc5_add_unit_test_black(node_manager_black expr)
 cvc5_add_unit_test_white(node_manager_white expr)
diff --git a/test/unit/node/node_algorithms_black.cpp b/test/unit/node/node_algorithms_black.cpp
new file mode 100644 (file)
index 0000000..ebcb9b6
--- /dev/null
@@ -0,0 +1,163 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Gereon Kremer
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Black box testing of expr/algorithms/
+ */
+
+#include "base/output.h"
+#include "expr/algorithm/flatten.h"
+#include "expr/node_manager.h"
+#include "test_node.h"
+
+namespace cvc5 {
+
+using namespace expr;
+using namespace kind;
+
+namespace test {
+
+class TestNodeBlackNodeAlgorithms : public TestNode
+{
+};
+
+TEST_F(TestNodeBlackNodeAlgorithms, flatten)
+{
+  {
+    Node x = d_nodeManager->mkBoundVar(*d_realTypeNode);
+    Node n = d_nodeManager->mkNode(Kind::PLUS, x, x);
+    EXPECT_FALSE(expr::algorithm::canFlatten(n));
+    EXPECT_FALSE(expr::algorithm::canFlatten(n, Kind::PLUS));
+    EXPECT_FALSE(expr::algorithm::canFlatten(n, Kind::MULT));
+    EXPECT_FALSE(expr::algorithm::canFlatten(n, Kind::PLUS, Kind::MULT));
+    EXPECT_EQ(expr::algorithm::flatten(n), n);
+    EXPECT_EQ(expr::algorithm::flatten(n, Kind::PLUS), n);
+    EXPECT_EQ(expr::algorithm::flatten(n, Kind::MULT), n);
+    EXPECT_EQ(expr::algorithm::flatten(n, Kind::PLUS, Kind::MULT), n);
+
+    {
+      std::vector<TNode> children;
+      expr::algorithm::flatten(n, children);
+      EXPECT_EQ(children.size(), 2);
+      EXPECT_EQ(children[0], x);
+      EXPECT_EQ(children[1], x);
+    }
+    {
+      std::vector<TNode> children;
+      expr::algorithm::flatten(n, children, Kind::PLUS);
+      EXPECT_EQ(children.size(), 2);
+      EXPECT_EQ(children[0], x);
+      EXPECT_EQ(children[1], x);
+    }
+    {
+      std::vector<TNode> children;
+      expr::algorithm::flatten(n, children, Kind::MULT);
+      EXPECT_EQ(children.size(), 1);
+      EXPECT_EQ(children[0], n);
+    }
+    {
+      std::vector<TNode> children;
+      expr::algorithm::flatten(n, children, Kind::PLUS, Kind::MULT);
+      EXPECT_EQ(children.size(), 2);
+      EXPECT_EQ(children[0], x);
+      EXPECT_EQ(children[1], x);
+    }
+  }
+  {
+    Node x = d_nodeManager->mkBoundVar(*d_realTypeNode);
+    Node n = d_nodeManager->mkNode(
+        Kind::PLUS, x, d_nodeManager->mkNode(Kind::PLUS, x, x));
+    EXPECT_TRUE(expr::algorithm::canFlatten(n));
+    EXPECT_TRUE(expr::algorithm::canFlatten(n, Kind::PLUS));
+    EXPECT_FALSE(expr::algorithm::canFlatten(n, Kind::MULT));
+    EXPECT_TRUE(expr::algorithm::canFlatten(n, Kind::PLUS, Kind::MULT));
+    EXPECT_NE(expr::algorithm::flatten(n), n);
+    EXPECT_NE(expr::algorithm::flatten(n, Kind::PLUS), n);
+    EXPECT_EQ(expr::algorithm::flatten(n, Kind::MULT), n);
+    EXPECT_NE(expr::algorithm::flatten(n, Kind::PLUS, Kind::MULT), n);
+
+    {
+      std::vector<TNode> children;
+      expr::algorithm::flatten(n, children);
+      EXPECT_EQ(children.size(), 3);
+      EXPECT_EQ(children[0], x);
+      EXPECT_EQ(children[1], x);
+      EXPECT_EQ(children[2], x);
+    }
+    {
+      std::vector<TNode> children;
+      expr::algorithm::flatten(n, children, Kind::PLUS);
+      EXPECT_EQ(children.size(), 3);
+      EXPECT_EQ(children[0], x);
+      EXPECT_EQ(children[1], x);
+      EXPECT_EQ(children[2], x);
+    }
+    {
+      std::vector<TNode> children;
+      expr::algorithm::flatten(n, children, Kind::MULT);
+      EXPECT_EQ(children.size(), 1);
+      EXPECT_EQ(children[0], n);
+    }
+    {
+      std::vector<TNode> children;
+      expr::algorithm::flatten(n, children, Kind::PLUS, Kind::MULT);
+      EXPECT_EQ(children.size(), 3);
+      EXPECT_EQ(children[0], x);
+      EXPECT_EQ(children[1], x);
+      EXPECT_EQ(children[2], x);
+    }
+  }
+  {
+    Node x = d_nodeManager->mkBoundVar(*d_realTypeNode);
+    Node n = d_nodeManager->mkNode(
+        Kind::PLUS, x, d_nodeManager->mkNode(Kind::MULT, x, x));
+    EXPECT_FALSE(expr::algorithm::canFlatten(n));
+    EXPECT_FALSE(expr::algorithm::canFlatten(n, Kind::PLUS));
+    EXPECT_FALSE(expr::algorithm::canFlatten(n, Kind::MULT));
+    EXPECT_TRUE(expr::algorithm::canFlatten(n, Kind::PLUS, Kind::MULT));
+    EXPECT_EQ(expr::algorithm::flatten(n), n);
+    EXPECT_EQ(expr::algorithm::flatten(n, Kind::PLUS), n);
+    EXPECT_EQ(expr::algorithm::flatten(n, Kind::MULT), n);
+    EXPECT_NE(expr::algorithm::flatten(n, Kind::PLUS, Kind::MULT), n);
+
+    {
+      std::vector<TNode> children;
+      expr::algorithm::flatten(n, children);
+      EXPECT_EQ(children.size(), 2);
+      EXPECT_EQ(children[0], x);
+      EXPECT_EQ(children[1], n[1]);
+    }
+    {
+      std::vector<TNode> children;
+      expr::algorithm::flatten(n, children, Kind::PLUS);
+      EXPECT_EQ(children.size(), 2);
+      EXPECT_EQ(children[0], x);
+      EXPECT_EQ(children[1], n[1]);
+    }
+    {
+      std::vector<TNode> children;
+      expr::algorithm::flatten(n, children, Kind::MULT);
+      EXPECT_EQ(children.size(), 1);
+      EXPECT_EQ(children[0], n);
+    }
+    {
+      std::vector<TNode> children;
+      expr::algorithm::flatten(n, children, Kind::PLUS, Kind::MULT);
+      EXPECT_EQ(children.size(), 3);
+      EXPECT_EQ(children[0], x);
+      EXPECT_EQ(children[1], x);
+      EXPECT_EQ(children[2], x);
+    }
+  }
+}
+
+}  // namespace test
+}  // namespace cvc5