From 2564d8730f768a8305325d4b6cc08211d8a3281d Mon Sep 17 00:00:00 2001 From: "Christopher L. Conway" Date: Tue, 27 Jul 2010 20:54:33 +0000 Subject: [PATCH] Adding optional 'check' parameter to getType() methods --- src/expr/expr_manager_template.cpp | 29 +++- src/expr/expr_manager_template.h | 3 +- src/expr/expr_template.cpp | 4 +- src/expr/expr_template.h | 27 +++- src/expr/node.h | 30 +++- src/expr/node_manager.cpp | 135 ++++++++++-------- src/expr/node_manager.h | 32 ++++- src/theory/arith/theory_arith_type_rules.h | 41 +++--- src/theory/arrays/theory_arrays_type_rules.h | 32 +++-- src/theory/booleans/theory_bool_type_rules.h | 33 +++-- .../builtin/theory_builtin_type_rules.h | 24 ++-- src/theory/bv/theory_bv_type_rules.h | 104 ++++++++------ src/theory/uf/theory_uf_type_rules.h | 26 ++-- test/unit/expr/expr_public.h | 8 +- 14 files changed, 343 insertions(+), 185 deletions(-) diff --git a/src/expr/expr_manager_template.cpp b/src/expr/expr_manager_template.cpp index f28729b94..5fcbad3a2 100644 --- a/src/expr/expr_manager_template.cpp +++ b/src/expr/expr_manager_template.cpp @@ -226,11 +226,36 @@ SortType ExprManager::mkSort(const std::string& name) const { return Type(d_nodeManager, new TypeNode(d_nodeManager->mkSort(name))); } -Type ExprManager::getType(const Expr& e) throw (TypeCheckingException) { +/** + * Get the type for the given Expr and optionally do type checking. + * + * Initial type computation will be near-constant time if + * type checking is not requested. Results are memoized, so that + * subsequent calls to getType() without type checking will be + * constant time. + * + * Initial type checking is linear in the size of the expression. + * Again, the results are memoized, so that subsequent calls to + * getType(), with or without type checking, will be constant + * time. + * + * NOTE: A TypeCheckingException can be thrown even when type + * checking is not requested. getType() will always return a + * valid and correct type and, thus, an exception will be thrown + * when no valid or correct type can be computed (e.g., if the + * arguments to a bit-vector operation aren't bit-vectors). When + * type checking is not requested, getType() will do the minimum + * amount of checking required to return a valid result. + * + * @param n the Expr for which we want a type + * @param check whether we should check the type as we compute it + * (default: false) + */ +Type ExprManager::getType(const Expr& e, bool check) throw (TypeCheckingException) { NodeManagerScope nms(d_nodeManager); Type t; try { - t = Type(d_nodeManager, new TypeNode(d_nodeManager->getType(e.getNode()))); + t = Type(d_nodeManager, new TypeNode(d_nodeManager->getType(e.getNode(), check))); } catch (const TypeCheckingExceptionPrivate& e) { throw TypeCheckingException(Expr(this, new Node(e.getNode())), e.getMessage()); } diff --git a/src/expr/expr_manager_template.h b/src/expr/expr_manager_template.h index 450d7fc4d..3b5b0e0f4 100644 --- a/src/expr/expr_manager_template.h +++ b/src/expr/expr_manager_template.h @@ -221,7 +221,8 @@ public: SortType mkSort(const std::string& name) const; /** Get the type of an expression */ - Type getType(const Expr& e) throw (TypeCheckingException); + Type getType(const Expr& e, bool check = false) + throw (TypeCheckingException); // variables are special, because duplicates are permitted Expr mkVar(const std::string& name, const Type& type); diff --git a/src/expr/expr_template.cpp b/src/expr/expr_template.cpp index fc67bcba1..48acd2588 100644 --- a/src/expr/expr_template.cpp +++ b/src/expr/expr_template.cpp @@ -181,10 +181,10 @@ Expr Expr::getOperator() const { return Expr(d_exprManager, new Node(d_node->getOperator())); } -Type Expr::getType() const throw (TypeCheckingException) { +Type Expr::getType(bool check) const throw (TypeCheckingException) { ExprManagerScope ems(*this); Assert(d_node != NULL, "Unexpected NULL expression pointer!"); - return d_exprManager->getType(*this); + return d_exprManager->getType(*this, check); } std::string Expr::toString() const { diff --git a/src/expr/expr_template.h b/src/expr/expr_template.h index 517931477..becdd46e2 100644 --- a/src/expr/expr_template.h +++ b/src/expr/expr_template.h @@ -224,10 +224,31 @@ public: */ Expr getOperator() const; - /** Returns the type of the expression, if it has been computed. - * Returns NULL if the type of the expression is not known. + /** + * Get the type for this Expr and optionally do type checking. + * + * Initial type computation will be near-constant time if + * type checking is not requested. Results are memoized, so that + * subsequent calls to getType() without type checking will be + * constant time. + * + * Initial type checking is linear in the size of the expression. + * Again, the results are memoized, so that subsequent calls to + * getType(), with or without type checking, will be constant + * time. + * + * NOTE: A TypeCheckingException can be thrown even when type + * checking is not requested. getType() will always return a + * valid and correct type and, thus, an exception will be thrown + * when no valid or correct type can be computed (e.g., if the + * arguments to a bit-vector operation aren't bit-vectors). When + * type checking is not requested, getType() will do the minimum + * amount of checking required to return a valid result. + * + * @param check whether we should check the type as we compute it + * (default: false) */ - Type getType() const throw (TypeCheckingException); + Type getType(bool check = false) const throw (TypeCheckingException); /** * Returns the string representation of the expression. diff --git a/src/expr/node.h b/src/expr/node.h index 218b9a3ea..4b1a0e5be 100644 --- a/src/expr/node.h +++ b/src/expr/node.h @@ -334,10 +334,30 @@ public: inline bool hasOperator() const; /** - * Returns the type of this node. - * @return the type + * Get the type for the node and optionally do type checking. + * + * Initial type computation will be near-constant time if + * type checking is not requested. Results are memoized, so that + * subsequent calls to getType() without type checking will be + * constant time. + * + * Initial type checking is linear in the size of the expression. + * Again, the results are memoized, so that subsequent calls to + * getType(), with or without type checking, will be constant + * time. + * + * NOTE: A TypeCheckingException can be thrown even when type + * checking is not requested. getType() will always return a + * valid and correct type and, thus, an exception will be thrown + * when no valid or correct type can be computed (e.g., if the + * arguments to a bit-vector operation aren't bit-vectors). When + * type checking is not requested, getType() will do the minimum + * amount of checking required to return a valid result. + * + * @param check whether we should check the type as we compute it + * (default: false) */ - TypeNode getType() const + TypeNode getType(bool check = false) const throw (CVC4::TypeCheckingExceptionPrivate, CVC4::AssertionException); /** @@ -893,7 +913,7 @@ inline bool NodeTemplate::hasOperator() const { } template -TypeNode NodeTemplate::getType() const +TypeNode NodeTemplate::getType(bool check) const throw (CVC4::TypeCheckingExceptionPrivate, CVC4::AssertionException) { Assert( NodeManager::currentNM() != NULL, "There is no current CVC4::NodeManager associated to this thread.\n" @@ -903,7 +923,7 @@ TypeNode NodeTemplate::getType() const Assert( d_nv->d_rc > 0, "TNode pointing to an expired NodeValue" ); } - return NodeManager::currentNM()->getType(*this); + return NodeManager::currentNM()->getType(*this, check); } #ifdef CVC4_DEBUG diff --git a/src/expr/node_manager.cpp b/src/expr/node_manager.cpp index 2e45fe9d0..fbfffe87d 100644 --- a/src/expr/node_manager.cpp +++ b/src/expr/node_manager.cpp @@ -27,6 +27,8 @@ #include "theory/arrays/theory_arrays_type_rules.h" #include "theory/bv/theory_bv_type_rules.h" +#include "util/Assert.h" + #include #include @@ -182,202 +184,211 @@ void NodeManager::reclaimZombies() { } }/* NodeManager::reclaimZombies() */ -TypeNode NodeManager::getType(TNode n) +TypeNode NodeManager::getType(TNode n, bool check) throw (TypeCheckingExceptionPrivate, AssertionException) { TypeNode typeNode; bool hasType = getAttribute(n, TypeAttr(), typeNode); + bool needsCheck = check && !getAttribute(n, TypeCheckedAttr()); + Debug("getType") << "getting type for " << n << std::endl; - if(!hasType) { + if(!hasType || needsCheck) { + TypeNode oldType = typeNode; + // Infer the type switch(n.getKind()) { case kind::SORT_TYPE: typeNode = kindType(); break; case kind::EQUAL: - typeNode = CVC4::theory::builtin::EqualityTypeRule::computeType(this, n); + typeNode = CVC4::theory::builtin::EqualityTypeRule::computeType(this, n, check); break; case kind::DISTINCT: - typeNode = CVC4::theory::builtin::DistinctTypeRule::computeType(this, n); + typeNode = CVC4::theory::builtin::DistinctTypeRule::computeType(this, n, check); break; case kind::CONST_BOOLEAN: - typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n); + typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check); break; case kind::NOT: - typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n); + typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check); break; case kind::AND: - typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n); + typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check); break; case kind::IFF: - typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n); + typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check); break; case kind::IMPLIES: - typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n); + typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check); break; case kind::OR: - typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n); + typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check); break; case kind::XOR: - typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n); + typeNode = CVC4::theory::boolean::BooleanTypeRule::computeType(this, n, check); break; case kind::ITE: - typeNode = CVC4::theory::boolean::IteTypeRule::computeType(this, n); + typeNode = CVC4::theory::boolean::IteTypeRule::computeType(this, n, check); break; case kind::APPLY_UF: - typeNode = CVC4::theory::uf::UfTypeRule::computeType(this, n); + typeNode = CVC4::theory::uf::UfTypeRule::computeType(this, n, check); break; case kind::PLUS: - typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n); + typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check); break; case kind::MULT: - typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n); + typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check); break; case kind::MINUS: - typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n); + typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check); break; case kind::UMINUS: - typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n); + typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check); break; case kind::DIVISION: - typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n); + typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check); break; case kind::CONST_RATIONAL: - typeNode = CVC4::theory::arith::ArithConstantTypeRule::computeType(this, n); + typeNode = CVC4::theory::arith::ArithConstantTypeRule::computeType(this, n, check); break; case kind::CONST_INTEGER: - typeNode = CVC4::theory::arith::ArithConstantTypeRule::computeType(this, n); + typeNode = CVC4::theory::arith::ArithConstantTypeRule::computeType(this, n, check); break; case kind::LT: - typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n); + typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check); break; case kind::LEQ: - typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n); + typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check); break; case kind::GT: - typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n); + typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check); break; case kind::GEQ: - typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n); + typeNode = CVC4::theory::arith::ArithPredicateTypeRule::computeType(this, n, check); break; case kind::SELECT: - typeNode = CVC4::theory::arrays::ArraySelectTypeRule::computeType(this, n); + typeNode = CVC4::theory::arrays::ArraySelectTypeRule::computeType(this, n, check); break; case kind::STORE: - typeNode = CVC4::theory::arrays::ArrayStoreTypeRule::computeType(this, n); + typeNode = CVC4::theory::arrays::ArrayStoreTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_CONST: - typeNode = CVC4::theory::bv::BitVectorConstantTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorConstantTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_AND: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_OR: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_XOR: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_NOT: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_NAND: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_NOR: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_XNOR: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_COMP: - typeNode = CVC4::theory::bv::BitVectorCompRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorCompRule::computeType(this, n, check); break; case kind::BITVECTOR_MULT: - typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check); break; case kind::BITVECTOR_PLUS: - typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check); break; case kind::BITVECTOR_SUB: - typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check); break; case kind::BITVECTOR_NEG: - typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorArithRule::computeType(this, n, check); break; case kind::BITVECTOR_UDIV: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_UREM: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_SDIV: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_SREM: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_SMOD: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_SHL: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_LSHR: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_ASHR: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_ROTATE_LEFT: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_ROTATE_RIGHT: - typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorFixedWidthTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_ULT: - typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_ULE: - typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_UGT: - typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_UGE: - typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_SLT: - typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_SLE: - typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_SGT: - typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_SGE: - typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorPredicateTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_EXTRACT: - typeNode = CVC4::theory::bv::BitVectorExtractTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorExtractTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_CONCAT: - typeNode = CVC4::theory::bv::BitVectorConcatRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorConcatRule::computeType(this, n, check); break; case kind::BITVECTOR_REPEAT: - typeNode = CVC4::theory::bv::BitVectorRepeatTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorRepeatTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_ZERO_EXTEND: - typeNode = CVC4::theory::bv::BitVectorExtendTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorExtendTypeRule::computeType(this, n, check); break; case kind::BITVECTOR_SIGN_EXTEND: - typeNode = CVC4::theory::bv::BitVectorExtendTypeRule::computeType(this, n); + typeNode = CVC4::theory::bv::BitVectorExtendTypeRule::computeType(this, n, check); break; default: Debug("getType") << "FAILURE" << std::endl; Unhandled(n.getKind()); } + + // DebugAssert( !hasType || oldType == typeNode, + // "Type re-computation yielded a different type" ); + setAttribute(n, TypeAttr(), typeNode); + setAttribute(n, TypeCheckedAttr(), check); } Debug("getType") << "type of " << n << " is " << typeNode << std::endl; return typeNode; diff --git a/src/expr/node_manager.h b/src/expr/node_manager.h index d01e22abd..7a53cabfc 100644 --- a/src/expr/node_manager.h +++ b/src/expr/node_manager.h @@ -221,10 +221,12 @@ class NodeManager { // attribute tags struct TypeTag {}; + struct TypeCheckedTag; // NodeManager's attributes. These aren't exposed outside of this // class; use the getters. typedef expr::Attribute TypeAttr; + typedef expr::Attribute TypeCheckedAttr; /* A note on isAtomic() and isAtomicFormula() (in CVC3 parlance).. * @@ -527,11 +529,32 @@ public: inline TypeNode mkSort(const std::string& name); /** - * Get the type for the given node. + * Get the type for the given node and optionally do type checking. + * + * Initial type computation will be near-constant time if + * type checking is not requested. Results are memoized, so that + * subsequent calls to getType() without type checking will be + * constant time. + * + * Initial type checking is linear in the size of the expression. + * Again, the results are memoized, so that subsequent calls to + * getType(), with or without type checking, will be constant + * time. + * + * NOTE: A TypeCheckingException can be thrown even when type + * checking is not requested. getType() will always return a + * valid and correct type and, thus, an exception will be thrown + * when no valid or correct type can be computed (e.g., if the + * arguments to a bit-vector operation aren't bit-vectors). When + * type checking is not requested, getType() will do the minimum + * amount of checking required to return a valid result. + * + * @param n the Node for which we want a type + * @param check whether we should check the type as we compute it + * (default: false) */ - TypeNode getType(TNode n) + TypeNode getType(TNode n, bool check = false) throw (TypeCheckingExceptionPrivate, AssertionException); - }; /** @@ -888,18 +911,21 @@ inline Node* NodeManager::mkVarPtr(const std::string& name, inline Node NodeManager::mkVar(const TypeNode& type) { Node n = NodeBuilder<0>(this, kind::VARIABLE); n.setAttribute(TypeAttr(), type); + n.setAttribute(TypeCheckedAttr(), true); return n; } inline Node* NodeManager::mkVarPtr(const TypeNode& type) { Node* n = NodeBuilder<0>(this, kind::VARIABLE).constructNodePtr(); n->setAttribute(TypeAttr(), type); + n->setAttribute(TypeCheckedAttr(), true); return n; } inline Node NodeManager::mkSkolem(const TypeNode& type) { Node n = NodeBuilder<0>(this, kind::SKOLEM); n.setAttribute(TypeAttr(), type); + n.setAttribute(TypeCheckedAttr(), true); return n; } diff --git a/src/theory/arith/theory_arith_type_rules.h b/src/theory/arith/theory_arith_type_rules.h index 9fb30bdb4..b8fa85c03 100644 --- a/src/theory/arith/theory_arith_type_rules.h +++ b/src/theory/arith/theory_arith_type_rules.h @@ -28,7 +28,7 @@ namespace arith { class ArithConstantTypeRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { if (n.getKind() == kind::CONST_RATIONAL) return nodeManager->realType(); return nodeManager->integerType(); @@ -37,7 +37,7 @@ public: class ArithOperatorTypeRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { TypeNode integerType = nodeManager->integerType(); TypeNode realType = nodeManager->realType(); @@ -45,10 +45,17 @@ public: TNode::iterator child_it_end = n.end(); bool isInteger = true; for(; child_it != child_it_end; ++child_it) { - TypeNode childType = (*child_it).getType(); - if (!childType.isInteger()) isInteger = false; - if(childType != integerType && childType != realType) { - throw TypeCheckingExceptionPrivate(n, "expecting an arithmetic subterm"); + TypeNode childType = (*child_it).getType(check); + if (!childType.isInteger()) { + isInteger = false; + if( !check ) { // if we're not checking, nothing left to do + break; + } + } + if( check ) { + if(childType != integerType && childType != realType) { + throw TypeCheckingExceptionPrivate(n, "expecting an arithmetic subterm"); + } } } return (isInteger ? integerType : realType); @@ -57,17 +64,19 @@ public: class ArithPredicateTypeRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { - TypeNode integerType = nodeManager->integerType(); - TypeNode realType = nodeManager->realType(); - TypeNode lhsType = n[0].getType(); - if (lhsType != integerType && lhsType != realType) { - throw TypeCheckingExceptionPrivate(n, "expecting an arithmetic term on the left-hand-side"); - } - TypeNode rhsType = n[1].getType(); - if (rhsType != integerType && rhsType != realType) { - throw TypeCheckingExceptionPrivate(n, "expecting an arithmetic term on the right-hand-side"); + if( check ) { + TypeNode integerType = nodeManager->integerType(); + TypeNode realType = nodeManager->realType(); + TypeNode lhsType = n[0].getType(check); + if (lhsType != integerType && lhsType != realType) { + throw TypeCheckingExceptionPrivate(n, "expecting an arithmetic term on the left-hand-side"); + } + TypeNode rhsType = n[1].getType(check); + if (rhsType != integerType && rhsType != realType) { + throw TypeCheckingExceptionPrivate(n, "expecting an arithmetic term on the right-hand-side"); + } } return nodeManager->booleanType(); } diff --git a/src/theory/arrays/theory_arrays_type_rules.h b/src/theory/arrays/theory_arrays_type_rules.h index 0eb88d800..5d0713a89 100644 --- a/src/theory/arrays/theory_arrays_type_rules.h +++ b/src/theory/arrays/theory_arrays_type_rules.h @@ -26,30 +26,34 @@ namespace theory { namespace arrays { struct ArraySelectTypeRule { - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { Assert(n.getKind() == kind::SELECT); - TypeNode arrayType = n[0].getType(); - TypeNode indexType = n[1].getType(); - if(arrayType.getArrayIndexType() != indexType) { - throw TypeCheckingExceptionPrivate(n, "array select not indexed with correct type for array"); + TypeNode arrayType = n[0].getType(check); + if( check ) { + TypeNode indexType = n[1].getType(check); + if(arrayType.getArrayIndexType() != indexType) { + throw TypeCheckingExceptionPrivate(n, "array select not indexed with correct type for array"); + } } return arrayType.getArrayConstituentType(); } };/* struct ArraySelectTypeRule */ struct ArrayStoreTypeRule { - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { Assert(n.getKind() == kind::STORE); - TypeNode arrayType = n[0].getType(); - TypeNode indexType = n[1].getType(); - TypeNode valueType = n[2].getType(); - if(arrayType.getArrayIndexType() != indexType) { - throw TypeCheckingExceptionPrivate(n, "array store not indexed with correct type for array"); - } - if(arrayType.getArrayConstituentType() != valueType) { - throw TypeCheckingExceptionPrivate(n, "array store not assigned with correct type for array"); + TypeNode arrayType = n[0].getType(check); + if( check ) { + TypeNode indexType = n[1].getType(check); + TypeNode valueType = n[2].getType(check); + if(arrayType.getArrayIndexType() != indexType) { + throw TypeCheckingExceptionPrivate(n, "array store not indexed with correct type for array"); + } + if(arrayType.getArrayConstituentType() != valueType) { + throw TypeCheckingExceptionPrivate(n, "array store not assigned with correct type for array"); + } } return arrayType; } diff --git a/src/theory/booleans/theory_bool_type_rules.h b/src/theory/booleans/theory_bool_type_rules.h index b947cee10..fddced8ef 100644 --- a/src/theory/booleans/theory_bool_type_rules.h +++ b/src/theory/booleans/theory_bool_type_rules.h @@ -27,30 +27,35 @@ namespace boolean { class BooleanTypeRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { TypeNode booleanType = nodeManager->booleanType(); - TNode::iterator child_it = n.begin(); - TNode::iterator child_it_end = n.end(); - for(; child_it != child_it_end; ++child_it) - if((*child_it).getType() != booleanType) { - throw TypeCheckingExceptionPrivate(n, "expecting a Boolean subexpression"); + if( check ) { + TNode::iterator child_it = n.begin(); + TNode::iterator child_it_end = n.end(); + for(; child_it != child_it_end; ++child_it) { + if((*child_it).getType(check) != booleanType) { + throw TypeCheckingExceptionPrivate(n, "expecting a Boolean subexpression"); + } } + } return booleanType; } }; class IteTypeRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { - TypeNode booleanType = nodeManager->booleanType(); - if (n[0].getType() != booleanType) { - throw TypeCheckingExceptionPrivate(n, "condition of ITE is not Boolean"); - } - TypeNode iteType = n[1].getType(); - if (iteType != n[2].getType()) { - throw TypeCheckingExceptionPrivate(n, "both branches of the ITE must be of the same type"); + TypeNode iteType = n[1].getType(check); + if( check ) { + TypeNode booleanType = nodeManager->booleanType(); + if (n[0].getType(check) != booleanType) { + throw TypeCheckingExceptionPrivate(n, "condition of ITE is not Boolean"); + } + if (iteType != n[2].getType(check)) { + throw TypeCheckingExceptionPrivate(n, "both branches of the ITE must be of the same type"); + } } return iteType; } diff --git a/src/theory/builtin/theory_builtin_type_rules.h b/src/theory/builtin/theory_builtin_type_rules.h index 19d6e268b..4458931a9 100644 --- a/src/theory/builtin/theory_builtin_type_rules.h +++ b/src/theory/builtin/theory_builtin_type_rules.h @@ -31,9 +31,11 @@ namespace builtin { class EqualityTypeRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) throw (TypeCheckingExceptionPrivate) { - if (n[0].getType() != n[1].getType()) { - throw TypeCheckingExceptionPrivate(n, "Left and right hand side of the equation are not of the same type"); + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { + if( check ) { + if (n[0].getType(check) != n[1].getType(check)) { + throw TypeCheckingExceptionPrivate(n, "Left and right hand side of the equation are not of the same type"); + } } return nodeManager->booleanType(); } @@ -41,13 +43,15 @@ class EqualityTypeRule { class DistinctTypeRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) { - TNode::iterator child_it = n.begin(); - TNode::iterator child_it_end = n.end(); - TypeNode firstType = (*child_it).getType(); - for (++child_it; child_it != child_it_end; ++child_it) { - if ((*child_it).getType() != firstType) { - throw TypeCheckingExceptionPrivate(n, "Not all arguments are of the same type"); + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) { + if( check ) { + TNode::iterator child_it = n.begin(); + TNode::iterator child_it_end = n.end(); + TypeNode firstType = (*child_it).getType(check); + for (++child_it; child_it != child_it_end; ++child_it) { + if ((*child_it).getType() != firstType) { + throw TypeCheckingExceptionPrivate(n, "Not all arguments are of the same type"); + } } } return nodeManager->booleanType(); diff --git a/src/theory/bv/theory_bv_type_rules.h b/src/theory/bv/theory_bv_type_rules.h index 7aaae7349..9bb9e61df 100644 --- a/src/theory/bv/theory_bv_type_rules.h +++ b/src/theory/bv/theory_bv_type_rules.h @@ -18,6 +18,8 @@ #include "cvc4_private.h" +#include + #ifndef __CVC4__THEORY__BV__THEORY_BV_TYPE_RULES_H #define __CVC4__THEORY__BV__THEORY_BV_TYPE_RULES_H @@ -27,7 +29,7 @@ namespace bv { class BitVectorConstantTypeRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { return nodeManager->mkBitVectorType(n.getConst().getSize()); } @@ -35,12 +37,14 @@ public: class BitVectorCompRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { - TypeNode lhs = n[0].getType(); - TypeNode rhs = n[1].getType(); - if (!lhs.isBitVector() || lhs != rhs) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms of the same width"); + if( check ) { + TypeNode lhs = n[0].getType(check); + TypeNode rhs = n[1].getType(check); + if (!lhs.isBitVector() || lhs != rhs) { + throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms of the same width"); + } } return nodeManager->mkBitVectorType(1); } @@ -48,18 +52,18 @@ public: class BitVectorArithRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { unsigned maxWidth = 0; TNode::iterator it = n.begin(); TNode::iterator it_end = n.end(); // TODO: optimize unary neg for (; it != it_end; ++ it) { - TypeNode t = (*it).getType(); - if (!t.isBitVector()) { + TypeNode t = (*it).getType(check); + if (check && !t.isBitVector()) { throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms"); } - if (maxWidth < t.getBitVectorSize()) maxWidth = t.getBitVectorSize(); + maxWidth = std::max( maxWidth, t.getBitVectorSize() ); } return nodeManager->mkBitVectorType(maxWidth); } @@ -67,17 +71,19 @@ public: class BitVectorFixedWidthTypeRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { TNode::iterator it = n.begin(); - TNode::iterator it_end = n.end(); - TypeNode t = (*it).getType(); - if (!t.isBitVector()) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms"); - } - for (++ it; it != it_end; ++ it) { - if ((*it).getType() != t) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms of the same width"); + TypeNode t = (*it).getType(check); + if( check ) { + if (!t.isBitVector()) { + throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms"); + } + TNode::iterator it_end = n.end(); + for (++ it; it != it_end; ++ it) { + if ((*it).getType(check) != t) { + throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms of the same width"); + } } } return t; @@ -86,15 +92,17 @@ public: class BitVectorPredicateTypeRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { - TypeNode lhsType = n[0].getType(); - if (!lhsType.isBitVector()) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms"); - } - TypeNode rhsType = n[1].getType(); - if (lhsType != rhsType) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms of the same width"); + if( check ) { + TypeNode lhsType = n[0].getType(check); + if (!lhsType.isBitVector()) { + throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms"); + } + TypeNode rhsType = n[1].getType(check); + if (lhsType != rhsType) { + throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms of the same width"); + } } return nodeManager->booleanType(); } @@ -102,18 +110,25 @@ public: class BitVectorExtractTypeRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { - TypeNode t = n[0].getType(); - if (!t.isBitVector()) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term"); - } BitVectorExtract extractInfo = n.getOperator().getConst(); + + // NOTE: We're throwing a type-checking exception here even + // if check is false, bc if we allow high < low the resulting + // type will be illegal if (extractInfo.high < extractInfo.low) { throw TypeCheckingExceptionPrivate(n, "high extract index is smaller than the low extract index"); } - if (extractInfo.high >= t.getBitVectorSize()) { - throw TypeCheckingExceptionPrivate(n, "high extract index is bigger than the size of the bit-vector"); + + if( check ) { + TypeNode t = n[0].getType(check); + if (!t.isBitVector()) { + throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term"); + } + if (extractInfo.high >= t.getBitVectorSize()) { + throw TypeCheckingExceptionPrivate(n, "high extract index is bigger than the size of the bit-vector"); + } } return nodeManager->mkBitVectorType(extractInfo.high - extractInfo.low + 1); } @@ -121,13 +136,16 @@ public: class BitVectorConcatRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { unsigned size = 0; TNode::iterator it = n.begin(); TNode::iterator it_end = n.end(); for (; it != it_end; ++ it) { - TypeNode t = n[0].getType(); + TypeNode t = n[0].getType(check); + // NOTE: We're throwing a type-checking exception here even + // when check is false, bc if we don't check that the arguments + // are bit-vectors the result type will be inaccurate if (!t.isBitVector()) { throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms"); } @@ -139,9 +157,12 @@ public: class BitVectorRepeatTypeRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { - TypeNode t = n[0].getType(); + TypeNode t = n[0].getType(check); + // NOTE: We're throwing a type-checking exception here even + // when check is false, bc if the argument isn't a bit-vector + // the result type will be inaccurate if (!t.isBitVector()) { throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term"); } @@ -152,9 +173,12 @@ public: class BitVectorExtendTypeRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { - TypeNode t = n[0].getType(); + TypeNode t = n[0].getType(check); + // NOTE: We're throwing a type-checking exception here even + // when check is false, bc if the argument isn't a bit-vector + // the result type will be inaccurate if (!t.isBitVector()) { throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term"); } diff --git a/src/theory/uf/theory_uf_type_rules.h b/src/theory/uf/theory_uf_type_rules.h index f09a44d50..38018112a 100644 --- a/src/theory/uf/theory_uf_type_rules.h +++ b/src/theory/uf/theory_uf_type_rules.h @@ -27,20 +27,26 @@ namespace uf { class UfTypeRule { public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n) + inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) throw (TypeCheckingExceptionPrivate) { TNode f = n.getOperator(); - TypeNode fType = f.getType(); - if (n.getNumChildren() != fType.getNumChildren() - 1) { - throw TypeCheckingExceptionPrivate(n, "number of arguments does not match the function type"); + TypeNode fType = f.getType(check); + if( !fType.isFunction() ) { + throw TypeCheckingExceptionPrivate(n, "operator does not have function type"); } - TNode::iterator argument_it = n.begin(); - TNode::iterator argument_it_end = n.end(); - TypeNode::iterator argument_type_it = fType.begin(); - for(; argument_it != argument_it_end; ++argument_it) - if((*argument_it).getType() != *argument_type_it) { - throw TypeCheckingExceptionPrivate(n, "argument types do not match the function type"); + if( check ) { + if (n.getNumChildren() != fType.getNumChildren() - 1) { + throw TypeCheckingExceptionPrivate(n, "number of arguments does not match the function type"); } + TNode::iterator argument_it = n.begin(); + TNode::iterator argument_it_end = n.end(); + TypeNode::iterator argument_type_it = fType.begin(); + for(; argument_it != argument_it_end; ++argument_it) { + if((*argument_it).getType() != *argument_type_it) { + throw TypeCheckingExceptionPrivate(n, "argument types do not match the function type"); + } + } + } return fType.getRangeType(); } };/* class UfTypeRule */ diff --git a/test/unit/expr/expr_public.h b/test/unit/expr/expr_public.h index 7900057e1..4849e55cb 100644 --- a/test/unit/expr/expr_public.h +++ b/test/unit/expr/expr_public.h @@ -265,9 +265,11 @@ public: void testGetType() { /* Type getType(); */ - TS_ASSERT(a_bool->getType() == d_em->booleanType()); - TS_ASSERT(b_bool->getType() == d_em->booleanType()); - TS_ASSERT_THROWS(c_bool_mult->getType(), TypeCheckingException); + TS_ASSERT(a_bool->getType(false) == d_em->booleanType()); + TS_ASSERT(a_bool->getType(true) == d_em->booleanType()); + TS_ASSERT(b_bool->getType(false) == d_em->booleanType()); + TS_ASSERT(b_bool->getType(true) == d_em->booleanType()); + TS_ASSERT_THROWS(c_bool_mult->getType(true), TypeCheckingException); // These need better support for operators // TS_ASSERT(mult_op->getType().isNull()); // TS_ASSERT(plus_op->getType().isNull()); -- 2.30.2