From 5b494f33f146677386385fc7d055f95d9aae08d5 Mon Sep 17 00:00:00 2001 From: yoni206 Date: Mon, 22 Jul 2019 22:39:59 -0700 Subject: [PATCH] Get operators in node (#3094) This commit adds a function to node_algorithm.{h,cpp} that returns the operators that occur in a given node. --- src/expr/node_algorithm.cpp | 41 ++++++++ src/expr/node_algorithm.h | 33 ++++++- test/unit/expr/CMakeLists.txt | 1 + test/unit/expr/node_algorithm_black.h | 133 ++++++++++++++++++++++++++ 4 files changed, 207 insertions(+), 1 deletion(-) create mode 100644 test/unit/expr/node_algorithm_black.h diff --git a/src/expr/node_algorithm.cpp b/src/expr/node_algorithm.cpp index 841f9ea28..c20dddbcc 100644 --- a/src/expr/node_algorithm.cpp +++ b/src/expr/node_algorithm.cpp @@ -304,6 +304,47 @@ void getSymbols(TNode n, } while (!visit.empty()); } +void getOperatorsMap( + TNode n, + std::map>& ops) +{ + std::unordered_set visited; + getOperatorsMap(n, ops, visited); +} + +void getOperatorsMap( + TNode n, + std::map>& ops, + std::unordered_set& visited) +{ + // nodes that we still need to visit + std::vector visit; + // current node + TNode cur; + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + // if cur is in the cache, do nothing + if (visited.find(cur) == visited.end()) + { + // fetch the correct type + TypeNode tn = cur.getType(); + // add the current operator to the result + if (cur.hasOperator()) + { + ops[tn].insert(NodeManager::currentNM()->operatorOf(cur.getKind())); + } + // add children to visit in the future + for (TNode cn : cur) + { + visit.push_back(cn); + } + } + } while (!visit.empty()); +} + Node substituteCaptureAvoiding(TNode n, Node src, Node dest) { if (n == src) diff --git a/src/expr/node_algorithm.h b/src/expr/node_algorithm.h index 727f5ba75..3686b6686 100644 --- a/src/expr/node_algorithm.h +++ b/src/expr/node_algorithm.h @@ -25,6 +25,7 @@ #include #include "expr/node.h" +#include "expr/type_node.h" namespace CVC4 { namespace expr { @@ -87,10 +88,40 @@ bool getVariables(TNode n, std::unordered_set& vs); * @param syms The set which the symbols of n are added to */ void getSymbols(TNode n, std::unordered_set& syms); -/** Same as above, with a visited cache */ + +/** + * For term n, this function collects the symbols that occur as a subterms + * of n. A symbol is a variable that does not have kind BOUND_VARIABLE. + * @param n The node under investigation + * @param syms The set which the symbols of n are added to + * @param visited A cache to be used for visited nodes. + */ void getSymbols(TNode n, std::unordered_set& syms, std::unordered_set& visited); + +/** + * For term n, this function collects the operators that occur in n. + * @param n The node under investigation + * @param ops The map (from each type to operators of that type) which the + * operators of n are added to + */ +void getOperatorsMap( + TNode n, + std::map>& ops); + +/** + * For term n, this function collects the operators that occur in n. + * @param n The node under investigation + * @param ops The map (from each type to operators of that type) which the + * operators of n are added to + * @param visited A cache to be used for visited nodes. + */ +void getOperatorsMap( + TNode n, + std::map>& ops, + std::unordered_set& visited); + /** * Substitution of Nodes in a capture avoiding way. */ diff --git a/test/unit/expr/CMakeLists.txt b/test/unit/expr/CMakeLists.txt index 575268cae..d487bf560 100644 --- a/test/unit/expr/CMakeLists.txt +++ b/test/unit/expr/CMakeLists.txt @@ -8,6 +8,7 @@ cvc4_add_unit_test_black(expr_public expr) cvc4_add_unit_test_black(kind_black expr) cvc4_add_unit_test_black(kind_map_black expr) cvc4_add_unit_test_black(node_black expr) +cvc4_add_unit_test_black(node_algorithm_black expr) cvc4_add_unit_test_black(node_builder_black expr) cvc4_add_unit_test_black(node_manager_black expr) cvc4_add_unit_test_white(node_manager_white expr) diff --git a/test/unit/expr/node_algorithm_black.h b/test/unit/expr/node_algorithm_black.h new file mode 100644 index 000000000..2151eef72 --- /dev/null +++ b/test/unit/expr/node_algorithm_black.h @@ -0,0 +1,133 @@ +/********************* */ +/*! \file node_algorithm_black.h + ** \verbatim + ** Top contributors (to current version): + ** Yoni Zohar + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Black box testing of utility functions in node_algorithm.{h,cpp} + ** + ** Black box testing of node_algorithm.{h,cpp} + **/ + +#include + +#include +#include + +#include "base/output.h" +#include "expr/node_algorithm.h" +#include "expr/node_manager.h" +#include "util/integer.h" +#include "util/rational.h" + +using namespace CVC4; +using namespace CVC4::expr; +using namespace CVC4::kind; + +class NodeAlgorithmBlack : public CxxTest::TestSuite +{ + private: + NodeManager* d_nodeManager; + NodeManagerScope* d_scope; + TypeNode* d_intTypeNode; + TypeNode* d_boolTypeNode; + + public: + void setUp() override + { + d_nodeManager = new NodeManager(NULL); + d_scope = new NodeManagerScope(d_nodeManager); + d_intTypeNode = new TypeNode(d_nodeManager->integerType()); + d_boolTypeNode = new TypeNode(d_nodeManager->booleanType()); + } + + void tearDown() override + { + delete d_intTypeNode; + delete d_boolTypeNode; + delete d_scope; + delete d_nodeManager; + } + + // The only symbol in ~x (x is a boolean varible) should be x + void testGetSymbols1() + { + Node x = d_nodeManager->mkSkolem("x", d_nodeManager->booleanType()); + Node n = d_nodeManager->mkNode(NOT, x); + std::unordered_set syms; + getSymbols(n, syms); + TS_ASSERT_EQUALS(syms.size(), 1); + TS_ASSERT(syms.find(x) != syms.end()); + } + + // the only symbols in x=y ^ (exists var. var+var = x) are x and y, because + // "var" is bound. + void testGetSymbols2() + { + // left conjunct + Node x = d_nodeManager->mkSkolem("x", d_nodeManager->integerType()); + Node y = d_nodeManager->mkSkolem("y", d_nodeManager->integerType()); + Node left = d_nodeManager->mkNode(EQUAL, x, y); + + // right conjunct + Node var = d_nodeManager->mkBoundVar(*d_intTypeNode); + std::vector vars; + vars.push_back(var); + Node sum = d_nodeManager->mkNode(PLUS, var, var); + Node qeq = d_nodeManager->mkNode(EQUAL, x, sum); + Node bvl = d_nodeManager->mkNode(BOUND_VAR_LIST, vars); + Node right = d_nodeManager->mkNode(EXISTS, bvl, qeq); + + // conjunction + Node res = d_nodeManager->mkNode(AND, left, right); + + // symbols + std::unordered_set syms; + getSymbols(res, syms); + + // assertions + TS_ASSERT_EQUALS(syms.size(), 2); + TS_ASSERT(syms.find(x) != syms.end()); + TS_ASSERT(syms.find(y) != syms.end()); + TS_ASSERT(syms.find(var) == syms.end()); + } + + void testGetOperatorsMap() + { + // map to store result + std::map > result = + std::map >(); + + // create test formula + Node x = d_nodeManager->mkSkolem("x", d_nodeManager->integerType()); + Node plus = d_nodeManager->mkNode(PLUS, x, x); + Node mul = d_nodeManager->mkNode(MULT, x, x); + Node eq = d_nodeManager->mkNode(EQUAL, plus, mul); + + // call function + expr::getOperatorsMap(eq, result); + + // Verify result + // We should have only integer and boolean as types + TS_ASSERT(result.size() == 2); + TS_ASSERT(result.find(*d_intTypeNode) != result.end()); + TS_ASSERT(result.find(*d_boolTypeNode) != result.end()); + + // in integers, we should only have plus and mult as operators + TS_ASSERT(result[*d_intTypeNode].size() == 2); + TS_ASSERT(result[*d_intTypeNode].find(d_nodeManager->operatorOf(PLUS)) + != result[*d_intTypeNode].end()); + TS_ASSERT(result[*d_intTypeNode].find(d_nodeManager->operatorOf(MULT)) + != result[*d_intTypeNode].end()); + + // in booleans, we should only have "=" as an operator. + TS_ASSERT(result[*d_boolTypeNode].size() == 1); + TS_ASSERT(result[*d_boolTypeNode].find(d_nodeManager->operatorOf(EQUAL)) + != result[*d_boolTypeNode].end()); + } +}; -- 2.30.2