Adding optional 'check' parameter to getType() methods
authorChristopher L. Conway <christopherleeconway@gmail.com>
Tue, 27 Jul 2010 20:54:33 +0000 (20:54 +0000)
committerChristopher L. Conway <christopherleeconway@gmail.com>
Tue, 27 Jul 2010 20:54:33 +0000 (20:54 +0000)
14 files changed:
src/expr/expr_manager_template.cpp
src/expr/expr_manager_template.h
src/expr/expr_template.cpp
src/expr/expr_template.h
src/expr/node.h
src/expr/node_manager.cpp
src/expr/node_manager.h
src/theory/arith/theory_arith_type_rules.h
src/theory/arrays/theory_arrays_type_rules.h
src/theory/booleans/theory_bool_type_rules.h
src/theory/builtin/theory_builtin_type_rules.h
src/theory/bv/theory_bv_type_rules.h
src/theory/uf/theory_uf_type_rules.h
test/unit/expr/expr_public.h

index f28729b94d8167869b5ad63db527e1771f84e92e..5fcbad3a2d1df9f653c3c7545ba46497bd3b71c5 100644 (file)
@@ -226,11 +226,36 @@ SortType ExprManager::mkSort(const std::string& name) const {
   return Type(d_nodeManager, new TypeNode(d_nodeManager->mkSort(name)));
 }
 
-Type ExprManager::getType(const Expr& e) throw (TypeCheckingException) {
+/**
+ * Get the type for the given Expr and optionally do type checking.
+ *
+ * Initial type computation will be near-constant time if
+ * type checking is not requested. Results are memoized, so that
+ * subsequent calls to getType() without type checking will be
+ * constant time.
+ *
+ * Initial type checking is linear in the size of the expression.
+ * Again, the results are memoized, so that subsequent calls to
+ * getType(), with or without type checking, will be constant
+ * time.
+ *
+ * NOTE: A TypeCheckingException can be thrown even when type
+ * checking is not requested. getType() will always return a
+ * valid and correct type and, thus, an exception will be thrown
+ * when no valid or correct type can be computed (e.g., if the
+ * arguments to a bit-vector operation aren't bit-vectors). When
+ * type checking is not requested, getType() will do the minimum
+ * amount of checking required to return a valid result.
+ *
+ * @param n the Expr for which we want a type
+ * @param check whether we should check the type as we compute it 
+ * (default: false)
+ */
+Type ExprManager::getType(const Expr& e, bool check) throw (TypeCheckingException) {
   NodeManagerScope nms(d_nodeManager);
   Type t;
   try {
-    t = Type(d_nodeManager, new TypeNode(d_nodeManager->getType(e.getNode())));
+    t = Type(d_nodeManager, new TypeNode(d_nodeManager->getType(e.getNode(), check)));
   } catch (const TypeCheckingExceptionPrivate& e) {
     throw TypeCheckingException(Expr(this, new Node(e.getNode())), e.getMessage());
   }
index 450d7fc4d8640e95d724169c33cf22a2316e684b..3b5b0e0f48b4f8bc06f6d92a58eecbcfba6a4d23 100644 (file)
@@ -221,7 +221,8 @@ public:
   SortType mkSort(const std::string& name) const;
 
   /** Get the type of an expression */
-  Type getType(const Expr& e) throw (TypeCheckingException);
+  Type getType(const Expr& e, bool check = false) 
+    throw (TypeCheckingException);
 
   // variables are special, because duplicates are permitted
   Expr mkVar(const std::string& name, const Type& type);
index fc67bcba1530561c43c896184eca42a54f54472c..48acd25889ccf0d42d5b32a843c7a42bb96c994e 100644 (file)
@@ -181,10 +181,10 @@ Expr Expr::getOperator() const {
   return Expr(d_exprManager, new Node(d_node->getOperator()));
 }
 
-Type Expr::getType() const throw (TypeCheckingException) {
+Type Expr::getType(bool check) const throw (TypeCheckingException) {
   ExprManagerScope ems(*this);
   Assert(d_node != NULL, "Unexpected NULL expression pointer!");
-  return d_exprManager->getType(*this);
+  return d_exprManager->getType(*this, check);
 }
 
 std::string Expr::toString() const {
index 517931477f61d35b3252d5fe8cbbc29de94b7305..becdd46e23ecdece751f5455b5fce70ee5ff2e11 100644 (file)
@@ -224,10 +224,31 @@ public:
    */
   Expr getOperator() const;
 
-  /** Returns the type of the expression, if it has been computed.
-   * Returns NULL if the type of the expression is not known.
+  /**
+   * Get the type for this Expr and optionally do type checking.
+   *
+   * Initial type computation will be near-constant time if
+   * type checking is not requested. Results are memoized, so that
+   * subsequent calls to getType() without type checking will be
+   * constant time.
+   *
+   * Initial type checking is linear in the size of the expression.
+   * Again, the results are memoized, so that subsequent calls to
+   * getType(), with or without type checking, will be constant
+   * time.
+   *
+   * NOTE: A TypeCheckingException can be thrown even when type
+   * checking is not requested. getType() will always return a
+   * valid and correct type and, thus, an exception will be thrown
+   * when no valid or correct type can be computed (e.g., if the
+   * arguments to a bit-vector operation aren't bit-vectors). When
+   * type checking is not requested, getType() will do the minimum
+   * amount of checking required to return a valid result.
+   *
+   * @param check whether we should check the type as we compute it 
+   * (default: false)
    */
-  Type getType() const throw (TypeCheckingException);
+  Type getType(bool check = false) const throw (TypeCheckingException);
 
   /**
    * Returns the string representation of the expression.
index 218b9a3ea3a932afc11a5fb633e424bf55621eb8..4b1a0e5bec7c4bee8160390f55d50e1cf06a6240 100644 (file)
@@ -334,10 +334,30 @@ public:
   inline bool hasOperator() const;
 
   /**
-   * Returns the type of this node.
-   * @return the type
+   * Get the type for the node and optionally do type checking.
+   *
+   * Initial type computation will be near-constant time if
+   * type checking is not requested. Results are memoized, so that
+   * subsequent calls to getType() without type checking will be
+   * constant time.
+   *
+   * Initial type checking is linear in the size of the expression.
+   * Again, the results are memoized, so that subsequent calls to
+   * getType(), with or without type checking, will be constant
+   * time.
+   *
+   * NOTE: A TypeCheckingException can be thrown even when type
+   * checking is not requested. getType() will always return a
+   * valid and correct type and, thus, an exception will be thrown
+   * when no valid or correct type can be computed (e.g., if the
+   * arguments to a bit-vector operation aren't bit-vectors). When
+   * type checking is not requested, getType() will do the minimum
+   * amount of checking required to return a valid result.
+   *
+   * @param check whether we should check the type as we compute it 
+   * (default: false)
    */
-  TypeNode getType() const
+  TypeNode getType(bool check = false) const
     throw (CVC4::TypeCheckingExceptionPrivate, CVC4::AssertionException);
 
   /**
@@ -893,7 +913,7 @@ inline bool NodeTemplate<ref_count>::hasOperator() const {
 }
 
 template <bool ref_count>
-TypeNode NodeTemplate<ref_count>::getType() const
+TypeNode NodeTemplate<ref_count>::getType(bool check) const
   throw (CVC4::TypeCheckingExceptionPrivate, CVC4::AssertionException) {
   Assert( NodeManager::currentNM() != NULL,
           "There is no current CVC4::NodeManager associated to this thread.\n"
@@ -903,7 +923,7 @@ TypeNode NodeTemplate<ref_count>::getType() const
     Assert( d_nv->d_rc > 0, "TNode pointing to an expired NodeValue" );
   }
 
-  return NodeManager::currentNM()->getType(*this);
+  return NodeManager::currentNM()->getType(*this, check);
 }
 
 #ifdef CVC4_DEBUG
index 2e45fe9d0cf901f307a7250a9c4b88f0d75cf341..fbfffe87d73518967b4a052879ff11f2d09d2db2 100644 (file)
@@ -27,6 +27,8 @@
 #include "theory/arrays/theory_arrays_type_rules.h"
 #include "theory/bv/theory_bv_type_rules.h"
 
+#include "util/Assert.h"
+
 #include <ext/hash_set>
 #include <algorithm>
 
@@ -182,202 +184,211 @@ void NodeManager::reclaimZombies() {
   }
 }/* NodeManager::reclaimZombies() */
 
-TypeNode NodeManager::getType(TNode n)
+TypeNode NodeManager::getType(TNode n, bool check)
   throw (TypeCheckingExceptionPrivate, AssertionException) {
   TypeNode typeNode;
   bool hasType = getAttribute(n, TypeAttr(), typeNode);
+  bool needsCheck = check && !getAttribute(n, TypeCheckedAttr());
+
   Debug("getType") << "getting type for " << n << std::endl;
-  if(!hasType) {
+  if(!hasType || needsCheck) {
+    TypeNode oldType = typeNode;
+
     // Infer the type
     switch(n.getKind()) {
     case kind::SORT_TYPE:
       typeNode = kindType();
       break;
     case kind::EQUAL:
-      typeNode = CVC4::theory::builtin::EqualityTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::builtin::EqualityTypeRule::computeType(this, n, check);
       break;
     case kind::DISTINCT:
-      typeNode = CVC4::theory::builtin::DistinctTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::builtin::DistinctTypeRule::computeType(this, n, check);
       break;
     case kind::CONST_BOOLEAN:
-      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
       break;
     case kind::NOT:
-      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
       break;
     case kind::AND:
-      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
       break;
     case kind::IFF:
-      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
       break;
     case kind::IMPLIES:
-      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
       break;
     case kind::OR:
-      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
       break;
     case kind::XOR:
-      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check);
       break;
     case kind::ITE:
-      typeNode = CVC4::theory::boolean::IteTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::boolean::IteTypeRule::computeType(this, n, check);
       break;
     case kind::APPLY_UF:
-      typeNode = CVC4::theory::uf::UfTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::uf::UfTypeRule::computeType(this, n, check);
       break;
     case kind::PLUS:
-      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
       break;
     case kind::MULT:
-      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
       break;
     case kind::MINUS:
-      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
       break;
     case kind::UMINUS:
-      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
       break;
     case kind::DIVISION:
-      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
       break;
     case kind::CONST_RATIONAL:
-      typeNode = CVC4::theory::arith::ArithConstantTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::arith::ArithConstantTypeRule::computeType(this, n, check);
       break;
     case kind::CONST_INTEGER:
-      typeNode = CVC4::theory::arith::ArithConstantTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::arith::ArithConstantTypeRule::computeType(this, n, check);
       break;
     case kind::LT:
-      typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check);
       break;
     case kind::LEQ:
-      typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check);
       break;
     case kind::GT:
-      typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check);
       break;
     case kind::GEQ:
-      typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check);
       break;
     case kind::SELECT:
-      typeNode = CVC4::theory::arrays::ArraySelectTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::arrays::ArraySelectTypeRule::computeType(this, n, check);
       break;
     case kind::STORE:
-      typeNode = CVC4::theory::arrays::ArrayStoreTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::arrays::ArrayStoreTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_CONST:
-      typeNode = CVC4::theory::bv::BitVectorConstantTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorConstantTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_AND:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_OR:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_XOR:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_NOT:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_NAND:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_NOR:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_XNOR:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_COMP:
-      typeNode = CVC4::theory::bv::BitVectorCompRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorCompRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_MULT:
-      typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_PLUS:
-      typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_SUB:
-      typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_NEG:
-      typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_UDIV:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_UREM:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_SDIV:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_SREM:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_SMOD:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_SHL:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_LSHR:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_ASHR:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_ROTATE_LEFT:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_ROTATE_RIGHT:
-      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_ULT:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_ULE:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_UGT:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_UGE:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_SLT:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_SLE:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_SGT:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_SGE:
-      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_EXTRACT:
-      typeNode = CVC4::theory::bv::BitVectorExtractTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorExtractTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_CONCAT:
-      typeNode = CVC4::theory::bv::BitVectorConcatRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorConcatRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_REPEAT:
-      typeNode = CVC4::theory::bv::BitVectorRepeatTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorRepeatTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_ZERO_EXTEND:
-      typeNode = CVC4::theory::bv::BitVectorExtendTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorExtendTypeRule::computeType(this, n, check);
       break;
     case kind::BITVECTOR_SIGN_EXTEND:
-      typeNode = CVC4::theory::bv::BitVectorExtendTypeRule::computeType(this, n);
+      typeNode = CVC4::theory::bv::BitVectorExtendTypeRule::computeType(this, n, check);
       break;
     default:
       Debug("getType") << "FAILURE" << std::endl;
       Unhandled(n.getKind());
     }
+
+    // DebugAssert( !hasType || oldType == typeNode,
+    //              "Type re-computation yielded a different type" );
+
     setAttribute(n, TypeAttr(), typeNode);
+    setAttribute(n, TypeCheckedAttr(), check);
   }
   Debug("getType") << "type of " << n << " is " << typeNode << std::endl;
   return typeNode;
index d01e22abd41f47a6089bd6e93a2e32191abf2ec1..7a53cabfc261f9381ff9de987390062f253c6186 100644 (file)
@@ -221,10 +221,12 @@ class NodeManager {
 
   // attribute tags
   struct TypeTag {};
+  struct TypeCheckedTag;
 
   // NodeManager's attributes.  These aren't exposed outside of this
   // class; use the getters.
   typedef expr::Attribute<TypeTag, TypeNode> TypeAttr;
+  typedef expr::Attribute<TypeCheckedTag, bool> TypeCheckedAttr;
 
   /* A note on isAtomic() and isAtomicFormula() (in CVC3 parlance)..
    *
@@ -527,11 +529,32 @@ public:
   inline TypeNode mkSort(const std::string& name);
 
   /**
-   * Get the type for the given node.
+   * Get the type for the given node and optionally do type checking.
+   *
+   * Initial type computation will be near-constant time if
+   * type checking is not requested. Results are memoized, so that
+   * subsequent calls to getType() without type checking will be
+   * constant time.
+   *
+   * Initial type checking is linear in the size of the expression.
+   * Again, the results are memoized, so that subsequent calls to
+   * getType(), with or without type checking, will be constant
+   * time.
+   *
+   * NOTE: A TypeCheckingException can be thrown even when type
+   * checking is not requested. getType() will always return a
+   * valid and correct type and, thus, an exception will be thrown
+   * when no valid or correct type can be computed (e.g., if the
+   * arguments to a bit-vector operation aren't bit-vectors). When
+   * type checking is not requested, getType() will do the minimum
+   * amount of checking required to return a valid result.
+   *
+   * @param n the Node for which we want a type
+   * @param check whether we should check the type as we compute it 
+   * (default: false)
    */
-  TypeNode getType(TNode n)
+  TypeNode getType(TNode n, bool check = false)
     throw (TypeCheckingExceptionPrivate, AssertionException);
-
 };
 
 /**
@@ -888,18 +911,21 @@ inline Node* NodeManager::mkVarPtr(const std::string& name,
 inline Node NodeManager::mkVar(const TypeNode& type) {
   Node n = NodeBuilder<0>(this, kind::VARIABLE);
   n.setAttribute(TypeAttr(), type);
+  n.setAttribute(TypeCheckedAttr(), true);
   return n;
 }
 
 inline Node* NodeManager::mkVarPtr(const TypeNode& type) {
   Node* n = NodeBuilder<0>(this, kind::VARIABLE).constructNodePtr();
   n->setAttribute(TypeAttr(), type);
+  n->setAttribute(TypeCheckedAttr(), true);
   return n;
 }
 
 inline Node NodeManager::mkSkolem(const TypeNode& type) {
   Node n = NodeBuilder<0>(this, kind::SKOLEM);
   n.setAttribute(TypeAttr(), type);
+  n.setAttribute(TypeCheckedAttr(), true);
   return n;
 }
 
index 9fb30bdb4565f565949bf18cb1a3ac4249b17cc1..b8fa85c033b3fc0a511e6517a017bf411a32c3ef 100644 (file)
@@ -28,7 +28,7 @@ namespace arith {
 
 class ArithConstantTypeRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
     if (n.getKind() == kind::CONST_RATIONAL) return nodeManager->realType();
     return nodeManager->integerType();
@@ -37,7 +37,7 @@ public:
 
 class ArithOperatorTypeRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
     TypeNode integerType = nodeManager->integerType();
     TypeNode realType = nodeManager->realType();
@@ -45,10 +45,17 @@ public:
     TNode::iterator child_it_end = n.end();
     bool isInteger = true;
     for(; child_it != child_it_end; ++child_it) {
-      TypeNode childType = (*child_it).getType();
-      if (!childType.isInteger()) isInteger = false;
-      if(childType != integerType && childType != realType) {
-        throw TypeCheckingExceptionPrivate(n, "expecting an arithmetic subterm");
+      TypeNode childType = (*child_it).getType(check);
+      if (!childType.isInteger()) {
+        isInteger = false;
+        if( !check ) { // if we're not checking, nothing left to do
+          break;
+        }
+      }
+      if( check ) {
+        if(childType != integerType && childType != realType) {
+          throw TypeCheckingExceptionPrivate(n, "expecting an arithmetic subterm");
+        }
       }
     }
     return (isInteger ? integerType : realType);
@@ -57,17 +64,19 @@ public:
 
 class ArithPredicateTypeRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
-    TypeNode integerType = nodeManager->integerType();
-    TypeNode realType = nodeManager->realType();
-    TypeNode lhsType = n[0].getType();
-    if (lhsType != integerType && lhsType != realType) {
-      throw TypeCheckingExceptionPrivate(n, "expecting an arithmetic term on the left-hand-side");
-    }
-    TypeNode rhsType = n[1].getType();
-    if (rhsType != integerType && rhsType != realType) {
-      throw TypeCheckingExceptionPrivate(n, "expecting an arithmetic term on the right-hand-side");
+    if( check ) {
+      TypeNode integerType = nodeManager->integerType();
+      TypeNode realType = nodeManager->realType();
+      TypeNode lhsType = n[0].getType(check);
+      if (lhsType != integerType && lhsType != realType) {
+        throw TypeCheckingExceptionPrivate(n, "expecting an arithmetic term on the left-hand-side");
+      }
+      TypeNode rhsType = n[1].getType(check);
+      if (rhsType != integerType && rhsType != realType) {
+        throw TypeCheckingExceptionPrivate(n, "expecting an arithmetic term on the right-hand-side");
+      }
     }
     return nodeManager->booleanType();
   }
index 0eb88d800c00561c2a48f30947e529d5408e21f5..5d0713a894773456fa4ea02f1e19dbca7d2a0358 100644 (file)
@@ -26,30 +26,34 @@ namespace theory {
 namespace arrays {
 
 struct ArraySelectTypeRule {
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
     throw (TypeCheckingExceptionPrivate) {
     Assert(n.getKind() == kind::SELECT);
-    TypeNode arrayType = n[0].getType();
-    TypeNode indexType = n[1].getType();
-    if(arrayType.getArrayIndexType() != indexType) {
-      throw TypeCheckingExceptionPrivate(n, "array select not indexed with correct type for array");
+    TypeNode arrayType = n[0].getType(check);
+    if( check ) {
+      TypeNode indexType = n[1].getType(check);
+      if(arrayType.getArrayIndexType() != indexType) {
+        throw TypeCheckingExceptionPrivate(n, "array select not indexed with correct type for array");
+      }
     }
     return arrayType.getArrayConstituentType();
   }
 };/* struct ArraySelectTypeRule */
 
 struct ArrayStoreTypeRule {
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
     throw (TypeCheckingExceptionPrivate) {
     Assert(n.getKind() == kind::STORE);
-    TypeNode arrayType = n[0].getType();
-    TypeNode indexType = n[1].getType();
-    TypeNode valueType = n[2].getType();
-    if(arrayType.getArrayIndexType() != indexType) {
-      throw TypeCheckingExceptionPrivate(n, "array store not indexed with correct type for array");
-    }
-    if(arrayType.getArrayConstituentType() != valueType) {
-      throw TypeCheckingExceptionPrivate(n, "array store not assigned with correct type for array");
+    TypeNode arrayType = n[0].getType(check);
+    if( check ) {
+      TypeNode indexType = n[1].getType(check);
+      TypeNode valueType = n[2].getType(check);
+      if(arrayType.getArrayIndexType() != indexType) {
+        throw TypeCheckingExceptionPrivate(n, "array store not indexed with correct type for array");
+      }
+      if(arrayType.getArrayConstituentType() != valueType) {
+        throw TypeCheckingExceptionPrivate(n, "array store not assigned with correct type for array");
+      }
     }
     return arrayType;
   }
index b947cee109cca11d0e1eb934e19f64bb6e7a9d4c..fddced8efc702ac35c4fbf5f3b9554eb84a883ef 100644 (file)
@@ -27,30 +27,35 @@ namespace boolean {
 
 class BooleanTypeRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
     TypeNode booleanType = nodeManager->booleanType();
-    TNode::iterator child_it = n.begin();
-    TNode::iterator child_it_end = n.end();
-    for(; child_it != child_it_end; ++child_it)
-      if((*child_it).getType() != booleanType) {
-        throw TypeCheckingExceptionPrivate(n, "expecting a Boolean subexpression");
+    if( check ) {
+      TNode::iterator child_it = n.begin();
+      TNode::iterator child_it_end = n.end();
+      for(; child_it != child_it_end; ++child_it) {
+        if((*child_it).getType(check) != booleanType) {
+          throw TypeCheckingExceptionPrivate(n, "expecting a Boolean subexpression");
+        }
       }
+    }
     return booleanType;
   }
 };
 
 class IteTypeRule {
   public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
-    TypeNode booleanType = nodeManager->booleanType();
-    if (n[0].getType() != booleanType) {
-      throw TypeCheckingExceptionPrivate(n, "condition of ITE is not Boolean");
-    }
-    TypeNode iteType = n[1].getType();
-    if (iteType != n[2].getType()) {
-      throw TypeCheckingExceptionPrivate(n, "both branches of the ITE must be of the same type");
+    TypeNode iteType = n[1].getType(check);
+    if( check ) {
+      TypeNode booleanType = nodeManager->booleanType();
+      if (n[0].getType(check) != booleanType) {
+        throw TypeCheckingExceptionPrivate(n, "condition of ITE is not Boolean");
+      }
+      if (iteType != n[2].getType(check)) {
+        throw TypeCheckingExceptionPrivate(n, "both branches of the ITE must be of the same type");
+      }
     }
     return iteType;
   }
index 19d6e268b0c1841335353f5899033ec6803581e5..4458931a9de6dd0f29473de6a98fe3e002b2b353 100644 (file)
@@ -31,9 +31,11 @@ namespace builtin {
 
 class EqualityTypeRule {
   public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n) throw (TypeCheckingExceptionPrivate) {
-    if (n[0].getType() != n[1].getType()) {
-      throw TypeCheckingExceptionPrivate(n, "Left and right hand side of the equation are not of the same type");
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) {
+    if( check ) {
+      if (n[0].getType(check) != n[1].getType(check)) {
+        throw TypeCheckingExceptionPrivate(n, "Left and right hand side of the equation are not of the same type");
+      }
     }
     return nodeManager->booleanType();
   }
@@ -41,13 +43,15 @@ class EqualityTypeRule {
 
 class DistinctTypeRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n) {
-    TNode::iterator child_it = n.begin();
-    TNode::iterator child_it_end = n.end();
-    TypeNode firstType = (*child_it).getType();
-    for (++child_it; child_it != child_it_end; ++child_it) {
-      if ((*child_it).getType() != firstType) {
-        throw TypeCheckingExceptionPrivate(n, "Not all arguments are of the same type");
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) {
+    if( check ) {
+      TNode::iterator child_it = n.begin();
+      TNode::iterator child_it_end = n.end();
+      TypeNode firstType = (*child_it).getType(check);
+      for (++child_it; child_it != child_it_end; ++child_it) {
+        if ((*child_it).getType() != firstType) {
+          throw TypeCheckingExceptionPrivate(n, "Not all arguments are of the same type");
+        }
       }
     }
     return nodeManager->booleanType();
index 7aaae7349a0cf311c94e8a7588d735cc0b033383..9bb9e61dfa6eceff53aa361570972cb5a4029692 100644 (file)
@@ -18,6 +18,8 @@
 
 #include "cvc4_private.h"
 
+#include <algorithm>
+
 #ifndef __CVC4__THEORY__BV__THEORY_BV_TYPE_RULES_H
 #define __CVC4__THEORY__BV__THEORY_BV_TYPE_RULES_H
 
@@ -27,7 +29,7 @@ namespace bv {
 
 class BitVectorConstantTypeRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
     return nodeManager->mkBitVectorType(n.getConst<BitVector>().getSize());
   }
@@ -35,12 +37,14 @@ public:
 
 class BitVectorCompRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
-    TypeNode lhs = n[0].getType();
-    TypeNode rhs = n[1].getType();
-    if (!lhs.isBitVector() || lhs != rhs) {
-      throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms of the same width");
+    if( check ) {
+      TypeNode lhs = n[0].getType(check);
+      TypeNode rhs = n[1].getType(check);
+      if (!lhs.isBitVector() || lhs != rhs) {
+        throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms of the same width");
+      }
     }
     return nodeManager->mkBitVectorType(1);
   }
@@ -48,18 +52,18 @@ public:
 
 class BitVectorArithRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
     unsigned maxWidth = 0;
     TNode::iterator it = n.begin();
     TNode::iterator it_end = n.end();
     // TODO: optimize unary neg
     for (; it != it_end; ++ it) {
-      TypeNode t = (*it).getType();
-      if (!t.isBitVector()) {
+      TypeNode t = (*it).getType(check);
+      if (check && !t.isBitVector()) {
         throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms");
       }
-      if (maxWidth < t.getBitVectorSize()) maxWidth = t.getBitVectorSize();
+      maxWidth = std::max( maxWidth, t.getBitVectorSize() );
     }
     return nodeManager->mkBitVectorType(maxWidth);
   }
@@ -67,17 +71,19 @@ public:
 
 class BitVectorFixedWidthTypeRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
     TNode::iterator it = n.begin();
-    TNode::iterator it_end = n.end();
-    TypeNode t = (*it).getType();
-    if (!t.isBitVector()) {
-      throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms");
-    }
-    for (++ it; it != it_end; ++ it) {
-      if ((*it).getType() != t) {
-        throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms of the same width");
+    TypeNode t = (*it).getType(check);
+    if( check ) {
+      if (!t.isBitVector()) {
+        throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms");
+      }
+      TNode::iterator it_end = n.end();
+      for (++ it; it != it_end; ++ it) {
+        if ((*it).getType(check) != t) {
+          throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms of the same width");
+        }
       }
     }
     return t;
@@ -86,15 +92,17 @@ public:
 
 class BitVectorPredicateTypeRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
-    TypeNode lhsType = n[0].getType();
-    if (!lhsType.isBitVector()) {
-      throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms");
-    }
-    TypeNode rhsType = n[1].getType();
-    if (lhsType != rhsType) {
-      throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms of the same width");
+    if( check ) {
+      TypeNode lhsType = n[0].getType(check);
+      if (!lhsType.isBitVector()) {
+        throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms");
+      }
+      TypeNode rhsType = n[1].getType(check);
+      if (lhsType != rhsType) {
+        throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms of the same width");
+      }
     }
     return nodeManager->booleanType();
   }
@@ -102,18 +110,25 @@ public:
 
 class BitVectorExtractTypeRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
-    TypeNode t = n[0].getType();
-    if (!t.isBitVector()) {
-      throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term");
-    }
     BitVectorExtract extractInfo = n.getOperator().getConst<BitVectorExtract>();
+
+    // NOTE: We're throwing a type-checking exception here even
+    // if check is false, bc if we allow high < low the resulting
+    // type will be illegal
     if (extractInfo.high < extractInfo.low) {
       throw TypeCheckingExceptionPrivate(n, "high extract index is smaller than the low extract index");
     }
-    if (extractInfo.high >= t.getBitVectorSize()) {
-      throw TypeCheckingExceptionPrivate(n, "high extract index is bigger than the size of the bit-vector");
+
+    if( check ) {
+      TypeNode t = n[0].getType(check);
+      if (!t.isBitVector()) {
+        throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term");
+      }
+      if (extractInfo.high >= t.getBitVectorSize()) {
+        throw TypeCheckingExceptionPrivate(n, "high extract index is bigger than the size of the bit-vector");
+      }
     }
     return nodeManager->mkBitVectorType(extractInfo.high - extractInfo.low + 1);
   }
@@ -121,13 +136,16 @@ public:
 
 class BitVectorConcatRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
     unsigned size = 0;
     TNode::iterator it = n.begin();
     TNode::iterator it_end = n.end();
     for (; it != it_end; ++ it) {
-       TypeNode t = n[0].getType();
+       TypeNode t = n[0].getType(check);
+       // NOTE: We're throwing a type-checking exception here even
+       // when check is false, bc if we don't check that the arguments
+       // are bit-vectors the result type will be inaccurate
        if (!t.isBitVector()) {
          throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms");
        }
@@ -139,9 +157,12 @@ public:
 
 class BitVectorRepeatTypeRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
-    TypeNode t = n[0].getType();
+    TypeNode t = n[0].getType(check);
+    // NOTE: We're throwing a type-checking exception here even
+    // when check is false, bc if the argument isn't a bit-vector
+    // the result type will be inaccurate
     if (!t.isBitVector()) {
       throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term");
     }
@@ -152,9 +173,12 @@ public:
 
 class BitVectorExtendTypeRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
-    TypeNode t = n[0].getType();
+    TypeNode t = n[0].getType(check);
+    // NOTE: We're throwing a type-checking exception here even
+    // when check is false, bc if the argument isn't a bit-vector
+    // the result type will be inaccurate
     if (!t.isBitVector()) {
       throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term");
     }
index f09a44d509869183e545f23776260785a3e12ffc..38018112a4d1c8eff4e1f142a80a208900e38858 100644 (file)
@@ -27,20 +27,26 @@ namespace uf {
 
 class UfTypeRule {
 public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n)
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
       throw (TypeCheckingExceptionPrivate) {
     TNode f = n.getOperator();
-    TypeNode fType = f.getType();
-    if (n.getNumChildren() != fType.getNumChildren() - 1) {
-      throw TypeCheckingExceptionPrivate(n, "number of arguments does not match the function type");
+    TypeNode fType = f.getType(check);
+    if( !fType.isFunction() ) {
+      throw TypeCheckingExceptionPrivate(n, "operator does not have function type");
     }
-    TNode::iterator argument_it = n.begin();
-    TNode::iterator argument_it_end = n.end();
-    TypeNode::iterator argument_type_it = fType.begin();
-    for(; argument_it != argument_it_end; ++argument_it)
-      if((*argument_it).getType() != *argument_type_it) {
-        throw TypeCheckingExceptionPrivate(n, "argument types do not match the function type");
+    if( check ) {
+      if (n.getNumChildren() != fType.getNumChildren() - 1) {
+        throw TypeCheckingExceptionPrivate(n, "number of arguments does not match the function type");
       }
+      TNode::iterator argument_it = n.begin();
+      TNode::iterator argument_it_end = n.end();
+      TypeNode::iterator argument_type_it = fType.begin();
+      for(; argument_it != argument_it_end; ++argument_it) {
+        if((*argument_it).getType() != *argument_type_it) {
+          throw TypeCheckingExceptionPrivate(n, "argument types do not match the function type");
+        }
+      }
+    }
     return fType.getRangeType();
   }
 };/* class UfTypeRule */
index 7900057e15ac2694ee77355f65fcd109fb5b6798..4849e55cb54a34a1aa6c4dae8607160ffb276bb8 100644 (file)
@@ -265,9 +265,11 @@ public:
   void testGetType() {
     /* Type getType(); */
 
-    TS_ASSERT(a_bool->getType() == d_em->booleanType());
-    TS_ASSERT(b_bool->getType() == d_em->booleanType());
-    TS_ASSERT_THROWS(c_bool_mult->getType(), TypeCheckingException);
+    TS_ASSERT(a_bool->getType(false) == d_em->booleanType());
+    TS_ASSERT(a_bool->getType(true) == d_em->booleanType());
+    TS_ASSERT(b_bool->getType(false) == d_em->booleanType());
+    TS_ASSERT(b_bool->getType(true) == d_em->booleanType());
+    TS_ASSERT_THROWS(c_bool_mult->getType(true), TypeCheckingException);
 // These need better support for operators
 //    TS_ASSERT(mult_op->getType().isNull());
 //    TS_ASSERT(plus_op->getType().isNull());