Update to standard implementation of contains term (#3270)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 12 Sep 2019 22:31:22 +0000 (17:31 -0500)
committerGitHub <noreply@github.com>
Thu, 12 Sep 2019 22:31:22 +0000 (17:31 -0500)
src/expr/node_algorithm.cpp
src/expr/node_algorithm.h
src/theory/quantifiers/cegqi/ceg_instantiator.cpp
src/theory/quantifiers/instantiate.cpp
src/theory/quantifiers/quantifiers_rewriter.cpp
src/theory/quantifiers/term_util.cpp
src/theory/quantifiers/term_util.h

index 50ac8297ceeeca0ef0ba59834a8d785997685313..59e3d3b0313eca4df04f0db9f679e9fd2e90fa04 100644 (file)
@@ -34,16 +34,26 @@ bool hasSubterm(TNode n, TNode t, bool strict)
 
   toProcess.push_back(n);
 
+  // incrementally iterate and add to toProcess
   for (unsigned i = 0; i < toProcess.size(); ++i)
   {
     TNode current = toProcess[i];
-    if (current.hasOperator() && current.getOperator() == t)
+    for (unsigned j = 0, j_end = current.getNumChildren(); j <= j_end; ++j)
     {
-      return true;
-    }
-    for (unsigned j = 0, j_end = current.getNumChildren(); j < j_end; ++j)
-    {
-      TNode child = current[j];
+      TNode child;
+      // try children then operator
+      if (j < j_end)
+      {
+        child = current[j];
+      }
+      else if (current.hasOperator())
+      {
+        child = current.getOperator();
+      }
+      else
+      {
+        break;
+      }
       if (child == t)
       {
         return true;
@@ -118,6 +128,61 @@ bool hasSubtermMulti(TNode n, TNode t)
   return false;
 }
 
+bool hasSubterm(TNode n, const std::vector<Node>& t, bool strict)
+{
+  if (t.empty())
+  {
+    return false;
+  }
+  if (!strict && std::find(t.begin(), t.end(), n) != t.end())
+  {
+    return true;
+  }
+
+  std::unordered_set<TNode, TNodeHashFunction> visited;
+  std::vector<TNode> toProcess;
+
+  toProcess.push_back(n);
+
+  // incrementally iterate and add to toProcess
+  for (unsigned i = 0; i < toProcess.size(); ++i)
+  {
+    TNode current = toProcess[i];
+    for (unsigned j = 0, j_end = current.getNumChildren(); j <= j_end; ++j)
+    {
+      TNode child;
+      // try children then operator
+      if (j < j_end)
+      {
+        child = current[j];
+      }
+      else if (current.hasOperator())
+      {
+        child = current.getOperator();
+      }
+      else
+      {
+        break;
+      }
+      if (std::find(t.begin(), t.end(), child) != t.end())
+      {
+        return true;
+      }
+      if (visited.find(child) != visited.end())
+      {
+        continue;
+      }
+      else
+      {
+        visited.insert(child);
+        toProcess.push_back(child);
+      }
+    }
+  }
+
+  return false;
+}
+
 struct HasBoundVarTag
 {
 };
index 17d7d951b61373ce502975b7a388ae44ba9772bb..e5a21d5652f25bf687f5418238b4e368e19b3b43 100644 (file)
@@ -44,6 +44,15 @@ bool hasSubterm(TNode n, TNode t, bool strict = false);
  */
 bool hasSubtermMulti(TNode n, TNode t);
 
+/**
+ * Check if the node n has a subterm that occurs in t.
+ * @param n The node to search in
+ * @param t The set of subterms to search for
+ * @param strict If true, a term is not considered to be a subterm of itself
+ * @return true iff there is a term in t that is a subterm in n
+ */
+bool hasSubterm(TNode n, const std::vector<Node>& t, bool strict = false);
+
 /**
  * Returns true iff the node n contains a bound variable, that is a node of
  * kind BOUND_VARIABLE. This bound variable may or may not be free.
index 104e40d8b9539e9874f9060ce9d7ed7f11a738a6..1713c21e2f2968e2dad7cc0b5d4698c689306b8f 100644 (file)
@@ -1245,7 +1245,8 @@ Node CegInstantiator::applySubstitution( TypeNode tn, Node n, std::vector< Node
           Node nretc = children.size()==1 ? children[0] : NodeManager::currentNM()->mkNode( PLUS, children );
           nretc = Rewriter::rewrite( nretc );
           //ensure that nret does not contain vars
-          if( !TermUtil::containsTerms( nretc, vars ) ){
+          if (!expr::hasSubterm(nretc, vars))
+          {
             //result is ( nret / pv_prop.d_coeff )
             nret = nretc;
           }else{
index ea90ddd66349bf8a7164cc692b56117d68ba6f0f..c6427a4c48619d2b4a8ac9e9f05f2615c1535fe7 100644 (file)
@@ -14,6 +14,7 @@
 
 #include "theory/quantifiers/instantiate.h"
 
+#include "expr/node_algorithm.h"
 #include "options/quantifiers_options.h"
 #include "smt/smt_statistics_registry.h"
 #include "theory/quantifiers/cegqi/inst_strategy_cegqi.h"
@@ -170,8 +171,7 @@ bool Instantiate::addInstantiation(
                         << terms[i] << std::endl;
           bad_inst = true;
         }
-        else if (quantifiers::TermUtil::containsTerms(
-                     terms[i], d_term_util->d_inst_constants[q]))
+        else if (expr::hasSubterm(terms[i], d_term_util->d_inst_constants[q]))
         {
           Trace("inst") << "***& inst contains inst constants : " << terms[i]
                         << std::endl;
index f5159a63025478f3ed2078985e0939fb6b1c5997..33da466754200a74542c822096282a1151a9fe9a 100644 (file)
@@ -1808,9 +1808,10 @@ Node QuantifiersRewriter::computeMiniscoping( std::vector< Node >& args, Node bo
       Node newBody = body;
       NodeBuilder<> body_split(kind::OR);
       NodeBuilder<> tb(kind::OR);
-      for( unsigned i=0; i<body.getNumChildren(); i++ ){
-        Node trm = body[i];
-        if( TermUtil::containsTerms( body[i], args ) ){
+      for (const Node& trm : body)
+      {
+        if (expr::hasSubterm(trm, args))
+        {
           tb << trm;
         }else{
           body_split << trm;
index ffd808ed3716b6d98978c5550f1e052067db8ba0..48dc88537517e97de82b7d2401d5b49cd56f737c 100644 (file)
@@ -489,15 +489,17 @@ Node TermUtil::rewriteVtsSymbols( Node n ) {
 bool TermUtil::containsVtsTerm( Node n, bool isFree ) {
   std::vector< Node > t;
   getVtsTerms( t, isFree, false );
-  return containsTerms( n, t );
+  return expr::hasSubterm(n, t);
 }
 
 bool TermUtil::containsVtsTerm( std::vector< Node >& n, bool isFree ) {
   std::vector< Node > t;
   getVtsTerms( t, isFree, false );
   if( !t.empty() ){
-    for( unsigned i=0; i<n.size(); i++ ){
-      if( containsTerms( n[i], t ) ){
+    for (const Node& nc : n)
+    {
+      if (expr::hasSubterm(nc, t))
+      {
         return true;
       }
     }
@@ -508,7 +510,7 @@ bool TermUtil::containsVtsTerm( std::vector< Node >& n, bool isFree ) {
 bool TermUtil::containsVtsInfinity( Node n, bool isFree ) {
   std::vector< Node > t;
   getVtsTerms( t, isFree, false, false );
-  return containsTerms( n, t );
+  return expr::hasSubterm(n, t);
 }
 
 Node TermUtil::ensureType( Node n, TypeNode tn ) {
@@ -524,40 +526,6 @@ Node TermUtil::ensureType( Node n, TypeNode tn ) {
   }
 }
 
-bool TermUtil::containsTerms2( Node n, std::vector< Node >& t, std::map< Node, bool >& visited ) {
-  if (visited.find(n) == visited.end())
-  {
-    if( std::find( t.begin(), t.end(), n )!=t.end() ){
-      return true;
-    }
-    visited[n] = true;
-    if (n.hasOperator())
-    {
-      if (containsTerms2(n.getOperator(), t, visited))
-      {
-        return true;
-      }
-    }
-    for (const Node& nc : n)
-    {
-      if (containsTerms2(nc, t, visited))
-      {
-        return true;
-      }
-    }
-  }
-  return false;
-}
-
-bool TermUtil::containsTerms( Node n, std::vector< Node >& t ) {
-  if( t.empty() ){
-    return false;
-  }else{
-    std::map< Node, bool > visited;
-    return containsTerms2( n, t, visited );
-  }
-}
-
 int TermUtil::getTermDepth( Node n ) {
   if (!n.hasAttribute(TermDepthAttribute()) ){
     int maxDepth = -1;
index b39a4e129cfc9e6ad6a61c764f7e154779e8c108..99ea483d91113d380d406988874f9bf228bf0f29 100644 (file)
@@ -219,8 +219,6 @@ public:
 //general utilities
   // TODO #1216 : promote these?
  private:
-  //helper for contains term
-  static bool containsTerms2( Node n, std::vector< Node >& t, std::map< Node, bool >& visited );
   /** cache for getTypeValue */
   std::unordered_map<TypeNode,
                      std::unordered_map<int, Node>,
@@ -244,8 +242,6 @@ public:
       d_type_value_offset_status;
 
  public:
-  /** simple check for contains term, true if contains at least one term in t */
-  static bool containsTerms( Node n, std::vector< Node >& t );
   /** contains uninterpreted constant */
   static bool containsUninterpretedConstant( Node n );
   /** get the term depth of n */