Fix rewriter for lambda (#2211)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 26 Jul 2018 19:48:40 +0000 (14:48 -0500)
committerAndres Noetzli <andres.noetzli@gmail.com>
Thu, 26 Jul 2018 19:48:40 +0000 (12:48 -0700)
The rewriter for lambda is currently too aggressive, there are cases like:

lambda xy. x = y

that are converted into an array representation that when indexing based on x gives (store y true false), which is subsequently converted to:

lambda fv_1 fv_2. fv_1 = y

where fv_1 and fv_2 are canonical free variables. Here, y is used as index but was not substituted hence is incorrectly made free.

To make things simpler, this PR disables any rewriting for lambda unless the array representation of the lambda is a constant, which hardcodes/simplifies a previous argument (reqConst=true). This fixes a sygus issue I ran into yesterday (regression added in this PR).

Some parts of the code were formatted as a result.

src/theory/builtin/theory_builtin_rewriter.cpp
src/theory/builtin/theory_builtin_rewriter.h
src/theory/builtin/theory_builtin_type_rules.h
test/regress/Makefile.tests
test/regress/regress1/sygus/sygus-lambda-fv.sy [new file with mode: 0644]

index 3228d55f6efd86fcc74e56a5a153acbdcf366e41..da28b1ffdc6cf40d2abae72e31ba3c0e6b3d795d 100644 (file)
@@ -88,6 +88,7 @@ RewriteResponse TheoryBuiltinRewriter::postRewrite(TNode node) {
         Trace("builtin-rewrite") << "  array rep : " << anode << ", constant = " << anode.isConst() << std::endl;
         Assert( anode.isConst()==retNode.isConst() );
         Assert( retNode.getType()==node.getType() );
+        Assert(node.hasFreeVar() == retNode.hasFreeVar());
         return RewriteResponse(REWRITE_DONE, retNode);
       } 
     }else{
@@ -192,7 +193,9 @@ Node TheoryBuiltinRewriter::getLambdaForArrayRepresentation( TNode a, TNode bvl
   }
 }
 
-Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec( TNode n, bool reqConst, TypeNode retType ){
+Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec(TNode n,
+                                                               TypeNode retType)
+{
   Assert( n.getKind()==kind::LAMBDA );
   Trace("builtin-rewrite-debug") << "Get array representation for : " << n << std::endl;
 
@@ -230,7 +233,9 @@ Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec( TNode n, bool re
       // non-equality condition
       Trace("builtin-rewrite-debug2") << "  ...non-equality condition." << std::endl;
       return Node::null();
-    }else if( reqConst && Rewriter::rewrite( index_eq )!=index_eq ){
+    }
+    else if (Rewriter::rewrite(index_eq) != index_eq)
+    {
       // equality must be oriented correctly based on rewriter
       Trace("builtin-rewrite-debug2") << "  ...equality not oriented properly." << std::endl;
       return Node::null();
@@ -241,7 +246,8 @@ Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec( TNode n, bool re
       Node arg = index_eq[r];
       Node val = index_eq[1-r];
       if( arg==first_arg ){
-        if( reqConst && !val.isConst() ){
+        if (!val.isConst())
+        {
           // non-constant value
           Trace("builtin-rewrite-debug2") << "  ...non-constant value." << std::endl;
           return Node::null();
@@ -255,7 +261,7 @@ Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec( TNode n, bool re
     if( !curr_index.isNull() ){
       if( !rec_bvl.isNull() ){
         curr_val = NodeManager::currentNM()->mkNode( kind::LAMBDA, rec_bvl, curr_val );
-        curr_val = getArrayRepresentationForLambdaRec( curr_val, reqConst, retType );
+        curr_val = getArrayRepresentationForLambdaRec(curr_val, retType);
         if( curr_val.isNull() ){
           Trace("builtin-rewrite-debug2") << "  ...non-constant value." << std::endl;
           return Node::null();
@@ -274,7 +280,7 @@ Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec( TNode n, bool re
   }
   if( !rec_bvl.isNull() ){
     curr = NodeManager::currentNM()->mkNode( kind::LAMBDA, rec_bvl, curr );
-    curr = getArrayRepresentationForLambdaRec( curr, reqConst, retType );
+    curr = getArrayRepresentationForLambdaRec(curr, retType);
   }
   if( !curr.isNull() && curr.isConst() ){
     // compute the return type
@@ -302,11 +308,12 @@ Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec( TNode n, bool re
   }
 }
 
-Node TheoryBuiltinRewriter::getArrayRepresentationForLambda( TNode n, bool reqConst ){
+Node TheoryBuiltinRewriter::getArrayRepresentationForLambda(TNode n)
+{
   Assert( n.getKind()==kind::LAMBDA );
   // must carry the overall return type to deal with cases like (lambda ((x Int)(y Int)) (ite (= x _) 0.5 0.0)),
   //  where the inner construction for the else case about should be (arraystoreall (Array Int Real) 0.0)
-  return getArrayRepresentationForLambdaRec( n, reqConst, n[1].getType() );
+  return getArrayRepresentationForLambdaRec(n, n[1].getType());
 }
 
 }/* CVC4::theory::builtin namespace */
index 79ae825e994786cf20ed15cd46ed38ec3c6960e5..8f45cc0fd23fed4ac6c21b1803cecf2c60deecb8 100644 (file)
@@ -60,60 +60,70 @@ private:
   static Node getLambdaForArrayRepresentationRec( TNode a, TNode bvl, unsigned bvlIndex, 
                                                   std::unordered_map< TNode, Node, TNodeHashFunction >& visited );
   /** recursive helper for getArrayRepresentationForLambda */
-  static Node getArrayRepresentationForLambdaRec( TNode n, bool reqConst, TypeNode retType );
-public:
- /** Get function type for array type
-  *
-  * This returns the function type of terms returned by the function
-  * getLambdaForArrayRepresentation( t, bvl ),
-  * where t.getType()=atn.
-  *
-  * bvl should be a bound variable list whose variables correspond in-order
-  * to the index types of the (curried) Array type. For example, a bound
-  * variable list bvl whose variables have types (Int, Real) can be given as
-  * input when paired with atn = (Array Int (Array Real Bool)), or (Array Int
-  * (Array Real (Array Bool Bool))). This function returns (-> Int Real Bool)
-  * and (-> Int Real (Array Bool Bool)) respectively in these cases.
-  * On the other hand, the above bvl is not a proper input for
-  * atn = (Array Int (Array Bool Bool)) or (Array Int Int).
-  * If the types of bvl and atn do not match, we throw an assertion failure.
-  */
- static TypeNode getFunctionTypeForArrayType(TypeNode atn, Node bvl);
- /** Get array type for function type
-  *
-  * This returns the array type of terms returned by
-  * getArrayRepresentationForLambda( t ), where t.getType()=ftn.
-  */
- static TypeNode getArrayTypeForFunctionType(TypeNode ftn);
- /**
-  * Given an array constant a, returns a lambda expression that it corresponds
-  * to, with bound variable list bvl.
-  * Examples:
-  *
-  * (store (storeall (Array Int Int) 2) 0 1)
-  * becomes
-  * ((lambda x. (ite (= x 0) 1 2))
-  *
-  * (store (storeall (Array Int (Array Int Int)) (storeall (Array Int Int) 4)) 0
-  * (store (storeall (Array Int Int) 3) 1 2))
-  * becomes
-  * (lambda xy. (ite (= x 0) (ite (= x 1) 2 3) 4))
-  *
-  * (store (store (storeall (Array Int Bool) false) 2 true) 1 true)
-  * becomes
-  * (lambda x. (ite (= x 1) true (ite (= x 2) true false)))
-  *
-  * Notice that the return body of the lambda is rewritten to ensure that the
-  * representation is canonical. Hence the last
-  * example will in fact be returned as:
-  * (lambda x. (ite (= x 1) true (= x 2)))
-  */
- static Node getLambdaForArrayRepresentation(TNode a, TNode bvl);
- /** given a lambda expression n, returns an array term. reqConst is true if we
-  * require the return value to be a constant.
+  static Node getArrayRepresentationForLambdaRec(TNode n, TypeNode retType);
+
+ public:
+  /** Get function type for array type
+   *
+   * This returns the function type of terms returned by the function
+   * getLambdaForArrayRepresentation( t, bvl ),
+   * where t.getType()=atn.
+   *
+   * bvl should be a bound variable list whose variables correspond in-order
+   * to the index types of the (curried) Array type. For example, a bound
+   * variable list bvl whose variables have types (Int, Real) can be given as
+   * input when paired with atn = (Array Int (Array Real Bool)), or (Array Int
+   * (Array Real (Array Bool Bool))). This function returns (-> Int Real Bool)
+   * and (-> Int Real (Array Bool Bool)) respectively in these cases.
+   * On the other hand, the above bvl is not a proper input for
+   * atn = (Array Int (Array Bool Bool)) or (Array Int Int).
+   * If the types of bvl and atn do not match, we throw an assertion failure.
+   */
+  static TypeNode getFunctionTypeForArrayType(TypeNode atn, Node bvl);
+  /** Get array type for function type
+   *
+   * This returns the array type of terms returned by
+   * getArrayRepresentationForLambda( t ), where t.getType()=ftn.
+   */
+  static TypeNode getArrayTypeForFunctionType(TypeNode ftn);
+  /**
+   * Given an array constant a, returns a lambda expression that it corresponds
+   * to, with bound variable list bvl.
+   * Examples:
+   *
+   * (store (storeall (Array Int Int) 2) 0 1)
+   * becomes
+   * ((lambda x. (ite (= x 0) 1 2))
+   *
+   * (store (storeall (Array Int (Array Int Int)) (storeall (Array Int Int) 4))
+   * 0 (store (storeall (Array Int Int) 3) 1 2)) becomes (lambda xy. (ite (= x
+   * 0) (ite (= x 1) 2 3) 4))
+   *
+   * (store (store (storeall (Array Int Bool) false) 2 true) 1 true)
+   * becomes
+   * (lambda x. (ite (= x 1) true (ite (= x 2) true false)))
+   *
+   * Notice that the return body of the lambda is rewritten to ensure that the
+   * representation is canonical. Hence the last
+   * example will in fact be returned as:
+   * (lambda x. (ite (= x 1) true (= x 2)))
+   */
 static Node getLambdaForArrayRepresentation(TNode a, TNode bvl);
+  /**
+   * Given a lambda expression n, returns an array term that corresponds to n.
    * This does the opposite direction of the examples described above.
+   *
+   * We limit the return values of this method to be almost constant functions,
+   * that is, arrays of the form:
+   *   (store ... (store (storeall _ b) i1 e1) ... in en)
+   * where b, i1, e1, ..., in, en are constants.
+   * Notice however that the return value of this form need not be a (canonical)
+   * array constant.
+   *
+   * If it is not possible to construct an array of this form that corresponds
+   * to n, this method returns null.
    */
static Node getArrayRepresentationForLambda(TNode n, bool reqConst = false);
 static Node getArrayRepresentationForLambda(TNode n);
 };/* class TheoryBuiltinRewriter */
 
 }/* CVC4::theory::builtin namespace */
index bd3e5faa4ed6dc83da1f2195c1a4a72918a001a3..c471caf867a34b2a3fc91305f2b9e6f46bcea8cb 100644 (file)
@@ -170,7 +170,7 @@ class LambdaTypeRule {
   {
     Assert(n.getKind() == kind::LAMBDA);
     //get array representation of this function, if possible
-    Node na = TheoryBuiltinRewriter::getArrayRepresentationForLambda( n, true );
+    Node na = TheoryBuiltinRewriter::getArrayRepresentationForLambda(n);
     if( !na.isNull() ){
       Assert( na.getType().isArray() );
       Trace("lambda-const") << "Array representation for " << n << " is " << na << " " << na.getType() << std::endl;
index 48b290bfe7202f451d0fc89227467e7d793ccee5..c43e083f38884d3769cf007e026493a4bcf5bc4b 100644 (file)
@@ -1547,6 +1547,7 @@ REG1_TESTS = \
        regress1/sygus/strings-trivial-two-type.sy \
        regress1/sygus/strings-trivial.sy \
        regress1/sygus/sygus-dt.sy \
+       regress1/sygus/sygus-lambda-fv.sy \
        regress1/sygus/sygus-uf-ex.sy \
        regress1/sygus/t8.sy \
        regress1/sygus/tl-type-0.sy \
diff --git a/test/regress/regress1/sygus/sygus-lambda-fv.sy b/test/regress/regress1/sygus/sygus-lambda-fv.sy
new file mode 100644 (file)
index 0000000..d2a3700
--- /dev/null
@@ -0,0 +1,21 @@
+; EXPECT: unsat
+; COMMAND-LINE: --sygus-out=status
+(set-logic ALL)
+
+(synth-fun SC ((y (BitVec 32)) (w (BitVec 32)) ) (BitVec 32)
+  (
+   (Start (BitVec 32) (
+     y
+     w
+     #x00000000
+     (bvadd Start Start)
+     (ite StartBool Start Start)
+   ))
+   (StartBool Bool ((= Start #x10000000) (= Start #x00000000)))
+))
+
+(constraint (= (SC #x00000000 #x00001000) #x00001000))
+(constraint (= (SC #x00001000 #x00001000) #x00001000))
+(constraint (= (SC #x01001000 #x00001000) #x01001000))
+
+(check-synth)