Assert(n.getKind() == kind::BAG_MAP);
if (n[1].isConst())
{
- // (bag.map f bag.empty) = bag.empty
+ // (bag.map f (as bag.empty (Bag T1)) = (as bag.empty (Bag T2))
// (bag.map f (bag "a" 3)) = (bag (f "a") 3)
std::map<Node, Rational> elements = NormalForm::getBagElements(n[1]);
std::map<Node, Rational> mappedElements;
{
case BAG_MAKE:
{
+ // (bag.map f (bag x y)) = (bag (apply f x) y)
Node mappedElement = d_nm->mkNode(APPLY_UF, n[0], n[1][0]);
Node ret =
d_nm->mkBag(n[0].getType().getRangeType(), mappedElement, n[1][1]);
case BAG_UNION_DISJOINT:
{
- Node a = d_nm->mkNode(BAG_MAP, n[1][0]);
- Node b = d_nm->mkNode(BAG_MAP, n[1][1]);
+ // (bag.map f (bag.union_disjoint A B)) =
+ // (bag.union_disjoint (bag.map f A) (bag.map f B))
+ Node a = d_nm->mkNode(BAG_MAP, n[0], n[1][0]);
+ Node b = d_nm->mkNode(BAG_MAP, n[0], n[1][1]);
Node ret = d_nm->mkNode(BAG_UNION_DISJOINT, a, b);
return BagsRewriteResponse(ret, Rewrite::MAP_UNION_DISJOINT);
}
/**
* rewrites for n include:
- * - (bag.map (lambda ((x U)) t) bag.empty) = bag.empty
- * - (bag.map (lambda ((x U)) t) (bag y z)) = (bag (apply (lambda ((x U)) t)
- * y) z)
- * - (bag.map (lambda ((x U)) t) (bag.union_disjoint A B)) =
- * (bag.union_disjoint
- * (bag ((lambda ((x U)) t) "a") 3)
- * (bag ((lambda ((x U)) t) "b") 4))
- *
+ * - (bag.map f (as bag.empty (Bag T1)) = (as bag.empty (Bag T2))
+ * - (bag.map f (bag x y)) = (bag (apply f x) y)
+ * - (bag.map f (bag.union_disjoint A B)) =
+ * (bag.union_disjoint (bag.map f A) (bag.map f B))
+ * where f: T1 -> T2
*/
BagsRewriteResponse postRewriteMap(const TNode& n) const;
d_nodeManager->mkBag(d_nodeManager->stringType(),
empty,
d_nodeManager->mkConst(CONST_RATIONAL, Rational(7)));
- ASSERT_TRUE(rewritten == bag);
+ // - (bag.map f (bag.union_disjoint K1 K2)) =
+ // (bag.union_disjoint (bag.map f K1) (bag.map f K2))
+ Node k1 = d_skolemManager->mkDummySkolem("K1", A.getType());
+ Node k2 = d_skolemManager->mkDummySkolem("K2", A.getType());
+ Node f = d_skolemManager->mkDummySkolem("f", lambda.getType());
+ Node unionDisjointK1K2 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, k1, k2);
+ Node n3 = d_nodeManager->mkNode(BAG_MAP, f, unionDisjointK1K2);
+ Node rewritten3 = Rewriter::rewrite(n3);
+ Node mapK1 = d_nodeManager->mkNode(BAG_MAP, f, k1);
+ Node mapK2 = d_nodeManager->mkNode(BAG_MAP, f, k2);
+ Node unionDisjointMapK1K2 =
+ d_nodeManager->mkNode(BAG_UNION_DISJOINT, mapK1, mapK2);
+ ASSERT_TRUE(rewritten3 == unionDisjointMapK1K2);
}
} // namespace test