Add isConst check for lambda expressions. (#1084)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 14 Sep 2017 00:26:35 +0000 (19:26 -0500)
committerAina Niemetz <aina.niemetz@gmail.com>
Thu, 14 Sep 2017 00:26:35 +0000 (17:26 -0700)
Add isConst check for lambda expressions by conversions to and from an Array representation where isConst is implemented. This enables check-model to succeed on higher-order benchmarks. Change the builtin rewriter for lambda to attempt to put lambdas into constant form. Update regression.

src/expr/node_manager.cpp
src/expr/node_manager.h
src/theory/builtin/kinds
src/theory/builtin/theory_builtin_rewriter.cpp
src/theory/builtin/theory_builtin_rewriter.h
src/theory/builtin/theory_builtin_type_rules.h
test/regress/regress0/print_lambda.cvc

index 33f0572747083f7c93848a0927950956d6427397..85f5e3c7560c070d4e339dea08b86641c511797b 100644 (file)
@@ -83,6 +83,12 @@ struct NVReclaim {
 
 } // namespace
 
+namespace attr {
+  struct LambdaBoundVarListTag { };
+}/* CVC4::attr namespace */
+
+// attribute that stores the canonical bound variable list for function types
+typedef expr::Attribute<attr::LambdaBoundVarListTag, Node> LambdaBoundVarListAttr;
 
 NodeManager::NodeManager(ExprManager* exprManager) :
   d_options(new Options()),
@@ -692,6 +698,21 @@ Node* NodeManager::mkBoundVarPtr(const std::string& name,
   return n;
 }
 
+Node NodeManager::getBoundVarListForFunctionType( TypeNode tn ) {
+  Assert( tn.isFunction() );
+  Node bvl = tn.getAttribute(LambdaBoundVarListAttr());
+  if( bvl.isNull() ){
+    std::vector< Node > vars;
+    for( unsigned i=0; i<tn.getNumChildren()-1; i++ ){
+      vars.push_back( NodeManager::currentNM()->mkBoundVar( tn[i] ) );
+    }
+    bvl = NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, vars );
+    Trace("functions") << "Make standard bound var list " << bvl << " for " << tn << std::endl;
+    tn.setAttribute(LambdaBoundVarListAttr(),bvl);
+  }
+  return bvl;
+}
+
 Node NodeManager::mkVar(const TypeNode& type, uint32_t flags) {
   Node n = NodeBuilder<0>(this, kind::VARIABLE);
   setAttribute(n, TypeAttr(), type);
index b1b0bc974d7061b6ac36bd3c0124008dc8026172..d5d296579d7eea448223baaedae94cd157543a34 100644 (file)
@@ -491,6 +491,9 @@ public:
   Node mkBoundVar(const TypeNode& type);
   Node* mkBoundVarPtr(const TypeNode& type);
 
+  /** get the canonical bound variable list for function type tn */
+  static Node getBoundVarListForFunctionType( TypeNode tn );
+
   /**
    * Optional flags used to control behavior of NodeManager::mkSkolem().
    * They should be composed with a bitwise OR (e.g.,
index 12e8971892090451efe2aaba55ca6fa4921b43d0..6b7b952e2cca1dab0e397f856d775e642cc14ede 100644 (file)
@@ -336,4 +336,7 @@ typerule LAMBDA ::CVC4::theory::builtin::LambdaTypeRule
 typerule CHAIN ::CVC4::theory::builtin::ChainTypeRule
 typerule CHAIN_OP ::CVC4::theory::builtin::ChainedOperatorTypeRule
 
+# lambda expressions that are isomorphic to array constants can be considered constants
+construle LAMBDA ::CVC4::theory::builtin::LambdaTypeRule
+
 endtheory
index 32b35dfe850171fd68778c860e075dd5e2d39a48..57249e181629e6f985f2f8fe28de798057eb80db 100644 (file)
@@ -15,6 +15,7 @@
  ** \todo document this file
  **/
 
+#include "expr/attribute.h"
 #include "theory/builtin/theory_builtin_rewriter.h"
 
 #include "expr/chain.h"
@@ -70,6 +71,192 @@ Node TheoryBuiltinRewriter::blastChain(TNode in) {
   }
 }
 
+RewriteResponse TheoryBuiltinRewriter::postRewrite(TNode node) {
+  if( node.getKind()==kind::LAMBDA ){
+    Trace("builtin-rewrite") << "Rewriting lambda " << node << "..." << std::endl;
+    Node anode = getArrayRepresentationForLambda( node );
+    if( !anode.isNull() ){
+      anode = Rewriter::rewrite( anode );
+      Assert( anode.getType().isArray() );
+      //must get the standard bound variable list
+      Node varList = NodeManager::currentNM()->getBoundVarListForFunctionType( node.getType() );
+      Node retNode = getLambdaForArrayRepresentation( anode, varList );
+      if( !retNode.isNull() && retNode!=node ){
+        Trace("builtin-rewrite") << "Rewrote lambda : " << std::endl;
+        Trace("builtin-rewrite") << "     input  : " << node << std::endl;
+        Trace("builtin-rewrite") << "     output : " << retNode << ", constant = " << retNode.isConst() << std::endl;
+        Trace("builtin-rewrite") << "  array rep : " << anode << ", constant = " << anode.isConst() << std::endl;
+        Assert( anode.isConst()==retNode.isConst() );
+        Assert( retNode.getType()==node.getType() );
+        return RewriteResponse(REWRITE_DONE, retNode);
+      } 
+    }else{
+      Trace("builtin-rewrite-debug") << "...failed to get array representation." << std::endl;
+    }
+    return RewriteResponse(REWRITE_DONE, node);
+  }else{ 
+    return doRewrite(node);
+  }
+}
+
+Node TheoryBuiltinRewriter::getLambdaForArrayRepresentationRec( TNode a, TNode bvl, unsigned bvlIndex, 
+                                                                std::unordered_map< TNode, Node, TNodeHashFunction >& visited ){
+  std::unordered_map< TNode, Node, TNodeHashFunction >::iterator it = visited.find( a );
+  if( it==visited.end() ){
+    Node ret;
+    if( bvlIndex<bvl.getNumChildren() ){
+      Assert( a.getType().isArray() );
+      if( a.getKind()==kind::STORE ){
+        // convert the array recursively
+        Node body = getLambdaForArrayRepresentationRec( a[0], bvl, bvlIndex, visited );
+        if( !body.isNull() ){
+          // convert the value recursively (bounded by the number of arguments in bvl)
+          Node val = getLambdaForArrayRepresentationRec( a[2], bvl, bvlIndex+1, visited );
+          if( !val.isNull() ){
+            Assert( !TypeNode::leastCommonTypeNode( a[1].getType(), bvl[bvlIndex].getType() ).isNull() );
+            Assert( !TypeNode::leastCommonTypeNode( val.getType(), body.getType() ).isNull() );
+            Node cond = bvl[bvlIndex].eqNode( a[1] );
+            ret = NodeManager::currentNM()->mkNode( kind::ITE, cond, val, body );
+          }
+        }
+      }else if( a.getKind()==kind::STORE_ALL ){
+        ArrayStoreAll storeAll = a.getConst<ArrayStoreAll>();
+        Node sa = Node::fromExpr(storeAll.getExpr());
+        // convert the default value recursively (bounded by the number of arguments in bvl)
+        ret = getLambdaForArrayRepresentationRec( sa, bvl, bvlIndex+1, visited );
+      }
+    }else{
+      ret = a;
+    }
+    visited[a] = ret;
+    return ret;
+  }else{
+    return it->second;
+  }
+}
+
+Node TheoryBuiltinRewriter::getLambdaForArrayRepresentation( TNode a, TNode bvl ){
+  Assert( a.getType().isArray() );
+  std::unordered_map< TNode, Node, TNodeHashFunction > visited;
+  Trace("builtin-rewrite-debug") << "Get lambda for : " << a << ", with variables " << bvl << std::endl;
+  Node body = getLambdaForArrayRepresentationRec( a, bvl, 0, visited );
+  if( !body.isNull() ){
+    body = Rewriter::rewrite( body );
+    Trace("builtin-rewrite-debug") << "...got lambda body " << body << std::endl;
+    return NodeManager::currentNM()->mkNode( kind::LAMBDA, bvl, body );
+  }else{
+    Trace("builtin-rewrite-debug") << "...failed to get lambda body" << std::endl;
+    return Node::null();
+  }
+}
+
+Node TheoryBuiltinRewriter::getArrayRepresentationForLambda( TNode n, bool reqConst ){
+  Assert( n.getKind()==kind::LAMBDA );
+  Trace("builtin-rewrite-debug") << "Get array representation for : " << n << std::endl;
+
+  Node first_arg = n[0][0];
+  Node rec_bvl;
+  if( n[0].getNumChildren()>1 ){
+    std::vector< Node > args;
+    for( unsigned i=1; i<n[0].getNumChildren(); i++ ){
+      args.push_back( n[0][i] );
+    }
+    rec_bvl = NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, args );
+  }
+
+  Trace("builtin-rewrite-debug2") << "  process body..." << std::endl;
+  TypeNode retType;
+  std::vector< Node > conds;
+  std::vector< Node > vals;
+  Node curr = n[1];
+  while( curr.getKind()==kind::ITE || curr.getKind()==kind::EQUAL || curr.getKind()==kind::NOT ){
+    Trace("builtin-rewrite-debug2") << "  process condition : " << curr[0] << std::endl;
+    Node index_eq;
+    Node curr_val;
+    Node next;
+    if( curr.getKind()==kind::ITE ){
+      index_eq = curr[0];
+      curr_val = curr[1];
+      next = curr[2];
+    }else{
+      bool pol = curr.getKind()!=kind::NOT;
+      //Boolean case, e.g. lambda x. (= x v) is lambda x. (ite (= x v) true false)
+      index_eq = curr.getKind()==kind::NOT ? curr[0] : curr;
+      curr_val = NodeManager::currentNM()->mkConst( pol );
+      next = NodeManager::currentNM()->mkConst( !pol );
+    }
+    if( index_eq.getKind()!=kind::EQUAL ){
+      // non-equality condition
+      Trace("builtin-rewrite-debug2") << "  ...non-equality condition." << std::endl;
+      return Node::null();
+    }else if( reqConst && Rewriter::rewrite( index_eq )!=index_eq ){
+      // equality must be oriented correctly based on rewriter
+      Trace("builtin-rewrite-debug2") << "  ...equality not oriented properly." << std::endl;
+      return Node::null();
+    }
+
+    Node curr_index;
+    for( unsigned r=0; r<2; r++ ){
+      Node arg = index_eq[r];
+      Node val = index_eq[1-r];
+      if( arg==first_arg ){
+        if( reqConst && !val.isConst() ){
+          // non-constant value
+          Trace("builtin-rewrite-debug2") << "  ...non-constant value." << std::endl;
+          return Node::null();
+        }else{
+          curr_index = val;
+          Trace("builtin-rewrite-debug2") << "    " << arg << " -> " << val << std::endl;
+          break;
+        }
+      }
+    }
+    if( !curr_index.isNull() ){
+      if( !rec_bvl.isNull() ){
+        curr_val = NodeManager::currentNM()->mkNode( kind::LAMBDA, rec_bvl, curr_val );
+        curr_val = getArrayRepresentationForLambda( curr_val, reqConst );
+        if( curr_val.isNull() ){
+          Trace("builtin-rewrite-debug2") << "  ...non-constant value." << std::endl;
+          return Node::null();
+        }
+      }      
+      Trace("builtin-rewrite-debug2") << "  ...condition is index " << curr_val << std::endl;
+    }else{
+      Trace("builtin-rewrite-debug2") << "  ...non-constant value." << std::endl;
+      return Node::null();
+    }
+    conds.push_back( curr_index );
+    vals.push_back( curr_val );
+    TypeNode vtype = curr_val.getType();
+    retType = retType.isNull() ? vtype : TypeNode::leastCommonTypeNode( retType, vtype );
+    //recurse
+    curr = next;
+  }
+  if( !rec_bvl.isNull() ){
+    curr = NodeManager::currentNM()->mkNode( kind::LAMBDA, rec_bvl, curr );
+    curr = getArrayRepresentationForLambda( curr );
+  }
+  if( !curr.isNull() && curr.isConst() ){
+    TypeNode ctype = curr.getType();
+    retType = retType.isNull() ? ctype : TypeNode::leastCommonTypeNode( retType, ctype );
+    TypeNode array_type = NodeManager::currentNM()->mkArrayType( first_arg.getType(), retType );
+    curr = NodeManager::currentNM()->mkConst(ArrayStoreAll(((ArrayType)array_type.toType()), curr.toExpr()));
+    Trace("builtin-rewrite-debug2") << "  build array..." << std::endl;
+    // can only build if default value is constant (since array store all must be constant)
+    Trace("builtin-rewrite-debug2") << "  got constant base " << curr << std::endl;
+    // construct store chain
+    for( int i=((int)conds.size()-1); i>=0; i-- ){
+      Assert( conds[i].getType().isSubtypeOf( first_arg.getType() ) );
+      curr = NodeManager::currentNM()->mkNode( kind::STORE, curr, conds[i], vals[i] );
+    }
+    Trace("builtin-rewrite-debug") << "...got array " << curr << " for " << n << std::endl;
+    return curr;
+  }else{
+    Trace("builtin-rewrite-debug") << "...failed to get array (cannot get constant default value)" << std::endl;
+    return Node::null();    
+  }
+}
+
 }/* CVC4::theory::builtin namespace */
 }/* CVC4::theory namespace */
 }/* CVC4 namespace */
index 8ca2c538e9dfe1cbd6a3f4e1584f28753eb5276f..d9ae5b447b40ab3306c29f7ed1b5f5391a84a928 100644 (file)
@@ -45,9 +45,7 @@ public:
     }
   }
 
-  static inline RewriteResponse postRewrite(TNode node) {
-    return doRewrite(node);
-  }
+  static RewriteResponse postRewrite(TNode node);
 
   static inline RewriteResponse preRewrite(TNode node) {
     return doRewrite(node);
@@ -56,6 +54,37 @@ public:
   static inline void init() {}
   static inline void shutdown() {}
 
+// conversions between lambdas and arrays
+private:  
+  /** recursive helper for getLambdaForArrayRepresentation */
+  static Node getLambdaForArrayRepresentationRec( TNode a, TNode bvl, unsigned bvlIndex, 
+                                                  std::unordered_map< TNode, Node, TNodeHashFunction >& visited );
+public:
+  /** 
+   * Given an array constant a, returns a lambda expression that it corresponds to, with bound variable list bvl. 
+   * Examples:
+   *
+   * (store (storeall (Array Int Int) 2) 0 1) 
+   * becomes
+   * ((lambda x. (ite (= x 0) 1 2))
+   *
+   * (store (storeall (Array Int (Array Int Int)) (storeall (Array Int Int) 4)) 0 (store (storeall (Array Int Int) 3) 1 2))
+   * becomes
+   * (lambda xy. (ite (= x 0) (ite (= x 1) 2 3) 4))
+   *
+   * (store (store (storeall (Array Int Bool) false) 2 true) 1 true)
+   * becomes
+   * (lambda x. (ite (= x 1) true (ite (= x 2) true false)))
+   *
+   * Notice that the return body of the lambda is rewritten to ensure that the representation is canonical. Hence the last
+   * example will in fact be returned as:
+   * (lambda x. (ite (= x 1) true (= x 2)))
+   */
+  static Node getLambdaForArrayRepresentation( TNode a, TNode bvl );
+  /** given a lambda expression n, returns an array term. reqConst is true if we require the return value to be a constant. 
+    * This does the opposite direction of the examples described above.
+    */
+  static Node getArrayRepresentationForLambda( TNode n, bool reqConst = false );
 };/* class TheoryBuiltinRewriter */
 
 }/* CVC4::theory::builtin namespace */
index d8893d441fc36fa95c6232df1606e94921b50e3b..370e5d3488d66aef0def0ef62b9933929779a619 100644 (file)
@@ -23,6 +23,7 @@
 #include "expr/type_node.h"
 #include "expr/expr.h"
 #include "theory/rewriter.h"
+#include "theory/builtin/theory_builtin_rewriter.h" // for array and lambda representation
 
 #include <sstream>
 
@@ -161,6 +162,36 @@ public:
     TypeNode rangeType = n[1].getType(check);
     return nodeManager->mkFunctionType(argTypes, rangeType);
   }
+  // computes whether a lambda is a constant value, via conversion to array representation
+  inline static bool computeIsConst(NodeManager* nodeManager, TNode n)
+    throw (AssertionException) {
+    Assert(n.getKind() == kind::LAMBDA);
+    //get array representation of this function, if possible
+    Node na = TheoryBuiltinRewriter::getArrayRepresentationForLambda( n, true );
+    if( !na.isNull() ){
+      Assert( na.getType().isArray() );
+      Trace("lambda-const") << "Array representation for " << n << " is " << na << " " << na.getType() << std::endl;
+      // must have the standard bound variable list
+      Node bvl = NodeManager::currentNM()->getBoundVarListForFunctionType( n.getType() );
+      if( bvl==n[0] ){
+        //array must be constant
+        if( na.isConst() ){
+          Trace("lambda-const") << "*** Constant lambda : " << n;
+          Trace("lambda-const") << " since its array representation : " << na << " is constant." << std::endl;
+          return true;
+        }else{
+          Trace("lambda-const") << "Non-constant lambda : " << n << " since array is not constant." << std::endl;
+        } 
+      }else{
+        Trace("lambda-const") << "Non-constant lambda : " << n << " since its varlist is not standard." << std::endl;
+        Trace("lambda-const") << "  standard : " << bvl << std::endl;
+        Trace("lambda-const") << "   current : " << n[0] << std::endl;
+      } 
+    }else{
+      Trace("lambda-const") << "Non-constant lambda : " << n << " since it has no array representation." << std::endl;
+    } 
+    return false;
+  }
 };/* class LambdaTypeRule */
 
 class ChainTypeRule {
index aee61a5335a6ee17ade3d7d7feec30607520526c..548623954749a2d2dded5ef99bc8ef2a443902c7 100644 (file)
@@ -1,6 +1,7 @@
+% SCRUBBER: sed -e 's/f : (INT) -> INT = (LAMBDA(.*:INT): 0);$/f : (INT) -> INT = (LAMBDA(VAR:INT): 0);/'
 % COMMAND-LINE: --produce-models
 % EXPECT: sat
-% EXPECT: f : (INT) -> INT = (LAMBDA(_ufmt_1:INT): 0);
+% EXPECT: f : (INT) -> INT = (LAMBDA(VAR:INT): 0);
 
 f : INT -> INT;
 ASSERT f(1) = 0;