Node SmtEnginePrivate::expandDefinitions(TNode n, hash_map<Node, Node, NodeHashFunction>& cache)
throw(TypeCheckingException, LogicException) {
- Kind k = n.getKind();
-
- if(k != kind::APPLY && n.getNumChildren() == 0) {
- SmtEngine::DefinedFunctionMap::const_iterator i = d_smt.d_definedFunctions->find(n);
- if(i != d_smt.d_definedFunctions->end()) {
- // replacement must be closed
- if((*i).second.getFormals().size() > 0) {
- return d_smt.d_nodeManager->mkNode(kind::LAMBDA, d_smt.d_nodeManager->mkNode(kind::BOUND_VAR_LIST, (*i).second.getFormals()), (*i).second.getFormula());
+ stack< triple<Node, Node, bool> > worklist;
+ stack<Node> result;
+ worklist.push(make_triple(Node(n), Node(n), false));
+
+ do {
+ n = worklist.top().first;
+ Node node = worklist.top().second;
+ bool childrenPushed = worklist.top().third;
+ worklist.pop();
+
+ if(!childrenPushed) {
+ Kind k = n.getKind();
+
+ if(k != kind::APPLY && n.getNumChildren() == 0) {
+ SmtEngine::DefinedFunctionMap::const_iterator i = d_smt.d_definedFunctions->find(n);
+ if(i != d_smt.d_definedFunctions->end()) {
+ // replacement must be closed
+ if((*i).second.getFormals().size() > 0) {
+ result.push(d_smt.d_nodeManager->mkNode(kind::LAMBDA, d_smt.d_nodeManager->mkNode(kind::BOUND_VAR_LIST, (*i).second.getFormals()), (*i).second.getFormula()));
+ continue;
+ }
+ // don't bother putting in the cache
+ result.push((*i).second.getFormula());
+ continue;
+ }
+ // don't bother putting in the cache
+ result.push(n);
+ continue;
}
- // don't bother putting in the cache
- return (*i).second.getFormula();
- }
- // don't bother putting in the cache
- return n;
- }
- // maybe it's in the cache
- hash_map<Node, Node, NodeHashFunction>::iterator cacheHit = cache.find(n);
- if(cacheHit != cache.end()) {
- TNode ret = (*cacheHit).second;
- return ret.isNull() ? n : ret;
- }
+ // maybe it's in the cache
+ hash_map<Node, Node, NodeHashFunction>::iterator cacheHit = cache.find(n);
+ if(cacheHit != cache.end()) {
+ TNode ret = (*cacheHit).second;
+ result.push(ret.isNull() ? n : ret);
+ continue;
+ }
- // otherwise expand it
+ // otherwise expand it
- Node node = n;
- NodeManager* nm = d_smt.d_nodeManager;
- // FIXME: this theory specific code should be factored out of the SmtEngine, somehow
- switch(k) {
- case kind::BITVECTOR_SDIV:
- case kind::BITVECTOR_SREM:
- case kind::BITVECTOR_SMOD: {
- node = bv::TheoryBVRewriter::eliminateBVSDiv(node);
- break;
- }
-
- case kind::BITVECTOR_UDIV:
- case kind::BITVECTOR_UREM: {
- node = expandBVDivByZero(node);
- break;
- }
-
- case kind::DIVISION: {
- // partial function: division
- if(d_divByZero.isNull()) {
- d_divByZero = nm->mkSkolem("divByZero", nm->mkFunctionType(nm->realType(), nm->realType()),
- "partial real division", NodeManager::SKOLEM_EXACT_NAME);
- if(!d_smt.d_logic.isTheoryEnabled(THEORY_UF)) {
- d_smt.d_logic = d_smt.d_logic.getUnlockedCopy();
- d_smt.d_logic.enableTheory(THEORY_UF);
- d_smt.d_logic.lock();
+ NodeManager* nm = d_smt.d_nodeManager;
+ // FIXME: this theory specific code should be factored out of the SmtEngine, somehow
+ switch(k) {
+ case kind::BITVECTOR_SDIV:
+ case kind::BITVECTOR_SREM:
+ case kind::BITVECTOR_SMOD: {
+ node = bv::TheoryBVRewriter::eliminateBVSDiv(node);
+ break;
}
- }
- TNode num = n[0], den = n[1];
- Node den_eq_0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0)));
- Node divByZeroNum = nm->mkNode(kind::APPLY_UF, d_divByZero, num);
- Node divTotalNumDen = nm->mkNode(kind::DIVISION_TOTAL, num, den);
- node = nm->mkNode(kind::ITE, den_eq_0, divByZeroNum, divTotalNumDen);
- break;
- }
-
- case kind::INTS_DIVISION: {
- // partial function: integer div
- if(d_intDivByZero.isNull()) {
- d_intDivByZero = nm->mkSkolem("intDivByZero", nm->mkFunctionType(nm->integerType(), nm->integerType()),
- "partial integer division", NodeManager::SKOLEM_EXACT_NAME);
- if(!d_smt.d_logic.isTheoryEnabled(THEORY_UF)) {
- d_smt.d_logic = d_smt.d_logic.getUnlockedCopy();
- d_smt.d_logic.enableTheory(THEORY_UF);
- d_smt.d_logic.lock();
+
+ case kind::BITVECTOR_UDIV:
+ case kind::BITVECTOR_UREM: {
+ node = expandBVDivByZero(node);
+ break;
}
- }
- TNode num = n[0], den = n[1];
- Node den_eq_0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0)));
- Node intDivByZeroNum = nm->mkNode(kind::APPLY_UF, d_intDivByZero, num);
- Node intDivTotalNumDen = nm->mkNode(kind::INTS_DIVISION_TOTAL, num, den);
- node = nm->mkNode(kind::ITE, den_eq_0, intDivByZeroNum, intDivTotalNumDen);
- break;
- }
-
- case kind::INTS_MODULUS: {
- // partial function: mod
- if(d_modZero.isNull()) {
- d_modZero = nm->mkSkolem("modZero", nm->mkFunctionType(nm->integerType(), nm->integerType()),
- "partial modulus", NodeManager::SKOLEM_EXACT_NAME);
- if(!d_smt.d_logic.isTheoryEnabled(THEORY_UF)) {
- d_smt.d_logic = d_smt.d_logic.getUnlockedCopy();
- d_smt.d_logic.enableTheory(THEORY_UF);
- d_smt.d_logic.lock();
+
+ case kind::DIVISION: {
+ // partial function: division
+ if(d_divByZero.isNull()) {
+ d_divByZero = nm->mkSkolem("divByZero", nm->mkFunctionType(nm->realType(), nm->realType()),
+ "partial real division", NodeManager::SKOLEM_EXACT_NAME);
+ if(!d_smt.d_logic.isTheoryEnabled(THEORY_UF)) {
+ d_smt.d_logic = d_smt.d_logic.getUnlockedCopy();
+ d_smt.d_logic.enableTheory(THEORY_UF);
+ d_smt.d_logic.lock();
+ }
+ }
+ TNode num = n[0], den = n[1];
+ Node den_eq_0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0)));
+ Node divByZeroNum = nm->mkNode(kind::APPLY_UF, d_divByZero, num);
+ Node divTotalNumDen = nm->mkNode(kind::DIVISION_TOTAL, num, den);
+ node = nm->mkNode(kind::ITE, den_eq_0, divByZeroNum, divTotalNumDen);
+ break;
}
- }
- TNode num = n[0], den = n[1];
- Node den_eq_0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0)));
- Node modZeroNum = nm->mkNode(kind::APPLY_UF, d_modZero, num);
- Node modTotalNumDen = nm->mkNode(kind::INTS_MODULUS_TOTAL, num, den);
- node = nm->mkNode(kind::ITE, den_eq_0, modZeroNum, modTotalNumDen);
- break;
- }
-
- case kind::ABS: {
- Node out = nm->mkNode(kind::ITE, nm->mkNode(kind::LT, node[0], nm->mkConst(Rational(0))), nm->mkNode(kind::UMINUS, node[0]), node[0]);
- cache[n] = out;
- return out;
- }
-
- case kind::APPLY: {
- // application of a user-defined symbol
- TNode func = n.getOperator();
- SmtEngine::DefinedFunctionMap::const_iterator i =
- d_smt.d_definedFunctions->find(func);
- DefinedFunction def = (*i).second;
- vector<Node> formals = def.getFormals();
-
- if(Debug.isOn("expand")) {
- Debug("expand") << "found: " << n << endl;
- Debug("expand") << " func: " << func << endl;
- string name = func.getAttribute(expr::VarNameAttr());
- Debug("expand") << " : \"" << name << "\"" << endl;
- }
- if(i == d_smt.d_definedFunctions->end()) {
- throw TypeCheckingException(n.toExpr(), string("Undefined function: `") + func.toString() + "'");
- }
- if(Debug.isOn("expand")) {
- Debug("expand") << " defn: " << def.getFunction() << endl
- << " [";
- if(formals.size() > 0) {
- copy( formals.begin(), formals.end() - 1,
- ostream_iterator<Node>(Debug("expand"), ", ") );
- Debug("expand") << formals.back();
+
+ case kind::INTS_DIVISION: {
+ // partial function: integer div
+ if(d_intDivByZero.isNull()) {
+ d_intDivByZero = nm->mkSkolem("intDivByZero", nm->mkFunctionType(nm->integerType(), nm->integerType()),
+ "partial integer division", NodeManager::SKOLEM_EXACT_NAME);
+ if(!d_smt.d_logic.isTheoryEnabled(THEORY_UF)) {
+ d_smt.d_logic = d_smt.d_logic.getUnlockedCopy();
+ d_smt.d_logic.enableTheory(THEORY_UF);
+ d_smt.d_logic.lock();
+ }
+ }
+ TNode num = n[0], den = n[1];
+ Node den_eq_0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0)));
+ Node intDivByZeroNum = nm->mkNode(kind::APPLY_UF, d_intDivByZero, num);
+ Node intDivTotalNumDen = nm->mkNode(kind::INTS_DIVISION_TOTAL, num, den);
+ node = nm->mkNode(kind::ITE, den_eq_0, intDivByZeroNum, intDivTotalNumDen);
+ break;
+ }
+
+ case kind::INTS_MODULUS: {
+ // partial function: mod
+ if(d_modZero.isNull()) {
+ d_modZero = nm->mkSkolem("modZero", nm->mkFunctionType(nm->integerType(), nm->integerType()),
+ "partial modulus", NodeManager::SKOLEM_EXACT_NAME);
+ if(!d_smt.d_logic.isTheoryEnabled(THEORY_UF)) {
+ d_smt.d_logic = d_smt.d_logic.getUnlockedCopy();
+ d_smt.d_logic.enableTheory(THEORY_UF);
+ d_smt.d_logic.lock();
+ }
+ }
+ TNode num = n[0], den = n[1];
+ Node den_eq_0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0)));
+ Node modZeroNum = nm->mkNode(kind::APPLY_UF, d_modZero, num);
+ Node modTotalNumDen = nm->mkNode(kind::INTS_MODULUS_TOTAL, num, den);
+ node = nm->mkNode(kind::ITE, den_eq_0, modZeroNum, modTotalNumDen);
+ break;
}
- Debug("expand") << "]" << endl
- << " " << def.getFunction().getType() << endl
- << " " << def.getFormula() << endl;
- }
- TNode fm = def.getFormula();
- Node instance = fm.substitute(formals.begin(), formals.end(),
- n.begin(), n.end());
- Debug("expand") << "made : " << instance << endl;
+ case kind::ABS: {
+ Node out = nm->mkNode(kind::ITE, nm->mkNode(kind::LT, node[0], nm->mkConst(Rational(0))), nm->mkNode(kind::UMINUS, node[0]), node[0]);
+ cache[n] = out;
+ result.push(out);
+ continue;
+ }
- Node expanded = expandDefinitions(instance, cache);
- cache[n] = (n == expanded ? Node::null() : expanded);
- return expanded;
- }
+ case kind::APPLY: {
+ // application of a user-defined symbol
+ TNode func = n.getOperator();
+ SmtEngine::DefinedFunctionMap::const_iterator i =
+ d_smt.d_definedFunctions->find(func);
+ DefinedFunction def = (*i).second;
+ vector<Node> formals = def.getFormals();
+
+ if(Debug.isOn("expand")) {
+ Debug("expand") << "found: " << n << endl;
+ Debug("expand") << " func: " << func << endl;
+ string name = func.getAttribute(expr::VarNameAttr());
+ Debug("expand") << " : \"" << name << "\"" << endl;
+ }
+ if(i == d_smt.d_definedFunctions->end()) {
+ throw TypeCheckingException(n.toExpr(), string("Undefined function: `") + func.toString() + "'");
+ }
+ if(Debug.isOn("expand")) {
+ Debug("expand") << " defn: " << def.getFunction() << endl
+ << " [";
+ if(formals.size() > 0) {
+ copy( formals.begin(), formals.end() - 1,
+ ostream_iterator<Node>(Debug("expand"), ", ") );
+ Debug("expand") << formals.back();
+ }
+ Debug("expand") << "]" << endl
+ << " " << def.getFunction().getType() << endl
+ << " " << def.getFormula() << endl;
+ }
- default:
- // unknown kind for expansion, just iterate over the children
- node = n;
- }
+ TNode fm = def.getFormula();
+ Node instance = fm.substitute(formals.begin(), formals.end(),
+ n.begin(), n.end());
+ Debug("expand") << "made : " << instance << endl;
- // there should be children here, otherwise we short-circuited a return, above
- Assert(node.getNumChildren() > 0);
+ Node expanded = expandDefinitions(instance, cache);
+ cache[n] = (n == expanded ? Node::null() : expanded);
+ result.push(expanded);
+ continue;
+ }
- // the partial functions can fall through, in which case we still
- // consider their children
- Debug("expand") << "cons : " << node << endl;
- NodeBuilder<> nb(node.getKind());
- if(node.getMetaKind() == kind::metakind::PARAMETERIZED) {
- Debug("expand") << "op : " << node.getOperator() << endl;
- nb << node.getOperator();
- }
- for(Node::iterator i = node.begin(),
- iend = node.end();
- i != iend;
- ++i) {
- Node expanded = expandDefinitions(*i, cache);
- Debug("expand") << "exchld: " << expanded << endl;
- nb << expanded;
- }
- node = nb;
- cache[n] = n == node ? Node::null() : node;
- return node;
+ default:
+ // unknown kind for expansion, just iterate over the children
+ node = n;
+ }
+
+ // there should be children here, otherwise we short-circuited a result-push/continue, above
+ Assert(node.getNumChildren() > 0);
+
+ // the partial functions can fall through, in which case we still
+ // consider their children
+ worklist.push(make_triple(Node(n), node, true));
+
+ for(size_t i = 0; i < node.getNumChildren(); ++i) {
+ worklist.push(make_triple(node[i], node[i], false));
+ }
+
+ } else {
+
+ Debug("expand") << "cons : " << node << endl;
+ //cout << "cons : " << node << endl;
+ NodeBuilder<> nb(node.getKind());
+ if(node.getMetaKind() == kind::metakind::PARAMETERIZED) {
+ Debug("expand") << "op : " << node.getOperator() << endl;
+ //cout << "op : " << node.getOperator() << endl;
+ nb << node.getOperator();
+ }
+ for(size_t i = 0; i < node.getNumChildren(); ++i) {
+ Assert(!result.empty());
+ Node expanded = result.top();
+ result.pop();
+ //cout << "exchld : " << expanded << endl;
+ Debug("expand") << "exchld : " << expanded << endl;
+ nb << expanded;
+ }
+ node = nb;
+ cache[n] = n == node ? Node::null() : node;
+ result.push(node);
+ }
+ } while(!worklist.empty());
+
+ AlwaysAssert(result.size() == 1);
+
+ return result.top();
}