Change recursive expandDefinitions() to an interative worklist-based one; we were...
authorMorgan Deters <mdeters@cs.nyu.edu>
Tue, 20 Aug 2013 20:32:31 +0000 (16:32 -0400)
committerMorgan Deters <mdeters@cs.nyu.edu>
Tue, 20 Aug 2013 20:54:06 +0000 (16:54 -0400)
src/smt/smt_engine.cpp

index 0a9afab3c374796ad6f9a7fd09cb5eb242831d4d..df253bef5d9b40f6fde4eb52fd63434f69476fbb 100644 (file)
@@ -1333,178 +1333,210 @@ Node SmtEnginePrivate::expandBVDivByZero(TNode n) {
 Node SmtEnginePrivate::expandDefinitions(TNode n, hash_map<Node, Node, NodeHashFunction>& cache)
   throw(TypeCheckingException, LogicException) {
 
-  Kind k = n.getKind();
-
-  if(k != kind::APPLY && n.getNumChildren() == 0) {
-    SmtEngine::DefinedFunctionMap::const_iterator i = d_smt.d_definedFunctions->find(n);
-    if(i != d_smt.d_definedFunctions->end()) {
-      // replacement must be closed
-      if((*i).second.getFormals().size() > 0) {
-        return d_smt.d_nodeManager->mkNode(kind::LAMBDA, d_smt.d_nodeManager->mkNode(kind::BOUND_VAR_LIST, (*i).second.getFormals()), (*i).second.getFormula());
+  stack< triple<Node, Node, bool> > worklist;
+  stack<Node> result;
+  worklist.push(make_triple(Node(n), Node(n), false));
+
+  do {
+    n = worklist.top().first;
+    Node node = worklist.top().second;
+    bool childrenPushed = worklist.top().third;
+    worklist.pop();
+
+    if(!childrenPushed) {
+      Kind k = n.getKind();
+
+      if(k != kind::APPLY && n.getNumChildren() == 0) {
+       SmtEngine::DefinedFunctionMap::const_iterator i = d_smt.d_definedFunctions->find(n);
+       if(i != d_smt.d_definedFunctions->end()) {
+         // replacement must be closed
+         if((*i).second.getFormals().size() > 0) {
+           result.push(d_smt.d_nodeManager->mkNode(kind::LAMBDA, d_smt.d_nodeManager->mkNode(kind::BOUND_VAR_LIST, (*i).second.getFormals()), (*i).second.getFormula()));
+           continue;
+         }
+         // don't bother putting in the cache
+         result.push((*i).second.getFormula());
+         continue;
+       }
+       // don't bother putting in the cache
+       result.push(n);
+       continue;
       }
-      // don't bother putting in the cache
-      return (*i).second.getFormula();
-    }
-    // don't bother putting in the cache
-    return n;
-  }
 
-  // maybe it's in the cache
-  hash_map<Node, Node, NodeHashFunction>::iterator cacheHit = cache.find(n);
-  if(cacheHit != cache.end()) {
-    TNode ret = (*cacheHit).second;
-    return ret.isNull() ? n : ret;
-  }
+      // maybe it's in the cache
+      hash_map<Node, Node, NodeHashFunction>::iterator cacheHit = cache.find(n);
+      if(cacheHit != cache.end()) {
+        TNode ret = (*cacheHit).second;
+        result.push(ret.isNull() ? n : ret);
+        continue;
+      }
 
-  // otherwise expand it
+      // otherwise expand it
 
-  Node node = n;
-  NodeManager* nm = d_smt.d_nodeManager;
-  // FIXME: this theory specific code should be factored out of the SmtEngine, somehow
-  switch(k) {
-  case kind::BITVECTOR_SDIV:
-  case kind::BITVECTOR_SREM:
-  case kind::BITVECTOR_SMOD: {
-    node = bv::TheoryBVRewriter::eliminateBVSDiv(node);
-    break;
-  }
-
-  case kind::BITVECTOR_UDIV:
-  case kind::BITVECTOR_UREM: {
-    node = expandBVDivByZero(node);
-    break;
-  }
-
-  case kind::DIVISION: {
-    // partial function: division
-    if(d_divByZero.isNull()) {
-      d_divByZero = nm->mkSkolem("divByZero", nm->mkFunctionType(nm->realType(), nm->realType()),
-                                 "partial real division", NodeManager::SKOLEM_EXACT_NAME);
-      if(!d_smt.d_logic.isTheoryEnabled(THEORY_UF)) {
-        d_smt.d_logic = d_smt.d_logic.getUnlockedCopy();
-        d_smt.d_logic.enableTheory(THEORY_UF);
-        d_smt.d_logic.lock();
+      NodeManager* nm = d_smt.d_nodeManager;
+      // FIXME: this theory specific code should be factored out of the SmtEngine, somehow
+      switch(k) {
+      case kind::BITVECTOR_SDIV:
+      case kind::BITVECTOR_SREM:
+      case kind::BITVECTOR_SMOD: {
+        node = bv::TheoryBVRewriter::eliminateBVSDiv(node);
+        break;
       }
-    }
-    TNode num = n[0], den = n[1];
-    Node den_eq_0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0)));
-    Node divByZeroNum = nm->mkNode(kind::APPLY_UF, d_divByZero, num);
-    Node divTotalNumDen = nm->mkNode(kind::DIVISION_TOTAL, num, den);
-    node = nm->mkNode(kind::ITE, den_eq_0, divByZeroNum, divTotalNumDen);
-    break;
-  }
-
-  case kind::INTS_DIVISION: {
-    // partial function: integer div
-    if(d_intDivByZero.isNull()) {
-      d_intDivByZero = nm->mkSkolem("intDivByZero", nm->mkFunctionType(nm->integerType(), nm->integerType()),
-                                    "partial integer division", NodeManager::SKOLEM_EXACT_NAME);
-      if(!d_smt.d_logic.isTheoryEnabled(THEORY_UF)) {
-        d_smt.d_logic = d_smt.d_logic.getUnlockedCopy();
-        d_smt.d_logic.enableTheory(THEORY_UF);
-        d_smt.d_logic.lock();
+
+      case kind::BITVECTOR_UDIV:
+      case kind::BITVECTOR_UREM: {
+        node = expandBVDivByZero(node);
+        break;
       }
-    }
-    TNode num = n[0], den = n[1];
-    Node den_eq_0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0)));
-    Node intDivByZeroNum = nm->mkNode(kind::APPLY_UF, d_intDivByZero, num);
-    Node intDivTotalNumDen = nm->mkNode(kind::INTS_DIVISION_TOTAL, num, den);
-    node = nm->mkNode(kind::ITE, den_eq_0, intDivByZeroNum, intDivTotalNumDen);
-    break;
-  }
-
-  case kind::INTS_MODULUS: {
-    // partial function: mod
-    if(d_modZero.isNull()) {
-      d_modZero = nm->mkSkolem("modZero", nm->mkFunctionType(nm->integerType(), nm->integerType()),
-                               "partial modulus", NodeManager::SKOLEM_EXACT_NAME);
-      if(!d_smt.d_logic.isTheoryEnabled(THEORY_UF)) {
-        d_smt.d_logic = d_smt.d_logic.getUnlockedCopy();
-        d_smt.d_logic.enableTheory(THEORY_UF);
-        d_smt.d_logic.lock();
+
+      case kind::DIVISION: {
+        // partial function: division
+        if(d_divByZero.isNull()) {
+          d_divByZero = nm->mkSkolem("divByZero", nm->mkFunctionType(nm->realType(), nm->realType()),
+                                     "partial real division", NodeManager::SKOLEM_EXACT_NAME);
+          if(!d_smt.d_logic.isTheoryEnabled(THEORY_UF)) {
+            d_smt.d_logic = d_smt.d_logic.getUnlockedCopy();
+            d_smt.d_logic.enableTheory(THEORY_UF);
+            d_smt.d_logic.lock();
+          }
+        }
+        TNode num = n[0], den = n[1];
+        Node den_eq_0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0)));
+        Node divByZeroNum = nm->mkNode(kind::APPLY_UF, d_divByZero, num);
+        Node divTotalNumDen = nm->mkNode(kind::DIVISION_TOTAL, num, den);
+        node = nm->mkNode(kind::ITE, den_eq_0, divByZeroNum, divTotalNumDen);
+        break;
       }
-    }
-    TNode num = n[0], den = n[1];
-    Node den_eq_0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0)));
-    Node modZeroNum = nm->mkNode(kind::APPLY_UF, d_modZero, num);
-    Node modTotalNumDen = nm->mkNode(kind::INTS_MODULUS_TOTAL, num, den);
-    node = nm->mkNode(kind::ITE, den_eq_0, modZeroNum, modTotalNumDen);
-    break;
-  }
-
-  case kind::ABS: {
-    Node out = nm->mkNode(kind::ITE, nm->mkNode(kind::LT, node[0], nm->mkConst(Rational(0))), nm->mkNode(kind::UMINUS, node[0]), node[0]);
-    cache[n] = out;
-    return out;
-  }
-
-  case kind::APPLY: {
-    // application of a user-defined symbol
-    TNode func = n.getOperator();
-    SmtEngine::DefinedFunctionMap::const_iterator i =
-      d_smt.d_definedFunctions->find(func);
-    DefinedFunction def = (*i).second;
-    vector<Node> formals = def.getFormals();
-
-    if(Debug.isOn("expand")) {
-      Debug("expand") << "found: " << n << endl;
-      Debug("expand") << " func: " << func << endl;
-      string name = func.getAttribute(expr::VarNameAttr());
-      Debug("expand") << "     : \"" << name << "\"" << endl;
-    }
-    if(i == d_smt.d_definedFunctions->end()) {
-      throw TypeCheckingException(n.toExpr(), string("Undefined function: `") + func.toString() + "'");
-    }
-    if(Debug.isOn("expand")) {
-      Debug("expand") << " defn: " << def.getFunction() << endl
-                      << "       [";
-      if(formals.size() > 0) {
-        copy( formals.begin(), formals.end() - 1,
-              ostream_iterator<Node>(Debug("expand"), ", ") );
-        Debug("expand") << formals.back();
+
+      case kind::INTS_DIVISION: {
+        // partial function: integer div
+        if(d_intDivByZero.isNull()) {
+          d_intDivByZero = nm->mkSkolem("intDivByZero", nm->mkFunctionType(nm->integerType(), nm->integerType()),
+                                        "partial integer division", NodeManager::SKOLEM_EXACT_NAME);
+          if(!d_smt.d_logic.isTheoryEnabled(THEORY_UF)) {
+            d_smt.d_logic = d_smt.d_logic.getUnlockedCopy();
+            d_smt.d_logic.enableTheory(THEORY_UF);
+            d_smt.d_logic.lock();
+          }
+        }
+        TNode num = n[0], den = n[1];
+        Node den_eq_0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0)));
+        Node intDivByZeroNum = nm->mkNode(kind::APPLY_UF, d_intDivByZero, num);
+        Node intDivTotalNumDen = nm->mkNode(kind::INTS_DIVISION_TOTAL, num, den);
+        node = nm->mkNode(kind::ITE, den_eq_0, intDivByZeroNum, intDivTotalNumDen);
+        break;
+      }
+
+      case kind::INTS_MODULUS: {
+        // partial function: mod
+        if(d_modZero.isNull()) {
+          d_modZero = nm->mkSkolem("modZero", nm->mkFunctionType(nm->integerType(), nm->integerType()),
+                                   "partial modulus", NodeManager::SKOLEM_EXACT_NAME);
+          if(!d_smt.d_logic.isTheoryEnabled(THEORY_UF)) {
+            d_smt.d_logic = d_smt.d_logic.getUnlockedCopy();
+            d_smt.d_logic.enableTheory(THEORY_UF);
+            d_smt.d_logic.lock();
+          }
+        }
+        TNode num = n[0], den = n[1];
+        Node den_eq_0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0)));
+        Node modZeroNum = nm->mkNode(kind::APPLY_UF, d_modZero, num);
+        Node modTotalNumDen = nm->mkNode(kind::INTS_MODULUS_TOTAL, num, den);
+        node = nm->mkNode(kind::ITE, den_eq_0, modZeroNum, modTotalNumDen);
+        break;
       }
-      Debug("expand") << "]" << endl
-                      << "       " << def.getFunction().getType() << endl
-                      << "       " << def.getFormula() << endl;
-    }
 
-    TNode fm = def.getFormula();
-    Node instance = fm.substitute(formals.begin(), formals.end(),
-                                  n.begin(), n.end());
-    Debug("expand") << "made : " << instance << endl;
+      case kind::ABS: {
+        Node out = nm->mkNode(kind::ITE, nm->mkNode(kind::LT, node[0], nm->mkConst(Rational(0))), nm->mkNode(kind::UMINUS, node[0]), node[0]);
+        cache[n] = out;
+        result.push(out);
+        continue;
+      }
 
-    Node expanded = expandDefinitions(instance, cache);
-    cache[n] = (n == expanded ? Node::null() : expanded);
-    return expanded;
-  }
+      case kind::APPLY: {
+        // application of a user-defined symbol
+        TNode func = n.getOperator();
+        SmtEngine::DefinedFunctionMap::const_iterator i =
+          d_smt.d_definedFunctions->find(func);
+        DefinedFunction def = (*i).second;
+        vector<Node> formals = def.getFormals();
+
+        if(Debug.isOn("expand")) {
+          Debug("expand") << "found: " << n << endl;
+          Debug("expand") << " func: " << func << endl;
+          string name = func.getAttribute(expr::VarNameAttr());
+          Debug("expand") << "     : \"" << name << "\"" << endl;
+        }
+        if(i == d_smt.d_definedFunctions->end()) {
+          throw TypeCheckingException(n.toExpr(), string("Undefined function: `") + func.toString() + "'");
+        }
+        if(Debug.isOn("expand")) {
+          Debug("expand") << " defn: " << def.getFunction() << endl
+                          << "       [";
+          if(formals.size() > 0) {
+            copy( formals.begin(), formals.end() - 1,
+                  ostream_iterator<Node>(Debug("expand"), ", ") );
+            Debug("expand") << formals.back();
+          }
+          Debug("expand") << "]" << endl
+                          << "       " << def.getFunction().getType() << endl
+                          << "       " << def.getFormula() << endl;
+        }
 
-  default:
-    // unknown kind for expansion, just iterate over the children
-    node = n;
-  }
+        TNode fm = def.getFormula();
+        Node instance = fm.substitute(formals.begin(), formals.end(),
+                                      n.begin(), n.end());
+        Debug("expand") << "made : " << instance << endl;
 
-  // there should be children here, otherwise we short-circuited a return, above
-  Assert(node.getNumChildren() > 0);
+        Node expanded = expandDefinitions(instance, cache);
+        cache[n] = (n == expanded ? Node::null() : expanded);
+        result.push(expanded);
+        continue;
+      }
 
-  // the partial functions can fall through, in which case we still
-  // consider their children
-  Debug("expand") << "cons : " << node << endl;
-  NodeBuilder<> nb(node.getKind());
-  if(node.getMetaKind() == kind::metakind::PARAMETERIZED) {
-    Debug("expand") << "op   : " << node.getOperator() << endl;
-    nb << node.getOperator();
-  }
-  for(Node::iterator i = node.begin(),
-        iend = node.end();
-      i != iend;
-      ++i) {
-    Node expanded = expandDefinitions(*i, cache);
-    Debug("expand") << "exchld: " << expanded << endl;
-    nb << expanded;
-  }
-  node = nb;
-  cache[n] = n == node ? Node::null() : node;
-  return node;
+      default:
+        // unknown kind for expansion, just iterate over the children
+        node = n;
+      }
+
+      // there should be children here, otherwise we short-circuited a result-push/continue, above
+      Assert(node.getNumChildren() > 0);
+
+      // the partial functions can fall through, in which case we still
+      // consider their children
+      worklist.push(make_triple(Node(n), node, true));
+
+      for(size_t i = 0; i < node.getNumChildren(); ++i) {
+        worklist.push(make_triple(node[i], node[i], false));
+      }
+
+    } else {
+
+      Debug("expand") << "cons : " << node << endl;
+      //cout << "cons : " << node << endl;
+      NodeBuilder<> nb(node.getKind());
+      if(node.getMetaKind() == kind::metakind::PARAMETERIZED) {
+        Debug("expand") << "op   : " << node.getOperator() << endl;
+        //cout << "op   : " << node.getOperator() << endl;
+        nb << node.getOperator();
+      }
+      for(size_t i = 0; i < node.getNumChildren(); ++i) {
+        Assert(!result.empty());
+        Node expanded = result.top();
+        result.pop();
+        //cout << "exchld : " << expanded << endl;
+        Debug("expand") << "exchld : " << expanded << endl;
+        nb << expanded;
+      }
+      node = nb;
+      cache[n] = n == node ? Node::null() : node;
+      result.push(node);
+    }
+  } while(!worklist.empty());
+
+  AlwaysAssert(result.size() == 1);
+
+  return result.top();
 }