From 805205a2047eeae7842b1c534859b52fa204ee0e Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Sat, 13 Nov 2021 13:33:34 -0600 Subject: [PATCH] Fix type error for rewriting bag.map bag.union_disjoint (#7640) Fix type error for rewriting bag.map bag.union_disjoint --- src/theory/bags/bags_rewriter.cpp | 9 ++++++--- src/theory/bags/bags_rewriter.h | 13 +++++-------- test/unit/theory/theory_bags_rewriter_white.cpp | 14 +++++++++++++- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index f9cc990f4..7093d52fc 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -521,7 +521,7 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const Assert(n.getKind() == kind::BAG_MAP); if (n[1].isConst()) { - // (bag.map f bag.empty) = bag.empty + // (bag.map f (as bag.empty (Bag T1)) = (as bag.empty (Bag T2)) // (bag.map f (bag "a" 3)) = (bag (f "a") 3) std::map elements = NormalForm::getBagElements(n[1]); std::map mappedElements; @@ -541,6 +541,7 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const { case BAG_MAKE: { + // (bag.map f (bag x y)) = (bag (apply f x) y) Node mappedElement = d_nm->mkNode(APPLY_UF, n[0], n[1][0]); Node ret = d_nm->mkBag(n[0].getType().getRangeType(), mappedElement, n[1][1]); @@ -549,8 +550,10 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const case BAG_UNION_DISJOINT: { - Node a = d_nm->mkNode(BAG_MAP, n[1][0]); - Node b = d_nm->mkNode(BAG_MAP, n[1][1]); + // (bag.map f (bag.union_disjoint A B)) = + // (bag.union_disjoint (bag.map f A) (bag.map f B)) + Node a = d_nm->mkNode(BAG_MAP, n[0], n[1][0]); + Node b = d_nm->mkNode(BAG_MAP, n[0], n[1][1]); Node ret = d_nm->mkNode(BAG_UNION_DISJOINT, a, b); return BagsRewriteResponse(ret, Rewrite::MAP_UNION_DISJOINT); } diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h index 36958a491..a938b3bd4 100644 --- a/src/theory/bags/bags_rewriter.h +++ b/src/theory/bags/bags_rewriter.h @@ -214,14 +214,11 @@ class BagsRewriter : public TheoryRewriter /** * rewrites for n include: - * - (bag.map (lambda ((x U)) t) bag.empty) = bag.empty - * - (bag.map (lambda ((x U)) t) (bag y z)) = (bag (apply (lambda ((x U)) t) - * y) z) - * - (bag.map (lambda ((x U)) t) (bag.union_disjoint A B)) = - * (bag.union_disjoint - * (bag ((lambda ((x U)) t) "a") 3) - * (bag ((lambda ((x U)) t) "b") 4)) - * + * - (bag.map f (as bag.empty (Bag T1)) = (as bag.empty (Bag T2)) + * - (bag.map f (bag x y)) = (bag (apply f x) y) + * - (bag.map f (bag.union_disjoint A B)) = + * (bag.union_disjoint (bag.map f A) (bag.map f B)) + * where f: T1 -> T2 */ BagsRewriteResponse postRewriteMap(const TNode& n) const; diff --git a/test/unit/theory/theory_bags_rewriter_white.cpp b/test/unit/theory/theory_bags_rewriter_white.cpp index ca142c6b9..ee1e89448 100644 --- a/test/unit/theory/theory_bags_rewriter_white.cpp +++ b/test/unit/theory/theory_bags_rewriter_white.cpp @@ -785,7 +785,19 @@ TEST_F(TestTheoryWhiteBagsRewriter, map) d_nodeManager->mkBag(d_nodeManager->stringType(), empty, d_nodeManager->mkConst(CONST_RATIONAL, Rational(7))); - ASSERT_TRUE(rewritten == bag); + // - (bag.map f (bag.union_disjoint K1 K2)) = + // (bag.union_disjoint (bag.map f K1) (bag.map f K2)) + Node k1 = d_skolemManager->mkDummySkolem("K1", A.getType()); + Node k2 = d_skolemManager->mkDummySkolem("K2", A.getType()); + Node f = d_skolemManager->mkDummySkolem("f", lambda.getType()); + Node unionDisjointK1K2 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, k1, k2); + Node n3 = d_nodeManager->mkNode(BAG_MAP, f, unionDisjointK1K2); + Node rewritten3 = Rewriter::rewrite(n3); + Node mapK1 = d_nodeManager->mkNode(BAG_MAP, f, k1); + Node mapK2 = d_nodeManager->mkNode(BAG_MAP, f, k2); + Node unionDisjointMapK1K2 = + d_nodeManager->mkNode(BAG_UNION_DISJOINT, mapK1, mapK2); + ASSERT_TRUE(rewritten3 == unionDisjointMapK1K2); } } // namespace test -- 2.30.2