From c4306288347e043091628b63797f9f54b0359a7c Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 13 Sep 2017 19:26:35 -0500 Subject: [PATCH] Add isConst check for lambda expressions. (#1084) Add isConst check for lambda expressions by conversions to and from an Array representation where isConst is implemented. This enables check-model to succeed on higher-order benchmarks. Change the builtin rewriter for lambda to attempt to put lambdas into constant form. Update regression. --- src/expr/node_manager.cpp | 21 ++ src/expr/node_manager.h | 3 + src/theory/builtin/kinds | 3 + .../builtin/theory_builtin_rewriter.cpp | 187 ++++++++++++++++++ src/theory/builtin/theory_builtin_rewriter.h | 35 +++- .../builtin/theory_builtin_type_rules.h | 31 +++ test/regress/regress0/print_lambda.cvc | 3 +- 7 files changed, 279 insertions(+), 4 deletions(-) diff --git a/src/expr/node_manager.cpp b/src/expr/node_manager.cpp index 33f057274..85f5e3c75 100644 --- a/src/expr/node_manager.cpp +++ b/src/expr/node_manager.cpp @@ -83,6 +83,12 @@ struct NVReclaim { } // namespace +namespace attr { + struct LambdaBoundVarListTag { }; +}/* CVC4::attr namespace */ + +// attribute that stores the canonical bound variable list for function types +typedef expr::Attribute LambdaBoundVarListAttr; NodeManager::NodeManager(ExprManager* exprManager) : d_options(new Options()), @@ -692,6 +698,21 @@ Node* NodeManager::mkBoundVarPtr(const std::string& name, return n; } +Node NodeManager::getBoundVarListForFunctionType( TypeNode tn ) { + Assert( tn.isFunction() ); + Node bvl = tn.getAttribute(LambdaBoundVarListAttr()); + if( bvl.isNull() ){ + std::vector< Node > vars; + for( unsigned i=0; imkBoundVar( tn[i] ) ); + } + bvl = NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, vars ); + Trace("functions") << "Make standard bound var list " << bvl << " for " << tn << std::endl; + tn.setAttribute(LambdaBoundVarListAttr(),bvl); + } + return bvl; +} + Node NodeManager::mkVar(const TypeNode& type, uint32_t flags) { Node n = NodeBuilder<0>(this, kind::VARIABLE); setAttribute(n, TypeAttr(), type); diff --git a/src/expr/node_manager.h b/src/expr/node_manager.h index b1b0bc974..d5d296579 100644 --- a/src/expr/node_manager.h +++ b/src/expr/node_manager.h @@ -491,6 +491,9 @@ public: Node mkBoundVar(const TypeNode& type); Node* mkBoundVarPtr(const TypeNode& type); + /** get the canonical bound variable list for function type tn */ + static Node getBoundVarListForFunctionType( TypeNode tn ); + /** * Optional flags used to control behavior of NodeManager::mkSkolem(). * They should be composed with a bitwise OR (e.g., diff --git a/src/theory/builtin/kinds b/src/theory/builtin/kinds index 12e897189..6b7b952e2 100644 --- a/src/theory/builtin/kinds +++ b/src/theory/builtin/kinds @@ -336,4 +336,7 @@ typerule LAMBDA ::CVC4::theory::builtin::LambdaTypeRule typerule CHAIN ::CVC4::theory::builtin::ChainTypeRule typerule CHAIN_OP ::CVC4::theory::builtin::ChainedOperatorTypeRule +# lambda expressions that are isomorphic to array constants can be considered constants +construle LAMBDA ::CVC4::theory::builtin::LambdaTypeRule + endtheory diff --git a/src/theory/builtin/theory_builtin_rewriter.cpp b/src/theory/builtin/theory_builtin_rewriter.cpp index 32b35dfe8..57249e181 100644 --- a/src/theory/builtin/theory_builtin_rewriter.cpp +++ b/src/theory/builtin/theory_builtin_rewriter.cpp @@ -15,6 +15,7 @@ ** \todo document this file **/ +#include "expr/attribute.h" #include "theory/builtin/theory_builtin_rewriter.h" #include "expr/chain.h" @@ -70,6 +71,192 @@ Node TheoryBuiltinRewriter::blastChain(TNode in) { } } +RewriteResponse TheoryBuiltinRewriter::postRewrite(TNode node) { + if( node.getKind()==kind::LAMBDA ){ + Trace("builtin-rewrite") << "Rewriting lambda " << node << "..." << std::endl; + Node anode = getArrayRepresentationForLambda( node ); + if( !anode.isNull() ){ + anode = Rewriter::rewrite( anode ); + Assert( anode.getType().isArray() ); + //must get the standard bound variable list + Node varList = NodeManager::currentNM()->getBoundVarListForFunctionType( node.getType() ); + Node retNode = getLambdaForArrayRepresentation( anode, varList ); + if( !retNode.isNull() && retNode!=node ){ + Trace("builtin-rewrite") << "Rewrote lambda : " << std::endl; + Trace("builtin-rewrite") << " input : " << node << std::endl; + Trace("builtin-rewrite") << " output : " << retNode << ", constant = " << retNode.isConst() << std::endl; + Trace("builtin-rewrite") << " array rep : " << anode << ", constant = " << anode.isConst() << std::endl; + Assert( anode.isConst()==retNode.isConst() ); + Assert( retNode.getType()==node.getType() ); + return RewriteResponse(REWRITE_DONE, retNode); + } + }else{ + Trace("builtin-rewrite-debug") << "...failed to get array representation." << std::endl; + } + return RewriteResponse(REWRITE_DONE, node); + }else{ + return doRewrite(node); + } +} + +Node TheoryBuiltinRewriter::getLambdaForArrayRepresentationRec( TNode a, TNode bvl, unsigned bvlIndex, + std::unordered_map< TNode, Node, TNodeHashFunction >& visited ){ + std::unordered_map< TNode, Node, TNodeHashFunction >::iterator it = visited.find( a ); + if( it==visited.end() ){ + Node ret; + if( bvlIndexmkNode( kind::ITE, cond, val, body ); + } + } + }else if( a.getKind()==kind::STORE_ALL ){ + ArrayStoreAll storeAll = a.getConst(); + Node sa = Node::fromExpr(storeAll.getExpr()); + // convert the default value recursively (bounded by the number of arguments in bvl) + ret = getLambdaForArrayRepresentationRec( sa, bvl, bvlIndex+1, visited ); + } + }else{ + ret = a; + } + visited[a] = ret; + return ret; + }else{ + return it->second; + } +} + +Node TheoryBuiltinRewriter::getLambdaForArrayRepresentation( TNode a, TNode bvl ){ + Assert( a.getType().isArray() ); + std::unordered_map< TNode, Node, TNodeHashFunction > visited; + Trace("builtin-rewrite-debug") << "Get lambda for : " << a << ", with variables " << bvl << std::endl; + Node body = getLambdaForArrayRepresentationRec( a, bvl, 0, visited ); + if( !body.isNull() ){ + body = Rewriter::rewrite( body ); + Trace("builtin-rewrite-debug") << "...got lambda body " << body << std::endl; + return NodeManager::currentNM()->mkNode( kind::LAMBDA, bvl, body ); + }else{ + Trace("builtin-rewrite-debug") << "...failed to get lambda body" << std::endl; + return Node::null(); + } +} + +Node TheoryBuiltinRewriter::getArrayRepresentationForLambda( TNode n, bool reqConst ){ + Assert( n.getKind()==kind::LAMBDA ); + Trace("builtin-rewrite-debug") << "Get array representation for : " << n << std::endl; + + Node first_arg = n[0][0]; + Node rec_bvl; + if( n[0].getNumChildren()>1 ){ + std::vector< Node > args; + for( unsigned i=1; imkNode( kind::BOUND_VAR_LIST, args ); + } + + Trace("builtin-rewrite-debug2") << " process body..." << std::endl; + TypeNode retType; + std::vector< Node > conds; + std::vector< Node > vals; + Node curr = n[1]; + while( curr.getKind()==kind::ITE || curr.getKind()==kind::EQUAL || curr.getKind()==kind::NOT ){ + Trace("builtin-rewrite-debug2") << " process condition : " << curr[0] << std::endl; + Node index_eq; + Node curr_val; + Node next; + if( curr.getKind()==kind::ITE ){ + index_eq = curr[0]; + curr_val = curr[1]; + next = curr[2]; + }else{ + bool pol = curr.getKind()!=kind::NOT; + //Boolean case, e.g. lambda x. (= x v) is lambda x. (ite (= x v) true false) + index_eq = curr.getKind()==kind::NOT ? curr[0] : curr; + curr_val = NodeManager::currentNM()->mkConst( pol ); + next = NodeManager::currentNM()->mkConst( !pol ); + } + if( index_eq.getKind()!=kind::EQUAL ){ + // non-equality condition + Trace("builtin-rewrite-debug2") << " ...non-equality condition." << std::endl; + return Node::null(); + }else if( reqConst && Rewriter::rewrite( index_eq )!=index_eq ){ + // equality must be oriented correctly based on rewriter + Trace("builtin-rewrite-debug2") << " ...equality not oriented properly." << std::endl; + return Node::null(); + } + + Node curr_index; + for( unsigned r=0; r<2; r++ ){ + Node arg = index_eq[r]; + Node val = index_eq[1-r]; + if( arg==first_arg ){ + if( reqConst && !val.isConst() ){ + // non-constant value + Trace("builtin-rewrite-debug2") << " ...non-constant value." << std::endl; + return Node::null(); + }else{ + curr_index = val; + Trace("builtin-rewrite-debug2") << " " << arg << " -> " << val << std::endl; + break; + } + } + } + if( !curr_index.isNull() ){ + if( !rec_bvl.isNull() ){ + curr_val = NodeManager::currentNM()->mkNode( kind::LAMBDA, rec_bvl, curr_val ); + curr_val = getArrayRepresentationForLambda( curr_val, reqConst ); + if( curr_val.isNull() ){ + Trace("builtin-rewrite-debug2") << " ...non-constant value." << std::endl; + return Node::null(); + } + } + Trace("builtin-rewrite-debug2") << " ...condition is index " << curr_val << std::endl; + }else{ + Trace("builtin-rewrite-debug2") << " ...non-constant value." << std::endl; + return Node::null(); + } + conds.push_back( curr_index ); + vals.push_back( curr_val ); + TypeNode vtype = curr_val.getType(); + retType = retType.isNull() ? vtype : TypeNode::leastCommonTypeNode( retType, vtype ); + //recurse + curr = next; + } + if( !rec_bvl.isNull() ){ + curr = NodeManager::currentNM()->mkNode( kind::LAMBDA, rec_bvl, curr ); + curr = getArrayRepresentationForLambda( curr ); + } + if( !curr.isNull() && curr.isConst() ){ + TypeNode ctype = curr.getType(); + retType = retType.isNull() ? ctype : TypeNode::leastCommonTypeNode( retType, ctype ); + TypeNode array_type = NodeManager::currentNM()->mkArrayType( first_arg.getType(), retType ); + curr = NodeManager::currentNM()->mkConst(ArrayStoreAll(((ArrayType)array_type.toType()), curr.toExpr())); + Trace("builtin-rewrite-debug2") << " build array..." << std::endl; + // can only build if default value is constant (since array store all must be constant) + Trace("builtin-rewrite-debug2") << " got constant base " << curr << std::endl; + // construct store chain + for( int i=((int)conds.size()-1); i>=0; i-- ){ + Assert( conds[i].getType().isSubtypeOf( first_arg.getType() ) ); + curr = NodeManager::currentNM()->mkNode( kind::STORE, curr, conds[i], vals[i] ); + } + Trace("builtin-rewrite-debug") << "...got array " << curr << " for " << n << std::endl; + return curr; + }else{ + Trace("builtin-rewrite-debug") << "...failed to get array (cannot get constant default value)" << std::endl; + return Node::null(); + } +} + }/* CVC4::theory::builtin namespace */ }/* CVC4::theory namespace */ }/* CVC4 namespace */ diff --git a/src/theory/builtin/theory_builtin_rewriter.h b/src/theory/builtin/theory_builtin_rewriter.h index 8ca2c538e..d9ae5b447 100644 --- a/src/theory/builtin/theory_builtin_rewriter.h +++ b/src/theory/builtin/theory_builtin_rewriter.h @@ -45,9 +45,7 @@ public: } } - static inline RewriteResponse postRewrite(TNode node) { - return doRewrite(node); - } + static RewriteResponse postRewrite(TNode node); static inline RewriteResponse preRewrite(TNode node) { return doRewrite(node); @@ -56,6 +54,37 @@ public: static inline void init() {} static inline void shutdown() {} +// conversions between lambdas and arrays +private: + /** recursive helper for getLambdaForArrayRepresentation */ + static Node getLambdaForArrayRepresentationRec( TNode a, TNode bvl, unsigned bvlIndex, + std::unordered_map< TNode, Node, TNodeHashFunction >& visited ); +public: + /** + * Given an array constant a, returns a lambda expression that it corresponds to, with bound variable list bvl. + * Examples: + * + * (store (storeall (Array Int Int) 2) 0 1) + * becomes + * ((lambda x. (ite (= x 0) 1 2)) + * + * (store (storeall (Array Int (Array Int Int)) (storeall (Array Int Int) 4)) 0 (store (storeall (Array Int Int) 3) 1 2)) + * becomes + * (lambda xy. (ite (= x 0) (ite (= x 1) 2 3) 4)) + * + * (store (store (storeall (Array Int Bool) false) 2 true) 1 true) + * becomes + * (lambda x. (ite (= x 1) true (ite (= x 2) true false))) + * + * Notice that the return body of the lambda is rewritten to ensure that the representation is canonical. Hence the last + * example will in fact be returned as: + * (lambda x. (ite (= x 1) true (= x 2))) + */ + static Node getLambdaForArrayRepresentation( TNode a, TNode bvl ); + /** given a lambda expression n, returns an array term. reqConst is true if we require the return value to be a constant. + * This does the opposite direction of the examples described above. + */ + static Node getArrayRepresentationForLambda( TNode n, bool reqConst = false ); };/* class TheoryBuiltinRewriter */ }/* CVC4::theory::builtin namespace */ diff --git a/src/theory/builtin/theory_builtin_type_rules.h b/src/theory/builtin/theory_builtin_type_rules.h index d8893d441..370e5d348 100644 --- a/src/theory/builtin/theory_builtin_type_rules.h +++ b/src/theory/builtin/theory_builtin_type_rules.h @@ -23,6 +23,7 @@ #include "expr/type_node.h" #include "expr/expr.h" #include "theory/rewriter.h" +#include "theory/builtin/theory_builtin_rewriter.h" // for array and lambda representation #include @@ -161,6 +162,36 @@ public: TypeNode rangeType = n[1].getType(check); return nodeManager->mkFunctionType(argTypes, rangeType); } + // computes whether a lambda is a constant value, via conversion to array representation + inline static bool computeIsConst(NodeManager* nodeManager, TNode n) + throw (AssertionException) { + Assert(n.getKind() == kind::LAMBDA); + //get array representation of this function, if possible + Node na = TheoryBuiltinRewriter::getArrayRepresentationForLambda( n, true ); + if( !na.isNull() ){ + Assert( na.getType().isArray() ); + Trace("lambda-const") << "Array representation for " << n << " is " << na << " " << na.getType() << std::endl; + // must have the standard bound variable list + Node bvl = NodeManager::currentNM()->getBoundVarListForFunctionType( n.getType() ); + if( bvl==n[0] ){ + //array must be constant + if( na.isConst() ){ + Trace("lambda-const") << "*** Constant lambda : " << n; + Trace("lambda-const") << " since its array representation : " << na << " is constant." << std::endl; + return true; + }else{ + Trace("lambda-const") << "Non-constant lambda : " << n << " since array is not constant." << std::endl; + } + }else{ + Trace("lambda-const") << "Non-constant lambda : " << n << " since its varlist is not standard." << std::endl; + Trace("lambda-const") << " standard : " << bvl << std::endl; + Trace("lambda-const") << " current : " << n[0] << std::endl; + } + }else{ + Trace("lambda-const") << "Non-constant lambda : " << n << " since it has no array representation." << std::endl; + } + return false; + } };/* class LambdaTypeRule */ class ChainTypeRule { diff --git a/test/regress/regress0/print_lambda.cvc b/test/regress/regress0/print_lambda.cvc index aee61a533..548623954 100644 --- a/test/regress/regress0/print_lambda.cvc +++ b/test/regress/regress0/print_lambda.cvc @@ -1,6 +1,7 @@ +% SCRUBBER: sed -e 's/f : (INT) -> INT = (LAMBDA(.*:INT): 0);$/f : (INT) -> INT = (LAMBDA(VAR:INT): 0);/' % COMMAND-LINE: --produce-models % EXPECT: sat -% EXPECT: f : (INT) -> INT = (LAMBDA(_ufmt_1:INT): 0); +% EXPECT: f : (INT) -> INT = (LAMBDA(VAR:INT): 0); f : INT -> INT; ASSERT f(1) = 0; -- 2.30.2