From: Gereon Kremer Date: Mon, 31 Jan 2022 21:49:21 +0000 (-0800) Subject: Add utilities for flattening nodes (#7961) X-Git-Tag: cvc5-1.0.0~490 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=ac2a185f2895095f2c8e1ddb52aa33c40ab07e2e;p=cvc5.git Add utilities for flattening nodes (#7961) This PR adds new utilities for flattening nodes and checking whether they can be flattened. They replace the custom implementation we used in the arithmetic rewriter. --- diff --git a/src/expr/CMakeLists.txt b/src/expr/CMakeLists.txt index ab66aa236..67e67ca44 100644 --- a/src/expr/CMakeLists.txt +++ b/src/expr/CMakeLists.txt @@ -14,6 +14,7 @@ ## libcvc5_add_sources( + algorithm/flatten.h array_store_all.cpp array_store_all.h ascription_type.cpp diff --git a/src/expr/algorithm/flatten.h b/src/expr/algorithm/flatten.h new file mode 100644 index 000000000..f48194695 --- /dev/null +++ b/src/expr/algorithm/flatten.h @@ -0,0 +1,121 @@ +/****************************************************************************** + * Top contributors (to current version): + * Gereon Kremer + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Utilities for flattening nodes. + */ + +#include "cvc5_private.h" + +#ifndef CVC5__EXPR__ALGORITHMS__FLATTEN_H +#define CVC5__EXPR__ALGORITHMS__FLATTEN_H + +#include + +#include "expr/node.h" + +namespace cvc5::expr::algorithm { + +/** + * Flatten a node into a vector of its (direct or indirect) children. + * Optionally, a sequence of kinds that should be flattened can be passed. If no + * kinds are given, flattening is done based on the kind of t. + * Note that flatten(t, c) is equivalent to flatten(t, c, t.getKind()). + * @param t The node to be flattened + * @param children The resulting list of children + * @param kinds Optional sequence of kinds to consider for flattening + */ +template +void flatten(TNode t, std::vector& children, Kinds... kinds) +{ + std::vector queue = {t}; + while (!queue.empty()) + { + TNode cur = queue.back(); + queue.pop_back(); + bool recurse = false; + // figure out whether to recurse into cur + if constexpr (sizeof...(kinds) == 0) + { + recurse = t.getKind() == cur.getKind(); + } + else + { + recurse = ((kinds == cur.getKind()) || ...); + } + if (recurse) + { + queue.insert(queue.end(), cur.rbegin(), cur.rend()); + } + else + { + children.emplace_back(cur); + } + } +} + +/** + * Check whether a node can be flattened, that is whether calling flatten() + * returns something other than its direct children. If no kinds are passed + * explicitly, this simply checks whether any of the children has the same kind + * as t. If a sequence of kinds is passed, this checks whether any of the + * children has one of these kinds. + * Note that canFlatten(t) is equivalent to canFlatten(t, t.getKind()). + * @param t The node that should be checked + * @param kinds Optional sequence of kinds + * @return true iff t could be flattened + */ +template +bool canFlatten(TNode t, Kinds... kinds) +{ + if constexpr (sizeof...(kinds) == 0) + { + return std::any_of(t.begin(), t.end(), [k = t.getKind()](TNode child) { + return child.getKind() == k; + }); + } + else + { + if (!((t.getKind() == kinds) || ...)) + { + return false; + } + return std::any_of(t.begin(), t.end(), [=](TNode child) { + return ((child.getKind() == kinds) || ...); + }); + } +} + +/** + * If t can be flattened, return a new node of the same kind as t with the + * flattened children. Otherwise, return t. + * If a sequence of kinds is given, the flattening (and the respective check) + * are done with respect to these kinds: see the documentation of flatten() + * and canFlatten() for more details. + * @param t The node to be flattened + * @param kinds Optional sequence of kinds + * @return A flattened version of t + */ +template +Node flatten(TNode t, Kinds... kinds) +{ + if (!canFlatten(t, kinds...)) + { + return t; + } + std::vector children; + flatten(t, children, kinds...); + return NodeManager::currentNM()->mkNode(t.getKind(), children); +} + +} // namespace cvc5::expr + +#endif diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 9f1fa40bb..e157484af 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -24,6 +24,7 @@ #include #include +#include "expr/algorithm/flatten.h" #include "smt/logic_exception.h" #include "theory/arith/arith_msum.h" #include "theory/arith/arith_utilities.h" @@ -455,52 +456,24 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ } } -static bool canFlatten(Kind k, TNode t){ - for(TNode::iterator i = t.begin(); i != t.end(); ++i) { - TNode child = *i; - if(child.getKind() == k){ - return true; - } - } - return false; -} - -static void flatten(std::vector& pb, Kind k, TNode t){ - if(t.getKind() == k){ - for(TNode::iterator i = t.begin(); i != t.end(); ++i) { - TNode child = *i; - if(child.getKind() == k){ - flatten(pb, k, child); - }else{ - pb.push_back(child); - } - } - }else{ - pb.push_back(t); - } -} - -static Node flatten(Kind k, TNode t){ - std::vector pb; - flatten(pb, k, t); - Assert(pb.size() >= 2); - return NodeManager::currentNM()->mkNode(k, pb); -} RewriteResponse ArithRewriter::preRewritePlus(TNode t){ Assert(t.getKind() == kind::PLUS); - - if(canFlatten(kind::PLUS, t)){ - return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t)); - }else{ - return RewriteResponse(REWRITE_DONE, t); - } + return RewriteResponse(REWRITE_DONE, expr::algorithm::flatten(t)); } RewriteResponse ArithRewriter::postRewritePlus(TNode t){ Assert(t.getKind() == kind::PLUS); Assert(t.getNumChildren() > 1); + { + Node flat = expr::algorithm::flatten(t); + if (flat != t) + { + return RewriteResponse(REWRITE_AGAIN, flat); + } + } + Rational rational; RealAlgebraicNumber ran; std::vector monomials; diff --git a/test/unit/node/CMakeLists.txt b/test/unit/node/CMakeLists.txt index c42e0452a..989e408d2 100644 --- a/test/unit/node/CMakeLists.txt +++ b/test/unit/node/CMakeLists.txt @@ -20,6 +20,7 @@ cvc5_add_unit_test_black(kind_black expr) cvc5_add_unit_test_black(kind_map_black expr) cvc5_add_unit_test_black(node_black expr) cvc5_add_unit_test_black(node_algorithm_black expr) +cvc5_add_unit_test_black(node_algorithms_black expr) cvc5_add_unit_test_black(node_builder_black expr) cvc5_add_unit_test_black(node_manager_black expr) cvc5_add_unit_test_white(node_manager_white expr) diff --git a/test/unit/node/node_algorithms_black.cpp b/test/unit/node/node_algorithms_black.cpp new file mode 100644 index 000000000..ebcb9b6c1 --- /dev/null +++ b/test/unit/node/node_algorithms_black.cpp @@ -0,0 +1,163 @@ +/****************************************************************************** + * Top contributors (to current version): + * Gereon Kremer + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Black box testing of expr/algorithms/ + */ + +#include "base/output.h" +#include "expr/algorithm/flatten.h" +#include "expr/node_manager.h" +#include "test_node.h" + +namespace cvc5 { + +using namespace expr; +using namespace kind; + +namespace test { + +class TestNodeBlackNodeAlgorithms : public TestNode +{ +}; + +TEST_F(TestNodeBlackNodeAlgorithms, flatten) +{ + { + Node x = d_nodeManager->mkBoundVar(*d_realTypeNode); + Node n = d_nodeManager->mkNode(Kind::PLUS, x, x); + EXPECT_FALSE(expr::algorithm::canFlatten(n)); + EXPECT_FALSE(expr::algorithm::canFlatten(n, Kind::PLUS)); + EXPECT_FALSE(expr::algorithm::canFlatten(n, Kind::MULT)); + EXPECT_FALSE(expr::algorithm::canFlatten(n, Kind::PLUS, Kind::MULT)); + EXPECT_EQ(expr::algorithm::flatten(n), n); + EXPECT_EQ(expr::algorithm::flatten(n, Kind::PLUS), n); + EXPECT_EQ(expr::algorithm::flatten(n, Kind::MULT), n); + EXPECT_EQ(expr::algorithm::flatten(n, Kind::PLUS, Kind::MULT), n); + + { + std::vector children; + expr::algorithm::flatten(n, children); + EXPECT_EQ(children.size(), 2); + EXPECT_EQ(children[0], x); + EXPECT_EQ(children[1], x); + } + { + std::vector children; + expr::algorithm::flatten(n, children, Kind::PLUS); + EXPECT_EQ(children.size(), 2); + EXPECT_EQ(children[0], x); + EXPECT_EQ(children[1], x); + } + { + std::vector children; + expr::algorithm::flatten(n, children, Kind::MULT); + EXPECT_EQ(children.size(), 1); + EXPECT_EQ(children[0], n); + } + { + std::vector children; + expr::algorithm::flatten(n, children, Kind::PLUS, Kind::MULT); + EXPECT_EQ(children.size(), 2); + EXPECT_EQ(children[0], x); + EXPECT_EQ(children[1], x); + } + } + { + Node x = d_nodeManager->mkBoundVar(*d_realTypeNode); + Node n = d_nodeManager->mkNode( + Kind::PLUS, x, d_nodeManager->mkNode(Kind::PLUS, x, x)); + EXPECT_TRUE(expr::algorithm::canFlatten(n)); + EXPECT_TRUE(expr::algorithm::canFlatten(n, Kind::PLUS)); + EXPECT_FALSE(expr::algorithm::canFlatten(n, Kind::MULT)); + EXPECT_TRUE(expr::algorithm::canFlatten(n, Kind::PLUS, Kind::MULT)); + EXPECT_NE(expr::algorithm::flatten(n), n); + EXPECT_NE(expr::algorithm::flatten(n, Kind::PLUS), n); + EXPECT_EQ(expr::algorithm::flatten(n, Kind::MULT), n); + EXPECT_NE(expr::algorithm::flatten(n, Kind::PLUS, Kind::MULT), n); + + { + std::vector children; + expr::algorithm::flatten(n, children); + EXPECT_EQ(children.size(), 3); + EXPECT_EQ(children[0], x); + EXPECT_EQ(children[1], x); + EXPECT_EQ(children[2], x); + } + { + std::vector children; + expr::algorithm::flatten(n, children, Kind::PLUS); + EXPECT_EQ(children.size(), 3); + EXPECT_EQ(children[0], x); + EXPECT_EQ(children[1], x); + EXPECT_EQ(children[2], x); + } + { + std::vector children; + expr::algorithm::flatten(n, children, Kind::MULT); + EXPECT_EQ(children.size(), 1); + EXPECT_EQ(children[0], n); + } + { + std::vector children; + expr::algorithm::flatten(n, children, Kind::PLUS, Kind::MULT); + EXPECT_EQ(children.size(), 3); + EXPECT_EQ(children[0], x); + EXPECT_EQ(children[1], x); + EXPECT_EQ(children[2], x); + } + } + { + Node x = d_nodeManager->mkBoundVar(*d_realTypeNode); + Node n = d_nodeManager->mkNode( + Kind::PLUS, x, d_nodeManager->mkNode(Kind::MULT, x, x)); + EXPECT_FALSE(expr::algorithm::canFlatten(n)); + EXPECT_FALSE(expr::algorithm::canFlatten(n, Kind::PLUS)); + EXPECT_FALSE(expr::algorithm::canFlatten(n, Kind::MULT)); + EXPECT_TRUE(expr::algorithm::canFlatten(n, Kind::PLUS, Kind::MULT)); + EXPECT_EQ(expr::algorithm::flatten(n), n); + EXPECT_EQ(expr::algorithm::flatten(n, Kind::PLUS), n); + EXPECT_EQ(expr::algorithm::flatten(n, Kind::MULT), n); + EXPECT_NE(expr::algorithm::flatten(n, Kind::PLUS, Kind::MULT), n); + + { + std::vector children; + expr::algorithm::flatten(n, children); + EXPECT_EQ(children.size(), 2); + EXPECT_EQ(children[0], x); + EXPECT_EQ(children[1], n[1]); + } + { + std::vector children; + expr::algorithm::flatten(n, children, Kind::PLUS); + EXPECT_EQ(children.size(), 2); + EXPECT_EQ(children[0], x); + EXPECT_EQ(children[1], n[1]); + } + { + std::vector children; + expr::algorithm::flatten(n, children, Kind::MULT); + EXPECT_EQ(children.size(), 1); + EXPECT_EQ(children[0], n); + } + { + std::vector children; + expr::algorithm::flatten(n, children, Kind::PLUS, Kind::MULT); + EXPECT_EQ(children.size(), 3); + EXPECT_EQ(children[0], x); + EXPECT_EQ(children[1], x); + EXPECT_EQ(children[2], x); + } + } +} + +} // namespace test +} // namespace cvc5