This PR implements rewrite rules for bags. This PR focuses on rewrite rules for non constant nodes.
Rewriting nodes with constant children is delegated to bags::NormalForm class (future PR).
theory/assertion.h
theory/atom_requests.cpp
theory/atom_requests.h
+ theory/bags/bags_rewriter.cpp
+ theory/bags/bags_rewriter.h
+ theory/bags/bags_statistics.cpp
+ theory/bags/bags_statistics.h
theory/bags/inference_manager.cpp
theory/bags/inference_manager.h
theory/bags/normal_form.cpp
theory/bags/normal_form.h
+ theory/bags/rewrites.cpp
+ theory/bags/rewrites.h
theory/bags/solver_state.cpp
theory/bags/solver_state.h
theory/bags/term_registry.cpp
theory/bags/term_registry.h
theory/bags/theory_bags.cpp
theory/bags/theory_bags.h
- theory/bags/theory_bags_rewriter.cpp
- theory/bags/theory_bags_rewriter.h
theory/bags/theory_bags_type_enumerator.cpp
theory/bags/theory_bags_type_enumerator.h
theory/bags/theory_bags_type_rules.h
--- /dev/null
+/********************* */
+/*! \file bags_rewriter.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Bags theory rewriter.
+ **/
+
+#include "theory/bags/bags_rewriter.h"
+
+#include "normal_form.h"
+
+using namespace CVC4::kind;
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+BagsRewriteResponse::BagsRewriteResponse()
+ : d_node(Node::null()), d_rewrite(Rewrite::NONE)
+{
+}
+
+BagsRewriteResponse::BagsRewriteResponse(Node n, Rewrite rewrite)
+ : d_node(n), d_rewrite(rewrite)
+{
+}
+
+BagsRewriteResponse::BagsRewriteResponse(const BagsRewriteResponse& r)
+ : d_node(r.d_node), d_rewrite(r.d_rewrite)
+{
+}
+
+BagsRewriter::BagsRewriter(HistogramStat<Rewrite>* statistics)
+ : d_statistics(statistics)
+{
+ d_nm = NodeManager::currentNM();
+}
+
+RewriteResponse BagsRewriter::postRewrite(TNode n)
+{
+ BagsRewriteResponse response;
+ if (n.isConst())
+ {
+ // no need to rewrite n if it is already in a normal form
+ response = BagsRewriteResponse(n, Rewrite::NONE);
+ }
+ else if (NormalForm::AreChildrenConstants(n))
+ {
+ Node value = NormalForm::evaluate(n);
+ response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
+ }
+ else
+ {
+ Kind k = n.getKind();
+ switch (k)
+ {
+ case MK_BAG: response = rewriteMakeBag(n); break;
+ case BAG_COUNT: response = rewriteBagCount(n); break;
+ case UNION_MAX: response = rewriteUnionMax(n); break;
+ case UNION_DISJOINT: response = rewriteUnionDisjoint(n); break;
+ case INTERSECTION_MIN: response = rewriteIntersectionMin(n); break;
+ case DIFFERENCE_SUBTRACT: response = rewriteDifferenceSubtract(n); break;
+ case DIFFERENCE_REMOVE: response = rewriteDifferenceRemove(n); break;
+ case BAG_CHOOSE: response = rewriteChoose(n); break;
+ case BAG_CARD: response = rewriteCard(n); break;
+ case BAG_IS_SINGLETON: response = rewriteIsSingleton(n); break;
+ default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
+ }
+ }
+
+ Trace("bags-rewrite") << "postRewrite " << n << " to " << response.d_node
+ << " by " << response.d_rewrite << "." << std::endl;
+
+ if (d_statistics != nullptr)
+ {
+ (*d_statistics) << response.d_rewrite;
+ }
+ if (response.d_node != n)
+ {
+ return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
+ }
+ return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
+}
+
+RewriteResponse BagsRewriter::preRewrite(TNode n)
+{
+ BagsRewriteResponse response;
+ Kind k = n.getKind();
+ switch (k)
+ {
+ case EQUAL: response = rewriteEqual(n); break;
+ case BAG_IS_INCLUDED: response = rewriteIsIncluded(n); break;
+ default: response = BagsRewriteResponse(n, Rewrite::NONE);
+ }
+
+ Trace("bags-rewrite") << "preRewrite " << n << " to " << response.d_node
+ << " by " << response.d_rewrite << "." << std::endl;
+
+ if (d_statistics != nullptr)
+ {
+ (*d_statistics) << response.d_rewrite;
+ }
+ if (response.d_node != n)
+ {
+ return RewriteResponse(RewriteStatus::REWRITE_AGAIN_FULL, response.d_node);
+ }
+ return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteEqual(const TNode& n) const
+{
+ Assert(n.getKind() == EQUAL);
+ if (n[0] == n[1])
+ {
+ // (= A A) = true where A is a bag
+ return BagsRewriteResponse(d_nm->mkConst(true), Rewrite::IDENTICAL_NODES);
+ }
+ return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteIsIncluded(const TNode& n) const
+{
+ Assert(n.getKind() == BAG_IS_INCLUDED);
+
+ // (bag.is_included A B) = ((difference_subtract A B) == emptybag)
+ Node emptybag = d_nm->mkConst(EmptyBag(n[0].getType()));
+ Node subtract = d_nm->mkNode(DIFFERENCE_SUBTRACT, n[0], n[1]);
+ Node equal = subtract.eqNode(emptybag);
+ return BagsRewriteResponse(equal, Rewrite::SUB_BAG);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteMakeBag(const TNode& n) const
+{
+ Assert(n.getKind() == MK_BAG);
+ // return emptybag for negative or zero multiplicity
+ if (n[1].isConst() && n[1].getConst<Rational>().sgn() != 1)
+ {
+ // (mkBag x c) = emptybag where c <= 0
+ Node emptybag = d_nm->mkConst(EmptyBag(n.getType()));
+ return BagsRewriteResponse(emptybag, Rewrite::MK_BAG_COUNT_NEGATIVE);
+ }
+ return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteBagCount(const TNode& n) const
+{
+ Assert(n.getKind() == BAG_COUNT);
+ if (n[1].isConst() && n[1].getKind() == EMPTYBAG)
+ {
+ // (bag.count x emptybag) = 0
+ return BagsRewriteResponse(d_nm->mkConst(Rational(0)),
+ Rewrite::COUNT_EMPTY);
+ }
+ if (n[1].getKind() == MK_BAG && n[0] == n[1][0])
+ {
+ // (bag.count x (mkBag x c) = c where c > 0 is a constant
+ return BagsRewriteResponse(n[1][1], Rewrite::COUNT_MK_BAG);
+ }
+ return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteUnionMax(const TNode& n) const
+{
+ Assert(n.getKind() == UNION_MAX);
+ if (n[1].getKind() == EMPTYBAG || n[0] == n[1])
+ {
+ // (union_max A A) = A
+ // (union_max A emptybag) = A
+ return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_SAME_OR_EMPTY);
+ }
+ if (n[0].getKind() == EMPTYBAG)
+ {
+ // (union_max emptybag A) = A
+ return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_EMPTY);
+ }
+
+ if ((n[1].getKind() == UNION_MAX || n[1].getKind() == UNION_DISJOINT)
+ && (n[0] == n[1][0] || n[0] == n[1][1]))
+ {
+ // (union_max A (union_max A B)) = (union_max A B)
+ // (union_max A (union_max B A)) = (union_max B A)
+ // (union_max A (union_disjoint A B)) = (union_disjoint A B)
+ // (union_max A (union_disjoint B A)) = (union_disjoint B A)
+ return BagsRewriteResponse(n[1], Rewrite::UNION_MAX_UNION_LEFT);
+ }
+
+ if ((n[0].getKind() == UNION_MAX || n[0].getKind() == UNION_DISJOINT)
+ && (n[0][0] == n[1] || n[0][1] == n[1]))
+ {
+ // (union_max (union_max A B) A)) = (union_max A B)
+ // (union_max (union_max B A) A)) = (union_max B A)
+ // (union_max (union_disjoint A B) A)) = (union_disjoint A B)
+ // (union_max (union_disjoint B A) A)) = (union_disjoint B A)
+ return BagsRewriteResponse(n[0], Rewrite::UNION_MAX_UNION_RIGHT);
+ }
+ return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteUnionDisjoint(const TNode& n) const
+{
+ Assert(n.getKind() == UNION_DISJOINT);
+ if (n[1].getKind() == EMPTYBAG)
+ {
+ // (union_disjoint A emptybag) = A
+ return BagsRewriteResponse(n[0], Rewrite::UNION_DISJOINT_EMPTY_RIGHT);
+ }
+ if (n[0].getKind() == EMPTYBAG)
+ {
+ // (union_disjoint emptybag A) = A
+ return BagsRewriteResponse(n[1], Rewrite::UNION_DISJOINT_EMPTY_LEFT);
+ }
+ if ((n[0].getKind() == UNION_MAX && n[1].getKind() == INTERSECTION_MIN)
+ || (n[1].getKind() == UNION_MAX && n[0].getKind() == INTERSECTION_MIN))
+
+ {
+ // (union_disjoint (union_max A B) (intersection_min A B)) =
+ // (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
+ // check if the operands of union_max and intersection_min are the same
+ std::set<Node> left(n[0].begin(), n[0].end());
+ std::set<Node> right(n[0].begin(), n[0].end());
+ if (left == right)
+ {
+ Node rewritten = d_nm->mkNode(UNION_DISJOINT, n[0][0], n[0][1]);
+ return BagsRewriteResponse(rewritten, Rewrite::UNION_DISJOINT_MAX_MIN);
+ }
+ }
+ return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteIntersectionMin(const TNode& n) const
+{
+ Assert(n.getKind() == INTERSECTION_MIN);
+ if (n[0].getKind() == EMPTYBAG)
+ {
+ // (intersection_min emptybag A) = emptybag
+ return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_EMPTY_LEFT);
+ }
+ if (n[1].getKind() == EMPTYBAG)
+ {
+ // (intersection_min A emptybag) = emptybag
+ return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_EMPTY_RIGHT);
+ }
+ if (n[0] == n[1])
+ {
+ // (intersection_min A A) = A
+ return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SAME);
+ }
+ if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
+ {
+ if (n[0] == n[1][0] || n[0] == n[1][1])
+ {
+ // (intersection_min A (union_disjoint A B)) = A
+ // (intersection_min A (union_disjoint B A)) = A
+ // (intersection_min A (union_max A B)) = A
+ // (intersection_min A (union_max B A)) = A
+ return BagsRewriteResponse(n[0], Rewrite::INTERSECTION_SHARED_LEFT);
+ }
+ }
+
+ if (n[0].getKind() == UNION_DISJOINT || n[0].getKind() == UNION_MAX)
+ {
+ if (n[1] == n[0][0] || n[1] == n[0][1])
+ {
+ // (intersection_min (union_disjoint A B) A) = A
+ // (intersection_min (union_disjoint B A) A) = A
+ // (intersection_min (union_max A B) A) = A
+ // (intersection_min (union_max B A) A) = A
+ return BagsRewriteResponse(n[1], Rewrite::INTERSECTION_SHARED_RIGHT);
+ }
+ }
+
+ return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteDifferenceSubtract(
+ const TNode& n) const
+{
+ Assert(n.getKind() == DIFFERENCE_SUBTRACT);
+ if (n[0].getKind() == EMPTYBAG || n[1].getKind() == EMPTYBAG)
+ {
+ // (difference_subtract A emptybag) = A
+ // (difference_subtract emptybag A) = emptybag
+ return BagsRewriteResponse(n[0], Rewrite::SUBTRACT_RETURN_LEFT);
+ }
+ if (n[0] == n[1])
+ {
+ // (difference_subtract A A) = emptybag
+ Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
+ return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_SAME);
+ }
+
+ if (n[0].getKind() == UNION_DISJOINT)
+ {
+ if (n[1] == n[0][0])
+ {
+ // (difference_subtract (union_disjoint A B) A) = B
+ return BagsRewriteResponse(n[0][1],
+ Rewrite::SUBTRACT_DISJOINT_SHARED_LEFT);
+ }
+ if (n[1] == n[0][1])
+ {
+ // (difference_subtract (union_disjoint B A) A) = B
+ return BagsRewriteResponse(n[0][0],
+ Rewrite::SUBTRACT_DISJOINT_SHARED_RIGHT);
+ }
+ }
+
+ if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
+ {
+ if (n[0] == n[1][0] || n[0] == n[1][1])
+ {
+ // (difference_subtract A (union_disjoint A B)) = emptybag
+ // (difference_subtract A (union_disjoint B A)) = emptybag
+ // (difference_subtract A (union_max A B)) = emptybag
+ // (difference_subtract A (union_max B A)) = emptybag
+ Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
+ return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_FROM_UNION);
+ }
+ }
+
+ if (n[0].getKind() == INTERSECTION_MIN)
+ {
+ if (n[1] == n[0][0] || n[1] == n[0][1])
+ {
+ // (difference_subtract (intersection_min A B) A) = emptybag
+ // (difference_subtract (intersection_min B A) A) = emptybag
+ Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
+ return BagsRewriteResponse(emptyBag, Rewrite::SUBTRACT_MIN);
+ }
+ }
+
+ return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteDifferenceRemove(const TNode& n) const
+{
+ Assert(n.getKind() == DIFFERENCE_REMOVE);
+
+ if (n[0].getKind() == EMPTYBAG || n[1].getKind() == EMPTYBAG)
+ {
+ // (difference_remove A emptybag) = A
+ // (difference_remove emptybag B) = emptybag
+ return BagsRewriteResponse(n[0], Rewrite::REMOVE_RETURN_LEFT);
+ }
+
+ if (n[0] == n[1])
+ {
+ // (difference_remove A A) = emptybag
+ Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
+ return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_SAME);
+ }
+
+ if (n[1].getKind() == UNION_DISJOINT || n[1].getKind() == UNION_MAX)
+ {
+ if (n[0] == n[1][0] || n[0] == n[1][1])
+ {
+ // (difference_remove A (union_disjoint A B)) = emptybag
+ // (difference_remove A (union_disjoint B A)) = emptybag
+ // (difference_remove A (union_max A B)) = emptybag
+ // (difference_remove A (union_max B A)) = emptybag
+ Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
+ return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_FROM_UNION);
+ }
+ }
+
+ if (n[0].getKind() == INTERSECTION_MIN)
+ {
+ if (n[1] == n[0][0] || n[1] == n[0][1])
+ {
+ // (difference_remove (intersection_min A B) A) = emptybag
+ // (difference_remove (intersection_min B A) A) = emptybag
+ Node emptyBag = d_nm->mkConst(EmptyBag(n.getType()));
+ return BagsRewriteResponse(emptyBag, Rewrite::REMOVE_MIN);
+ }
+ }
+
+ return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteChoose(const TNode& n) const
+{
+ Assert(n.getKind() == BAG_CHOOSE);
+ if (n[0].getKind() == MK_BAG && n[0][1].isConst())
+ {
+ // (bag.choose (mkBag x c)) = x where c is a constant > 0
+ return BagsRewriteResponse(n[0][0], Rewrite::CHOOSE_MK_BAG);
+ }
+ return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteCard(const TNode& n) const
+{
+ Assert(n.getKind() == BAG_CARD);
+ if (n[0].getKind() == MK_BAG && n[0][1].isConst())
+ {
+ // (bag.card (mkBag x c)) = c where c is a constant > 0
+ return BagsRewriteResponse(n[0][1], Rewrite::CARD_MK_BAG);
+ }
+
+ if (n[0].getKind() == UNION_DISJOINT)
+ {
+ // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
+ Node A = d_nm->mkNode(BAG_CARD, n[0][0]);
+ Node B = d_nm->mkNode(BAG_CARD, n[0][1]);
+ Node plus = d_nm->mkNode(PLUS, A, B);
+ return BagsRewriteResponse(plus, Rewrite::CARD_DISJOINT);
+ }
+
+ return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+BagsRewriteResponse BagsRewriter::rewriteIsSingleton(const TNode& n) const
+{
+ Assert(n.getKind() == BAG_IS_SINGLETON);
+ if (n[0].getKind() == MK_BAG)
+ {
+ // (bag.is_singleton (mkBag x c)) = (c == 1)
+ Node one = d_nm->mkConst(Rational(1));
+ Node equal = n[0][1].eqNode(one);
+ return BagsRewriteResponse(equal, Rewrite::IS_SINGLETON_MK_BAG);
+ }
+ return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
+} // namespace bags
+} // namespace theory
+} // namespace CVC4
--- /dev/null
+/********************* */
+/*! \file bags_rewriter.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Bags theory rewriter.
+ **/
+
+#include "cvc4_private.h"
+#include "theory/bags/rewrites.h"
+
+#ifndef CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H
+#define CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H
+
+#include "theory/rewriter.h"
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+/** a class represents the result of rewriting bag nodes */
+struct BagsRewriteResponse
+{
+ BagsRewriteResponse();
+ BagsRewriteResponse(Node n, Rewrite rewrite);
+ BagsRewriteResponse(const BagsRewriteResponse& r);
+ /** the rewritten node */
+ Node d_node;
+ /** type of rewrite used by bags */
+ Rewrite d_rewrite;
+
+}; /* struct BagsRewriteResponse */
+
+class BagsRewriter : public TheoryRewriter
+{
+ public:
+ BagsRewriter(HistogramStat<Rewrite>* statistics = nullptr);
+
+ /**
+ * postRewrite nodes with kinds: MK_BAG, BAG_COUNT, UNION_MAX, UNION_DISJOINT,
+ * INTERSECTION_MIN, DIFFERENCE_SUBTRACT, DIFFERENCE_REMOVE, BAG_CHOOSE,
+ * BAG_CARD, BAG_IS_SINGLETON.
+ * See the rewrite rules for these kinds below.
+ */
+ RewriteResponse postRewrite(TNode n) override;
+ /**
+ * preRewrite nodes with kinds: EQUAL, BAG_IS_INCLUDED.
+ * See the rewrite rules for these kinds below.
+ */
+ RewriteResponse preRewrite(TNode n) override;
+
+ private:
+ /**
+ * rewrites for n include:
+ * - (= A A) = true where A is a bag
+ */
+ BagsRewriteResponse rewriteEqual(const TNode& n) const;
+
+ /**
+ * rewrites for n include:
+ * - (bag.is_included A B) = ((difference_subtract A B) == emptybag)
+ */
+ BagsRewriteResponse rewriteIsIncluded(const TNode& n) const;
+
+ /**
+ * rewrites for n include:
+ * - (mkBag x 0) = (emptybag T) where T is the type of x
+ * - (mkBag x (-c)) = (emptybag T) where T is the type of x, and c > 0 is a
+ * constant
+ * - otherwise = n
+ */
+ BagsRewriteResponse rewriteMakeBag(const TNode& n) const;
+ /**
+ * rewrites for n include:
+ * - (bag.count x emptybag) = 0
+ * - (bag.count x (mkBag x c) = c where c > 0 is a constant
+ * - otherwise = n
+ */
+ BagsRewriteResponse rewriteBagCount(const TNode& n) const;
+
+ /**
+ * rewrites for n include:
+ * - (union_max A emptybag) = A
+ * - (union_max emptybag A) = A
+ * - (union_max A A) = A
+ * - (union_max A (union_max A B)) = (union_max A B)
+ * - (union_max A (union_max B A)) = (union_max B A)
+ * - (union_max (union_max A B) A) = (union_max A B)
+ * - (union_max (union_max B A) A) = (union_max B A)
+ * - (union_max A (union_disjoint A B)) = (union_disjoint A B)
+ * - (union_max A (union_disjoint B A)) = (union_disjoint B A)
+ * - (union_max (union_disjoint A B) A) = (union_disjoint A B)
+ * - (union_max (union_disjoint B A) A) = (union_disjoint B A)
+ * - otherwise = n
+ */
+ BagsRewriteResponse rewriteUnionMax(const TNode& n) const;
+
+ /**
+ * rewrites for n include:
+ * - (union_disjoint A emptybag) = A
+ * - (union_disjoint emptybag A) = A
+ * - (union_disjoint (union_max A B) (intersection_min A B)) =
+ * (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
+ * - other permutations of the above like swapping A and B, or swapping
+ * intersection_min and union_max
+ * - otherwise = n
+ */
+ BagsRewriteResponse rewriteUnionDisjoint(const TNode& n) const;
+
+ /**
+ * rewrites for n include:
+ * - (intersection_min A emptybag) = emptybag
+ * - (intersection_min emptybag A) = emptybag
+ * - (intersection_min A A) = A
+ * - (intersection_min A (union_disjoint A B)) = A
+ * - (intersection_min A (union_disjoint B A)) = A
+ * - (intersection_min (union_disjoint A B) A) = A
+ * - (intersection_min (union_disjoint B A) A) = A
+ * - (intersection_min A (union_max A B)) = A
+ * - (intersection_min A (union_max B A)) = A
+ * - (intersection_min (union_max A B) A) = A
+ * - (intersection_min (union_max B A) A) = A
+ * - otherwise = n
+ */
+ BagsRewriteResponse rewriteIntersectionMin(const TNode& n) const;
+
+ /**
+ * rewrites for n include:
+ * - (difference_subtract A emptybag) = A
+ * - (difference_subtract emptybag A) = emptybag
+ * - (difference_subtract A A) = emptybag
+ * - (difference_subtract (union_disjoint A B) A) = B
+ * - (difference_subtract (union_disjoint B A) A) = B
+ * - (difference_subtract A (union_disjoint A B)) = emptybag
+ * - (difference_subtract A (union_disjoint B A)) = emptybag
+ * - (difference_subtract A (union_max A B)) = emptybag
+ * - (difference_subtract A (union_max B A)) = emptybag
+ * - (difference_subtract (intersection_min A B) A) = emptybag
+ * - (difference_subtract (intersection_min B A) A) = emptybag
+ * - otherwise = n
+ */
+ BagsRewriteResponse rewriteDifferenceSubtract(const TNode& n) const;
+
+ /**
+ * rewrites for n include:
+ * - (difference_remove A emptybag) = A
+ * - (difference_remove emptybag A) = emptybag
+ * - (difference_remove A A) = emptybag
+ * - (difference_remove A (union_disjoint A B)) = emptybag
+ * - (difference_remove A (union_disjoint B A)) = emptybag
+ * - (difference_remove A (union_max A B)) = emptybag
+ * - (difference_remove A (union_max B A)) = emptybag
+ * - (difference_remove (intersection_min A B) A) = emptybag
+ * - (difference_remove (intersection_min B A) A) = emptybag
+ * - otherwise = n
+ */
+ BagsRewriteResponse rewriteDifferenceRemove(const TNode& n) const;
+ /**
+ * rewrites for n include:
+ * - (bag.choose (mkBag x c)) = x where c is a constant > 0
+ * - otherwise = n
+ */
+ BagsRewriteResponse rewriteChoose(const TNode& n) const;
+ /**
+ * rewrites for n include:
+ * - (bag.card (mkBag x c)) = c where c is a constant > 0
+ * - (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
+ * - otherwise = n
+ */
+ BagsRewriteResponse rewriteCard(const TNode& n) const;
+
+ /**
+ * rewrites for n include:
+ * - (bag.is_singleton (mkBag x c)) = (c == 1)
+ */
+ BagsRewriteResponse rewriteIsSingleton(const TNode& n) const;
+
+ private:
+ /** Reference to the rewriter statistics. */
+ NodeManager* d_nm;
+ /** Reference to the rewriter statistics. */
+ HistogramStat<Rewrite>* d_statistics;
+}; /* class TheoryBagsRewriter */
+
+} // namespace bags
+} // namespace theory
+} // namespace CVC4
+
+#endif /* CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H */
--- /dev/null
+/********************* */
+/*! \file bags_statistics.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Statistics for the theory of bags
+ **/
+
+#include "theory/bags/bags_statistics.h"
+
+#include "smt/smt_statistics_registry.h"
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+BagsStatistics::BagsStatistics() : d_rewrites("theory::bags::rewrites")
+{
+ smtStatisticsRegistry()->registerStat(&d_rewrites);
+}
+
+BagsStatistics::~BagsStatistics()
+{
+ smtStatisticsRegistry()->unregisterStat(&d_rewrites);
+}
+
+} // namespace bags
+} // namespace theory
+} // namespace CVC4
--- /dev/null
+/********************* */
+/*! \file bags_statistics.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Statistics for the theory of bags
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef CVC4__THEORY__BAGS_STATISTICS_H
+#define CVC4__THEORY__BAGS_STATISTICS_H
+
+#include "expr/kind.h"
+#include "theory/bags/rewrites.h"
+#include "util/statistics_registry.h"
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+/**
+ * Statistics for the theory of bags.
+ */
+class BagsStatistics
+{
+ public:
+ BagsStatistics();
+ ~BagsStatistics();
+
+ /** Counts the number of applications of each type of rewrite rule */
+ HistogramStat<Rewrite> d_rewrites;
+};
+
+} // namespace bags
+} // namespace theory
+} // namespace CVC4
+
+#endif /* CVC4__THEORY__BAGS_STATISTICS_H */
::CVC4::theory::bags::TheoryBags \
"theory/bags/theory_bags.h"
typechecker "theory/bags/theory_bags_type_rules.h"
-rewriter ::CVC4::theory::bags::TheoryBagsRewriter \
- "theory/bags/theory_bags_rewriter.h"
+rewriter ::CVC4::theory::bags::BagsRewriter \
+ "theory/bags/bags_rewriter.h"
properties parametric
properties check propagate presolve
return false;
}
+bool NormalForm::AreChildrenConstants(TNode n)
+{
+ return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
+}
+
+Node NormalForm::evaluate(TNode n)
+{
+ // TODO(projects#223): complete this function
+ return CVC4::Node();
+}
} // namespace bags
} // namespace theory
} // namespace CVC4
\ No newline at end of file
** \brief Normal form for bag constants.
**/
-#include "cvc4_private.h"
#include <expr/node.h>
+#include "cvc4_private.h"
+
#ifndef CVC4__THEORY__BAGS__NORMAL_FORM_H
#define CVC4__THEORY__BAGS__NORMAL_FORM_H
* Also handles the corner cases of empty bag and singleton bag.
*/
static bool checkNormalConstant(TNode n);
+ /**
+ * check whether all children of the given node are in normal form
+ */
+ static bool AreChildrenConstants(TNode n);
+ /**
+ * evaluate the node n to a constant value
+ */
+ static Node evaluate(TNode n);
};
} // namespace bags
} // namespace theory
--- /dev/null
+/********************* */
+/*! \file rewrites.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Implementation of inference information utility.
+ **/
+
+#include "theory/bags/rewrites.h"
+
+#include <iostream>
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+const char* toString(Rewrite r)
+{
+ switch (r)
+ {
+ case Rewrite::NONE: return "NONE";
+ case Rewrite::CARD_DISJOINT: return "CARD_DISJOINT";
+ case Rewrite::CARD_MK_BAG: return "CARD_MK_BAG";
+ case Rewrite::CHOOSE_MK_BAG: return "CHOOSE_MK_BAG";
+ case Rewrite::CONSTANT_EVALUATION: return "CONSTANT_EVALUATION";
+ case Rewrite::COUNT_EMPTY: return "COUNT_EMPTY";
+ case Rewrite::COUNT_MK_BAG: return "COUNT_MK_BAG";
+ case Rewrite::IDENTICAL_NODES: return "IDENTICAL_NODES";
+ case Rewrite::INTERSECTION_EMPTY_LEFT: return "INTERSECTION_EMPTY_LEFT";
+ case Rewrite::INTERSECTION_EMPTY_RIGHT: return "INTERSECTION_EMPTY_RIGHT";
+ case Rewrite::INTERSECTION_SAME: return "INTERSECTION_SAME";
+ case Rewrite::INTERSECTION_SHARED_LEFT: return "INTERSECTION_SHARED_LEFT";
+ case Rewrite::INTERSECTION_SHARED_RIGHT: return "INTERSECTION_SHARED_RIGHT";
+ case Rewrite::IS_SINGLETON_MK_BAG: return "IS_SINGLETON_MK_BAG";
+ case Rewrite::MK_BAG_COUNT_NEGATIVE: return "MK_BAG_COUNT_NEGATIVE";
+ case Rewrite::REMOVE_FROM_UNION: return "REMOVE_FROM_UNION";
+ case Rewrite::REMOVE_MIN: return "REMOVE_MIN";
+ case Rewrite::REMOVE_RETURN_LEFT: return "REMOVE_RETURN_LEFT";
+ case Rewrite::REMOVE_SAME: return "REMOVE_SAME";
+ case Rewrite::SUB_BAG: return "SUB_BAG";
+ case Rewrite::SUBTRACT_DISJOINT_SHARED_LEFT:
+ return "SUBTRACT_DISJOINT_SHARED_LEFT";
+ case Rewrite::SUBTRACT_DISJOINT_SHARED_RIGHT:
+ return "SUBTRACT_DISJOINT_SHARED_RIGHT";
+ case Rewrite::SUBTRACT_FROM_UNION: return "SUBTRACT_FROM_UNION";
+ case Rewrite::SUBTRACT_MIN: return "SUBTRACT_MIN";
+ case Rewrite::SUBTRACT_RETURN_LEFT: return "SUBTRACT_RETURN_LEFT";
+ case Rewrite::SUBTRACT_SAME: return "SUBTRACT_SAME";
+ case Rewrite::UNION_DISJOINT_EMPTY_LEFT: return "UNION_DISJOINT_EMPTY_LEFT";
+ case Rewrite::UNION_DISJOINT_EMPTY_RIGHT:
+ return "UNION_DISJOINT_EMPTY_RIGHT";
+ case Rewrite::UNION_DISJOINT_MAX_MIN: return "UNION_DISJOINT_MAX_MIN";
+ case Rewrite::UNION_MAX_EMPTY: return "UNION_MAX_EMPTY";
+ case Rewrite::UNION_MAX_SAME_OR_EMPTY: return "UNION_MAX_SAME_OR_EMPTY";
+ case Rewrite::UNION_MAX_UNION_LEFT: return "UNION_MAX_UNION_LEFT";
+ case Rewrite::UNION_MAX_UNION_RIGHT: return "UNION_MAX_UNION_RIGHT";
+
+ default: return "?";
+ }
+}
+
+std::ostream& operator<<(std::ostream& out, Rewrite r)
+{
+ out << toString(r);
+ return out;
+}
+
+} // namespace bags
+} // namespace theory
+} // namespace CVC4
--- /dev/null
+/********************* */
+/*! \file rewrites.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Type for rewrites for bags.
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef CVC4__THEORY__BAGS__REWRITES_H
+#define CVC4__THEORY__BAGS__REWRITES_H
+
+#include <iosfwd>
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+/** Types of rewrites used by bags
+ *
+ * This rewrites are documented where they are used in the rewriter.
+ */
+enum class Rewrite : uint32_t
+{
+ NONE, // no rewrite happened
+ CARD_DISJOINT,
+ CARD_MK_BAG,
+ CHOOSE_MK_BAG,
+ CONSTANT_EVALUATION,
+ COUNT_EMPTY,
+ COUNT_MK_BAG,
+ IDENTICAL_NODES,
+ INTERSECTION_EMPTY_LEFT,
+ INTERSECTION_EMPTY_RIGHT,
+ INTERSECTION_SAME,
+ INTERSECTION_SHARED_LEFT,
+ INTERSECTION_SHARED_RIGHT,
+ IS_SINGLETON_MK_BAG,
+ MK_BAG_COUNT_NEGATIVE,
+ REMOVE_FROM_UNION,
+ REMOVE_MIN,
+ REMOVE_RETURN_LEFT,
+ REMOVE_SAME,
+ SUB_BAG,
+ SUBTRACT_DISJOINT_SHARED_LEFT,
+ SUBTRACT_DISJOINT_SHARED_RIGHT,
+ SUBTRACT_FROM_UNION,
+ SUBTRACT_MIN,
+ SUBTRACT_RETURN_LEFT,
+ SUBTRACT_SAME,
+ UNION_DISJOINT_EMPTY_LEFT,
+ UNION_DISJOINT_EMPTY_RIGHT,
+ UNION_DISJOINT_MAX_MIN,
+ UNION_MAX_EMPTY,
+ UNION_MAX_SAME_OR_EMPTY,
+ UNION_MAX_UNION_LEFT,
+ UNION_MAX_UNION_RIGHT
+};
+
+/**
+ * Converts an rewrite to a string. Note: This function is also used in
+ * `safe_print()`. Changing this functions name or signature will result in
+ * `safe_print()` printing "<unsupported>" instead of the proper strings for
+ * the enum values.
+ *
+ * @param r The rewrite
+ * @return The name of the rewrite
+ */
+const char* toString(Rewrite r);
+
+/**
+ * Writes an rewrite name to a stream.
+ *
+ * @param out The stream to write to
+ * @param r The rewrite to write to the stream
+ * @return The stream
+ */
+std::ostream& operator<<(std::ostream& out, Rewrite r);
+
+} // namespace bags
+} // namespace theory
+} // namespace CVC4
+
+#endif /* CVC4__THEORY__BAGS__REWRITES_H */
: Theory(THEORY_BAGS, c, u, out, valuation, logicInfo, pnm),
d_state(c, u, valuation),
d_im(*this, d_state, pnm),
- d_rewriter(),
- d_notify(*this, d_im)
+ d_notify(*this, d_im),
+ d_statistics(),
+ d_rewriter(&d_statistics.d_rewrites)
{
// use the official theory state and inference manager objects
d_theoryState = &d_state;
#include <memory>
+#include "theory/bags/bags_rewriter.h"
+#include "theory/bags/bags_statistics.h"
#include "theory/bags/inference_manager.h"
#include "theory/bags/solver_state.h"
-#include "theory/bags/theory_bags_rewriter.h"
#include "theory/theory.h"
#include "theory/theory_eq_notify.h"
#include "theory/uf/equality_engine.h"
InferenceManager d_im;
/** Instance of the above class */
NotifyClass d_notify;
+ /** Statistics for the theory of bags. */
+ BagsStatistics d_statistics;
/** The theory rewriter for this theory. */
- TheoryBagsRewriter d_rewriter;
+ BagsRewriter d_rewriter;
void eqNotifyNewClass(TNode t);
void eqNotifyMerge(TNode t1, TNode t2);
+++ /dev/null
-/********************* */
-/*! \file theory_bags_rewriter.cpp
- ** \verbatim
- ** Top contributors (to current version):
- ** Mudathir Mohamed
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
- ** All rights reserved. See the file COPYING in the top-level source
- ** directory for licensing information.\endverbatim
- **
- ** \brief Bags theory rewriter.
- **/
-
-#include "theory/bags/theory_bags_rewriter.h"
-
-using namespace CVC4::kind;
-
-namespace CVC4 {
-namespace theory {
-namespace bags {
-
-RewriteResponse TheoryBagsRewriter::postRewrite(TNode node)
-{
- // TODO(projects#225): complete the code here
- return RewriteResponse(REWRITE_DONE, node);
-}
-
-RewriteResponse TheoryBagsRewriter::preRewrite(TNode node)
-{
- // TODO(projects#225): complete the code here
- return RewriteResponse(REWRITE_DONE, node);
-}
-
-} // namespace bags
-} // namespace theory
-} // namespace CVC4
+++ /dev/null
-/********************* */
-/*! \file theory_bags_rewriter.h
- ** \verbatim
- ** Top contributors (to current version):
- ** Mudathir Mohamed
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
- ** All rights reserved. See the file COPYING in the top-level source
- ** directory for licensing information.\endverbatim
- **
- ** \brief Bags theory rewriter.
- **/
-
-#include "cvc4_private.h"
-
-#ifndef CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H
-#define CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H
-
-#include "theory/rewriter.h"
-
-namespace CVC4 {
-namespace theory {
-namespace bags {
-
-class TheoryBagsRewriter : public TheoryRewriter
-{
- public:
- RewriteResponse postRewrite(TNode node) override;
-
- RewriteResponse preRewrite(TNode node) override;
-}; /* class TheoryBagsRewriter */
-
-} // namespace bags
-} // namespace theory
-} // namespace CVC4
-
-#endif /* CVC4__THEORY__BAGS__THEORY_BAGS_REWRITER_H */
*
* This seems too expensive to implement.
* For now we are implementing an obvious solution
- * {(1,1)}, {(1,2)}, {(1,3)}, ... which works for both fininte and infinite
+ * {(1,1)}, {(1,2)}, {(1,3)}, ... which works for both finite and infinite
* types
*/
BagEnumerator& operator++() override;
{
static Cardinality computeCardinality(TypeNode type)
{
- return Cardinality::UNKNOWN_CARD;
+ return Cardinality::INTEGERS;
}
static bool isWellFounded(TypeNode type) { return type[0].isWellFounded(); }
cvc4_add_unit_test_white(logic_info_white theory)
cvc4_add_unit_test_white(sequences_rewriter_white theory)
cvc4_add_unit_test_white(theory_arith_white theory)
+cvc4_add_unit_test_white(theory_bags_rewriter_black theory)
cvc4_add_unit_test_white(theory_bags_type_rules_black theory)
cvc4_add_unit_test_white(theory_bv_rewriter_white theory)
cvc4_add_unit_test_white(theory_bv_white theory)
--- /dev/null
+/********************* */
+/*! \file theory_bags_rewriter_black.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Black box testing of bags rewriter
+ **/
+
+#include <cxxtest/TestSuite.h>
+
+#include "expr/dtype.h"
+#include "smt/smt_engine.h"
+#include "theory/bags/bags_rewriter.h"
+#include "theory/strings/type_enumerator.h"
+
+using namespace CVC4;
+using namespace CVC4::smt;
+using namespace CVC4::theory;
+using namespace CVC4::kind;
+using namespace CVC4::theory::bags;
+using namespace std;
+
+typedef expr::Attribute<Node, Node> attribute;
+
+class BagsTypeRuleBlack : public CxxTest::TestSuite
+{
+ public:
+ void setUp() override
+ {
+ d_em.reset(new ExprManager());
+ d_smt.reset(new SmtEngine(d_em.get()));
+ d_nm.reset(NodeManager::fromExprManager(d_em.get()));
+ d_smt->finishInit();
+ d_rewriter.reset(new BagsRewriter(nullptr));
+ }
+
+ void tearDown() override
+ {
+ d_rewriter.reset();
+ d_smt.reset();
+ d_nm.release();
+ d_em.reset();
+ }
+
+ std::vector<Node> getNStrings(size_t n)
+ {
+ std::vector<Node> elements(n);
+ for (size_t i = 0; i < n; i++)
+ {
+ elements[i] = d_nm->mkSkolem("x", d_nm->stringType());
+ }
+ return elements;
+ }
+
+ void testEmptyBagNormalForm()
+ {
+ Node emptybag = d_nm->mkConst(EmptyBag(d_nm->stringType()));
+ // empty bags are in normal form
+ TS_ASSERT(emptybag.isConst());
+ RewriteResponse response = d_rewriter->postRewrite(emptybag);
+ TS_ASSERT(emptybag == response.d_node && response.d_status == REWRITE_DONE);
+ }
+
+ void testBagEquality()
+ {
+ vector<Node> elements = getNStrings(2);
+ Node x = elements[0];
+ Node y = elements[1];
+ Node c = d_nm->mkSkolem("c", d_nm->integerType());
+ Node d = d_nm->mkSkolem("d", d_nm->integerType());
+ Node bagX = d_nm->mkNode(MK_BAG, x, c);
+ Node bagY = d_nm->mkNode(MK_BAG, y, d);
+ Node emptyBag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+
+ // (= A A) = true where A is a bag
+ Node n1 = emptyBag.eqNode(emptyBag);
+ RewriteResponse response1 = d_rewriter->preRewrite(n1);
+ TS_ASSERT(response1.d_node == d_nm->mkConst(true)
+ && response1.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testMkBagConstantElement()
+ {
+ vector<Node> elements = getNStrings(1);
+ Node negative =
+ d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(-1)));
+ Node zero = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(0)));
+ Node positive =
+ d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(1)));
+ Node emptybag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
+ RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
+ RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
+
+ // bags with non-positive multiplicity are rewritten as empty bags
+ TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL
+ && negativeResponse.d_node == emptybag);
+ TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL
+ && zeroResponse.d_node == emptybag);
+
+ // no change for positive
+ TS_ASSERT(positiveResponse.d_status == REWRITE_DONE
+ && positive == positiveResponse.d_node);
+ }
+
+ void testMkBagVariableElement()
+ {
+ Node skolem = d_nm->mkSkolem("x", d_nm->stringType());
+ Node variable = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(-1)));
+ Node negative = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(-1)));
+ Node zero = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(0)));
+ Node positive = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(1)));
+ Node emptybag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ RewriteResponse negativeResponse = d_rewriter->postRewrite(negative);
+ RewriteResponse zeroResponse = d_rewriter->postRewrite(zero);
+ RewriteResponse positiveResponse = d_rewriter->postRewrite(positive);
+
+ // bags with non-positive multiplicity are rewritten as empty bags
+ TS_ASSERT(negativeResponse.d_status == REWRITE_AGAIN_FULL
+ && negativeResponse.d_node == emptybag);
+ TS_ASSERT(zeroResponse.d_status == REWRITE_AGAIN_FULL
+ && zeroResponse.d_node == emptybag);
+
+ // no change for positive
+ TS_ASSERT(positiveResponse.d_status == REWRITE_DONE
+ && positive == positiveResponse.d_node);
+ }
+
+ void testBagCount()
+ {
+ int n = 3;
+ Node skolem = d_nm->mkSkolem("x", d_nm->stringType());
+ Node emptyBag = d_nm->mkConst(EmptyBag(d_nm->mkBagType(skolem.getType())));
+ Node bag = d_nm->mkNode(MK_BAG, skolem, d_nm->mkConst(Rational(n)));
+
+ // (bag.count x emptybag) = 0
+ Node n1 = d_nm->mkNode(BAG_COUNT, skolem, emptyBag);
+ RewriteResponse response1 = d_rewriter->postRewrite(n1);
+ TS_ASSERT(response1.d_status == REWRITE_AGAIN_FULL
+ && response1.d_node == d_nm->mkConst(Rational(0)));
+
+ // (bag.count x (mkBag x c) = c where c > 0 is a constant
+ Node n2 = d_nm->mkNode(BAG_COUNT, skolem, bag);
+ RewriteResponse response2 = d_rewriter->postRewrite(n2);
+ TS_ASSERT(response2.d_status == REWRITE_AGAIN_FULL
+ && response2.d_node == d_nm->mkConst(Rational(n)));
+ }
+
+ void testUnionMax()
+ {
+ int n = 3;
+ vector<Node> elements = getNStrings(2);
+ Node emptyBag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
+ Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
+ Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+ Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+ Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+ Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+
+ // (union_max A emptybag) = A
+ Node unionMax1 = d_nm->mkNode(UNION_MAX, A, emptyBag);
+ RewriteResponse response1 = d_rewriter->postRewrite(unionMax1);
+ TS_ASSERT(response1.d_node == A
+ && response1.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max emptybag A) = A
+ Node unionMax2 = d_nm->mkNode(UNION_MAX, emptyBag, A);
+ RewriteResponse response2 = d_rewriter->postRewrite(unionMax2);
+ TS_ASSERT(response2.d_node == A
+ && response2.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max A A) = A
+ Node unionMax3 = d_nm->mkNode(UNION_MAX, A, A);
+ RewriteResponse response3 = d_rewriter->postRewrite(unionMax3);
+ TS_ASSERT(response3.d_node == A
+ && response3.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max A (union_max A B)) = (union_max A B)
+ Node unionMax4 = d_nm->mkNode(UNION_MAX, A, unionMaxAB);
+ RewriteResponse response4 = d_rewriter->postRewrite(unionMax4);
+ TS_ASSERT(response4.d_node == unionMaxAB
+ && response4.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max A (union_max B A)) = (union_max B A)
+ Node unionMax5 = d_nm->mkNode(UNION_MAX, A, unionMaxBA);
+ RewriteResponse response5 = d_rewriter->postRewrite(unionMax5);
+ TS_ASSERT(response5.d_node == unionMaxBA
+ && response4.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max (union_max A B) A) = (union_max A B)
+ Node unionMax6 = d_nm->mkNode(UNION_MAX, unionMaxAB, A);
+ RewriteResponse response6 = d_rewriter->postRewrite(unionMax6);
+ TS_ASSERT(response6.d_node == unionMaxAB
+ && response6.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max (union_max B A) A) = (union_max B A)
+ Node unionMax7 = d_nm->mkNode(UNION_MAX, unionMaxBA, A);
+ RewriteResponse response7 = d_rewriter->postRewrite(unionMax7);
+ TS_ASSERT(response7.d_node == unionMaxBA
+ && response7.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max A (union_disjoint A B)) = (union_disjoint A B)
+ Node unionMax8 = d_nm->mkNode(UNION_MAX, A, unionDisjointAB);
+ RewriteResponse response8 = d_rewriter->postRewrite(unionMax8);
+ TS_ASSERT(response8.d_node == unionDisjointAB
+ && response8.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max A (union_disjoint B A)) = (union_disjoint B A)
+ Node unionMax9 = d_nm->mkNode(UNION_MAX, A, unionDisjointBA);
+ RewriteResponse response9 = d_rewriter->postRewrite(unionMax9);
+ TS_ASSERT(response9.d_node == unionDisjointBA
+ && response9.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max (union_disjoint A B) A) = (union_disjoint A B)
+ Node unionMax10 = d_nm->mkNode(UNION_MAX, unionDisjointAB, A);
+ RewriteResponse response10 = d_rewriter->postRewrite(unionMax10);
+ TS_ASSERT(response10.d_node == unionDisjointAB
+ && response10.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_max (union_disjoint B A) A) = (union_disjoint B A)
+ Node unionMax11 = d_nm->mkNode(UNION_MAX, unionDisjointBA, A);
+ RewriteResponse response11 = d_rewriter->postRewrite(unionMax11);
+ TS_ASSERT(response11.d_node == unionDisjointBA
+ && response11.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testUnionDisjoint()
+ {
+ int n = 3;
+ vector<Node> elements = getNStrings(2);
+ Node emptyBag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
+ Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
+ Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+ Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+ Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+ Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+ Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
+ Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
+
+ // (union_disjoint A emptybag) = A
+ Node unionDisjoint1 = d_nm->mkNode(UNION_DISJOINT, A, emptyBag);
+ RewriteResponse response1 = d_rewriter->postRewrite(unionDisjoint1);
+ TS_ASSERT(response1.d_node == A
+ && response1.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_disjoint emptybag A) = A
+ Node unionDisjoint2 = d_nm->mkNode(UNION_DISJOINT, emptyBag, A);
+ RewriteResponse response2 = d_rewriter->postRewrite(unionDisjoint2);
+ TS_ASSERT(response2.d_node == A
+ && response2.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_disjoint (union_max A B) (intersection_min B A)) =
+ // (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
+ Node unionDisjoint3 =
+ d_nm->mkNode(UNION_DISJOINT, unionMaxAB, intersectionBA);
+ RewriteResponse response3 = d_rewriter->postRewrite(unionDisjoint3);
+ TS_ASSERT(response3.d_node == unionDisjointAB
+ && response3.d_status == REWRITE_AGAIN_FULL);
+
+ // (union_disjoint (intersection_min B A)) (union_max A B) =
+ // (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b)
+ Node unionDisjoint4 =
+ d_nm->mkNode(UNION_DISJOINT, unionMaxBA, intersectionBA);
+ RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4);
+ TS_ASSERT(response4.d_node == unionDisjointBA
+ && response4.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testIntersectionMin()
+ {
+ int n = 3;
+ vector<Node> elements = getNStrings(2);
+ Node emptyBag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
+ Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
+ Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+ Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+ Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+ Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+
+ // (intersection_min A emptybag) = emptyBag
+ Node n1 = d_nm->mkNode(INTERSECTION_MIN, A, emptyBag);
+ RewriteResponse response1 = d_rewriter->postRewrite(n1);
+ TS_ASSERT(response1.d_node == emptyBag
+ && response1.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min emptybag A) = emptyBag
+ Node n2 = d_nm->mkNode(INTERSECTION_MIN, emptyBag, A);
+ RewriteResponse response2 = d_rewriter->postRewrite(n2);
+ TS_ASSERT(response2.d_node == emptyBag
+ && response2.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min A A) = A
+ Node n3 = d_nm->mkNode(INTERSECTION_MIN, A, A);
+ RewriteResponse response3 = d_rewriter->postRewrite(n3);
+ TS_ASSERT(response3.d_node == A
+ && response3.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min A (union_max A B) = A
+ Node n4 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxAB);
+ RewriteResponse response4 = d_rewriter->postRewrite(n4);
+ TS_ASSERT(response4.d_node == A
+ && response4.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min A (union_max B A) = A
+ Node n5 = d_nm->mkNode(INTERSECTION_MIN, A, unionMaxBA);
+ RewriteResponse response5 = d_rewriter->postRewrite(n5);
+ TS_ASSERT(response5.d_node == A
+ && response4.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min (union_max A B) A) = A
+ Node n6 = d_nm->mkNode(INTERSECTION_MIN, unionMaxAB, A);
+ RewriteResponse response6 = d_rewriter->postRewrite(n6);
+ TS_ASSERT(response6.d_node == A
+ && response6.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min (union_max B A) A) = A
+ Node n7 = d_nm->mkNode(INTERSECTION_MIN, unionMaxBA, A);
+ RewriteResponse response7 = d_rewriter->postRewrite(n7);
+ TS_ASSERT(response7.d_node == A
+ && response7.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min A (union_disjoint A B) = A
+ Node n8 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointAB);
+ RewriteResponse response8 = d_rewriter->postRewrite(n8);
+ TS_ASSERT(response8.d_node == A
+ && response8.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min A (union_disjoint B A) = A
+ Node n9 = d_nm->mkNode(INTERSECTION_MIN, A, unionDisjointBA);
+ RewriteResponse response9 = d_rewriter->postRewrite(n9);
+ TS_ASSERT(response9.d_node == A
+ && response9.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min (union_disjoint A B) A) = A
+ Node n10 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointAB, A);
+ RewriteResponse response10 = d_rewriter->postRewrite(n10);
+ TS_ASSERT(response10.d_node == A
+ && response10.d_status == REWRITE_AGAIN_FULL);
+
+ // (intersection_min (union_disjoint B A) A) = A
+ Node n11 = d_nm->mkNode(INTERSECTION_MIN, unionDisjointBA, A);
+ RewriteResponse response11 = d_rewriter->postRewrite(n11);
+ TS_ASSERT(response11.d_node == A
+ && response11.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testDifferenceSubtract()
+ {
+ int n = 3;
+ vector<Node> elements = getNStrings(2);
+ Node emptyBag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
+ Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
+ Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+ Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+ Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+ Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+ Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
+ Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
+
+ // (difference_subtract A emptybag) = A
+ Node n1 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, emptyBag);
+ RewriteResponse response1 = d_rewriter->postRewrite(n1);
+ TS_ASSERT(response1.d_node == A
+ && response1.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract emptybag A) = emptyBag
+ Node n2 = d_nm->mkNode(DIFFERENCE_SUBTRACT, emptyBag, A);
+ RewriteResponse response2 = d_rewriter->postRewrite(n2);
+ TS_ASSERT(response2.d_node == emptyBag
+ && response2.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract A A) = emptybag
+ Node n3 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, A);
+ RewriteResponse response3 = d_rewriter->postRewrite(n3);
+ TS_ASSERT(response3.d_node == emptyBag
+ && response3.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract (union_disjoint A B) A) = B
+ Node n4 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointAB, A);
+ RewriteResponse response4 = d_rewriter->postRewrite(n4);
+ TS_ASSERT(response4.d_node == B
+ && response4.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract (union_disjoint B A) A) = B
+ Node n5 = d_nm->mkNode(DIFFERENCE_SUBTRACT, unionDisjointBA, A);
+ RewriteResponse response5 = d_rewriter->postRewrite(n5);
+ TS_ASSERT(response5.d_node == B
+ && response4.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract A (union_disjoint A B)) = emptybag
+ Node n6 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointAB);
+ RewriteResponse response6 = d_rewriter->postRewrite(n6);
+ TS_ASSERT(response6.d_node == emptyBag
+ && response6.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract A (union_disjoint B A)) = emptybag
+ Node n7 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionDisjointBA);
+ RewriteResponse response7 = d_rewriter->postRewrite(n7);
+ TS_ASSERT(response7.d_node == emptyBag
+ && response7.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract A (union_max A B)) = emptybag
+ Node n8 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxAB);
+ RewriteResponse response8 = d_rewriter->postRewrite(n8);
+ TS_ASSERT(response8.d_node == emptyBag
+ && response8.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract A (union_max B A)) = emptybag
+ Node n9 = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, unionMaxBA);
+ RewriteResponse response9 = d_rewriter->postRewrite(n9);
+ TS_ASSERT(response9.d_node == emptyBag
+ && response9.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract (intersection_min A B) A) = emptybag
+ Node n10 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionAB, A);
+ RewriteResponse response10 = d_rewriter->postRewrite(n10);
+ TS_ASSERT(response10.d_node == emptyBag
+ && response10.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_subtract (intersection_min B A) A) = emptybag
+ Node n11 = d_nm->mkNode(DIFFERENCE_SUBTRACT, intersectionBA, A);
+ RewriteResponse response11 = d_rewriter->postRewrite(n11);
+ TS_ASSERT(response11.d_node == emptyBag
+ && response11.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testDifferenceRemove()
+ {
+ int n = 3;
+ vector<Node> elements = getNStrings(2);
+ Node emptyBag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(n)));
+ Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(n + 1)));
+ Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+ Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
+ Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+ Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+ Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
+ Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
+
+ // (difference_remove A emptybag) = A
+ Node n1 = d_nm->mkNode(DIFFERENCE_REMOVE, A, emptyBag);
+ RewriteResponse response1 = d_rewriter->postRewrite(n1);
+ TS_ASSERT(response1.d_node == A
+ && response1.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove emptybag A) = emptyBag
+ Node n2 = d_nm->mkNode(DIFFERENCE_REMOVE, emptyBag, A);
+ RewriteResponse response2 = d_rewriter->postRewrite(n2);
+ TS_ASSERT(response2.d_node == emptyBag
+ && response2.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove A A) = emptybag
+ Node n3 = d_nm->mkNode(DIFFERENCE_REMOVE, A, A);
+ RewriteResponse response3 = d_rewriter->postRewrite(n3);
+ TS_ASSERT(response3.d_node == emptyBag
+ && response3.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove A (union_disjoint A B)) = emptybag
+ Node n6 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointAB);
+ RewriteResponse response6 = d_rewriter->postRewrite(n6);
+ TS_ASSERT(response6.d_node == emptyBag
+ && response6.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove A (union_disjoint B A)) = emptybag
+ Node n7 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionDisjointBA);
+ RewriteResponse response7 = d_rewriter->postRewrite(n7);
+ TS_ASSERT(response7.d_node == emptyBag
+ && response7.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove A (union_max A B)) = emptybag
+ Node n8 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxAB);
+ RewriteResponse response8 = d_rewriter->postRewrite(n8);
+ TS_ASSERT(response8.d_node == emptyBag
+ && response8.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove A (union_max B A)) = emptybag
+ Node n9 = d_nm->mkNode(DIFFERENCE_REMOVE, A, unionMaxBA);
+ RewriteResponse response9 = d_rewriter->postRewrite(n9);
+ TS_ASSERT(response9.d_node == emptyBag
+ && response9.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove (intersection_min A B) A) = emptybag
+ Node n10 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionAB, A);
+ RewriteResponse response10 = d_rewriter->postRewrite(n10);
+ TS_ASSERT(response10.d_node == emptyBag
+ && response10.d_status == REWRITE_AGAIN_FULL);
+
+ // (difference_remove (intersection_min B A) A) = emptybag
+ Node n11 = d_nm->mkNode(DIFFERENCE_REMOVE, intersectionBA, A);
+ RewriteResponse response11 = d_rewriter->postRewrite(n11);
+ TS_ASSERT(response11.d_node == emptyBag
+ && response11.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testChoose()
+ {
+ Node x = d_nm->mkSkolem("x", d_nm->stringType());
+ Node c = d_nm->mkConst(Rational(3));
+ Node bag = d_nm->mkNode(MK_BAG, x, c);
+
+ // (bag.choose (mkBag x c)) = x where c is a constant > 0
+ Node n1 = d_nm->mkNode(BAG_CHOOSE, bag);
+ RewriteResponse response1 = d_rewriter->postRewrite(n1);
+ TS_ASSERT(response1.d_node == x
+ && response1.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testBagCard()
+ {
+ Node x = d_nm->mkSkolem("x", d_nm->stringType());
+ Node emptyBag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node zero = d_nm->mkConst(Rational(0));
+ Node c = d_nm->mkConst(Rational(3));
+ Node bag = d_nm->mkNode(MK_BAG, x, c);
+ vector<Node> elements = getNStrings(2);
+ Node A = d_nm->mkNode(MK_BAG, elements[0], d_nm->mkConst(Rational(4)));
+ Node B = d_nm->mkNode(MK_BAG, elements[1], d_nm->mkConst(Rational(5)));
+ Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+
+ // TODO(projects#223): enable this test after implementing bags normal form
+ // // (bag.card emptybag) = 0
+ // Node n1 = d_nm->mkNode(BAG_CARD, emptyBag);
+ // RewriteResponse response1 = d_rewriter->postRewrite(n1);
+ // TS_ASSERT(response1.d_node == zero && response1.d_status ==
+ // REWRITE_AGAIN_FULL);
+
+ // (bag.card (mkBag x c)) = c where c is a constant > 0
+ Node n2 = d_nm->mkNode(BAG_CARD, bag);
+ RewriteResponse response2 = d_rewriter->postRewrite(n2);
+ TS_ASSERT(response2.d_node == c
+ && response2.d_status == REWRITE_AGAIN_FULL);
+
+ // (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
+ Node n3 = d_nm->mkNode(BAG_CARD, unionDisjointAB);
+ Node cardA = d_nm->mkNode(BAG_CARD, A);
+ Node cardB = d_nm->mkNode(BAG_CARD, B);
+ Node plus = d_nm->mkNode(PLUS, cardA, cardB);
+ RewriteResponse response3 = d_rewriter->postRewrite(n3);
+ TS_ASSERT(response3.d_node == plus
+ && response3.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ void testIsSingleton()
+ {
+ Node emptybag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node x = d_nm->mkSkolem("x", d_nm->stringType());
+ Node c = d_nm->mkSkolem("c", d_nm->integerType());
+ Node bag = d_nm->mkNode(MK_BAG, x, c);
+
+ // TODO(projects#223): complete this function
+ // (bag.is_singleton emptybag) = false
+ // Node n1 = d_nm->mkNode(BAG_IS_SINGLETON, emptybag);
+ // RewriteResponse response1 = d_rewriter->postRewrite(n1);
+ // TS_ASSERT(response1.d_node == d_nm->mkConst(false)
+ // && response1.d_status == REWRITE_AGAIN_FULL);
+
+ // (bag.is_singleton (mkBag x c) = (c == 1)
+ Node n2 = d_nm->mkNode(BAG_IS_SINGLETON, bag);
+ RewriteResponse response2 = d_rewriter->postRewrite(n2);
+ Node one = d_nm->mkConst(Rational(1));
+ Node equal = c.eqNode(one);
+ TS_ASSERT(response2.d_node == equal
+ && response2.d_status == REWRITE_AGAIN_FULL);
+ }
+
+ private:
+ std::unique_ptr<ExprManager> d_em;
+ std::unique_ptr<SmtEngine> d_smt;
+ std::unique_ptr<NodeManager> d_nm;
+
+ std::unique_ptr<BagsRewriter> d_rewriter;
+}; /* class BagsTypeRuleBlack */