Modifying getType to use a non-recursive algorithm (Fixes: #228)
authorChristopher L. Conway <christopherleeconway@gmail.com>
Wed, 27 Oct 2010 20:10:12 +0000 (20:10 +0000)
committerChristopher L. Conway <christopherleeconway@gmail.com>
Wed, 27 Oct 2010 20:10:12 +0000 (20:10 +0000)
src/expr/node_manager.cpp
src/expr/node_manager.h

index edb3a1592adaab8c7ea883b6bab242ec7df5eac4..c1456050df40e2dda2e527f07cc0e4402570e2e7 100644 (file)
@@ -30,8 +30,9 @@
 #include "util/Assert.h"
 #include "util/tls.h"
 
-#include <ext/hash_set>
 #include <algorithm>
+#include <list>
+#include <ext/hash_set>
 
 using namespace std;
 using namespace CVC4::expr;
@@ -209,6 +210,213 @@ void NodeManager::reclaimZombies() {
   }
 }/* NodeManager::reclaimZombies() */
 
+TypeNode NodeManager::computeType(TNode n, bool check)
+  throw (TypeCheckingExceptionPrivate, AssertionException) {  
+  TypeNode typeNode;
+
+  // Infer the type
+  switch(n.getKind()) {
+  case kind::SORT_TYPE:
+    typeNode = kindType();
+    break;
+  case kind::APPLY:
+    typeNode = CVC4::theory::builtin::ApplyTypeRule::computeType(this, n, check);
+    break;
+  case kind::EQUAL:
+    typeNode = CVC4::theory::builtin::EqualityTypeRule::computeType(this, n, check);
+    break;
+  case kind::DISTINCT:
+    typeNode = CVC4::theory::builtin::DistinctTypeRule::computeType(this, n, check);
+    break;
+  case kind::TUPLE:
+    typeNode = CVC4::theory::builtin::TupleTypeRule::computeType(this, n, check);
+    break;
+  case kind::CONST_BOOLEAN:
+    typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
+    break;
+  case kind::NOT:
+    typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
+    break;
+  case kind::AND:
+    typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
+    break;
+  case kind::IFF:
+    typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
+    break;
+  case kind::IMPLIES:
+    typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
+    break;
+  case kind::OR:
+    typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
+    break;
+  case kind::XOR:
+    typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
+    break;
+  case kind::ITE:
+    typeNode = CVC4::theory::boolean::IteTypeRule::computeType(this, n, check);
+    break;
+  case kind::APPLY_UF:
+    typeNode = CVC4::theory::uf::UfTypeRule::computeType(this, n, check);
+    break;
+  case kind::PLUS:
+    typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
+    break;
+  case kind::MULT:
+    typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
+    break;
+  case kind::MINUS:
+    typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
+    break;
+  case kind::UMINUS:
+    typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
+    break;
+  case kind::DIVISION:
+    typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
+    break;
+  case kind::CONST_RATIONAL:
+    typeNode = CVC4::theory::arith::ArithConstantTypeRule::computeType(this, n, check);
+    break;
+  case kind::CONST_INTEGER:
+    typeNode = CVC4::theory::arith::ArithConstantTypeRule::computeType(this, n, check);
+    break;
+  case kind::LT:
+    typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check);
+    break;
+  case kind::LEQ:
+    typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check);
+    break;
+  case kind::GT:
+    typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check);
+    break;
+  case kind::GEQ:
+    typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check);
+    break;
+  case kind::SELECT:
+    typeNode = CVC4::theory::arrays::ArraySelectTypeRule::computeType(this, n, check);
+    break;
+  case kind::STORE:
+    typeNode = CVC4::theory::arrays::ArrayStoreTypeRule::computeType(this, n, check);
+    break;
+  case kind::CONST_BITVECTOR:
+    typeNode = CVC4::theory::bv::BitVectorConstantTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_AND:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_OR:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_XOR:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_NOT:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_NAND:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_NOR:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_XNOR:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_COMP:
+    typeNode = CVC4::theory::bv::BitVectorCompRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_MULT:
+    typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_PLUS:
+    typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_SUB:
+    typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_NEG:
+    typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_UDIV:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_UREM:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_SDIV:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_SREM:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_SMOD:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_SHL:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_LSHR:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_ASHR:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_ROTATE_LEFT:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_ROTATE_RIGHT:
+    typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_ULT:
+    typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_ULE:
+    typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_UGT:
+    typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_UGE:
+    typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_SLT:
+    typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_SLE:
+    typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_SGT:
+    typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_SGE:
+    typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_EXTRACT:
+    typeNode = CVC4::theory::bv::BitVectorExtractTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_CONCAT:
+    typeNode = CVC4::theory::bv::BitVectorConcatRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_REPEAT:
+    typeNode = CVC4::theory::bv::BitVectorRepeatTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_ZERO_EXTEND:
+    typeNode = CVC4::theory::bv::BitVectorExtendTypeRule::computeType(this, n, check);
+    break;
+  case kind::BITVECTOR_SIGN_EXTEND:
+    typeNode = CVC4::theory::bv::BitVectorExtendTypeRule::computeType(this, n, check);
+    break;
+  default:
+    Debug("getType") << "FAILURE" << std::endl;
+    Unhandled(n.getKind());
+  }
+
+  setAttribute(n, TypeAttr(), typeNode);
+  setAttribute(n, TypeCheckedAttr(), 
+               check || getAttribute(n, TypeCheckedAttr()));
+  
+  return typeNode;
+}
+
 TypeNode NodeManager::getType(TNode n, bool check)
   throw (TypeCheckingExceptionPrivate, AssertionException) {
   TypeNode typeNode;
@@ -217,209 +425,49 @@ TypeNode NodeManager::getType(TNode n, bool check)
 
   Debug("getType") << "getting type for " << n << std::endl;
   if(!hasType || needsCheck) {
-    TypeNode oldType = typeNode;
-
-    // Infer the type
-    switch(n.getKind()) {
-    case kind::SORT_TYPE:
-      typeNode = kindType();
-      break;
-    case kind::APPLY:
-      typeNode = CVC4::theory::builtin::ApplyTypeRule::computeType(this, n, check);
-      break;
-    case kind::EQUAL:
-      typeNode = CVC4::theory::builtin::EqualityTypeRule::computeType(this, n, check);
-      break;
-    case kind::DISTINCT:
-      typeNode = CVC4::theory::builtin::DistinctTypeRule::computeType(this, n, check);
-      break;
-    case kind::TUPLE:
-      typeNode = CVC4::theory::builtin::TupleTypeRule::computeType(this, n, check);
-      break;
-    case kind::CONST_BOOLEAN:
-      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
-      break;
-    case kind::NOT:
-      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
-      break;
-    case kind::AND:
-      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
-      break;
-    case kind::IFF:
-      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
-      break;
-    case kind::IMPLIES:
-      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
-      break;
-    case kind::OR:
-      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
-      break;
-    case kind::XOR:
-      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
-      break;
-    case kind::ITE:
-      typeNode = CVC4::theory::boolean::IteTypeRule::computeType(this, n, check);
-      break;
-    case kind::APPLY_UF:
-      typeNode = CVC4::theory::uf::UfTypeRule::computeType(this, n, check);
-      break;
-    case kind::PLUS:
-      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
-      break;
-    case kind::MULT:
-      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
-      break;
-    case kind::MINUS:
-      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
-      break;
-    case kind::UMINUS:
-      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
-      break;
-    case kind::DIVISION:
-      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
-      break;
-    case kind::CONST_RATIONAL:
-      typeNode = CVC4::theory::arith::ArithConstantTypeRule::computeType(this, n, check);
-      break;
-    case kind::CONST_INTEGER:
-      typeNode = CVC4::theory::arith::ArithConstantTypeRule::computeType(this, n, check);
-      break;
-    case kind::LT:
-      typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check);
-      break;
-    case kind::LEQ:
-      typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check);
-      break;
-    case kind::GT:
-      typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check);
-      break;
-    case kind::GEQ:
-      typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check);
-      break;
-    case kind::SELECT:
-      typeNode = CVC4::theory::arrays::ArraySelectTypeRule::computeType(this, n, check);
-      break;
-    case kind::STORE:
-      typeNode = CVC4::theory::arrays::ArrayStoreTypeRule::computeType(this, n, check);
-      break;
-    case kind::CONST_BITVECTOR:
-      typeNode = CVC4::theory::bv::BitVectorConstantTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_AND:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_OR:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_XOR:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_NOT:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_NAND:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_NOR:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_XNOR:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_COMP:
-      typeNode = CVC4::theory::bv::BitVectorCompRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_MULT:
-      typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_PLUS:
-      typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_SUB:
-      typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_NEG:
-      typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_UDIV:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_UREM:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_SDIV:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_SREM:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_SMOD:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_SHL:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_LSHR:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_ASHR:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_ROTATE_LEFT:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_ROTATE_RIGHT:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_ULT:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_ULE:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_UGT:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_UGE:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_SLT:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_SLE:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_SGT:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_SGE:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_EXTRACT:
-      typeNode = CVC4::theory::bv::BitVectorExtractTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_CONCAT:
-      typeNode = CVC4::theory::bv::BitVectorConcatRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_REPEAT:
-      typeNode = CVC4::theory::bv::BitVectorRepeatTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_ZERO_EXTEND:
-      typeNode = CVC4::theory::bv::BitVectorExtendTypeRule::computeType(this, n, check);
-      break;
-    case kind::BITVECTOR_SIGN_EXTEND:
-      typeNode = CVC4::theory::bv::BitVectorExtendTypeRule::computeType(this, n, check);
-      break;
-    default:
-      Debug("getType") << "FAILURE" << std::endl;
-      Unhandled(n.getKind());
+    // TypeNode oldType = typeNode;
+
+    list<TNode> worklist;
+    worklist.push_back(n);
+
+    /* Iterate and compute the children bottom up.  This iteration is
+       very inefficient: it would be better to top-sort the Nodes so
+       that the leaves are always computed first. */
+    while( !worklist.empty() ) {
+      TNode m = worklist.front();
+      worklist.pop_front();
+
+      bool readyToCompute = true;
+      TNode::iterator it = m.begin();
+      TNode::iterator end = m.end();
+
+      for( ; it != end; ++it ) {
+        if( !hasAttribute(*it, TypeAttr()) 
+            || (check && !getAttribute(*it, TypeCheckedAttr())) ) {
+          readyToCompute = false;
+          worklist.push_back(*it);
+        }
+      }
+
+      if( readyToCompute ) {
+        /* All the children have types, time to compute */
+        computeType(m,check);
+      } else {
+        /* Wait until the children have been computed. */
+        worklist.push_back(m);
+      }
     }
 
+    /* Retrieve the type computed in the loop */
+    hasType = getAttribute(n, TypeAttr(), typeNode);
+
+    /* Type should be there and the check should have happened if we
+       asked for it. */
+    Assert( hasType );
+    Assert( !check || hasAttribute(n, TypeCheckedAttr()) );
+
     // DebugAssert( !hasType || oldType == typeNode,
     //              "Type re-computation yielded a different type" );
-
-    setAttribute(n, TypeAttr(), typeNode);
-    setAttribute(n, TypeCheckedAttr(), check);
   }
   Debug("getType") << "type of " << n << " is " << typeNode << std::endl;
   return typeNode;
index 6c7bf500bbcc7ee810637bf7524b6124bb6f62a7..c262a4847b4347de9309efb21ad5b5c1487753dc 100644 (file)
@@ -242,6 +242,9 @@ class NodeManager {
   // undefined private copy constructor (disallow copy)
   NodeManager(const NodeManager&) CVC4_UNDEFINED;
 
+  TypeNode computeType(TNode n, bool check = false)
+    throw (TypeCheckingExceptionPrivate, AssertionException);
+
 public:
 
   explicit NodeManager(context::Context* ctxt, bool earlyTypeChecking = true);