** \todo document this file
**/
+#include "expr/attribute.h"
#include "theory/builtin/theory_builtin_rewriter.h"
#include "expr/chain.h"
}
}
+RewriteResponse TheoryBuiltinRewriter::postRewrite(TNode node) {
+ if( node.getKind()==kind::LAMBDA ){
+ Trace("builtin-rewrite") << "Rewriting lambda " << node << "..." << std::endl;
+ Node anode = getArrayRepresentationForLambda( node );
+ if( !anode.isNull() ){
+ anode = Rewriter::rewrite( anode );
+ Assert( anode.getType().isArray() );
+ //must get the standard bound variable list
+ Node varList = NodeManager::currentNM()->getBoundVarListForFunctionType( node.getType() );
+ Node retNode = getLambdaForArrayRepresentation( anode, varList );
+ if( !retNode.isNull() && retNode!=node ){
+ Trace("builtin-rewrite") << "Rewrote lambda : " << std::endl;
+ Trace("builtin-rewrite") << " input : " << node << std::endl;
+ Trace("builtin-rewrite") << " output : " << retNode << ", constant = " << retNode.isConst() << std::endl;
+ Trace("builtin-rewrite") << " array rep : " << anode << ", constant = " << anode.isConst() << std::endl;
+ Assert( anode.isConst()==retNode.isConst() );
+ Assert( retNode.getType()==node.getType() );
+ return RewriteResponse(REWRITE_DONE, retNode);
+ }
+ }else{
+ Trace("builtin-rewrite-debug") << "...failed to get array representation." << std::endl;
+ }
+ return RewriteResponse(REWRITE_DONE, node);
+ }else{
+ return doRewrite(node);
+ }
+}
+
+Node TheoryBuiltinRewriter::getLambdaForArrayRepresentationRec( TNode a, TNode bvl, unsigned bvlIndex,
+ std::unordered_map< TNode, Node, TNodeHashFunction >& visited ){
+ std::unordered_map< TNode, Node, TNodeHashFunction >::iterator it = visited.find( a );
+ if( it==visited.end() ){
+ Node ret;
+ if( bvlIndex<bvl.getNumChildren() ){
+ Assert( a.getType().isArray() );
+ if( a.getKind()==kind::STORE ){
+ // convert the array recursively
+ Node body = getLambdaForArrayRepresentationRec( a[0], bvl, bvlIndex, visited );
+ if( !body.isNull() ){
+ // convert the value recursively (bounded by the number of arguments in bvl)
+ Node val = getLambdaForArrayRepresentationRec( a[2], bvl, bvlIndex+1, visited );
+ if( !val.isNull() ){
+ Assert( !TypeNode::leastCommonTypeNode( a[1].getType(), bvl[bvlIndex].getType() ).isNull() );
+ Assert( !TypeNode::leastCommonTypeNode( val.getType(), body.getType() ).isNull() );
+ Node cond = bvl[bvlIndex].eqNode( a[1] );
+ ret = NodeManager::currentNM()->mkNode( kind::ITE, cond, val, body );
+ }
+ }
+ }else if( a.getKind()==kind::STORE_ALL ){
+ ArrayStoreAll storeAll = a.getConst<ArrayStoreAll>();
+ Node sa = Node::fromExpr(storeAll.getExpr());
+ // convert the default value recursively (bounded by the number of arguments in bvl)
+ ret = getLambdaForArrayRepresentationRec( sa, bvl, bvlIndex+1, visited );
+ }
+ }else{
+ ret = a;
+ }
+ visited[a] = ret;
+ return ret;
+ }else{
+ return it->second;
+ }
+}
+
+Node TheoryBuiltinRewriter::getLambdaForArrayRepresentation( TNode a, TNode bvl ){
+ Assert( a.getType().isArray() );
+ std::unordered_map< TNode, Node, TNodeHashFunction > visited;
+ Trace("builtin-rewrite-debug") << "Get lambda for : " << a << ", with variables " << bvl << std::endl;
+ Node body = getLambdaForArrayRepresentationRec( a, bvl, 0, visited );
+ if( !body.isNull() ){
+ body = Rewriter::rewrite( body );
+ Trace("builtin-rewrite-debug") << "...got lambda body " << body << std::endl;
+ return NodeManager::currentNM()->mkNode( kind::LAMBDA, bvl, body );
+ }else{
+ Trace("builtin-rewrite-debug") << "...failed to get lambda body" << std::endl;
+ return Node::null();
+ }
+}
+
+Node TheoryBuiltinRewriter::getArrayRepresentationForLambda( TNode n, bool reqConst ){
+ Assert( n.getKind()==kind::LAMBDA );
+ Trace("builtin-rewrite-debug") << "Get array representation for : " << n << std::endl;
+
+ Node first_arg = n[0][0];
+ Node rec_bvl;
+ if( n[0].getNumChildren()>1 ){
+ std::vector< Node > args;
+ for( unsigned i=1; i<n[0].getNumChildren(); i++ ){
+ args.push_back( n[0][i] );
+ }
+ rec_bvl = NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, args );
+ }
+
+ Trace("builtin-rewrite-debug2") << " process body..." << std::endl;
+ TypeNode retType;
+ std::vector< Node > conds;
+ std::vector< Node > vals;
+ Node curr = n[1];
+ while( curr.getKind()==kind::ITE || curr.getKind()==kind::EQUAL || curr.getKind()==kind::NOT ){
+ Trace("builtin-rewrite-debug2") << " process condition : " << curr[0] << std::endl;
+ Node index_eq;
+ Node curr_val;
+ Node next;
+ if( curr.getKind()==kind::ITE ){
+ index_eq = curr[0];
+ curr_val = curr[1];
+ next = curr[2];
+ }else{
+ bool pol = curr.getKind()!=kind::NOT;
+ //Boolean case, e.g. lambda x. (= x v) is lambda x. (ite (= x v) true false)
+ index_eq = curr.getKind()==kind::NOT ? curr[0] : curr;
+ curr_val = NodeManager::currentNM()->mkConst( pol );
+ next = NodeManager::currentNM()->mkConst( !pol );
+ }
+ if( index_eq.getKind()!=kind::EQUAL ){
+ // non-equality condition
+ Trace("builtin-rewrite-debug2") << " ...non-equality condition." << std::endl;
+ return Node::null();
+ }else if( reqConst && Rewriter::rewrite( index_eq )!=index_eq ){
+ // equality must be oriented correctly based on rewriter
+ Trace("builtin-rewrite-debug2") << " ...equality not oriented properly." << std::endl;
+ return Node::null();
+ }
+
+ Node curr_index;
+ for( unsigned r=0; r<2; r++ ){
+ Node arg = index_eq[r];
+ Node val = index_eq[1-r];
+ if( arg==first_arg ){
+ if( reqConst && !val.isConst() ){
+ // non-constant value
+ Trace("builtin-rewrite-debug2") << " ...non-constant value." << std::endl;
+ return Node::null();
+ }else{
+ curr_index = val;
+ Trace("builtin-rewrite-debug2") << " " << arg << " -> " << val << std::endl;
+ break;
+ }
+ }
+ }
+ if( !curr_index.isNull() ){
+ if( !rec_bvl.isNull() ){
+ curr_val = NodeManager::currentNM()->mkNode( kind::LAMBDA, rec_bvl, curr_val );
+ curr_val = getArrayRepresentationForLambda( curr_val, reqConst );
+ if( curr_val.isNull() ){
+ Trace("builtin-rewrite-debug2") << " ...non-constant value." << std::endl;
+ return Node::null();
+ }
+ }
+ Trace("builtin-rewrite-debug2") << " ...condition is index " << curr_val << std::endl;
+ }else{
+ Trace("builtin-rewrite-debug2") << " ...non-constant value." << std::endl;
+ return Node::null();
+ }
+ conds.push_back( curr_index );
+ vals.push_back( curr_val );
+ TypeNode vtype = curr_val.getType();
+ retType = retType.isNull() ? vtype : TypeNode::leastCommonTypeNode( retType, vtype );
+ //recurse
+ curr = next;
+ }
+ if( !rec_bvl.isNull() ){
+ curr = NodeManager::currentNM()->mkNode( kind::LAMBDA, rec_bvl, curr );
+ curr = getArrayRepresentationForLambda( curr );
+ }
+ if( !curr.isNull() && curr.isConst() ){
+ TypeNode ctype = curr.getType();
+ retType = retType.isNull() ? ctype : TypeNode::leastCommonTypeNode( retType, ctype );
+ TypeNode array_type = NodeManager::currentNM()->mkArrayType( first_arg.getType(), retType );
+ curr = NodeManager::currentNM()->mkConst(ArrayStoreAll(((ArrayType)array_type.toType()), curr.toExpr()));
+ Trace("builtin-rewrite-debug2") << " build array..." << std::endl;
+ // can only build if default value is constant (since array store all must be constant)
+ Trace("builtin-rewrite-debug2") << " got constant base " << curr << std::endl;
+ // construct store chain
+ for( int i=((int)conds.size()-1); i>=0; i-- ){
+ Assert( conds[i].getType().isSubtypeOf( first_arg.getType() ) );
+ curr = NodeManager::currentNM()->mkNode( kind::STORE, curr, conds[i], vals[i] );
+ }
+ Trace("builtin-rewrite-debug") << "...got array " << curr << " for " << n << std::endl;
+ return curr;
+ }else{
+ Trace("builtin-rewrite-debug") << "...failed to get array (cannot get constant default value)" << std::endl;
+ return Node::null();
+ }
+}
+
}/* CVC4::theory::builtin namespace */
}/* CVC4::theory namespace */
}/* CVC4 namespace */
}
}
- static inline RewriteResponse postRewrite(TNode node) {
- return doRewrite(node);
- }
+ static RewriteResponse postRewrite(TNode node);
static inline RewriteResponse preRewrite(TNode node) {
return doRewrite(node);
static inline void init() {}
static inline void shutdown() {}
+// conversions between lambdas and arrays
+private:
+ /** recursive helper for getLambdaForArrayRepresentation */
+ static Node getLambdaForArrayRepresentationRec( TNode a, TNode bvl, unsigned bvlIndex,
+ std::unordered_map< TNode, Node, TNodeHashFunction >& visited );
+public:
+ /**
+ * Given an array constant a, returns a lambda expression that it corresponds to, with bound variable list bvl.
+ * Examples:
+ *
+ * (store (storeall (Array Int Int) 2) 0 1)
+ * becomes
+ * ((lambda x. (ite (= x 0) 1 2))
+ *
+ * (store (storeall (Array Int (Array Int Int)) (storeall (Array Int Int) 4)) 0 (store (storeall (Array Int Int) 3) 1 2))
+ * becomes
+ * (lambda xy. (ite (= x 0) (ite (= x 1) 2 3) 4))
+ *
+ * (store (store (storeall (Array Int Bool) false) 2 true) 1 true)
+ * becomes
+ * (lambda x. (ite (= x 1) true (ite (= x 2) true false)))
+ *
+ * Notice that the return body of the lambda is rewritten to ensure that the representation is canonical. Hence the last
+ * example will in fact be returned as:
+ * (lambda x. (ite (= x 1) true (= x 2)))
+ */
+ static Node getLambdaForArrayRepresentation( TNode a, TNode bvl );
+ /** given a lambda expression n, returns an array term. reqConst is true if we require the return value to be a constant.
+ * This does the opposite direction of the examples described above.
+ */
+ static Node getArrayRepresentationForLambda( TNode n, bool reqConst = false );
};/* class TheoryBuiltinRewriter */
}/* CVC4::theory::builtin namespace */
#include "expr/type_node.h"
#include "expr/expr.h"
#include "theory/rewriter.h"
+#include "theory/builtin/theory_builtin_rewriter.h" // for array and lambda representation
#include <sstream>
TypeNode rangeType = n[1].getType(check);
return nodeManager->mkFunctionType(argTypes, rangeType);
}
+ // computes whether a lambda is a constant value, via conversion to array representation
+ inline static bool computeIsConst(NodeManager* nodeManager, TNode n)
+ throw (AssertionException) {
+ Assert(n.getKind() == kind::LAMBDA);
+ //get array representation of this function, if possible
+ Node na = TheoryBuiltinRewriter::getArrayRepresentationForLambda( n, true );
+ if( !na.isNull() ){
+ Assert( na.getType().isArray() );
+ Trace("lambda-const") << "Array representation for " << n << " is " << na << " " << na.getType() << std::endl;
+ // must have the standard bound variable list
+ Node bvl = NodeManager::currentNM()->getBoundVarListForFunctionType( n.getType() );
+ if( bvl==n[0] ){
+ //array must be constant
+ if( na.isConst() ){
+ Trace("lambda-const") << "*** Constant lambda : " << n;
+ Trace("lambda-const") << " since its array representation : " << na << " is constant." << std::endl;
+ return true;
+ }else{
+ Trace("lambda-const") << "Non-constant lambda : " << n << " since array is not constant." << std::endl;
+ }
+ }else{
+ Trace("lambda-const") << "Non-constant lambda : " << n << " since its varlist is not standard." << std::endl;
+ Trace("lambda-const") << " standard : " << bvl << std::endl;
+ Trace("lambda-const") << " current : " << n[0] << std::endl;
+ }
+ }else{
+ Trace("lambda-const") << "Non-constant lambda : " << n << " since it has no array representation." << std::endl;
+ }
+ return false;
+ }
};/* class LambdaTypeRule */
class ChainTypeRule {