Fix type error for rewriting bag.map bag.union_disjoint (#7640)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Sat, 13 Nov 2021 19:33:34 +0000 (13:33 -0600)
committerGitHub <noreply@github.com>
Sat, 13 Nov 2021 19:33:34 +0000 (19:33 +0000)
Fix type error for rewriting bag.map bag.union_disjoint

src/theory/bags/bags_rewriter.cpp
src/theory/bags/bags_rewriter.h
test/unit/theory/theory_bags_rewriter_white.cpp

index f9cc990f407c8e6dbc3bc942e5ec9385ef9e2f42..7093d52fcf51578e6e9ce83243991091983040b2 100644 (file)
@@ -521,7 +521,7 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const
   Assert(n.getKind() == kind::BAG_MAP);
   if (n[1].isConst())
   {
-    // (bag.map f bag.empty) = bag.empty
+    // (bag.map f (as bag.empty (Bag T1)) = (as bag.empty (Bag T2))
     // (bag.map f (bag "a" 3)) = (bag (f "a") 3)
     std::map<Node, Rational> elements = NormalForm::getBagElements(n[1]);
     std::map<Node, Rational> mappedElements;
@@ -541,6 +541,7 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const
   {
     case BAG_MAKE:
     {
+      // (bag.map f (bag x y)) = (bag (apply f x) y)
       Node mappedElement = d_nm->mkNode(APPLY_UF, n[0], n[1][0]);
       Node ret =
           d_nm->mkBag(n[0].getType().getRangeType(), mappedElement, n[1][1]);
@@ -549,8 +550,10 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const
 
     case BAG_UNION_DISJOINT:
     {
-      Node a = d_nm->mkNode(BAG_MAP, n[1][0]);
-      Node b = d_nm->mkNode(BAG_MAP, n[1][1]);
+      // (bag.map f (bag.union_disjoint A B)) =
+      //    (bag.union_disjoint (bag.map f A) (bag.map f B))
+      Node a = d_nm->mkNode(BAG_MAP, n[0], n[1][0]);
+      Node b = d_nm->mkNode(BAG_MAP, n[0], n[1][1]);
       Node ret = d_nm->mkNode(BAG_UNION_DISJOINT, a, b);
       return BagsRewriteResponse(ret, Rewrite::MAP_UNION_DISJOINT);
     }
index 36958a491dd363d7fb47745d1413660b1e82b934..a938b3bd49473ec7877ecc47ba4c68142de583aa 100644 (file)
@@ -214,14 +214,11 @@ class BagsRewriter : public TheoryRewriter
 
   /**
    *  rewrites for n include:
-   *  - (bag.map (lambda ((x U)) t) bag.empty) = bag.empty
-   *  - (bag.map (lambda ((x U)) t) (bag y z)) = (bag (apply (lambda ((x U)) t)
-   * y) z)
-   *  - (bag.map (lambda ((x U)) t) (bag.union_disjoint A B)) =
-   *       (bag.union_disjoint
-   *          (bag ((lambda ((x U)) t) "a") 3)
-   *          (bag ((lambda ((x U)) t) "b") 4))
-   *
+   *  - (bag.map f (as bag.empty (Bag T1)) = (as bag.empty (Bag T2))
+   *  - (bag.map f (bag x y)) = (bag (apply f x) y)
+   *  - (bag.map f (bag.union_disjoint A B)) =
+   *       (bag.union_disjoint (bag.map f A) (bag.map f B))
+   *  where f: T1 -> T2
    */
   BagsRewriteResponse postRewriteMap(const TNode& n) const;
 
index ca142c6b9c5145b669fea5256ab2a1af79c1744f..ee1e894482974b53b8a665e45a61eae3df04bc33 100644 (file)
@@ -785,7 +785,19 @@ TEST_F(TestTheoryWhiteBagsRewriter, map)
       d_nodeManager->mkBag(d_nodeManager->stringType(),
                            empty,
                            d_nodeManager->mkConst(CONST_RATIONAL, Rational(7)));
-  ASSERT_TRUE(rewritten == bag);
+  //  - (bag.map f (bag.union_disjoint K1 K2)) =
+  //      (bag.union_disjoint (bag.map f K1) (bag.map f K2))
+  Node k1 = d_skolemManager->mkDummySkolem("K1", A.getType());
+  Node k2 = d_skolemManager->mkDummySkolem("K2", A.getType());
+  Node f = d_skolemManager->mkDummySkolem("f", lambda.getType());
+  Node unionDisjointK1K2 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, k1, k2);
+  Node n3 = d_nodeManager->mkNode(BAG_MAP, f, unionDisjointK1K2);
+  Node rewritten3 = Rewriter::rewrite(n3);
+  Node mapK1 = d_nodeManager->mkNode(BAG_MAP, f, k1);
+  Node mapK2 = d_nodeManager->mkNode(BAG_MAP, f, k2);
+  Node unionDisjointMapK1K2 =
+      d_nodeManager->mkNode(BAG_UNION_DISJOINT, mapK1, mapK2);
+  ASSERT_TRUE(rewritten3 == unionDisjointMapK1K2);
 }
 
 }  // namespace test