Get operators in node (#3094)
authoryoni206 <yoni206@users.noreply.github.com>
Tue, 23 Jul 2019 05:39:59 +0000 (22:39 -0700)
committerGitHub <noreply@github.com>
Tue, 23 Jul 2019 05:39:59 +0000 (22:39 -0700)
This commit adds a function to node_algorithm.{h,cpp} that returns the operators that occur in a given node.

src/expr/node_algorithm.cpp
src/expr/node_algorithm.h
test/unit/expr/CMakeLists.txt
test/unit/expr/node_algorithm_black.h [new file with mode: 0644]

index 841f9ea28d4fce06ca98ae69604caef1c46e939a..c20dddbcc3c7e30cdfffaaaf5b1ecee427828f42 100644 (file)
@@ -304,6 +304,47 @@ void getSymbols(TNode n,
   } while (!visit.empty());
 }
 
+void getOperatorsMap(
+    TNode n,
+    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& ops)
+{
+  std::unordered_set<TNode, TNodeHashFunction> visited;
+  getOperatorsMap(n, ops, visited);
+}
+
+void getOperatorsMap(
+    TNode n,
+    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& ops,
+    std::unordered_set<TNode, TNodeHashFunction>& visited)
+{
+  // nodes that we still need to visit
+  std::vector<TNode> visit;
+  // current node
+  TNode cur;
+  visit.push_back(n);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+    // if cur is in the cache, do nothing
+    if (visited.find(cur) == visited.end())
+    {
+      // fetch the correct type
+      TypeNode tn = cur.getType();
+      // add the current operator to the result
+      if (cur.hasOperator())
+      {
+        ops[tn].insert(NodeManager::currentNM()->operatorOf(cur.getKind()));
+      }
+      // add children to visit in the future
+      for (TNode cn : cur)
+      {
+        visit.push_back(cn);
+      }
+    }
+  } while (!visit.empty());
+}
+
 Node substituteCaptureAvoiding(TNode n, Node src, Node dest)
 {
   if (n == src)
index 727f5ba75d6cc5762301abad2be8c514b3921b3f..3686b6686c48e7f5b0adc914fa606d09d7a33b80 100644 (file)
@@ -25,6 +25,7 @@
 #include <vector>
 
 #include "expr/node.h"
+#include "expr/type_node.h"
 
 namespace CVC4 {
 namespace expr {
@@ -87,10 +88,40 @@ bool getVariables(TNode n, std::unordered_set<TNode, TNodeHashFunction>& vs);
  * @param syms The set which the symbols of n are added to
  */
 void getSymbols(TNode n, std::unordered_set<Node, NodeHashFunction>& syms);
-/** Same as above, with a visited cache */
+
+/**
+ * For term n, this function collects the symbols that occur as a subterms
+ * of n. A symbol is a variable that does not have kind BOUND_VARIABLE.
+ * @param n The node under investigation
+ * @param syms The set which the symbols of n are added to
+ * @param visited A cache to be used for visited nodes.
+ */
 void getSymbols(TNode n,
                 std::unordered_set<Node, NodeHashFunction>& syms,
                 std::unordered_set<TNode, TNodeHashFunction>& visited);
+
+/**
+ * For term n, this function collects the operators that occur in n.
+ * @param n The node under investigation
+ * @param ops The map (from each type to operators of that type) which the
+ * operators of n are added to
+ */
+void getOperatorsMap(
+    TNode n,
+    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& ops);
+
+/**
+ * For term n, this function collects the operators that occur in n.
+ * @param n The node under investigation
+ * @param ops The map (from each type to operators of that type) which the
+ * operators of n are added to
+ * @param visited A cache to be used for visited nodes.
+ */
+void getOperatorsMap(
+    TNode n,
+    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& ops,
+    std::unordered_set<TNode, TNodeHashFunction>& visited);
+
 /**
  * Substitution of Nodes in a capture avoiding way.
  */
index 575268cae758e35fa11595cd228543ff6518c094..d487bf56092e9264eff988196989b717ab795415 100644 (file)
@@ -8,6 +8,7 @@ cvc4_add_unit_test_black(expr_public expr)
 cvc4_add_unit_test_black(kind_black expr)
 cvc4_add_unit_test_black(kind_map_black expr)
 cvc4_add_unit_test_black(node_black expr)
+cvc4_add_unit_test_black(node_algorithm_black expr)
 cvc4_add_unit_test_black(node_builder_black expr)
 cvc4_add_unit_test_black(node_manager_black expr)
 cvc4_add_unit_test_white(node_manager_white expr)
diff --git a/test/unit/expr/node_algorithm_black.h b/test/unit/expr/node_algorithm_black.h
new file mode 100644 (file)
index 0000000..2151eef
--- /dev/null
@@ -0,0 +1,133 @@
+/*********************                                                        */
+/*! \file node_algorithm_black.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Yoni Zohar
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Black box testing of utility functions in node_algorithm.{h,cpp}
+ **
+ ** Black box testing of node_algorithm.{h,cpp}
+ **/
+
+#include <cxxtest/TestSuite.h>
+
+#include <string>
+#include <vector>
+
+#include "base/output.h"
+#include "expr/node_algorithm.h"
+#include "expr/node_manager.h"
+#include "util/integer.h"
+#include "util/rational.h"
+
+using namespace CVC4;
+using namespace CVC4::expr;
+using namespace CVC4::kind;
+
+class NodeAlgorithmBlack : public CxxTest::TestSuite
+{
+ private:
+  NodeManager* d_nodeManager;
+  NodeManagerScope* d_scope;
+  TypeNode* d_intTypeNode;
+  TypeNode* d_boolTypeNode;
+
+ public:
+  void setUp() override
+  {
+    d_nodeManager = new NodeManager(NULL);
+    d_scope = new NodeManagerScope(d_nodeManager);
+    d_intTypeNode = new TypeNode(d_nodeManager->integerType());
+    d_boolTypeNode = new TypeNode(d_nodeManager->booleanType());
+  }
+
+  void tearDown() override
+  {
+    delete d_intTypeNode;
+    delete d_boolTypeNode;
+    delete d_scope;
+    delete d_nodeManager;
+  }
+
+  // The only symbol in ~x (x is a boolean varible) should be x
+  void testGetSymbols1()
+  {
+    Node x = d_nodeManager->mkSkolem("x", d_nodeManager->booleanType());
+    Node n = d_nodeManager->mkNode(NOT, x);
+    std::unordered_set<Node, NodeHashFunction> syms;
+    getSymbols(n, syms);
+    TS_ASSERT_EQUALS(syms.size(), 1);
+    TS_ASSERT(syms.find(x) != syms.end());
+  }
+
+  // the only symbols in x=y ^ (exists var. var+var = x) are x and y, because
+  // "var" is bound.
+  void testGetSymbols2()
+  {
+    // left conjunct
+    Node x = d_nodeManager->mkSkolem("x", d_nodeManager->integerType());
+    Node y = d_nodeManager->mkSkolem("y", d_nodeManager->integerType());
+    Node left = d_nodeManager->mkNode(EQUAL, x, y);
+
+    // right conjunct
+    Node var = d_nodeManager->mkBoundVar(*d_intTypeNode);
+    std::vector<Node> vars;
+    vars.push_back(var);
+    Node sum = d_nodeManager->mkNode(PLUS, var, var);
+    Node qeq = d_nodeManager->mkNode(EQUAL, x, sum);
+    Node bvl = d_nodeManager->mkNode(BOUND_VAR_LIST, vars);
+    Node right = d_nodeManager->mkNode(EXISTS, bvl, qeq);
+
+    // conjunction
+    Node res = d_nodeManager->mkNode(AND, left, right);
+
+    // symbols
+    std::unordered_set<Node, NodeHashFunction> syms;
+    getSymbols(res, syms);
+
+    // assertions
+    TS_ASSERT_EQUALS(syms.size(), 2);
+    TS_ASSERT(syms.find(x) != syms.end());
+    TS_ASSERT(syms.find(y) != syms.end());
+    TS_ASSERT(syms.find(var) == syms.end());
+  }
+
+  void testGetOperatorsMap()
+  {
+    // map to store result
+    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction> > result =
+        std::map<TypeNode, std::unordered_set<Node, NodeHashFunction> >();
+
+    // create test formula
+    Node x = d_nodeManager->mkSkolem("x", d_nodeManager->integerType());
+    Node plus = d_nodeManager->mkNode(PLUS, x, x);
+    Node mul = d_nodeManager->mkNode(MULT, x, x);
+    Node eq = d_nodeManager->mkNode(EQUAL, plus, mul);
+
+    // call function
+    expr::getOperatorsMap(eq, result);
+
+    // Verify result
+    // We should have only integer and boolean as types
+    TS_ASSERT(result.size() == 2);
+    TS_ASSERT(result.find(*d_intTypeNode) != result.end());
+    TS_ASSERT(result.find(*d_boolTypeNode) != result.end());
+
+    // in integers, we should only have plus and mult as operators
+    TS_ASSERT(result[*d_intTypeNode].size() == 2);
+    TS_ASSERT(result[*d_intTypeNode].find(d_nodeManager->operatorOf(PLUS))
+              != result[*d_intTypeNode].end());
+    TS_ASSERT(result[*d_intTypeNode].find(d_nodeManager->operatorOf(MULT))
+              != result[*d_intTypeNode].end());
+
+    // in booleans, we should only have "=" as an operator.
+    TS_ASSERT(result[*d_boolTypeNode].size() == 1);
+    TS_ASSERT(result[*d_boolTypeNode].find(d_nodeManager->operatorOf(EQUAL))
+              != result[*d_boolTypeNode].end());
+  }
+};